From 7763403ebfa723c8d6ad2d8cef733f46aef90452 Mon Sep 17 00:00:00 2001 From: Dario Lopez Padial Date: Fri, 29 Sep 2023 19:51:57 +0200 Subject: [PATCH] bump pydantic to v2 --- src/email/send_email.py | 6 +- src/etls/etl_common.py | 10 +-- src/etls/etl_daily.py | 5 +- src/etls/etl_initial.py | 13 ++-- src/etls/scrapper/base.py | 11 ++-- src/etls/scrapper/boe.py | 128 +++++++++++++++++++++----------------- src/etls/utils.py | 38 ++++++----- src/initialize.py | 75 ++++++++++++---------- src/service/main.py | 10 +-- src/utils.py | 8 +-- 10 files changed, 163 insertions(+), 141 deletions(-) diff --git a/src/email/send_email.py b/src/email/send_email.py index d0408e6..9cda21f 100644 --- a/src/email/send_email.py +++ b/src/email/send_email.py @@ -12,9 +12,9 @@ def send_email(config_loader, subject: str, content: str) -> None: logger = lg.getLogger(send_email.__name__) logger.info("Sending email") - sg = SendGridAPIClient(api_key=os.environ.get('SENDGRID_API_KEY')) - from_email = Email(config_loader['admin_email']) - to_email = To(config_loader['admin_email']) + sg = SendGridAPIClient(api_key=os.environ.get("SENDGRID_API_KEY")) + from_email = Email(config_loader["admin_email"]) + to_email = To(config_loader["admin_email"]) content = Content("text/plain", content) mail = Mail(from_email, to_email, subject, content) response = sg.client.mail.send.post(request_body=mail.get()) diff --git a/src/etls/etl_common.py b/src/etls/etl_common.py index 4607f5b..e0010ed 100644 --- a/src/etls/etl_common.py +++ b/src/etls/etl_common.py @@ -36,13 +36,13 @@ def _split_documents(self, docs: tp.List[BOEMetadataDocument]) -> tp.List[Docume loader = BOETextLoader(file_path=doc.filepath, metadata=doc.dict()) documents = loader.load() text_splitter = CharacterTextSplitter( - separator=self._config_loader['separator'], - chunk_size=self._config_loader['chunk_size'], - chunk_overlap=self._config_loader['chunk_overlap'] + separator=self._config_loader["separator"], + chunk_size=self._config_loader["chunk_size"], + chunk_overlap=self._config_loader["chunk_overlap"], ) docs_chunks += text_splitter.split_documents(documents) if doc: - logger.info('Removing file %s', doc.filepath) + logger.info("Removing file %s", doc.filepath) os.remove(doc.filepath) logger.info("Splitted %s documents in %s chunks", len(docs), len(docs_chunks)) return docs_chunks @@ -56,7 +56,7 @@ def _load_database(self, docs_chunks: tp.List[Document]) -> None: def _log_database_stats(self) -> None: logger = lg.getLogger(self._log_database_stats.__name__) - index_name = self._config_loader['vector_store_index_name'] + index_name = self._config_loader["vector_store_index_name"] logger.info(pinecone.describe_index(index_name)) index = pinecone.Index(index_name) logger.info(index.describe_index_stats()) diff --git a/src/etls/etl_daily.py b/src/etls/etl_daily.py index 3eabfcf..20b183d 100644 --- a/src/etls/etl_daily.py +++ b/src/etls/etl_daily.py @@ -5,11 +5,10 @@ from src.etls.scrapper.boe import BOEScrapper from src.initialize import initialize_app -if __name__ == '__main__': +if __name__ == "__main__": INIT_OBJECTS = initialize_app() etl_job = ETL( - config_loader=INIT_OBJECTS.config_loader, - vector_store=INIT_OBJECTS.vector_store + config_loader=INIT_OBJECTS.config_loader, vector_store=INIT_OBJECTS.vector_store ) boe_scrapper = BOEScrapper() day = date.today() diff --git a/src/etls/etl_initial.py b/src/etls/etl_initial.py index 56915c8..2fb3a34 100644 --- a/src/etls/etl_initial.py +++ b/src/etls/etl_initial.py @@ -5,16 +5,19 @@ from src.etls.scrapper.boe import BOEScrapper from src.initialize import initialize_app -if __name__ == '__main__': +if __name__ == "__main__": INIT_OBJECTS = initialize_app() etl_job = ETL( - config_loader=INIT_OBJECTS.config_loader, - vector_store=INIT_OBJECTS.vector_store + config_loader=INIT_OBJECTS.config_loader, vector_store=INIT_OBJECTS.vector_store ) boe_scrapper = BOEScrapper() docs = boe_scrapper.download_days( - date_start=datetime.strptime(INIT_OBJECTS.config_loader['date_start'], '%Y/%m/%d').date(), - date_end=datetime.strptime(INIT_OBJECTS.config_loader['date_end'], '%Y/%m/%d').date(), + date_start=datetime.strptime( + INIT_OBJECTS.config_loader["date_start"], "%Y/%m/%d" + ).date(), + date_end=datetime.strptime( + INIT_OBJECTS.config_loader["date_end"], "%Y/%m/%d" + ).date(), ) if docs: etl_job.run(docs) diff --git a/src/etls/scrapper/base.py b/src/etls/scrapper/base.py index b82e1e1..f7deeec 100644 --- a/src/etls/scrapper/base.py +++ b/src/etls/scrapper/base.py @@ -6,17 +6,16 @@ class BaseScrapper(ABC): - @abstractmethod - def download_days(self, date_start: date, date_end: date) -> tp.List[BOEMetadataDocument]: - """Download all the documents between two dates (from date_start to date_end) - """ + def download_days( + self, date_start: date, date_end: date + ) -> tp.List[BOEMetadataDocument]: + """Download all the documents between two dates (from date_start to date_end)""" pass @abstractmethod def download_day(self, day: date) -> tp.List[BOEMetadataDocument]: - """Download all the documents for a specific date. - """ + """Download all the documents for a specific date.""" pass @abstractmethod diff --git a/src/etls/scrapper/boe.py b/src/etls/scrapper/boe.py index 58ca521..d6c2dfc 100644 --- a/src/etls/scrapper/boe.py +++ b/src/etls/scrapper/boe.py @@ -10,8 +10,11 @@ from requests.exceptions import HTTPError from src.etls.scrapper.base import BaseScrapper -from src.etls.utils import (BOEMetadataDocument, BOEMetadataDocument2, - BOEMetadataReferencia) +from src.etls.utils import ( + BOEMetadataDocument, + BOEMetadataDocument2, + BOEMetadataReferencia, +) from src.initialize import initialize_logging initialize_logging() @@ -23,63 +26,79 @@ def _extract_metadata(soup) -> tp.Dict: # Metadatos identificador = soup.documento.metadatos.identificador if identificador: - metadata_dict['identificador'] = identificador.get_text() + metadata_dict["identificador"] = identificador.get_text() if numero_oficial := soup.documento.metadatos.numero_oficial: - metadata_dict['numero_oficial'] = numero_oficial.get_text() + metadata_dict["numero_oficial"] = numero_oficial.get_text() if departamento := soup.documento.metadatos.departamento: - metadata_dict['departamento'] = departamento.get_text() + metadata_dict["departamento"] = departamento.get_text() if rango := soup.documento.metadatos.rango: - metadata_dict['rango'] = rango.get_text() + metadata_dict["rango"] = rango.get_text() if titulo := soup.documento.metadatos.titulo: - metadata_dict['titulo'] = titulo.get_text() + metadata_dict["titulo"] = titulo.get_text() if url_pdf := soup.documento.metadatos.url_pdf: - metadata_dict['url_pdf'] = url_pdf.get_text() + metadata_dict["url_pdf"] = url_pdf.get_text() if origen_legislativo := soup.documento.metadatos.origen_legislativo: - metadata_dict['origen_legislativo'] = origen_legislativo.get_text() + metadata_dict["origen_legislativo"] = origen_legislativo.get_text() if fecha_publicacion := soup.documento.metadatos.fecha_publicacion: - metadata_dict['fecha_publicacion'] = fecha_publicacion.get_text() + metadata_dict["fecha_publicacion"] = fecha_publicacion.get_text() if fecha_disposicion := soup.documento.metadatos.fecha_disposicion: - metadata_dict['fecha_disposicion'] = fecha_disposicion.get_text() + metadata_dict["fecha_disposicion"] = fecha_disposicion.get_text() - metadata_dict['anio'] = datetime.strptime(fecha_publicacion.get_text(), '%Y%m%d').strftime('%Y') + metadata_dict["anio"] = datetime.strptime( + fecha_publicacion.get_text(), "%Y%m%d" + ).strftime("%Y") # Analisis if observaciones := soup.documento.analisis.observaciones: - metadata_dict['observaciones'] = observaciones.get_text() + metadata_dict["observaciones"] = observaciones.get_text() if ambito_geografico := soup.documento.analisis.ambito_geografico: - metadata_dict['ambito_geografico'] = ambito_geografico.get_text() + metadata_dict["ambito_geografico"] = ambito_geografico.get_text() if modalidad := soup.documento.analisis.modalidad: - metadata_dict['modalidad'] = modalidad.get_text() + metadata_dict["modalidad"] = modalidad.get_text() if tipo := soup.documento.analisis.tipo: - metadata_dict['tipo'] = tipo.get_text() + metadata_dict["tipo"] = tipo.get_text() - metadata_dict['materias'] = [ - materia.get_text() for materia in soup.select('documento > analisis > materias > materia') + metadata_dict["materias"] = [ + materia.get_text() + for materia in soup.select("documento > analisis > materias > materia") ] - metadata_dict['alertas'] = [ - alerta.get_text() for alerta in soup.select('documento > analisis > alertas > alerta') + metadata_dict["alertas"] = [ + alerta.get_text() + for alerta in soup.select("documento > analisis > alertas > alerta") ] - metadata_dict['notas'] = [ - nota.get_text() for nota in soup.select('documento > analisis > notas > nota') + metadata_dict["notas"] = [ + nota.get_text() for nota in soup.select("documento > analisis > notas > nota") ] - metadata_dict['ref_posteriores'] = [ - BOEMetadataReferencia(id=ref['referencia'], palabra=ref.palabra.get_text(), texto=ref.texto.get_text()) - for ref in soup.select('documento > analisis > referencias > posteriores > posterior') + metadata_dict["ref_posteriores"] = [ + BOEMetadataReferencia( + id=ref["referencia"], + palabra=ref.palabra.get_text(), + texto=ref.texto.get_text(), + ) + for ref in soup.select( + "documento > analisis > referencias > posteriores > posterior" + ) ] - metadata_dict['ref_anteriores'] = [ - BOEMetadataReferencia(id=ref['referencia'], palabra=ref.palabra.get_text(), texto=ref.texto.get_text()) - for ref in soup.select('documento > analisis > referencias > anteriores > anterior') + metadata_dict["ref_anteriores"] = [ + BOEMetadataReferencia( + id=ref["referencia"], + palabra=ref.palabra.get_text(), + texto=ref.texto.get_text(), + ) + for ref in soup.select( + "documento > analisis > referencias > anteriores > anterior" + ) ] return metadata_dict @@ -94,25 +113,25 @@ def _list_links_day(url: str) -> tp.List[str]: logger.info("Scrapping day: %s", url) response = requests.get(url) response.raise_for_status() - soup = BeautifulSoup(response.text, 'lxml') + soup = BeautifulSoup(response.text, "lxml") id_links = [ - url.text.split('?id=')[-1] + url.text.split("?id=")[-1] for section in soup.find_all( - lambda tag: tag.name == "seccion" and 'num' in tag.attrs and ( - tag.attrs['num'] == '1' or tag.attrs['num'] == 'T' - ) + lambda tag: tag.name == "seccion" + and "num" in tag.attrs + and (tag.attrs["num"] == "1" or tag.attrs["num"] == "T") ) - for url in section.find_all('urlxml') + for url in section.find_all("urlxml") ] logger.info("Scrapped day successfully %s (%s BOE documents)", url, len(id_links)) return id_links class BOEScrapper(BaseScrapper): - - def download_days(self, date_start: date, date_end: date) -> tp.List[BOEMetadataDocument]: - """Download all the documents between two dates (from date_start to date_end) - """ + def download_days( + self, date_start: date, date_end: date + ) -> tp.List[BOEMetadataDocument]: + """Download all the documents between two dates (from date_start to date_end)""" logger = lg.getLogger(self.download_days.__name__) logger.info("Downloading BOE content from day %s to %s", date_start, date_end) delta = timedelta(days=1) @@ -126,8 +145,7 @@ def download_days(self, date_start: date, date_end: date) -> tp.List[BOEMetadata return metadata_documents def download_day(self, day: date) -> tp.List[BOEMetadataDocument]: - """Download all the documents for a specific date. - """ + """Download all the documents for a specific date.""" logger = lg.getLogger(self.download_day.__name__) logger.info("Downloading BOE content for day %s", day) day_str = day.strftime("%Y%m%d") @@ -141,7 +159,9 @@ def download_day(self, day: date) -> tp.List[BOEMetadataDocument]: metadata_doc = self.download_document(url_document) metadata_documents.append(metadata_doc) except HTTPError: - logger.error("Not scrapped document %s on day %s", url_document, day_url) + logger.error( + "Not scrapped document %s on day %s", url_document, day_url + ) except HTTPError: logger.error("Not scrapped document on day %s", day_url) logger.info("Downloaded BOE content for day %s", day) @@ -159,14 +179,11 @@ def download_document(self, url: str) -> BOEMetadataDocument: logger.info("Scrapping document: %s", url) response = requests.get(url) response.raise_for_status() - soup = BeautifulSoup(response.text, 'lxml') - with tempfile.NamedTemporaryFile('w', delete=False) as fn: + soup = BeautifulSoup(response.text, "lxml") + with tempfile.NamedTemporaryFile("w", delete=False) as fn: text = soup.select_one("documento > texto").get_text() fn.write(text) - metadata_doc = BOEMetadataDocument( - filepath=fn.name, - **_extract_metadata(soup) - ) + metadata_doc = BOEMetadataDocument(filepath=fn.name, **_extract_metadata(soup)) logger.info("Scrapped document successfully %s", url) return metadata_doc @@ -182,21 +199,18 @@ def download_document_txt(self, url: str) -> BOEMetadataDocument2: logger.info("Scrapping document: %s", url) response = requests.get(url) response.raise_for_status() - with tempfile.NamedTemporaryFile('w', delete=False) as fn: - soup = BeautifulSoup(response.text, 'html.parser') # 'html5lib' - text = soup.find('div', id='textoxslt').get_text() - text = unicodedata.normalize('NFKC', text) + with tempfile.NamedTemporaryFile("w", delete=False) as fn: + soup = BeautifulSoup(response.text, "html.parser") # 'html5lib' + text = soup.find("div", id="textoxslt").get_text() + text = unicodedata.normalize("NFKC", text) fn.write(text) - span_tag = soup.find('span', class_='puntoConso') + span_tag = soup.find("span", class_="puntoConso") if span_tag: span_tag = span_tag.extract() # TODO: link to span_tag.a['href'] to improve the split by articles -> https://www.boe.es/buscar/act.php?id=BOE-A-2022-14630 - title = soup.find('h3', class_='documento-tit').get_text() + title = soup.find("h3", class_="documento-tit").get_text() metadata_doc = BOEMetadataDocument2( - filepath=fn.name, - title=title, - url=url, - document_id=url.split('?id=')[-1] + filepath=fn.name, title=title, url=url, document_id=url.split("?id=")[-1] ) logger.info("Scrapped document successfully %s", url) return metadata_doc diff --git a/src/etls/utils.py b/src/etls/utils.py index 7a4fa9b..9b7aa90 100644 --- a/src/etls/utils.py +++ b/src/etls/utils.py @@ -4,7 +4,7 @@ from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader -from pydantic import field_validator, BaseModel +from pydantic import BaseModel, field_validator @dataclass @@ -20,11 +20,11 @@ class BOEMetadataDocument2: def load_metadata(self) -> dict: metadata_dict = { - 'title': self.title, - 'url': self.url, - 'document_id': self.document_id, - 'date_doc': self.date_doc, - 'datetime_insert': self.datetime_insert + "title": self.title, + "url": self.url, + "document_id": self.document_id, + "date_doc": self.date_doc, + "datetime_insert": self.datetime_insert, } return metadata_dict @@ -37,26 +37,27 @@ class BOEMetadataReferencia(BaseModel): class BOEMetadataDocument(BaseModel): """Class for keeping metadata of a BOE Document scrapped.""" + # Text filepath: str # Metadatos identificador: str - numero_oficial: str = '' + numero_oficial: str = "" departamento: str - rango: str = '' + rango: str = "" titulo: str url_pdf: str - origen_legislativo: str = '' + origen_legislativo: str = "" fecha_publicacion: str - fecha_disposicion: str = '' + fecha_disposicion: str = "" anio: str # Analisis - observaciones: str = '' - ambito_geografico: str = '' - modalidad: str = '' - tipo: str = '' + observaciones: str = "" + ambito_geografico: str = "" + modalidad: str = "" + tipo: str = "" materias: tp.List[str] alertas: tp.List[str] notas: tp.List[str] @@ -79,14 +80,19 @@ def ref_anteriores_to_json(cls, validators): @classmethod def isoformat(cls, v): if v: - return datetime.strptime(v, '%Y%m%d').strftime('%Y-%m-%d') + return datetime.strptime(v, "%Y%m%d").strftime("%Y-%m-%d") return v class BOETextLoader(BaseLoader): """Load text files.""" - def __init__(self, file_path: str, encoding: tp.Optional[str] = None, metadata: tp.Optional[dict] = None): + def __init__( + self, + file_path: str, + encoding: tp.Optional[str] = None, + metadata: tp.Optional[dict] = None, + ): """Initialize with file path.""" self.file_path = file_path self.encoding = encoding diff --git a/src/initialize.py b/src/initialize.py index 212fcb2..56624d5 100644 --- a/src/initialize.py +++ b/src/initialize.py @@ -7,8 +7,11 @@ from langchain.chains import RetrievalQA from langchain.chat_models import ChatOpenAI from langchain.embeddings import HuggingFaceEmbeddings -from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate, - SystemMessagePromptTemplate) +from langchain.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) from langchain.vectorstores.pinecone import Pinecone from langchain.vectorstores.qdrant import Qdrant from qdrant_client import QdrantClient @@ -20,34 +23,35 @@ def initialize_logging(): logger = lg.getLogger() - logger.info('Initializing logging') + logger.info("Initializing logging") logger.handlers = [] handler = lg.StreamHandler() - formatter = ( - lg.Formatter('[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s') + formatter = lg.Formatter( + "[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s" ) handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(lg.INFO) - logger.info('Initialized logging') - lg.getLogger('uvicorn.error').handlers = logger.handlers + logger.info("Initialized logging") + lg.getLogger("uvicorn.error").handlers = logger.handlers def initialize_app(): - """Initializes the application - """ + """Initializes the application""" logger = lg.getLogger(initialize_app.__name__) - logger.info('Initializing application') + logger.info("Initializing application") config_loader = _init_config() vector_store = _init_vector_store(config_loader) retrieval_qa = _init_retrieval_qa_llm(vector_store, config_loader) - logger.info('Initialized application') - init_objects = collections.namedtuple('init_objects', ['config_loader', 'vector_store', 'retrieval_qa']) + logger.info("Initialized application") + init_objects = collections.namedtuple( + "init_objects", ["config_loader", "vector_store", "retrieval_qa"] + ) return init_objects(config_loader, vector_store, retrieval_qa) def _init_config(): - yaml_config_path = os.path.join(os.environ['APP_PATH'], 'config', 'config.yaml') + yaml_config_path = os.path.join(os.environ["APP_PATH"], "config", "config.yaml") with open(yaml_config_path, "r") as stream: config_loader = yaml.safe_load(stream) return config_loader @@ -56,14 +60,14 @@ def _init_config(): def _init_vector_store(config_loader): logger = lg.getLogger(_init_vector_store.__name__) logger.info("Initializing vector store") - if config_loader['vector_store'] == 'pinecone': + if config_loader["vector_store"] == "pinecone": vector_store = _init_vector_store_pinecone(config_loader) - elif config_loader['vector_store'] == 'supabase': + elif config_loader["vector_store"] == "supabase": vector_store = _init_vector_store_supabase(config_loader) - elif config_loader['vector_store'] == 'qdrant': + elif config_loader["vector_store"] == "qdrant": vector_store = _init_vector_store_qdrant(config_loader) else: - raise ValueError('Vector Database not configured') + raise ValueError("Vector Database not configured") return vector_store @@ -71,13 +75,14 @@ def _init_vector_store_pinecone(config_loader): logger = lg.getLogger(_init_vector_store_pinecone.__name__) logger.info("Initializing vector store") pinecone.init( - api_key=os.environ['PINECONE_API_KEY'], - environment=os.environ['PINECONE_ENV'], + api_key=os.environ["PINECONE_API_KEY"], + environment=os.environ["PINECONE_ENV"], ) - index_name = config_loader['vector_store_index_name'] + index_name = config_loader["vector_store_index_name"] index = pinecone.Index(index_name) embeddings = HuggingFaceEmbeddings( - model_name=config_loader['embeddings_model_name'], model_kwargs={'device': 'cpu'} + model_name=config_loader["embeddings_model_name"], + model_kwargs={"device": "cpu"}, ) vector_store = Pinecone(index, embeddings.embed_query, "text") logger.info(pinecone.describe_index(index_name)) @@ -97,13 +102,14 @@ def _init_vector_store_supabase(config_loader): options=ClientOptions(postgrest_client_timeout=60), ) embeddings = HuggingFaceEmbeddings( - model_name=config_loader['embeddings_model_name'], model_kwargs={'device': 'cpu'} + model_name=config_loader["embeddings_model_name"], + model_kwargs={"device": "cpu"}, ) vector_store = StandardSupabaseVectorStore( client=supabase_client, embedding=embeddings, table_name=config_loader["table_name"], - query_name=config_loader["query_name"] + query_name=config_loader["query_name"], ) logger.info("Initialized vector store") return vector_store @@ -113,22 +119,23 @@ def _init_vector_store_qdrant(config_loader): logger = lg.getLogger(_init_vector_store_qdrant.__name__) logger.info("Initializing vector store") qdrant_client = QdrantClient( - url=os.environ['QDRANT_API_URL'], - api_key=os.environ['QDRANT_API_KEY'], - prefer_grpc=True + url=os.environ["QDRANT_API_URL"], + api_key=os.environ["QDRANT_API_KEY"], + prefer_grpc=True, ) embeddings = HuggingFaceEmbeddings( - model_name=config_loader['embeddings_model_name'], model_kwargs={'device': 'cpu'} + model_name=config_loader["embeddings_model_name"], + model_kwargs={"device": "cpu"}, ) if len(qdrant_client.get_collections().collections) == 0: logger.info("Creating collection for vector store") qdrant_client.recreate_collection( - collection_name=config_loader['collection_name'], + collection_name=config_loader["collection_name"], vectors_config=VectorParams(size=768, distance=Distance.COSINE), - on_disk_payload=True + on_disk_payload=True, ) logger.info("Created collection for vector store") - vector_store = Qdrant(qdrant_client, config_loader['collection_name'], embeddings) + vector_store = Qdrant(qdrant_client, config_loader["collection_name"], embeddings) logger.info("Initialized vector store") return vector_store @@ -144,15 +151,15 @@ def _init_retrieval_qa_llm(vector_store, config_loader): ] retrieval_qa = RetrievalQA.from_chain_type( llm=ChatOpenAI( - model_name=config_loader['llm_model_name'], - temperature=config_loader['temperature'], - max_tokens=config_loader['max_tokens'] + model_name=config_loader["llm_model_name"], + temperature=config_loader["temperature"], + max_tokens=config_loader["max_tokens"], ), chain_type="stuff", retriever=retriever, chain_type_kwargs={ "prompt": ChatPromptTemplate.from_messages(messages), - "verbose": True # TODO: remove in production + "verbose": True, # TODO: remove in production }, ) logger.info(retrieval_qa.combine_documents_chain.llm_chain.prompt.format) diff --git a/src/service/main.py b/src/service/main.py index 9965a95..f546335 100644 --- a/src/service/main.py +++ b/src/service/main.py @@ -32,8 +32,7 @@ async def semantic_search(input_query: str = DEFAULT_INPUT_QUERY): logger = lg.getLogger(semantic_search.__name__) logger.info(input_query) docs = INIT_OBJECTS.vector_store.similarity_search_with_score( - query=input_query, - k=INIT_OBJECTS.config_loader['top_k_results'] + query=input_query, k=INIT_OBJECTS.config_loader["top_k_results"] ) logger.info(docs) return docs @@ -45,13 +44,10 @@ async def qa(input_query: str = DEFAULT_INPUT_QUERY): logger = lg.getLogger(qa.__name__) logger.info(input_query) docs = INIT_OBJECTS.vector_store.similarity_search_with_score( - query=input_query, - k=INIT_OBJECTS.config_loader['top_k_results'] + query=input_query, k=INIT_OBJECTS.config_loader["top_k_results"] ) answer = INIT_OBJECTS.retrieval_qa.run(input_query) response_payload = QAResponsePayloadModel( - scoring_id=str(uuid.uuid4()), - context=docs, - answer=answer + scoring_id=str(uuid.uuid4()), context=docs, answer=answer ) return response_payload diff --git a/src/utils.py b/src/utils.py index 117fa53..199817c 100644 --- a/src/utils.py +++ b/src/utils.py @@ -9,7 +9,6 @@ class StandardSupabaseVectorStore(SupabaseVectorStore): - def similarity_search_with_score( self, query: str, k: int = 4, **kwargs: tp.Any ) -> tp.List[tp.Tuple[Document, float]]: @@ -23,17 +22,16 @@ class QAResponsePayloadModel(BaseModel): def timeit(func): - @wraps(func) async def wrapper(*args, **kwargs): logger = lg.getLogger(func.__name__) - logger.info('<<< Starting >>>') + logger.info("<<< Starting >>>") start_time = time.time() result = await func(*args, **kwargs) end_time = time.time() delta = end_time - start_time - msg = f'{delta:2.2f}s' if delta > 1 else f'{1000 * delta:2.1f}ms' - logger.info('<<< Completed >>> in %s', msg) + msg = f"{delta:2.2f}s" if delta > 1 else f"{1000 * delta:2.1f}ms" + logger.info("<<< Completed >>> in %s", msg) return result return wrapper