Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,29 @@
GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS + GPT_4_128K_MODELS + GPT_4O_MODELS + O_MODELS

def default_max_tokens(model: str) -> int:
"""
Gets the default number of max tokens for the given model.
:param model: The model name
:return: The default number of max tokens
"""Return the default ``max_tokens`` value for ``model``.

Models are grouped by category and mapped to their default token count.
"""
base = 1200
if model in GPT_3_MODELS:
return base
elif model in GPT_4_MODELS:
return base * 2
elif model in GPT_3_16K_MODELS:
if model == "gpt-3.5-turbo-1106":
return 4096
return base * 4
elif model in GPT_4_32K_MODELS:
return base * 8
elif model in GPT_4_VISION_MODELS:
return 4096
elif model in GPT_4_128K_MODELS:
return 4096
elif model in GPT_4O_MODELS:
return 4096
elif model in O_MODELS:
if model == "gpt-3.5-turbo-1106":
return 4096

model_defaults = {
GPT_3_MODELS: base,
GPT_4_MODELS: base * 2,
GPT_3_16K_MODELS: base * 4,
GPT_4_32K_MODELS: base * 8,
GPT_4_VISION_MODELS: 4096,
GPT_4_128K_MODELS: 4096,
GPT_4O_MODELS: 4096,
O_MODELS: 4096,
}

for models, value in model_defaults.items():
if model in models:
return value


def are_functions_available(model: str) -> bool:
"""
Expand Down
58 changes: 25 additions & 33 deletions bot/plugin_manager.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
import json
import importlib
import inspect
import os
import pkgutil

from plugins.gtts_text_to_speech import GTTSTextToSpeech
from plugins.auto_tts import AutoTextToSpeech
from plugins.dice import DicePlugin
from plugins.youtube_audio_extractor import YouTubeAudioExtractorPlugin
from plugins.ddg_image_search import DDGImageSearchPlugin
from plugins.spotify import SpotifyPlugin
from plugins.crypto import CryptoPlugin
from plugins.weather import WeatherPlugin
from plugins.ddg_web_search import DDGWebSearchPlugin
from plugins.wolfram_alpha import WolframAlphaPlugin
from plugins.deepl import DeeplTranslatePlugin
from plugins.worldtimeapi import WorldTimeApiPlugin
from plugins.whois_ import WhoisPlugin
from plugins.webshot import WebshotPlugin
from plugins.iplocation import IpLocationPlugin
from plugins.plugin import Plugin


class PluginManager:
Expand All @@ -23,25 +13,27 @@ class PluginManager:
"""

def __init__(self, config):
enabled_plugins = config.get('plugins', [])
plugin_mapping = {
'wolfram': WolframAlphaPlugin,
'weather': WeatherPlugin,
'crypto': CryptoPlugin,
'ddg_web_search': DDGWebSearchPlugin,
'ddg_image_search': DDGImageSearchPlugin,
'spotify': SpotifyPlugin,
'worldtimeapi': WorldTimeApiPlugin,
'youtube_audio_extractor': YouTubeAudioExtractorPlugin,
'dice': DicePlugin,
'deepl_translate': DeeplTranslatePlugin,
'gtts_text_to_speech': GTTSTextToSpeech,
'auto_tts': AutoTextToSpeech,
'whois': WhoisPlugin,
'webshot': WebshotPlugin,
'iplocation': IpLocationPlugin,
enabled_plugins = [p for p in config.get('plugins', []) if p]

alias = {
'wolfram': 'wolfram_alpha',
'deepl_translate': 'deepl',
'whois': 'whois_',
}
self.plugins = [plugin_mapping[plugin]() for plugin in enabled_plugins if plugin in plugin_mapping]

plugin_dir = os.path.join(os.path.dirname(__file__), 'plugins')
available = {}
for _, module_name, _ in pkgutil.iter_modules([plugin_dir]):
if module_name == 'plugin':
continue
module = importlib.import_module(f'plugins.{module_name}')
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Plugin) and obj is not Plugin:
plugin_name = next((k for k, v in alias.items() if v == module_name), module_name)
available[plugin_name] = obj
break

self.plugins = [available[name]() for name in enabled_plugins if name in available]

def get_functions_specs(self):
"""
Expand Down