Skip to content

Commit

Permalink
fix: Address lint issues
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Klehr <[email protected]>
  • Loading branch information
marcelklehr committed Jul 25, 2024
1 parent fc88ffc commit 15041b4
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 136 deletions.
66 changes: 33 additions & 33 deletions context_chat_backend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,56 @@
from langchain.llms.base import LLM
from langchain.schema.embeddings import Embeddings

_embedding_models = ['llama', 'hugging_face', 'instructor']
_llm_models = ['nc_texttotext', 'llama', 'hugging_face', 'ctransformer']
_embedding_models = ["llama", "hugging_face", "instructor"]
_llm_models = ["nc_texttotext", "llama", "hugging_face", "ctransformer"]

models = {
'embedding': _embedding_models,
'llm': _llm_models,
"embedding": _embedding_models,
"llm": _llm_models,
}

__all__ = ['init_model', 'load_model', 'models', 'LlmException']
__all__ = ["init_model", "load_model", "models", "LlmException"]


def load_model(model_type: str, model_info: tuple[str, dict]) -> Embeddings | LLM | None:
model_name, model_config = model_info
model_name, model_config = model_info

try:
module = import_module(f'.{model_name}', 'context_chat_backend.models')
except Exception as e:
raise AssertionError(f'Error: could not load {model_name} model from context_chat_backend/models') from e
try:
module = import_module(f".{model_name}", "context_chat_backend.models")
except Exception as e:
raise AssertionError(f"Error: could not load {model_name} model from context_chat_backend/models") from e

if module is None or not hasattr(module, 'get_model_for'):
raise AssertionError(f'Error: could not load {model_name} model')
if module is None or not hasattr(module, "get_model_for"):
raise AssertionError(f"Error: could not load {model_name} model")

get_model_for = module.get_model_for
get_model_for = module.get_model_for

if not isinstance(get_model_for, Callable):
raise AssertionError(f'Error: {model_name} does not have a valid loader function')
if not isinstance(get_model_for, Callable):
raise AssertionError(f"Error: {model_name} does not have a valid loader function")

return get_model_for(model_type, model_config)
return get_model_for(model_type, model_config)


def init_model(model_type: str, model_info: tuple[str, dict]):
'''
Initializes a given model. This function assumes that the model is implemented in a module with
the same name as the model in the models dir.
'''
model_name, _ = model_info
available_models = models.get(model_type, [])
"""
Initializes a given model. This function assumes that the model is implemented in a module with
the same name as the model in the models dir.
"""
model_name, _ = model_info
available_models = models.get(model_type, [])

if model_name not in available_models:
raise AssertionError(f'Error: {model_type}_model should be one of {available_models}')
if model_name not in available_models:
raise AssertionError(f"Error: {model_type}_model should be one of {available_models}")

try:
model = load_model(model_type, model_info)
except Exception as e:
raise AssertionError(f'Error: {model_name} failed to load') from e
try:
model = load_model(model_type, model_info)
except Exception as e:
raise AssertionError(f"Error: {model_name} failed to load") from e

if model_type == 'llm' and not isinstance(model, LLM):
raise AssertionError(f'Error: {model} does not implement "llm" type or has returned an invalid object')
if model_type == "llm" and not isinstance(model, LLM):
raise AssertionError(f'Error: {model} does not implement "llm" type or has returned an invalid object')

return model
return model

class LlmException(Exception):
...

class LlmException(Exception): ...
18 changes: 9 additions & 9 deletions context_chat_backend/models/ctransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@


def get_model_for(model_type: str, model_config: dict):
model_dir = getenv('MODEL_DIR', 'persistent_storage/model_files')
if str(model_config.get('model')).startswith('/'):
model_dir = ''
model_dir = getenv("MODEL_DIR", "persistent_storage/model_files")
if str(model_config.get("model")).startswith("/"):
model_dir = ""

model_path = path.join(model_dir, model_config.get('model', ''))
model_path = path.join(model_dir, model_config.get("model", ""))

if model_config is None:
return None
if model_config is None:
return None

if model_type == 'llm':
return CTransformers(**{ **model_config, 'model': model_path })
if model_type == "llm":
return CTransformers(**{**model_config, "model": model_path})

return None
return None
28 changes: 14 additions & 14 deletions context_chat_backend/models/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@


def get_model_for(model_type: str, model_config: dict):
if model_config.get('model_path') is not None:
model_dir = getenv('MODEL_DIR', 'persistent_storage/model_files')
if str(model_config.get('model_path')).startswith('/'):
model_dir = ''
if model_config.get("model_path") is not None:
model_dir = getenv("MODEL_DIR", "persistent_storage/model_files")
if str(model_config.get("model_path")).startswith("/"):
model_dir = ""

model_path = path.join(model_dir, model_config.get('model_path', ''))
else:
model_path = model_config.get('model_id', '')
model_path = path.join(model_dir, model_config.get("model_path", ""))
else:
model_path = model_config.get("model_id", "")

if model_config is None:
return None
if model_config is None:
return None

if model_type == 'embedding':
return HuggingFaceEmbeddings(**model_config)
if model_type == "embedding":
return HuggingFaceEmbeddings(**model_config)

if model_type == 'llm':
return HuggingFacePipeline.from_model_id(**{ **model_config, 'model_id': model_path })
if model_type == "llm":
return HuggingFacePipeline.from_model_id(**{**model_config, "model_id": model_path})

return None
return None
10 changes: 5 additions & 5 deletions context_chat_backend/models/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@


def get_model_for(model_type: str, model_config: dict):
if model_config is None:
return None
if model_config is None:
return None

if model_type == 'embedding':
return HuggingFaceInstructEmbeddings(**model_config)
if model_type == "embedding":
return HuggingFaceInstructEmbeddings(**model_config)

return None
return None
22 changes: 11 additions & 11 deletions context_chat_backend/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@


def get_model_for(model_type: str, model_config: dict):
model_dir = getenv('MODEL_DIR', 'persistent_storage/model_files')
if str(model_config.get('model_path')).startswith('/'):
model_dir = ''
model_dir = getenv("MODEL_DIR", "persistent_storage/model_files")
if str(model_config.get("model_path")).startswith("/"):
model_dir = ""

model_path = path.join(model_dir, model_config.get('model_path', ''))
model_path = path.join(model_dir, model_config.get("model_path", ""))

if model_config is None:
return None
if model_config is None:
return None

if model_type == 'embedding':
return LlamaCppEmbeddings(**{ **model_config, 'model_path': model_path })
if model_type == "embedding":
return LlamaCppEmbeddings(**{**model_config, "model_path": model_path})

if model_type == 'llm':
return LlamaCpp(**{ **model_config, 'model_path': model_path })
if model_type == "llm":
return LlamaCpp(**{**model_config, "model_path": model_path})

return None
return None
125 changes: 61 additions & 64 deletions context_chat_backend/models/nc_texttotext.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
import json
import time
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ValidationError
from typing import Any

from nc_py_api import Nextcloud
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from nc_py_api import Nextcloud
from pydantic import BaseModel, ValidationError

from context_chat_backend.models import LlmException


def get_model_for(model_type: str, model_config: dict):
if model_config is None:
return None
if model_config is None:
return None

if model_type == 'llm':
return CustomLLM()
if model_type == "llm":
return CustomLLM()

return None
return None


class Task(BaseModel):
id: int
status: str
output: Optional[Dict[str, str]] = None
id: int
status: str
output: dict[str, str] | None = None


class CustomLLM(LLM):
"""A custom chat model that queries Nextcloud's TextToText provider
"""

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
"""A custom chat model that queries Nextcloud's TextToText provider"""

def _call(
self,
prompt: str,
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given input.
Override this method to implement the LLM logic.
Expand All @@ -53,45 +52,43 @@ def _call(
Returns:
The model output as a string. Actual completions SHOULD NOT include the prompt.
"""
nc = Nextcloud()

print(json.dumps(prompt))

response = nc.ocs("POST", "/ocs/v1.php/taskprocessing/schedule", json={
"type": "core:text2text",
"appId": "context_chat_backend",
"input": {
"input": prompt
}
})

try:
task = Task.model_validate(response["task"])

while task.status != 'STATUS_SUCCESSFUL' and task.status != 'STATUS_FAILED':
time.sleep(5)
response = nc.ocs("GET", f"/ocs/v1.php/taskprocessing/task/{task.id}")
task = Task.model_validate(response["task"])
except ValidationError as e:
raise LlmException('Failed to parse Nextcloud TaskProcessing task result')

if task.status == 'STATUS_FAILED':
raise LlmException('Nextcloud TaskProcessing Task failed')

return task.output['output']

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "NextcloudTextToTextProvider",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "nc_texttotetx"
nc = Nextcloud()

print(json.dumps(prompt))

response = nc.ocs(
"POST",
"/ocs/v1.php/taskprocessing/schedule",
json={"type": "core:text2text", "appId": "context_chat_backend", "input": {"input": prompt}},
)

try:
task = Task.model_validate(response["task"])

while task.status != "STATUS_SUCCESSFUL" and task.status != "STATUS_FAILED":
time.sleep(5)
response = nc.ocs("GET", f"/ocs/v1.php/taskprocessing/task/{task.id}")
task = Task.model_validate(response["task"])
except ValidationError as e:
raise LlmException("Failed to parse Nextcloud TaskProcessing task result") from e

if task.status == "STATUS_FAILED":
raise LlmException("Nextcloud TaskProcessing Task failed")

return task.output["output"]

@property
def _identifying_params(self) -> dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": "NextcloudTextToTextProvider",
}

@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "nc_texttotetx"

0 comments on commit 15041b4

Please sign in to comment.