Skip to content

Commit

Permalink
Updates to handle model_* fields
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jun 24, 2024
1 parent 5c2406c commit d0d8fd4
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 26 deletions.
5 changes: 2 additions & 3 deletions backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dotenv import load_dotenv
from backend.settings import settings

# load environment variables
load_dotenv()
settings.LOG_LEVEL
3 changes: 1 addition & 2 deletions backend/indexer/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ async def _sync_data_source_to_collection(
Returns:
None
"""

client = await get_client()

failed_data_point_fqns = []
Expand Down Expand Up @@ -219,7 +218,7 @@ async def ingest_data_points(
"""
embeddings = model_gateway.get_embedder_from_model_config(
model_name=inputs.embedder_config.model_config.name
model_name=inputs.embedder_config.model_configuration.name
)
documents_to_be_upserted = []
logger.info(
Expand Down
25 changes: 15 additions & 10 deletions backend/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@

from backend.settings import settings

LOG_LEVEL = logging.getLevelName(settings.LOG_LEVEL.upper())

logger = logging.getLogger(__name__)
logging.getLogger("boto3").setLevel(logging.CRITICAL)
logging.getLogger("botocore").setLevel(logging.CRITICAL)
logging.getLogger("nose").setLevel(logging.CRITICAL)
logging.getLogger("s3transfer").setLevel(logging.CRITICAL)
logging.getLogger("urllib3").setLevel(logging.CRITICAL)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
"%(levelname)s: %(asctime)s - %(module)s:%(funcName)s:%(lineno)d - %(message)s"
)
handler = logging.StreamHandler(stream=sys.stdout)
handler.setLevel(LOG_LEVEL)
handler.setFormatter(formatter)
logger.addHandler(handler)


def setup_logging(level):
logger.setLevel(logging.DEBUG)
log_level = logging.getLevelName(level.upper())
formatter = logging.Formatter(
"%(levelname)s: %(asctime)s - %(module)s:%(funcName)s:%(lineno)d - %(message)s"
)
handler = logging.StreamHandler(stream=sys.stdout)
handler.setLevel(log_level)
handler.setFormatter(formatter)
logger.addHandler(handler)


setup_logging(level=settings.LOG_LEVEL)
2 changes: 1 addition & 1 deletion backend/modules/query_controllers/example/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def _get_vector_store(self, collection_name: str):
return VECTOR_STORE_CLIENT.get_vector_store(
collection_name=collection.name,
embeddings=model_gateway.get_embedder_from_model_config(
model_name=collection.embedder_config.model_config.name
model_name=collection.embedder_config.model_configuration.name
),
)

Expand Down
6 changes: 4 additions & 2 deletions backend/modules/query_controllers/example/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar, Optional, Sequence

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from qdrant_client.models import Filter as QdrantFilter

from backend.logger import logger
Expand Down Expand Up @@ -123,14 +123,16 @@ class ExampleQueryInput(BaseModel):
Requires a collection name, retriever configuration, query, LLM configuration and prompt template.
"""

# TODO (chiragjn): This is not the best idea
model_config = ConfigDict(protected_namespaces=tuple())

collection_name: str = Field(
default=None,
title="Collection name on which to search",
)

query: str = Field(title="Question to search for")

# TODO (chiragjn): Pydantic v2 does not like fields that begin with model_*
model_configuration: ModelConfig

prompt_template: str = Field(
Expand Down
2 changes: 1 addition & 1 deletion backend/modules/query_controllers/multimodal/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def _get_vector_store(self, collection_name: str):
return VECTOR_STORE_CLIENT.get_vector_store(
collection_name=collection.name,
embeddings=model_gateway.get_embedder_from_model_config(
model_name=collection.embedder_config.model_config.name
model_name=collection.embedder_config.model_configuration.name
),
)

Expand Down
6 changes: 4 additions & 2 deletions backend/modules/query_controllers/multimodal/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar, Optional, Sequence

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from qdrant_client.models import Filter as QdrantFilter

from backend.logger import logger
Expand Down Expand Up @@ -123,14 +123,16 @@ class ExampleQueryInput(BaseModel):
Requires a collection name, retriever configuration, query, LLM configuration and prompt template.
"""

# TODO (chiragjn): This is not the best idea
model_config = ConfigDict(protected_namespaces=tuple())

collection_name: str = Field(
default=None,
title="Collection name on which to search",
)

query: str = Field(title="Question to search for")

# TODO (chiragjn): Pydantic v2 does not like fields that begin with model_*
model_configuration: ModelConfig

prompt_template: str = Field(
Expand Down
2 changes: 1 addition & 1 deletion backend/server/routers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def create_collection(collection: CreateCollectionDto):
VECTOR_STORE_CLIENT.create_collection(
collection_name=collection.name,
embeddings=model_gateway.get_embedder_from_model_config(
model_name=collection.embedder_config.model_config.name
model_name=collection.embedder_config.model_configuration.name
),
)
logger.info(f"Created collection... {created_collection}")
Expand Down
4 changes: 2 additions & 2 deletions backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pydantic import ConfigDict, model_validator
from pydantic_settings import BaseSettings

from backend.logger import logger
from backend.types import MetadataStoreConfig, VectorDBConfig


Expand Down Expand Up @@ -57,7 +56,8 @@ def _validate_values(cls, values: Any) -> Any:
tfy_llm_gateway_url = f"{tfy_host.rstrip('/')}/api/llm"
values["TFY_LLM_GATEWAY_URL"] = tfy_llm_gateway_url
else:
logger.warning(
# logger has not been initialized at this point, hence the print
print(
f"[Validation Skipped] Pydantic v2 validator received "
f"non dict values of type {type(values)}"
)
Expand Down
9 changes: 7 additions & 2 deletions backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,13 @@ class EmbedderConfig(BaseModel):
Embedder configuration
"""

# TODO (chiragjn): Pydantic v2 does not like fields that begin with model_*
model_config: ModelConfig
# TODO (chiragjn): This is not the best idea
model_config = ConfigDict(protected_namespaces=tuple())

# Pydantic v2 reserves model_config for itself
model_configuration: ModelConfig = Field(
validation_alias="model_config", serialization_alias="model_config"
)
config: Optional[dict[str, Any]] = Field(
title="Configuration for the embedder", default_factory=dict
)
Expand Down

0 comments on commit d0d8fd4

Please sign in to comment.