Skip to content

Commit

Permalink
update llama_index
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 24, 2024
1 parent c2284b3 commit 686b496
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
import os
from typing import Any, List

import faiss
import jsonlines
from langchain.docstore.document import Document
from langchain.text_splitter import SpacyTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader
from llama_index import SimpleDirectoryReader
from llama_index import (
ServiceContext,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import TextNode
from llama_index.vector_stores.faiss import FaissVectorStore

from erniebot_agent.memory import HumanMessage, Message
from erniebot_agent.prompt import PromptTemplate
Expand Down Expand Up @@ -175,8 +184,60 @@ def build_index_langchain(


def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract=False, origin_data=None):
# TODO: Adapt to llamaindex
pass
if embeddings.model == "text-embedding-ada-002":
d = 1536
elif embeddings.model == "ernie-text-embedding":
d = 384
else:
raise ValueError(f"model {embeddings.model} not support")

faiss_index = faiss.IndexFlatIP(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
if os.path.exists(index_name):
vector_store = FaissVectorStore.from_persist_dir(persist_dir=index_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=index_name)
service_context = ServiceContext.from_defaults(embed_model=embeddings)
index = load_index_from_storage(storage_context=storage_context, service_context=service_context)
return index
if not abstract and not origin_data:
documents = preprocess(path, url_path=url_path, use_langchain=False)
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index
elif abstract:
nodes = get_abstract_data(path, use_langchain=False)
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index
elif origin_data:
nodes = [TextNode(text=item.page_content, metadata=item.metadata) for item in origin_data]
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index


def get_retriver_by_type(frame_type):
Expand Down
2 changes: 1 addition & 1 deletion erniebot-agent/applications/erniebot_researcher/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def generate_report(query, history=[]):
agent_sets = get_agents(
retriever_sets, tool_sets, llm, llm_long, dir_path, target_path, build_index_function, retrieval_tool
)
team_actor = ResearchTeam(**agent_sets, use_reflection=False)
team_actor = ResearchTeam(**agent_sets, use_reflection=True)
report, path = asyncio.run(team_actor.run(query, args.iterations))
return report, path

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,13 @@ def __init__(
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
# TODO: Adapt to llamaindex
pass
retriever = self.db.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
docs = []
for doc in nodes:
if doc.score > self.threshold:
new_doc = {"content": doc.node.text, "score": doc.score}
if self.return_meta_data:
new_doc["meta"] = doc.metadata
docs.append(new_doc)
return {"documents": docs}

0 comments on commit 686b496

Please sign in to comment.