diff --git a/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py b/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py index 740b5d98..a94f44a9 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py @@ -1,12 +1,13 @@ import json import os -from typing import Any, Dict, List, Optional +from typing import Any, List import faiss import jsonlines from langchain.docstore.document import Document from langchain.text_splitter import SpacyTextSplitter from langchain.vectorstores import FAISS +from llama_index.schema import TextNode from langchain_community.document_loaders import DirectoryLoader from llama_index import ( ServiceContext, @@ -73,10 +74,7 @@ def write_json(self, data: List[dict]): async def run(self, data_dir, url_path=None): if url_path: - _, suffix = os.path.splitext(url_path) - assert suffix == ".json" - with open(url_path) as f: - url_dict = json.load(f) + url_dict = get_url(url_path) else: url_dict = None docs = self.data_load(data_dir) @@ -91,12 +89,15 @@ async def run(self, data_dir, url_path=None): return self.path -def add_url(url_dict: Dict, path: Optional[str] = None): - if not path: - path = "./url.json" - json_str = json.dumps(url_dict, ensure_ascii=False) - with open(path, "w") as json_file: - json_file.write(json_str) +def get_url(url_path): + with open(url_path, "r") as f: + data = f.read() + data_list = data.split("\n") + url_dict = {} + for item in data_list: + url, path = item.split(" ") + url_dict[path] = url + return url_dict def preprocess(data_dir, url_path=None, use_langchain=True): @@ -104,10 +105,7 @@ def preprocess(data_dir, url_path=None, use_langchain=True): loader = DirectoryLoader(path=data_dir) docs = loader.load() if url_path: - _, suffix = os.path.splitext(url_path) - assert suffix == ".json" - with open(url_path) as f: - url_dict = json.load(f) + url_dict = get_url(url_path) for item in docs: if "source" not in item.metadata: item.metadata["source"] = "" @@ -119,10 +117,7 @@ def preprocess(data_dir, url_path=None, use_langchain=True): else: docs = SimpleDirectoryReader(data_dir).load_data() if url_path: - _, suffix = os.path.splitext(url_path) - assert suffix == ".json" - with open(url_path) as f: - url_dict = json.load(f) + url_dict = get_url(url_path) for item in docs: if "source" not in item.metadata: item.metadata["source"] = item.metadata["file_path"] @@ -134,33 +129,45 @@ def preprocess(data_dir, url_path=None, use_langchain=True): return docs +def get_abstract_data(path, use_langchain=True): + all_docs = [] + with jsonlines.open(path) as reader: + for obj in reader: + if type(obj) is list: + for item in obj: + if "url" in item: + metadata = {"url": item["url"], "name": item["name"]} + else: + metadata = {"name": item["name"]} + if use_langchain: + doc = Document(page_content=item["page_content"], metadata=metadata) + else: + doc = TextNode(text=item["page_content"], metadata=metadata) + all_docs.append(doc) + elif type(obj) is dict: + if "url" in obj: + metadata = {"url": obj["url"], "name": obj["name"]} + else: + metadata = {"name": obj["name"]} + if use_langchain: + doc = Document(page_content=obj["page_content"], metadata=metadata) + else: + doc = TextNode(text=item["page_content"], metadata=metadata) + all_docs.append(doc) + return all_docs + + def build_index_langchain( - index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None, use_data=False + index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None ): if os.path.exists(index_name): db = FAISS.load_local(index_name, embeddings) - elif abstract and not use_data: - all_docs = [] - with jsonlines.open(path) as reader: - for obj in reader: - if type(obj) is list: - for item in obj: - if "url" in item: - metadata = {"url": item["url"], "name": item["name"]} - else: - metadata = {"name": item["name"]} - doc = Document(page_content=item["page_content"], metadata=metadata) - all_docs.append(doc) - elif type(obj) is dict: - if "url" in obj: - metadata = {"url": obj["url"], "name": obj["name"]} - else: - metadata = {"name": obj["name"]} - doc = Document(page_content=obj["page_content"], metadata=metadata) - all_docs.append(doc) + elif abstract: + all_docs = get_abstract_data(path) db = FAISS.from_documents(all_docs, embeddings) db.save_local(index_name) - elif not abstract and not use_data: + + elif not abstract and not origin_data: documents = preprocess(path, url_path) text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0) docs = text_splitter.split_documents(documents) @@ -170,15 +177,13 @@ def build_index_langchain( docs_tackle.append(item) db = FAISS.from_documents(docs_tackle, embeddings) db.save_local(index_name) - elif use_data: + elif origin_data: db = FAISS.from_documents(origin_data, embeddings) db.save_local(index_name) return db -def build_index_llama( - index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None, use_data=False -): +def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None): if embeddings.model == "text-embedding-ada-002": d = 1536 elif embeddings.model == "ernie-text-embedding": @@ -194,7 +199,7 @@ def build_index_llama( service_context = ServiceContext.from_defaults(embed_model=embeddings) index = load_index_from_storage(storage_context=storage_context, service_context=service_context) return index - if not abstract and not use_data: + if not abstract and not origin_data: documents = preprocess(path, url_path=url_path, use_langchain=False) text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) storage_context = StorageContext.from_defaults(vector_store=vector_store) @@ -208,27 +213,7 @@ def build_index_llama( index.storage_context.persist(persist_dir=index_name) return index elif abstract: - all_docs = [] - from llama_index.schema import TextNode - - with jsonlines.open(path) as reader: - for obj in reader: - if type(obj) is list: - for item in obj: - if "url" in item: - metadata = {"url": item["url"], "name": item["name"]} - else: - metadata = {"name": item["name"]} - doc = {"text": item["page_content"], "metadata": metadata} - all_docs.append(doc) - elif type(obj) is dict: - if "url" in obj: - metadata = {"url": obj["url"], "name": obj["name"]} - else: - metadata = {"name": obj["name"]} - doc = {"text": item["page_content"], "metadata": metadata} - all_docs.append(doc) - nodes = [TextNode(text=item["text"], metadata=item["metadata"]) for item in all_docs] + nodes = get_abstract_data(path, use_langchain=False) text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) storage_context = StorageContext.from_defaults(vector_store=vector_store) service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter) @@ -297,41 +282,29 @@ def parse_arguments(): embeddings: Any = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada") else: embeddings = ErnieEmbeddings(aistudio_access_token=access_token) - _, suffix = os.path.splitext(args.url_path) - if suffix == ".txt": - url_path = args.url_path.replace(".txt", ".json") - with open(args.url_path, "r") as f: - data = f.read() - data_list = data.split("\n") - url_dict = {} - for item in data_list: - url, path = item.split(" ") - url_dict[path] = url - url_path - add_url(url_dict, path=url_path) - else: - url_path = args.url_path if not args.path_abstract: from erniebot_agent.chat_models import ERNIEBot llm = ERNIEBot(model="ernie-4.0") generate_abstract = GenerateAbstract(llm=llm) - abstract_path = asyncio.run(generate_abstract.run(args.path_full_text, url_path)) + abstract_path = asyncio.run(generate_abstract.run(args.path_full_text, args.url_path)) else: abstract_path = args.path_abstract + build_index_fuction, retrieval_tool = get_retriver_by_type(args.framework) + full_text_db = build_index_fuction( index_name=args.index_name_full_text, embeddings=embeddings, path=args.path_full_text, - url_path=url_path, + url_path=args.url_path, ) abstract_db = build_index_fuction( index_name=args.index_name_abstract, embeddings=embeddings, path=abstract_path, abstract=True, - url_path=url_path, + url_path=args.url_path, ) retrieval_full = retrieval_tool(full_text_db) retrieval_abstract = retrieval_tool(abstract_db) diff --git a/erniebot-agent/applications/erniebot_researcher/tools/utils.py b/erniebot-agent/applications/erniebot_researcher/tools/utils.py index 4fcc2b70..cfdf6594 100644 --- a/erniebot-agent/applications/erniebot_researcher/tools/utils.py +++ b/erniebot-agent/applications/erniebot_researcher/tools/utils.py @@ -153,9 +153,7 @@ def add_citation(paragraphs, index_name, embeddings, build_index, SearchTool): for item in paragraphs: example = Document(page_content=item["summary"], metadata={"url": item["url"], "name": item["name"]}) list_data.append(example) - faiss_db = build_index( - index_name=index_name, use_data=True, embeddings=embeddings, origin_data=list_data - ) + faiss_db = build_index(index_name=index_name, embeddings=embeddings, origin_data=list_data) faiss_search = SearchTool(db=faiss_db) return faiss_search