-
Notifications
You must be signed in to change notification settings - Fork 0
/
rag_demo.py
290 lines (256 loc) · 13.1 KB
/
rag_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
"""
# 导入必要的库与模块
import json
import os
import textwrap
import requests
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, TensorflowHubEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Weaviate
from weaviate import Client
from weaviate.embedded import EmbeddedOptions
from zhipuai import ZhipuAI
from openai import AzureOpenAI
# 环境设置与文档下载
load_dotenv() # 加载环境变量
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
MIMIMAX_API_KEY = os.getenv("MIMIMAX_API_KEY")
MIMIMAX_GROUP_ID = os.getenv("MIMIMAX_GROUP_ID")
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
KIMI_OPENAI_API_KEY = os.getenv("KIMI_OPENAI_API_KEY")
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY")
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT")
# 确保 OPENAI_API_KEY 被正确设置
if not OPENAI_API_KEY:
raise ValueError("OpenAI API Key not found in the environment variables.")
# 文档加载与分割
def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
"""加载文档并分割成小块"""
loader = TextLoader(file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_documents(documents)
return chunks
# 向量存储建立
def create_vector_store(chunks, model="OpenAI"):
"""将文档块转换为向量并存储到 Weaviate 中"""
client = Client(embedded_options=EmbeddedOptions())
if model == "OpenAI":
embedding_model = OpenAIEmbeddings()
elif model == "HuggingFace":
embedding_model = HuggingFaceEmbeddings()
elif model == "TensorflowHub":
embedding_model = TensorflowHubEmbeddings()
else:
raise ValueError(f"Unsupported embedding model: {model}")
vectorstore = Weaviate.from_documents(
client=client,
documents=chunks,
embedding=embedding_model,
by_text=False
)
return vectorstore
def get_retriever(vectorstore, k=4):
return vectorstore.as_retriever(search_kwargs={'k': k})
def setup_rag_chain(model_name="gpt-4", temperature=0):
"""设置检索增强生成流程"""
if model_name.startswith("gpt"):
# 如果是以gpt开头的模型,使用原来的逻辑
prompt_template = """
您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的**上下文信息**来形成回答,并适当参照**对话历史**。
如果**上下文信息**与**问题**无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。
**对话历史**: {conversation_history}
**问题**: {question}
**上下文信息**: {context}
**回答**:
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
# 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
rag_chain = (
prompt
| llm
| StrOutputParser()
)
else:
# 如果不是以gpt开头的模型,返回None
rag_chain = None
return rag_chain
# 执行查询并打印结果
def execute_query(retriever, rag_chain, query, model_name="gpt-4", temperature=0):
"""
执行查询并返回结果及检索到的文档块
参数:
retriever: 文档检索器对象
rag_chain: 检索增强生成链对象,如果为None则不使用RAG链
query: 查询问题
model_name: 使用的语言模型名称,默认为"gpt-4"
temperature: 生成温度,默认为0
返回:
retrieved_documents: 检索到的文档块列表
response_text: 生成的回答文本
"""
if isinstance(query, list):
[conversation_history, question] = query
else:
conversation_history = ''
question = query
# 使用检索器检索相关文档块
retrieved_documents = retriever.invoke(question)
if rag_chain is not None:
# 如果有RAG链,则使用RAG链生成回答
rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": question})
response_text = rag_chain_response
else:
prompt_template = """
【对话历史】: {conversation_history}
【上下文信息】: {context}
您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的【上下文信息】来形成回答,并适当参照【对话历史】。
如果【上下文信息】与【问题】无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。
【问题】: {question}
【回答】:
"""
context = '\n'.join(
[retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
prompt = prompt_template.format(conversation_history=conversation_history, question=question, context=context)
response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
return retrieved_documents, response_text
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
"""执行无 RAG 链的查询"""
if model_name.startswith("gpt"):
# 如果是以gpt开头的模型,使用原来的逻辑
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
response = llm.invoke(query)
return response.content
elif model_name.startswith("azure_gpt"):
client = AzureOpenAI(
azure_endpoint=AZURE_ENDPOINT,
api_key=AZURE_OPENAI_KEY,
api_version="2024-02-15-preview"
)
message_text = [{"role": "user", "content": query}, ]
completion = client.chat.completions.create(
model=model_name[6:], # model_name = 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'
messages=message_text,
temperature=temperature,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
return completion.choices[0].message.content
elif model_name == 'abab6-chat':
# 如果是'abab6-chat'模型,使用专门的API调用方式
url = "https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId=" + MIMIMAX_GROUP_ID
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + MIMIMAX_API_KEY}
payload = {
"bot_setting": [
{
"bot_name": "MM智能助理",
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
}
],
"messages": [{"sender_type": "USER", "sender_name": "小明", "text": query}],
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
"model": model_name,
"tokens_to_generate": 1034,
"temperature": temperature,
"top_p": 0.9,
}
response = requests.request("POST", url, headers=headers, json=payload)
# 将 JSON 字符串解析为字典
response_dict = json.loads(response.text)
# 提取 'reply' 键对应的值
return response_dict['reply']
elif model_name == 'glm-4':
# 如果是'glm-4'模型,使用专门的API调用方式
client = ZhipuAI(api_key=ZHIPUAI_API_KEY) # 填写您自己的APIKey
response = client.chat.completions.create(
model=model_name, # 填写需要调用的模型名称
messages=[{"role": "user", "content": query}]
)
return response.choices[0].message.content
elif model_name == 'kimi':
# 如果是'kimi'模型,使用专门的API调用方式
from openai import OpenAI
client = OpenAI(
api_key=KIMI_OPENAI_API_KEY,
base_url="https://api.moonshot.cn/v1",
)
messages = [
{
"role": "system",
"content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
},
{"role": "user",
"content": query},
]
completion = client.chat.completions.create(
# model="moonshot-v1-128k",
model="moonshot-v1-32k",
messages=messages,
temperature=temperature,
top_p=1.0,
n=1, # 为每条输入消息生成多少个结果
stream=False # 流式输出
)
return completion.choices[0].message.content
else:
# 如果模型不支持,抛出异常
raise ValueError(f"Unsupported model: {model_name}")
if __name__ == "__main__":
# 假设文档已存在于本地
file_path = './documents/LightZero_README_zh.md'
# model_name = "glm-4" # model_name=['abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo']
# model_name = 'azure_gpt-4'
model_name = 'kimi'
temperature = 0.01
embedding_model = 'OpenAI' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
# 加载和分割文档
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
# 创建向量存储
vectorstore = create_vector_store(chunks, model=embedding_model)
retriever = get_retriever(vectorstore, k=5)
# 设置 RAG 流程
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
# 提出问题并获取答案
query = ("请回答下面的问题:(1)请简要介绍一下 LightZero。(2)请详细介绍 LightZero 的框架结构。 (3)请给出安装 LightZero,运行他们的示例代码的详细步骤。(4)- 请问 LightZero 具体支持什么任务(tasks/environments)? (5)请问 LightZero 具体支持什么算法?(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行? (7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。(11)请问对这个仓库提出详细的改进建议")
"""
(1)请简要介绍一下 LightZero。
(2)请详细介绍 LightZero 的框架结构。
(3)请给出安装 LightZero,运行他们的示例代码的详细步骤 。
(4)请问 LightZero 具体支持什么任务(tasks/environments)?
(5)请问 LightZero 具体支持什么算法?
(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行?
(7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?
(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?
(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?
(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。
(11)请问对这个仓库提出详细的改进建议。
"""
# 使用 RAG 链获取参考的文档与答案
retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query, model_name=model_name,
temperature=temperature)
# 不使用 RAG 链获取答案
result_without_rag = execute_query_no_rag(model_name=model_name, query=query, temperature=temperature)
# 打印并对比两种方法的结果
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80)
wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80)
context = '\n'.join(
[f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
# 打印自动分段后的文本
print("=" * 40)
print(f"我的问题是:\n{query}")
print("=" * 40)
print(f"Result with RAG:\n{wrapped_result_with_rag}\n检索得到的context是: \n{context}")
print("=" * 40)
print(f"Result without RAG:\n{wrapped_result_without_rag}")
print("=" * 40)