diff --git a/src/llama_stack_client/_base_client.py b/src/llama_stack_client/_base_client.py index 9640b52..4f4afde 100644 --- a/src/llama_stack_client/_base_client.py +++ b/src/llama_stack_client/_base_client.py @@ -792,6 +792,7 @@ def __init__( custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, ) -> None: + kwargs: dict[str, Any] = {} if limits is not None: warnings.warn( "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", @@ -804,6 +805,7 @@ def __init__( limits = DEFAULT_CONNECTION_LIMITS if transport is not None: + kwargs["transport"] = transport warnings.warn( "The `transport` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -813,6 +815,7 @@ def __init__( raise ValueError("The `http_client` argument is mutually exclusive with `transport`") if proxies is not None: + kwargs["proxies"] = proxies warnings.warn( "The `proxies` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -856,10 +859,9 @@ def __init__( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - proxies=proxies, - transport=transport, limits=limits, follow_redirects=True, + **kwargs, # type: ignore ) def is_closed(self) -> bool: @@ -1358,6 +1360,7 @@ def __init__( custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: + kwargs: dict[str, Any] = {} if limits is not None: warnings.warn( "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", @@ -1370,6 +1373,7 @@ def __init__( limits = DEFAULT_CONNECTION_LIMITS if transport is not None: + kwargs["transport"] = transport warnings.warn( "The `transport` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -1379,6 +1383,7 @@ def __init__( raise ValueError("The `http_client` argument is mutually exclusive with `transport`") if proxies is not None: + kwargs["proxies"] = proxies warnings.warn( "The `proxies` argument is deprecated. The `http_client` argument should be passed instead", category=DeprecationWarning, @@ -1422,10 +1427,9 @@ def __init__( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), - proxies=proxies, - transport=transport, limits=limits, follow_redirects=True, + **kwargs, # type: ignore ) def is_closed(self) -> bool: diff --git a/src/llama_stack_client/_compat.py b/src/llama_stack_client/_compat.py index 4794129..92d9ee6 100644 --- a/src/llama_stack_client/_compat.py +++ b/src/llama_stack_client/_compat.py @@ -145,7 +145,8 @@ def model_dump( exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, - warnings=warnings, + # warnings are not supported in Pydantic v1 + warnings=warnings if PYDANTIC_V2 else True, ) return cast( "dict[str, Any]", @@ -213,9 +214,6 @@ def __set_name__(self, owner: type[Any], name: str) -> None: ... # __set__ is not defined at runtime, but @cached_property is designed to be settable def __set__(self, instance: object, value: _T) -> None: ... else: - try: - from functools import cached_property as cached_property - except ImportError: - from cached_property import cached_property as cached_property + from functools import cached_property as cached_property typed_cached_property = cached_property diff --git a/tests/test_client.py b/tests/test_client.py index c1f8496..6d92271 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -675,12 +675,25 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/alpha/models/register").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/alpha/inference/chat-completion").mock( + side_effect=httpx.TimeoutException("Test timeout error") + ) with pytest.raises(APITimeoutError): self.client.post( - "/alpha/models/register", - body=cast(object, dict(model_id="model_id")), + "/alpha/inference/chat-completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -690,12 +703,23 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/alpha/models/register").mock(return_value=httpx.Response(500)) + respx_mock.post("/alpha/inference/chat-completion").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): self.client.post( - "/alpha/models/register", - body=cast(object, dict(model_id="model_id")), + "/alpha/inference/chat-completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -726,9 +750,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler) + respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler) - response = client.models.with_raw_response.register(model_id="model_id") + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ) assert response.retries_taken == failures_before_success assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success @@ -750,10 +782,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler) - - response = client.models.with_raw_response.register( - model_id="model_id", extra_headers={"x-stainless-retry-count": Omit()} + respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler) + + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + extra_headers={"x-stainless-retry-count": Omit()}, ) assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 @@ -775,10 +814,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler) - - response = client.models.with_raw_response.register( - model_id="model_id", extra_headers={"x-stainless-retry-count": "42"} + respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler) + + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + extra_headers={"x-stainless-retry-count": "42"}, ) assert response.http_request.headers.get("x-stainless-retry-count") == "42" @@ -1416,12 +1462,25 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/alpha/models/register").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/alpha/inference/chat-completion").mock( + side_effect=httpx.TimeoutException("Test timeout error") + ) with pytest.raises(APITimeoutError): await self.client.post( - "/alpha/models/register", - body=cast(object, dict(model_id="model_id")), + "/alpha/inference/chat-completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1431,12 +1490,23 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/alpha/models/register").mock(return_value=httpx.Response(500)) + respx_mock.post("/alpha/inference/chat-completion").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): await self.client.post( - "/alpha/models/register", - body=cast(object, dict(model_id="model_id")), + "/alpha/inference/chat-completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1468,9 +1538,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler) + respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler) - response = await client.models.with_raw_response.register(model_id="model_id") + response = await client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + ) assert response.retries_taken == failures_before_success assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success @@ -1493,10 +1571,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler) - - response = await client.models.with_raw_response.register( - model_id="model_id", extra_headers={"x-stainless-retry-count": Omit()} + respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler) + + response = await client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + extra_headers={"x-stainless-retry-count": Omit()}, ) assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 @@ -1519,10 +1604,17 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/alpha/models/register").mock(side_effect=retry_handler) - - response = await client.models.with_raw_response.register( - model_id="model_id", extra_headers={"x-stainless-retry-count": "42"} + respx_mock.post("/alpha/inference/chat-completion").mock(side_effect=retry_handler) + + response = await client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model_id="model_id", + extra_headers={"x-stainless-retry-count": "42"}, ) assert response.http_request.headers.get("x-stainless-retry-count") == "42" @@ -1539,7 +1631,7 @@ def test_get_platform(self) -> None: import threading from llama_stack_client._utils import asyncify - from llama_stack_client._base_client import get_platform + from llama_stack_client._base_client import get_platform async def test_main() -> None: result = await asyncify(get_platform)() diff --git a/tests/test_models.py b/tests/test_models.py index 59458bf..4a98774 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -561,6 +561,14 @@ class Model(BaseModel): m.model_dump(warnings=False) +def test_compat_method_no_error_for_warnings() -> None: + class Model(BaseModel): + foo: Optional[str] + + m = Model(foo="hello") + assert isinstance(model_dump(m, warnings=False), dict) + + def test_to_json() -> None: class Model(BaseModel): foo: Optional[str] = Field(alias="FOO", default=None)