Skip to content

Commit

Permalink
allow custom httpx client (#384)
Browse files Browse the repository at this point in the history
* allow custom httpx client

* split into aget_images + unit test

* typo
  • Loading branch information
sourabhdesai authored Sep 7, 2024
1 parent f304c2d commit fd3836e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
42 changes: 34 additions & 8 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import mimetypes
import time
from pathlib import Path
from typing import List, Optional, Union
from typing import AsyncGenerator, List, Optional, Union
from contextlib import asynccontextmanager
from io import BufferedIOBase

from llama_index.core.async_utils import run_jobs
Expand Down Expand Up @@ -141,6 +142,9 @@ class LlamaParse(BasePydanticReader):
default=False,
description="Whether to take screenshot of each page of the document.",
)
custom_client: Optional[httpx.AsyncClient] = Field(
default=None, description="A custom HTTPX client to use for sending requests."
)

@field_validator("api_key", mode="before", check_fields=True)
@classmethod
Expand All @@ -163,6 +167,15 @@ def validate_base_url(cls, v: str) -> str:
url = os.getenv("LLAMA_CLOUD_BASE_URL", None)
return url or v or DEFAULT_BASE_URL

@asynccontextmanager
async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Create a context for the HTTPX client."""
if self.custom_client is not None:
yield self.custom_client
else:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
yield client

# upload a document and get back a job_id
async def _create_job(
self, file_input: FileInput, extra_info: Optional[dict] = None
Expand Down Expand Up @@ -231,7 +244,7 @@ async def _create_job(
data["target_pages"] = self.target_pages

try:
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
async with self.client_context() as client:
response = await client.post(
url,
files=files,
Expand All @@ -257,7 +270,7 @@ async def _get_job_result(
tries = 0
while True:
await asyncio.sleep(self.check_interval)
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
async with self.client_context() as client:
tries += 1

result = await client.get(status_url, headers=headers)
Expand Down Expand Up @@ -447,7 +460,9 @@ def get_json_result(
else:
raise e

def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
async def aget_images(
self, json_result: List[dict], download_path: str
) -> List[dict]:
"""Download images from the parsed result."""
headers = {"Authorization": f"Bearer {self.api_key}"}

Expand Down Expand Up @@ -481,11 +496,12 @@ def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
image["page_number"] = page["page"]
with open(image_path, "wb") as f:
image_url = f"{self.base_url}/api/parsing/job/{job_id}/result/image/{image_name}"
f.write(
httpx.get(
async with self.client_context() as client:
res = await client.get(
image_url, headers=headers, timeout=self.max_timeout
).content
)
)
res.raise_for_status()
f.write(res.content)
images.append(image)
return images
except Exception as e:
Expand All @@ -495,6 +511,16 @@ def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
else:
raise e

def get_images(self, json_result: List[dict], download_path: str) -> List[dict]:
"""Download images from the parsed result."""
try:
return asyncio.run(self.aget_images(json_result, download_path))
except RuntimeError as e:
if nest_asyncio_err in str(e):
raise RuntimeError(nest_asyncio_msg)
else:
raise e

def _get_sub_docs(self, docs: List[Document]) -> List[Document]:
"""Split docs into pages, by separator."""
sub_docs = []
Expand Down
16 changes: 16 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pytest
from httpx import AsyncClient
from llama_parse import LlamaParse


Expand Down Expand Up @@ -93,3 +94,18 @@ def test_simple_page_progress_workers() -> None:
result = parser.load_data([filepath, filepath])
assert len(result) == 2
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_custom_client() -> None:
custom_client = AsyncClient(verify=False, timeout=10)
parser = LlamaParse(result_type="markdown", custom_client=custom_client)
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath)
assert len(result) == 1
assert len(result[0].text) > 0

0 comments on commit fd3836e

Please sign in to comment.