diff --git a/.stats.yml b/.stats.yml
index 68a0a53..4517049 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,2 +1,2 @@
-configured_endpoints: 55
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/meta%2Fllama-stack-a836f3ef44b0852623ef583d04501ee7fcfc4a72af8b4b54cc9d3b415ed038aa.yml
+configured_endpoints: 51
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/meta%2Fllama-stack-d52e4c19360cc636336d6a60ba6af1db89736fc0a3025c2b1d11870a5f1a1e3d.yml
diff --git a/api.md b/api.md
index 5561fa2..636407d 100644
--- a/api.md
+++ b/api.md
@@ -32,14 +32,10 @@ Types:
```python
from llama_stack.types import (
- CustomQueryGeneratorConfig,
- DefaultQueryGeneratorConfig,
InferenceStep,
- LlmQueryGeneratorConfig,
MemoryRetrievalStep,
RestAPIExecutionConfig,
ShieldCallStep,
- ShieldDefinition,
ToolExecutionStep,
ToolParamDefinition,
AgentCreateResponse,
@@ -196,49 +192,49 @@ Methods:
Types:
```python
-from llama_stack.types import SheidResponse, SafetyRunShieldsResponse
+from llama_stack.types import RunSheidResponse
```
Methods:
-- client.safety.run_shields(\*\*params) -> SafetyRunShieldsResponse
+- client.safety.run_shield(\*\*params) -> RunSheidResponse
-# MemoryBanks
+# Memory
Types:
```python
from llama_stack.types import (
QueryDocuments,
- MemoryBankCreateResponse,
- MemoryBankRetrieveResponse,
- MemoryBankListResponse,
- MemoryBankDropResponse,
+ MemoryCreateResponse,
+ MemoryRetrieveResponse,
+ MemoryListResponse,
+ MemoryDropResponse,
)
```
Methods:
-- client.memory_banks.create(\*\*params) -> object
-- client.memory_banks.retrieve(\*\*params) -> object
-- client.memory_banks.update(\*\*params) -> None
-- client.memory_banks.list() -> object
-- client.memory_banks.drop(\*\*params) -> str
-- client.memory_banks.insert(\*\*params) -> None
-- client.memory_banks.query(\*\*params) -> QueryDocuments
+- client.memory.create(\*\*params) -> object
+- client.memory.retrieve(\*\*params) -> object
+- client.memory.update(\*\*params) -> None
+- client.memory.list() -> object
+- client.memory.drop(\*\*params) -> str
+- client.memory.insert(\*\*params) -> None
+- client.memory.query(\*\*params) -> QueryDocuments
## Documents
Types:
```python
-from llama_stack.types.memory_banks import DocumentRetrieveResponse
+from llama_stack.types.memory import DocumentRetrieveResponse
```
Methods:
-- client.memory_banks.documents.retrieve(\*\*params) -> DocumentRetrieveResponse
-- client.memory_banks.documents.delete(\*\*params) -> None
+- client.memory.documents.retrieve(\*\*params) -> DocumentRetrieveResponse
+- client.memory.documents.delete(\*\*params) -> None
# PostTraining
@@ -309,3 +305,42 @@ Methods:
- client.batch_inference.chat_completion(\*\*params) -> BatchChatCompletion
- client.batch_inference.completion(\*\*params) -> BatchCompletion
+
+# Models
+
+Types:
+
+```python
+from llama_stack.types import ModelServingSpec
+```
+
+Methods:
+
+- client.models.list() -> ModelServingSpec
+- client.models.get(\*\*params) -> Optional
+
+# MemoryBanks
+
+Types:
+
+```python
+from llama_stack.types import MemoryBankSpec
+```
+
+Methods:
+
+- client.memory_banks.list() -> MemoryBankSpec
+- client.memory_banks.get(\*\*params) -> Optional
+
+# Shields
+
+Types:
+
+```python
+from llama_stack.types import ShieldSpec
+```
+
+Methods:
+
+- client.shields.list() -> ShieldSpec
+- client.shields.get(\*\*params) -> Optional
diff --git a/src/llama_stack/_base_client.py b/src/llama_stack/_base_client.py
index f6e5f9a..25a1c43 100644
--- a/src/llama_stack/_base_client.py
+++ b/src/llama_stack/_base_client.py
@@ -400,14 +400,7 @@ def _make_status_error(
) -> _exceptions.APIStatusError:
raise NotImplementedError()
- def _remaining_retries(
- self,
- remaining_retries: Optional[int],
- options: FinalRequestOptions,
- ) -> int:
- return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
-
- def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
+ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)
@@ -419,6 +412,8 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
+ headers.setdefault("x-stainless-retry-count", str(retries_taken))
+
return headers
def _prepare_url(self, url: str) -> URL:
@@ -440,6 +435,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
def _build_request(
self,
options: FinalRequestOptions,
+ *,
+ retries_taken: int = 0,
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
@@ -455,7 +452,7 @@ def _build_request(
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
- headers = self._build_headers(options)
+ headers = self._build_headers(options, retries_taken=retries_taken)
params = _merge_mappings(self.default_query, options.params)
content_type = headers.get("Content-Type")
files = options.files
@@ -489,12 +486,17 @@ def _build_request(
if not files:
files = cast(HttpxRequestFiles, ForceMultipartDict())
+ prepared_url = self._prepare_url(options.url)
+ if "_" in prepared_url.host:
+ # work around https://github.com/encode/httpx/discussions/2880
+ kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")}
+
# TODO: report this error to httpx
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
method=options.method,
- url=self._prepare_url(options.url),
+ url=prepared_url,
# the `Query` type that we use is incompatible with qs'
# `Params` type as it needs to be typed as `Mapping[str, object]`
# so that passing a `TypedDict` doesn't cause an error.
@@ -933,12 +935,17 @@ def request(
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
+ if remaining_retries is not None:
+ retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
+ else:
+ retries_taken = 0
+
return self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
- remaining_retries=remaining_retries,
+ retries_taken=retries_taken,
)
def _request(
@@ -946,7 +953,7 @@ def _request(
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: int | None,
+ retries_taken: int,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
@@ -958,8 +965,8 @@ def _request(
cast_to = self._maybe_override_cast_to(cast_to, options)
options = self._prepare_options(options)
- retries = self._remaining_retries(remaining_retries, options)
- request = self._build_request(options)
+ remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
self._prepare_request(request)
kwargs: HttpxSendArgs = {}
@@ -977,11 +984,11 @@ def _request(
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)
- if retries > 0:
+ if remaining_retries > 0:
return self._retry_request(
input_options,
cast_to,
- retries,
+ retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
@@ -992,11 +999,11 @@ def _request(
except Exception as err:
log.debug("Encountered Exception", exc_info=True)
- if retries > 0:
+ if remaining_retries > 0:
return self._retry_request(
input_options,
cast_to,
- retries,
+ retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
@@ -1019,13 +1026,13 @@ def _request(
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
- if retries > 0 and self._should_retry(err.response):
+ if remaining_retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
input_options,
cast_to,
- retries,
- err.response.headers,
+ retries_taken=retries_taken,
+ response_headers=err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
@@ -1044,26 +1051,26 @@ def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
- retries_taken=options.get_max_retries(self.max_retries) - retries,
+ retries_taken=retries_taken,
)
def _retry_request(
self,
options: FinalRequestOptions,
cast_to: Type[ResponseT],
- remaining_retries: int,
- response_headers: httpx.Headers | None,
*,
+ retries_taken: int,
+ response_headers: httpx.Headers | None,
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
- remaining = remaining_retries - 1
- if remaining == 1:
+ remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ if remaining_retries == 1:
log.debug("1 retry left")
else:
- log.debug("%i retries left", remaining)
+ log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
@@ -1073,7 +1080,7 @@ def _retry_request(
return self._request(
options=options,
cast_to=cast_to,
- remaining_retries=remaining,
+ retries_taken=retries_taken + 1,
stream=stream,
stream_cls=stream_cls,
)
@@ -1491,12 +1498,17 @@ async def request(
stream_cls: type[_AsyncStreamT] | None = None,
remaining_retries: Optional[int] = None,
) -> ResponseT | _AsyncStreamT:
+ if remaining_retries is not None:
+ retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
+ else:
+ retries_taken = 0
+
return await self._request(
cast_to=cast_to,
options=options,
stream=stream,
stream_cls=stream_cls,
- remaining_retries=remaining_retries,
+ retries_taken=retries_taken,
)
async def _request(
@@ -1506,7 +1518,7 @@ async def _request(
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
- remaining_retries: int | None,
+ retries_taken: int,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
@@ -1521,8 +1533,8 @@ async def _request(
cast_to = self._maybe_override_cast_to(cast_to, options)
options = await self._prepare_options(options)
- retries = self._remaining_retries(remaining_retries, options)
- request = self._build_request(options)
+ remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
await self._prepare_request(request)
kwargs: HttpxSendArgs = {}
@@ -1538,11 +1550,11 @@ async def _request(
except httpx.TimeoutException as err:
log.debug("Encountered httpx.TimeoutException", exc_info=True)
- if retries > 0:
+ if remaining_retries > 0:
return await self._retry_request(
input_options,
cast_to,
- retries,
+ retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
@@ -1553,11 +1565,11 @@ async def _request(
except Exception as err:
log.debug("Encountered Exception", exc_info=True)
- if retries > 0:
+ if retries_taken > 0:
return await self._retry_request(
input_options,
cast_to,
- retries,
+ retries_taken=retries_taken,
stream=stream,
stream_cls=stream_cls,
response_headers=None,
@@ -1575,13 +1587,13 @@ async def _request(
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
- if retries > 0 and self._should_retry(err.response):
+ if remaining_retries > 0 and self._should_retry(err.response):
await err.response.aclose()
return await self._retry_request(
input_options,
cast_to,
- retries,
- err.response.headers,
+ retries_taken=retries_taken,
+ response_headers=err.response.headers,
stream=stream,
stream_cls=stream_cls,
)
@@ -1600,26 +1612,26 @@ async def _request(
response=response,
stream=stream,
stream_cls=stream_cls,
- retries_taken=options.get_max_retries(self.max_retries) - retries,
+ retries_taken=retries_taken,
)
async def _retry_request(
self,
options: FinalRequestOptions,
cast_to: Type[ResponseT],
- remaining_retries: int,
- response_headers: httpx.Headers | None,
*,
+ retries_taken: int,
+ response_headers: httpx.Headers | None,
stream: bool,
stream_cls: type[_AsyncStreamT] | None,
) -> ResponseT | _AsyncStreamT:
- remaining = remaining_retries - 1
- if remaining == 1:
+ remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ if remaining_retries == 1:
log.debug("1 retry left")
else:
- log.debug("%i retries left", remaining)
+ log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
await anyio.sleep(timeout)
@@ -1627,7 +1639,7 @@ async def _retry_request(
return await self._request(
options=options,
cast_to=cast_to,
- remaining_retries=remaining,
+ retries_taken=retries_taken + 1,
stream=stream,
stream_cls=stream_cls,
)
diff --git a/src/llama_stack/_client.py b/src/llama_stack/_client.py
index 80d45bd..420716d 100644
--- a/src/llama_stack/_client.py
+++ b/src/llama_stack/_client.py
@@ -59,11 +59,14 @@ class LlamaStack(SyncAPIClient):
evaluations: resources.EvaluationsResource
inference: resources.InferenceResource
safety: resources.SafetyResource
- memory_banks: resources.MemoryBanksResource
+ memory: resources.MemoryResource
post_training: resources.PostTrainingResource
reward_scoring: resources.RewardScoringResource
synthetic_data_generation: resources.SyntheticDataGenerationResource
batch_inference: resources.BatchInferenceResource
+ models: resources.ModelsResource
+ memory_banks: resources.MemoryBanksResource
+ shields: resources.ShieldsResource
with_raw_response: LlamaStackWithRawResponse
with_streaming_response: LlamaStackWithStreamedResponse
@@ -139,11 +142,14 @@ def __init__(
self.evaluations = resources.EvaluationsResource(self)
self.inference = resources.InferenceResource(self)
self.safety = resources.SafetyResource(self)
- self.memory_banks = resources.MemoryBanksResource(self)
+ self.memory = resources.MemoryResource(self)
self.post_training = resources.PostTrainingResource(self)
self.reward_scoring = resources.RewardScoringResource(self)
self.synthetic_data_generation = resources.SyntheticDataGenerationResource(self)
self.batch_inference = resources.BatchInferenceResource(self)
+ self.models = resources.ModelsResource(self)
+ self.memory_banks = resources.MemoryBanksResource(self)
+ self.shields = resources.ShieldsResource(self)
self.with_raw_response = LlamaStackWithRawResponse(self)
self.with_streaming_response = LlamaStackWithStreamedResponse(self)
@@ -254,11 +260,14 @@ class AsyncLlamaStack(AsyncAPIClient):
evaluations: resources.AsyncEvaluationsResource
inference: resources.AsyncInferenceResource
safety: resources.AsyncSafetyResource
- memory_banks: resources.AsyncMemoryBanksResource
+ memory: resources.AsyncMemoryResource
post_training: resources.AsyncPostTrainingResource
reward_scoring: resources.AsyncRewardScoringResource
synthetic_data_generation: resources.AsyncSyntheticDataGenerationResource
batch_inference: resources.AsyncBatchInferenceResource
+ models: resources.AsyncModelsResource
+ memory_banks: resources.AsyncMemoryBanksResource
+ shields: resources.AsyncShieldsResource
with_raw_response: AsyncLlamaStackWithRawResponse
with_streaming_response: AsyncLlamaStackWithStreamedResponse
@@ -334,11 +343,14 @@ def __init__(
self.evaluations = resources.AsyncEvaluationsResource(self)
self.inference = resources.AsyncInferenceResource(self)
self.safety = resources.AsyncSafetyResource(self)
- self.memory_banks = resources.AsyncMemoryBanksResource(self)
+ self.memory = resources.AsyncMemoryResource(self)
self.post_training = resources.AsyncPostTrainingResource(self)
self.reward_scoring = resources.AsyncRewardScoringResource(self)
self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResource(self)
self.batch_inference = resources.AsyncBatchInferenceResource(self)
+ self.models = resources.AsyncModelsResource(self)
+ self.memory_banks = resources.AsyncMemoryBanksResource(self)
+ self.shields = resources.AsyncShieldsResource(self)
self.with_raw_response = AsyncLlamaStackWithRawResponse(self)
self.with_streaming_response = AsyncLlamaStackWithStreamedResponse(self)
@@ -450,13 +462,16 @@ def __init__(self, client: LlamaStack) -> None:
self.evaluations = resources.EvaluationsResourceWithRawResponse(client.evaluations)
self.inference = resources.InferenceResourceWithRawResponse(client.inference)
self.safety = resources.SafetyResourceWithRawResponse(client.safety)
- self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks)
+ self.memory = resources.MemoryResourceWithRawResponse(client.memory)
self.post_training = resources.PostTrainingResourceWithRawResponse(client.post_training)
self.reward_scoring = resources.RewardScoringResourceWithRawResponse(client.reward_scoring)
self.synthetic_data_generation = resources.SyntheticDataGenerationResourceWithRawResponse(
client.synthetic_data_generation
)
self.batch_inference = resources.BatchInferenceResourceWithRawResponse(client.batch_inference)
+ self.models = resources.ModelsResourceWithRawResponse(client.models)
+ self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks)
+ self.shields = resources.ShieldsResourceWithRawResponse(client.shields)
class AsyncLlamaStackWithRawResponse:
@@ -468,13 +483,16 @@ def __init__(self, client: AsyncLlamaStack) -> None:
self.evaluations = resources.AsyncEvaluationsResourceWithRawResponse(client.evaluations)
self.inference = resources.AsyncInferenceResourceWithRawResponse(client.inference)
self.safety = resources.AsyncSafetyResourceWithRawResponse(client.safety)
- self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks)
+ self.memory = resources.AsyncMemoryResourceWithRawResponse(client.memory)
self.post_training = resources.AsyncPostTrainingResourceWithRawResponse(client.post_training)
self.reward_scoring = resources.AsyncRewardScoringResourceWithRawResponse(client.reward_scoring)
self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResourceWithRawResponse(
client.synthetic_data_generation
)
self.batch_inference = resources.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference)
+ self.models = resources.AsyncModelsResourceWithRawResponse(client.models)
+ self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks)
+ self.shields = resources.AsyncShieldsResourceWithRawResponse(client.shields)
class LlamaStackWithStreamedResponse:
@@ -486,13 +504,16 @@ def __init__(self, client: LlamaStack) -> None:
self.evaluations = resources.EvaluationsResourceWithStreamingResponse(client.evaluations)
self.inference = resources.InferenceResourceWithStreamingResponse(client.inference)
self.safety = resources.SafetyResourceWithStreamingResponse(client.safety)
- self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks)
+ self.memory = resources.MemoryResourceWithStreamingResponse(client.memory)
self.post_training = resources.PostTrainingResourceWithStreamingResponse(client.post_training)
self.reward_scoring = resources.RewardScoringResourceWithStreamingResponse(client.reward_scoring)
self.synthetic_data_generation = resources.SyntheticDataGenerationResourceWithStreamingResponse(
client.synthetic_data_generation
)
self.batch_inference = resources.BatchInferenceResourceWithStreamingResponse(client.batch_inference)
+ self.models = resources.ModelsResourceWithStreamingResponse(client.models)
+ self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks)
+ self.shields = resources.ShieldsResourceWithStreamingResponse(client.shields)
class AsyncLlamaStackWithStreamedResponse:
@@ -504,13 +525,16 @@ def __init__(self, client: AsyncLlamaStack) -> None:
self.evaluations = resources.AsyncEvaluationsResourceWithStreamingResponse(client.evaluations)
self.inference = resources.AsyncInferenceResourceWithStreamingResponse(client.inference)
self.safety = resources.AsyncSafetyResourceWithStreamingResponse(client.safety)
- self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks)
+ self.memory = resources.AsyncMemoryResourceWithStreamingResponse(client.memory)
self.post_training = resources.AsyncPostTrainingResourceWithStreamingResponse(client.post_training)
self.reward_scoring = resources.AsyncRewardScoringResourceWithStreamingResponse(client.reward_scoring)
self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResourceWithStreamingResponse(
client.synthetic_data_generation
)
self.batch_inference = resources.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference)
+ self.models = resources.AsyncModelsResourceWithStreamingResponse(client.models)
+ self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks)
+ self.shields = resources.AsyncShieldsResourceWithStreamingResponse(client.shields)
Client = LlamaStack
diff --git a/src/llama_stack/_compat.py b/src/llama_stack/_compat.py
index 21fe694..162a6fb 100644
--- a/src/llama_stack/_compat.py
+++ b/src/llama_stack/_compat.py
@@ -136,12 +136,14 @@ def model_dump(
exclude: IncEx = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
+ warnings: bool = True,
) -> dict[str, Any]:
if PYDANTIC_V2:
return model.model_dump(
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
+ warnings=warnings,
)
return cast(
"dict[str, Any]",
diff --git a/src/llama_stack/resources/__init__.py b/src/llama_stack/resources/__init__.py
index 484c881..e981ad1 100644
--- a/src/llama_stack/resources/__init__.py
+++ b/src/llama_stack/resources/__init__.py
@@ -8,6 +8,22 @@
AgentsResourceWithStreamingResponse,
AsyncAgentsResourceWithStreamingResponse,
)
+from .memory import (
+ MemoryResource,
+ AsyncMemoryResource,
+ MemoryResourceWithRawResponse,
+ AsyncMemoryResourceWithRawResponse,
+ MemoryResourceWithStreamingResponse,
+ AsyncMemoryResourceWithStreamingResponse,
+)
+from .models import (
+ ModelsResource,
+ AsyncModelsResource,
+ ModelsResourceWithRawResponse,
+ AsyncModelsResourceWithRawResponse,
+ ModelsResourceWithStreamingResponse,
+ AsyncModelsResourceWithStreamingResponse,
+)
from .safety import (
SafetyResource,
AsyncSafetyResource,
@@ -16,6 +32,14 @@
SafetyResourceWithStreamingResponse,
AsyncSafetyResourceWithStreamingResponse,
)
+from .shields import (
+ ShieldsResource,
+ AsyncShieldsResource,
+ ShieldsResourceWithRawResponse,
+ AsyncShieldsResourceWithRawResponse,
+ ShieldsResourceWithStreamingResponse,
+ AsyncShieldsResourceWithStreamingResponse,
+)
from .datasets import (
DatasetsResource,
AsyncDatasetsResource,
@@ -140,12 +164,12 @@
"AsyncSafetyResourceWithRawResponse",
"SafetyResourceWithStreamingResponse",
"AsyncSafetyResourceWithStreamingResponse",
- "MemoryBanksResource",
- "AsyncMemoryBanksResource",
- "MemoryBanksResourceWithRawResponse",
- "AsyncMemoryBanksResourceWithRawResponse",
- "MemoryBanksResourceWithStreamingResponse",
- "AsyncMemoryBanksResourceWithStreamingResponse",
+ "MemoryResource",
+ "AsyncMemoryResource",
+ "MemoryResourceWithRawResponse",
+ "AsyncMemoryResourceWithRawResponse",
+ "MemoryResourceWithStreamingResponse",
+ "AsyncMemoryResourceWithStreamingResponse",
"PostTrainingResource",
"AsyncPostTrainingResource",
"PostTrainingResourceWithRawResponse",
@@ -170,4 +194,22 @@
"AsyncBatchInferenceResourceWithRawResponse",
"BatchInferenceResourceWithStreamingResponse",
"AsyncBatchInferenceResourceWithStreamingResponse",
+ "ModelsResource",
+ "AsyncModelsResource",
+ "ModelsResourceWithRawResponse",
+ "AsyncModelsResourceWithRawResponse",
+ "ModelsResourceWithStreamingResponse",
+ "AsyncModelsResourceWithStreamingResponse",
+ "MemoryBanksResource",
+ "AsyncMemoryBanksResource",
+ "MemoryBanksResourceWithRawResponse",
+ "AsyncMemoryBanksResourceWithRawResponse",
+ "MemoryBanksResourceWithStreamingResponse",
+ "AsyncMemoryBanksResourceWithStreamingResponse",
+ "ShieldsResource",
+ "AsyncShieldsResource",
+ "ShieldsResourceWithRawResponse",
+ "AsyncShieldsResourceWithRawResponse",
+ "ShieldsResourceWithStreamingResponse",
+ "AsyncShieldsResourceWithStreamingResponse",
]
diff --git a/src/llama_stack/resources/agents/agents.py b/src/llama_stack/resources/agents/agents.py
index 48871a2..7a865a5 100644
--- a/src/llama_stack/resources/agents/agents.py
+++ b/src/llama_stack/resources/agents/agents.py
@@ -24,6 +24,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .sessions import (
@@ -84,6 +85,7 @@ def create(
self,
*,
agent_config: agent_create_params.AgentConfig,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -101,6 +103,10 @@ def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/agents/create",
body=maybe_transform({"agent_config": agent_config}, agent_create_params.AgentCreateParams),
@@ -114,6 +120,7 @@ def delete(
self,
*,
agent_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -132,6 +139,10 @@ def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/agents/delete",
body=maybe_transform({"agent_id": agent_id}, agent_delete_params.AgentDeleteParams),
@@ -178,6 +189,7 @@ async def create(
self,
*,
agent_config: agent_create_params.AgentConfig,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -195,6 +207,10 @@ async def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/agents/create",
body=await async_maybe_transform({"agent_config": agent_config}, agent_create_params.AgentCreateParams),
@@ -208,6 +224,7 @@ async def delete(
self,
*,
agent_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -226,6 +243,10 @@ async def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/agents/delete",
body=await async_maybe_transform({"agent_id": agent_id}, agent_delete_params.AgentDeleteParams),
diff --git a/src/llama_stack/resources/agents/sessions.py b/src/llama_stack/resources/agents/sessions.py
index 43eb41b..b8c6ae4 100644
--- a/src/llama_stack/resources/agents/sessions.py
+++ b/src/llama_stack/resources/agents/sessions.py
@@ -9,6 +9,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -52,6 +53,7 @@ def create(
*,
agent_id: str,
session_name: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -69,6 +71,10 @@ def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/agents/session/create",
body=maybe_transform(
@@ -90,6 +96,7 @@ def retrieve(
agent_id: str,
session_id: str,
turn_ids: List[str] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -107,6 +114,10 @@ def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/agents/session/get",
body=maybe_transform({"turn_ids": turn_ids}, session_retrieve_params.SessionRetrieveParams),
@@ -131,6 +142,7 @@ def delete(
*,
agent_id: str,
session_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -149,6 +161,10 @@ def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/agents/session/delete",
body=maybe_transform(
@@ -190,6 +206,7 @@ async def create(
*,
agent_id: str,
session_name: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -207,6 +224,10 @@ async def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/agents/session/create",
body=await async_maybe_transform(
@@ -228,6 +249,7 @@ async def retrieve(
agent_id: str,
session_id: str,
turn_ids: List[str] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -245,6 +267,10 @@ async def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/agents/session/get",
body=await async_maybe_transform({"turn_ids": turn_ids}, session_retrieve_params.SessionRetrieveParams),
@@ -269,6 +295,7 @@ async def delete(
*,
agent_id: str,
session_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -287,6 +314,10 @@ async def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/agents/session/delete",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/agents/steps.py b/src/llama_stack/resources/agents/steps.py
index df6649f..7afaa90 100644
--- a/src/llama_stack/resources/agents/steps.py
+++ b/src/llama_stack/resources/agents/steps.py
@@ -7,6 +7,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -50,6 +51,7 @@ def retrieve(
agent_id: str,
step_id: str,
turn_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -67,6 +69,10 @@ def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/agents/step/get",
options=make_request_options(
@@ -113,6 +119,7 @@ async def retrieve(
agent_id: str,
step_id: str,
turn_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -130,6 +137,10 @@ async def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/agents/step/get",
options=make_request_options(
diff --git a/src/llama_stack/resources/agents/turns.py b/src/llama_stack/resources/agents/turns.py
index 217fff1..ac6767f 100644
--- a/src/llama_stack/resources/agents/turns.py
+++ b/src/llama_stack/resources/agents/turns.py
@@ -11,6 +11,7 @@
from ..._utils import (
required_args,
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -60,6 +61,7 @@ def create(
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
stream: Literal[False] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -88,6 +90,7 @@ def create(
session_id: str,
stream: Literal[True],
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -116,6 +119,7 @@ def create(
session_id: str,
stream: bool,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -144,6 +148,7 @@ def create(
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -151,6 +156,10 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AgentsTurnStreamChunk | Stream[AgentsTurnStreamChunk]:
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/agents/turn/create",
body=maybe_transform(
@@ -176,6 +185,7 @@ def retrieve(
*,
agent_id: str,
turn_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -193,6 +203,10 @@ def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/agents/turn/get",
options=make_request_options(
@@ -241,6 +255,7 @@ async def create(
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
stream: Literal[False] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -269,6 +284,7 @@ async def create(
session_id: str,
stream: Literal[True],
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -297,6 +313,7 @@ async def create(
session_id: str,
stream: bool,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -325,6 +342,7 @@ async def create(
session_id: str,
attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN,
stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -332,6 +350,10 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AgentsTurnStreamChunk | AsyncStream[AgentsTurnStreamChunk]:
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/agents/turn/create",
body=await async_maybe_transform(
@@ -357,6 +379,7 @@ async def retrieve(
*,
agent_id: str,
turn_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -374,6 +397,10 @@ async def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/agents/turn/get",
options=make_request_options(
diff --git a/src/llama_stack/resources/batch_inference.py b/src/llama_stack/resources/batch_inference.py
index 36028b3..7c2c563 100644
--- a/src/llama_stack/resources/batch_inference.py
+++ b/src/llama_stack/resources/batch_inference.py
@@ -11,6 +11,7 @@
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -59,6 +60,7 @@ def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -86,6 +88,10 @@ def chat_completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/batch_inference/chat_completion",
body=maybe_transform(
@@ -113,6 +119,7 @@ def completion(
model: str,
logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -130,6 +137,10 @@ def completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/batch_inference/completion",
body=maybe_transform(
@@ -178,6 +189,7 @@ async def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -205,6 +217,10 @@ async def chat_completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/batch_inference/chat_completion",
body=await async_maybe_transform(
@@ -232,6 +248,7 @@ async def completion(
model: str,
logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -249,6 +266,10 @@ async def completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/batch_inference/completion",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/datasets.py b/src/llama_stack/resources/datasets.py
index 321e301..dcf1005 100644
--- a/src/llama_stack/resources/datasets.py
+++ b/src/llama_stack/resources/datasets.py
@@ -8,6 +8,7 @@
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -50,6 +51,7 @@ def create(
*,
dataset: TrainEvalDatasetParam,
uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -68,6 +70,10 @@ def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/datasets/create",
body=maybe_transform(
@@ -87,6 +93,7 @@ def delete(
self,
*,
dataset_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -105,6 +112,10 @@ def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/datasets/delete",
body=maybe_transform({"dataset_uuid": dataset_uuid}, dataset_delete_params.DatasetDeleteParams),
@@ -118,6 +129,7 @@ def get(
self,
*,
dataset_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -135,6 +147,10 @@ def get(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/datasets/get",
options=make_request_options(
@@ -173,6 +189,7 @@ async def create(
*,
dataset: TrainEvalDatasetParam,
uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -191,6 +208,10 @@ async def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/datasets/create",
body=await async_maybe_transform(
@@ -210,6 +231,7 @@ async def delete(
self,
*,
dataset_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -228,6 +250,10 @@ async def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/datasets/delete",
body=await async_maybe_transform({"dataset_uuid": dataset_uuid}, dataset_delete_params.DatasetDeleteParams),
@@ -241,6 +267,7 @@ async def get(
self,
*,
dataset_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -258,6 +285,10 @@ async def get(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/datasets/get",
options=make_request_options(
diff --git a/src/llama_stack/resources/evaluate/jobs/artifacts.py b/src/llama_stack/resources/evaluate/jobs/artifacts.py
index b5e0a3d..ce03116 100644
--- a/src/llama_stack/resources/evaluate/jobs/artifacts.py
+++ b/src/llama_stack/resources/evaluate/jobs/artifacts.py
@@ -7,6 +7,7 @@
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ...._compat import cached_property
@@ -48,6 +49,7 @@ def list(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -65,6 +67,10 @@ def list(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/evaluate/job/artifacts",
options=make_request_options(
@@ -102,6 +108,7 @@ async def list(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -119,6 +126,10 @@ async def list(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/evaluate/job/artifacts",
options=make_request_options(
diff --git a/src/llama_stack/resources/evaluate/jobs/jobs.py b/src/llama_stack/resources/evaluate/jobs/jobs.py
index bfbeb42..9e64c18 100644
--- a/src/llama_stack/resources/evaluate/jobs/jobs.py
+++ b/src/llama_stack/resources/evaluate/jobs/jobs.py
@@ -23,6 +23,7 @@
from ...._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ...._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .artifacts import (
@@ -83,6 +84,7 @@ def with_streaming_response(self) -> JobsResourceWithStreamingResponse:
def list(
self,
*,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -90,7 +92,21 @@ def list(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EvaluationJob:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/evaluate/jobs",
options=make_request_options(
@@ -103,6 +119,7 @@ def cancel(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -121,6 +138,10 @@ def cancel(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/evaluate/job/cancel",
body=maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams),
@@ -166,6 +187,7 @@ def with_streaming_response(self) -> AsyncJobsResourceWithStreamingResponse:
async def list(
self,
*,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -173,7 +195,21 @@ async def list(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EvaluationJob:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/evaluate/jobs",
options=make_request_options(
@@ -186,6 +222,7 @@ async def cancel(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -204,6 +241,10 @@ async def cancel(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/evaluate/job/cancel",
body=await async_maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams),
diff --git a/src/llama_stack/resources/evaluate/jobs/logs.py b/src/llama_stack/resources/evaluate/jobs/logs.py
index 2aae53c..c1db747 100644
--- a/src/llama_stack/resources/evaluate/jobs/logs.py
+++ b/src/llama_stack/resources/evaluate/jobs/logs.py
@@ -7,6 +7,7 @@
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ...._compat import cached_property
@@ -48,6 +49,7 @@ def list(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -65,6 +67,10 @@ def list(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/evaluate/job/logs",
options=make_request_options(
@@ -102,6 +108,7 @@ async def list(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -119,6 +126,10 @@ async def list(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/evaluate/job/logs",
options=make_request_options(
diff --git a/src/llama_stack/resources/evaluate/jobs/status.py b/src/llama_stack/resources/evaluate/jobs/status.py
index 4719ede..2c3aca8 100644
--- a/src/llama_stack/resources/evaluate/jobs/status.py
+++ b/src/llama_stack/resources/evaluate/jobs/status.py
@@ -7,6 +7,7 @@
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ...._compat import cached_property
@@ -48,6 +49,7 @@ def list(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -65,6 +67,10 @@ def list(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/evaluate/job/status",
options=make_request_options(
@@ -102,6 +108,7 @@ async def list(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -119,6 +126,10 @@ async def list(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/evaluate/job/status",
options=make_request_options(
diff --git a/src/llama_stack/resources/evaluate/question_answering.py b/src/llama_stack/resources/evaluate/question_answering.py
index ca5169f..50b4a0c 100644
--- a/src/llama_stack/resources/evaluate/question_answering.py
+++ b/src/llama_stack/resources/evaluate/question_answering.py
@@ -10,6 +10,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -51,6 +52,7 @@ def create(
self,
*,
metrics: List[Literal["em", "f1"]],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -68,6 +70,10 @@ def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/evaluate/question_answering/",
body=maybe_transform({"metrics": metrics}, question_answering_create_params.QuestionAnsweringCreateParams),
@@ -102,6 +108,7 @@ async def create(
self,
*,
metrics: List[Literal["em", "f1"]],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -119,6 +126,10 @@ async def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/evaluate/question_answering/",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/evaluations.py b/src/llama_stack/resources/evaluations.py
index 6328060..cebe2ba 100644
--- a/src/llama_stack/resources/evaluations.py
+++ b/src/llama_stack/resources/evaluations.py
@@ -11,6 +11,7 @@
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -51,6 +52,7 @@ def summarization(
self,
*,
metrics: List[Literal["rouge", "bleu"]],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -68,6 +70,10 @@ def summarization(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/evaluate/summarization/",
body=maybe_transform({"metrics": metrics}, evaluation_summarization_params.EvaluationSummarizationParams),
@@ -81,6 +87,7 @@ def text_generation(
self,
*,
metrics: List[Literal["perplexity", "rouge", "bleu"]],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -98,6 +105,10 @@ def text_generation(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/evaluate/text_generation/",
body=maybe_transform(
@@ -134,6 +145,7 @@ async def summarization(
self,
*,
metrics: List[Literal["rouge", "bleu"]],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -151,6 +163,10 @@ async def summarization(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/evaluate/summarization/",
body=await async_maybe_transform(
@@ -166,6 +182,7 @@ async def text_generation(
self,
*,
metrics: List[Literal["perplexity", "rouge", "bleu"]],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -183,6 +200,10 @@ async def text_generation(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/evaluate/text_generation/",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/inference/embeddings.py b/src/llama_stack/resources/inference/embeddings.py
index 77e7e5e..0b332cf 100644
--- a/src/llama_stack/resources/inference/embeddings.py
+++ b/src/llama_stack/resources/inference/embeddings.py
@@ -9,6 +9,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -51,6 +52,7 @@ def create(
*,
contents: List[Union[str, List[str]]],
model: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -68,6 +70,10 @@ def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/inference/embeddings",
body=maybe_transform(
@@ -109,6 +115,7 @@ async def create(
*,
contents: List[Union[str, List[str]]],
model: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -126,6 +133,10 @@ async def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/inference/embeddings",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/inference/inference.py b/src/llama_stack/resources/inference/inference.py
index 648736e..1be5b06 100644
--- a/src/llama_stack/resources/inference/inference.py
+++ b/src/llama_stack/resources/inference/inference.py
@@ -12,6 +12,7 @@
from ..._utils import (
required_args,
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -75,6 +76,7 @@ def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -116,6 +118,7 @@ def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -157,6 +160,7 @@ def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -198,6 +202,7 @@ def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -206,6 +211,10 @@ def chat_completion(
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]:
extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return cast(
InferenceChatCompletionResponse,
self._post(
@@ -242,6 +251,7 @@ def completion(
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -259,6 +269,10 @@ def completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return cast(
InferenceCompletionResponse,
self._post(
@@ -319,6 +333,7 @@ async def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -360,6 +375,7 @@ async def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -401,6 +417,7 @@ async def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -442,6 +459,7 @@ async def chat_completion(
tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN,
tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN,
tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -450,6 +468,10 @@ async def chat_completion(
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]:
extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return cast(
InferenceChatCompletionResponse,
await self._post(
@@ -486,6 +508,7 @@ async def completion(
logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN,
sampling_params: SamplingParams | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -503,6 +526,10 @@ async def completion(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return cast(
InferenceCompletionResponse,
await self._post(
diff --git a/src/llama_stack/resources/memory_banks/__init__.py b/src/llama_stack/resources/memory/__init__.py
similarity index 53%
rename from src/llama_stack/resources/memory_banks/__init__.py
rename to src/llama_stack/resources/memory/__init__.py
index f26b272..1438115 100644
--- a/src/llama_stack/resources/memory_banks/__init__.py
+++ b/src/llama_stack/resources/memory/__init__.py
@@ -1,5 +1,13 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+from .memory import (
+ MemoryResource,
+ AsyncMemoryResource,
+ MemoryResourceWithRawResponse,
+ AsyncMemoryResourceWithRawResponse,
+ MemoryResourceWithStreamingResponse,
+ AsyncMemoryResourceWithStreamingResponse,
+)
from .documents import (
DocumentsResource,
AsyncDocumentsResource,
@@ -8,14 +16,6 @@
DocumentsResourceWithStreamingResponse,
AsyncDocumentsResourceWithStreamingResponse,
)
-from .memory_banks import (
- MemoryBanksResource,
- AsyncMemoryBanksResource,
- MemoryBanksResourceWithRawResponse,
- AsyncMemoryBanksResourceWithRawResponse,
- MemoryBanksResourceWithStreamingResponse,
- AsyncMemoryBanksResourceWithStreamingResponse,
-)
__all__ = [
"DocumentsResource",
@@ -24,10 +24,10 @@
"AsyncDocumentsResourceWithRawResponse",
"DocumentsResourceWithStreamingResponse",
"AsyncDocumentsResourceWithStreamingResponse",
- "MemoryBanksResource",
- "AsyncMemoryBanksResource",
- "MemoryBanksResourceWithRawResponse",
- "AsyncMemoryBanksResourceWithRawResponse",
- "MemoryBanksResourceWithStreamingResponse",
- "AsyncMemoryBanksResourceWithStreamingResponse",
+ "MemoryResource",
+ "AsyncMemoryResource",
+ "MemoryResourceWithRawResponse",
+ "AsyncMemoryResourceWithRawResponse",
+ "MemoryResourceWithStreamingResponse",
+ "AsyncMemoryResourceWithStreamingResponse",
]
diff --git a/src/llama_stack/resources/memory_banks/documents.py b/src/llama_stack/resources/memory/documents.py
similarity index 88%
rename from src/llama_stack/resources/memory_banks/documents.py
rename to src/llama_stack/resources/memory/documents.py
index 995bc7a..546ffd4 100644
--- a/src/llama_stack/resources/memory_banks/documents.py
+++ b/src/llama_stack/resources/memory/documents.py
@@ -9,6 +9,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -20,8 +21,8 @@
async_to_streamed_response_wrapper,
)
from ..._base_client import make_request_options
-from ...types.memory_banks import document_delete_params, document_retrieve_params
-from ...types.memory_banks.document_retrieve_response import DocumentRetrieveResponse
+from ...types.memory import document_delete_params, document_retrieve_params
+from ...types.memory.document_retrieve_response import DocumentRetrieveResponse
__all__ = ["DocumentsResource", "AsyncDocumentsResource"]
@@ -51,6 +52,7 @@ def retrieve(
*,
bank_id: str,
document_ids: List[str],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -69,8 +71,12 @@ def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_bank/documents/get",
+ "/memory/documents/get",
body=maybe_transform({"document_ids": document_ids}, document_retrieve_params.DocumentRetrieveParams),
options=make_request_options(
extra_headers=extra_headers,
@@ -87,6 +93,7 @@ def delete(
*,
bank_id: str,
document_ids: List[str],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -105,8 +112,12 @@ def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_bank/documents/delete",
+ "/memory/documents/delete",
body=maybe_transform(
{
"bank_id": bank_id,
@@ -146,6 +157,7 @@ async def retrieve(
*,
bank_id: str,
document_ids: List[str],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -164,8 +176,12 @@ async def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_bank/documents/get",
+ "/memory/documents/get",
body=await async_maybe_transform(
{"document_ids": document_ids}, document_retrieve_params.DocumentRetrieveParams
),
@@ -186,6 +202,7 @@ async def delete(
*,
bank_id: str,
document_ids: List[str],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -204,8 +221,12 @@ async def delete(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_bank/documents/delete",
+ "/memory/documents/delete",
body=await async_maybe_transform(
{
"bank_id": bank_id,
diff --git a/src/llama_stack/resources/memory_banks/memory_banks.py b/src/llama_stack/resources/memory/memory.py
similarity index 73%
rename from src/llama_stack/resources/memory_banks/memory_banks.py
rename to src/llama_stack/resources/memory/memory.py
index 52ddeba..761ab5d 100644
--- a/src/llama_stack/resources/memory_banks/memory_banks.py
+++ b/src/llama_stack/resources/memory/memory.py
@@ -7,16 +7,17 @@
import httpx
from ...types import (
- memory_bank_drop_params,
- memory_bank_query_params,
- memory_bank_create_params,
- memory_bank_insert_params,
- memory_bank_update_params,
- memory_bank_retrieve_params,
+ memory_drop_params,
+ memory_query_params,
+ memory_create_params,
+ memory_insert_params,
+ memory_update_params,
+ memory_retrieve_params,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -38,37 +39,38 @@
from ..._base_client import make_request_options
from ...types.query_documents import QueryDocuments
-__all__ = ["MemoryBanksResource", "AsyncMemoryBanksResource"]
+__all__ = ["MemoryResource", "AsyncMemoryResource"]
-class MemoryBanksResource(SyncAPIResource):
+class MemoryResource(SyncAPIResource):
@cached_property
def documents(self) -> DocumentsResource:
return DocumentsResource(self._client)
@cached_property
- def with_raw_response(self) -> MemoryBanksResourceWithRawResponse:
+ def with_raw_response(self) -> MemoryResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return the
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
"""
- return MemoryBanksResourceWithRawResponse(self)
+ return MemoryResourceWithRawResponse(self)
@cached_property
- def with_streaming_response(self) -> MemoryBanksResourceWithStreamingResponse:
+ def with_streaming_response(self) -> MemoryResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
"""
- return MemoryBanksResourceWithStreamingResponse(self)
+ return MemoryResourceWithStreamingResponse(self)
def create(
self,
*,
body: object,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -86,9 +88,13 @@ def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_banks/create",
- body=maybe_transform(body, memory_bank_create_params.MemoryBankCreateParams),
+ "/memory/create",
+ body=maybe_transform(body, memory_create_params.MemoryCreateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
@@ -99,6 +105,7 @@ def retrieve(
self,
*,
bank_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -116,14 +123,18 @@ def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
- "/memory_banks/get",
+ "/memory/get",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
- query=maybe_transform({"bank_id": bank_id}, memory_bank_retrieve_params.MemoryBankRetrieveParams),
+ query=maybe_transform({"bank_id": bank_id}, memory_retrieve_params.MemoryRetrieveParams),
),
cast_to=object,
)
@@ -132,7 +143,8 @@ def update(
self,
*,
bank_id: str,
- documents: Iterable[memory_bank_update_params.Document],
+ documents: Iterable[memory_update_params.Document],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -151,14 +163,18 @@ def update(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_bank/update",
+ "/memory/update",
body=maybe_transform(
{
"bank_id": bank_id,
"documents": documents,
},
- memory_bank_update_params.MemoryBankUpdateParams,
+ memory_update_params.MemoryUpdateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -169,6 +185,7 @@ def update(
def list(
self,
*,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -176,9 +193,23 @@ def list(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> object:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
- "/memory_banks/list",
+ "/memory/list",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
@@ -189,6 +220,7 @@ def drop(
self,
*,
bank_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -206,9 +238,13 @@ def drop(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_banks/drop",
- body=maybe_transform({"bank_id": bank_id}, memory_bank_drop_params.MemoryBankDropParams),
+ "/memory/drop",
+ body=maybe_transform({"bank_id": bank_id}, memory_drop_params.MemoryDropParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
@@ -219,8 +255,9 @@ def insert(
self,
*,
bank_id: str,
- documents: Iterable[memory_bank_insert_params.Document],
+ documents: Iterable[memory_insert_params.Document],
ttl_seconds: int | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -239,15 +276,19 @@ def insert(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_bank/insert",
+ "/memory/insert",
body=maybe_transform(
{
"bank_id": bank_id,
"documents": documents,
"ttl_seconds": ttl_seconds,
},
- memory_bank_insert_params.MemoryBankInsertParams,
+ memory_insert_params.MemoryInsertParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -261,6 +302,7 @@ def query(
bank_id: str,
query: Union[str, List[str]],
params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -278,15 +320,19 @@ def query(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/memory_bank/query",
+ "/memory/query",
body=maybe_transform(
{
"bank_id": bank_id,
"query": query,
"params": params,
},
- memory_bank_query_params.MemoryBankQueryParams,
+ memory_query_params.MemoryQueryParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -295,34 +341,35 @@ def query(
)
-class AsyncMemoryBanksResource(AsyncAPIResource):
+class AsyncMemoryResource(AsyncAPIResource):
@cached_property
def documents(self) -> AsyncDocumentsResource:
return AsyncDocumentsResource(self._client)
@cached_property
- def with_raw_response(self) -> AsyncMemoryBanksResourceWithRawResponse:
+ def with_raw_response(self) -> AsyncMemoryResourceWithRawResponse:
"""
This property can be used as a prefix for any HTTP method call to return the
the raw response object instead of the parsed content.
For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
"""
- return AsyncMemoryBanksResourceWithRawResponse(self)
+ return AsyncMemoryResourceWithRawResponse(self)
@cached_property
- def with_streaming_response(self) -> AsyncMemoryBanksResourceWithStreamingResponse:
+ def with_streaming_response(self) -> AsyncMemoryResourceWithStreamingResponse:
"""
An alternative to `.with_raw_response` that doesn't eagerly read the response body.
For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
"""
- return AsyncMemoryBanksResourceWithStreamingResponse(self)
+ return AsyncMemoryResourceWithStreamingResponse(self)
async def create(
self,
*,
body: object,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -340,9 +387,13 @@ async def create(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_banks/create",
- body=await async_maybe_transform(body, memory_bank_create_params.MemoryBankCreateParams),
+ "/memory/create",
+ body=await async_maybe_transform(body, memory_create_params.MemoryCreateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
@@ -353,6 +404,7 @@ async def retrieve(
self,
*,
bank_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -370,16 +422,18 @@ async def retrieve(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
- "/memory_banks/get",
+ "/memory/get",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
- query=await async_maybe_transform(
- {"bank_id": bank_id}, memory_bank_retrieve_params.MemoryBankRetrieveParams
- ),
+ query=await async_maybe_transform({"bank_id": bank_id}, memory_retrieve_params.MemoryRetrieveParams),
),
cast_to=object,
)
@@ -388,7 +442,8 @@ async def update(
self,
*,
bank_id: str,
- documents: Iterable[memory_bank_update_params.Document],
+ documents: Iterable[memory_update_params.Document],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -407,14 +462,18 @@ async def update(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_bank/update",
+ "/memory/update",
body=await async_maybe_transform(
{
"bank_id": bank_id,
"documents": documents,
},
- memory_bank_update_params.MemoryBankUpdateParams,
+ memory_update_params.MemoryUpdateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -425,6 +484,7 @@ async def update(
async def list(
self,
*,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -432,9 +492,23 @@ async def list(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> object:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
- "/memory_banks/list",
+ "/memory/list",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
@@ -445,6 +519,7 @@ async def drop(
self,
*,
bank_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -462,9 +537,13 @@ async def drop(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_banks/drop",
- body=await async_maybe_transform({"bank_id": bank_id}, memory_bank_drop_params.MemoryBankDropParams),
+ "/memory/drop",
+ body=await async_maybe_transform({"bank_id": bank_id}, memory_drop_params.MemoryDropParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
@@ -475,8 +554,9 @@ async def insert(
self,
*,
bank_id: str,
- documents: Iterable[memory_bank_insert_params.Document],
+ documents: Iterable[memory_insert_params.Document],
ttl_seconds: int | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -495,15 +575,19 @@ async def insert(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_bank/insert",
+ "/memory/insert",
body=await async_maybe_transform(
{
"bank_id": bank_id,
"documents": documents,
"ttl_seconds": ttl_seconds,
},
- memory_bank_insert_params.MemoryBankInsertParams,
+ memory_insert_params.MemoryInsertParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -517,6 +601,7 @@ async def query(
bank_id: str,
query: Union[str, List[str]],
params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -534,15 +619,19 @@ async def query(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/memory_bank/query",
+ "/memory/query",
body=await async_maybe_transform(
{
"bank_id": bank_id,
"query": query,
"params": params,
},
- memory_bank_query_params.MemoryBankQueryParams,
+ memory_query_params.MemoryQueryParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -551,125 +640,125 @@ async def query(
)
-class MemoryBanksResourceWithRawResponse:
- def __init__(self, memory_banks: MemoryBanksResource) -> None:
- self._memory_banks = memory_banks
+class MemoryResourceWithRawResponse:
+ def __init__(self, memory: MemoryResource) -> None:
+ self._memory = memory
self.create = to_raw_response_wrapper(
- memory_banks.create,
+ memory.create,
)
self.retrieve = to_raw_response_wrapper(
- memory_banks.retrieve,
+ memory.retrieve,
)
self.update = to_raw_response_wrapper(
- memory_banks.update,
+ memory.update,
)
self.list = to_raw_response_wrapper(
- memory_banks.list,
+ memory.list,
)
self.drop = to_raw_response_wrapper(
- memory_banks.drop,
+ memory.drop,
)
self.insert = to_raw_response_wrapper(
- memory_banks.insert,
+ memory.insert,
)
self.query = to_raw_response_wrapper(
- memory_banks.query,
+ memory.query,
)
@cached_property
def documents(self) -> DocumentsResourceWithRawResponse:
- return DocumentsResourceWithRawResponse(self._memory_banks.documents)
+ return DocumentsResourceWithRawResponse(self._memory.documents)
-class AsyncMemoryBanksResourceWithRawResponse:
- def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None:
- self._memory_banks = memory_banks
+class AsyncMemoryResourceWithRawResponse:
+ def __init__(self, memory: AsyncMemoryResource) -> None:
+ self._memory = memory
self.create = async_to_raw_response_wrapper(
- memory_banks.create,
+ memory.create,
)
self.retrieve = async_to_raw_response_wrapper(
- memory_banks.retrieve,
+ memory.retrieve,
)
self.update = async_to_raw_response_wrapper(
- memory_banks.update,
+ memory.update,
)
self.list = async_to_raw_response_wrapper(
- memory_banks.list,
+ memory.list,
)
self.drop = async_to_raw_response_wrapper(
- memory_banks.drop,
+ memory.drop,
)
self.insert = async_to_raw_response_wrapper(
- memory_banks.insert,
+ memory.insert,
)
self.query = async_to_raw_response_wrapper(
- memory_banks.query,
+ memory.query,
)
@cached_property
def documents(self) -> AsyncDocumentsResourceWithRawResponse:
- return AsyncDocumentsResourceWithRawResponse(self._memory_banks.documents)
+ return AsyncDocumentsResourceWithRawResponse(self._memory.documents)
-class MemoryBanksResourceWithStreamingResponse:
- def __init__(self, memory_banks: MemoryBanksResource) -> None:
- self._memory_banks = memory_banks
+class MemoryResourceWithStreamingResponse:
+ def __init__(self, memory: MemoryResource) -> None:
+ self._memory = memory
self.create = to_streamed_response_wrapper(
- memory_banks.create,
+ memory.create,
)
self.retrieve = to_streamed_response_wrapper(
- memory_banks.retrieve,
+ memory.retrieve,
)
self.update = to_streamed_response_wrapper(
- memory_banks.update,
+ memory.update,
)
self.list = to_streamed_response_wrapper(
- memory_banks.list,
+ memory.list,
)
self.drop = to_streamed_response_wrapper(
- memory_banks.drop,
+ memory.drop,
)
self.insert = to_streamed_response_wrapper(
- memory_banks.insert,
+ memory.insert,
)
self.query = to_streamed_response_wrapper(
- memory_banks.query,
+ memory.query,
)
@cached_property
def documents(self) -> DocumentsResourceWithStreamingResponse:
- return DocumentsResourceWithStreamingResponse(self._memory_banks.documents)
+ return DocumentsResourceWithStreamingResponse(self._memory.documents)
-class AsyncMemoryBanksResourceWithStreamingResponse:
- def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None:
- self._memory_banks = memory_banks
+class AsyncMemoryResourceWithStreamingResponse:
+ def __init__(self, memory: AsyncMemoryResource) -> None:
+ self._memory = memory
self.create = async_to_streamed_response_wrapper(
- memory_banks.create,
+ memory.create,
)
self.retrieve = async_to_streamed_response_wrapper(
- memory_banks.retrieve,
+ memory.retrieve,
)
self.update = async_to_streamed_response_wrapper(
- memory_banks.update,
+ memory.update,
)
self.list = async_to_streamed_response_wrapper(
- memory_banks.list,
+ memory.list,
)
self.drop = async_to_streamed_response_wrapper(
- memory_banks.drop,
+ memory.drop,
)
self.insert = async_to_streamed_response_wrapper(
- memory_banks.insert,
+ memory.insert,
)
self.query = async_to_streamed_response_wrapper(
- memory_banks.query,
+ memory.query,
)
@cached_property
def documents(self) -> AsyncDocumentsResourceWithStreamingResponse:
- return AsyncDocumentsResourceWithStreamingResponse(self._memory_banks.documents)
+ return AsyncDocumentsResourceWithStreamingResponse(self._memory.documents)
diff --git a/src/llama_stack/resources/memory_banks.py b/src/llama_stack/resources/memory_banks.py
new file mode 100644
index 0000000..294c309
--- /dev/null
+++ b/src/llama_stack/resources/memory_banks.py
@@ -0,0 +1,262 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing import Optional
+from typing_extensions import Literal
+
+import httpx
+
+from ..types import memory_bank_get_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import (
+ maybe_transform,
+ strip_not_given,
+ async_maybe_transform,
+)
+from .._compat import cached_property
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import (
+ to_raw_response_wrapper,
+ to_streamed_response_wrapper,
+ async_to_raw_response_wrapper,
+ async_to_streamed_response_wrapper,
+)
+from .._base_client import make_request_options
+from ..types.memory_bank_spec import MemoryBankSpec
+
+__all__ = ["MemoryBanksResource", "AsyncMemoryBanksResource"]
+
+
+class MemoryBanksResource(SyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> MemoryBanksResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return the
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
+ """
+ return MemoryBanksResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> MemoryBanksResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
+ """
+ return MemoryBanksResourceWithStreamingResponse(self)
+
+ def list(
+ self,
+ *,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> MemoryBankSpec:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return self._get(
+ "/memory_banks/list",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=MemoryBankSpec,
+ )
+
+ def get(
+ self,
+ *,
+ bank_type: Literal["vector", "keyvalue", "keyword", "graph"],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Optional[MemoryBankSpec]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return self._get(
+ "/memory_banks/get",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform({"bank_type": bank_type}, memory_bank_get_params.MemoryBankGetParams),
+ ),
+ cast_to=MemoryBankSpec,
+ )
+
+
+class AsyncMemoryBanksResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncMemoryBanksResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return the
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
+ """
+ return AsyncMemoryBanksResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncMemoryBanksResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
+ """
+ return AsyncMemoryBanksResourceWithStreamingResponse(self)
+
+ async def list(
+ self,
+ *,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> MemoryBankSpec:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return await self._get(
+ "/memory_banks/list",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=MemoryBankSpec,
+ )
+
+ async def get(
+ self,
+ *,
+ bank_type: Literal["vector", "keyvalue", "keyword", "graph"],
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Optional[MemoryBankSpec]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return await self._get(
+ "/memory_banks/get",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=await async_maybe_transform({"bank_type": bank_type}, memory_bank_get_params.MemoryBankGetParams),
+ ),
+ cast_to=MemoryBankSpec,
+ )
+
+
+class MemoryBanksResourceWithRawResponse:
+ def __init__(self, memory_banks: MemoryBanksResource) -> None:
+ self._memory_banks = memory_banks
+
+ self.list = to_raw_response_wrapper(
+ memory_banks.list,
+ )
+ self.get = to_raw_response_wrapper(
+ memory_banks.get,
+ )
+
+
+class AsyncMemoryBanksResourceWithRawResponse:
+ def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None:
+ self._memory_banks = memory_banks
+
+ self.list = async_to_raw_response_wrapper(
+ memory_banks.list,
+ )
+ self.get = async_to_raw_response_wrapper(
+ memory_banks.get,
+ )
+
+
+class MemoryBanksResourceWithStreamingResponse:
+ def __init__(self, memory_banks: MemoryBanksResource) -> None:
+ self._memory_banks = memory_banks
+
+ self.list = to_streamed_response_wrapper(
+ memory_banks.list,
+ )
+ self.get = to_streamed_response_wrapper(
+ memory_banks.get,
+ )
+
+
+class AsyncMemoryBanksResourceWithStreamingResponse:
+ def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None:
+ self._memory_banks = memory_banks
+
+ self.list = async_to_streamed_response_wrapper(
+ memory_banks.list,
+ )
+ self.get = async_to_streamed_response_wrapper(
+ memory_banks.get,
+ )
diff --git a/src/llama_stack/resources/models.py b/src/llama_stack/resources/models.py
new file mode 100644
index 0000000..29f435c
--- /dev/null
+++ b/src/llama_stack/resources/models.py
@@ -0,0 +1,261 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing import Optional
+
+import httpx
+
+from ..types import model_get_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import (
+ maybe_transform,
+ strip_not_given,
+ async_maybe_transform,
+)
+from .._compat import cached_property
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import (
+ to_raw_response_wrapper,
+ to_streamed_response_wrapper,
+ async_to_raw_response_wrapper,
+ async_to_streamed_response_wrapper,
+)
+from .._base_client import make_request_options
+from ..types.model_serving_spec import ModelServingSpec
+
+__all__ = ["ModelsResource", "AsyncModelsResource"]
+
+
+class ModelsResource(SyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> ModelsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return the
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
+ """
+ return ModelsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> ModelsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
+ """
+ return ModelsResourceWithStreamingResponse(self)
+
+ def list(
+ self,
+ *,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> ModelServingSpec:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return self._get(
+ "/models/list",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ModelServingSpec,
+ )
+
+ def get(
+ self,
+ *,
+ core_model_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Optional[ModelServingSpec]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return self._get(
+ "/models/get",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform({"core_model_id": core_model_id}, model_get_params.ModelGetParams),
+ ),
+ cast_to=ModelServingSpec,
+ )
+
+
+class AsyncModelsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncModelsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return the
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
+ """
+ return AsyncModelsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
+ """
+ return AsyncModelsResourceWithStreamingResponse(self)
+
+ async def list(
+ self,
+ *,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> ModelServingSpec:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return await self._get(
+ "/models/list",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ModelServingSpec,
+ )
+
+ async def get(
+ self,
+ *,
+ core_model_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Optional[ModelServingSpec]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return await self._get(
+ "/models/get",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=await async_maybe_transform({"core_model_id": core_model_id}, model_get_params.ModelGetParams),
+ ),
+ cast_to=ModelServingSpec,
+ )
+
+
+class ModelsResourceWithRawResponse:
+ def __init__(self, models: ModelsResource) -> None:
+ self._models = models
+
+ self.list = to_raw_response_wrapper(
+ models.list,
+ )
+ self.get = to_raw_response_wrapper(
+ models.get,
+ )
+
+
+class AsyncModelsResourceWithRawResponse:
+ def __init__(self, models: AsyncModelsResource) -> None:
+ self._models = models
+
+ self.list = async_to_raw_response_wrapper(
+ models.list,
+ )
+ self.get = async_to_raw_response_wrapper(
+ models.get,
+ )
+
+
+class ModelsResourceWithStreamingResponse:
+ def __init__(self, models: ModelsResource) -> None:
+ self._models = models
+
+ self.list = to_streamed_response_wrapper(
+ models.list,
+ )
+ self.get = to_streamed_response_wrapper(
+ models.get,
+ )
+
+
+class AsyncModelsResourceWithStreamingResponse:
+ def __init__(self, models: AsyncModelsResource) -> None:
+ self._models = models
+
+ self.list = async_to_streamed_response_wrapper(
+ models.list,
+ )
+ self.get = async_to_streamed_response_wrapper(
+ models.get,
+ )
diff --git a/src/llama_stack/resources/post_training/jobs.py b/src/llama_stack/resources/post_training/jobs.py
index e3e4632..840b2e7 100644
--- a/src/llama_stack/resources/post_training/jobs.py
+++ b/src/llama_stack/resources/post_training/jobs.py
@@ -7,6 +7,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -50,6 +51,7 @@ def with_streaming_response(self) -> JobsResourceWithStreamingResponse:
def list(
self,
*,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -57,7 +59,21 @@ def list(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> PostTrainingJob:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/post_training/jobs",
options=make_request_options(
@@ -70,6 +86,7 @@ def artifacts(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -87,6 +104,10 @@ def artifacts(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/post_training/job/artifacts",
options=make_request_options(
@@ -103,6 +124,7 @@ def cancel(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -121,6 +143,10 @@ def cancel(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/post_training/job/cancel",
body=maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams),
@@ -134,6 +160,7 @@ def logs(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -151,6 +178,10 @@ def logs(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/post_training/job/logs",
options=make_request_options(
@@ -167,6 +198,7 @@ def status(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -184,6 +216,10 @@ def status(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/post_training/job/status",
options=make_request_options(
@@ -220,6 +256,7 @@ def with_streaming_response(self) -> AsyncJobsResourceWithStreamingResponse:
async def list(
self,
*,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -227,7 +264,21 @@ async def list(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> PostTrainingJob:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/post_training/jobs",
options=make_request_options(
@@ -240,6 +291,7 @@ async def artifacts(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -257,6 +309,10 @@ async def artifacts(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/post_training/job/artifacts",
options=make_request_options(
@@ -273,6 +329,7 @@ async def cancel(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -291,6 +348,10 @@ async def cancel(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/post_training/job/cancel",
body=await async_maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams),
@@ -304,6 +365,7 @@ async def logs(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -321,6 +383,10 @@ async def logs(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/post_training/job/logs",
options=make_request_options(
@@ -337,6 +403,7 @@ async def status(
self,
*,
job_uuid: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -354,6 +421,10 @@ async def status(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/post_training/job/status",
options=make_request_options(
diff --git a/src/llama_stack/resources/post_training/post_training.py b/src/llama_stack/resources/post_training/post_training.py
index db5aac8..8863e6b 100644
--- a/src/llama_stack/resources/post_training/post_training.py
+++ b/src/llama_stack/resources/post_training/post_training.py
@@ -22,6 +22,7 @@
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from ..._compat import cached_property
@@ -76,6 +77,7 @@ def preference_optimize(
optimizer_config: post_training_preference_optimize_params.OptimizerConfig,
training_config: post_training_preference_optimize_params.TrainingConfig,
validation_dataset: TrainEvalDatasetParam,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -93,6 +95,10 @@ def preference_optimize(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/post_training/preference_optimize",
body=maybe_transform(
@@ -129,6 +135,7 @@ def supervised_fine_tune(
optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig,
training_config: post_training_supervised_fine_tune_params.TrainingConfig,
validation_dataset: TrainEvalDatasetParam,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -146,6 +153,10 @@ def supervised_fine_tune(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/post_training/supervised_fine_tune",
body=maybe_transform(
@@ -207,6 +218,7 @@ async def preference_optimize(
optimizer_config: post_training_preference_optimize_params.OptimizerConfig,
training_config: post_training_preference_optimize_params.TrainingConfig,
validation_dataset: TrainEvalDatasetParam,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -224,6 +236,10 @@ async def preference_optimize(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/post_training/preference_optimize",
body=await async_maybe_transform(
@@ -260,6 +276,7 @@ async def supervised_fine_tune(
optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig,
training_config: post_training_supervised_fine_tune_params.TrainingConfig,
validation_dataset: TrainEvalDatasetParam,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -277,6 +294,10 @@ async def supervised_fine_tune(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/post_training/supervised_fine_tune",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/reward_scoring.py b/src/llama_stack/resources/reward_scoring.py
index bed1479..3e55287 100644
--- a/src/llama_stack/resources/reward_scoring.py
+++ b/src/llama_stack/resources/reward_scoring.py
@@ -10,6 +10,7 @@
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -51,6 +52,7 @@ def score(
*,
dialog_generations: Iterable[reward_scoring_score_params.DialogGeneration],
model: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -68,6 +70,10 @@ def score(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/reward_scoring/score",
body=maybe_transform(
@@ -109,6 +115,7 @@ async def score(
*,
dialog_generations: Iterable[reward_scoring_score_params.DialogGeneration],
model: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -126,6 +133,10 @@ async def score(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/reward_scoring/score",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/safety.py b/src/llama_stack/resources/safety.py
index 3943f13..2bc3022 100644
--- a/src/llama_stack/resources/safety.py
+++ b/src/llama_stack/resources/safety.py
@@ -2,14 +2,15 @@
from __future__ import annotations
-from typing import Iterable
+from typing import Dict, Union, Iterable
import httpx
-from ..types import safety_run_shields_params
+from ..types import safety_run_shield_params
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -21,8 +22,7 @@
async_to_streamed_response_wrapper,
)
from .._base_client import make_request_options
-from ..types.shield_definition_param import ShieldDefinitionParam
-from ..types.safety_run_shields_response import SafetyRunShieldsResponse
+from ..types.run_sheid_response import RunSheidResponse
__all__ = ["SafetyResource", "AsyncSafetyResource"]
@@ -47,18 +47,20 @@ def with_streaming_response(self) -> SafetyResourceWithStreamingResponse:
"""
return SafetyResourceWithStreamingResponse(self)
- def run_shields(
+ def run_shield(
self,
*,
- messages: Iterable[safety_run_shields_params.Message],
- shields: Iterable[ShieldDefinitionParam],
+ messages: Iterable[safety_run_shield_params.Message],
+ params: Dict[str, Union[bool, float, str, Iterable[object], object, None]],
+ shield_type: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> SafetyRunShieldsResponse:
+ ) -> RunSheidResponse:
"""
Args:
extra_headers: Send extra headers
@@ -69,19 +71,24 @@ def run_shields(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
- "/safety/run_shields",
+ "/safety/run_shield",
body=maybe_transform(
{
"messages": messages,
- "shields": shields,
+ "params": params,
+ "shield_type": shield_type,
},
- safety_run_shields_params.SafetyRunShieldsParams,
+ safety_run_shield_params.SafetyRunShieldParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=SafetyRunShieldsResponse,
+ cast_to=RunSheidResponse,
)
@@ -105,18 +112,20 @@ def with_streaming_response(self) -> AsyncSafetyResourceWithStreamingResponse:
"""
return AsyncSafetyResourceWithStreamingResponse(self)
- async def run_shields(
+ async def run_shield(
self,
*,
- messages: Iterable[safety_run_shields_params.Message],
- shields: Iterable[ShieldDefinitionParam],
+ messages: Iterable[safety_run_shield_params.Message],
+ params: Dict[str, Union[bool, float, str, Iterable[object], object, None]],
+ shield_type: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> SafetyRunShieldsResponse:
+ ) -> RunSheidResponse:
"""
Args:
extra_headers: Send extra headers
@@ -127,19 +136,24 @@ async def run_shields(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
- "/safety/run_shields",
+ "/safety/run_shield",
body=await async_maybe_transform(
{
"messages": messages,
- "shields": shields,
+ "params": params,
+ "shield_type": shield_type,
},
- safety_run_shields_params.SafetyRunShieldsParams,
+ safety_run_shield_params.SafetyRunShieldParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=SafetyRunShieldsResponse,
+ cast_to=RunSheidResponse,
)
@@ -147,8 +161,8 @@ class SafetyResourceWithRawResponse:
def __init__(self, safety: SafetyResource) -> None:
self._safety = safety
- self.run_shields = to_raw_response_wrapper(
- safety.run_shields,
+ self.run_shield = to_raw_response_wrapper(
+ safety.run_shield,
)
@@ -156,8 +170,8 @@ class AsyncSafetyResourceWithRawResponse:
def __init__(self, safety: AsyncSafetyResource) -> None:
self._safety = safety
- self.run_shields = async_to_raw_response_wrapper(
- safety.run_shields,
+ self.run_shield = async_to_raw_response_wrapper(
+ safety.run_shield,
)
@@ -165,8 +179,8 @@ class SafetyResourceWithStreamingResponse:
def __init__(self, safety: SafetyResource) -> None:
self._safety = safety
- self.run_shields = to_streamed_response_wrapper(
- safety.run_shields,
+ self.run_shield = to_streamed_response_wrapper(
+ safety.run_shield,
)
@@ -174,6 +188,6 @@ class AsyncSafetyResourceWithStreamingResponse:
def __init__(self, safety: AsyncSafetyResource) -> None:
self._safety = safety
- self.run_shields = async_to_streamed_response_wrapper(
- safety.run_shields,
+ self.run_shield = async_to_streamed_response_wrapper(
+ safety.run_shield,
)
diff --git a/src/llama_stack/resources/shields.py b/src/llama_stack/resources/shields.py
new file mode 100644
index 0000000..bc800de
--- /dev/null
+++ b/src/llama_stack/resources/shields.py
@@ -0,0 +1,261 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing import Optional
+
+import httpx
+
+from ..types import shield_get_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
+from .._utils import (
+ maybe_transform,
+ strip_not_given,
+ async_maybe_transform,
+)
+from .._compat import cached_property
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import (
+ to_raw_response_wrapper,
+ to_streamed_response_wrapper,
+ async_to_raw_response_wrapper,
+ async_to_streamed_response_wrapper,
+)
+from .._base_client import make_request_options
+from ..types.shield_spec import ShieldSpec
+
+__all__ = ["ShieldsResource", "AsyncShieldsResource"]
+
+
+class ShieldsResource(SyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> ShieldsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return the
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
+ """
+ return ShieldsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> ShieldsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
+ """
+ return ShieldsResourceWithStreamingResponse(self)
+
+ def list(
+ self,
+ *,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> ShieldSpec:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return self._get(
+ "/shields/list",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ShieldSpec,
+ )
+
+ def get(
+ self,
+ *,
+ shield_type: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Optional[ShieldSpec]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return self._get(
+ "/shields/get",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform({"shield_type": shield_type}, shield_get_params.ShieldGetParams),
+ ),
+ cast_to=ShieldSpec,
+ )
+
+
+class AsyncShieldsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncShieldsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return the
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers
+ """
+ return AsyncShieldsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncShieldsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response
+ """
+ return AsyncShieldsResourceWithStreamingResponse(self)
+
+ async def list(
+ self,
+ *,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> ShieldSpec:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return await self._get(
+ "/shields/list",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=ShieldSpec,
+ )
+
+ async def get(
+ self,
+ *,
+ shield_type: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Optional[ShieldSpec]:
+ """
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
+ return await self._get(
+ "/shields/get",
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=await async_maybe_transform({"shield_type": shield_type}, shield_get_params.ShieldGetParams),
+ ),
+ cast_to=ShieldSpec,
+ )
+
+
+class ShieldsResourceWithRawResponse:
+ def __init__(self, shields: ShieldsResource) -> None:
+ self._shields = shields
+
+ self.list = to_raw_response_wrapper(
+ shields.list,
+ )
+ self.get = to_raw_response_wrapper(
+ shields.get,
+ )
+
+
+class AsyncShieldsResourceWithRawResponse:
+ def __init__(self, shields: AsyncShieldsResource) -> None:
+ self._shields = shields
+
+ self.list = async_to_raw_response_wrapper(
+ shields.list,
+ )
+ self.get = async_to_raw_response_wrapper(
+ shields.get,
+ )
+
+
+class ShieldsResourceWithStreamingResponse:
+ def __init__(self, shields: ShieldsResource) -> None:
+ self._shields = shields
+
+ self.list = to_streamed_response_wrapper(
+ shields.list,
+ )
+ self.get = to_streamed_response_wrapper(
+ shields.get,
+ )
+
+
+class AsyncShieldsResourceWithStreamingResponse:
+ def __init__(self, shields: AsyncShieldsResource) -> None:
+ self._shields = shields
+
+ self.list = async_to_streamed_response_wrapper(
+ shields.list,
+ )
+ self.get = async_to_streamed_response_wrapper(
+ shields.get,
+ )
diff --git a/src/llama_stack/resources/synthetic_data_generation.py b/src/llama_stack/resources/synthetic_data_generation.py
index 0847308..d13532c 100644
--- a/src/llama_stack/resources/synthetic_data_generation.py
+++ b/src/llama_stack/resources/synthetic_data_generation.py
@@ -11,6 +11,7 @@
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -53,6 +54,7 @@ def generate(
dialogs: Iterable[synthetic_data_generation_generate_params.Dialog],
filtering_function: Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"],
model: str | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -70,6 +72,10 @@ def generate(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/synthetic_data_generation/generate",
body=maybe_transform(
@@ -113,6 +119,7 @@ async def generate(
dialogs: Iterable[synthetic_data_generation_generate_params.Dialog],
filtering_function: Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"],
model: str | NotGiven = NOT_GIVEN,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -130,6 +137,10 @@ async def generate(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/synthetic_data_generation/generate",
body=await async_maybe_transform(
diff --git a/src/llama_stack/resources/telemetry.py b/src/llama_stack/resources/telemetry.py
index b3e0524..4526dd7 100644
--- a/src/llama_stack/resources/telemetry.py
+++ b/src/llama_stack/resources/telemetry.py
@@ -8,6 +8,7 @@
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
from .._utils import (
maybe_transform,
+ strip_not_given,
async_maybe_transform,
)
from .._compat import cached_property
@@ -48,6 +49,7 @@ def get_trace(
self,
*,
trace_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -65,6 +67,10 @@ def get_trace(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._get(
"/telemetry/get_trace",
options=make_request_options(
@@ -81,6 +87,7 @@ def log(
self,
*,
event: telemetry_log_params.Event,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -99,6 +106,10 @@ def log(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return self._post(
"/telemetry/log_event",
body=maybe_transform({"event": event}, telemetry_log_params.TelemetryLogParams),
@@ -133,6 +144,7 @@ async def get_trace(
self,
*,
trace_id: str,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -150,6 +162,10 @@ async def get_trace(
timeout: Override the client-level default timeout for this request, in seconds
"""
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._get(
"/telemetry/get_trace",
options=make_request_options(
@@ -168,6 +184,7 @@ async def log(
self,
*,
event: telemetry_log_params.Event,
+ x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
@@ -186,6 +203,10 @@ async def log(
timeout: Override the client-level default timeout for this request, in seconds
"""
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ extra_headers = {
+ **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}),
+ **(extra_headers or {}),
+ }
return await self._post(
"/telemetry/log_event",
body=await async_maybe_transform({"event": event}, telemetry_log_params.TelemetryLogParams),
diff --git a/src/llama_stack/types/__init__.py b/src/llama_stack/types/__init__.py
index 8f43ce3..452da12 100644
--- a/src/llama_stack/types/__init__.py
+++ b/src/llama_stack/types/__init__.py
@@ -12,42 +12,46 @@
CompletionMessage as CompletionMessage,
ToolResponseMessage as ToolResponseMessage,
)
+from .shield_spec import ShieldSpec as ShieldSpec
from .evaluation_job import EvaluationJob as EvaluationJob
from .inference_step import InferenceStep as InferenceStep
from .reward_scoring import RewardScoring as RewardScoring
-from .sheid_response import SheidResponse as SheidResponse
from .query_documents import QueryDocuments as QueryDocuments
from .token_log_probs import TokenLogProbs as TokenLogProbs
+from .memory_bank_spec import MemoryBankSpec as MemoryBankSpec
+from .model_get_params import ModelGetParams as ModelGetParams
from .shield_call_step import ShieldCallStep as ShieldCallStep
from .post_training_job import PostTrainingJob as PostTrainingJob
+from .shield_get_params import ShieldGetParams as ShieldGetParams
from .dataset_get_params import DatasetGetParams as DatasetGetParams
+from .memory_drop_params import MemoryDropParams as MemoryDropParams
+from .model_serving_spec import ModelServingSpec as ModelServingSpec
+from .run_sheid_response import RunSheidResponse as RunSheidResponse
from .train_eval_dataset import TrainEvalDataset as TrainEvalDataset
from .agent_create_params import AgentCreateParams as AgentCreateParams
from .agent_delete_params import AgentDeleteParams as AgentDeleteParams
+from .memory_query_params import MemoryQueryParams as MemoryQueryParams
from .tool_execution_step import ToolExecutionStep as ToolExecutionStep
+from .memory_create_params import MemoryCreateParams as MemoryCreateParams
+from .memory_drop_response import MemoryDropResponse as MemoryDropResponse
+from .memory_insert_params import MemoryInsertParams as MemoryInsertParams
+from .memory_update_params import MemoryUpdateParams as MemoryUpdateParams
from .telemetry_log_params import TelemetryLogParams as TelemetryLogParams
from .agent_create_response import AgentCreateResponse as AgentCreateResponse
from .batch_chat_completion import BatchChatCompletion as BatchChatCompletion
from .dataset_create_params import DatasetCreateParams as DatasetCreateParams
from .dataset_delete_params import DatasetDeleteParams as DatasetDeleteParams
from .memory_retrieval_step import MemoryRetrievalStep as MemoryRetrievalStep
+from .memory_bank_get_params import MemoryBankGetParams as MemoryBankGetParams
+from .memory_retrieve_params import MemoryRetrieveParams as MemoryRetrieveParams
from .completion_stream_chunk import CompletionStreamChunk as CompletionStreamChunk
-from .memory_bank_drop_params import MemoryBankDropParams as MemoryBankDropParams
-from .shield_definition_param import ShieldDefinitionParam as ShieldDefinitionParam
-from .memory_bank_query_params import MemoryBankQueryParams as MemoryBankQueryParams
+from .safety_run_shield_params import SafetyRunShieldParams as SafetyRunShieldParams
from .train_eval_dataset_param import TrainEvalDatasetParam as TrainEvalDatasetParam
-from .memory_bank_create_params import MemoryBankCreateParams as MemoryBankCreateParams
-from .memory_bank_drop_response import MemoryBankDropResponse as MemoryBankDropResponse
-from .memory_bank_insert_params import MemoryBankInsertParams as MemoryBankInsertParams
-from .memory_bank_update_params import MemoryBankUpdateParams as MemoryBankUpdateParams
-from .safety_run_shields_params import SafetyRunShieldsParams as SafetyRunShieldsParams
from .scored_dialog_generations import ScoredDialogGenerations as ScoredDialogGenerations
from .synthetic_data_generation import SyntheticDataGeneration as SyntheticDataGeneration
from .telemetry_get_trace_params import TelemetryGetTraceParams as TelemetryGetTraceParams
from .inference_completion_params import InferenceCompletionParams as InferenceCompletionParams
-from .memory_bank_retrieve_params import MemoryBankRetrieveParams as MemoryBankRetrieveParams
from .reward_scoring_score_params import RewardScoringScoreParams as RewardScoringScoreParams
-from .safety_run_shields_response import SafetyRunShieldsResponse as SafetyRunShieldsResponse
from .tool_param_definition_param import ToolParamDefinitionParam as ToolParamDefinitionParam
from .chat_completion_stream_chunk import ChatCompletionStreamChunk as ChatCompletionStreamChunk
from .telemetry_get_trace_response import TelemetryGetTraceResponse as TelemetryGetTraceResponse
@@ -55,12 +59,9 @@
from .evaluation_summarization_params import EvaluationSummarizationParams as EvaluationSummarizationParams
from .rest_api_execution_config_param import RestAPIExecutionConfigParam as RestAPIExecutionConfigParam
from .inference_chat_completion_params import InferenceChatCompletionParams as InferenceChatCompletionParams
-from .llm_query_generator_config_param import LlmQueryGeneratorConfigParam as LlmQueryGeneratorConfigParam
from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams
from .evaluation_text_generation_params import EvaluationTextGenerationParams as EvaluationTextGenerationParams
from .inference_chat_completion_response import InferenceChatCompletionResponse as InferenceChatCompletionResponse
-from .custom_query_generator_config_param import CustomQueryGeneratorConfigParam as CustomQueryGeneratorConfigParam
-from .default_query_generator_config_param import DefaultQueryGeneratorConfigParam as DefaultQueryGeneratorConfigParam
from .batch_inference_chat_completion_params import (
BatchInferenceChatCompletionParams as BatchInferenceChatCompletionParams,
)
diff --git a/src/llama_stack/types/agent_create_params.py b/src/llama_stack/types/agent_create_params.py
index 39738a3..27d4705 100644
--- a/src/llama_stack/types/agent_create_params.py
+++ b/src/llama_stack/types/agent_create_params.py
@@ -3,15 +3,12 @@
from __future__ import annotations
from typing import Dict, List, Union, Iterable
-from typing_extensions import Literal, Required, TypeAlias, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict
-from .shield_definition_param import ShieldDefinitionParam
+from .._utils import PropertyInfo
from .tool_param_definition_param import ToolParamDefinitionParam
from .shared_params.sampling_params import SamplingParams
from .rest_api_execution_config_param import RestAPIExecutionConfigParam
-from .llm_query_generator_config_param import LlmQueryGeneratorConfigParam
-from .custom_query_generator_config_param import CustomQueryGeneratorConfigParam
-from .default_query_generator_config_param import DefaultQueryGeneratorConfigParam
__all__ = [
"AgentCreateParams",
@@ -22,19 +19,24 @@
"AgentConfigToolPhotogenToolDefinition",
"AgentConfigToolCodeInterpreterToolDefinition",
"AgentConfigToolFunctionCallToolDefinition",
- "AgentConfigToolShield",
- "AgentConfigToolShieldMemoryBankConfig",
- "AgentConfigToolShieldMemoryBankConfigVector",
- "AgentConfigToolShieldMemoryBankConfigKeyvalue",
- "AgentConfigToolShieldMemoryBankConfigKeyword",
- "AgentConfigToolShieldMemoryBankConfigGraph",
- "AgentConfigToolShieldQueryGeneratorConfig",
+ "AgentConfigToolMemoryToolDefinition",
+ "AgentConfigToolMemoryToolDefinitionMemoryBankConfig",
+ "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0",
+ "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1",
+ "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2",
+ "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3",
+ "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfig",
+ "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0",
+ "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1",
+ "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigType",
]
class AgentCreateParams(TypedDict, total=False):
agent_config: Required[AgentConfig]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class AgentConfigToolSearchToolDefinition(TypedDict, total=False):
api_key: Required[str]
@@ -43,9 +45,9 @@ class AgentConfigToolSearchToolDefinition(TypedDict, total=False):
type: Required[Literal["brave_search"]]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
remote_execution: RestAPIExecutionConfigParam
@@ -55,9 +57,9 @@ class AgentConfigToolWolframAlphaToolDefinition(TypedDict, total=False):
type: Required[Literal["wolfram_alpha"]]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
remote_execution: RestAPIExecutionConfigParam
@@ -65,9 +67,9 @@ class AgentConfigToolWolframAlphaToolDefinition(TypedDict, total=False):
class AgentConfigToolPhotogenToolDefinition(TypedDict, total=False):
type: Required[Literal["photogen"]]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
remote_execution: RestAPIExecutionConfigParam
@@ -77,9 +79,9 @@ class AgentConfigToolCodeInterpreterToolDefinition(TypedDict, total=False):
type: Required[Literal["code_interpreter"]]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
remote_execution: RestAPIExecutionConfigParam
@@ -93,20 +95,20 @@ class AgentConfigToolFunctionCallToolDefinition(TypedDict, total=False):
type: Required[Literal["function_call"]]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
remote_execution: RestAPIExecutionConfigParam
-class AgentConfigToolShieldMemoryBankConfigVector(TypedDict, total=False):
+class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False):
bank_id: Required[str]
type: Required[Literal["vector"]]
-class AgentConfigToolShieldMemoryBankConfigKeyvalue(TypedDict, total=False):
+class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1(TypedDict, total=False):
bank_id: Required[str]
keys: Required[List[str]]
@@ -114,13 +116,13 @@ class AgentConfigToolShieldMemoryBankConfigKeyvalue(TypedDict, total=False):
type: Required[Literal["keyvalue"]]
-class AgentConfigToolShieldMemoryBankConfigKeyword(TypedDict, total=False):
+class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2(TypedDict, total=False):
bank_id: Required[str]
type: Required[Literal["keyword"]]
-class AgentConfigToolShieldMemoryBankConfigGraph(TypedDict, total=False):
+class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3(TypedDict, total=False):
bank_id: Required[str]
entities: Required[List[str]]
@@ -128,32 +130,53 @@ class AgentConfigToolShieldMemoryBankConfigGraph(TypedDict, total=False):
type: Required[Literal["graph"]]
-AgentConfigToolShieldMemoryBankConfig: TypeAlias = Union[
- AgentConfigToolShieldMemoryBankConfigVector,
- AgentConfigToolShieldMemoryBankConfigKeyvalue,
- AgentConfigToolShieldMemoryBankConfigKeyword,
- AgentConfigToolShieldMemoryBankConfigGraph,
+AgentConfigToolMemoryToolDefinitionMemoryBankConfig: TypeAlias = Union[
+ AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0,
+ AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1,
+ AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2,
+ AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3,
]
-AgentConfigToolShieldQueryGeneratorConfig: TypeAlias = Union[
- DefaultQueryGeneratorConfigParam, LlmQueryGeneratorConfigParam, CustomQueryGeneratorConfigParam
+
+class AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0(TypedDict, total=False):
+ sep: Required[str]
+
+ type: Required[Literal["default"]]
+
+
+class AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1(TypedDict, total=False):
+ model: Required[str]
+
+ template: Required[str]
+
+ type: Required[Literal["llm"]]
+
+
+class AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigType(TypedDict, total=False):
+ type: Required[Literal["custom"]]
+
+
+AgentConfigToolMemoryToolDefinitionQueryGeneratorConfig: TypeAlias = Union[
+ AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0,
+ AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1,
+ AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigType,
]
-class AgentConfigToolShield(TypedDict, total=False):
+class AgentConfigToolMemoryToolDefinition(TypedDict, total=False):
max_chunks: Required[int]
max_tokens_in_context: Required[int]
- memory_bank_configs: Required[Iterable[AgentConfigToolShieldMemoryBankConfig]]
+ memory_bank_configs: Required[Iterable[AgentConfigToolMemoryToolDefinitionMemoryBankConfig]]
- query_generator_config: Required[AgentConfigToolShieldQueryGeneratorConfig]
+ query_generator_config: Required[AgentConfigToolMemoryToolDefinitionQueryGeneratorConfig]
type: Required[Literal["memory"]]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
AgentConfigTool: TypeAlias = Union[
@@ -162,18 +185,22 @@ class AgentConfigToolShield(TypedDict, total=False):
AgentConfigToolPhotogenToolDefinition,
AgentConfigToolCodeInterpreterToolDefinition,
AgentConfigToolFunctionCallToolDefinition,
- AgentConfigToolShield,
+ AgentConfigToolMemoryToolDefinition,
]
class AgentConfig(TypedDict, total=False):
+ enable_session_persistence: Required[bool]
+
instructions: Required[str]
+ max_infer_iters: Required[int]
+
model: Required[str]
- input_shields: Iterable[ShieldDefinitionParam]
+ input_shields: List[str]
- output_shields: Iterable[ShieldDefinitionParam]
+ output_shields: List[str]
sampling_params: SamplingParams
diff --git a/src/llama_stack/types/agent_delete_params.py b/src/llama_stack/types/agent_delete_params.py
index cfc4cd0..ba601b9 100644
--- a/src/llama_stack/types/agent_delete_params.py
+++ b/src/llama_stack/types/agent_delete_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
__all__ = ["AgentDeleteParams"]
class AgentDeleteParams(TypedDict, total=False):
agent_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/agents/session_create_params.py b/src/llama_stack/types/agents/session_create_params.py
index d64bee0..42e19fe 100644
--- a/src/llama_stack/types/agents/session_create_params.py
+++ b/src/llama_stack/types/agents/session_create_params.py
@@ -2,7 +2,9 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["SessionCreateParams"]
@@ -11,3 +13,5 @@ class SessionCreateParams(TypedDict, total=False):
agent_id: Required[str]
session_name: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/agents/session_delete_params.py b/src/llama_stack/types/agents/session_delete_params.py
index 1474f72..45864d6 100644
--- a/src/llama_stack/types/agents/session_delete_params.py
+++ b/src/llama_stack/types/agents/session_delete_params.py
@@ -2,7 +2,9 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["SessionDeleteParams"]
@@ -11,3 +13,5 @@ class SessionDeleteParams(TypedDict, total=False):
agent_id: Required[str]
session_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/agents/session_retrieve_params.py b/src/llama_stack/types/agents/session_retrieve_params.py
index 4c0b691..974c95f 100644
--- a/src/llama_stack/types/agents/session_retrieve_params.py
+++ b/src/llama_stack/types/agents/session_retrieve_params.py
@@ -3,7 +3,9 @@
from __future__ import annotations
from typing import List
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["SessionRetrieveParams"]
@@ -14,3 +16,5 @@ class SessionRetrieveParams(TypedDict, total=False):
session_id: Required[str]
turn_ids: List[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/agents/step_retrieve_params.py b/src/llama_stack/types/agents/step_retrieve_params.py
index 35b49c5..cccdc19 100644
--- a/src/llama_stack/types/agents/step_retrieve_params.py
+++ b/src/llama_stack/types/agents/step_retrieve_params.py
@@ -2,7 +2,9 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["StepRetrieveParams"]
@@ -13,3 +15,5 @@ class StepRetrieveParams(TypedDict, total=False):
step_id: Required[str]
turn_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/agents/turn_create_params.py b/src/llama_stack/types/agents/turn_create_params.py
index ffbb54c..349d12d 100644
--- a/src/llama_stack/types/agents/turn_create_params.py
+++ b/src/llama_stack/types/agents/turn_create_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Union, Iterable
-from typing_extensions import Literal, Required, TypeAlias, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict
+from ..._utils import PropertyInfo
from ..shared_params.attachment import Attachment
from ..shared_params.user_message import UserMessage
from ..shared_params.tool_response_message import ToolResponseMessage
@@ -21,6 +22,8 @@ class TurnCreateParamsBase(TypedDict, total=False):
attachments: Iterable[Attachment]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
Message: TypeAlias = Union[UserMessage, ToolResponseMessage]
diff --git a/src/llama_stack/types/agents/turn_retrieve_params.py b/src/llama_stack/types/agents/turn_retrieve_params.py
index 0352f69..7f3349a 100644
--- a/src/llama_stack/types/agents/turn_retrieve_params.py
+++ b/src/llama_stack/types/agents/turn_retrieve_params.py
@@ -2,7 +2,9 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["TurnRetrieveParams"]
@@ -11,3 +13,5 @@ class TurnRetrieveParams(TypedDict, total=False):
agent_id: Required[str]
turn_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/batch_inference_chat_completion_params.py b/src/llama_stack/types/batch_inference_chat_completion_params.py
index 28ad798..0aae2d2 100644
--- a/src/llama_stack/types/batch_inference_chat_completion_params.py
+++ b/src/llama_stack/types/batch_inference_chat_completion_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Dict, Union, Iterable
-from typing_extensions import Literal, Required, TypeAlias, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict
+from .._utils import PropertyInfo
from .shared_params.user_message import UserMessage
from .tool_param_definition_param import ToolParamDefinitionParam
from .shared_params.system_message import SystemMessage
@@ -41,6 +42,8 @@ class BatchInferenceChatCompletionParams(TypedDict, total=False):
tools: Iterable[Tool]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
MessagesBatch: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage]
diff --git a/src/llama_stack/types/batch_inference_completion_params.py b/src/llama_stack/types/batch_inference_completion_params.py
index 398531a..e205b85 100644
--- a/src/llama_stack/types/batch_inference_completion_params.py
+++ b/src/llama_stack/types/batch_inference_completion_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import List, Union
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+from .._utils import PropertyInfo
from .shared_params.sampling_params import SamplingParams
__all__ = ["BatchInferenceCompletionParams", "Logprobs"]
@@ -19,6 +20,8 @@ class BatchInferenceCompletionParams(TypedDict, total=False):
sampling_params: SamplingParams
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class Logprobs(TypedDict, total=False):
top_k: int
diff --git a/src/llama_stack/types/custom_query_generator_config_param.py b/src/llama_stack/types/custom_query_generator_config_param.py
deleted file mode 100644
index 432450c..0000000
--- a/src/llama_stack/types/custom_query_generator_config_param.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing_extensions import Literal, Required, TypedDict
-
-__all__ = ["CustomQueryGeneratorConfigParam"]
-
-
-class CustomQueryGeneratorConfigParam(TypedDict, total=False):
- type: Required[Literal["custom"]]
diff --git a/src/llama_stack/types/dataset_create_params.py b/src/llama_stack/types/dataset_create_params.py
index 971ddb5..ec81175 100644
--- a/src/llama_stack/types/dataset_create_params.py
+++ b/src/llama_stack/types/dataset_create_params.py
@@ -2,8 +2,9 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+from .._utils import PropertyInfo
from .train_eval_dataset_param import TrainEvalDatasetParam
__all__ = ["DatasetCreateParams"]
@@ -13,3 +14,5 @@ class DatasetCreateParams(TypedDict, total=False):
dataset: Required[TrainEvalDatasetParam]
uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/dataset_delete_params.py b/src/llama_stack/types/dataset_delete_params.py
index ca46388..66d0670 100644
--- a/src/llama_stack/types/dataset_delete_params.py
+++ b/src/llama_stack/types/dataset_delete_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
__all__ = ["DatasetDeleteParams"]
class DatasetDeleteParams(TypedDict, total=False):
dataset_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/dataset_get_params.py b/src/llama_stack/types/dataset_get_params.py
index b1423f1..d0d6695 100644
--- a/src/llama_stack/types/dataset_get_params.py
+++ b/src/llama_stack/types/dataset_get_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
__all__ = ["DatasetGetParams"]
class DatasetGetParams(TypedDict, total=False):
dataset_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/default_query_generator_config_param.py b/src/llama_stack/types/default_query_generator_config_param.py
deleted file mode 100644
index 2aaaa81..0000000
--- a/src/llama_stack/types/default_query_generator_config_param.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing_extensions import Literal, Required, TypedDict
-
-__all__ = ["DefaultQueryGeneratorConfigParam"]
-
-
-class DefaultQueryGeneratorConfigParam(TypedDict, total=False):
- sep: Required[str]
-
- type: Required[Literal["default"]]
diff --git a/src/llama_stack/types/evaluate/job_cancel_params.py b/src/llama_stack/types/evaluate/job_cancel_params.py
index c9c30d8..9321c3b 100644
--- a/src/llama_stack/types/evaluate/job_cancel_params.py
+++ b/src/llama_stack/types/evaluate/job_cancel_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["JobCancelParams"]
class JobCancelParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/evaluate/jobs/artifact_list_params.py b/src/llama_stack/types/evaluate/jobs/artifact_list_params.py
index f52228e..579033e 100644
--- a/src/llama_stack/types/evaluate/jobs/artifact_list_params.py
+++ b/src/llama_stack/types/evaluate/jobs/artifact_list_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ...._utils import PropertyInfo
__all__ = ["ArtifactListParams"]
class ArtifactListParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/evaluate/jobs/log_list_params.py b/src/llama_stack/types/evaluate/jobs/log_list_params.py
index 1035005..4b2df45 100644
--- a/src/llama_stack/types/evaluate/jobs/log_list_params.py
+++ b/src/llama_stack/types/evaluate/jobs/log_list_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ...._utils import PropertyInfo
__all__ = ["LogListParams"]
class LogListParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/evaluate/jobs/status_list_params.py b/src/llama_stack/types/evaluate/jobs/status_list_params.py
index c2a0dc5..a7d5165 100644
--- a/src/llama_stack/types/evaluate/jobs/status_list_params.py
+++ b/src/llama_stack/types/evaluate/jobs/status_list_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ...._utils import PropertyInfo
__all__ = ["StatusListParams"]
class StatusListParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/evaluate/question_answering_create_params.py b/src/llama_stack/types/evaluate/question_answering_create_params.py
index 8477717..de8caa0 100644
--- a/src/llama_stack/types/evaluate/question_answering_create_params.py
+++ b/src/llama_stack/types/evaluate/question_answering_create_params.py
@@ -3,10 +3,14 @@
from __future__ import annotations
from typing import List
-from typing_extensions import Literal, Required, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["QuestionAnsweringCreateParams"]
class QuestionAnsweringCreateParams(TypedDict, total=False):
metrics: Required[List[Literal["em", "f1"]]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/evaluation_summarization_params.py b/src/llama_stack/types/evaluation_summarization_params.py
index 34542d6..80dd8f5 100644
--- a/src/llama_stack/types/evaluation_summarization_params.py
+++ b/src/llama_stack/types/evaluation_summarization_params.py
@@ -3,10 +3,14 @@
from __future__ import annotations
from typing import List
-from typing_extensions import Literal, Required, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
__all__ = ["EvaluationSummarizationParams"]
class EvaluationSummarizationParams(TypedDict, total=False):
metrics: Required[List[Literal["rouge", "bleu"]]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/evaluation_text_generation_params.py b/src/llama_stack/types/evaluation_text_generation_params.py
index deaec66..1cd3a56 100644
--- a/src/llama_stack/types/evaluation_text_generation_params.py
+++ b/src/llama_stack/types/evaluation_text_generation_params.py
@@ -3,10 +3,14 @@
from __future__ import annotations
from typing import List
-from typing_extensions import Literal, Required, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
__all__ = ["EvaluationTextGenerationParams"]
class EvaluationTextGenerationParams(TypedDict, total=False):
metrics: Required[List[Literal["perplexity", "rouge", "bleu"]]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/inference/embedding_create_params.py b/src/llama_stack/types/inference/embedding_create_params.py
index 272f199..0e4e970 100644
--- a/src/llama_stack/types/inference/embedding_create_params.py
+++ b/src/llama_stack/types/inference/embedding_create_params.py
@@ -3,7 +3,9 @@
from __future__ import annotations
from typing import List, Union
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["EmbeddingCreateParams"]
@@ -12,3 +14,5 @@ class EmbeddingCreateParams(TypedDict, total=False):
contents: Required[List[Union[str, List[str]]]]
model: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/inference_chat_completion_params.py b/src/llama_stack/types/inference_chat_completion_params.py
index af21934..2cac635 100644
--- a/src/llama_stack/types/inference_chat_completion_params.py
+++ b/src/llama_stack/types/inference_chat_completion_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Dict, Union, Iterable
-from typing_extensions import Literal, Required, TypeAlias, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict
+from .._utils import PropertyInfo
from .shared_params.user_message import UserMessage
from .tool_param_definition_param import ToolParamDefinitionParam
from .shared_params.system_message import SystemMessage
@@ -48,6 +49,8 @@ class InferenceChatCompletionParamsBase(TypedDict, total=False):
tools: Iterable[Tool]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
Message: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage]
diff --git a/src/llama_stack/types/inference_completion_params.py b/src/llama_stack/types/inference_completion_params.py
index 79544b0..d4145a5 100644
--- a/src/llama_stack/types/inference_completion_params.py
+++ b/src/llama_stack/types/inference_completion_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import List, Union
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+from .._utils import PropertyInfo
from .shared_params.sampling_params import SamplingParams
__all__ = ["InferenceCompletionParams", "Logprobs"]
@@ -21,6 +22,8 @@ class InferenceCompletionParams(TypedDict, total=False):
stream: bool
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class Logprobs(TypedDict, total=False):
top_k: int
diff --git a/src/llama_stack/types/llm_query_generator_config_param.py b/src/llama_stack/types/llm_query_generator_config_param.py
deleted file mode 100644
index 8d6bd31..0000000
--- a/src/llama_stack/types/llm_query_generator_config_param.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing_extensions import Literal, Required, TypedDict
-
-__all__ = ["LlmQueryGeneratorConfigParam"]
-
-
-class LlmQueryGeneratorConfigParam(TypedDict, total=False):
- model: Required[str]
-
- template: Required[str]
-
- type: Required[Literal["llm"]]
diff --git a/src/llama_stack/types/memory_banks/__init__.py b/src/llama_stack/types/memory/__init__.py
similarity index 100%
rename from src/llama_stack/types/memory_banks/__init__.py
rename to src/llama_stack/types/memory/__init__.py
diff --git a/src/llama_stack/types/memory_banks/document_delete_params.py b/src/llama_stack/types/memory/document_delete_params.py
similarity index 60%
rename from src/llama_stack/types/memory_banks/document_delete_params.py
rename to src/llama_stack/types/memory/document_delete_params.py
index 01d66b6..9ec4bf1 100644
--- a/src/llama_stack/types/memory_banks/document_delete_params.py
+++ b/src/llama_stack/types/memory/document_delete_params.py
@@ -3,7 +3,9 @@
from __future__ import annotations
from typing import List
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["DocumentDeleteParams"]
@@ -12,3 +14,5 @@ class DocumentDeleteParams(TypedDict, total=False):
bank_id: Required[str]
document_ids: Required[List[str]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_banks/document_retrieve_params.py b/src/llama_stack/types/memory/document_retrieve_params.py
similarity index 61%
rename from src/llama_stack/types/memory_banks/document_retrieve_params.py
rename to src/llama_stack/types/memory/document_retrieve_params.py
index d746395..3f30f9b 100644
--- a/src/llama_stack/types/memory_banks/document_retrieve_params.py
+++ b/src/llama_stack/types/memory/document_retrieve_params.py
@@ -3,7 +3,9 @@
from __future__ import annotations
from typing import List
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["DocumentRetrieveParams"]
@@ -12,3 +14,5 @@ class DocumentRetrieveParams(TypedDict, total=False):
bank_id: Required[str]
document_ids: Required[List[str]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_banks/document_retrieve_response.py b/src/llama_stack/types/memory/document_retrieve_response.py
similarity index 100%
rename from src/llama_stack/types/memory_banks/document_retrieve_response.py
rename to src/llama_stack/types/memory/document_retrieve_response.py
diff --git a/src/llama_stack/types/memory_bank_create_params.py b/src/llama_stack/types/memory_bank_create_params.py
deleted file mode 100644
index 421fca0..0000000
--- a/src/llama_stack/types/memory_bank_create_params.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing_extensions import Required, TypedDict
-
-__all__ = ["MemoryBankCreateParams"]
-
-
-class MemoryBankCreateParams(TypedDict, total=False):
- body: Required[object]
diff --git a/src/llama_stack/types/memory_bank_drop_params.py b/src/llama_stack/types/memory_bank_drop_params.py
deleted file mode 100644
index e19be55..0000000
--- a/src/llama_stack/types/memory_bank_drop_params.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing_extensions import Required, TypedDict
-
-__all__ = ["MemoryBankDropParams"]
-
-
-class MemoryBankDropParams(TypedDict, total=False):
- bank_id: Required[str]
diff --git a/src/llama_stack/types/memory_bank_get_params.py b/src/llama_stack/types/memory_bank_get_params.py
new file mode 100644
index 0000000..de5b43e
--- /dev/null
+++ b/src/llama_stack/types/memory_bank_get_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Literal, Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
+
+__all__ = ["MemoryBankGetParams"]
+
+
+class MemoryBankGetParams(TypedDict, total=False):
+ bank_type: Required[Literal["vector", "keyvalue", "keyword", "graph"]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_bank_retrieve_params.py b/src/llama_stack/types/memory_bank_retrieve_params.py
deleted file mode 100644
index 21436f3..0000000
--- a/src/llama_stack/types/memory_bank_retrieve_params.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing_extensions import Required, TypedDict
-
-__all__ = ["MemoryBankRetrieveParams"]
-
-
-class MemoryBankRetrieveParams(TypedDict, total=False):
- bank_id: Required[str]
diff --git a/src/llama_stack/types/memory_bank_spec.py b/src/llama_stack/types/memory_bank_spec.py
new file mode 100644
index 0000000..b116082
--- /dev/null
+++ b/src/llama_stack/types/memory_bank_spec.py
@@ -0,0 +1,20 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, List, Union
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["MemoryBankSpec", "ProviderConfig"]
+
+
+class ProviderConfig(BaseModel):
+ config: Dict[str, Union[bool, float, str, List[object], object, None]]
+
+ provider_id: str
+
+
+class MemoryBankSpec(BaseModel):
+ bank_type: Literal["vector", "keyvalue", "keyword", "graph"]
+
+ provider_config: ProviderConfig
diff --git a/src/llama_stack/types/memory_create_params.py b/src/llama_stack/types/memory_create_params.py
new file mode 100644
index 0000000..01f496e
--- /dev/null
+++ b/src/llama_stack/types/memory_create_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
+
+__all__ = ["MemoryCreateParams"]
+
+
+class MemoryCreateParams(TypedDict, total=False):
+ body: Required[object]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_drop_params.py b/src/llama_stack/types/memory_drop_params.py
new file mode 100644
index 0000000..b15ec34
--- /dev/null
+++ b/src/llama_stack/types/memory_drop_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
+
+__all__ = ["MemoryDropParams"]
+
+
+class MemoryDropParams(TypedDict, total=False):
+ bank_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_bank_drop_response.py b/src/llama_stack/types/memory_drop_response.py
similarity index 62%
rename from src/llama_stack/types/memory_bank_drop_response.py
rename to src/llama_stack/types/memory_drop_response.py
index d3f5c3f..f032e04 100644
--- a/src/llama_stack/types/memory_bank_drop_response.py
+++ b/src/llama_stack/types/memory_drop_response.py
@@ -2,6 +2,6 @@
from typing_extensions import TypeAlias
-__all__ = ["MemoryBankDropResponse"]
+__all__ = ["MemoryDropResponse"]
-MemoryBankDropResponse: TypeAlias = str
+MemoryDropResponse: TypeAlias = str
diff --git a/src/llama_stack/types/memory_bank_insert_params.py b/src/llama_stack/types/memory_insert_params.py
similarity index 63%
rename from src/llama_stack/types/memory_bank_insert_params.py
rename to src/llama_stack/types/memory_insert_params.py
index c460fc2..b09f8a7 100644
--- a/src/llama_stack/types/memory_bank_insert_params.py
+++ b/src/llama_stack/types/memory_insert_params.py
@@ -3,18 +3,22 @@
from __future__ import annotations
from typing import Dict, List, Union, Iterable
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
-__all__ = ["MemoryBankInsertParams", "Document"]
+from .._utils import PropertyInfo
+__all__ = ["MemoryInsertParams", "Document"]
-class MemoryBankInsertParams(TypedDict, total=False):
+
+class MemoryInsertParams(TypedDict, total=False):
bank_id: Required[str]
documents: Required[Iterable[Document]]
ttl_seconds: int
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class Document(TypedDict, total=False):
content: Required[Union[str, List[str]]]
diff --git a/src/llama_stack/types/memory_bank_query_params.py b/src/llama_stack/types/memory_query_params.py
similarity index 54%
rename from src/llama_stack/types/memory_bank_query_params.py
rename to src/llama_stack/types/memory_query_params.py
index 05e4a56..d885600 100644
--- a/src/llama_stack/types/memory_bank_query_params.py
+++ b/src/llama_stack/types/memory_query_params.py
@@ -3,14 +3,18 @@
from __future__ import annotations
from typing import Dict, List, Union, Iterable
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
-__all__ = ["MemoryBankQueryParams"]
+from .._utils import PropertyInfo
+__all__ = ["MemoryQueryParams"]
-class MemoryBankQueryParams(TypedDict, total=False):
+
+class MemoryQueryParams(TypedDict, total=False):
bank_id: Required[str]
query: Required[Union[str, List[str]]]
params: Dict[str, Union[bool, float, str, Iterable[object], object, None]]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_retrieve_params.py b/src/llama_stack/types/memory_retrieve_params.py
new file mode 100644
index 0000000..62f6496
--- /dev/null
+++ b/src/llama_stack/types/memory_retrieve_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
+
+__all__ = ["MemoryRetrieveParams"]
+
+
+class MemoryRetrieveParams(TypedDict, total=False):
+ bank_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/memory_bank_update_params.py b/src/llama_stack/types/memory_update_params.py
similarity index 62%
rename from src/llama_stack/types/memory_bank_update_params.py
rename to src/llama_stack/types/memory_update_params.py
index 2ab747b..ac9b8e4 100644
--- a/src/llama_stack/types/memory_bank_update_params.py
+++ b/src/llama_stack/types/memory_update_params.py
@@ -3,16 +3,20 @@
from __future__ import annotations
from typing import Dict, List, Union, Iterable
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
-__all__ = ["MemoryBankUpdateParams", "Document"]
+from .._utils import PropertyInfo
+__all__ = ["MemoryUpdateParams", "Document"]
-class MemoryBankUpdateParams(TypedDict, total=False):
+
+class MemoryUpdateParams(TypedDict, total=False):
bank_id: Required[str]
documents: Required[Iterable[Document]]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class Document(TypedDict, total=False):
content: Required[Union[str, List[str]]]
diff --git a/src/llama_stack/types/model_get_params.py b/src/llama_stack/types/model_get_params.py
new file mode 100644
index 0000000..f3dc87d
--- /dev/null
+++ b/src/llama_stack/types/model_get_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
+
+__all__ = ["ModelGetParams"]
+
+
+class ModelGetParams(TypedDict, total=False):
+ core_model_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/model_serving_spec.py b/src/llama_stack/types/model_serving_spec.py
new file mode 100644
index 0000000..87b75a9
--- /dev/null
+++ b/src/llama_stack/types/model_serving_spec.py
@@ -0,0 +1,23 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, List, Union
+
+from .._models import BaseModel
+
+__all__ = ["ModelServingSpec", "ProviderConfig"]
+
+
+class ProviderConfig(BaseModel):
+ config: Dict[str, Union[bool, float, str, List[object], object, None]]
+
+ provider_id: str
+
+
+class ModelServingSpec(BaseModel):
+ llama_model: object
+ """
+ The model family and SKU of the model along with other parameters corresponding
+ to the model.
+ """
+
+ provider_config: ProviderConfig
diff --git a/src/llama_stack/types/post_training/job_artifacts_params.py b/src/llama_stack/types/post_training/job_artifacts_params.py
index 4f75a13..1f7ae65 100644
--- a/src/llama_stack/types/post_training/job_artifacts_params.py
+++ b/src/llama_stack/types/post_training/job_artifacts_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["JobArtifactsParams"]
class JobArtifactsParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/post_training/job_cancel_params.py b/src/llama_stack/types/post_training/job_cancel_params.py
index c9c30d8..9321c3b 100644
--- a/src/llama_stack/types/post_training/job_cancel_params.py
+++ b/src/llama_stack/types/post_training/job_cancel_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["JobCancelParams"]
class JobCancelParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/post_training/job_logs_params.py b/src/llama_stack/types/post_training/job_logs_params.py
index a550be5..42f7e07 100644
--- a/src/llama_stack/types/post_training/job_logs_params.py
+++ b/src/llama_stack/types/post_training/job_logs_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["JobLogsParams"]
class JobLogsParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/post_training/job_status_params.py b/src/llama_stack/types/post_training/job_status_params.py
index 8cf17b0..f1f8b20 100644
--- a/src/llama_stack/types/post_training/job_status_params.py
+++ b/src/llama_stack/types/post_training/job_status_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from ..._utils import PropertyInfo
__all__ = ["JobStatusParams"]
class JobStatusParams(TypedDict, total=False):
job_uuid: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/post_training_preference_optimize_params.py b/src/llama_stack/types/post_training_preference_optimize_params.py
index 9e6f3cc..805e6cf 100644
--- a/src/llama_stack/types/post_training_preference_optimize_params.py
+++ b/src/llama_stack/types/post_training_preference_optimize_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Dict, Union, Iterable
-from typing_extensions import Literal, Required, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypedDict
+from .._utils import PropertyInfo
from .train_eval_dataset_param import TrainEvalDatasetParam
__all__ = ["PostTrainingPreferenceOptimizeParams", "AlgorithmConfig", "OptimizerConfig", "TrainingConfig"]
@@ -31,6 +32,8 @@ class PostTrainingPreferenceOptimizeParams(TypedDict, total=False):
validation_dataset: Required[TrainEvalDatasetParam]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class AlgorithmConfig(TypedDict, total=False):
epsilon: Required[float]
diff --git a/src/llama_stack/types/post_training_supervised_fine_tune_params.py b/src/llama_stack/types/post_training_supervised_fine_tune_params.py
index 36f776d..084e1ed 100644
--- a/src/llama_stack/types/post_training_supervised_fine_tune_params.py
+++ b/src/llama_stack/types/post_training_supervised_fine_tune_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Dict, List, Union, Iterable
-from typing_extensions import Literal, Required, TypeAlias, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict
+from .._utils import PropertyInfo
from .train_eval_dataset_param import TrainEvalDatasetParam
__all__ = [
@@ -39,6 +40,8 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False):
validation_dataset: Required[TrainEvalDatasetParam]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False):
alpha: Required[int]
diff --git a/src/llama_stack/types/reward_scoring_score_params.py b/src/llama_stack/types/reward_scoring_score_params.py
index b969b75..bb7bfb6 100644
--- a/src/llama_stack/types/reward_scoring_score_params.py
+++ b/src/llama_stack/types/reward_scoring_score_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Union, Iterable
-from typing_extensions import Required, TypeAlias, TypedDict
+from typing_extensions import Required, Annotated, TypeAlias, TypedDict
+from .._utils import PropertyInfo
from .shared_params.user_message import UserMessage
from .shared_params.system_message import SystemMessage
from .shared_params.completion_message import CompletionMessage
@@ -23,6 +24,8 @@ class RewardScoringScoreParams(TypedDict, total=False):
model: Required[str]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
DialogGenerationDialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage]
diff --git a/src/llama_stack/types/run_sheid_response.py b/src/llama_stack/types/run_sheid_response.py
new file mode 100644
index 0000000..478b023
--- /dev/null
+++ b/src/llama_stack/types/run_sheid_response.py
@@ -0,0 +1,20 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, List, Union, Optional
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["RunSheidResponse", "Violation"]
+
+
+class Violation(BaseModel):
+ metadata: Dict[str, Union[bool, float, str, List[object], object, None]]
+
+ violation_level: Literal["info", "warn", "error"]
+
+ user_message: Optional[str] = None
+
+
+class RunSheidResponse(BaseModel):
+ violation: Optional[Violation] = None
diff --git a/src/llama_stack/types/safety_run_shields_params.py b/src/llama_stack/types/safety_run_shield_params.py
similarity index 52%
rename from src/llama_stack/types/safety_run_shields_params.py
rename to src/llama_stack/types/safety_run_shield_params.py
index 59498f6..430473b 100644
--- a/src/llama_stack/types/safety_run_shields_params.py
+++ b/src/llama_stack/types/safety_run_shield_params.py
@@ -2,22 +2,26 @@
from __future__ import annotations
-from typing import Union, Iterable
-from typing_extensions import Required, TypeAlias, TypedDict
+from typing import Dict, Union, Iterable
+from typing_extensions import Required, Annotated, TypeAlias, TypedDict
-from .shield_definition_param import ShieldDefinitionParam
+from .._utils import PropertyInfo
from .shared_params.user_message import UserMessage
from .shared_params.system_message import SystemMessage
from .shared_params.completion_message import CompletionMessage
from .shared_params.tool_response_message import ToolResponseMessage
-__all__ = ["SafetyRunShieldsParams", "Message"]
+__all__ = ["SafetyRunShieldParams", "Message"]
-class SafetyRunShieldsParams(TypedDict, total=False):
+class SafetyRunShieldParams(TypedDict, total=False):
messages: Required[Iterable[Message]]
- shields: Required[Iterable[ShieldDefinitionParam]]
+ params: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]]
+
+ shield_type: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
Message: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage]
diff --git a/src/llama_stack/types/safety_run_shields_response.py b/src/llama_stack/types/safety_run_shields_response.py
deleted file mode 100644
index 24f87f2..0000000
--- a/src/llama_stack/types/safety_run_shields_response.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from typing import List
-
-from .._models import BaseModel
-from .sheid_response import SheidResponse
-
-__all__ = ["SafetyRunShieldsResponse"]
-
-
-class SafetyRunShieldsResponse(BaseModel):
- responses: List[SheidResponse]
diff --git a/src/llama_stack/types/sheid_response.py b/src/llama_stack/types/sheid_response.py
deleted file mode 100644
index 99cdcf5..0000000
--- a/src/llama_stack/types/sheid_response.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from typing import Union, Optional
-from typing_extensions import Literal
-
-from .._models import BaseModel
-
-__all__ = ["SheidResponse"]
-
-
-class SheidResponse(BaseModel):
- is_violation: bool
-
- shield_type: Union[
- Literal["llama_guard", "code_scanner_guard", "third_party_shield", "injection_shield", "jailbreak_shield"], str
- ]
-
- violation_return_message: Optional[str] = None
-
- violation_type: Optional[str] = None
diff --git a/src/llama_stack/types/shield_call_step.py b/src/llama_stack/types/shield_call_step.py
index e360825..d4b90d8 100644
--- a/src/llama_stack/types/shield_call_step.py
+++ b/src/llama_stack/types/shield_call_step.py
@@ -1,18 +1,23 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-from typing import Optional
+from typing import Dict, List, Union, Optional
from datetime import datetime
from typing_extensions import Literal
from .._models import BaseModel
-from .sheid_response import SheidResponse
-__all__ = ["ShieldCallStep"]
+__all__ = ["ShieldCallStep", "Violation"]
-class ShieldCallStep(BaseModel):
- response: SheidResponse
+class Violation(BaseModel):
+ metadata: Dict[str, Union[bool, float, str, List[object], object, None]]
+
+ violation_level: Literal["info", "warn", "error"]
+
+ user_message: Optional[str] = None
+
+class ShieldCallStep(BaseModel):
step_id: str
step_type: Literal["shield_call"]
@@ -22,3 +27,5 @@ class ShieldCallStep(BaseModel):
completed_at: Optional[datetime] = None
started_at: Optional[datetime] = None
+
+ violation: Optional[Violation] = None
diff --git a/src/llama_stack/types/shield_definition_param.py b/src/llama_stack/types/shield_definition_param.py
deleted file mode 100644
index 9672e03..0000000
--- a/src/llama_stack/types/shield_definition_param.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from __future__ import annotations
-
-from typing import Dict, Union
-from typing_extensions import Literal, Required, TypedDict
-
-from .tool_param_definition_param import ToolParamDefinitionParam
-from .rest_api_execution_config_param import RestAPIExecutionConfigParam
-
-__all__ = ["ShieldDefinitionParam"]
-
-
-class ShieldDefinitionParam(TypedDict, total=False):
- on_violation_action: Required[Literal[0, 1, 2]]
-
- shield_type: Required[
- Union[
- Literal["llama_guard", "code_scanner_guard", "third_party_shield", "injection_shield", "jailbreak_shield"],
- str,
- ]
- ]
-
- description: str
-
- execution_config: RestAPIExecutionConfigParam
-
- parameters: Dict[str, ToolParamDefinitionParam]
diff --git a/src/llama_stack/types/shield_get_params.py b/src/llama_stack/types/shield_get_params.py
new file mode 100644
index 0000000..cb9ce90
--- /dev/null
+++ b/src/llama_stack/types/shield_get_params.py
@@ -0,0 +1,15 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
+
+__all__ = ["ShieldGetParams"]
+
+
+class ShieldGetParams(TypedDict, total=False):
+ shield_type: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/shield_spec.py b/src/llama_stack/types/shield_spec.py
new file mode 100644
index 0000000..d83cd51
--- /dev/null
+++ b/src/llama_stack/types/shield_spec.py
@@ -0,0 +1,19 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, List, Union
+
+from .._models import BaseModel
+
+__all__ = ["ShieldSpec", "ProviderConfig"]
+
+
+class ProviderConfig(BaseModel):
+ config: Dict[str, Union[bool, float, str, List[object], object, None]]
+
+ provider_id: str
+
+
+class ShieldSpec(BaseModel):
+ provider_config: ProviderConfig
+
+ shield_type: str
diff --git a/src/llama_stack/types/synthetic_data_generation_generate_params.py b/src/llama_stack/types/synthetic_data_generation_generate_params.py
index 4238473..2514992 100644
--- a/src/llama_stack/types/synthetic_data_generation_generate_params.py
+++ b/src/llama_stack/types/synthetic_data_generation_generate_params.py
@@ -3,8 +3,9 @@
from __future__ import annotations
from typing import Union, Iterable
-from typing_extensions import Literal, Required, TypeAlias, TypedDict
+from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict
+from .._utils import PropertyInfo
from .shared_params.user_message import UserMessage
from .shared_params.system_message import SystemMessage
from .shared_params.completion_message import CompletionMessage
@@ -20,5 +21,7 @@ class SyntheticDataGenerationGenerateParams(TypedDict, total=False):
model: str
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
Dialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage]
diff --git a/src/llama_stack/types/telemetry_get_trace_params.py b/src/llama_stack/types/telemetry_get_trace_params.py
index 520724b..dbee698 100644
--- a/src/llama_stack/types/telemetry_get_trace_params.py
+++ b/src/llama_stack/types/telemetry_get_trace_params.py
@@ -2,10 +2,14 @@
from __future__ import annotations
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, Annotated, TypedDict
+
+from .._utils import PropertyInfo
__all__ = ["TelemetryGetTraceParams"]
class TelemetryGetTraceParams(TypedDict, total=False):
trace_id: Required[str]
+
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
diff --git a/src/llama_stack/types/telemetry_log_params.py b/src/llama_stack/types/telemetry_log_params.py
index 6e6eb61..a2e4d9b 100644
--- a/src/llama_stack/types/telemetry_log_params.py
+++ b/src/llama_stack/types/telemetry_log_params.py
@@ -23,6 +23,8 @@
class TelemetryLogParams(TypedDict, total=False):
event: Required[Event]
+ x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")]
+
class EventUnstructuredLogEvent(TypedDict, total=False):
message: Required[str]
diff --git a/tests/api_resources/agents/test_sessions.py b/tests/api_resources/agents/test_sessions.py
index 4040ef0..97836d7 100644
--- a/tests/api_resources/agents/test_sessions.py
+++ b/tests/api_resources/agents/test_sessions.py
@@ -28,6 +28,15 @@ def test_method_create(self, client: LlamaStack) -> None:
)
assert_matches_type(SessionCreateResponse, session, path=["response"])
+ @parametrize
+ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
+ session = client.agents.sessions.create(
+ agent_id="agent_id",
+ session_name="session_name",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(SessionCreateResponse, session, path=["response"])
+
@parametrize
def test_raw_response_create(self, client: LlamaStack) -> None:
response = client.agents.sessions.with_raw_response.create(
@@ -68,6 +77,7 @@ def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None:
agent_id="agent_id",
session_id="session_id",
turn_ids=["string", "string", "string"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(Session, session, path=["response"])
@@ -105,6 +115,15 @@ def test_method_delete(self, client: LlamaStack) -> None:
)
assert session is None
+ @parametrize
+ def test_method_delete_with_all_params(self, client: LlamaStack) -> None:
+ session = client.agents.sessions.delete(
+ agent_id="agent_id",
+ session_id="session_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert session is None
+
@parametrize
def test_raw_response_delete(self, client: LlamaStack) -> None:
response = client.agents.sessions.with_raw_response.delete(
@@ -143,6 +162,15 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(SessionCreateResponse, session, path=["response"])
+ @parametrize
+ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ session = await async_client.agents.sessions.create(
+ agent_id="agent_id",
+ session_name="session_name",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(SessionCreateResponse, session, path=["response"])
+
@parametrize
async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.agents.sessions.with_raw_response.create(
@@ -183,6 +211,7 @@ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaSta
agent_id="agent_id",
session_id="session_id",
turn_ids=["string", "string", "string"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(Session, session, path=["response"])
@@ -220,6 +249,15 @@ async def test_method_delete(self, async_client: AsyncLlamaStack) -> None:
)
assert session is None
+ @parametrize
+ async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ session = await async_client.agents.sessions.delete(
+ agent_id="agent_id",
+ session_id="session_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert session is None
+
@parametrize
async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.agents.sessions.with_raw_response.delete(
diff --git a/tests/api_resources/agents/test_steps.py b/tests/api_resources/agents/test_steps.py
index 3dda39e..64345d9 100644
--- a/tests/api_resources/agents/test_steps.py
+++ b/tests/api_resources/agents/test_steps.py
@@ -26,6 +26,16 @@ def test_method_retrieve(self, client: LlamaStack) -> None:
)
assert_matches_type(AgentsStep, step, path=["response"])
+ @parametrize
+ def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None:
+ step = client.agents.steps.retrieve(
+ agent_id="agent_id",
+ step_id="step_id",
+ turn_id="turn_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(AgentsStep, step, path=["response"])
+
@parametrize
def test_raw_response_retrieve(self, client: LlamaStack) -> None:
response = client.agents.steps.with_raw_response.retrieve(
@@ -67,6 +77,16 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(AgentsStep, step, path=["response"])
+ @parametrize
+ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ step = await async_client.agents.steps.retrieve(
+ agent_id="agent_id",
+ step_id="step_id",
+ turn_id="turn_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(AgentsStep, step, path=["response"])
+
@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.agents.steps.with_raw_response.retrieve(
diff --git a/tests/api_resources/agents/test_turns.py b/tests/api_resources/agents/test_turns.py
index 47614e2..ec890a0 100644
--- a/tests/api_resources/agents/test_turns.py
+++ b/tests/api_resources/agents/test_turns.py
@@ -76,6 +76,7 @@ def test_method_create_with_all_params_overload_1(self, client: LlamaStack) -> N
},
],
stream=False,
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"])
@@ -193,6 +194,7 @@ def test_method_create_with_all_params_overload_2(self, client: LlamaStack) -> N
"mime_type": "mime_type",
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
turn_stream.response.close()
@@ -259,6 +261,15 @@ def test_method_retrieve(self, client: LlamaStack) -> None:
)
assert_matches_type(Turn, turn, path=["response"])
+ @parametrize
+ def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None:
+ turn = client.agents.turns.retrieve(
+ agent_id="agent_id",
+ turn_id="turn_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Turn, turn, path=["response"])
+
@parametrize
def test_raw_response_retrieve(self, client: LlamaStack) -> None:
response = client.agents.turns.with_raw_response.retrieve(
@@ -348,6 +359,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn
},
],
stream=False,
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"])
@@ -465,6 +477,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn
"mime_type": "mime_type",
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
await turn_stream.response.aclose()
@@ -531,6 +544,15 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(Turn, turn, path=["response"])
+ @parametrize
+ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ turn = await async_client.agents.turns.retrieve(
+ agent_id="agent_id",
+ turn_id="turn_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Turn, turn, path=["response"])
+
@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.agents.turns.with_raw_response.retrieve(
diff --git a/tests/api_resources/evaluate/jobs/test_artifacts.py b/tests/api_resources/evaluate/jobs/test_artifacts.py
index 3cef0d7..44d205f 100755
--- a/tests/api_resources/evaluate/jobs/test_artifacts.py
+++ b/tests/api_resources/evaluate/jobs/test_artifacts.py
@@ -24,6 +24,14 @@ def test_method_list(self, client: LlamaStack) -> None:
)
assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"])
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ artifact = client.evaluate.jobs.artifacts.list(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"])
+
@parametrize
def test_raw_response_list(self, client: LlamaStack) -> None:
response = client.evaluate.jobs.artifacts.with_raw_response.list(
@@ -59,6 +67,14 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"])
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ artifact = await async_client.evaluate.jobs.artifacts.list(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"])
+
@parametrize
async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluate.jobs.artifacts.with_raw_response.list(
diff --git a/tests/api_resources/evaluate/jobs/test_logs.py b/tests/api_resources/evaluate/jobs/test_logs.py
index 472ac90..51af56e 100755
--- a/tests/api_resources/evaluate/jobs/test_logs.py
+++ b/tests/api_resources/evaluate/jobs/test_logs.py
@@ -24,6 +24,14 @@ def test_method_list(self, client: LlamaStack) -> None:
)
assert_matches_type(EvaluationJobLogStream, log, path=["response"])
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ log = client.evaluate.jobs.logs.list(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJobLogStream, log, path=["response"])
+
@parametrize
def test_raw_response_list(self, client: LlamaStack) -> None:
response = client.evaluate.jobs.logs.with_raw_response.list(
@@ -59,6 +67,14 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(EvaluationJobLogStream, log, path=["response"])
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ log = await async_client.evaluate.jobs.logs.list(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJobLogStream, log, path=["response"])
+
@parametrize
async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluate.jobs.logs.with_raw_response.list(
diff --git a/tests/api_resources/evaluate/jobs/test_status.py b/tests/api_resources/evaluate/jobs/test_status.py
index 3a6000e..c8bc510 100755
--- a/tests/api_resources/evaluate/jobs/test_status.py
+++ b/tests/api_resources/evaluate/jobs/test_status.py
@@ -24,6 +24,14 @@ def test_method_list(self, client: LlamaStack) -> None:
)
assert_matches_type(EvaluationJobStatus, status, path=["response"])
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ status = client.evaluate.jobs.status.list(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJobStatus, status, path=["response"])
+
@parametrize
def test_raw_response_list(self, client: LlamaStack) -> None:
response = client.evaluate.jobs.status.with_raw_response.list(
@@ -59,6 +67,14 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(EvaluationJobStatus, status, path=["response"])
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ status = await async_client.evaluate.jobs.status.list(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJobStatus, status, path=["response"])
+
@parametrize
async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluate.jobs.status.with_raw_response.list(
diff --git a/tests/api_resources/evaluate/test_jobs.py b/tests/api_resources/evaluate/test_jobs.py
index 601bf38..3e296af 100644
--- a/tests/api_resources/evaluate/test_jobs.py
+++ b/tests/api_resources/evaluate/test_jobs.py
@@ -22,6 +22,13 @@ def test_method_list(self, client: LlamaStack) -> None:
job = client.evaluate.jobs.list()
assert_matches_type(EvaluationJob, job, path=["response"])
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ job = client.evaluate.jobs.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, job, path=["response"])
+
@parametrize
def test_raw_response_list(self, client: LlamaStack) -> None:
response = client.evaluate.jobs.with_raw_response.list()
@@ -49,6 +56,14 @@ def test_method_cancel(self, client: LlamaStack) -> None:
)
assert job is None
+ @parametrize
+ def test_method_cancel_with_all_params(self, client: LlamaStack) -> None:
+ job = client.evaluate.jobs.cancel(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert job is None
+
@parametrize
def test_raw_response_cancel(self, client: LlamaStack) -> None:
response = client.evaluate.jobs.with_raw_response.cancel(
@@ -82,6 +97,13 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
job = await async_client.evaluate.jobs.list()
assert_matches_type(EvaluationJob, job, path=["response"])
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.evaluate.jobs.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, job, path=["response"])
+
@parametrize
async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluate.jobs.with_raw_response.list()
@@ -109,6 +131,14 @@ async def test_method_cancel(self, async_client: AsyncLlamaStack) -> None:
)
assert job is None
+ @parametrize
+ async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.evaluate.jobs.cancel(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert job is None
+
@parametrize
async def test_raw_response_cancel(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluate.jobs.with_raw_response.cancel(
diff --git a/tests/api_resources/evaluate/test_question_answering.py b/tests/api_resources/evaluate/test_question_answering.py
index b229620..d7b4a23 100644
--- a/tests/api_resources/evaluate/test_question_answering.py
+++ b/tests/api_resources/evaluate/test_question_answering.py
@@ -24,6 +24,14 @@ def test_method_create(self, client: LlamaStack) -> None:
)
assert_matches_type(EvaluationJob, question_answering, path=["response"])
+ @parametrize
+ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
+ question_answering = client.evaluate.question_answering.create(
+ metrics=["em", "f1"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, question_answering, path=["response"])
+
@parametrize
def test_raw_response_create(self, client: LlamaStack) -> None:
response = client.evaluate.question_answering.with_raw_response.create(
@@ -59,6 +67,14 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(EvaluationJob, question_answering, path=["response"])
+ @parametrize
+ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ question_answering = await async_client.evaluate.question_answering.create(
+ metrics=["em", "f1"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, question_answering, path=["response"])
+
@parametrize
async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluate.question_answering.with_raw_response.create(
diff --git a/tests/api_resources/inference/test_embeddings.py b/tests/api_resources/inference/test_embeddings.py
index 62c90ce..b83a8ef 100644
--- a/tests/api_resources/inference/test_embeddings.py
+++ b/tests/api_resources/inference/test_embeddings.py
@@ -25,6 +25,15 @@ def test_method_create(self, client: LlamaStack) -> None:
)
assert_matches_type(Embeddings, embedding, path=["response"])
+ @parametrize
+ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
+ embedding = client.inference.embeddings.create(
+ contents=["string", "string", "string"],
+ model="model",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Embeddings, embedding, path=["response"])
+
@parametrize
def test_raw_response_create(self, client: LlamaStack) -> None:
response = client.inference.embeddings.with_raw_response.create(
@@ -63,6 +72,15 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(Embeddings, embedding, path=["response"])
+ @parametrize
+ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ embedding = await async_client.inference.embeddings.create(
+ contents=["string", "string", "string"],
+ model="model",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Embeddings, embedding, path=["response"])
+
@parametrize
async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.inference.embeddings.with_raw_response.create(
diff --git a/tests/api_resources/memory_banks/__init__.py b/tests/api_resources/memory/__init__.py
similarity index 100%
rename from tests/api_resources/memory_banks/__init__.py
rename to tests/api_resources/memory/__init__.py
diff --git a/tests/api_resources/memory_banks/test_documents.py b/tests/api_resources/memory/test_documents.py
similarity index 68%
rename from tests/api_resources/memory_banks/test_documents.py
rename to tests/api_resources/memory/test_documents.py
index 04514eb..842efac 100644
--- a/tests/api_resources/memory_banks/test_documents.py
+++ b/tests/api_resources/memory/test_documents.py
@@ -9,7 +9,7 @@
from llama_stack import LlamaStack, AsyncLlamaStack
from tests.utils import assert_matches_type
-from llama_stack.types.memory_banks import DocumentRetrieveResponse
+from llama_stack.types.memory import DocumentRetrieveResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -19,15 +19,24 @@ class TestDocuments:
@parametrize
def test_method_retrieve(self, client: LlamaStack) -> None:
- document = client.memory_banks.documents.retrieve(
+ document = client.memory.documents.retrieve(
bank_id="bank_id",
document_ids=["string", "string", "string"],
)
assert_matches_type(DocumentRetrieveResponse, document, path=["response"])
+ @parametrize
+ def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None:
+ document = client.memory.documents.retrieve(
+ bank_id="bank_id",
+ document_ids=["string", "string", "string"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(DocumentRetrieveResponse, document, path=["response"])
+
@parametrize
def test_raw_response_retrieve(self, client: LlamaStack) -> None:
- response = client.memory_banks.documents.with_raw_response.retrieve(
+ response = client.memory.documents.with_raw_response.retrieve(
bank_id="bank_id",
document_ids=["string", "string", "string"],
)
@@ -39,7 +48,7 @@ def test_raw_response_retrieve(self, client: LlamaStack) -> None:
@parametrize
def test_streaming_response_retrieve(self, client: LlamaStack) -> None:
- with client.memory_banks.documents.with_streaming_response.retrieve(
+ with client.memory.documents.with_streaming_response.retrieve(
bank_id="bank_id",
document_ids=["string", "string", "string"],
) as response:
@@ -53,15 +62,24 @@ def test_streaming_response_retrieve(self, client: LlamaStack) -> None:
@parametrize
def test_method_delete(self, client: LlamaStack) -> None:
- document = client.memory_banks.documents.delete(
+ document = client.memory.documents.delete(
+ bank_id="bank_id",
+ document_ids=["string", "string", "string"],
+ )
+ assert document is None
+
+ @parametrize
+ def test_method_delete_with_all_params(self, client: LlamaStack) -> None:
+ document = client.memory.documents.delete(
bank_id="bank_id",
document_ids=["string", "string", "string"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert document is None
@parametrize
def test_raw_response_delete(self, client: LlamaStack) -> None:
- response = client.memory_banks.documents.with_raw_response.delete(
+ response = client.memory.documents.with_raw_response.delete(
bank_id="bank_id",
document_ids=["string", "string", "string"],
)
@@ -73,7 +91,7 @@ def test_raw_response_delete(self, client: LlamaStack) -> None:
@parametrize
def test_streaming_response_delete(self, client: LlamaStack) -> None:
- with client.memory_banks.documents.with_streaming_response.delete(
+ with client.memory.documents.with_streaming_response.delete(
bank_id="bank_id",
document_ids=["string", "string", "string"],
) as response:
@@ -91,15 +109,24 @@ class TestAsyncDocuments:
@parametrize
async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None:
- document = await async_client.memory_banks.documents.retrieve(
+ document = await async_client.memory.documents.retrieve(
bank_id="bank_id",
document_ids=["string", "string", "string"],
)
assert_matches_type(DocumentRetrieveResponse, document, path=["response"])
+ @parametrize
+ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ document = await async_client.memory.documents.retrieve(
+ bank_id="bank_id",
+ document_ids=["string", "string", "string"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(DocumentRetrieveResponse, document, path=["response"])
+
@parametrize
async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.documents.with_raw_response.retrieve(
+ response = await async_client.memory.documents.with_raw_response.retrieve(
bank_id="bank_id",
document_ids=["string", "string", "string"],
)
@@ -111,7 +138,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> Non
@parametrize
async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.documents.with_streaming_response.retrieve(
+ async with async_client.memory.documents.with_streaming_response.retrieve(
bank_id="bank_id",
document_ids=["string", "string", "string"],
) as response:
@@ -125,15 +152,24 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack)
@parametrize
async def test_method_delete(self, async_client: AsyncLlamaStack) -> None:
- document = await async_client.memory_banks.documents.delete(
+ document = await async_client.memory.documents.delete(
+ bank_id="bank_id",
+ document_ids=["string", "string", "string"],
+ )
+ assert document is None
+
+ @parametrize
+ async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ document = await async_client.memory.documents.delete(
bank_id="bank_id",
document_ids=["string", "string", "string"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert document is None
@parametrize
async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.documents.with_raw_response.delete(
+ response = await async_client.memory.documents.with_raw_response.delete(
bank_id="bank_id",
document_ids=["string", "string", "string"],
)
@@ -145,7 +181,7 @@ async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None:
@parametrize
async def test_streaming_response_delete(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.documents.with_streaming_response.delete(
+ async with async_client.memory.documents.with_streaming_response.delete(
bank_id="bank_id",
document_ids=["string", "string", "string"],
) as response:
diff --git a/tests/api_resources/post_training/test_jobs.py b/tests/api_resources/post_training/test_jobs.py
index 2580031..ec2f8f0 100644
--- a/tests/api_resources/post_training/test_jobs.py
+++ b/tests/api_resources/post_training/test_jobs.py
@@ -27,6 +27,13 @@ def test_method_list(self, client: LlamaStack) -> None:
job = client.post_training.jobs.list()
assert_matches_type(PostTrainingJob, job, path=["response"])
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ job = client.post_training.jobs.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJob, job, path=["response"])
+
@parametrize
def test_raw_response_list(self, client: LlamaStack) -> None:
response = client.post_training.jobs.with_raw_response.list()
@@ -54,6 +61,14 @@ def test_method_artifacts(self, client: LlamaStack) -> None:
)
assert_matches_type(PostTrainingJobArtifacts, job, path=["response"])
+ @parametrize
+ def test_method_artifacts_with_all_params(self, client: LlamaStack) -> None:
+ job = client.post_training.jobs.artifacts(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJobArtifacts, job, path=["response"])
+
@parametrize
def test_raw_response_artifacts(self, client: LlamaStack) -> None:
response = client.post_training.jobs.with_raw_response.artifacts(
@@ -85,6 +100,14 @@ def test_method_cancel(self, client: LlamaStack) -> None:
)
assert job is None
+ @parametrize
+ def test_method_cancel_with_all_params(self, client: LlamaStack) -> None:
+ job = client.post_training.jobs.cancel(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert job is None
+
@parametrize
def test_raw_response_cancel(self, client: LlamaStack) -> None:
response = client.post_training.jobs.with_raw_response.cancel(
@@ -116,6 +139,14 @@ def test_method_logs(self, client: LlamaStack) -> None:
)
assert_matches_type(PostTrainingJobLogStream, job, path=["response"])
+ @parametrize
+ def test_method_logs_with_all_params(self, client: LlamaStack) -> None:
+ job = client.post_training.jobs.logs(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJobLogStream, job, path=["response"])
+
@parametrize
def test_raw_response_logs(self, client: LlamaStack) -> None:
response = client.post_training.jobs.with_raw_response.logs(
@@ -147,6 +178,14 @@ def test_method_status(self, client: LlamaStack) -> None:
)
assert_matches_type(PostTrainingJobStatus, job, path=["response"])
+ @parametrize
+ def test_method_status_with_all_params(self, client: LlamaStack) -> None:
+ job = client.post_training.jobs.status(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJobStatus, job, path=["response"])
+
@parametrize
def test_raw_response_status(self, client: LlamaStack) -> None:
response = client.post_training.jobs.with_raw_response.status(
@@ -180,6 +219,13 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
job = await async_client.post_training.jobs.list()
assert_matches_type(PostTrainingJob, job, path=["response"])
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.post_training.jobs.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJob, job, path=["response"])
+
@parametrize
async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.post_training.jobs.with_raw_response.list()
@@ -207,6 +253,14 @@ async def test_method_artifacts(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(PostTrainingJobArtifacts, job, path=["response"])
+ @parametrize
+ async def test_method_artifacts_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.post_training.jobs.artifacts(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJobArtifacts, job, path=["response"])
+
@parametrize
async def test_raw_response_artifacts(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.post_training.jobs.with_raw_response.artifacts(
@@ -238,6 +292,14 @@ async def test_method_cancel(self, async_client: AsyncLlamaStack) -> None:
)
assert job is None
+ @parametrize
+ async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.post_training.jobs.cancel(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert job is None
+
@parametrize
async def test_raw_response_cancel(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.post_training.jobs.with_raw_response.cancel(
@@ -269,6 +331,14 @@ async def test_method_logs(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(PostTrainingJobLogStream, job, path=["response"])
+ @parametrize
+ async def test_method_logs_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.post_training.jobs.logs(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJobLogStream, job, path=["response"])
+
@parametrize
async def test_raw_response_logs(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.post_training.jobs.with_raw_response.logs(
@@ -300,6 +370,14 @@ async def test_method_status(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(PostTrainingJobStatus, job, path=["response"])
+ @parametrize
+ async def test_method_status_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ job = await async_client.post_training.jobs.status(
+ job_uuid="job_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(PostTrainingJobStatus, job, path=["response"])
+
@parametrize
async def test_raw_response_status(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.post_training.jobs.with_raw_response.status(
diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py
index a9a46c9..355a0b1 100644
--- a/tests/api_resources/test_agents.py
+++ b/tests/api_resources/test_agents.py
@@ -21,7 +21,9 @@ class TestAgents:
def test_method_create(self, client: LlamaStack) -> None:
agent = client.agents.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
},
)
@@ -31,126 +33,12 @@ def test_method_create(self, client: LlamaStack) -> None:
def test_method_create_with_all_params(self, client: LlamaStack) -> None:
agent = client.agents.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"sampling_params": {
"strategy": "greedy",
"max_tokens": 0,
@@ -166,124 +54,8 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
"api_key": "api_key",
"engine": "bing",
"type": "brave_search",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"remote_execution": {
"method": "GET",
"url": "https://example.com",
@@ -296,124 +68,8 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
"api_key": "api_key",
"engine": "bing",
"type": "brave_search",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"remote_execution": {
"method": "GET",
"url": "https://example.com",
@@ -426,124 +82,8 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
"api_key": "api_key",
"engine": "bing",
"type": "brave_search",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"remote_execution": {
"method": "GET",
"url": "https://example.com",
@@ -554,6 +94,7 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
},
],
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(AgentCreateResponse, agent, path=["response"])
@@ -561,7 +102,9 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
def test_raw_response_create(self, client: LlamaStack) -> None:
response = client.agents.with_raw_response.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
},
)
@@ -575,7 +118,9 @@ def test_raw_response_create(self, client: LlamaStack) -> None:
def test_streaming_response_create(self, client: LlamaStack) -> None:
with client.agents.with_streaming_response.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
},
) as response:
@@ -594,6 +139,14 @@ def test_method_delete(self, client: LlamaStack) -> None:
)
assert agent is None
+ @parametrize
+ def test_method_delete_with_all_params(self, client: LlamaStack) -> None:
+ agent = client.agents.delete(
+ agent_id="agent_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert agent is None
+
@parametrize
def test_raw_response_delete(self, client: LlamaStack) -> None:
response = client.agents.with_raw_response.delete(
@@ -626,7 +179,9 @@ class TestAsyncAgents:
async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
agent = await async_client.agents.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
},
)
@@ -636,126 +191,12 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None:
agent = await async_client.agents.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"sampling_params": {
"strategy": "greedy",
"max_tokens": 0,
@@ -771,124 +212,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
"api_key": "api_key",
"engine": "bing",
"type": "brave_search",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"remote_execution": {
"method": "GET",
"url": "https://example.com",
@@ -901,124 +226,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
"api_key": "api_key",
"engine": "bing",
"type": "brave_search",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"remote_execution": {
"method": "GET",
"url": "https://example.com",
@@ -1031,124 +240,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
"api_key": "api_key",
"engine": "bing",
"type": "brave_search",
- "input_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
- "output_shields": [
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- "description": "description",
- "execution_config": {
- "method": "GET",
- "url": "https://example.com",
- "body": {"foo": True},
- "headers": {"foo": True},
- "params": {"foo": True},
- },
- "parameters": {
- "foo": {
- "param_type": "param_type",
- "description": "description",
- "required": True,
- }
- },
- },
- ],
+ "input_shields": ["string", "string", "string"],
+ "output_shields": ["string", "string", "string"],
"remote_execution": {
"method": "GET",
"url": "https://example.com",
@@ -1159,6 +252,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
},
],
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(AgentCreateResponse, agent, path=["response"])
@@ -1166,7 +260,9 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.agents.with_raw_response.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
},
)
@@ -1180,7 +276,9 @@ async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None:
async with async_client.agents.with_streaming_response.create(
agent_config={
+ "enable_session_persistence": True,
"instructions": "instructions",
+ "max_infer_iters": 0,
"model": "model",
},
) as response:
@@ -1199,6 +297,14 @@ async def test_method_delete(self, async_client: AsyncLlamaStack) -> None:
)
assert agent is None
+ @parametrize
+ async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ agent = await async_client.agents.delete(
+ agent_id="agent_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert agent is None
+
@parametrize
async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.agents.with_raw_response.delete(
diff --git a/tests/api_resources/test_batch_inference.py b/tests/api_resources/test_batch_inference.py
index 13fd6bb..3277598 100644
--- a/tests/api_resources/test_batch_inference.py
+++ b/tests/api_resources/test_batch_inference.py
@@ -174,6 +174,7 @@ def test_method_chat_completion_with_all_params(self, client: LlamaStack) -> Non
},
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(BatchChatCompletion, batch_inference, path=["response"])
@@ -311,6 +312,7 @@ def test_method_completion_with_all_params(self, client: LlamaStack) -> None:
"top_k": 0,
"top_p": 0,
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(BatchCompletion, batch_inference, path=["response"])
@@ -498,6 +500,7 @@ async def test_method_chat_completion_with_all_params(self, async_client: AsyncL
},
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(BatchChatCompletion, batch_inference, path=["response"])
@@ -635,6 +638,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS
"top_k": 0,
"top_p": 0,
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(BatchCompletion, batch_inference, path=["response"])
diff --git a/tests/api_resources/test_datasets.py b/tests/api_resources/test_datasets.py
index 4aec07d..8a0433b 100644
--- a/tests/api_resources/test_datasets.py
+++ b/tests/api_resources/test_datasets.py
@@ -37,6 +37,7 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
"metadata": {"foo": True},
},
uuid="uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert dataset is None
@@ -79,6 +80,14 @@ def test_method_delete(self, client: LlamaStack) -> None:
)
assert dataset is None
+ @parametrize
+ def test_method_delete_with_all_params(self, client: LlamaStack) -> None:
+ dataset = client.datasets.delete(
+ dataset_uuid="dataset_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert dataset is None
+
@parametrize
def test_raw_response_delete(self, client: LlamaStack) -> None:
response = client.datasets.with_raw_response.delete(
@@ -110,6 +119,14 @@ def test_method_get(self, client: LlamaStack) -> None:
)
assert_matches_type(TrainEvalDataset, dataset, path=["response"])
+ @parametrize
+ def test_method_get_with_all_params(self, client: LlamaStack) -> None:
+ dataset = client.datasets.get(
+ dataset_uuid="dataset_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(TrainEvalDataset, dataset, path=["response"])
+
@parametrize
def test_raw_response_get(self, client: LlamaStack) -> None:
response = client.datasets.with_raw_response.get(
@@ -158,6 +175,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack
"metadata": {"foo": True},
},
uuid="uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert dataset is None
@@ -200,6 +218,14 @@ async def test_method_delete(self, async_client: AsyncLlamaStack) -> None:
)
assert dataset is None
+ @parametrize
+ async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ dataset = await async_client.datasets.delete(
+ dataset_uuid="dataset_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert dataset is None
+
@parametrize
async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.datasets.with_raw_response.delete(
@@ -231,6 +257,14 @@ async def test_method_get(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(TrainEvalDataset, dataset, path=["response"])
+ @parametrize
+ async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ dataset = await async_client.datasets.get(
+ dataset_uuid="dataset_uuid",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(TrainEvalDataset, dataset, path=["response"])
+
@parametrize
async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.datasets.with_raw_response.get(
diff --git a/tests/api_resources/test_evaluations.py b/tests/api_resources/test_evaluations.py
index dbdf834..7d6a275 100644
--- a/tests/api_resources/test_evaluations.py
+++ b/tests/api_resources/test_evaluations.py
@@ -24,6 +24,14 @@ def test_method_summarization(self, client: LlamaStack) -> None:
)
assert_matches_type(EvaluationJob, evaluation, path=["response"])
+ @parametrize
+ def test_method_summarization_with_all_params(self, client: LlamaStack) -> None:
+ evaluation = client.evaluations.summarization(
+ metrics=["rouge", "bleu"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, evaluation, path=["response"])
+
@parametrize
def test_raw_response_summarization(self, client: LlamaStack) -> None:
response = client.evaluations.with_raw_response.summarization(
@@ -55,6 +63,14 @@ def test_method_text_generation(self, client: LlamaStack) -> None:
)
assert_matches_type(EvaluationJob, evaluation, path=["response"])
+ @parametrize
+ def test_method_text_generation_with_all_params(self, client: LlamaStack) -> None:
+ evaluation = client.evaluations.text_generation(
+ metrics=["perplexity", "rouge", "bleu"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, evaluation, path=["response"])
+
@parametrize
def test_raw_response_text_generation(self, client: LlamaStack) -> None:
response = client.evaluations.with_raw_response.text_generation(
@@ -90,6 +106,14 @@ async def test_method_summarization(self, async_client: AsyncLlamaStack) -> None
)
assert_matches_type(EvaluationJob, evaluation, path=["response"])
+ @parametrize
+ async def test_method_summarization_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ evaluation = await async_client.evaluations.summarization(
+ metrics=["rouge", "bleu"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, evaluation, path=["response"])
+
@parametrize
async def test_raw_response_summarization(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluations.with_raw_response.summarization(
@@ -121,6 +145,14 @@ async def test_method_text_generation(self, async_client: AsyncLlamaStack) -> No
)
assert_matches_type(EvaluationJob, evaluation, path=["response"])
+ @parametrize
+ async def test_method_text_generation_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ evaluation = await async_client.evaluations.text_generation(
+ metrics=["perplexity", "rouge", "bleu"],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(EvaluationJob, evaluation, path=["response"])
+
@parametrize
async def test_raw_response_text_generation(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.evaluations.with_raw_response.text_generation(
diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py
index a7b80b6..7b20e11 100644
--- a/tests/api_resources/test_inference.py
+++ b/tests/api_resources/test_inference.py
@@ -109,6 +109,7 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt
},
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"])
@@ -254,6 +255,7 @@ def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaSt
},
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
inference_stream.response.close()
@@ -333,6 +335,7 @@ def test_method_completion_with_all_params(self, client: LlamaStack) -> None:
"top_p": 0,
},
stream=True,
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(InferenceCompletionResponse, inference, path=["response"])
@@ -455,6 +458,7 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli
},
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"])
@@ -600,6 +604,7 @@ async def test_method_chat_completion_with_all_params_overload_2(self, async_cli
},
},
],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
await inference_stream.response.aclose()
@@ -679,6 +684,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS
"top_p": 0,
},
stream=True,
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(InferenceCompletionResponse, inference, path=["response"])
diff --git a/tests/api_resources/test_memory.py b/tests/api_resources/test_memory.py
new file mode 100644
index 0000000..2d1cc05
--- /dev/null
+++ b/tests/api_resources/test_memory.py
@@ -0,0 +1,852 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+import os
+from typing import Any, cast
+
+import pytest
+
+from llama_stack import LlamaStack, AsyncLlamaStack
+from tests.utils import assert_matches_type
+from llama_stack.types import (
+ QueryDocuments,
+)
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestMemory:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ def test_method_create(self, client: LlamaStack) -> None:
+ memory = client.memory.create(
+ body={},
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_method_create_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.create(
+ body={},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_raw_response_create(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.create(
+ body={},
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_streaming_response_create(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.create(
+ body={},
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_retrieve(self, client: LlamaStack) -> None:
+ memory = client.memory.retrieve(
+ bank_id="bank_id",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.retrieve(
+ bank_id="bank_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_raw_response_retrieve(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.retrieve(
+ bank_id="bank_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_streaming_response_retrieve(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.retrieve(
+ bank_id="bank_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_update(self, client: LlamaStack) -> None:
+ memory = client.memory.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+ assert memory is None
+
+ @parametrize
+ def test_method_update_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ ],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert memory is None
+
+ @parametrize
+ def test_raw_response_update(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert memory is None
+
+ @parametrize
+ def test_streaming_response_update(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert memory is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_list(self, client: LlamaStack) -> None:
+ memory = client.memory.list()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.list()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ def test_streaming_response_list(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.list() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_drop(self, client: LlamaStack) -> None:
+ memory = client.memory.drop(
+ bank_id="bank_id",
+ )
+ assert_matches_type(str, memory, path=["response"])
+
+ @parametrize
+ def test_method_drop_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.drop(
+ bank_id="bank_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(str, memory, path=["response"])
+
+ @parametrize
+ def test_raw_response_drop(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.drop(
+ bank_id="bank_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert_matches_type(str, memory, path=["response"])
+
+ @parametrize
+ def test_streaming_response_drop(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.drop(
+ bank_id="bank_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert_matches_type(str, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_insert(self, client: LlamaStack) -> None:
+ memory = client.memory.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+ assert memory is None
+
+ @parametrize
+ def test_method_insert_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ ],
+ ttl_seconds=0,
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert memory is None
+
+ @parametrize
+ def test_raw_response_insert(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert memory is None
+
+ @parametrize
+ def test_streaming_response_insert(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert memory is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_query(self, client: LlamaStack) -> None:
+ memory = client.memory.query(
+ bank_id="bank_id",
+ query="string",
+ )
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ @parametrize
+ def test_method_query_with_all_params(self, client: LlamaStack) -> None:
+ memory = client.memory.query(
+ bank_id="bank_id",
+ query="string",
+ params={"foo": True},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ @parametrize
+ def test_raw_response_query(self, client: LlamaStack) -> None:
+ response = client.memory.with_raw_response.query(
+ bank_id="bank_id",
+ query="string",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = response.parse()
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ @parametrize
+ def test_streaming_response_query(self, client: LlamaStack) -> None:
+ with client.memory.with_streaming_response.query(
+ bank_id="bank_id",
+ query="string",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = response.parse()
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+
+class TestAsyncMemory:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.create(
+ body={},
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.create(
+ body={},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.create(
+ body={},
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.create(
+ body={},
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.retrieve(
+ bank_id="bank_id",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.retrieve(
+ bank_id="bank_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.retrieve(
+ bank_id="bank_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.retrieve(
+ bank_id="bank_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_update(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+ assert memory is None
+
+ @parametrize
+ async def test_method_update_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ ],
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert memory is None
+
+ @parametrize
+ async def test_raw_response_update(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert memory is None
+
+ @parametrize
+ async def test_streaming_response_update(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.update(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert memory is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.list()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.list()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.list() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert_matches_type(object, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_drop(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.drop(
+ bank_id="bank_id",
+ )
+ assert_matches_type(str, memory, path=["response"])
+
+ @parametrize
+ async def test_method_drop_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.drop(
+ bank_id="bank_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(str, memory, path=["response"])
+
+ @parametrize
+ async def test_raw_response_drop(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.drop(
+ bank_id="bank_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert_matches_type(str, memory, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_drop(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.drop(
+ bank_id="bank_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert_matches_type(str, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_insert(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+ assert memory is None
+
+ @parametrize
+ async def test_method_insert_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ "mime_type": "mime_type",
+ },
+ ],
+ ttl_seconds=0,
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert memory is None
+
+ @parametrize
+ async def test_raw_response_insert(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert memory is None
+
+ @parametrize
+ async def test_streaming_response_insert(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.insert(
+ bank_id="bank_id",
+ documents=[
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ {
+ "content": "string",
+ "document_id": "document_id",
+ "metadata": {"foo": True},
+ },
+ ],
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert memory is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_query(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.query(
+ bank_id="bank_id",
+ query="string",
+ )
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ @parametrize
+ async def test_method_query_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory = await async_client.memory.query(
+ bank_id="bank_id",
+ query="string",
+ params={"foo": True},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ @parametrize
+ async def test_raw_response_query(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory.with_raw_response.query(
+ bank_id="bank_id",
+ query="string",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ memory = await response.parse()
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_query(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory.with_streaming_response.query(
+ bank_id="bank_id",
+ query="string",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ memory = await response.parse()
+ assert_matches_type(QueryDocuments, memory, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_memory_banks.py b/tests/api_resources/test_memory_banks.py
index 9b3f433..93708b1 100644
--- a/tests/api_resources/test_memory_banks.py
+++ b/tests/api_resources/test_memory_banks.py
@@ -3,15 +3,13 @@
from __future__ import annotations
import os
-from typing import Any, cast
+from typing import Any, Optional, cast
import pytest
from llama_stack import LlamaStack, AsyncLlamaStack
from tests.utils import assert_matches_type
-from llama_stack.types import (
- QueryDocuments,
-)
+from llama_stack.types import MemoryBankSpec
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -20,153 +18,16 @@ class TestMemoryBanks:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
@parametrize
- def test_method_create(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.create(
- body={},
- )
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- def test_raw_response_create(self, client: LlamaStack) -> None:
- response = client.memory_banks.with_raw_response.create(
- body={},
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- def test_streaming_response_create(self, client: LlamaStack) -> None:
- with client.memory_banks.with_streaming_response.create(
- body={},
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- def test_method_retrieve(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.retrieve(
- bank_id="bank_id",
- )
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- def test_raw_response_retrieve(self, client: LlamaStack) -> None:
- response = client.memory_banks.with_raw_response.retrieve(
- bank_id="bank_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- def test_streaming_response_retrieve(self, client: LlamaStack) -> None:
- with client.memory_banks.with_streaming_response.retrieve(
- bank_id="bank_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- def test_method_update(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.update(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- )
- assert memory_bank is None
+ def test_method_list(self, client: LlamaStack) -> None:
+ memory_bank = client.memory_banks.list()
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
@parametrize
- def test_raw_response_update(self, client: LlamaStack) -> None:
- response = client.memory_banks.with_raw_response.update(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ memory_bank = client.memory_banks.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = response.parse()
- assert memory_bank is None
-
- @parametrize
- def test_streaming_response_update(self, client: LlamaStack) -> None:
- with client.memory_banks.with_streaming_response.update(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = response.parse()
- assert memory_bank is None
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- def test_method_list(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.list()
- assert_matches_type(object, memory_bank, path=["response"])
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
@parametrize
def test_raw_response_list(self, client: LlamaStack) -> None:
@@ -175,7 +36,7 @@ def test_raw_response_list(self, client: LlamaStack) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
@parametrize
def test_streaming_response_list(self, client: LlamaStack) -> None:
@@ -184,191 +45,46 @@ def test_streaming_response_list(self, client: LlamaStack) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- def test_method_drop(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.drop(
- bank_id="bank_id",
- )
- assert_matches_type(str, memory_bank, path=["response"])
-
- @parametrize
- def test_raw_response_drop(self, client: LlamaStack) -> None:
- response = client.memory_banks.with_raw_response.drop(
- bank_id="bank_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = response.parse()
- assert_matches_type(str, memory_bank, path=["response"])
-
- @parametrize
- def test_streaming_response_drop(self, client: LlamaStack) -> None:
- with client.memory_banks.with_streaming_response.drop(
- bank_id="bank_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = response.parse()
- assert_matches_type(str, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- def test_method_insert(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- )
- assert memory_bank is None
-
- @parametrize
- def test_method_insert_with_all_params(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- "mime_type": "mime_type",
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- "mime_type": "mime_type",
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- "mime_type": "mime_type",
- },
- ],
- ttl_seconds=0,
- )
- assert memory_bank is None
-
- @parametrize
- def test_raw_response_insert(self, client: LlamaStack) -> None:
- response = client.memory_banks.with_raw_response.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = response.parse()
- assert memory_bank is None
-
- @parametrize
- def test_streaming_response_insert(self, client: LlamaStack) -> None:
- with client.memory_banks.with_streaming_response.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = response.parse()
- assert memory_bank is None
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
assert cast(Any, response.is_closed) is True
@parametrize
- def test_method_query(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.query(
- bank_id="bank_id",
- query="string",
+ def test_method_get(self, client: LlamaStack) -> None:
+ memory_bank = client.memory_banks.get(
+ bank_type="vector",
)
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
@parametrize
- def test_method_query_with_all_params(self, client: LlamaStack) -> None:
- memory_bank = client.memory_banks.query(
- bank_id="bank_id",
- query="string",
- params={"foo": True},
+ def test_method_get_with_all_params(self, client: LlamaStack) -> None:
+ memory_bank = client.memory_banks.get(
+ bank_type="vector",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
@parametrize
- def test_raw_response_query(self, client: LlamaStack) -> None:
- response = client.memory_banks.with_raw_response.query(
- bank_id="bank_id",
- query="string",
+ def test_raw_response_get(self, client: LlamaStack) -> None:
+ response = client.memory_banks.with_raw_response.get(
+ bank_type="vector",
)
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = response.parse()
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
@parametrize
- def test_streaming_response_query(self, client: LlamaStack) -> None:
- with client.memory_banks.with_streaming_response.query(
- bank_id="bank_id",
- query="string",
+ def test_streaming_response_get(self, client: LlamaStack) -> None:
+ with client.memory_banks.with_streaming_response.get(
+ bank_type="vector",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = response.parse()
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -377,153 +93,16 @@ class TestAsyncMemoryBanks:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@parametrize
- async def test_method_create(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.create(
- body={},
- )
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.with_raw_response.create(
- body={},
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = await response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.with_streaming_response.create(
- body={},
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = await response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.retrieve(
- bank_id="bank_id",
- )
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.with_raw_response.retrieve(
- bank_id="bank_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = await response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.with_streaming_response.retrieve(
- bank_id="bank_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = await response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- async def test_method_update(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.update(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- )
- assert memory_bank is None
+ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
+ memory_bank = await async_client.memory_banks.list()
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
@parametrize
- async def test_raw_response_update(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.with_raw_response.update(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory_bank = await async_client.memory_banks.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = await response.parse()
- assert memory_bank is None
-
- @parametrize
- async def test_streaming_response_update(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.with_streaming_response.update(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = await response.parse()
- assert memory_bank is None
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.list()
- assert_matches_type(object, memory_bank, path=["response"])
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
@parametrize
async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
@@ -532,7 +111,7 @@ async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = await response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
@parametrize
async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None:
@@ -541,190 +120,45 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> N
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = await response.parse()
- assert_matches_type(object, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- async def test_method_drop(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.drop(
- bank_id="bank_id",
- )
- assert_matches_type(str, memory_bank, path=["response"])
-
- @parametrize
- async def test_raw_response_drop(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.with_raw_response.drop(
- bank_id="bank_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = await response.parse()
- assert_matches_type(str, memory_bank, path=["response"])
-
- @parametrize
- async def test_streaming_response_drop(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.with_streaming_response.drop(
- bank_id="bank_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = await response.parse()
- assert_matches_type(str, memory_bank, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @parametrize
- async def test_method_insert(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- )
- assert memory_bank is None
-
- @parametrize
- async def test_method_insert_with_all_params(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- "mime_type": "mime_type",
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- "mime_type": "mime_type",
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- "mime_type": "mime_type",
- },
- ],
- ttl_seconds=0,
- )
- assert memory_bank is None
-
- @parametrize
- async def test_raw_response_insert(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.with_raw_response.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- memory_bank = await response.parse()
- assert memory_bank is None
-
- @parametrize
- async def test_streaming_response_insert(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.with_streaming_response.insert(
- bank_id="bank_id",
- documents=[
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- {
- "content": "string",
- "document_id": "document_id",
- "metadata": {"foo": True},
- },
- ],
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- memory_bank = await response.parse()
- assert memory_bank is None
+ assert_matches_type(MemoryBankSpec, memory_bank, path=["response"])
assert cast(Any, response.is_closed) is True
@parametrize
- async def test_method_query(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.query(
- bank_id="bank_id",
- query="string",
+ async def test_method_get(self, async_client: AsyncLlamaStack) -> None:
+ memory_bank = await async_client.memory_banks.get(
+ bank_type="vector",
)
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
@parametrize
- async def test_method_query_with_all_params(self, async_client: AsyncLlamaStack) -> None:
- memory_bank = await async_client.memory_banks.query(
- bank_id="bank_id",
- query="string",
- params={"foo": True},
+ async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ memory_bank = await async_client.memory_banks.get(
+ bank_type="vector",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
@parametrize
- async def test_raw_response_query(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.memory_banks.with_raw_response.query(
- bank_id="bank_id",
- query="string",
+ async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.memory_banks.with_raw_response.get(
+ bank_type="vector",
)
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = await response.parse()
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
@parametrize
- async def test_streaming_response_query(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.memory_banks.with_streaming_response.query(
- bank_id="bank_id",
- query="string",
+ async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.memory_banks.with_streaming_response.get(
+ bank_type="vector",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
memory_bank = await response.parse()
- assert_matches_type(QueryDocuments, memory_bank, path=["response"])
+ assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"])
assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py
new file mode 100644
index 0000000..47be556
--- /dev/null
+++ b/tests/api_resources/test_models.py
@@ -0,0 +1,164 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+import os
+from typing import Any, Optional, cast
+
+import pytest
+
+from llama_stack import LlamaStack, AsyncLlamaStack
+from tests.utils import assert_matches_type
+from llama_stack.types import ModelServingSpec
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestModels:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ def test_method_list(self, client: LlamaStack) -> None:
+ model = client.models.list()
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ model = client.models.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: LlamaStack) -> None:
+ response = client.models.with_raw_response.list()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ @parametrize
+ def test_streaming_response_list(self, client: LlamaStack) -> None:
+ with client.models.with_streaming_response.list() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = response.parse()
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_get(self, client: LlamaStack) -> None:
+ model = client.models.get(
+ core_model_id="core_model_id",
+ )
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ @parametrize
+ def test_method_get_with_all_params(self, client: LlamaStack) -> None:
+ model = client.models.get(
+ core_model_id="core_model_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ @parametrize
+ def test_raw_response_get(self, client: LlamaStack) -> None:
+ response = client.models.with_raw_response.get(
+ core_model_id="core_model_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ @parametrize
+ def test_streaming_response_get(self, client: LlamaStack) -> None:
+ with client.models.with_streaming_response.get(
+ core_model_id="core_model_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = response.parse()
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+
+class TestAsyncModels:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
+ model = await async_client.models.list()
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ model = await async_client.models.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.models.with_raw_response.list()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = await response.parse()
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.models.with_streaming_response.list() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = await response.parse()
+ assert_matches_type(ModelServingSpec, model, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_get(self, async_client: AsyncLlamaStack) -> None:
+ model = await async_client.models.get(
+ core_model_id="core_model_id",
+ )
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ @parametrize
+ async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ model = await async_client.models.get(
+ core_model_id="core_model_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.models.with_raw_response.get(
+ core_model_id="core_model_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = await response.parse()
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.models.with_streaming_response.get(
+ core_model_id="core_model_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = await response.parse()
+ assert_matches_type(Optional[ModelServingSpec], model, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_post_training.py b/tests/api_resources/test_post_training.py
index 588f159..9cee8d0 100644
--- a/tests/api_resources/test_post_training.py
+++ b/tests/api_resources/test_post_training.py
@@ -98,6 +98,7 @@ def test_method_preference_optimize_with_all_params(self, client: LlamaStack) ->
"content_url": "https://example.com",
"metadata": {"foo": True},
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(PostTrainingJob, post_training, path=["response"])
@@ -272,6 +273,7 @@ def test_method_supervised_fine_tune_with_all_params(self, client: LlamaStack) -
"content_url": "https://example.com",
"metadata": {"foo": True},
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(PostTrainingJob, post_training, path=["response"])
@@ -450,6 +452,7 @@ async def test_method_preference_optimize_with_all_params(self, async_client: As
"content_url": "https://example.com",
"metadata": {"foo": True},
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(PostTrainingJob, post_training, path=["response"])
@@ -624,6 +627,7 @@ async def test_method_supervised_fine_tune_with_all_params(self, async_client: A
"content_url": "https://example.com",
"metadata": {"foo": True},
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(PostTrainingJob, post_training, path=["response"])
diff --git a/tests/api_resources/test_reward_scoring.py b/tests/api_resources/test_reward_scoring.py
index 83f6983..eaa9e43 100644
--- a/tests/api_resources/test_reward_scoring.py
+++ b/tests/api_resources/test_reward_scoring.py
@@ -116,6 +116,124 @@ def test_method_score(self, client: LlamaStack) -> None:
)
assert_matches_type(RewardScoring, reward_scoring, path=["response"])
+ @parametrize
+ def test_method_score_with_all_params(self, client: LlamaStack) -> None:
+ reward_scoring = client.reward_scoring.score(
+ dialog_generations=[
+ {
+ "dialog": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ "sampled_generations": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ },
+ {
+ "dialog": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ "sampled_generations": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ },
+ {
+ "dialog": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ "sampled_generations": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ },
+ ],
+ model="model",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(RewardScoring, reward_scoring, path=["response"])
+
@parametrize
def test_raw_response_score(self, client: LlamaStack) -> None:
response = client.reward_scoring.with_raw_response.score(
@@ -427,6 +545,124 @@ async def test_method_score(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(RewardScoring, reward_scoring, path=["response"])
+ @parametrize
+ async def test_method_score_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ reward_scoring = await async_client.reward_scoring.score(
+ dialog_generations=[
+ {
+ "dialog": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ "sampled_generations": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ },
+ {
+ "dialog": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ "sampled_generations": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ },
+ {
+ "dialog": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ "sampled_generations": [
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ {
+ "content": "string",
+ "role": "user",
+ "context": "string",
+ },
+ ],
+ },
+ ],
+ model="model",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(RewardScoring, reward_scoring, path=["response"])
+
@parametrize
async def test_raw_response_score(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.reward_scoring.with_raw_response.score(
diff --git a/tests/api_resources/test_safety.py b/tests/api_resources/test_safety.py
index f4b44dc..eccaf8a 100644
--- a/tests/api_resources/test_safety.py
+++ b/tests/api_resources/test_safety.py
@@ -9,7 +9,7 @@
from llama_stack import LlamaStack, AsyncLlamaStack
from tests.utils import assert_matches_type
-from llama_stack.types import SafetyRunShieldsResponse
+from llama_stack.types import RunSheidResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -18,8 +18,8 @@ class TestSafety:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
@parametrize
- def test_method_run_shields(self, client: LlamaStack) -> None:
- safety = client.safety.run_shields(
+ def test_method_run_shield(self, client: LlamaStack) -> None:
+ safety = client.safety.run_shield(
messages=[
{
"content": "string",
@@ -34,64 +34,66 @@ def test_method_run_shields(self, client: LlamaStack) -> None:
"role": "user",
},
],
- shields=[
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- ],
+ params={"foo": True},
+ shield_type="shield_type",
)
- assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"])
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
@parametrize
- def test_raw_response_run_shields(self, client: LlamaStack) -> None:
- response = client.safety.with_raw_response.run_shields(
+ def test_method_run_shield_with_all_params(self, client: LlamaStack) -> None:
+ safety = client.safety.run_shield(
messages=[
{
"content": "string",
"role": "user",
+ "context": "string",
},
{
"content": "string",
"role": "user",
+ "context": "string",
},
{
"content": "string",
"role": "user",
+ "context": "string",
},
],
- shields=[
+ params={"foo": True},
+ shield_type="shield_type",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
+
+ @parametrize
+ def test_raw_response_run_shield(self, client: LlamaStack) -> None:
+ response = client.safety.with_raw_response.run_shield(
+ messages=[
{
- "on_violation_action": 0,
- "shield_type": "llama_guard",
+ "content": "string",
+ "role": "user",
},
{
- "on_violation_action": 0,
- "shield_type": "llama_guard",
+ "content": "string",
+ "role": "user",
},
{
- "on_violation_action": 0,
- "shield_type": "llama_guard",
+ "content": "string",
+ "role": "user",
},
],
+ params={"foo": True},
+ shield_type="shield_type",
)
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
safety = response.parse()
- assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"])
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
@parametrize
- def test_streaming_response_run_shields(self, client: LlamaStack) -> None:
- with client.safety.with_streaming_response.run_shields(
+ def test_streaming_response_run_shield(self, client: LlamaStack) -> None:
+ with client.safety.with_streaming_response.run_shield(
messages=[
{
"content": "string",
@@ -106,26 +108,14 @@ def test_streaming_response_run_shields(self, client: LlamaStack) -> None:
"role": "user",
},
],
- shields=[
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- ],
+ params={"foo": True},
+ shield_type="shield_type",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
safety = response.parse()
- assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"])
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -134,8 +124,8 @@ class TestAsyncSafety:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@parametrize
- async def test_method_run_shields(self, async_client: AsyncLlamaStack) -> None:
- safety = await async_client.safety.run_shields(
+ async def test_method_run_shield(self, async_client: AsyncLlamaStack) -> None:
+ safety = await async_client.safety.run_shield(
messages=[
{
"content": "string",
@@ -150,64 +140,66 @@ async def test_method_run_shields(self, async_client: AsyncLlamaStack) -> None:
"role": "user",
},
],
- shields=[
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- ],
+ params={"foo": True},
+ shield_type="shield_type",
)
- assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"])
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
@parametrize
- async def test_raw_response_run_shields(self, async_client: AsyncLlamaStack) -> None:
- response = await async_client.safety.with_raw_response.run_shields(
+ async def test_method_run_shield_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ safety = await async_client.safety.run_shield(
messages=[
{
"content": "string",
"role": "user",
+ "context": "string",
},
{
"content": "string",
"role": "user",
+ "context": "string",
},
{
"content": "string",
"role": "user",
+ "context": "string",
},
],
- shields=[
+ params={"foo": True},
+ shield_type="shield_type",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
+
+ @parametrize
+ async def test_raw_response_run_shield(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.safety.with_raw_response.run_shield(
+ messages=[
{
- "on_violation_action": 0,
- "shield_type": "llama_guard",
+ "content": "string",
+ "role": "user",
},
{
- "on_violation_action": 0,
- "shield_type": "llama_guard",
+ "content": "string",
+ "role": "user",
},
{
- "on_violation_action": 0,
- "shield_type": "llama_guard",
+ "content": "string",
+ "role": "user",
},
],
+ params={"foo": True},
+ shield_type="shield_type",
)
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
safety = await response.parse()
- assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"])
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
@parametrize
- async def test_streaming_response_run_shields(self, async_client: AsyncLlamaStack) -> None:
- async with async_client.safety.with_streaming_response.run_shields(
+ async def test_streaming_response_run_shield(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.safety.with_streaming_response.run_shield(
messages=[
{
"content": "string",
@@ -222,25 +214,13 @@ async def test_streaming_response_run_shields(self, async_client: AsyncLlamaStac
"role": "user",
},
],
- shields=[
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- {
- "on_violation_action": 0,
- "shield_type": "llama_guard",
- },
- ],
+ params={"foo": True},
+ shield_type="shield_type",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
safety = await response.parse()
- assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"])
+ assert_matches_type(RunSheidResponse, safety, path=["response"])
assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_shields.py b/tests/api_resources/test_shields.py
new file mode 100644
index 0000000..fdc6728
--- /dev/null
+++ b/tests/api_resources/test_shields.py
@@ -0,0 +1,164 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from __future__ import annotations
+
+import os
+from typing import Any, Optional, cast
+
+import pytest
+
+from llama_stack import LlamaStack, AsyncLlamaStack
+from tests.utils import assert_matches_type
+from llama_stack.types import ShieldSpec
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestShields:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ def test_method_list(self, client: LlamaStack) -> None:
+ shield = client.shields.list()
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ @parametrize
+ def test_method_list_with_all_params(self, client: LlamaStack) -> None:
+ shield = client.shields.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ @parametrize
+ def test_raw_response_list(self, client: LlamaStack) -> None:
+ response = client.shields.with_raw_response.list()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ shield = response.parse()
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ @parametrize
+ def test_streaming_response_list(self, client: LlamaStack) -> None:
+ with client.shields.with_streaming_response.list() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ shield = response.parse()
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ def test_method_get(self, client: LlamaStack) -> None:
+ shield = client.shields.get(
+ shield_type="shield_type",
+ )
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ @parametrize
+ def test_method_get_with_all_params(self, client: LlamaStack) -> None:
+ shield = client.shields.get(
+ shield_type="shield_type",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ @parametrize
+ def test_raw_response_get(self, client: LlamaStack) -> None:
+ response = client.shields.with_raw_response.get(
+ shield_type="shield_type",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ shield = response.parse()
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ @parametrize
+ def test_streaming_response_get(self, client: LlamaStack) -> None:
+ with client.shields.with_streaming_response.get(
+ shield_type="shield_type",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ shield = response.parse()
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+
+class TestAsyncShields:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ async def test_method_list(self, async_client: AsyncLlamaStack) -> None:
+ shield = await async_client.shields.list()
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ @parametrize
+ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ shield = await async_client.shields.list(
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ @parametrize
+ async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.shields.with_raw_response.list()
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ shield = await response.parse()
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.shields.with_streaming_response.list() as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ shield = await response.parse()
+ assert_matches_type(ShieldSpec, shield, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @parametrize
+ async def test_method_get(self, async_client: AsyncLlamaStack) -> None:
+ shield = await async_client.shields.get(
+ shield_type="shield_type",
+ )
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ @parametrize
+ async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ shield = await async_client.shields.get(
+ shield_type="shield_type",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None:
+ response = await async_client.shields.with_raw_response.get(
+ shield_type="shield_type",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ shield = await response.parse()
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None:
+ async with async_client.shields.with_streaming_response.get(
+ shield_type="shield_type",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ shield = await response.parse()
+ assert_matches_type(Optional[ShieldSpec], shield, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_synthetic_data_generation.py b/tests/api_resources/test_synthetic_data_generation.py
index 7f5d73b..5c4bd93 100644
--- a/tests/api_resources/test_synthetic_data_generation.py
+++ b/tests/api_resources/test_synthetic_data_generation.py
@@ -60,6 +60,7 @@ def test_method_generate_with_all_params(self, client: LlamaStack) -> None:
],
filtering_function="none",
model="model",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"])
@@ -162,6 +163,7 @@ async def test_method_generate_with_all_params(self, async_client: AsyncLlamaSta
],
filtering_function="none",
model="model",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"])
diff --git a/tests/api_resources/test_telemetry.py b/tests/api_resources/test_telemetry.py
index 317bb8b..b7a7bda 100644
--- a/tests/api_resources/test_telemetry.py
+++ b/tests/api_resources/test_telemetry.py
@@ -25,6 +25,14 @@ def test_method_get_trace(self, client: LlamaStack) -> None:
)
assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"])
+ @parametrize
+ def test_method_get_trace_with_all_params(self, client: LlamaStack) -> None:
+ telemetry = client.telemetry.get_trace(
+ trace_id="trace_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"])
+
@parametrize
def test_raw_response_get_trace(self, client: LlamaStack) -> None:
response = client.telemetry.with_raw_response.get_trace(
@@ -75,6 +83,7 @@ def test_method_log_with_all_params(self, client: LlamaStack) -> None:
"type": "unstructured_log",
"attributes": {"foo": True},
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert telemetry is None
@@ -127,6 +136,14 @@ async def test_method_get_trace(self, async_client: AsyncLlamaStack) -> None:
)
assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"])
+ @parametrize
+ async def test_method_get_trace_with_all_params(self, async_client: AsyncLlamaStack) -> None:
+ telemetry = await async_client.telemetry.get_trace(
+ trace_id="trace_id",
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
+ )
+ assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"])
+
@parametrize
async def test_raw_response_get_trace(self, async_client: AsyncLlamaStack) -> None:
response = await async_client.telemetry.with_raw_response.get_trace(
@@ -177,6 +194,7 @@ async def test_method_log_with_all_params(self, async_client: AsyncLlamaStack) -
"type": "unstructured_log",
"attributes": {"foo": True},
},
+ x_llama_stack_provider_data="X-LlamaStack-ProviderData",
)
assert telemetry is None
diff --git a/tests/test_client.py b/tests/test_client.py
index ea89ceb..adbd9ad 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -720,6 +720,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
response = client.agents.sessions.with_raw_response.create(agent_id="agent_id", session_name="session_name")
assert response.retries_taken == failures_before_success
+ assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
class TestAsyncLlamaStack:
@@ -1405,3 +1406,4 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
)
assert response.retries_taken == failures_before_success
+ assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success