Skip to content

Commit

Permalink
Merge pull request #444 from davidleon/fix/lazy_import
Browse files Browse the repository at this point in the history
Fix/lazy import
  • Loading branch information
LarFii authored Dec 11, 2024
2 parents 65e0e67 + 288d4b8 commit 7fbd9aa
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
19 changes: 13 additions & 6 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,25 @@


def lazy_external_import(module_name: str, class_name: str):
"""Lazily import an external module and return a class from it."""
"""Lazily import a class from an external module based on the package of the caller."""

def import_class():
# Get the caller's module and package
import inspect

caller_frame = inspect.currentframe().f_back
module = inspect.getmodule(caller_frame)
package = module.__package__ if module else None

def import_class(*args, **kwargs):
import importlib

# Import the module using importlib
module = importlib.import_module(module_name)
module = importlib.import_module(module_name, package=package)

# Get the class from the module
return getattr(module, class_name)
# Get the class from the module and instantiate it
cls = getattr(module, class_name)
return cls(*args, **kwargs)

# Return the import_class function itself, not its result
return import_class


Expand Down
1 change: 1 addition & 0 deletions lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def openai_complete_if_cache(
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
Expand Down
14 changes: 10 additions & 4 deletions lightrag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,16 @@ async def upsert(self, data: dict[str, dict]):
embeddings = await f
embeddings_list.append(embeddings)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
if len(embeddings) == len(list_data):
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
results = self._client.upsert(datas=list_data)
return results
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)

async def query(self, query: str, top_k=5):
embedding = await self.embedding_func([query])
Expand Down
21 changes: 20 additions & 1 deletion lightrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@

from lightrag.prompt import PROMPTS


class UnlimitedSemaphore:
"""A context manager that allows unlimited access."""

async def __aenter__(self):
pass

async def __aexit__(self, exc_type, exc, tb):
pass


ENCODER = None

logger = logging.getLogger("lightrag")
Expand All @@ -42,9 +53,17 @@ class EmbeddingFunc:
embedding_dim: int
max_token_size: int
func: callable
concurrent_limit: int = 16

def __post_init__(self):
if self.concurrent_limit != 0:
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
else:
self._semaphore = UnlimitedSemaphore()

async def __call__(self, *args, **kwargs) -> np.ndarray:
return await self.func(*args, **kwargs)
async with self._semaphore:
return await self.func(*args, **kwargs)


def locate_json_string_body_from_string(content: str) -> Union[str, None]:
Expand Down

0 comments on commit 7fbd9aa

Please sign in to comment.