From 453b93220ac1d2075b4f5942febd227390f5f4c3 Mon Sep 17 00:00:00 2001 From: david Date: Wed, 6 Nov 2024 21:06:52 +0800 Subject: [PATCH] postpone the import of external different providers to lift the requirement forcing user installing different non-related provider apis. --- metagpt/provider/__init__.py | 42 ++++++++++++++++------- metagpt/provider/llm_provider_registry.py | 27 ++++++++++++++- 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index c90f5774a..1abc18862 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -5,20 +5,36 @@ @Author : alexanderwu @File : __init__.py """ +import importlib +class LLMFactory: + def __init__(self, module_name, instance_name): + self.module_name = module_name + self.instance_name = instance_name + self._module = None -from metagpt.provider.google_gemini_api import GeminiLLM -from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.openai_api import OpenAILLM -from metagpt.provider.zhipuai_api import ZhiPuAILLM -from metagpt.provider.azure_openai_api import AzureOpenAILLM -from metagpt.provider.metagpt_api import MetaGPTLLM -from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.spark_api import SparkLLM -from metagpt.provider.qianfan_api import QianFanLLM -from metagpt.provider.dashscope_api import DashScopeLLM -from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock_api import BedrockLLM -from metagpt.provider.ark_api import ArkLLM + def __getattr__(self, name): + if self._module is None: + self._module = importlib.import_module(self.module_name) + return getattr(self._module, name) + def __instancecheck__(self, instance): + if self._module is None: + self._module = importlib.import_module(self.module_name) + return isinstance(instance, getattr(self._module, self.instance_name)) + + +GeminiLLM = LLMFactory("metagpt.provider.google_gemini_api ", "GeminiLLM") +OllamaLLM = LLMFactory("metagpt.provider.ollama_api ", "OllamaLLM") +OpenAILLM = LLMFactory("metagpt.provider.openai_api ", "OpenAILLM") +ZhiPuAILLM = LLMFactory("metagpt.provider.zhipuai_api ", "ZhiPuAILLM") +AzureOpenAILLM = LLMFactory("metagpt.provider.azure_openai_api ", "AzureOpenAILLM") +MetaGPTLLM = LLMFactory("metagpt.provider.metagpt_api ", "MetaGPTLLM") +HumanProvider = LLMFactory("metagpt.provider.human_provider ", "HumanProvider") +SparkLLM = LLMFactory("metagpt.provider.spark_api ", "SparkLLM") +QianFanLLM = LLMFactory("metagpt.provider.qianfan_api ", "QianFanLLM") +DashScopeLLM = LLMFactory("metagpt.provider.dashscope_api ", "DashScopeLLM") +AnthropicLLM = LLMFactory("metagpt.provider.anthropic_api ", "AnthropicLLM") +BedrockLLM = LLMFactory("metagpt.provider.bedrock_api ", "BedrockLLM") +ArkLLM = LLMFactory("metagpt.provider.ark_api ", "ArkLLM") __all__ = [ "GeminiLLM", diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 7f8618590..8e43451bd 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -7,17 +7,42 @@ """ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM - +import importlib class LLMProviderRegistry: def __init__(self): self.providers = {} + self._module_map = { + LLMType.OPENAI: "metagpt.provider.openai_api", + LLMType.ANTHROPIC: "metagpt.provider.anthropic_api", + LLMType.CLAUDE: "metagpt.provider.anthropic_api", # Same module as Anthropic + LLMType.SPARK: "metagpt.provider.spark_api", + LLMType.ZHIPUAI: "metagpt.provider.zhipuai_api", + LLMType.FIREWORKS: "metagpt.provider.fireworks_api", + LLMType.OPEN_LLM: "metagpt.provider.open_llm_api", + LLMType.GEMINI: "metagpt.provider.google_gemini_api", + LLMType.METAGPT: "metagpt.provider.metagpt_api", + LLMType.AZURE: "metagpt.provider.azure_openai_api", + LLMType.OLLAMA: "metagpt.provider.ollama_api", + LLMType.QIANFAN: "metagpt.provider.qianfan_api", # Baidu BCE + LLMType.DASHSCOPE: "metagpt.provider.dashscope_api", # Aliyun LingJi DashScope + LLMType.MOONSHOT: "metagpt.provider.moonshot_api", + LLMType.MISTRAL: "metagpt.provider.mistral_api", + LLMType.YI: "metagpt.provider.yi_api", # lingyiwanwu + LLMType.OPENROUTER: "metagpt.provider.openrouter_api", + LLMType.BEDROCK: "metagpt.provider.bedrock_api", + LLMType.ARK: "metagpt.provider.ark_api", + } def register(self, key, provider_cls): self.providers[key] = provider_cls def get_provider(self, enum: LLMType): """get provider instance according to the enum""" + if enum not in self.providers: + # Import and register the provider if not already registered + module_name = self._module_map[enum] + importlib.import_module(module_name) return self.providers[enum]