Skip to content

Commit

Permalink
postpone the import of external different providers to lift the requi…
Browse files Browse the repository at this point in the history
…rement forcing user installing different non-related provider apis.
  • Loading branch information
davidleon committed Nov 6, 2024
1 parent f11e671 commit 453b932
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 14 deletions.
42 changes: 29 additions & 13 deletions metagpt/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,36 @@
@Author : alexanderwu
@File : __init__.py
"""
import importlib
class LLMFactory:
def __init__(self, module_name, instance_name):
self.module_name = module_name
self.instance_name = instance_name
self._module = None

from metagpt.provider.google_gemini_api import GeminiLLM
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM
from metagpt.provider.ark_api import ArkLLM
def __getattr__(self, name):
if self._module is None:
self._module = importlib.import_module(self.module_name)
return getattr(self._module, name)
def __instancecheck__(self, instance):
if self._module is None:
self._module = importlib.import_module(self.module_name)
return isinstance(instance, getattr(self._module, self.instance_name))


GeminiLLM = LLMFactory("metagpt.provider.google_gemini_api ", "GeminiLLM")
OllamaLLM = LLMFactory("metagpt.provider.ollama_api ", "OllamaLLM")
OpenAILLM = LLMFactory("metagpt.provider.openai_api ", "OpenAILLM")
ZhiPuAILLM = LLMFactory("metagpt.provider.zhipuai_api ", "ZhiPuAILLM")
AzureOpenAILLM = LLMFactory("metagpt.provider.azure_openai_api ", "AzureOpenAILLM")
MetaGPTLLM = LLMFactory("metagpt.provider.metagpt_api ", "MetaGPTLLM")
HumanProvider = LLMFactory("metagpt.provider.human_provider ", "HumanProvider")
SparkLLM = LLMFactory("metagpt.provider.spark_api ", "SparkLLM")
QianFanLLM = LLMFactory("metagpt.provider.qianfan_api ", "QianFanLLM")
DashScopeLLM = LLMFactory("metagpt.provider.dashscope_api ", "DashScopeLLM")
AnthropicLLM = LLMFactory("metagpt.provider.anthropic_api ", "AnthropicLLM")
BedrockLLM = LLMFactory("metagpt.provider.bedrock_api ", "BedrockLLM")
ArkLLM = LLMFactory("metagpt.provider.ark_api ", "ArkLLM")

__all__ = [
"GeminiLLM",
Expand Down
27 changes: 26 additions & 1 deletion metagpt/provider/llm_provider_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,42 @@
"""
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM

import importlib

class LLMProviderRegistry:
def __init__(self):
self.providers = {}
self._module_map = {
LLMType.OPENAI: "metagpt.provider.openai_api",
LLMType.ANTHROPIC: "metagpt.provider.anthropic_api",
LLMType.CLAUDE: "metagpt.provider.anthropic_api", # Same module as Anthropic
LLMType.SPARK: "metagpt.provider.spark_api",
LLMType.ZHIPUAI: "metagpt.provider.zhipuai_api",
LLMType.FIREWORKS: "metagpt.provider.fireworks_api",
LLMType.OPEN_LLM: "metagpt.provider.open_llm_api",
LLMType.GEMINI: "metagpt.provider.google_gemini_api",
LLMType.METAGPT: "metagpt.provider.metagpt_api",
LLMType.AZURE: "metagpt.provider.azure_openai_api",
LLMType.OLLAMA: "metagpt.provider.ollama_api",
LLMType.QIANFAN: "metagpt.provider.qianfan_api", # Baidu BCE
LLMType.DASHSCOPE: "metagpt.provider.dashscope_api", # Aliyun LingJi DashScope
LLMType.MOONSHOT: "metagpt.provider.moonshot_api",
LLMType.MISTRAL: "metagpt.provider.mistral_api",
LLMType.YI: "metagpt.provider.yi_api", # lingyiwanwu
LLMType.OPENROUTER: "metagpt.provider.openrouter_api",
LLMType.BEDROCK: "metagpt.provider.bedrock_api",
LLMType.ARK: "metagpt.provider.ark_api",
}

def register(self, key, provider_cls):
self.providers[key] = provider_cls

def get_provider(self, enum: LLMType):
"""get provider instance according to the enum"""
if enum not in self.providers:
# Import and register the provider if not already registered
module_name = self._module_map[enum]
importlib.import_module(module_name)
return self.providers[enum]


Expand Down

0 comments on commit 453b932

Please sign in to comment.