Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
liqiang-fit2cloud committed Mar 25, 2024
2 parents 5ea26a4 + a91739a commit 1cf8008
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict

from langchain.schema import HumanMessage
from langchain_community.chat_models import AzureChatOpenAI
from langchain_community.chat_models.azure_openai import AzureChatOpenAI

from common import froms
from common.exception.app_exception import AppApiException
Expand All @@ -29,9 +29,6 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

if model_name not in model_dict:
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')

for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential:
if raise_exception:
Expand All @@ -40,7 +37,43 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='valid')])
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else:
return False

return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_base = froms.TextInputField('API 域名', required=True)

api_key = froms.PasswordInputField("API Key", required=True)

deployment_name = froms.TextInputField("部署名", required=True)


class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_base', 'api_key', 'deployment_name', 'api_version']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content='你好')])
except Exception as e:
if isinstance(e, AppApiException):
raise e
Expand All @@ -54,6 +87,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

api_version = froms.TextInputField("api_version", required=True)

api_base = froms.TextInputField('API 域名', required=True)

api_key = froms.PasswordInputField("API Key", required=True)
Expand All @@ -63,6 +98,8 @@ def encryption_dict(self, model: Dict[str, object]):

azure_llm_model_credential = AzureLLMModelCredential()

base_azure_llm_model_credential = DefaultAzureLLMModelCredential()

model_dict = {
'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential,
api_version='2023-07-01-preview'),
Expand All @@ -84,18 +121,18 @@ def get_model(self, model_type, model_name, model_credential: Dict[str, object],
model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatOpenAI(
openai_api_base=model_credential.get('api_base'),
openai_api_version=model_info.api_version,
openai_api_version=model_credential.get(
'api_version') if 'api_version' in model_credential else model_info.api_version,
deployment_name=model_credential.get('deployment_name'),
openai_api_key=model_credential.get('api_key'),
openai_api_type="azure",
tiktoken_model_name=model_name
openai_api_type="azure"
)
return azure_chat_open_ai

def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
raise AppApiException(500, f'不支持的模型:{model_name}')
return base_azure_llm_model_credential

def get_model_provide_info(self):
return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import os
from typing import Dict

from langchain_community.chat_models import QianfanChatEndpoint
from langchain.schema import HumanMessage
from langchain_community.chat_models import QianfanChatEndpoint
from qianfan import ChatCompletion

from common import froms
from common.exception.app_exception import AppApiException
Expand All @@ -27,10 +28,9 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
model_type_list = WenxinModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

if model_name not in model_dict:
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')

model_info = [model.lower() for model in ChatCompletion.models()]
if not model_info.__contains__(model_name.lower()):
raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持')
for key in ['api_key', 'secret_key']:
if key not in model_credential:
if raise_exception:
Expand All @@ -39,10 +39,9 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
return False
try:
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
[HumanMessage(content='valid')])
[HumanMessage(content='你好')])
except Exception as e:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, "校验失败,请检查 api_key secret_key 是否正确")
raise e
return True

def encryption_dict(self, model_info: Dict[str, object]):
Expand Down Expand Up @@ -121,7 +120,7 @@ def get_model_list(self, model_type):
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
raise AppApiException(500, f'不支持的模型:{model_name}')
return win_xin_llm_model_credential

def get_model_provide_info(self):
return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
openai = "^1.13.3"
tiktoken = "^0.5.1"
qianfan = "^0.1.1"
qianfan = "^0.3.6.1"
pycryptodome = "^3.19.0"
beautifulsoup4 = "^4.12.2"
html2text = "^2024.2.26"
Expand Down

0 comments on commit 1cf8008

Please sign in to comment.