Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修改实体和关系提取方式,分为两步进行,首先提取实体,然后根据实体提取关系。 #401

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lightrag/kg/tidb_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def upsert(self, data: dict[str, dict]):
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f"{item["__vector__"].tolist()}",
"content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace,
}
)
Expand Down Expand Up @@ -286,7 +286,7 @@ async def upsert(self, data: dict[str, dict]):
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update entity_id if node inserted by graph_storage_instance before
Expand All @@ -308,7 +308,7 @@ async def upsert(self, data: dict[str, dict]):
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item["content_vector"].tolist()}",
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
}
# update relation_id if node inserted by graph_storage_instance before
Expand Down
79 changes: 65 additions & 14 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os

from lightrag.operate import chunking_by_markdown_header
from tqdm.asyncio import tqdm as tqdm_async
from dataclasses import asdict, dataclass, field
from datetime import datetime
Expand All @@ -12,6 +14,8 @@
)
from .operate import (
chunking_by_token_size,
chunking_by_markdown_header,
chunking_by_markdown_text,
extract_entities,
# local_query,global_query,hybrid_query,
kg_query,
Expand Down Expand Up @@ -46,6 +50,13 @@
# GraphStorage as ArangoDBStorage
# )

# 存在路径问题,不使用动态导入 bumaple 2024-12-10
# from lightrag.kg.mongo_impl import MongoKVStorage
# from lightrag.kg.neo4j_impl import Neo4JStorage
# from lightrag.kg.milvus_impl import MilvusVectorDBStorge
# from lightrag.kg.oracle_impl import OracleKVStorage, OracleGraphStorage, OracleVectorDBStorage
# from lightrag.kg.chroma_impl import ChromaVectorDBStorage


def lazy_external_import(module_name: str, class_name: str):
"""Lazily import a class from an external module based on the package of the caller."""
Expand Down Expand Up @@ -172,10 +183,20 @@ class LightRAG:
addon_params: dict = field(default_factory=dict)
convert_response_to_json_func: callable = convert_response_to_json

# 自定义新增 主实体编号、名称 by bumaple 2024-12-03
extend_entity_title: str = ''
extend_entity_sn: str = ''
# 自定义新增 块类型 by bumaple 2024-12-11
chunk_type: str = 'token_size'
# 自定义新增 块标题层级 by bumaple 2024-12-11
chunk_header_level: int = 2
# 采用实体、关系分步骤识别 True:分步骤识别 False:合并识别
entity_relationship_extraction_step: bool = False

def __post_init__(self):
log_file = os.path.join("lightrag.log")
set_logger(log_file)
logger.setLevel(self.log_level)
log_file = os.path.join(self.working_dir, "lightrag.log")
set_logger(log_file, self.log_level)
# logger.setLevel(self.log_level)

logger.info(f"Logger initialized for working directory: {self.working_dir}")

Expand Down Expand Up @@ -315,18 +336,48 @@ async def ainsert(self, string_or_strings):
for doc_key, doc in tqdm_async(
new_docs.items(), desc="Chunking documents", unit="doc"
):
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
if self.chunk_type == "markdown_header":
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in chunking_by_markdown_header(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
extend_entity_title=self.extend_entity_title,
extend_entity_sn=self.extend_entity_sn,
chunk_header_level=self.chunk_header_level,
)
}
elif self.chunk_type == "markdown_text":
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in chunking_by_markdown_text(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
extend_entity_title=self.extend_entity_title,
extend_entity_sn=self.extend_entity_sn,
)
}
else:
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)
_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
Expand Down
3 changes: 2 additions & 1 deletion lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ async def openai_embedding(
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
timeout: float = 60,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
Expand All @@ -863,7 +864,7 @@ async def openai_embedding(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
model=model, input=texts, encoding_format="float", timeout=timeout
)
return np.array([dp.embedding for dp in response.data])

Expand Down
Loading