From d0d8fd4bf7d0b3fa2e8b318d401429777b4000c8 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 25 Jun 2024 05:06:16 +0530 Subject: [PATCH] Updates to handle model_* fields --- backend/__init__.py | 5 ++-- backend/indexer/indexer.py | 3 +-- backend/logger.py | 25 +++++++++++-------- .../query_controllers/example/controller.py | 2 +- .../query_controllers/example/types.py | 6 +++-- .../multimodal/controller.py | 2 +- .../query_controllers/multimodal/types.py | 6 +++-- backend/server/routers/collection.py | 2 +- backend/settings.py | 4 +-- backend/types.py | 9 +++++-- 10 files changed, 38 insertions(+), 26 deletions(-) diff --git a/backend/__init__.py b/backend/__init__.py index 2e14c3aa..419156a0 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -1,4 +1,3 @@ -from dotenv import load_dotenv +from backend.settings import settings -# load environment variables -load_dotenv() +settings.LOG_LEVEL diff --git a/backend/indexer/indexer.py b/backend/indexer/indexer.py index 12b9563c..d7bb5e54 100644 --- a/backend/indexer/indexer.py +++ b/backend/indexer/indexer.py @@ -148,7 +148,6 @@ async def _sync_data_source_to_collection( Returns: None """ - client = await get_client() failed_data_point_fqns = [] @@ -219,7 +218,7 @@ async def ingest_data_points( """ embeddings = model_gateway.get_embedder_from_model_config( - model_name=inputs.embedder_config.model_config.name + model_name=inputs.embedder_config.model_configuration.name ) documents_to_be_upserted = [] logger.info( diff --git a/backend/logger.py b/backend/logger.py index ebc3b601..d20a3f93 100644 --- a/backend/logger.py +++ b/backend/logger.py @@ -3,19 +3,24 @@ from backend.settings import settings -LOG_LEVEL = logging.getLevelName(settings.LOG_LEVEL.upper()) - logger = logging.getLogger(__name__) logging.getLogger("boto3").setLevel(logging.CRITICAL) logging.getLogger("botocore").setLevel(logging.CRITICAL) logging.getLogger("nose").setLevel(logging.CRITICAL) logging.getLogger("s3transfer").setLevel(logging.CRITICAL) logging.getLogger("urllib3").setLevel(logging.CRITICAL) -logger.setLevel(logging.DEBUG) -formatter = logging.Formatter( - "%(levelname)s: %(asctime)s - %(module)s:%(funcName)s:%(lineno)d - %(message)s" -) -handler = logging.StreamHandler(stream=sys.stdout) -handler.setLevel(LOG_LEVEL) -handler.setFormatter(formatter) -logger.addHandler(handler) + + +def setup_logging(level): + logger.setLevel(logging.DEBUG) + log_level = logging.getLevelName(level.upper()) + formatter = logging.Formatter( + "%(levelname)s: %(asctime)s - %(module)s:%(funcName)s:%(lineno)d - %(message)s" + ) + handler = logging.StreamHandler(stream=sys.stdout) + handler.setLevel(log_level) + handler.setFormatter(formatter) + logger.addHandler(handler) + + +setup_logging(level=settings.LOG_LEVEL) diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index 7077bfed..41200659 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -92,7 +92,7 @@ async def _get_vector_store(self, collection_name: str): return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, embeddings=model_gateway.get_embedder_from_model_config( - model_name=collection.embedder_config.model_config.name + model_name=collection.embedder_config.model_configuration.name ), ) diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index d73c8b03..01bc8a48 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar, Optional, Sequence -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from qdrant_client.models import Filter as QdrantFilter from backend.logger import logger @@ -123,6 +123,9 @@ class ExampleQueryInput(BaseModel): Requires a collection name, retriever configuration, query, LLM configuration and prompt template. """ + # TODO (chiragjn): This is not the best idea + model_config = ConfigDict(protected_namespaces=tuple()) + collection_name: str = Field( default=None, title="Collection name on which to search", @@ -130,7 +133,6 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") - # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* model_configuration: ModelConfig prompt_template: str = Field( diff --git a/backend/modules/query_controllers/multimodal/controller.py b/backend/modules/query_controllers/multimodal/controller.py index b403c32e..fae66867 100644 --- a/backend/modules/query_controllers/multimodal/controller.py +++ b/backend/modules/query_controllers/multimodal/controller.py @@ -105,7 +105,7 @@ async def _get_vector_store(self, collection_name: str): return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, embeddings=model_gateway.get_embedder_from_model_config( - model_name=collection.embedder_config.model_config.name + model_name=collection.embedder_config.model_configuration.name ), ) diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index d73c8b03..01bc8a48 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar, Optional, Sequence -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from qdrant_client.models import Filter as QdrantFilter from backend.logger import logger @@ -123,6 +123,9 @@ class ExampleQueryInput(BaseModel): Requires a collection name, retriever configuration, query, LLM configuration and prompt template. """ + # TODO (chiragjn): This is not the best idea + model_config = ConfigDict(protected_namespaces=tuple()) + collection_name: str = Field( default=None, title="Collection name on which to search", @@ -130,7 +133,6 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") - # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* model_configuration: ModelConfig prompt_template: str = Field( diff --git a/backend/server/routers/collection.py b/backend/server/routers/collection.py index 3cd8684a..91379662 100644 --- a/backend/server/routers/collection.py +++ b/backend/server/routers/collection.py @@ -80,7 +80,7 @@ async def create_collection(collection: CreateCollectionDto): VECTOR_STORE_CLIENT.create_collection( collection_name=collection.name, embeddings=model_gateway.get_embedder_from_model_config( - model_name=collection.embedder_config.model_config.name + model_name=collection.embedder_config.model_configuration.name ), ) logger.info(f"Created collection... {created_collection}") diff --git a/backend/settings.py b/backend/settings.py index cafc8d84..5cdc7d22 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -4,7 +4,6 @@ from pydantic import ConfigDict, model_validator from pydantic_settings import BaseSettings -from backend.logger import logger from backend.types import MetadataStoreConfig, VectorDBConfig @@ -57,7 +56,8 @@ def _validate_values(cls, values: Any) -> Any: tfy_llm_gateway_url = f"{tfy_host.rstrip('/')}/api/llm" values["TFY_LLM_GATEWAY_URL"] = tfy_llm_gateway_url else: - logger.warning( + # logger has not been initialized at this point, hence the print + print( f"[Validation Skipped] Pydantic v2 validator received " f"non dict values of type {type(values)}" ) diff --git a/backend/types.py b/backend/types.py index 97a228f4..f0fd37d1 100644 --- a/backend/types.py +++ b/backend/types.py @@ -123,8 +123,13 @@ class EmbedderConfig(BaseModel): Embedder configuration """ - # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* - model_config: ModelConfig + # TODO (chiragjn): This is not the best idea + model_config = ConfigDict(protected_namespaces=tuple()) + + # Pydantic v2 reserves model_config for itself + model_configuration: ModelConfig = Field( + validation_alias="model_config", serialization_alias="model_config" + ) config: Optional[dict[str, Any]] = Field( title="Configuration for the embedder", default_factory=dict )