Skip to content

Commit

Permalink
update langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 23, 2024
1 parent ab88b79 commit ab8bfe4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from erniebot_agent.prompt import PromptTemplate

logger = logging.getLogger(__name__)
MAX_RETRY = 10
PLAN_VERIFICATIONS_PROMPT = """
为了验证给出的内容中数字性表述的准确性,您需要创建一系列验证问题,
用于测试原始基线响应中的事实主张。例如,如果长格式响应的一部分包含
Expand Down Expand Up @@ -222,11 +223,24 @@ async def report_fact(self, report: str):
messages: List[Message] = [
HumanMessage(content=self.prompt_plan_verifications.format(base_context=item))
]
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(item, fact_check_result)
text.append(new_item)
retry_count = 0
while True:
try:
responese = await self.llm.chat(messages)
result: List[dict] = self.parse_json(responese.content)
fact_check_result: List[dict] = await self.verifications(result)
new_item: str = await self.generate_final_response(item, fact_check_result)
text.append(new_item)
break
except Exception as e:
retry_count += 1
logger.error(e)
if retry_count > MAX_RETRY:
raise Exception(
f"Failed to edit research for {report} after {MAX_RETRY} times."
)
continue

else:
text.append(item)
return "\n\n".join(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage
)
from llama_index.node_parser import SentenceSplitter
from llama_index.vector_stores.faiss import FaissVectorStore

from erniebot_agent.memory import HumanMessage, Message
Expand Down Expand Up @@ -90,7 +92,6 @@ async def run(self, data_dir, url_path=None):
url = url_dict.get(item.metadata["source"], item.metadata["source"])
if url_dict and item.metadata["source"] in url_dict:
item.metadata["source"] = url_dict[item.metadata["source"]]

summary_list.append({"page_content": summary, "url": url, "name": item.metadata["source"]})
self.write_json(summary_list)
return self.path
Expand Down Expand Up @@ -179,21 +180,22 @@ def build_index_llama(
if os.path.exists(faiss_name):
vector_store = FaissVectorStore.from_persist_dir(persist_dir=faiss_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=faiss_name)
service_context = ServiceContext.from_defaults(embed_model=embeddings)
index = load_index_from_storage(storage_context=storage_context, service_context=service_context)
return index
if not abstract and not use_data:
documents = SimpleDirectoryReader(path).load_data()
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=faiss_name)
return storage_context
from llama_index.node_parser import SentenceSplitter

documents = SimpleDirectoryReader(path).load_data()
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=faiss_name)
return storage_context


def get_retriver_by_type(frame_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ async def add_url_sentences(self, sententces: str, citation_faiss_research):
sentence += "。"
for item in query_result["documents"]:
source = item["meta"]["url"]
breakpoint()
if item["score"] >= self.theta_min and item["score"] <= self.theta_max:
if source not in self.recoder_cite_list:
self.recoder_cite_title.append(item["name"])
self.recoder_cite_title.append(item["meta"]["name"])
self.recoder_cite_list.append(source)
self.recoder_cite_dict[source] = 1
index = len(self.recoder_cite_list)
Expand Down

0 comments on commit ab8bfe4

Please sign in to comment.