From 6f59a6cffedae41e63ae8e4eaca55f19b42c65c2 Mon Sep 17 00:00:00 2001 From: david Date: Wed, 13 Nov 2024 13:27:09 +0800 Subject: [PATCH] postpone the import of external different providers to lift the requirement forcing user installing different non-related provider apis. --- metagpt/configs/llm_config.py | 23 +++++++ metagpt/provider/__init__.py | 83 +++++++++++++++-------- metagpt/provider/llm_provider_registry.py | 8 ++- 3 files changed, 85 insertions(+), 29 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index ef034ca49..caf6d40c3 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -25,6 +25,7 @@ class LLMType(Enum): OPEN_LLM = "open_llm" GEMINI = "gemini" METAGPT = "metagpt" + HUMAN = "human" AZURE = "azure" OLLAMA = "ollama" # /chat at ollama api OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api @@ -42,6 +43,28 @@ class LLMType(Enum): def __missing__(self, key): return self.OPENAI +LLMModuleMap = { + LLMType.OPENAI: "metagpt.provider.openai_api", + LLMType.ANTHROPIC: "metagpt.provider.anthropic_api", + LLMType.CLAUDE: "metagpt.provider.anthropic_api", # Same module as Anthropic + LLMType.SPARK: "metagpt.provider.spark_api", + LLMType.ZHIPUAI: "metagpt.provider.zhipuai_api", + LLMType.FIREWORKS: "metagpt.provider.fireworks_api", + LLMType.OPEN_LLM: "metagpt.provider.open_llm_api", + LLMType.GEMINI: "metagpt.provider.google_gemini_api", + LLMType.METAGPT: "metagpt.provider.metagpt_api", + LLMType.HUMAN: "metagpt.provider.human_provider", + LLMType.AZURE: "metagpt.provider.azure_openai_api", + LLMType.OLLAMA: "metagpt.provider.ollama_api", + LLMType.QIANFAN: "metagpt.provider.qianfan_api", # Baidu BCE + LLMType.DASHSCOPE: "metagpt.provider.dashscope_api", # Aliyun LingJi DashScope + LLMType.MOONSHOT: "metagpt.provider.moonshot_api", + LLMType.MISTRAL: "metagpt.provider.mistral_api", + LLMType.YI: "metagpt.provider.yi_api", # lingyiwanwu + LLMType.OPENROUTER: "metagpt.provider.openrouter_api", + LLMType.BEDROCK: "metagpt.provider.bedrock_api", + LLMType.ARK: "metagpt.provider.ark_api", +} class LLMConfig(YamlModel): """Config for LLM diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index c90f5774a..6e57083e5 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -5,33 +5,62 @@ @Author : alexanderwu @File : __init__.py """ +import importlib +from metagpt.configs.llm_config import LLMType, LLMModuleMap -from metagpt.provider.google_gemini_api import GeminiLLM -from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.openai_api import OpenAILLM -from metagpt.provider.zhipuai_api import ZhiPuAILLM -from metagpt.provider.azure_openai_api import AzureOpenAILLM -from metagpt.provider.metagpt_api import MetaGPTLLM -from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.spark_api import SparkLLM -from metagpt.provider.qianfan_api import QianFanLLM -from metagpt.provider.dashscope_api import DashScopeLLM -from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock_api import BedrockLLM -from metagpt.provider.ark_api import ArkLLM +class LLMFactory: + def __init__(self, module_name, instance_name): + self.module_name = module_name + self.instance_name = instance_name + self._module = None -__all__ = [ - "GeminiLLM", - "OpenAILLM", - "ZhiPuAILLM", - "AzureOpenAILLM", - "MetaGPTLLM", - "OllamaLLM", - "HumanProvider", - "SparkLLM", - "QianFanLLM", - "DashScopeLLM", - "AnthropicLLM", - "BedrockLLM", - "ArkLLM", + def __getattr__(self, name): + if self._module is None: + self._module = importlib.import_module(self.module_name) + return getattr(self._module, name) + + def __instancecheck__(self, instance): + if self._module is None: + self._module = importlib.import_module(self.module_name) + return isinstance(instance, getattr(self._module, self.instance_name)) + + def __call__(self, config): + # Import the module when it鈥檚 called for the first time + if self._module is None: + self._module = importlib.import_module(self.module_name) + + # Create an instance of the specified class from the module with the given config + return getattr(self._module, self.instance_name)(config) + +def create_llm_symbol(llm_configurations): + factories = {name: LLMFactory(LLMModuleMap[llm_type], name) for llm_type, name in llm_configurations} + # Add the factory created llm objects to the global namespace + globals().update(factories) + return factories.keys() + +# List of LLM configurations +llm_configurations = [ + (LLMType.GEMINI, "GeminiLLM"), + (LLMType.OLLAMA, "OllamaLLM"), + (LLMType.OPENAI, "OpenAILLM"), + (LLMType.ZHIPUAI, "ZhiPuAILLM"), + (LLMType.AZURE, "AzureOpenAILLM"), + (LLMType.METAGPT, "MetaGPTLLM"), + (LLMType.HUMAN, "HumanProvider"), + (LLMType.SPARK, "SparkLLM"), + (LLMType.QIANFAN, "QianFanLLM"), + (LLMType.DASHSCOPE, "DashScopeLLM"), + (LLMType.ANTHROPIC, "AnthropicLLM"), + (LLMType.BEDROCK, "BedrockLLM"), + (LLMType.ARK, "ArkLLM"), + (LLMType.FIREWORKS, "FireworksLLM"), + (LLMType.OPEN_LLM, "OpenLLM"), + (LLMType.MOONSHOT, "MoonshotLLM"), + (LLMType.MISTRAL, "MistralLLM"), + (LLMType.YI, "YiLLM"), + (LLMType.OPENROUTER, "OpenRouterLLM"), + (LLMType.CLAUDE, "ClaudeLLM"), ] + +# Create all LLMFactory instances and get created symbols +__all__ = create_llm_symbol(llm_configurations) \ No newline at end of file diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 7f8618590..a9e739f44 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -5,9 +5,9 @@ @Author : alexanderwu @File : llm_provider_registry.py """ -from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.configs.llm_config import LLMConfig, LLMType, LLMModuleMap from metagpt.provider.base_llm import BaseLLM - +import importlib class LLMProviderRegistry: def __init__(self): @@ -18,6 +18,10 @@ def register(self, key, provider_cls): def get_provider(self, enum: LLMType): """get provider instance according to the enum""" + if enum not in self.providers: + # Import and register the provider if not already registered + module_name = LLMModuleMap[enum] + importlib.import_module(module_name) return self.providers[enum]