From 686b4960f2a467485510514aa27fc499e5fe96e7 Mon Sep 17 00:00:00 2001 From: qingzhong1 Date: Wed, 24 Jan 2024 05:47:40 +0000 Subject: [PATCH] update llama_index --- .../tools/preprocessing.py | 67 ++++++++++++++++++- .../applications/erniebot_researcher/ui.py | 2 +- .../tools/llama_index_retrieval_tool.py | 12 +++- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py b/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py index a503016a..cbf8a8bb 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py @@ -2,13 +2,22 @@ import os from typing import Any, List +import faiss import jsonlines from langchain.docstore.document import Document from langchain.text_splitter import SpacyTextSplitter from langchain.vectorstores import FAISS from langchain_community.document_loaders import DirectoryLoader -from llama_index import SimpleDirectoryReader +from llama_index import ( + ServiceContext, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, + load_index_from_storage, +) +from llama_index.node_parser import SentenceSplitter from llama_index.schema import TextNode +from llama_index.vector_stores.faiss import FaissVectorStore from erniebot_agent.memory import HumanMessage, Message from erniebot_agent.prompt import PromptTemplate @@ -175,8 +184,60 @@ def build_index_langchain( def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None): - # TODO: Adapt to llamaindex - pass + if embeddings.model == "text-embedding-ada-002": + d = 1536 + elif embeddings.model == "ernie-text-embedding": + d = 384 + else: + raise ValueError(f"model {embeddings.model} not support") + + faiss_index = faiss.IndexFlatIP(d) + vector_store = FaissVectorStore(faiss_index=faiss_index) + if os.path.exists(index_name): + vector_store = FaissVectorStore.from_persist_dir(persist_dir=index_name) + storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=index_name) + service_context = ServiceContext.from_defaults(embed_model=embeddings) + index = load_index_from_storage(storage_context=storage_context, service_context=service_context) + return index + if not abstract and not origin_data: + documents = preprocess(path, url_path=url_path, use_langchain=False) + text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter) + index = VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + show_progress=True, + service_context=service_context, + ) + index.storage_context.persist(persist_dir=index_name) + return index + elif abstract: + nodes = get_abstract_data(path, use_langchain=False) + text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter) + index = VectorStoreIndex( + nodes, + storage_context=storage_context, + show_progress=True, + service_context=service_context, + ) + index.storage_context.persist(persist_dir=index_name) + return index + elif origin_data: + nodes = [TextNode(text=item.page_content, metadata=item.metadata) for item in origin_data] + text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter) + index = VectorStoreIndex( + nodes, + storage_context=storage_context, + show_progress=True, + service_context=service_context, + ) + index.storage_context.persist(persist_dir=index_name) + return index def get_retriver_by_type(frame_type): diff --git a/erniebot-agent/applications/erniebot_researcher/ui.py b/erniebot-agent/applications/erniebot_researcher/ui.py index a21bc667..6f4b157c 100644 --- a/erniebot-agent/applications/erniebot_researcher/ui.py +++ b/erniebot-agent/applications/erniebot_researcher/ui.py @@ -223,7 +223,7 @@ def generate_report(query, history=[]): agent_sets = get_agents( retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool ) - team_actor = ResearchTeam(**agent_sets, use_reflection=False) + team_actor = ResearchTeam(**agent_sets, use_reflection=True) report, path = asyncio.run(team_actor.run(query, args.iterations)) return report, path diff --git a/erniebot-agent/src/erniebot_agent/tools/llama_index_retrieval_tool.py b/erniebot-agent/src/erniebot_agent/tools/llama_index_retrieval_tool.py index ffb6e68b..4c4d13fa 100644 --- a/erniebot-agent/src/erniebot_agent/tools/llama_index_retrieval_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/llama_index_retrieval_tool.py @@ -42,5 +42,13 @@ def __init__( self.threshold = threshold async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None): - # TODO: Adapt to llamaindex - pass + retriever = self.db.as_retriever(similarity_top_k=top_k) + nodes = retriever.retrieve(query) + docs = [] + for doc in nodes: + if doc.score > self.threshold: + new_doc = {"content": doc.node.text, "score": doc.score} + if self.return_meta_data: + new_doc["meta"] = doc.metadata + docs.append(new_doc) + return {"documents": docs}