Skip to content

Commit

Permalink
update langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 23, 2024
1 parent 600a443 commit 1743083
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 86 deletions.
139 changes: 56 additions & 83 deletions erniebot-agent/applications/erniebot_researcher/tools/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import os
from typing import Any, Dict, List, Optional
from typing import Any, List

import faiss
import jsonlines
from langchain.docstore.document import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.vectorstores import FAISS
from llama_index.schema import TextNode
from langchain_community.document_loaders import DirectoryLoader
from llama_index import (
ServiceContext,
Expand Down Expand Up @@ -73,10 +74,7 @@ def write_json(self, data: List[dict]):

async def run(self, data_dir, url_path=None):
if url_path:
_, suffix = os.path.splitext(url_path)
assert suffix == ".json"
with open(url_path) as f:
url_dict = json.load(f)
url_dict = get_url(url_path)
else:
url_dict = None
docs = self.data_load(data_dir)
Expand All @@ -91,23 +89,23 @@ async def run(self, data_dir, url_path=None):
return self.path


def add_url(url_dict: Dict, path: Optional[str] = None):
if not path:
path = "./url.json"
json_str = json.dumps(url_dict, ensure_ascii=False)
with open(path, "w") as json_file:
json_file.write(json_str)
def get_url(url_path):
with open(url_path, "r") as f:
data = f.read()
data_list = data.split("\n")
url_dict = {}
for item in data_list:
url, path = item.split(" ")
url_dict[path] = url
return url_dict


def preprocess(data_dir, url_path=None, use_langchain=True):
if use_langchain:
loader = DirectoryLoader(path=data_dir)
docs = loader.load()
if url_path:
_, suffix = os.path.splitext(url_path)
assert suffix == ".json"
with open(url_path) as f:
url_dict = json.load(f)
url_dict = get_url(url_path)
for item in docs:
if "source" not in item.metadata:
item.metadata["source"] = ""
Expand All @@ -119,10 +117,7 @@ def preprocess(data_dir, url_path=None, use_langchain=True):
else:
docs = SimpleDirectoryReader(data_dir).load_data()
if url_path:
_, suffix = os.path.splitext(url_path)
assert suffix == ".json"
with open(url_path) as f:
url_dict = json.load(f)
url_dict = get_url(url_path)
for item in docs:
if "source" not in item.metadata:
item.metadata["source"] = item.metadata["file_path"]
Expand All @@ -134,33 +129,45 @@ def preprocess(data_dir, url_path=None, use_langchain=True):
return docs


def get_abstract_data(path, use_langchain=True):
all_docs = []
with jsonlines.open(path) as reader:
for obj in reader:
if type(obj) is list:
for item in obj:
if "url" in item:
metadata = {"url": item["url"], "name": item["name"]}
else:
metadata = {"name": item["name"]}
if use_langchain:
doc = Document(page_content=item["page_content"], metadata=metadata)
else:
doc = TextNode(text=item["page_content"], metadata=metadata)
all_docs.append(doc)
elif type(obj) is dict:
if "url" in obj:
metadata = {"url": obj["url"], "name": obj["name"]}
else:
metadata = {"name": obj["name"]}
if use_langchain:
doc = Document(page_content=obj["page_content"], metadata=metadata)
else:
doc = TextNode(text=item["page_content"], metadata=metadata)
all_docs.append(doc)
return all_docs


def build_index_langchain(
index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None, use_data=False
index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None
):
if os.path.exists(index_name):
db = FAISS.load_local(index_name, embeddings)
elif abstract and not use_data:
all_docs = []
with jsonlines.open(path) as reader:
for obj in reader:
if type(obj) is list:
for item in obj:
if "url" in item:
metadata = {"url": item["url"], "name": item["name"]}
else:
metadata = {"name": item["name"]}
doc = Document(page_content=item["page_content"], metadata=metadata)
all_docs.append(doc)
elif type(obj) is dict:
if "url" in obj:
metadata = {"url": obj["url"], "name": obj["name"]}
else:
metadata = {"name": obj["name"]}
doc = Document(page_content=obj["page_content"], metadata=metadata)
all_docs.append(doc)
elif abstract:
all_docs = get_abstract_data(path)
db = FAISS.from_documents(all_docs, embeddings)
db.save_local(index_name)
elif not abstract and not use_data:

elif not abstract and not origin_data:
documents = preprocess(path, url_path)
text_splitter = SpacyTextSplitter(pipeline="zh_core_web_sm", chunk_size=1500, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
Expand All @@ -170,15 +177,13 @@ def build_index_langchain(
docs_tackle.append(item)
db = FAISS.from_documents(docs_tackle, embeddings)
db.save_local(index_name)
elif use_data:
elif origin_data:
db = FAISS.from_documents(origin_data, embeddings)
db.save_local(index_name)
return db


def build_index_llama(
index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None, use_data=False
):
def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None):
if embeddings.model == "text-embedding-ada-002":
d = 1536
elif embeddings.model == "ernie-text-embedding":
Expand All @@ -194,7 +199,7 @@ def build_index_llama(
service_context = ServiceContext.from_defaults(embed_model=embeddings)
index = load_index_from_storage(storage_context=storage_context, service_context=service_context)
return index
if not abstract and not use_data:
if not abstract and not origin_data:
documents = preprocess(path, url_path=url_path, use_langchain=False)
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
Expand All @@ -208,27 +213,7 @@ def build_index_llama(
index.storage_context.persist(persist_dir=index_name)
return index
elif abstract:
all_docs = []
from llama_index.schema import TextNode

with jsonlines.open(path) as reader:
for obj in reader:
if type(obj) is list:
for item in obj:
if "url" in item:
metadata = {"url": item["url"], "name": item["name"]}
else:
metadata = {"name": item["name"]}
doc = {"text": item["page_content"], "metadata": metadata}
all_docs.append(doc)
elif type(obj) is dict:
if "url" in obj:
metadata = {"url": obj["url"], "name": obj["name"]}
else:
metadata = {"name": obj["name"]}
doc = {"text": item["page_content"], "metadata": metadata}
all_docs.append(doc)
nodes = [TextNode(text=item["text"], metadata=item["metadata"]) for item in all_docs]
nodes = get_abstract_data(path, use_langchain=False)
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
Expand Down Expand Up @@ -297,41 +282,29 @@ def parse_arguments():
embeddings: Any = AzureOpenAIEmbeddings(azure_deployment="text-embedding-ada")
else:
embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
_, suffix = os.path.splitext(args.url_path)
if suffix == ".txt":
url_path = args.url_path.replace(".txt", ".json")
with open(args.url_path, "r") as f:
data = f.read()
data_list = data.split("\n")
url_dict = {}
for item in data_list:
url, path = item.split(" ")
url_dict[path] = url
url_path
add_url(url_dict, path=url_path)
else:
url_path = args.url_path
if not args.path_abstract:
from erniebot_agent.chat_models import ERNIEBot

llm = ERNIEBot(model="ernie-4.0")
generate_abstract = GenerateAbstract(llm=llm)
abstract_path = asyncio.run(generate_abstract.run(args.path_full_text, url_path))
abstract_path = asyncio.run(generate_abstract.run(args.path_full_text, args.url_path))
else:
abstract_path = args.path_abstract

build_index_fuction, retrieval_tool = get_retriver_by_type(args.framework)

full_text_db = build_index_fuction(
index_name=args.index_name_full_text,
embeddings=embeddings,
path=args.path_full_text,
url_path=url_path,
url_path=args.url_path,
)
abstract_db = build_index_fuction(
index_name=args.index_name_abstract,
embeddings=embeddings,
path=abstract_path,
abstract=True,
url_path=url_path,
url_path=args.url_path,
)
retrieval_full = retrieval_tool(full_text_db)
retrieval_abstract = retrieval_tool(abstract_db)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ def add_citation(paragraphs, index_name, embeddings, build_index, SearchTool):
for item in paragraphs:
example = Document(page_content=item["summary"], metadata={"url": item["url"], "name": item["name"]})
list_data.append(example)
faiss_db = build_index(
index_name=index_name, use_data=True, embeddings=embeddings, origin_data=list_data
)
faiss_db = build_index(index_name=index_name, embeddings=embeddings, origin_data=list_data)
faiss_search = SearchTool(db=faiss_db)
return faiss_search

Expand Down

0 comments on commit 1743083

Please sign in to comment.