Skip to content

Commit

Permalink
sync SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Nov 29, 2024
1 parent 9c8f7b1 commit fb4ba99
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 42 deletions.
12 changes: 8 additions & 4 deletions src/llama_stack_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def __init__(
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
Expand All @@ -804,6 +805,7 @@ def __init__(
limits = DEFAULT_CONNECTION_LIMITS

if transport is not None:
kwargs["transport"] = transport
warnings.warn(
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand All @@ -813,6 +815,7 @@ def __init__(
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")

if proxies is not None:
kwargs["proxies"] = proxies
warnings.warn(
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand Down Expand Up @@ -856,10 +859,9 @@ def __init__(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
transport=transport,
limits=limits,
follow_redirects=True,
**kwargs, # type: ignore
)

def is_closed(self) -> bool:
Expand Down Expand Up @@ -1358,6 +1360,7 @@ def __init__(
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
warnings.warn(
"The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead",
Expand All @@ -1370,6 +1373,7 @@ def __init__(
limits = DEFAULT_CONNECTION_LIMITS

if transport is not None:
kwargs["transport"] = transport
warnings.warn(
"The `transport` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand All @@ -1379,6 +1383,7 @@ def __init__(
raise ValueError("The `http_client` argument is mutually exclusive with `transport`")

if proxies is not None:
kwargs["proxies"] = proxies
warnings.warn(
"The `proxies` argument is deprecated. The `http_client` argument should be passed instead",
category=DeprecationWarning,
Expand Down Expand Up @@ -1422,10 +1427,9 @@ def __init__(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
proxies=proxies,
transport=transport,
limits=limits,
follow_redirects=True,
**kwargs, # type: ignore
)

def is_closed(self) -> bool:
Expand Down
8 changes: 3 additions & 5 deletions src/llama_stack_client/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def model_dump(
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
warnings=warnings,
# warnings are not supported in Pydantic v1
warnings=warnings if PYDANTIC_V2 else True,
)
return cast(
"dict[str, Any]",
Expand Down Expand Up @@ -213,9 +214,6 @@ def __set_name__(self, owner: type[Any], name: str) -> None: ...
# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
try:
from functools import cached_property as cached_property
except ImportError:
from cached_property import cached_property as cached_property
from functools import cached_property as cached_property

typed_cached_property = cached_property
158 changes: 125 additions & 33 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,12 +675,25 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(side_effect=httpx.TimeoutException("Test timeout error"))
respx_mock.post("/alpha/inference/chat-completion").mock(
side_effect=httpx.TimeoutException("Test timeout error")
)

with pytest.raises(APITimeoutError):
self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand All @@ -690,12 +703,23 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(return_value=httpx.Response(500))
respx_mock.post("/alpha/inference/chat-completion").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand Down Expand Up @@ -726,9 +750,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = client.models.with_raw_response.register(model_id="model_id")
response = client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
)

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
Expand All @@ -750,10 +782,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": Omit()}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": Omit()},
)

assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
Expand All @@ -775,10 +814,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": "42"}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": "42"},
)

assert response.http_request.headers.get("x-stainless-retry-count") == "42"
Expand Down Expand Up @@ -1416,12 +1462,25 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(side_effect=httpx.TimeoutException("Test timeout error"))
respx_mock.post("/alpha/inference/chat-completion").mock(
side_effect=httpx.TimeoutException("Test timeout error")
)

with pytest.raises(APITimeoutError):
await self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand All @@ -1431,12 +1490,23 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter)
@mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.post("/alpha/models/register").mock(return_value=httpx.Response(500))
respx_mock.post("/alpha/inference/chat-completion").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
await self.client.post(
"/alpha/models/register",
body=cast(object, dict(model_id="model_id")),
"/alpha/inference/chat-completion",
body=cast(
object,
dict(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
Expand Down Expand Up @@ -1468,9 +1538,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = await client.models.with_raw_response.register(model_id="model_id")
response = await client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
)

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
Expand All @@ -1493,10 +1571,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = await client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": Omit()}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = await client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": Omit()},
)

assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
Expand All @@ -1519,10 +1604,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler)

response = await client.models.with_raw_response.register(
model_id="model_id", extra_headers={"x-stainless-retry-count": "42"}
respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler)

response = await client.inference.with_raw_response.chat_completion(
messages=[
{
"content": "string",
"role": "user",
}
],
model_id="model_id",
extra_headers={"x-stainless-retry-count": "42"},
)

assert response.http_request.headers.get("x-stainless-retry-count") == "42"
Expand All @@ -1539,7 +1631,7 @@ def test_get_platform(self) -> None:
import threading
from llama_stack_client._utils import asyncify
from llama_stack_client._base_client import get_platform
from llama_stack_client._base_client import get_platform
async def test_main() -> None:
result = await asyncify(get_platform)()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,14 @@ class Model(BaseModel):
m.model_dump(warnings=False)


def test_compat_method_no_error_for_warnings() -> None:
class Model(BaseModel):
foo: Optional[str]

m = Model(foo="hello")
assert isinstance(model_dump(m, warnings=False), dict)


def test_to_json() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)
Expand Down

0 comments on commit fb4ba99

Please sign in to comment.