From 1f1679e9600d2f4aa750fafe90978c513dab1808 Mon Sep 17 00:00:00 2001 From: Rajendra Kadam Date: Thu, 22 Aug 2024 21:16:52 +0530 Subject: [PATCH] community: Refactor PebbloSafeLoader (#25582) **Refactor PebbloSafeLoader** - Created `APIWrapper` and moved API logic into it. - Moved helper functions to the utility file. - Created smaller functions and methods for better readability. - Properly read environment variables. - Removed unused code. **Issue:** NA **Dependencies:** NA **tests**: Updated --- .../document_loaders/pebblo.py | 339 +-------------- .../langchain_community/utilities/pebblo.py | 407 +++++++++++++++++- .../document_loaders/test_pebblo.py | 3 +- 3 files changed, 422 insertions(+), 327 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/pebblo.py b/libs/community/langchain_community/document_loaders/pebblo.py index b3bd447516129..772a206a803d8 100644 --- a/libs/community/langchain_community/document_loaders/pebblo.py +++ b/libs/community/langchain_community/document_loaders/pebblo.py @@ -1,31 +1,25 @@ """Pebblo's safe dataloader is a wrapper for document loaders""" -import json import logging import os import uuid -from http import HTTPStatus -from typing import Any, Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional -import requests # type: ignore from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader from langchain_community.utilities.pebblo import ( - APP_DISCOVER_URL, BATCH_SIZE_BYTES, - CLASSIFIER_URL, - LOADER_DOC_URL, - PEBBLO_CLOUD_URL, PLUGIN_VERSION, App, - Doc, IndexedDocument, + PebbloLoaderAPIWrapper, generate_size_based_batches, get_full_path, get_loader_full_path, get_loader_type, get_runtime, + get_source_size, ) logger = logging.getLogger(__name__) @@ -37,7 +31,6 @@ class PebbloSafeLoader(BaseLoader): """ _discover_sent: bool = False - _loader_sent: bool = False def __init__( self, @@ -54,22 +47,17 @@ def __init__( if not name or not isinstance(name, str): raise NameError("Must specify a valid name.") self.app_name = name - self.api_key = os.environ.get("PEBBLO_API_KEY") or api_key self.load_id = str(uuid.uuid4()) self.loader = langchain_loader self.load_semantic = os.environ.get("PEBBLO_LOAD_SEMANTIC") or load_semantic self.owner = owner self.description = description self.source_path = get_loader_full_path(self.loader) - self.source_owner = PebbloSafeLoader.get_file_owner_from_path(self.source_path) self.docs: List[Document] = [] self.docs_with_id: List[IndexedDocument] = [] loader_name = str(type(self.loader)).split(".")[-1].split("'")[0] self.source_type = get_loader_type(loader_name) - self.source_path_size = self.get_source_size(self.source_path) - self.source_aggregate_size = 0 - self.classifier_url = classifier_url or CLASSIFIER_URL - self.classifier_location = classifier_location + self.source_path_size = get_source_size(self.source_path) self.batch_size = BATCH_SIZE_BYTES self.loader_details = { "loader": loader_name, @@ -83,7 +71,13 @@ def __init__( } # generate app self.app = self._get_app_details() - self._send_discover() + # initialize Pebblo Loader API client + self.pb_client = PebbloLoaderAPIWrapper( + api_key=api_key, + classifier_location=classifier_location, + classifier_url=classifier_url, + ) + self.pb_client.send_loader_discover(self.app) def load(self) -> List[Document]: """Load Documents. @@ -113,7 +107,12 @@ def classify_in_batches(self) -> None: is_last_batch: bool = i == total_batches - 1 self.docs = batch self.docs_with_id = self._index_docs() - classified_docs = self._classify_doc(loading_end=is_last_batch) + classified_docs = self.pb_client.classify_documents( + self.docs_with_id, + self.app, + self.loader_details, + loading_end=is_last_batch, + ) self._add_pebblo_specific_metadata(classified_docs) if self.load_semantic: batch_processed_docs = self._add_semantic_to_docs(classified_docs) @@ -147,7 +146,9 @@ def lazy_load(self) -> Iterator[Document]: break self.docs = list((doc,)) self.docs_with_id = self._index_docs() - classified_doc = self._classify_doc() + classified_doc = self.pb_client.classify_documents( + self.docs_with_id, self.app, self.loader_details + ) self._add_pebblo_specific_metadata(classified_doc) if self.load_semantic: self.docs = self._add_semantic_to_docs(classified_doc) @@ -159,263 +160,6 @@ def lazy_load(self) -> Iterator[Document]: def set_discover_sent(cls) -> None: cls._discover_sent = True - @classmethod - def set_loader_sent(cls) -> None: - cls._loader_sent = True - - def _classify_doc(self, loading_end: bool = False) -> dict: - """Send documents fetched from loader to pebblo-server. Then send - classified documents to Daxa cloud(If api_key is present). Internal method. - - Args: - - loading_end (bool, optional): Flag indicating the halt of data - loading by loader. Defaults to False. - """ - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - } - if loading_end is True: - PebbloSafeLoader.set_loader_sent() - doc_content = [doc.dict() for doc in self.docs_with_id] - docs = [] - for doc in doc_content: - doc_metadata = doc.get("metadata", {}) - doc_authorized_identities = doc_metadata.get("authorized_identities", []) - doc_source_path = get_full_path( - doc_metadata.get( - "full_path", doc_metadata.get("source", self.source_path) - ) - ) - doc_source_owner = doc_metadata.get( - "owner", PebbloSafeLoader.get_file_owner_from_path(doc_source_path) - ) - doc_source_size = doc_metadata.get( - "size", self.get_source_size(doc_source_path) - ) - page_content = str(doc.get("page_content")) - page_content_size = self.calculate_content_size(page_content) - self.source_aggregate_size += page_content_size - doc_id = doc.get("pb_id", None) or 0 - docs.append( - { - "doc": page_content, - "source_path": doc_source_path, - "pb_id": doc_id, - "last_modified": doc.get("metadata", {}).get("last_modified"), - "file_owner": doc_source_owner, - **( - {"authorized_identities": doc_authorized_identities} - if doc_authorized_identities - else {} - ), - **( - {"source_path_size": doc_source_size} - if doc_source_size is not None - else {} - ), - } - ) - payload: Dict[str, Any] = { - "name": self.app_name, - "owner": self.owner, - "docs": docs, - "plugin_version": PLUGIN_VERSION, - "load_id": self.load_id, - "loader_details": self.loader_details, - "loading_end": "false", - "source_owner": self.source_owner, - "classifier_location": self.classifier_location, - } - if loading_end is True: - payload["loading_end"] = "true" - if "loader_details" in payload: - payload["loader_details"]["source_aggregate_size"] = ( - self.source_aggregate_size - ) - payload = Doc(**payload).dict(exclude_unset=True) - classified_docs = {} - # Raw payload to be sent to classifier - if self.classifier_location == "local": - load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}" - try: - pebblo_resp = requests.post( - load_doc_url, headers=headers, json=payload, timeout=300 - ) - - # Updating the structure of pebblo response docs for efficient searching - for classified_doc in json.loads(pebblo_resp.text).get("docs", []): - classified_docs.update({classified_doc["pb_id"]: classified_doc}) - if pebblo_resp.status_code not in [ - HTTPStatus.OK, - HTTPStatus.BAD_GATEWAY, - ]: - logger.warning( - "Received unexpected HTTP response code: %s", - pebblo_resp.status_code, - ) - logger.debug( - "send_loader_doc[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str( - len( - pebblo_resp.request.body if pebblo_resp.request.body else [] - ) - ), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_loader_doc: local %s", e) - - if self.api_key: - if self.classifier_location == "local": - docs = payload["docs"] - for doc_data in docs: - classified_data = classified_docs.get(doc_data["pb_id"], {}) - doc_data.update( - { - "pb_checksum": classified_data.get("pb_checksum", None), - "loader_source_path": classified_data.get( - "loader_source_path", None - ), - "entities": classified_data.get("entities", {}), - "topics": classified_data.get("topics", {}), - } - ) - doc_data.pop("doc") - - headers.update({"x-api-key": self.api_key}) - pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}" - try: - pebblo_cloud_response = requests.post( - pebblo_cloud_url, headers=headers, json=payload, timeout=20 - ) - logger.debug( - "send_loader_doc[cloud]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_cloud_response.request.url, - str(pebblo_cloud_response.request.body), - str( - len( - pebblo_cloud_response.request.body - if pebblo_cloud_response.request.body - else [] - ) - ), - str(pebblo_cloud_response.status_code), - pebblo_cloud_response.json(), - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach Pebblo cloud server.") - except Exception as e: - logger.warning("An Exception caught in _send_loader_doc: cloud %s", e) - elif self.classifier_location == "pebblo-cloud": - logger.warning("API key is missing for sending docs to Pebblo cloud.") - raise NameError("API key is missing for sending docs to Pebblo cloud.") - - return classified_docs - - @staticmethod - def calculate_content_size(page_content: str) -> int: - """Calculate the content size in bytes: - - Encode the string to bytes using a specific encoding (e.g., UTF-8) - - Get the length of the encoded bytes. - - Args: - page_content (str): Data string. - - Returns: - int: Size of string in bytes. - """ - - # Encode the content to bytes using UTF-8 - encoded_content = page_content.encode("utf-8") - size = len(encoded_content) - return size - - def _send_discover(self) -> None: - """Send app discovery payload to pebblo-server. Internal method.""" - pebblo_resp = None - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - } - payload = self.app.dict(exclude_unset=True) - # Raw discover payload to be sent to classifier - if self.classifier_location == "local": - app_discover_url = f"{self.classifier_url}{APP_DISCOVER_URL}" - try: - pebblo_resp = requests.post( - app_discover_url, headers=headers, json=payload, timeout=20 - ) - logger.debug( - "send_discover[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str( - len( - pebblo_resp.request.body if pebblo_resp.request.body else [] - ) - ), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: - PebbloSafeLoader.set_discover_sent() - else: - logger.warning( - f"Received unexpected HTTP response code:\ - {pebblo_resp.status_code}" - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: local %s", e) - - if self.api_key: - try: - headers.update({"x-api-key": self.api_key}) - # If the pebblo_resp is None, - # then the pebblo server version is not available - if pebblo_resp: - pebblo_server_version = json.loads(pebblo_resp.text).get( - "pebblo_server_version" - ) - payload.update({"pebblo_server_version": pebblo_server_version}) - - payload.update({"pebblo_client_version": PLUGIN_VERSION}) - pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}" - pebblo_cloud_response = requests.post( - pebblo_cloud_url, headers=headers, json=payload, timeout=20 - ) - - logger.debug( - "send_discover[cloud]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_cloud_response.request.url, - str(pebblo_cloud_response.request.body), - str( - len( - pebblo_cloud_response.request.body - if pebblo_cloud_response.request.body - else [] - ) - ), - str(pebblo_cloud_response.status_code), - pebblo_cloud_response.json(), - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach Pebblo cloud server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: cloud %s", e) - def _get_app_details(self) -> App: """Fetch app details. Internal method. @@ -434,49 +178,6 @@ def _get_app_details(self) -> App: ) return app - @staticmethod - def get_file_owner_from_path(file_path: str) -> str: - """Fetch owner of local file path. - - Args: - file_path (str): Local file path. - - Returns: - str: Name of owner. - """ - try: - import pwd - - file_owner_uid = os.stat(file_path).st_uid - file_owner_name = pwd.getpwuid(file_owner_uid).pw_name - except Exception: - file_owner_name = "unknown" - return file_owner_name - - def get_source_size(self, source_path: str) -> int: - """Fetch size of source path. Source can be a directory or a file. - - Args: - source_path (str): Local path of data source. - - Returns: - int: Source size in bytes. - """ - if not source_path: - return 0 - size = 0 - if os.path.isfile(source_path): - size = os.path.getsize(source_path) - elif os.path.isdir(source_path): - total_size = 0 - for dirpath, _, filenames in os.walk(source_path): - for f in filenames: - fp = os.path.join(dirpath, f) - if not os.path.islink(fp): - total_size += os.path.getsize(fp) - size = total_size - return size - def _index_docs(self) -> List[IndexedDocument]: """ Indexes the documents and returns a list of IndexedDocument objects. diff --git a/libs/community/langchain_community/utilities/pebblo.py b/libs/community/langchain_community/utilities/pebblo.py index c61ce5bc000a0..50e5b408b99cb 100644 --- a/libs/community/langchain_community/utilities/pebblo.py +++ b/libs/community/langchain_community/utilities/pebblo.py @@ -1,25 +1,29 @@ from __future__ import annotations +import json import logging import os import pathlib import platform -from typing import List, Optional, Tuple +from enum import Enum +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Tuple from langchain_core.documents import Document from langchain_core.env import get_runtime_environment from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import get_from_dict_or_env +from requests import Response, request +from requests.exceptions import RequestException from langchain_community.document_loaders.base import BaseLoader logger = logging.getLogger(__name__) PLUGIN_VERSION = "0.1.1" -CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000") -PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") -LOADER_DOC_URL = "/v1/loader/doc" -APP_DISCOVER_URL = "/v1/app/discover" +_DEFAULT_CLASSIFIER_URL = "http://localhost:8000" +_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai" BATCH_SIZE_BYTES = 100 * 1024 # 100 KB # Supported loaders for Pebblo safe data loading @@ -59,9 +63,15 @@ "cloud-folder": cloud_folder, } -SUPPORTED_LOADERS = (*file_loader, *dir_loader, *in_memory) -logger = logging.getLogger(__name__) +class Routes(str, Enum): + """Routes available for the Pebblo API as enumerator.""" + + loader_doc = "/v1/loader/doc" + loader_app_discover = "/v1/app/discover" + retrieval_app_discover = "/v1/app/discover" + prompt = "/v1/prompt" + prompt_governance = "/v1/prompt/governance" class IndexedDocument(Document): @@ -342,3 +352,386 @@ def generate_size_based_batches( batches.append(current_batch) return batches + + +def get_file_owner_from_path(file_path: str) -> str: + """Fetch owner of local file path. + + Args: + file_path (str): Local file path. + + Returns: + str: Name of owner. + """ + try: + import pwd + + file_owner_uid = os.stat(file_path).st_uid + file_owner_name = pwd.getpwuid(file_owner_uid).pw_name + except Exception: + file_owner_name = "unknown" + return file_owner_name + + +def get_source_size(source_path: str) -> int: + """Fetch size of source path. Source can be a directory or a file. + + Args: + source_path (str): Local path of data source. + + Returns: + int: Source size in bytes. + """ + if not source_path: + return 0 + size = 0 + if os.path.isfile(source_path): + size = os.path.getsize(source_path) + elif os.path.isdir(source_path): + total_size = 0 + for dirpath, _, filenames in os.walk(source_path): + for f in filenames: + fp = os.path.join(dirpath, f) + if not os.path.islink(fp): + total_size += os.path.getsize(fp) + size = total_size + return size + + +def calculate_content_size(data: str) -> int: + """Calculate the content size in bytes: + - Encode the string to bytes using a specific encoding (e.g., UTF-8) + - Get the length of the encoded bytes. + + Args: + data (str): Data string. + + Returns: + int: Size of string in bytes. + """ + encoded_content = data.encode("utf-8") + size = len(encoded_content) + return size + + +class PebbloLoaderAPIWrapper(BaseModel): + """Wrapper for Pebblo Loader API.""" + + api_key: Optional[str] # Use SecretStr + """API key for Pebblo Cloud""" + classifier_location: str = "local" + """Location of the classifier, local or cloud. Defaults to 'local'""" + classifier_url: Optional[str] + """URL of the Pebblo Classifier""" + cloud_url: Optional[str] + """URL of the Pebblo Cloud""" + + def __init__(self, **kwargs: Any): + """Validate that api key in environment.""" + kwargs["api_key"] = get_from_dict_or_env( + kwargs, "api_key", "PEBBLO_API_KEY", "" + ) + kwargs["classifier_url"] = get_from_dict_or_env( + kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL + ) + kwargs["cloud_url"] = get_from_dict_or_env( + kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL + ) + super().__init__(**kwargs) + + def send_loader_discover(self, app: App) -> None: + """ + Send app discovery request to Pebblo server & cloud. + + Args: + app (App): App instance to be discovered. + """ + pebblo_resp = None + payload = app.dict(exclude_unset=True) + + if self.classifier_location == "local": + # Send app details to local classifier + headers = self._make_headers() + app_discover_url = f"{self.classifier_url}{Routes.loader_app_discover}" + pebblo_resp = self.make_request("POST", app_discover_url, headers, payload) + + if self.api_key: + # Send app details to Pebblo cloud if api_key is present + headers = self._make_headers(cloud_request=True) + if pebblo_resp: + pebblo_server_version = json.loads(pebblo_resp.text).get( + "pebblo_server_version" + ) + payload.update({"pebblo_server_version": pebblo_server_version}) + + payload.update({"pebblo_client_version": PLUGIN_VERSION}) + pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_app_discover}" + _ = self.make_request("POST", pebblo_cloud_url, headers, payload) + + def classify_documents( + self, + docs_with_id: List[IndexedDocument], + app: App, + loader_details: dict, + loading_end: bool = False, + ) -> dict: + """ + Send documents to Pebblo server for classification. + Then send classified documents to Daxa cloud(If api_key is present). + + Args: + docs_with_id (List[IndexedDocument]): List of documents to be classified. + app (App): App instance. + loader_details (dict): Loader details. + loading_end (bool): Boolean, indicating the halt of data loading by loader. + """ + source_path = loader_details.get("source_path", "") + source_owner = get_file_owner_from_path(source_path) + # Prepare docs for classification + docs, source_aggregate_size = self.prepare_docs_for_classification( + docs_with_id, source_path + ) + # Build payload for classification + payload = self.build_classification_payload( + app, docs, loader_details, source_owner, source_aggregate_size, loading_end + ) + + classified_docs = {} + if self.classifier_location == "local": + # Send docs to local classifier + headers = self._make_headers() + load_doc_url = f"{self.classifier_url}{Routes.loader_doc}" + try: + pebblo_resp = self.make_request( + "POST", load_doc_url, headers, payload, 300 + ) + + if pebblo_resp: + # Updating structure of pebblo response docs for efficient searching + for classified_doc in json.loads(pebblo_resp.text).get("docs", []): + classified_docs.update( + {classified_doc["pb_id"]: classified_doc} + ) + except Exception as e: + logger.warning("An Exception caught in classify_documents: local %s", e) + + if self.api_key: + # Send docs to Pebblo cloud if api_key is present + if self.classifier_location == "local": + # If local classifier is used add the classified information + # and remove doc content + self.update_doc_data(payload["docs"], classified_docs) + self.send_docs_to_pebblo_cloud(payload) + elif self.classifier_location == "pebblo-cloud": + logger.warning("API key is missing for sending docs to Pebblo cloud.") + raise NameError("API key is missing for sending docs to Pebblo cloud.") + + return classified_docs + + def send_docs_to_pebblo_cloud(self, payload: dict) -> None: + """ + Send documents to Pebblo cloud. + + Args: + payload (dict): The payload containing documents to be sent. + """ + headers = self._make_headers(cloud_request=True) + pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_doc}" + try: + _ = self.make_request("POST", pebblo_cloud_url, headers, payload) + except Exception as e: + logger.warning("An Exception caught in classify_documents: cloud %s", e) + + def _make_headers(self, cloud_request: bool = False) -> dict: + """ + Generate headers for the request. + + args: + cloud_request (bool): flag indicating whether the request is for Pebblo + cloud. + returns: + dict: Headers for the request. + + """ + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if cloud_request: + # Add API key for Pebblo cloud request + if self.api_key: + headers.update({"x-api-key": self.api_key}) + else: + logger.warning("API key is missing for Pebblo cloud request.") + return headers + + def build_classification_payload( + self, + app: App, + docs: List[dict], + loader_details: dict, + source_owner: str, + source_aggregate_size: int, + loading_end: bool, + ) -> dict: + """ + Build the payload for document classification. + + Args: + app (App): App instance. + docs (List[dict]): List of documents to be classified. + loader_details (dict): Loader details. + source_owner (str): Owner of the source. + source_aggregate_size (int): Aggregate size of the source. + loading_end (bool): Boolean indicating the halt of data loading by loader. + + Returns: + dict: Payload for document classification. + """ + payload: Dict[str, Any] = { + "name": app.name, + "owner": app.owner, + "docs": docs, + "plugin_version": PLUGIN_VERSION, + "load_id": app.load_id, + "loader_details": loader_details, + "loading_end": "false", + "source_owner": source_owner, + "classifier_location": self.classifier_location, + } + if loading_end is True: + payload["loading_end"] = "true" + if "loader_details" in payload: + payload["loader_details"]["source_aggregate_size"] = ( + source_aggregate_size + ) + payload = Doc(**payload).dict(exclude_unset=True) + return payload + + @staticmethod + def make_request( + method: str, + url: str, + headers: dict, + payload: Optional[dict] = None, + timeout: int = 20, + ) -> Optional[Response]: + """ + Make a request to the Pebblo API + + Args: + method (str): HTTP method (GET, POST, PUT, DELETE, etc.). + url (str): URL for the request. + headers (dict): Headers for the request. + payload (Optional[dict]): Payload for the request (for POST, PUT, etc.). + timeout (int): Timeout for the request in seconds. + + Returns: + Optional[Response]: Response object if the request is successful. + """ + try: + response = request( + method=method, url=url, headers=headers, json=payload, timeout=timeout + ) + logger.debug( + "Request: method %s, url %s, len %s response status %s", + method, + response.request.url, + str(len(response.request.body if response.request.body else [])), + str(response.status_code), + ) + + if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: + logger.warning(f"Pebblo Server: Error {response.status_code}") + elif response.status_code >= HTTPStatus.BAD_REQUEST: + logger.warning(f"Pebblo received an invalid payload: {response.text}") + elif response.status_code != HTTPStatus.OK: + logger.warning( + f"Pebblo returned an unexpected response code: " + f"{response.status_code}" + ) + + return response + except RequestException: + logger.warning("Unable to reach server %s", url) + except Exception as e: + logger.warning("An Exception caught in make_request: %s", e) + return None + + @staticmethod + def prepare_docs_for_classification( + docs_with_id: List[IndexedDocument], source_path: str + ) -> Tuple[List[dict], int]: + """ + Prepare documents for classification. + + Args: + docs_with_id (List[IndexedDocument]): List of documents to be classified. + source_path (str): Source path of the documents. + + Returns: + Tuple[List[dict], int]: Documents and the aggregate size of the source. + """ + docs = [] + source_aggregate_size = 0 + doc_content = [doc.dict() for doc in docs_with_id] + for doc in doc_content: + doc_metadata = doc.get("metadata", {}) + doc_authorized_identities = doc_metadata.get("authorized_identities", []) + doc_source_path = get_full_path( + doc_metadata.get( + "full_path", + doc_metadata.get("source", source_path), + ) + ) + doc_source_owner = doc_metadata.get( + "owner", get_file_owner_from_path(doc_source_path) + ) + doc_source_size = doc_metadata.get("size", get_source_size(doc_source_path)) + page_content = str(doc.get("page_content")) + page_content_size = calculate_content_size(page_content) + source_aggregate_size += page_content_size + doc_id = doc.get("pb_id", None) or 0 + docs.append( + { + "doc": page_content, + "source_path": doc_source_path, + "pb_id": doc_id, + "last_modified": doc.get("metadata", {}).get("last_modified"), + "file_owner": doc_source_owner, + **( + {"authorized_identities": doc_authorized_identities} + if doc_authorized_identities + else {} + ), + **( + {"source_path_size": doc_source_size} + if doc_source_size is not None + else {} + ), + } + ) + return docs, source_aggregate_size + + @staticmethod + def update_doc_data(docs: List[dict], classified_docs: dict) -> None: + """ + Update the document data with classified information. + + Args: + docs (List[dict]): List of document data to be updated. + classified_docs (dict): The dictionary containing classified documents. + """ + for doc_data in docs: + classified_data = classified_docs.get(doc_data["pb_id"], {}) + # Update the document data with classified information + doc_data.update( + { + "pb_checksum": classified_data.get("pb_checksum"), + "loader_source_path": classified_data.get("loader_source_path"), + "entities": classified_data.get("entities", {}), + "topics": classified_data.get("topics", {}), + } + ) + # Remove the document content + doc_data.pop("doc") diff --git a/libs/community/tests/unit_tests/document_loaders/test_pebblo.py b/libs/community/tests/unit_tests/document_loaders/test_pebblo.py index 2d6256b5044de..89617b9cd5fa3 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_pebblo.py +++ b/libs/community/tests/unit_tests/document_loaders/test_pebblo.py @@ -144,4 +144,5 @@ def test_pebblo_safe_loader_api_key() -> None: ) # Assert - assert loader.api_key == api_key + assert loader.pb_client.api_key == api_key + assert loader.pb_client.classifier_location == "local"