diff --git a/llama_parse/base.py b/llama_parse/base.py index b1727b4..bda8b9c 100644 --- a/llama_parse/base.py +++ b/llama_parse/base.py @@ -4,7 +4,8 @@ import mimetypes import time from pathlib import Path -from typing import List, Optional, Union +from typing import AsyncGenerator, List, Optional, Union +from contextlib import asynccontextmanager from io import BufferedIOBase from llama_index.core.async_utils import run_jobs @@ -141,6 +142,9 @@ class LlamaParse(BasePydanticReader): default=False, description="Whether to take screenshot of each page of the document.", ) + custom_client: Optional[httpx.AsyncClient] = Field( + default=None, description="A custom HTTPX client to use for sending requests." + ) @field_validator("api_key", mode="before", check_fields=True) @classmethod @@ -163,6 +167,15 @@ def validate_base_url(cls, v: str) -> str: url = os.getenv("LLAMA_CLOUD_BASE_URL", None) return url or v or DEFAULT_BASE_URL + @asynccontextmanager + async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create a context for the HTTPX client.""" + if self.custom_client is not None: + yield self.custom_client + else: + async with httpx.AsyncClient(timeout=self.max_timeout) as client: + yield client + # upload a document and get back a job_id async def _create_job( self, file_input: FileInput, extra_info: Optional[dict] = None @@ -231,7 +244,7 @@ async def _create_job( data["target_pages"] = self.target_pages try: - async with httpx.AsyncClient(timeout=self.max_timeout) as client: + async with self.client_context() as client: response = await client.post( url, files=files, @@ -257,7 +270,7 @@ async def _get_job_result( tries = 0 while True: await asyncio.sleep(self.check_interval) - async with httpx.AsyncClient(timeout=self.max_timeout) as client: + async with self.client_context() as client: tries += 1 result = await client.get(status_url, headers=headers) @@ -447,7 +460,9 @@ def get_json_result( else: raise e - def get_images(self, json_result: List[dict], download_path: str) -> List[dict]: + async def aget_images( + self, json_result: List[dict], download_path: str + ) -> List[dict]: """Download images from the parsed result.""" headers = {"Authorization": f"Bearer {self.api_key}"} @@ -481,11 +496,12 @@ def get_images(self, json_result: List[dict], download_path: str) -> List[dict]: image["page_number"] = page["page"] with open(image_path, "wb") as f: image_url = f"{self.base_url}/api/parsing/job/{job_id}/result/image/{image_name}" - f.write( - httpx.get( + async with self.client_context() as client: + res = await client.get( image_url, headers=headers, timeout=self.max_timeout - ).content - ) + ) + res.raise_for_status() + f.write(res.content) images.append(image) return images except Exception as e: @@ -495,6 +511,16 @@ def get_images(self, json_result: List[dict], download_path: str) -> List[dict]: else: raise e + def get_images(self, json_result: List[dict], download_path: str) -> List[dict]: + """Download images from the parsed result.""" + try: + return asyncio.run(self.aget_images(json_result, download_path)) + except RuntimeError as e: + if nest_asyncio_err in str(e): + raise RuntimeError(nest_asyncio_msg) + else: + raise e + def _get_sub_docs(self, docs: List[Document]) -> List[Document]: """Split docs into pages, by separator.""" sub_docs = [] diff --git a/tests/test_reader.py b/tests/test_reader.py index 5ebc2ed..091da24 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -1,5 +1,6 @@ import os import pytest +from httpx import AsyncClient from llama_parse import LlamaParse @@ -93,3 +94,18 @@ def test_simple_page_progress_workers() -> None: result = parser.load_data([filepath, filepath]) assert len(result) == 2 assert len(result[0].text) > 0 + + +@pytest.mark.skipif( + os.environ.get("LLAMA_CLOUD_API_KEY", "") == "", + reason="LLAMA_CLOUD_API_KEY not set", +) +def test_custom_client() -> None: + custom_client = AsyncClient(verify=False, timeout=10) + parser = LlamaParse(result_type="markdown", custom_client=custom_client) + filepath = os.path.join( + os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf" + ) + result = parser.load_data(filepath) + assert len(result) == 1 + assert len(result[0].text) > 0