From 328904e091a8412020be85e82c8235cd1d961169 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 24 Sep 2024 01:11:17 -0700 Subject: [PATCH] sync repo for api_updates_3 --- .stats.yml | 4 +- api.md | 77 +- src/llama_stack/_base_client.py | 108 +- src/llama_stack/_client.py | 40 +- src/llama_stack/_compat.py | 2 + src/llama_stack/resources/__init__.py | 54 +- src/llama_stack/resources/agents/agents.py | 21 + src/llama_stack/resources/agents/sessions.py | 31 + src/llama_stack/resources/agents/steps.py | 11 + src/llama_stack/resources/agents/turns.py | 27 + src/llama_stack/resources/batch_inference.py | 21 + src/llama_stack/resources/datasets.py | 31 + .../resources/evaluate/jobs/artifacts.py | 11 + .../resources/evaluate/jobs/jobs.py | 41 + .../resources/evaluate/jobs/logs.py | 11 + .../resources/evaluate/jobs/status.py | 11 + .../resources/evaluate/question_answering.py | 11 + src/llama_stack/resources/evaluations.py | 21 + .../resources/inference/embeddings.py | 11 + .../resources/inference/inference.py | 27 + .../{memory_banks => memory}/__init__.py | 28 +- .../{memory_banks => memory}/documents.py | 33 +- .../memory_banks.py => memory/memory.py} | 275 +++-- src/llama_stack/resources/memory_banks.py | 262 +++++ src/llama_stack/resources/models.py | 261 +++++ .../resources/post_training/jobs.py | 71 ++ .../resources/post_training/post_training.py | 21 + src/llama_stack/resources/reward_scoring.py | 11 + src/llama_stack/resources/safety.py | 70 +- src/llama_stack/resources/shields.py | 261 +++++ .../resources/synthetic_data_generation.py | 11 + src/llama_stack/resources/telemetry.py | 21 + src/llama_stack/types/__init__.py | 29 +- src/llama_stack/types/agent_create_params.py | 109 +- src/llama_stack/types/agent_delete_params.py | 6 +- .../types/agents/session_create_params.py | 6 +- .../types/agents/session_delete_params.py | 6 +- .../types/agents/session_retrieve_params.py | 6 +- .../types/agents/step_retrieve_params.py | 6 +- .../types/agents/turn_create_params.py | 5 +- .../types/agents/turn_retrieve_params.py | 6 +- .../batch_inference_chat_completion_params.py | 5 +- .../batch_inference_completion_params.py | 5 +- .../custom_query_generator_config_param.py | 11 - .../types/dataset_create_params.py | 5 +- .../types/dataset_delete_params.py | 6 +- src/llama_stack/types/dataset_get_params.py | 6 +- .../default_query_generator_config_param.py | 13 - .../types/evaluate/job_cancel_params.py | 6 +- .../evaluate/jobs/artifact_list_params.py | 6 +- .../types/evaluate/jobs/log_list_params.py | 6 +- .../types/evaluate/jobs/status_list_params.py | 6 +- .../question_answering_create_params.py | 6 +- .../types/evaluation_summarization_params.py | 6 +- .../evaluation_text_generation_params.py | 6 +- .../inference/embedding_create_params.py | 6 +- .../types/inference_chat_completion_params.py | 5 +- .../types/inference_completion_params.py | 5 +- .../types/llm_query_generator_config_param.py | 15 - .../{memory_banks => memory}/__init__.py | 0 .../document_delete_params.py | 6 +- .../document_retrieve_params.py | 6 +- .../document_retrieve_response.py | 0 .../types/memory_bank_create_params.py | 11 - .../types/memory_bank_drop_params.py | 11 - .../types/memory_bank_get_params.py | 15 + .../types/memory_bank_retrieve_params.py | 11 - src/llama_stack/types/memory_bank_spec.py | 20 + src/llama_stack/types/memory_create_params.py | 15 + src/llama_stack/types/memory_drop_params.py | 15 + ...op_response.py => memory_drop_response.py} | 4 +- ...sert_params.py => memory_insert_params.py} | 10 +- ...query_params.py => memory_query_params.py} | 10 +- .../types/memory_retrieve_params.py | 15 + ...date_params.py => memory_update_params.py} | 10 +- src/llama_stack/types/model_get_params.py | 15 + src/llama_stack/types/model_serving_spec.py | 23 + .../post_training/job_artifacts_params.py | 6 +- .../types/post_training/job_cancel_params.py | 6 +- .../types/post_training/job_logs_params.py | 6 +- .../types/post_training/job_status_params.py | 6 +- ...ost_training_preference_optimize_params.py | 5 +- ...st_training_supervised_fine_tune_params.py | 5 +- .../types/reward_scoring_score_params.py | 5 +- src/llama_stack/types/run_sheid_response.py | 20 + ..._params.py => safety_run_shield_params.py} | 16 +- .../types/safety_run_shields_response.py | 12 - src/llama_stack/types/sheid_response.py | 20 - src/llama_stack/types/shield_call_step.py | 17 +- .../types/shield_definition_param.py | 28 - src/llama_stack/types/shield_get_params.py | 15 + src/llama_stack/types/shield_spec.py | 19 + ...nthetic_data_generation_generate_params.py | 5 +- .../types/telemetry_get_trace_params.py | 6 +- src/llama_stack/types/telemetry_log_params.py | 2 + tests/api_resources/agents/test_sessions.py | 38 + tests/api_resources/agents/test_steps.py | 20 + tests/api_resources/agents/test_turns.py | 22 + .../evaluate/jobs/test_artifacts.py | 16 + .../api_resources/evaluate/jobs/test_logs.py | 16 + .../evaluate/jobs/test_status.py | 16 + tests/api_resources/evaluate/test_jobs.py | 30 + .../evaluate/test_question_answering.py | 16 + .../inference/test_embeddings.py | 18 + .../{memory_banks => memory}/__init__.py | 0 .../test_documents.py | 62 +- .../api_resources/post_training/test_jobs.py | 78 ++ tests/api_resources/test_agents.py | 994 +----------------- tests/api_resources/test_batch_inference.py | 4 + tests/api_resources/test_datasets.py | 34 + tests/api_resources/test_evaluations.py | 32 + tests/api_resources/test_inference.py | 6 + tests/api_resources/test_memory.py | 852 +++++++++++++++ tests/api_resources/test_memory_banks.py | 674 +----------- tests/api_resources/test_models.py | 164 +++ tests/api_resources/test_post_training.py | 4 + tests/api_resources/test_reward_scoring.py | 236 +++++ tests/api_resources/test_safety.py | 158 ++- tests/api_resources/test_shields.py | 164 +++ .../test_synthetic_data_generation.py | 2 + tests/api_resources/test_telemetry.py | 18 + tests/test_client.py | 2 + 122 files changed, 4137 insertions(+), 2134 deletions(-) rename src/llama_stack/resources/{memory_banks => memory}/__init__.py (53%) rename src/llama_stack/resources/{memory_banks => memory}/documents.py (88%) rename src/llama_stack/resources/{memory_banks/memory_banks.py => memory/memory.py} (73%) create mode 100644 src/llama_stack/resources/memory_banks.py create mode 100644 src/llama_stack/resources/models.py create mode 100644 src/llama_stack/resources/shields.py delete mode 100644 src/llama_stack/types/custom_query_generator_config_param.py delete mode 100644 src/llama_stack/types/default_query_generator_config_param.py delete mode 100644 src/llama_stack/types/llm_query_generator_config_param.py rename src/llama_stack/types/{memory_banks => memory}/__init__.py (100%) rename src/llama_stack/types/{memory_banks => memory}/document_delete_params.py (60%) rename src/llama_stack/types/{memory_banks => memory}/document_retrieve_params.py (61%) rename src/llama_stack/types/{memory_banks => memory}/document_retrieve_response.py (100%) delete mode 100644 src/llama_stack/types/memory_bank_create_params.py delete mode 100644 src/llama_stack/types/memory_bank_drop_params.py create mode 100644 src/llama_stack/types/memory_bank_get_params.py delete mode 100644 src/llama_stack/types/memory_bank_retrieve_params.py create mode 100644 src/llama_stack/types/memory_bank_spec.py create mode 100644 src/llama_stack/types/memory_create_params.py create mode 100644 src/llama_stack/types/memory_drop_params.py rename src/llama_stack/types/{memory_bank_drop_response.py => memory_drop_response.py} (62%) rename src/llama_stack/types/{memory_bank_insert_params.py => memory_insert_params.py} (63%) rename src/llama_stack/types/{memory_bank_query_params.py => memory_query_params.py} (54%) create mode 100644 src/llama_stack/types/memory_retrieve_params.py rename src/llama_stack/types/{memory_bank_update_params.py => memory_update_params.py} (62%) create mode 100644 src/llama_stack/types/model_get_params.py create mode 100644 src/llama_stack/types/model_serving_spec.py create mode 100644 src/llama_stack/types/run_sheid_response.py rename src/llama_stack/types/{safety_run_shields_params.py => safety_run_shield_params.py} (52%) delete mode 100644 src/llama_stack/types/safety_run_shields_response.py delete mode 100644 src/llama_stack/types/sheid_response.py delete mode 100644 src/llama_stack/types/shield_definition_param.py create mode 100644 src/llama_stack/types/shield_get_params.py create mode 100644 src/llama_stack/types/shield_spec.py rename tests/api_resources/{memory_banks => memory}/__init__.py (100%) rename tests/api_resources/{memory_banks => memory}/test_documents.py (68%) create mode 100644 tests/api_resources/test_memory.py create mode 100644 tests/api_resources/test_models.py create mode 100644 tests/api_resources/test_shields.py diff --git a/.stats.yml b/.stats.yml index 68a0a53..4517049 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,2 +1,2 @@ -configured_endpoints: 55 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/meta%2Fllama-stack-a836f3ef44b0852623ef583d04501ee7fcfc4a72af8b4b54cc9d3b415ed038aa.yml +configured_endpoints: 51 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/meta%2Fllama-stack-d52e4c19360cc636336d6a60ba6af1db89736fc0a3025c2b1d11870a5f1a1e3d.yml diff --git a/api.md b/api.md index 5561fa2..636407d 100644 --- a/api.md +++ b/api.md @@ -32,14 +32,10 @@ Types: ```python from llama_stack.types import ( - CustomQueryGeneratorConfig, - DefaultQueryGeneratorConfig, InferenceStep, - LlmQueryGeneratorConfig, MemoryRetrievalStep, RestAPIExecutionConfig, ShieldCallStep, - ShieldDefinition, ToolExecutionStep, ToolParamDefinition, AgentCreateResponse, @@ -196,49 +192,49 @@ Methods: Types: ```python -from llama_stack.types import SheidResponse, SafetyRunShieldsResponse +from llama_stack.types import RunSheidResponse ``` Methods: -- client.safety.run_shields(\*\*params) -> SafetyRunShieldsResponse +- client.safety.run_shield(\*\*params) -> RunSheidResponse -# MemoryBanks +# Memory Types: ```python from llama_stack.types import ( QueryDocuments, - MemoryBankCreateResponse, - MemoryBankRetrieveResponse, - MemoryBankListResponse, - MemoryBankDropResponse, + MemoryCreateResponse, + MemoryRetrieveResponse, + MemoryListResponse, + MemoryDropResponse, ) ``` Methods: -- client.memory_banks.create(\*\*params) -> object -- client.memory_banks.retrieve(\*\*params) -> object -- client.memory_banks.update(\*\*params) -> None -- client.memory_banks.list() -> object -- client.memory_banks.drop(\*\*params) -> str -- client.memory_banks.insert(\*\*params) -> None -- client.memory_banks.query(\*\*params) -> QueryDocuments +- client.memory.create(\*\*params) -> object +- client.memory.retrieve(\*\*params) -> object +- client.memory.update(\*\*params) -> None +- client.memory.list() -> object +- client.memory.drop(\*\*params) -> str +- client.memory.insert(\*\*params) -> None +- client.memory.query(\*\*params) -> QueryDocuments ## Documents Types: ```python -from llama_stack.types.memory_banks import DocumentRetrieveResponse +from llama_stack.types.memory import DocumentRetrieveResponse ``` Methods: -- client.memory_banks.documents.retrieve(\*\*params) -> DocumentRetrieveResponse -- client.memory_banks.documents.delete(\*\*params) -> None +- client.memory.documents.retrieve(\*\*params) -> DocumentRetrieveResponse +- client.memory.documents.delete(\*\*params) -> None # PostTraining @@ -309,3 +305,42 @@ Methods: - client.batch_inference.chat_completion(\*\*params) -> BatchChatCompletion - client.batch_inference.completion(\*\*params) -> BatchCompletion + +# Models + +Types: + +```python +from llama_stack.types import ModelServingSpec +``` + +Methods: + +- client.models.list() -> ModelServingSpec +- client.models.get(\*\*params) -> Optional + +# MemoryBanks + +Types: + +```python +from llama_stack.types import MemoryBankSpec +``` + +Methods: + +- client.memory_banks.list() -> MemoryBankSpec +- client.memory_banks.get(\*\*params) -> Optional + +# Shields + +Types: + +```python +from llama_stack.types import ShieldSpec +``` + +Methods: + +- client.shields.list() -> ShieldSpec +- client.shields.get(\*\*params) -> Optional diff --git a/src/llama_stack/_base_client.py b/src/llama_stack/_base_client.py index f6e5f9a..25a1c43 100644 --- a/src/llama_stack/_base_client.py +++ b/src/llama_stack/_base_client.py @@ -400,14 +400,7 @@ def _make_status_error( ) -> _exceptions.APIStatusError: raise NotImplementedError() - def _remaining_retries( - self, - remaining_retries: Optional[int], - options: FinalRequestOptions, - ) -> int: - return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) - - def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: + def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers: custom_headers = options.headers or {} headers_dict = _merge_mappings(self.default_headers, custom_headers) self._validate_headers(headers_dict, custom_headers) @@ -419,6 +412,8 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers: if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: headers[idempotency_header] = options.idempotency_key or self._idempotency_key() + headers.setdefault("x-stainless-retry-count", str(retries_taken)) + return headers def _prepare_url(self, url: str) -> URL: @@ -440,6 +435,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: def _build_request( self, options: FinalRequestOptions, + *, + retries_taken: int = 0, ) -> httpx.Request: if log.isEnabledFor(logging.DEBUG): log.debug("Request options: %s", model_dump(options, exclude_unset=True)) @@ -455,7 +452,7 @@ def _build_request( else: raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") - headers = self._build_headers(options) + headers = self._build_headers(options, retries_taken=retries_taken) params = _merge_mappings(self.default_query, options.params) content_type = headers.get("Content-Type") files = options.files @@ -489,12 +486,17 @@ def _build_request( if not files: files = cast(HttpxRequestFiles, ForceMultipartDict()) + prepared_url = self._prepare_url(options.url) + if "_" in prepared_url.host: + # work around https://github.com/encode/httpx/discussions/2880 + kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")} + # TODO: report this error to httpx return self._client.build_request( # pyright: ignore[reportUnknownMemberType] headers=headers, timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, method=options.method, - url=self._prepare_url(options.url), + url=prepared_url, # the `Query` type that we use is incompatible with qs' # `Params` type as it needs to be typed as `Mapping[str, object]` # so that passing a `TypedDict` doesn't cause an error. @@ -933,12 +935,17 @@ def request( stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: + if remaining_retries is not None: + retries_taken = options.get_max_retries(self.max_retries) - remaining_retries + else: + retries_taken = 0 + return self._request( cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls, - remaining_retries=remaining_retries, + retries_taken=retries_taken, ) def _request( @@ -946,7 +953,7 @@ def _request( *, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: int | None, + retries_taken: int, stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: @@ -958,8 +965,8 @@ def _request( cast_to = self._maybe_override_cast_to(cast_to, options) options = self._prepare_options(options) - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + request = self._build_request(options, retries_taken=retries_taken) self._prepare_request(request) kwargs: HttpxSendArgs = {} @@ -977,11 +984,11 @@ def _request( except httpx.TimeoutException as err: log.debug("Encountered httpx.TimeoutException", exc_info=True) - if retries > 0: + if remaining_retries > 0: return self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -992,11 +999,11 @@ def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries > 0: + if remaining_retries > 0: return self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1019,13 +1026,13 @@ def _request( except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - if retries > 0 and self._should_retry(err.response): + if remaining_retries > 0 and self._should_retry(err.response): err.response.close() return self._retry_request( input_options, cast_to, - retries, - err.response.headers, + retries_taken=retries_taken, + response_headers=err.response.headers, stream=stream, stream_cls=stream_cls, ) @@ -1044,26 +1051,26 @@ def _request( response=response, stream=stream, stream_cls=stream_cls, - retries_taken=options.get_max_retries(self.max_retries) - retries, + retries_taken=retries_taken, ) def _retry_request( self, options: FinalRequestOptions, cast_to: Type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, *, + retries_taken: int, + response_headers: httpx.Headers | None, stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: - remaining = remaining_retries - 1 - if remaining == 1: + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + if remaining_retries == 1: log.debug("1 retry left") else: - log.debug("%i retries left", remaining) + log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) log.info("Retrying request to %s in %f seconds", options.url, timeout) # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a @@ -1073,7 +1080,7 @@ def _retry_request( return self._request( options=options, cast_to=cast_to, - remaining_retries=remaining, + retries_taken=retries_taken + 1, stream=stream, stream_cls=stream_cls, ) @@ -1491,12 +1498,17 @@ async def request( stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, ) -> ResponseT | _AsyncStreamT: + if remaining_retries is not None: + retries_taken = options.get_max_retries(self.max_retries) - remaining_retries + else: + retries_taken = 0 + return await self._request( cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls, - remaining_retries=remaining_retries, + retries_taken=retries_taken, ) async def _request( @@ -1506,7 +1518,7 @@ async def _request( *, stream: bool, stream_cls: type[_AsyncStreamT] | None, - remaining_retries: int | None, + retries_taken: int, ) -> ResponseT | _AsyncStreamT: if self._platform is None: # `get_platform` can make blocking IO calls so we @@ -1521,8 +1533,8 @@ async def _request( cast_to = self._maybe_override_cast_to(cast_to, options) options = await self._prepare_options(options) - retries = self._remaining_retries(remaining_retries, options) - request = self._build_request(options) + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + request = self._build_request(options, retries_taken=retries_taken) await self._prepare_request(request) kwargs: HttpxSendArgs = {} @@ -1538,11 +1550,11 @@ async def _request( except httpx.TimeoutException as err: log.debug("Encountered httpx.TimeoutException", exc_info=True) - if retries > 0: + if remaining_retries > 0: return await self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1553,11 +1565,11 @@ async def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries > 0: + if retries_taken > 0: return await self._retry_request( input_options, cast_to, - retries, + retries_taken=retries_taken, stream=stream, stream_cls=stream_cls, response_headers=None, @@ -1575,13 +1587,13 @@ async def _request( except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - if retries > 0 and self._should_retry(err.response): + if remaining_retries > 0 and self._should_retry(err.response): await err.response.aclose() return await self._retry_request( input_options, cast_to, - retries, - err.response.headers, + retries_taken=retries_taken, + response_headers=err.response.headers, stream=stream, stream_cls=stream_cls, ) @@ -1600,26 +1612,26 @@ async def _request( response=response, stream=stream, stream_cls=stream_cls, - retries_taken=options.get_max_retries(self.max_retries) - retries, + retries_taken=retries_taken, ) async def _retry_request( self, options: FinalRequestOptions, cast_to: Type[ResponseT], - remaining_retries: int, - response_headers: httpx.Headers | None, *, + retries_taken: int, + response_headers: httpx.Headers | None, stream: bool, stream_cls: type[_AsyncStreamT] | None, ) -> ResponseT | _AsyncStreamT: - remaining = remaining_retries - 1 - if remaining == 1: + remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + if remaining_retries == 1: log.debug("1 retry left") else: - log.debug("%i retries left", remaining) + log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) @@ -1627,7 +1639,7 @@ async def _retry_request( return await self._request( options=options, cast_to=cast_to, - remaining_retries=remaining, + retries_taken=retries_taken + 1, stream=stream, stream_cls=stream_cls, ) diff --git a/src/llama_stack/_client.py b/src/llama_stack/_client.py index 80d45bd..420716d 100644 --- a/src/llama_stack/_client.py +++ b/src/llama_stack/_client.py @@ -59,11 +59,14 @@ class LlamaStack(SyncAPIClient): evaluations: resources.EvaluationsResource inference: resources.InferenceResource safety: resources.SafetyResource - memory_banks: resources.MemoryBanksResource + memory: resources.MemoryResource post_training: resources.PostTrainingResource reward_scoring: resources.RewardScoringResource synthetic_data_generation: resources.SyntheticDataGenerationResource batch_inference: resources.BatchInferenceResource + models: resources.ModelsResource + memory_banks: resources.MemoryBanksResource + shields: resources.ShieldsResource with_raw_response: LlamaStackWithRawResponse with_streaming_response: LlamaStackWithStreamedResponse @@ -139,11 +142,14 @@ def __init__( self.evaluations = resources.EvaluationsResource(self) self.inference = resources.InferenceResource(self) self.safety = resources.SafetyResource(self) - self.memory_banks = resources.MemoryBanksResource(self) + self.memory = resources.MemoryResource(self) self.post_training = resources.PostTrainingResource(self) self.reward_scoring = resources.RewardScoringResource(self) self.synthetic_data_generation = resources.SyntheticDataGenerationResource(self) self.batch_inference = resources.BatchInferenceResource(self) + self.models = resources.ModelsResource(self) + self.memory_banks = resources.MemoryBanksResource(self) + self.shields = resources.ShieldsResource(self) self.with_raw_response = LlamaStackWithRawResponse(self) self.with_streaming_response = LlamaStackWithStreamedResponse(self) @@ -254,11 +260,14 @@ class AsyncLlamaStack(AsyncAPIClient): evaluations: resources.AsyncEvaluationsResource inference: resources.AsyncInferenceResource safety: resources.AsyncSafetyResource - memory_banks: resources.AsyncMemoryBanksResource + memory: resources.AsyncMemoryResource post_training: resources.AsyncPostTrainingResource reward_scoring: resources.AsyncRewardScoringResource synthetic_data_generation: resources.AsyncSyntheticDataGenerationResource batch_inference: resources.AsyncBatchInferenceResource + models: resources.AsyncModelsResource + memory_banks: resources.AsyncMemoryBanksResource + shields: resources.AsyncShieldsResource with_raw_response: AsyncLlamaStackWithRawResponse with_streaming_response: AsyncLlamaStackWithStreamedResponse @@ -334,11 +343,14 @@ def __init__( self.evaluations = resources.AsyncEvaluationsResource(self) self.inference = resources.AsyncInferenceResource(self) self.safety = resources.AsyncSafetyResource(self) - self.memory_banks = resources.AsyncMemoryBanksResource(self) + self.memory = resources.AsyncMemoryResource(self) self.post_training = resources.AsyncPostTrainingResource(self) self.reward_scoring = resources.AsyncRewardScoringResource(self) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResource(self) self.batch_inference = resources.AsyncBatchInferenceResource(self) + self.models = resources.AsyncModelsResource(self) + self.memory_banks = resources.AsyncMemoryBanksResource(self) + self.shields = resources.AsyncShieldsResource(self) self.with_raw_response = AsyncLlamaStackWithRawResponse(self) self.with_streaming_response = AsyncLlamaStackWithStreamedResponse(self) @@ -450,13 +462,16 @@ def __init__(self, client: LlamaStack) -> None: self.evaluations = resources.EvaluationsResourceWithRawResponse(client.evaluations) self.inference = resources.InferenceResourceWithRawResponse(client.inference) self.safety = resources.SafetyResourceWithRawResponse(client.safety) - self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks) + self.memory = resources.MemoryResourceWithRawResponse(client.memory) self.post_training = resources.PostTrainingResourceWithRawResponse(client.post_training) self.reward_scoring = resources.RewardScoringResourceWithRawResponse(client.reward_scoring) self.synthetic_data_generation = resources.SyntheticDataGenerationResourceWithRawResponse( client.synthetic_data_generation ) self.batch_inference = resources.BatchInferenceResourceWithRawResponse(client.batch_inference) + self.models = resources.ModelsResourceWithRawResponse(client.models) + self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks) + self.shields = resources.ShieldsResourceWithRawResponse(client.shields) class AsyncLlamaStackWithRawResponse: @@ -468,13 +483,16 @@ def __init__(self, client: AsyncLlamaStack) -> None: self.evaluations = resources.AsyncEvaluationsResourceWithRawResponse(client.evaluations) self.inference = resources.AsyncInferenceResourceWithRawResponse(client.inference) self.safety = resources.AsyncSafetyResourceWithRawResponse(client.safety) - self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks) + self.memory = resources.AsyncMemoryResourceWithRawResponse(client.memory) self.post_training = resources.AsyncPostTrainingResourceWithRawResponse(client.post_training) self.reward_scoring = resources.AsyncRewardScoringResourceWithRawResponse(client.reward_scoring) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResourceWithRawResponse( client.synthetic_data_generation ) self.batch_inference = resources.AsyncBatchInferenceResourceWithRawResponse(client.batch_inference) + self.models = resources.AsyncModelsResourceWithRawResponse(client.models) + self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks) + self.shields = resources.AsyncShieldsResourceWithRawResponse(client.shields) class LlamaStackWithStreamedResponse: @@ -486,13 +504,16 @@ def __init__(self, client: LlamaStack) -> None: self.evaluations = resources.EvaluationsResourceWithStreamingResponse(client.evaluations) self.inference = resources.InferenceResourceWithStreamingResponse(client.inference) self.safety = resources.SafetyResourceWithStreamingResponse(client.safety) - self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.memory = resources.MemoryResourceWithStreamingResponse(client.memory) self.post_training = resources.PostTrainingResourceWithStreamingResponse(client.post_training) self.reward_scoring = resources.RewardScoringResourceWithStreamingResponse(client.reward_scoring) self.synthetic_data_generation = resources.SyntheticDataGenerationResourceWithStreamingResponse( client.synthetic_data_generation ) self.batch_inference = resources.BatchInferenceResourceWithStreamingResponse(client.batch_inference) + self.models = resources.ModelsResourceWithStreamingResponse(client.models) + self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.shields = resources.ShieldsResourceWithStreamingResponse(client.shields) class AsyncLlamaStackWithStreamedResponse: @@ -504,13 +525,16 @@ def __init__(self, client: AsyncLlamaStack) -> None: self.evaluations = resources.AsyncEvaluationsResourceWithStreamingResponse(client.evaluations) self.inference = resources.AsyncInferenceResourceWithStreamingResponse(client.inference) self.safety = resources.AsyncSafetyResourceWithStreamingResponse(client.safety) - self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.memory = resources.AsyncMemoryResourceWithStreamingResponse(client.memory) self.post_training = resources.AsyncPostTrainingResourceWithStreamingResponse(client.post_training) self.reward_scoring = resources.AsyncRewardScoringResourceWithStreamingResponse(client.reward_scoring) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResourceWithStreamingResponse( client.synthetic_data_generation ) self.batch_inference = resources.AsyncBatchInferenceResourceWithStreamingResponse(client.batch_inference) + self.models = resources.AsyncModelsResourceWithStreamingResponse(client.models) + self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.shields = resources.AsyncShieldsResourceWithStreamingResponse(client.shields) Client = LlamaStack diff --git a/src/llama_stack/_compat.py b/src/llama_stack/_compat.py index 21fe694..162a6fb 100644 --- a/src/llama_stack/_compat.py +++ b/src/llama_stack/_compat.py @@ -136,12 +136,14 @@ def model_dump( exclude: IncEx = None, exclude_unset: bool = False, exclude_defaults: bool = False, + warnings: bool = True, ) -> dict[str, Any]: if PYDANTIC_V2: return model.model_dump( exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, + warnings=warnings, ) return cast( "dict[str, Any]", diff --git a/src/llama_stack/resources/__init__.py b/src/llama_stack/resources/__init__.py index 484c881..e981ad1 100644 --- a/src/llama_stack/resources/__init__.py +++ b/src/llama_stack/resources/__init__.py @@ -8,6 +8,22 @@ AgentsResourceWithStreamingResponse, AsyncAgentsResourceWithStreamingResponse, ) +from .memory import ( + MemoryResource, + AsyncMemoryResource, + MemoryResourceWithRawResponse, + AsyncMemoryResourceWithRawResponse, + MemoryResourceWithStreamingResponse, + AsyncMemoryResourceWithStreamingResponse, +) +from .models import ( + ModelsResource, + AsyncModelsResource, + ModelsResourceWithRawResponse, + AsyncModelsResourceWithRawResponse, + ModelsResourceWithStreamingResponse, + AsyncModelsResourceWithStreamingResponse, +) from .safety import ( SafetyResource, AsyncSafetyResource, @@ -16,6 +32,14 @@ SafetyResourceWithStreamingResponse, AsyncSafetyResourceWithStreamingResponse, ) +from .shields import ( + ShieldsResource, + AsyncShieldsResource, + ShieldsResourceWithRawResponse, + AsyncShieldsResourceWithRawResponse, + ShieldsResourceWithStreamingResponse, + AsyncShieldsResourceWithStreamingResponse, +) from .datasets import ( DatasetsResource, AsyncDatasetsResource, @@ -140,12 +164,12 @@ "AsyncSafetyResourceWithRawResponse", "SafetyResourceWithStreamingResponse", "AsyncSafetyResourceWithStreamingResponse", - "MemoryBanksResource", - "AsyncMemoryBanksResource", - "MemoryBanksResourceWithRawResponse", - "AsyncMemoryBanksResourceWithRawResponse", - "MemoryBanksResourceWithStreamingResponse", - "AsyncMemoryBanksResourceWithStreamingResponse", + "MemoryResource", + "AsyncMemoryResource", + "MemoryResourceWithRawResponse", + "AsyncMemoryResourceWithRawResponse", + "MemoryResourceWithStreamingResponse", + "AsyncMemoryResourceWithStreamingResponse", "PostTrainingResource", "AsyncPostTrainingResource", "PostTrainingResourceWithRawResponse", @@ -170,4 +194,22 @@ "AsyncBatchInferenceResourceWithRawResponse", "BatchInferenceResourceWithStreamingResponse", "AsyncBatchInferenceResourceWithStreamingResponse", + "ModelsResource", + "AsyncModelsResource", + "ModelsResourceWithRawResponse", + "AsyncModelsResourceWithRawResponse", + "ModelsResourceWithStreamingResponse", + "AsyncModelsResourceWithStreamingResponse", + "MemoryBanksResource", + "AsyncMemoryBanksResource", + "MemoryBanksResourceWithRawResponse", + "AsyncMemoryBanksResourceWithRawResponse", + "MemoryBanksResourceWithStreamingResponse", + "AsyncMemoryBanksResourceWithStreamingResponse", + "ShieldsResource", + "AsyncShieldsResource", + "ShieldsResourceWithRawResponse", + "AsyncShieldsResourceWithRawResponse", + "ShieldsResourceWithStreamingResponse", + "AsyncShieldsResourceWithStreamingResponse", ] diff --git a/src/llama_stack/resources/agents/agents.py b/src/llama_stack/resources/agents/agents.py index 48871a2..7a865a5 100644 --- a/src/llama_stack/resources/agents/agents.py +++ b/src/llama_stack/resources/agents/agents.py @@ -24,6 +24,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .sessions import ( @@ -84,6 +85,7 @@ def create( self, *, agent_config: agent_create_params.AgentConfig, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -101,6 +103,10 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/agents/create", body=maybe_transform({"agent_config": agent_config}, agent_create_params.AgentCreateParams), @@ -114,6 +120,7 @@ def delete( self, *, agent_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -132,6 +139,10 @@ def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/agents/delete", body=maybe_transform({"agent_id": agent_id}, agent_delete_params.AgentDeleteParams), @@ -178,6 +189,7 @@ async def create( self, *, agent_config: agent_create_params.AgentConfig, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -195,6 +207,10 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/agents/create", body=await async_maybe_transform({"agent_config": agent_config}, agent_create_params.AgentCreateParams), @@ -208,6 +224,7 @@ async def delete( self, *, agent_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -226,6 +243,10 @@ async def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/agents/delete", body=await async_maybe_transform({"agent_id": agent_id}, agent_delete_params.AgentDeleteParams), diff --git a/src/llama_stack/resources/agents/sessions.py b/src/llama_stack/resources/agents/sessions.py index 43eb41b..b8c6ae4 100644 --- a/src/llama_stack/resources/agents/sessions.py +++ b/src/llama_stack/resources/agents/sessions.py @@ -9,6 +9,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -52,6 +53,7 @@ def create( *, agent_id: str, session_name: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -69,6 +71,10 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/agents/session/create", body=maybe_transform( @@ -90,6 +96,7 @@ def retrieve( agent_id: str, session_id: str, turn_ids: List[str] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -107,6 +114,10 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/agents/session/get", body=maybe_transform({"turn_ids": turn_ids}, session_retrieve_params.SessionRetrieveParams), @@ -131,6 +142,7 @@ def delete( *, agent_id: str, session_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -149,6 +161,10 @@ def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/agents/session/delete", body=maybe_transform( @@ -190,6 +206,7 @@ async def create( *, agent_id: str, session_name: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -207,6 +224,10 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/agents/session/create", body=await async_maybe_transform( @@ -228,6 +249,7 @@ async def retrieve( agent_id: str, session_id: str, turn_ids: List[str] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -245,6 +267,10 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/agents/session/get", body=await async_maybe_transform({"turn_ids": turn_ids}, session_retrieve_params.SessionRetrieveParams), @@ -269,6 +295,7 @@ async def delete( *, agent_id: str, session_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -287,6 +314,10 @@ async def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/agents/session/delete", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/agents/steps.py b/src/llama_stack/resources/agents/steps.py index df6649f..7afaa90 100644 --- a/src/llama_stack/resources/agents/steps.py +++ b/src/llama_stack/resources/agents/steps.py @@ -7,6 +7,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -50,6 +51,7 @@ def retrieve( agent_id: str, step_id: str, turn_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -67,6 +69,10 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/agents/step/get", options=make_request_options( @@ -113,6 +119,7 @@ async def retrieve( agent_id: str, step_id: str, turn_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -130,6 +137,10 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/agents/step/get", options=make_request_options( diff --git a/src/llama_stack/resources/agents/turns.py b/src/llama_stack/resources/agents/turns.py index 217fff1..ac6767f 100644 --- a/src/llama_stack/resources/agents/turns.py +++ b/src/llama_stack/resources/agents/turns.py @@ -11,6 +11,7 @@ from ..._utils import ( required_args, maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -60,6 +61,7 @@ def create( session_id: str, attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, stream: Literal[False] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -88,6 +90,7 @@ def create( session_id: str, stream: Literal[True], attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -116,6 +119,7 @@ def create( session_id: str, stream: bool, attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -144,6 +148,7 @@ def create( session_id: str, attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -151,6 +156,10 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AgentsTurnStreamChunk | Stream[AgentsTurnStreamChunk]: + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/agents/turn/create", body=maybe_transform( @@ -176,6 +185,7 @@ def retrieve( *, agent_id: str, turn_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -193,6 +203,10 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/agents/turn/get", options=make_request_options( @@ -241,6 +255,7 @@ async def create( session_id: str, attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, stream: Literal[False] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -269,6 +284,7 @@ async def create( session_id: str, stream: Literal[True], attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -297,6 +313,7 @@ async def create( session_id: str, stream: bool, attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -325,6 +342,7 @@ async def create( session_id: str, attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -332,6 +350,10 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AgentsTurnStreamChunk | AsyncStream[AgentsTurnStreamChunk]: + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/agents/turn/create", body=await async_maybe_transform( @@ -357,6 +379,7 @@ async def retrieve( *, agent_id: str, turn_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -374,6 +397,10 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/agents/turn/get", options=make_request_options( diff --git a/src/llama_stack/resources/batch_inference.py b/src/llama_stack/resources/batch_inference.py index 36028b3..7c2c563 100644 --- a/src/llama_stack/resources/batch_inference.py +++ b/src/llama_stack/resources/batch_inference.py @@ -11,6 +11,7 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -59,6 +60,7 @@ def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -86,6 +88,10 @@ def chat_completion( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/batch_inference/chat_completion", body=maybe_transform( @@ -113,6 +119,7 @@ def completion( model: str, logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -130,6 +137,10 @@ def completion( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/batch_inference/completion", body=maybe_transform( @@ -178,6 +189,7 @@ async def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -205,6 +217,10 @@ async def chat_completion( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/batch_inference/chat_completion", body=await async_maybe_transform( @@ -232,6 +248,7 @@ async def completion( model: str, logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -249,6 +266,10 @@ async def completion( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/batch_inference/completion", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/datasets.py b/src/llama_stack/resources/datasets.py index 321e301..dcf1005 100644 --- a/src/llama_stack/resources/datasets.py +++ b/src/llama_stack/resources/datasets.py @@ -8,6 +8,7 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -50,6 +51,7 @@ def create( *, dataset: TrainEvalDatasetParam, uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,6 +70,10 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/datasets/create", body=maybe_transform( @@ -87,6 +93,7 @@ def delete( self, *, dataset_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -105,6 +112,10 @@ def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/datasets/delete", body=maybe_transform({"dataset_uuid": dataset_uuid}, dataset_delete_params.DatasetDeleteParams), @@ -118,6 +129,7 @@ def get( self, *, dataset_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -135,6 +147,10 @@ def get( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/datasets/get", options=make_request_options( @@ -173,6 +189,7 @@ async def create( *, dataset: TrainEvalDatasetParam, uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -191,6 +208,10 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/datasets/create", body=await async_maybe_transform( @@ -210,6 +231,7 @@ async def delete( self, *, dataset_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -228,6 +250,10 @@ async def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/datasets/delete", body=await async_maybe_transform({"dataset_uuid": dataset_uuid}, dataset_delete_params.DatasetDeleteParams), @@ -241,6 +267,7 @@ async def get( self, *, dataset_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -258,6 +285,10 @@ async def get( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/datasets/get", options=make_request_options( diff --git a/src/llama_stack/resources/evaluate/jobs/artifacts.py b/src/llama_stack/resources/evaluate/jobs/artifacts.py index b5e0a3d..ce03116 100644 --- a/src/llama_stack/resources/evaluate/jobs/artifacts.py +++ b/src/llama_stack/resources/evaluate/jobs/artifacts.py @@ -7,6 +7,7 @@ from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ...._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ...._compat import cached_property @@ -48,6 +49,7 @@ def list( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -65,6 +67,10 @@ def list( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/evaluate/job/artifacts", options=make_request_options( @@ -102,6 +108,7 @@ async def list( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -119,6 +126,10 @@ async def list( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/evaluate/job/artifacts", options=make_request_options( diff --git a/src/llama_stack/resources/evaluate/jobs/jobs.py b/src/llama_stack/resources/evaluate/jobs/jobs.py index bfbeb42..9e64c18 100644 --- a/src/llama_stack/resources/evaluate/jobs/jobs.py +++ b/src/llama_stack/resources/evaluate/jobs/jobs.py @@ -23,6 +23,7 @@ from ...._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ...._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .artifacts import ( @@ -83,6 +84,7 @@ def with_streaming_response(self) -> JobsResourceWithStreamingResponse: def list( self, *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -90,7 +92,21 @@ def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EvaluationJob: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/evaluate/jobs", options=make_request_options( @@ -103,6 +119,7 @@ def cancel( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -121,6 +138,10 @@ def cancel( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/evaluate/job/cancel", body=maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), @@ -166,6 +187,7 @@ def with_streaming_response(self) -> AsyncJobsResourceWithStreamingResponse: async def list( self, *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -173,7 +195,21 @@ async def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EvaluationJob: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/evaluate/jobs", options=make_request_options( @@ -186,6 +222,7 @@ async def cancel( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -204,6 +241,10 @@ async def cancel( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/evaluate/job/cancel", body=await async_maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), diff --git a/src/llama_stack/resources/evaluate/jobs/logs.py b/src/llama_stack/resources/evaluate/jobs/logs.py index 2aae53c..c1db747 100644 --- a/src/llama_stack/resources/evaluate/jobs/logs.py +++ b/src/llama_stack/resources/evaluate/jobs/logs.py @@ -7,6 +7,7 @@ from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ...._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ...._compat import cached_property @@ -48,6 +49,7 @@ def list( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -65,6 +67,10 @@ def list( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/evaluate/job/logs", options=make_request_options( @@ -102,6 +108,7 @@ async def list( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -119,6 +126,10 @@ async def list( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/evaluate/job/logs", options=make_request_options( diff --git a/src/llama_stack/resources/evaluate/jobs/status.py b/src/llama_stack/resources/evaluate/jobs/status.py index 4719ede..2c3aca8 100644 --- a/src/llama_stack/resources/evaluate/jobs/status.py +++ b/src/llama_stack/resources/evaluate/jobs/status.py @@ -7,6 +7,7 @@ from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ...._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ...._compat import cached_property @@ -48,6 +49,7 @@ def list( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -65,6 +67,10 @@ def list( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/evaluate/job/status", options=make_request_options( @@ -102,6 +108,7 @@ async def list( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -119,6 +126,10 @@ async def list( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/evaluate/job/status", options=make_request_options( diff --git a/src/llama_stack/resources/evaluate/question_answering.py b/src/llama_stack/resources/evaluate/question_answering.py index ca5169f..50b4a0c 100644 --- a/src/llama_stack/resources/evaluate/question_answering.py +++ b/src/llama_stack/resources/evaluate/question_answering.py @@ -10,6 +10,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -51,6 +52,7 @@ def create( self, *, metrics: List[Literal["em", "f1"]], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,6 +70,10 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/evaluate/question_answering/", body=maybe_transform({"metrics": metrics}, question_answering_create_params.QuestionAnsweringCreateParams), @@ -102,6 +108,7 @@ async def create( self, *, metrics: List[Literal["em", "f1"]], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -119,6 +126,10 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/evaluate/question_answering/", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/evaluations.py b/src/llama_stack/resources/evaluations.py index 6328060..cebe2ba 100644 --- a/src/llama_stack/resources/evaluations.py +++ b/src/llama_stack/resources/evaluations.py @@ -11,6 +11,7 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -51,6 +52,7 @@ def summarization( self, *, metrics: List[Literal["rouge", "bleu"]], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,6 +70,10 @@ def summarization( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/evaluate/summarization/", body=maybe_transform({"metrics": metrics}, evaluation_summarization_params.EvaluationSummarizationParams), @@ -81,6 +87,7 @@ def text_generation( self, *, metrics: List[Literal["perplexity", "rouge", "bleu"]], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -98,6 +105,10 @@ def text_generation( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/evaluate/text_generation/", body=maybe_transform( @@ -134,6 +145,7 @@ async def summarization( self, *, metrics: List[Literal["rouge", "bleu"]], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -151,6 +163,10 @@ async def summarization( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/evaluate/summarization/", body=await async_maybe_transform( @@ -166,6 +182,7 @@ async def text_generation( self, *, metrics: List[Literal["perplexity", "rouge", "bleu"]], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -183,6 +200,10 @@ async def text_generation( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/evaluate/text_generation/", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/inference/embeddings.py b/src/llama_stack/resources/inference/embeddings.py index 77e7e5e..0b332cf 100644 --- a/src/llama_stack/resources/inference/embeddings.py +++ b/src/llama_stack/resources/inference/embeddings.py @@ -9,6 +9,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -51,6 +52,7 @@ def create( *, contents: List[Union[str, List[str]]], model: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,6 +70,10 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/inference/embeddings", body=maybe_transform( @@ -109,6 +115,7 @@ async def create( *, contents: List[Union[str, List[str]]], model: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -126,6 +133,10 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/inference/embeddings", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/inference/inference.py b/src/llama_stack/resources/inference/inference.py index 648736e..1be5b06 100644 --- a/src/llama_stack/resources/inference/inference.py +++ b/src/llama_stack/resources/inference/inference.py @@ -12,6 +12,7 @@ from ..._utils import ( required_args, maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -75,6 +76,7 @@ def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -116,6 +118,7 @@ def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -157,6 +160,7 @@ def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -198,6 +202,7 @@ def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -206,6 +211,10 @@ def chat_completion( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return cast( InferenceChatCompletionResponse, self._post( @@ -242,6 +251,7 @@ def completion( logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, stream: bool | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -259,6 +269,10 @@ def completion( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return cast( InferenceCompletionResponse, self._post( @@ -319,6 +333,7 @@ async def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -360,6 +375,7 @@ async def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -401,6 +417,7 @@ async def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -442,6 +459,7 @@ async def chat_completion( tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -450,6 +468,10 @@ async def chat_completion( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return cast( InferenceChatCompletionResponse, await self._post( @@ -486,6 +508,7 @@ async def completion( logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, stream: bool | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -503,6 +526,10 @@ async def completion( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return cast( InferenceCompletionResponse, await self._post( diff --git a/src/llama_stack/resources/memory_banks/__init__.py b/src/llama_stack/resources/memory/__init__.py similarity index 53% rename from src/llama_stack/resources/memory_banks/__init__.py rename to src/llama_stack/resources/memory/__init__.py index f26b272..1438115 100644 --- a/src/llama_stack/resources/memory_banks/__init__.py +++ b/src/llama_stack/resources/memory/__init__.py @@ -1,5 +1,13 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +from .memory import ( + MemoryResource, + AsyncMemoryResource, + MemoryResourceWithRawResponse, + AsyncMemoryResourceWithRawResponse, + MemoryResourceWithStreamingResponse, + AsyncMemoryResourceWithStreamingResponse, +) from .documents import ( DocumentsResource, AsyncDocumentsResource, @@ -8,14 +16,6 @@ DocumentsResourceWithStreamingResponse, AsyncDocumentsResourceWithStreamingResponse, ) -from .memory_banks import ( - MemoryBanksResource, - AsyncMemoryBanksResource, - MemoryBanksResourceWithRawResponse, - AsyncMemoryBanksResourceWithRawResponse, - MemoryBanksResourceWithStreamingResponse, - AsyncMemoryBanksResourceWithStreamingResponse, -) __all__ = [ "DocumentsResource", @@ -24,10 +24,10 @@ "AsyncDocumentsResourceWithRawResponse", "DocumentsResourceWithStreamingResponse", "AsyncDocumentsResourceWithStreamingResponse", - "MemoryBanksResource", - "AsyncMemoryBanksResource", - "MemoryBanksResourceWithRawResponse", - "AsyncMemoryBanksResourceWithRawResponse", - "MemoryBanksResourceWithStreamingResponse", - "AsyncMemoryBanksResourceWithStreamingResponse", + "MemoryResource", + "AsyncMemoryResource", + "MemoryResourceWithRawResponse", + "AsyncMemoryResourceWithRawResponse", + "MemoryResourceWithStreamingResponse", + "AsyncMemoryResourceWithStreamingResponse", ] diff --git a/src/llama_stack/resources/memory_banks/documents.py b/src/llama_stack/resources/memory/documents.py similarity index 88% rename from src/llama_stack/resources/memory_banks/documents.py rename to src/llama_stack/resources/memory/documents.py index 995bc7a..546ffd4 100644 --- a/src/llama_stack/resources/memory_banks/documents.py +++ b/src/llama_stack/resources/memory/documents.py @@ -9,6 +9,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -20,8 +21,8 @@ async_to_streamed_response_wrapper, ) from ..._base_client import make_request_options -from ...types.memory_banks import document_delete_params, document_retrieve_params -from ...types.memory_banks.document_retrieve_response import DocumentRetrieveResponse +from ...types.memory import document_delete_params, document_retrieve_params +from ...types.memory.document_retrieve_response import DocumentRetrieveResponse __all__ = ["DocumentsResource", "AsyncDocumentsResource"] @@ -51,6 +52,7 @@ def retrieve( *, bank_id: str, document_ids: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -69,8 +71,12 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_bank/documents/get", + "/memory/documents/get", body=maybe_transform({"document_ids": document_ids}, document_retrieve_params.DocumentRetrieveParams), options=make_request_options( extra_headers=extra_headers, @@ -87,6 +93,7 @@ def delete( *, bank_id: str, document_ids: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -105,8 +112,12 @@ def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_bank/documents/delete", + "/memory/documents/delete", body=maybe_transform( { "bank_id": bank_id, @@ -146,6 +157,7 @@ async def retrieve( *, bank_id: str, document_ids: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -164,8 +176,12 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_bank/documents/get", + "/memory/documents/get", body=await async_maybe_transform( {"document_ids": document_ids}, document_retrieve_params.DocumentRetrieveParams ), @@ -186,6 +202,7 @@ async def delete( *, bank_id: str, document_ids: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -204,8 +221,12 @@ async def delete( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_bank/documents/delete", + "/memory/documents/delete", body=await async_maybe_transform( { "bank_id": bank_id, diff --git a/src/llama_stack/resources/memory_banks/memory_banks.py b/src/llama_stack/resources/memory/memory.py similarity index 73% rename from src/llama_stack/resources/memory_banks/memory_banks.py rename to src/llama_stack/resources/memory/memory.py index 52ddeba..761ab5d 100644 --- a/src/llama_stack/resources/memory_banks/memory_banks.py +++ b/src/llama_stack/resources/memory/memory.py @@ -7,16 +7,17 @@ import httpx from ...types import ( - memory_bank_drop_params, - memory_bank_query_params, - memory_bank_create_params, - memory_bank_insert_params, - memory_bank_update_params, - memory_bank_retrieve_params, + memory_drop_params, + memory_query_params, + memory_create_params, + memory_insert_params, + memory_update_params, + memory_retrieve_params, ) from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -38,37 +39,38 @@ from ..._base_client import make_request_options from ...types.query_documents import QueryDocuments -__all__ = ["MemoryBanksResource", "AsyncMemoryBanksResource"] +__all__ = ["MemoryResource", "AsyncMemoryResource"] -class MemoryBanksResource(SyncAPIResource): +class MemoryResource(SyncAPIResource): @cached_property def documents(self) -> DocumentsResource: return DocumentsResource(self._client) @cached_property - def with_raw_response(self) -> MemoryBanksResourceWithRawResponse: + def with_raw_response(self) -> MemoryResourceWithRawResponse: """ This property can be used as a prefix for any HTTP method call to return the the raw response object instead of the parsed content. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers """ - return MemoryBanksResourceWithRawResponse(self) + return MemoryResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> MemoryBanksResourceWithStreamingResponse: + def with_streaming_response(self) -> MemoryResourceWithStreamingResponse: """ An alternative to `.with_raw_response` that doesn't eagerly read the response body. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response """ - return MemoryBanksResourceWithStreamingResponse(self) + return MemoryResourceWithStreamingResponse(self) def create( self, *, body: object, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -86,9 +88,13 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_banks/create", - body=maybe_transform(body, memory_bank_create_params.MemoryBankCreateParams), + "/memory/create", + body=maybe_transform(body, memory_create_params.MemoryCreateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -99,6 +105,7 @@ def retrieve( self, *, bank_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -116,14 +123,18 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( - "/memory_banks/get", + "/memory/get", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=maybe_transform({"bank_id": bank_id}, memory_bank_retrieve_params.MemoryBankRetrieveParams), + query=maybe_transform({"bank_id": bank_id}, memory_retrieve_params.MemoryRetrieveParams), ), cast_to=object, ) @@ -132,7 +143,8 @@ def update( self, *, bank_id: str, - documents: Iterable[memory_bank_update_params.Document], + documents: Iterable[memory_update_params.Document], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -151,14 +163,18 @@ def update( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_bank/update", + "/memory/update", body=maybe_transform( { "bank_id": bank_id, "documents": documents, }, - memory_bank_update_params.MemoryBankUpdateParams, + memory_update_params.MemoryUpdateParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -169,6 +185,7 @@ def update( def list( self, *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -176,9 +193,23 @@ def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> object: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( - "/memory_banks/list", + "/memory/list", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -189,6 +220,7 @@ def drop( self, *, bank_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -206,9 +238,13 @@ def drop( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_banks/drop", - body=maybe_transform({"bank_id": bank_id}, memory_bank_drop_params.MemoryBankDropParams), + "/memory/drop", + body=maybe_transform({"bank_id": bank_id}, memory_drop_params.MemoryDropParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -219,8 +255,9 @@ def insert( self, *, bank_id: str, - documents: Iterable[memory_bank_insert_params.Document], + documents: Iterable[memory_insert_params.Document], ttl_seconds: int | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -239,15 +276,19 @@ def insert( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_bank/insert", + "/memory/insert", body=maybe_transform( { "bank_id": bank_id, "documents": documents, "ttl_seconds": ttl_seconds, }, - memory_bank_insert_params.MemoryBankInsertParams, + memory_insert_params.MemoryInsertParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -261,6 +302,7 @@ def query( bank_id: str, query: Union[str, List[str]], params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -278,15 +320,19 @@ def query( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/memory_bank/query", + "/memory/query", body=maybe_transform( { "bank_id": bank_id, "query": query, "params": params, }, - memory_bank_query_params.MemoryBankQueryParams, + memory_query_params.MemoryQueryParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -295,34 +341,35 @@ def query( ) -class AsyncMemoryBanksResource(AsyncAPIResource): +class AsyncMemoryResource(AsyncAPIResource): @cached_property def documents(self) -> AsyncDocumentsResource: return AsyncDocumentsResource(self._client) @cached_property - def with_raw_response(self) -> AsyncMemoryBanksResourceWithRawResponse: + def with_raw_response(self) -> AsyncMemoryResourceWithRawResponse: """ This property can be used as a prefix for any HTTP method call to return the the raw response object instead of the parsed content. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers """ - return AsyncMemoryBanksResourceWithRawResponse(self) + return AsyncMemoryResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncMemoryBanksResourceWithStreamingResponse: + def with_streaming_response(self) -> AsyncMemoryResourceWithStreamingResponse: """ An alternative to `.with_raw_response` that doesn't eagerly read the response body. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response """ - return AsyncMemoryBanksResourceWithStreamingResponse(self) + return AsyncMemoryResourceWithStreamingResponse(self) async def create( self, *, body: object, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -340,9 +387,13 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_banks/create", - body=await async_maybe_transform(body, memory_bank_create_params.MemoryBankCreateParams), + "/memory/create", + body=await async_maybe_transform(body, memory_create_params.MemoryCreateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -353,6 +404,7 @@ async def retrieve( self, *, bank_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -370,16 +422,18 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( - "/memory_banks/get", + "/memory/get", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=await async_maybe_transform( - {"bank_id": bank_id}, memory_bank_retrieve_params.MemoryBankRetrieveParams - ), + query=await async_maybe_transform({"bank_id": bank_id}, memory_retrieve_params.MemoryRetrieveParams), ), cast_to=object, ) @@ -388,7 +442,8 @@ async def update( self, *, bank_id: str, - documents: Iterable[memory_bank_update_params.Document], + documents: Iterable[memory_update_params.Document], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -407,14 +462,18 @@ async def update( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_bank/update", + "/memory/update", body=await async_maybe_transform( { "bank_id": bank_id, "documents": documents, }, - memory_bank_update_params.MemoryBankUpdateParams, + memory_update_params.MemoryUpdateParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -425,6 +484,7 @@ async def update( async def list( self, *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -432,9 +492,23 @@ async def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> object: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( - "/memory_banks/list", + "/memory/list", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -445,6 +519,7 @@ async def drop( self, *, bank_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -462,9 +537,13 @@ async def drop( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_banks/drop", - body=await async_maybe_transform({"bank_id": bank_id}, memory_bank_drop_params.MemoryBankDropParams), + "/memory/drop", + body=await async_maybe_transform({"bank_id": bank_id}, memory_drop_params.MemoryDropParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -475,8 +554,9 @@ async def insert( self, *, bank_id: str, - documents: Iterable[memory_bank_insert_params.Document], + documents: Iterable[memory_insert_params.Document], ttl_seconds: int | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -495,15 +575,19 @@ async def insert( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_bank/insert", + "/memory/insert", body=await async_maybe_transform( { "bank_id": bank_id, "documents": documents, "ttl_seconds": ttl_seconds, }, - memory_bank_insert_params.MemoryBankInsertParams, + memory_insert_params.MemoryInsertParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -517,6 +601,7 @@ async def query( bank_id: str, query: Union[str, List[str]], params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -534,15 +619,19 @@ async def query( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/memory_bank/query", + "/memory/query", body=await async_maybe_transform( { "bank_id": bank_id, "query": query, "params": params, }, - memory_bank_query_params.MemoryBankQueryParams, + memory_query_params.MemoryQueryParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -551,125 +640,125 @@ async def query( ) -class MemoryBanksResourceWithRawResponse: - def __init__(self, memory_banks: MemoryBanksResource) -> None: - self._memory_banks = memory_banks +class MemoryResourceWithRawResponse: + def __init__(self, memory: MemoryResource) -> None: + self._memory = memory self.create = to_raw_response_wrapper( - memory_banks.create, + memory.create, ) self.retrieve = to_raw_response_wrapper( - memory_banks.retrieve, + memory.retrieve, ) self.update = to_raw_response_wrapper( - memory_banks.update, + memory.update, ) self.list = to_raw_response_wrapper( - memory_banks.list, + memory.list, ) self.drop = to_raw_response_wrapper( - memory_banks.drop, + memory.drop, ) self.insert = to_raw_response_wrapper( - memory_banks.insert, + memory.insert, ) self.query = to_raw_response_wrapper( - memory_banks.query, + memory.query, ) @cached_property def documents(self) -> DocumentsResourceWithRawResponse: - return DocumentsResourceWithRawResponse(self._memory_banks.documents) + return DocumentsResourceWithRawResponse(self._memory.documents) -class AsyncMemoryBanksResourceWithRawResponse: - def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None: - self._memory_banks = memory_banks +class AsyncMemoryResourceWithRawResponse: + def __init__(self, memory: AsyncMemoryResource) -> None: + self._memory = memory self.create = async_to_raw_response_wrapper( - memory_banks.create, + memory.create, ) self.retrieve = async_to_raw_response_wrapper( - memory_banks.retrieve, + memory.retrieve, ) self.update = async_to_raw_response_wrapper( - memory_banks.update, + memory.update, ) self.list = async_to_raw_response_wrapper( - memory_banks.list, + memory.list, ) self.drop = async_to_raw_response_wrapper( - memory_banks.drop, + memory.drop, ) self.insert = async_to_raw_response_wrapper( - memory_banks.insert, + memory.insert, ) self.query = async_to_raw_response_wrapper( - memory_banks.query, + memory.query, ) @cached_property def documents(self) -> AsyncDocumentsResourceWithRawResponse: - return AsyncDocumentsResourceWithRawResponse(self._memory_banks.documents) + return AsyncDocumentsResourceWithRawResponse(self._memory.documents) -class MemoryBanksResourceWithStreamingResponse: - def __init__(self, memory_banks: MemoryBanksResource) -> None: - self._memory_banks = memory_banks +class MemoryResourceWithStreamingResponse: + def __init__(self, memory: MemoryResource) -> None: + self._memory = memory self.create = to_streamed_response_wrapper( - memory_banks.create, + memory.create, ) self.retrieve = to_streamed_response_wrapper( - memory_banks.retrieve, + memory.retrieve, ) self.update = to_streamed_response_wrapper( - memory_banks.update, + memory.update, ) self.list = to_streamed_response_wrapper( - memory_banks.list, + memory.list, ) self.drop = to_streamed_response_wrapper( - memory_banks.drop, + memory.drop, ) self.insert = to_streamed_response_wrapper( - memory_banks.insert, + memory.insert, ) self.query = to_streamed_response_wrapper( - memory_banks.query, + memory.query, ) @cached_property def documents(self) -> DocumentsResourceWithStreamingResponse: - return DocumentsResourceWithStreamingResponse(self._memory_banks.documents) + return DocumentsResourceWithStreamingResponse(self._memory.documents) -class AsyncMemoryBanksResourceWithStreamingResponse: - def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None: - self._memory_banks = memory_banks +class AsyncMemoryResourceWithStreamingResponse: + def __init__(self, memory: AsyncMemoryResource) -> None: + self._memory = memory self.create = async_to_streamed_response_wrapper( - memory_banks.create, + memory.create, ) self.retrieve = async_to_streamed_response_wrapper( - memory_banks.retrieve, + memory.retrieve, ) self.update = async_to_streamed_response_wrapper( - memory_banks.update, + memory.update, ) self.list = async_to_streamed_response_wrapper( - memory_banks.list, + memory.list, ) self.drop = async_to_streamed_response_wrapper( - memory_banks.drop, + memory.drop, ) self.insert = async_to_streamed_response_wrapper( - memory_banks.insert, + memory.insert, ) self.query = async_to_streamed_response_wrapper( - memory_banks.query, + memory.query, ) @cached_property def documents(self) -> AsyncDocumentsResourceWithStreamingResponse: - return AsyncDocumentsResourceWithStreamingResponse(self._memory_banks.documents) + return AsyncDocumentsResourceWithStreamingResponse(self._memory.documents) diff --git a/src/llama_stack/resources/memory_banks.py b/src/llama_stack/resources/memory_banks.py new file mode 100644 index 0000000..294c309 --- /dev/null +++ b/src/llama_stack/resources/memory_banks.py @@ -0,0 +1,262 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Optional +from typing_extensions import Literal + +import httpx + +from ..types import memory_bank_get_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.memory_bank_spec import MemoryBankSpec + +__all__ = ["MemoryBanksResource", "AsyncMemoryBanksResource"] + + +class MemoryBanksResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> MemoryBanksResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return MemoryBanksResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> MemoryBanksResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return MemoryBanksResourceWithStreamingResponse(self) + + def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> MemoryBankSpec: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/memory_banks/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=MemoryBankSpec, + ) + + def get( + self, + *, + bank_type: Literal["vector", "keyvalue", "keyword", "graph"], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[MemoryBankSpec]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/memory_banks/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"bank_type": bank_type}, memory_bank_get_params.MemoryBankGetParams), + ), + cast_to=MemoryBankSpec, + ) + + +class AsyncMemoryBanksResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncMemoryBanksResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncMemoryBanksResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncMemoryBanksResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncMemoryBanksResourceWithStreamingResponse(self) + + async def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> MemoryBankSpec: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/memory_banks/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=MemoryBankSpec, + ) + + async def get( + self, + *, + bank_type: Literal["vector", "keyvalue", "keyword", "graph"], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[MemoryBankSpec]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/memory_banks/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"bank_type": bank_type}, memory_bank_get_params.MemoryBankGetParams), + ), + cast_to=MemoryBankSpec, + ) + + +class MemoryBanksResourceWithRawResponse: + def __init__(self, memory_banks: MemoryBanksResource) -> None: + self._memory_banks = memory_banks + + self.list = to_raw_response_wrapper( + memory_banks.list, + ) + self.get = to_raw_response_wrapper( + memory_banks.get, + ) + + +class AsyncMemoryBanksResourceWithRawResponse: + def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None: + self._memory_banks = memory_banks + + self.list = async_to_raw_response_wrapper( + memory_banks.list, + ) + self.get = async_to_raw_response_wrapper( + memory_banks.get, + ) + + +class MemoryBanksResourceWithStreamingResponse: + def __init__(self, memory_banks: MemoryBanksResource) -> None: + self._memory_banks = memory_banks + + self.list = to_streamed_response_wrapper( + memory_banks.list, + ) + self.get = to_streamed_response_wrapper( + memory_banks.get, + ) + + +class AsyncMemoryBanksResourceWithStreamingResponse: + def __init__(self, memory_banks: AsyncMemoryBanksResource) -> None: + self._memory_banks = memory_banks + + self.list = async_to_streamed_response_wrapper( + memory_banks.list, + ) + self.get = async_to_streamed_response_wrapper( + memory_banks.get, + ) diff --git a/src/llama_stack/resources/models.py b/src/llama_stack/resources/models.py new file mode 100644 index 0000000..29f435c --- /dev/null +++ b/src/llama_stack/resources/models.py @@ -0,0 +1,261 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Optional + +import httpx + +from ..types import model_get_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.model_serving_spec import ModelServingSpec + +__all__ = ["ModelsResource", "AsyncModelsResource"] + + +class ModelsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ModelsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return ModelsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ModelsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return ModelsResourceWithStreamingResponse(self) + + def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ModelServingSpec: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/models/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ModelServingSpec, + ) + + def get( + self, + *, + core_model_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[ModelServingSpec]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/models/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"core_model_id": core_model_id}, model_get_params.ModelGetParams), + ), + cast_to=ModelServingSpec, + ) + + +class AsyncModelsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncModelsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncModelsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncModelsResourceWithStreamingResponse(self) + + async def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ModelServingSpec: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/models/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ModelServingSpec, + ) + + async def get( + self, + *, + core_model_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[ModelServingSpec]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/models/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"core_model_id": core_model_id}, model_get_params.ModelGetParams), + ), + cast_to=ModelServingSpec, + ) + + +class ModelsResourceWithRawResponse: + def __init__(self, models: ModelsResource) -> None: + self._models = models + + self.list = to_raw_response_wrapper( + models.list, + ) + self.get = to_raw_response_wrapper( + models.get, + ) + + +class AsyncModelsResourceWithRawResponse: + def __init__(self, models: AsyncModelsResource) -> None: + self._models = models + + self.list = async_to_raw_response_wrapper( + models.list, + ) + self.get = async_to_raw_response_wrapper( + models.get, + ) + + +class ModelsResourceWithStreamingResponse: + def __init__(self, models: ModelsResource) -> None: + self._models = models + + self.list = to_streamed_response_wrapper( + models.list, + ) + self.get = to_streamed_response_wrapper( + models.get, + ) + + +class AsyncModelsResourceWithStreamingResponse: + def __init__(self, models: AsyncModelsResource) -> None: + self._models = models + + self.list = async_to_streamed_response_wrapper( + models.list, + ) + self.get = async_to_streamed_response_wrapper( + models.get, + ) diff --git a/src/llama_stack/resources/post_training/jobs.py b/src/llama_stack/resources/post_training/jobs.py index e3e4632..840b2e7 100644 --- a/src/llama_stack/resources/post_training/jobs.py +++ b/src/llama_stack/resources/post_training/jobs.py @@ -7,6 +7,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -50,6 +51,7 @@ def with_streaming_response(self) -> JobsResourceWithStreamingResponse: def list( self, *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -57,7 +59,21 @@ def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> PostTrainingJob: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/post_training/jobs", options=make_request_options( @@ -70,6 +86,7 @@ def artifacts( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -87,6 +104,10 @@ def artifacts( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/post_training/job/artifacts", options=make_request_options( @@ -103,6 +124,7 @@ def cancel( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -121,6 +143,10 @@ def cancel( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/post_training/job/cancel", body=maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), @@ -134,6 +160,7 @@ def logs( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -151,6 +178,10 @@ def logs( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/post_training/job/logs", options=make_request_options( @@ -167,6 +198,7 @@ def status( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -184,6 +216,10 @@ def status( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/post_training/job/status", options=make_request_options( @@ -220,6 +256,7 @@ def with_streaming_response(self) -> AsyncJobsResourceWithStreamingResponse: async def list( self, *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -227,7 +264,21 @@ async def list( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> PostTrainingJob: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/post_training/jobs", options=make_request_options( @@ -240,6 +291,7 @@ async def artifacts( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -257,6 +309,10 @@ async def artifacts( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/post_training/job/artifacts", options=make_request_options( @@ -273,6 +329,7 @@ async def cancel( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -291,6 +348,10 @@ async def cancel( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/post_training/job/cancel", body=await async_maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), @@ -304,6 +365,7 @@ async def logs( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -321,6 +383,10 @@ async def logs( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/post_training/job/logs", options=make_request_options( @@ -337,6 +403,7 @@ async def status( self, *, job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -354,6 +421,10 @@ async def status( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/post_training/job/status", options=make_request_options( diff --git a/src/llama_stack/resources/post_training/post_training.py b/src/llama_stack/resources/post_training/post_training.py index db5aac8..8863e6b 100644 --- a/src/llama_stack/resources/post_training/post_training.py +++ b/src/llama_stack/resources/post_training/post_training.py @@ -22,6 +22,7 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from ..._compat import cached_property @@ -76,6 +77,7 @@ def preference_optimize( optimizer_config: post_training_preference_optimize_params.OptimizerConfig, training_config: post_training_preference_optimize_params.TrainingConfig, validation_dataset: TrainEvalDatasetParam, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -93,6 +95,10 @@ def preference_optimize( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/post_training/preference_optimize", body=maybe_transform( @@ -129,6 +135,7 @@ def supervised_fine_tune( optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig, training_config: post_training_supervised_fine_tune_params.TrainingConfig, validation_dataset: TrainEvalDatasetParam, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -146,6 +153,10 @@ def supervised_fine_tune( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/post_training/supervised_fine_tune", body=maybe_transform( @@ -207,6 +218,7 @@ async def preference_optimize( optimizer_config: post_training_preference_optimize_params.OptimizerConfig, training_config: post_training_preference_optimize_params.TrainingConfig, validation_dataset: TrainEvalDatasetParam, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -224,6 +236,10 @@ async def preference_optimize( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/post_training/preference_optimize", body=await async_maybe_transform( @@ -260,6 +276,7 @@ async def supervised_fine_tune( optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig, training_config: post_training_supervised_fine_tune_params.TrainingConfig, validation_dataset: TrainEvalDatasetParam, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -277,6 +294,10 @@ async def supervised_fine_tune( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/post_training/supervised_fine_tune", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/reward_scoring.py b/src/llama_stack/resources/reward_scoring.py index bed1479..3e55287 100644 --- a/src/llama_stack/resources/reward_scoring.py +++ b/src/llama_stack/resources/reward_scoring.py @@ -10,6 +10,7 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -51,6 +52,7 @@ def score( *, dialog_generations: Iterable[reward_scoring_score_params.DialogGeneration], model: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,6 +70,10 @@ def score( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/reward_scoring/score", body=maybe_transform( @@ -109,6 +115,7 @@ async def score( *, dialog_generations: Iterable[reward_scoring_score_params.DialogGeneration], model: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -126,6 +133,10 @@ async def score( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/reward_scoring/score", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/safety.py b/src/llama_stack/resources/safety.py index 3943f13..2bc3022 100644 --- a/src/llama_stack/resources/safety.py +++ b/src/llama_stack/resources/safety.py @@ -2,14 +2,15 @@ from __future__ import annotations -from typing import Iterable +from typing import Dict, Union, Iterable import httpx -from ..types import safety_run_shields_params +from ..types import safety_run_shield_params from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -21,8 +22,7 @@ async_to_streamed_response_wrapper, ) from .._base_client import make_request_options -from ..types.shield_definition_param import ShieldDefinitionParam -from ..types.safety_run_shields_response import SafetyRunShieldsResponse +from ..types.run_sheid_response import RunSheidResponse __all__ = ["SafetyResource", "AsyncSafetyResource"] @@ -47,18 +47,20 @@ def with_streaming_response(self) -> SafetyResourceWithStreamingResponse: """ return SafetyResourceWithStreamingResponse(self) - def run_shields( + def run_shield( self, *, - messages: Iterable[safety_run_shields_params.Message], - shields: Iterable[ShieldDefinitionParam], + messages: Iterable[safety_run_shield_params.Message], + params: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + shield_type: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> SafetyRunShieldsResponse: + ) -> RunSheidResponse: """ Args: extra_headers: Send extra headers @@ -69,19 +71,24 @@ def run_shields( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( - "/safety/run_shields", + "/safety/run_shield", body=maybe_transform( { "messages": messages, - "shields": shields, + "params": params, + "shield_type": shield_type, }, - safety_run_shields_params.SafetyRunShieldsParams, + safety_run_shield_params.SafetyRunShieldParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=SafetyRunShieldsResponse, + cast_to=RunSheidResponse, ) @@ -105,18 +112,20 @@ def with_streaming_response(self) -> AsyncSafetyResourceWithStreamingResponse: """ return AsyncSafetyResourceWithStreamingResponse(self) - async def run_shields( + async def run_shield( self, *, - messages: Iterable[safety_run_shields_params.Message], - shields: Iterable[ShieldDefinitionParam], + messages: Iterable[safety_run_shield_params.Message], + params: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + shield_type: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> SafetyRunShieldsResponse: + ) -> RunSheidResponse: """ Args: extra_headers: Send extra headers @@ -127,19 +136,24 @@ async def run_shields( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( - "/safety/run_shields", + "/safety/run_shield", body=await async_maybe_transform( { "messages": messages, - "shields": shields, + "params": params, + "shield_type": shield_type, }, - safety_run_shields_params.SafetyRunShieldsParams, + safety_run_shield_params.SafetyRunShieldParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=SafetyRunShieldsResponse, + cast_to=RunSheidResponse, ) @@ -147,8 +161,8 @@ class SafetyResourceWithRawResponse: def __init__(self, safety: SafetyResource) -> None: self._safety = safety - self.run_shields = to_raw_response_wrapper( - safety.run_shields, + self.run_shield = to_raw_response_wrapper( + safety.run_shield, ) @@ -156,8 +170,8 @@ class AsyncSafetyResourceWithRawResponse: def __init__(self, safety: AsyncSafetyResource) -> None: self._safety = safety - self.run_shields = async_to_raw_response_wrapper( - safety.run_shields, + self.run_shield = async_to_raw_response_wrapper( + safety.run_shield, ) @@ -165,8 +179,8 @@ class SafetyResourceWithStreamingResponse: def __init__(self, safety: SafetyResource) -> None: self._safety = safety - self.run_shields = to_streamed_response_wrapper( - safety.run_shields, + self.run_shield = to_streamed_response_wrapper( + safety.run_shield, ) @@ -174,6 +188,6 @@ class AsyncSafetyResourceWithStreamingResponse: def __init__(self, safety: AsyncSafetyResource) -> None: self._safety = safety - self.run_shields = async_to_streamed_response_wrapper( - safety.run_shields, + self.run_shield = async_to_streamed_response_wrapper( + safety.run_shield, ) diff --git a/src/llama_stack/resources/shields.py b/src/llama_stack/resources/shields.py new file mode 100644 index 0000000..bc800de --- /dev/null +++ b/src/llama_stack/resources/shields.py @@ -0,0 +1,261 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Optional + +import httpx + +from ..types import shield_get_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.shield_spec import ShieldSpec + +__all__ = ["ShieldsResource", "AsyncShieldsResource"] + + +class ShieldsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ShieldsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return ShieldsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ShieldsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return ShieldsResourceWithStreamingResponse(self) + + def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ShieldSpec: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/shields/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ShieldSpec, + ) + + def get( + self, + *, + shield_type: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[ShieldSpec]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/shields/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"shield_type": shield_type}, shield_get_params.ShieldGetParams), + ), + cast_to=ShieldSpec, + ) + + +class AsyncShieldsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncShieldsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncShieldsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncShieldsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncShieldsResourceWithStreamingResponse(self) + + async def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ShieldSpec: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/shields/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ShieldSpec, + ) + + async def get( + self, + *, + shield_type: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[ShieldSpec]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/shields/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"shield_type": shield_type}, shield_get_params.ShieldGetParams), + ), + cast_to=ShieldSpec, + ) + + +class ShieldsResourceWithRawResponse: + def __init__(self, shields: ShieldsResource) -> None: + self._shields = shields + + self.list = to_raw_response_wrapper( + shields.list, + ) + self.get = to_raw_response_wrapper( + shields.get, + ) + + +class AsyncShieldsResourceWithRawResponse: + def __init__(self, shields: AsyncShieldsResource) -> None: + self._shields = shields + + self.list = async_to_raw_response_wrapper( + shields.list, + ) + self.get = async_to_raw_response_wrapper( + shields.get, + ) + + +class ShieldsResourceWithStreamingResponse: + def __init__(self, shields: ShieldsResource) -> None: + self._shields = shields + + self.list = to_streamed_response_wrapper( + shields.list, + ) + self.get = to_streamed_response_wrapper( + shields.get, + ) + + +class AsyncShieldsResourceWithStreamingResponse: + def __init__(self, shields: AsyncShieldsResource) -> None: + self._shields = shields + + self.list = async_to_streamed_response_wrapper( + shields.list, + ) + self.get = async_to_streamed_response_wrapper( + shields.get, + ) diff --git a/src/llama_stack/resources/synthetic_data_generation.py b/src/llama_stack/resources/synthetic_data_generation.py index 0847308..d13532c 100644 --- a/src/llama_stack/resources/synthetic_data_generation.py +++ b/src/llama_stack/resources/synthetic_data_generation.py @@ -11,6 +11,7 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -53,6 +54,7 @@ def generate( dialogs: Iterable[synthetic_data_generation_generate_params.Dialog], filtering_function: Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"], model: str | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -70,6 +72,10 @@ def generate( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/synthetic_data_generation/generate", body=maybe_transform( @@ -113,6 +119,7 @@ async def generate( dialogs: Iterable[synthetic_data_generation_generate_params.Dialog], filtering_function: Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"], model: str | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -130,6 +137,10 @@ async def generate( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/synthetic_data_generation/generate", body=await async_maybe_transform( diff --git a/src/llama_stack/resources/telemetry.py b/src/llama_stack/resources/telemetry.py index b3e0524..4526dd7 100644 --- a/src/llama_stack/resources/telemetry.py +++ b/src/llama_stack/resources/telemetry.py @@ -8,6 +8,7 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, + strip_not_given, async_maybe_transform, ) from .._compat import cached_property @@ -48,6 +49,7 @@ def get_trace( self, *, trace_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -65,6 +67,10 @@ def get_trace( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._get( "/telemetry/get_trace", options=make_request_options( @@ -81,6 +87,7 @@ def log( self, *, event: telemetry_log_params.Event, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -99,6 +106,10 @@ def log( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return self._post( "/telemetry/log_event", body=maybe_transform({"event": event}, telemetry_log_params.TelemetryLogParams), @@ -133,6 +144,7 @@ async def get_trace( self, *, trace_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -150,6 +162,10 @@ async def get_trace( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._get( "/telemetry/get_trace", options=make_request_options( @@ -168,6 +184,7 @@ async def log( self, *, event: telemetry_log_params.Event, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -186,6 +203,10 @@ async def log( timeout: Override the client-level default timeout for this request, in seconds """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } return await self._post( "/telemetry/log_event", body=await async_maybe_transform({"event": event}, telemetry_log_params.TelemetryLogParams), diff --git a/src/llama_stack/types/__init__.py b/src/llama_stack/types/__init__.py index 8f43ce3..452da12 100644 --- a/src/llama_stack/types/__init__.py +++ b/src/llama_stack/types/__init__.py @@ -12,42 +12,46 @@ CompletionMessage as CompletionMessage, ToolResponseMessage as ToolResponseMessage, ) +from .shield_spec import ShieldSpec as ShieldSpec from .evaluation_job import EvaluationJob as EvaluationJob from .inference_step import InferenceStep as InferenceStep from .reward_scoring import RewardScoring as RewardScoring -from .sheid_response import SheidResponse as SheidResponse from .query_documents import QueryDocuments as QueryDocuments from .token_log_probs import TokenLogProbs as TokenLogProbs +from .memory_bank_spec import MemoryBankSpec as MemoryBankSpec +from .model_get_params import ModelGetParams as ModelGetParams from .shield_call_step import ShieldCallStep as ShieldCallStep from .post_training_job import PostTrainingJob as PostTrainingJob +from .shield_get_params import ShieldGetParams as ShieldGetParams from .dataset_get_params import DatasetGetParams as DatasetGetParams +from .memory_drop_params import MemoryDropParams as MemoryDropParams +from .model_serving_spec import ModelServingSpec as ModelServingSpec +from .run_sheid_response import RunSheidResponse as RunSheidResponse from .train_eval_dataset import TrainEvalDataset as TrainEvalDataset from .agent_create_params import AgentCreateParams as AgentCreateParams from .agent_delete_params import AgentDeleteParams as AgentDeleteParams +from .memory_query_params import MemoryQueryParams as MemoryQueryParams from .tool_execution_step import ToolExecutionStep as ToolExecutionStep +from .memory_create_params import MemoryCreateParams as MemoryCreateParams +from .memory_drop_response import MemoryDropResponse as MemoryDropResponse +from .memory_insert_params import MemoryInsertParams as MemoryInsertParams +from .memory_update_params import MemoryUpdateParams as MemoryUpdateParams from .telemetry_log_params import TelemetryLogParams as TelemetryLogParams from .agent_create_response import AgentCreateResponse as AgentCreateResponse from .batch_chat_completion import BatchChatCompletion as BatchChatCompletion from .dataset_create_params import DatasetCreateParams as DatasetCreateParams from .dataset_delete_params import DatasetDeleteParams as DatasetDeleteParams from .memory_retrieval_step import MemoryRetrievalStep as MemoryRetrievalStep +from .memory_bank_get_params import MemoryBankGetParams as MemoryBankGetParams +from .memory_retrieve_params import MemoryRetrieveParams as MemoryRetrieveParams from .completion_stream_chunk import CompletionStreamChunk as CompletionStreamChunk -from .memory_bank_drop_params import MemoryBankDropParams as MemoryBankDropParams -from .shield_definition_param import ShieldDefinitionParam as ShieldDefinitionParam -from .memory_bank_query_params import MemoryBankQueryParams as MemoryBankQueryParams +from .safety_run_shield_params import SafetyRunShieldParams as SafetyRunShieldParams from .train_eval_dataset_param import TrainEvalDatasetParam as TrainEvalDatasetParam -from .memory_bank_create_params import MemoryBankCreateParams as MemoryBankCreateParams -from .memory_bank_drop_response import MemoryBankDropResponse as MemoryBankDropResponse -from .memory_bank_insert_params import MemoryBankInsertParams as MemoryBankInsertParams -from .memory_bank_update_params import MemoryBankUpdateParams as MemoryBankUpdateParams -from .safety_run_shields_params import SafetyRunShieldsParams as SafetyRunShieldsParams from .scored_dialog_generations import ScoredDialogGenerations as ScoredDialogGenerations from .synthetic_data_generation import SyntheticDataGeneration as SyntheticDataGeneration from .telemetry_get_trace_params import TelemetryGetTraceParams as TelemetryGetTraceParams from .inference_completion_params import InferenceCompletionParams as InferenceCompletionParams -from .memory_bank_retrieve_params import MemoryBankRetrieveParams as MemoryBankRetrieveParams from .reward_scoring_score_params import RewardScoringScoreParams as RewardScoringScoreParams -from .safety_run_shields_response import SafetyRunShieldsResponse as SafetyRunShieldsResponse from .tool_param_definition_param import ToolParamDefinitionParam as ToolParamDefinitionParam from .chat_completion_stream_chunk import ChatCompletionStreamChunk as ChatCompletionStreamChunk from .telemetry_get_trace_response import TelemetryGetTraceResponse as TelemetryGetTraceResponse @@ -55,12 +59,9 @@ from .evaluation_summarization_params import EvaluationSummarizationParams as EvaluationSummarizationParams from .rest_api_execution_config_param import RestAPIExecutionConfigParam as RestAPIExecutionConfigParam from .inference_chat_completion_params import InferenceChatCompletionParams as InferenceChatCompletionParams -from .llm_query_generator_config_param import LlmQueryGeneratorConfigParam as LlmQueryGeneratorConfigParam from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams from .evaluation_text_generation_params import EvaluationTextGenerationParams as EvaluationTextGenerationParams from .inference_chat_completion_response import InferenceChatCompletionResponse as InferenceChatCompletionResponse -from .custom_query_generator_config_param import CustomQueryGeneratorConfigParam as CustomQueryGeneratorConfigParam -from .default_query_generator_config_param import DefaultQueryGeneratorConfigParam as DefaultQueryGeneratorConfigParam from .batch_inference_chat_completion_params import ( BatchInferenceChatCompletionParams as BatchInferenceChatCompletionParams, ) diff --git a/src/llama_stack/types/agent_create_params.py b/src/llama_stack/types/agent_create_params.py index 39738a3..27d4705 100644 --- a/src/llama_stack/types/agent_create_params.py +++ b/src/llama_stack/types/agent_create_params.py @@ -3,15 +3,12 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict -from .shield_definition_param import ShieldDefinitionParam +from .._utils import PropertyInfo from .tool_param_definition_param import ToolParamDefinitionParam from .shared_params.sampling_params import SamplingParams from .rest_api_execution_config_param import RestAPIExecutionConfigParam -from .llm_query_generator_config_param import LlmQueryGeneratorConfigParam -from .custom_query_generator_config_param import CustomQueryGeneratorConfigParam -from .default_query_generator_config_param import DefaultQueryGeneratorConfigParam __all__ = [ "AgentCreateParams", @@ -22,19 +19,24 @@ "AgentConfigToolPhotogenToolDefinition", "AgentConfigToolCodeInterpreterToolDefinition", "AgentConfigToolFunctionCallToolDefinition", - "AgentConfigToolShield", - "AgentConfigToolShieldMemoryBankConfig", - "AgentConfigToolShieldMemoryBankConfigVector", - "AgentConfigToolShieldMemoryBankConfigKeyvalue", - "AgentConfigToolShieldMemoryBankConfigKeyword", - "AgentConfigToolShieldMemoryBankConfigGraph", - "AgentConfigToolShieldQueryGeneratorConfig", + "AgentConfigToolMemoryToolDefinition", + "AgentConfigToolMemoryToolDefinitionMemoryBankConfig", + "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0", + "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1", + "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2", + "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3", + "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfig", + "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0", + "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1", + "AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigType", ] class AgentCreateParams(TypedDict, total=False): agent_config: Required[AgentConfig] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class AgentConfigToolSearchToolDefinition(TypedDict, total=False): api_key: Required[str] @@ -43,9 +45,9 @@ class AgentConfigToolSearchToolDefinition(TypedDict, total=False): type: Required[Literal["brave_search"]] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] remote_execution: RestAPIExecutionConfigParam @@ -55,9 +57,9 @@ class AgentConfigToolWolframAlphaToolDefinition(TypedDict, total=False): type: Required[Literal["wolfram_alpha"]] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] remote_execution: RestAPIExecutionConfigParam @@ -65,9 +67,9 @@ class AgentConfigToolWolframAlphaToolDefinition(TypedDict, total=False): class AgentConfigToolPhotogenToolDefinition(TypedDict, total=False): type: Required[Literal["photogen"]] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] remote_execution: RestAPIExecutionConfigParam @@ -77,9 +79,9 @@ class AgentConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): type: Required[Literal["code_interpreter"]] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] remote_execution: RestAPIExecutionConfigParam @@ -93,20 +95,20 @@ class AgentConfigToolFunctionCallToolDefinition(TypedDict, total=False): type: Required[Literal["function_call"]] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] remote_execution: RestAPIExecutionConfigParam -class AgentConfigToolShieldMemoryBankConfigVector(TypedDict, total=False): +class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False): bank_id: Required[str] type: Required[Literal["vector"]] -class AgentConfigToolShieldMemoryBankConfigKeyvalue(TypedDict, total=False): +class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1(TypedDict, total=False): bank_id: Required[str] keys: Required[List[str]] @@ -114,13 +116,13 @@ class AgentConfigToolShieldMemoryBankConfigKeyvalue(TypedDict, total=False): type: Required[Literal["keyvalue"]] -class AgentConfigToolShieldMemoryBankConfigKeyword(TypedDict, total=False): +class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2(TypedDict, total=False): bank_id: Required[str] type: Required[Literal["keyword"]] -class AgentConfigToolShieldMemoryBankConfigGraph(TypedDict, total=False): +class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3(TypedDict, total=False): bank_id: Required[str] entities: Required[List[str]] @@ -128,32 +130,53 @@ class AgentConfigToolShieldMemoryBankConfigGraph(TypedDict, total=False): type: Required[Literal["graph"]] -AgentConfigToolShieldMemoryBankConfig: TypeAlias = Union[ - AgentConfigToolShieldMemoryBankConfigVector, - AgentConfigToolShieldMemoryBankConfigKeyvalue, - AgentConfigToolShieldMemoryBankConfigKeyword, - AgentConfigToolShieldMemoryBankConfigGraph, +AgentConfigToolMemoryToolDefinitionMemoryBankConfig: TypeAlias = Union[ + AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0, + AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1, + AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2, + AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3, ] -AgentConfigToolShieldQueryGeneratorConfig: TypeAlias = Union[ - DefaultQueryGeneratorConfigParam, LlmQueryGeneratorConfigParam, CustomQueryGeneratorConfigParam + +class AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0(TypedDict, total=False): + sep: Required[str] + + type: Required[Literal["default"]] + + +class AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1(TypedDict, total=False): + model: Required[str] + + template: Required[str] + + type: Required[Literal["llm"]] + + +class AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigType(TypedDict, total=False): + type: Required[Literal["custom"]] + + +AgentConfigToolMemoryToolDefinitionQueryGeneratorConfig: TypeAlias = Union[ + AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0, + AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1, + AgentConfigToolMemoryToolDefinitionQueryGeneratorConfigType, ] -class AgentConfigToolShield(TypedDict, total=False): +class AgentConfigToolMemoryToolDefinition(TypedDict, total=False): max_chunks: Required[int] max_tokens_in_context: Required[int] - memory_bank_configs: Required[Iterable[AgentConfigToolShieldMemoryBankConfig]] + memory_bank_configs: Required[Iterable[AgentConfigToolMemoryToolDefinitionMemoryBankConfig]] - query_generator_config: Required[AgentConfigToolShieldQueryGeneratorConfig] + query_generator_config: Required[AgentConfigToolMemoryToolDefinitionQueryGeneratorConfig] type: Required[Literal["memory"]] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] AgentConfigTool: TypeAlias = Union[ @@ -162,18 +185,22 @@ class AgentConfigToolShield(TypedDict, total=False): AgentConfigToolPhotogenToolDefinition, AgentConfigToolCodeInterpreterToolDefinition, AgentConfigToolFunctionCallToolDefinition, - AgentConfigToolShield, + AgentConfigToolMemoryToolDefinition, ] class AgentConfig(TypedDict, total=False): + enable_session_persistence: Required[bool] + instructions: Required[str] + max_infer_iters: Required[int] + model: Required[str] - input_shields: Iterable[ShieldDefinitionParam] + input_shields: List[str] - output_shields: Iterable[ShieldDefinitionParam] + output_shields: List[str] sampling_params: SamplingParams diff --git a/src/llama_stack/types/agent_delete_params.py b/src/llama_stack/types/agent_delete_params.py index cfc4cd0..ba601b9 100644 --- a/src/llama_stack/types/agent_delete_params.py +++ b/src/llama_stack/types/agent_delete_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo __all__ = ["AgentDeleteParams"] class AgentDeleteParams(TypedDict, total=False): agent_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/agents/session_create_params.py b/src/llama_stack/types/agents/session_create_params.py index d64bee0..42e19fe 100644 --- a/src/llama_stack/types/agents/session_create_params.py +++ b/src/llama_stack/types/agents/session_create_params.py @@ -2,7 +2,9 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["SessionCreateParams"] @@ -11,3 +13,5 @@ class SessionCreateParams(TypedDict, total=False): agent_id: Required[str] session_name: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/agents/session_delete_params.py b/src/llama_stack/types/agents/session_delete_params.py index 1474f72..45864d6 100644 --- a/src/llama_stack/types/agents/session_delete_params.py +++ b/src/llama_stack/types/agents/session_delete_params.py @@ -2,7 +2,9 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["SessionDeleteParams"] @@ -11,3 +13,5 @@ class SessionDeleteParams(TypedDict, total=False): agent_id: Required[str] session_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/agents/session_retrieve_params.py b/src/llama_stack/types/agents/session_retrieve_params.py index 4c0b691..974c95f 100644 --- a/src/llama_stack/types/agents/session_retrieve_params.py +++ b/src/llama_stack/types/agents/session_retrieve_params.py @@ -3,7 +3,9 @@ from __future__ import annotations from typing import List -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["SessionRetrieveParams"] @@ -14,3 +16,5 @@ class SessionRetrieveParams(TypedDict, total=False): session_id: Required[str] turn_ids: List[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/agents/step_retrieve_params.py b/src/llama_stack/types/agents/step_retrieve_params.py index 35b49c5..cccdc19 100644 --- a/src/llama_stack/types/agents/step_retrieve_params.py +++ b/src/llama_stack/types/agents/step_retrieve_params.py @@ -2,7 +2,9 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["StepRetrieveParams"] @@ -13,3 +15,5 @@ class StepRetrieveParams(TypedDict, total=False): step_id: Required[str] turn_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/agents/turn_create_params.py b/src/llama_stack/types/agents/turn_create_params.py index ffbb54c..349d12d 100644 --- a/src/llama_stack/types/agents/turn_create_params.py +++ b/src/llama_stack/types/agents/turn_create_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict +from ..._utils import PropertyInfo from ..shared_params.attachment import Attachment from ..shared_params.user_message import UserMessage from ..shared_params.tool_response_message import ToolResponseMessage @@ -21,6 +22,8 @@ class TurnCreateParamsBase(TypedDict, total=False): attachments: Iterable[Attachment] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + Message: TypeAlias = Union[UserMessage, ToolResponseMessage] diff --git a/src/llama_stack/types/agents/turn_retrieve_params.py b/src/llama_stack/types/agents/turn_retrieve_params.py index 0352f69..7f3349a 100644 --- a/src/llama_stack/types/agents/turn_retrieve_params.py +++ b/src/llama_stack/types/agents/turn_retrieve_params.py @@ -2,7 +2,9 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["TurnRetrieveParams"] @@ -11,3 +13,5 @@ class TurnRetrieveParams(TypedDict, total=False): agent_id: Required[str] turn_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/batch_inference_chat_completion_params.py b/src/llama_stack/types/batch_inference_chat_completion_params.py index 28ad798..0aae2d2 100644 --- a/src/llama_stack/types/batch_inference_chat_completion_params.py +++ b/src/llama_stack/types/batch_inference_chat_completion_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Dict, Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict +from .._utils import PropertyInfo from .shared_params.user_message import UserMessage from .tool_param_definition_param import ToolParamDefinitionParam from .shared_params.system_message import SystemMessage @@ -41,6 +42,8 @@ class BatchInferenceChatCompletionParams(TypedDict, total=False): tools: Iterable[Tool] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + MessagesBatch: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/batch_inference_completion_params.py b/src/llama_stack/types/batch_inference_completion_params.py index 398531a..e205b85 100644 --- a/src/llama_stack/types/batch_inference_completion_params.py +++ b/src/llama_stack/types/batch_inference_completion_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import List, Union -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict +from .._utils import PropertyInfo from .shared_params.sampling_params import SamplingParams __all__ = ["BatchInferenceCompletionParams", "Logprobs"] @@ -19,6 +20,8 @@ class BatchInferenceCompletionParams(TypedDict, total=False): sampling_params: SamplingParams + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class Logprobs(TypedDict, total=False): top_k: int diff --git a/src/llama_stack/types/custom_query_generator_config_param.py b/src/llama_stack/types/custom_query_generator_config_param.py deleted file mode 100644 index 432450c..0000000 --- a/src/llama_stack/types/custom_query_generator_config_param.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Literal, Required, TypedDict - -__all__ = ["CustomQueryGeneratorConfigParam"] - - -class CustomQueryGeneratorConfigParam(TypedDict, total=False): - type: Required[Literal["custom"]] diff --git a/src/llama_stack/types/dataset_create_params.py b/src/llama_stack/types/dataset_create_params.py index 971ddb5..ec81175 100644 --- a/src/llama_stack/types/dataset_create_params.py +++ b/src/llama_stack/types/dataset_create_params.py @@ -2,8 +2,9 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict +from .._utils import PropertyInfo from .train_eval_dataset_param import TrainEvalDatasetParam __all__ = ["DatasetCreateParams"] @@ -13,3 +14,5 @@ class DatasetCreateParams(TypedDict, total=False): dataset: Required[TrainEvalDatasetParam] uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/dataset_delete_params.py b/src/llama_stack/types/dataset_delete_params.py index ca46388..66d0670 100644 --- a/src/llama_stack/types/dataset_delete_params.py +++ b/src/llama_stack/types/dataset_delete_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo __all__ = ["DatasetDeleteParams"] class DatasetDeleteParams(TypedDict, total=False): dataset_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/dataset_get_params.py b/src/llama_stack/types/dataset_get_params.py index b1423f1..d0d6695 100644 --- a/src/llama_stack/types/dataset_get_params.py +++ b/src/llama_stack/types/dataset_get_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo __all__ = ["DatasetGetParams"] class DatasetGetParams(TypedDict, total=False): dataset_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/default_query_generator_config_param.py b/src/llama_stack/types/default_query_generator_config_param.py deleted file mode 100644 index 2aaaa81..0000000 --- a/src/llama_stack/types/default_query_generator_config_param.py +++ /dev/null @@ -1,13 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Literal, Required, TypedDict - -__all__ = ["DefaultQueryGeneratorConfigParam"] - - -class DefaultQueryGeneratorConfigParam(TypedDict, total=False): - sep: Required[str] - - type: Required[Literal["default"]] diff --git a/src/llama_stack/types/evaluate/job_cancel_params.py b/src/llama_stack/types/evaluate/job_cancel_params.py index c9c30d8..9321c3b 100644 --- a/src/llama_stack/types/evaluate/job_cancel_params.py +++ b/src/llama_stack/types/evaluate/job_cancel_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["JobCancelParams"] class JobCancelParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/evaluate/jobs/artifact_list_params.py b/src/llama_stack/types/evaluate/jobs/artifact_list_params.py index f52228e..579033e 100644 --- a/src/llama_stack/types/evaluate/jobs/artifact_list_params.py +++ b/src/llama_stack/types/evaluate/jobs/artifact_list_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ...._utils import PropertyInfo __all__ = ["ArtifactListParams"] class ArtifactListParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/evaluate/jobs/log_list_params.py b/src/llama_stack/types/evaluate/jobs/log_list_params.py index 1035005..4b2df45 100644 --- a/src/llama_stack/types/evaluate/jobs/log_list_params.py +++ b/src/llama_stack/types/evaluate/jobs/log_list_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ...._utils import PropertyInfo __all__ = ["LogListParams"] class LogListParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/evaluate/jobs/status_list_params.py b/src/llama_stack/types/evaluate/jobs/status_list_params.py index c2a0dc5..a7d5165 100644 --- a/src/llama_stack/types/evaluate/jobs/status_list_params.py +++ b/src/llama_stack/types/evaluate/jobs/status_list_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ...._utils import PropertyInfo __all__ = ["StatusListParams"] class StatusListParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/evaluate/question_answering_create_params.py b/src/llama_stack/types/evaluate/question_answering_create_params.py index 8477717..de8caa0 100644 --- a/src/llama_stack/types/evaluate/question_answering_create_params.py +++ b/src/llama_stack/types/evaluate/question_answering_create_params.py @@ -3,10 +3,14 @@ from __future__ import annotations from typing import List -from typing_extensions import Literal, Required, TypedDict +from typing_extensions import Literal, Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["QuestionAnsweringCreateParams"] class QuestionAnsweringCreateParams(TypedDict, total=False): metrics: Required[List[Literal["em", "f1"]]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/evaluation_summarization_params.py b/src/llama_stack/types/evaluation_summarization_params.py index 34542d6..80dd8f5 100644 --- a/src/llama_stack/types/evaluation_summarization_params.py +++ b/src/llama_stack/types/evaluation_summarization_params.py @@ -3,10 +3,14 @@ from __future__ import annotations from typing import List -from typing_extensions import Literal, Required, TypedDict +from typing_extensions import Literal, Required, Annotated, TypedDict + +from .._utils import PropertyInfo __all__ = ["EvaluationSummarizationParams"] class EvaluationSummarizationParams(TypedDict, total=False): metrics: Required[List[Literal["rouge", "bleu"]]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/evaluation_text_generation_params.py b/src/llama_stack/types/evaluation_text_generation_params.py index deaec66..1cd3a56 100644 --- a/src/llama_stack/types/evaluation_text_generation_params.py +++ b/src/llama_stack/types/evaluation_text_generation_params.py @@ -3,10 +3,14 @@ from __future__ import annotations from typing import List -from typing_extensions import Literal, Required, TypedDict +from typing_extensions import Literal, Required, Annotated, TypedDict + +from .._utils import PropertyInfo __all__ = ["EvaluationTextGenerationParams"] class EvaluationTextGenerationParams(TypedDict, total=False): metrics: Required[List[Literal["perplexity", "rouge", "bleu"]]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/inference/embedding_create_params.py b/src/llama_stack/types/inference/embedding_create_params.py index 272f199..0e4e970 100644 --- a/src/llama_stack/types/inference/embedding_create_params.py +++ b/src/llama_stack/types/inference/embedding_create_params.py @@ -3,7 +3,9 @@ from __future__ import annotations from typing import List, Union -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["EmbeddingCreateParams"] @@ -12,3 +14,5 @@ class EmbeddingCreateParams(TypedDict, total=False): contents: Required[List[Union[str, List[str]]]] model: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/inference_chat_completion_params.py b/src/llama_stack/types/inference_chat_completion_params.py index af21934..2cac635 100644 --- a/src/llama_stack/types/inference_chat_completion_params.py +++ b/src/llama_stack/types/inference_chat_completion_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Dict, Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict +from .._utils import PropertyInfo from .shared_params.user_message import UserMessage from .tool_param_definition_param import ToolParamDefinitionParam from .shared_params.system_message import SystemMessage @@ -48,6 +49,8 @@ class InferenceChatCompletionParamsBase(TypedDict, total=False): tools: Iterable[Tool] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + Message: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/inference_completion_params.py b/src/llama_stack/types/inference_completion_params.py index 79544b0..d4145a5 100644 --- a/src/llama_stack/types/inference_completion_params.py +++ b/src/llama_stack/types/inference_completion_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import List, Union -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict +from .._utils import PropertyInfo from .shared_params.sampling_params import SamplingParams __all__ = ["InferenceCompletionParams", "Logprobs"] @@ -21,6 +22,8 @@ class InferenceCompletionParams(TypedDict, total=False): stream: bool + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class Logprobs(TypedDict, total=False): top_k: int diff --git a/src/llama_stack/types/llm_query_generator_config_param.py b/src/llama_stack/types/llm_query_generator_config_param.py deleted file mode 100644 index 8d6bd31..0000000 --- a/src/llama_stack/types/llm_query_generator_config_param.py +++ /dev/null @@ -1,15 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Literal, Required, TypedDict - -__all__ = ["LlmQueryGeneratorConfigParam"] - - -class LlmQueryGeneratorConfigParam(TypedDict, total=False): - model: Required[str] - - template: Required[str] - - type: Required[Literal["llm"]] diff --git a/src/llama_stack/types/memory_banks/__init__.py b/src/llama_stack/types/memory/__init__.py similarity index 100% rename from src/llama_stack/types/memory_banks/__init__.py rename to src/llama_stack/types/memory/__init__.py diff --git a/src/llama_stack/types/memory_banks/document_delete_params.py b/src/llama_stack/types/memory/document_delete_params.py similarity index 60% rename from src/llama_stack/types/memory_banks/document_delete_params.py rename to src/llama_stack/types/memory/document_delete_params.py index 01d66b6..9ec4bf1 100644 --- a/src/llama_stack/types/memory_banks/document_delete_params.py +++ b/src/llama_stack/types/memory/document_delete_params.py @@ -3,7 +3,9 @@ from __future__ import annotations from typing import List -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["DocumentDeleteParams"] @@ -12,3 +14,5 @@ class DocumentDeleteParams(TypedDict, total=False): bank_id: Required[str] document_ids: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_banks/document_retrieve_params.py b/src/llama_stack/types/memory/document_retrieve_params.py similarity index 61% rename from src/llama_stack/types/memory_banks/document_retrieve_params.py rename to src/llama_stack/types/memory/document_retrieve_params.py index d746395..3f30f9b 100644 --- a/src/llama_stack/types/memory_banks/document_retrieve_params.py +++ b/src/llama_stack/types/memory/document_retrieve_params.py @@ -3,7 +3,9 @@ from __future__ import annotations from typing import List -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["DocumentRetrieveParams"] @@ -12,3 +14,5 @@ class DocumentRetrieveParams(TypedDict, total=False): bank_id: Required[str] document_ids: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_banks/document_retrieve_response.py b/src/llama_stack/types/memory/document_retrieve_response.py similarity index 100% rename from src/llama_stack/types/memory_banks/document_retrieve_response.py rename to src/llama_stack/types/memory/document_retrieve_response.py diff --git a/src/llama_stack/types/memory_bank_create_params.py b/src/llama_stack/types/memory_bank_create_params.py deleted file mode 100644 index 421fca0..0000000 --- a/src/llama_stack/types/memory_bank_create_params.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["MemoryBankCreateParams"] - - -class MemoryBankCreateParams(TypedDict, total=False): - body: Required[object] diff --git a/src/llama_stack/types/memory_bank_drop_params.py b/src/llama_stack/types/memory_bank_drop_params.py deleted file mode 100644 index e19be55..0000000 --- a/src/llama_stack/types/memory_bank_drop_params.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["MemoryBankDropParams"] - - -class MemoryBankDropParams(TypedDict, total=False): - bank_id: Required[str] diff --git a/src/llama_stack/types/memory_bank_get_params.py b/src/llama_stack/types/memory_bank_get_params.py new file mode 100644 index 0000000..de5b43e --- /dev/null +++ b/src/llama_stack/types/memory_bank_get_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["MemoryBankGetParams"] + + +class MemoryBankGetParams(TypedDict, total=False): + bank_type: Required[Literal["vector", "keyvalue", "keyword", "graph"]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_bank_retrieve_params.py b/src/llama_stack/types/memory_bank_retrieve_params.py deleted file mode 100644 index 21436f3..0000000 --- a/src/llama_stack/types/memory_bank_retrieve_params.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["MemoryBankRetrieveParams"] - - -class MemoryBankRetrieveParams(TypedDict, total=False): - bank_id: Required[str] diff --git a/src/llama_stack/types/memory_bank_spec.py b/src/llama_stack/types/memory_bank_spec.py new file mode 100644 index 0000000..b116082 --- /dev/null +++ b/src/llama_stack/types/memory_bank_spec.py @@ -0,0 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["MemoryBankSpec", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_id: str + + +class MemoryBankSpec(BaseModel): + bank_type: Literal["vector", "keyvalue", "keyword", "graph"] + + provider_config: ProviderConfig diff --git a/src/llama_stack/types/memory_create_params.py b/src/llama_stack/types/memory_create_params.py new file mode 100644 index 0000000..01f496e --- /dev/null +++ b/src/llama_stack/types/memory_create_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["MemoryCreateParams"] + + +class MemoryCreateParams(TypedDict, total=False): + body: Required[object] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_drop_params.py b/src/llama_stack/types/memory_drop_params.py new file mode 100644 index 0000000..b15ec34 --- /dev/null +++ b/src/llama_stack/types/memory_drop_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["MemoryDropParams"] + + +class MemoryDropParams(TypedDict, total=False): + bank_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_bank_drop_response.py b/src/llama_stack/types/memory_drop_response.py similarity index 62% rename from src/llama_stack/types/memory_bank_drop_response.py rename to src/llama_stack/types/memory_drop_response.py index d3f5c3f..f032e04 100644 --- a/src/llama_stack/types/memory_bank_drop_response.py +++ b/src/llama_stack/types/memory_drop_response.py @@ -2,6 +2,6 @@ from typing_extensions import TypeAlias -__all__ = ["MemoryBankDropResponse"] +__all__ = ["MemoryDropResponse"] -MemoryBankDropResponse: TypeAlias = str +MemoryDropResponse: TypeAlias = str diff --git a/src/llama_stack/types/memory_bank_insert_params.py b/src/llama_stack/types/memory_insert_params.py similarity index 63% rename from src/llama_stack/types/memory_bank_insert_params.py rename to src/llama_stack/types/memory_insert_params.py index c460fc2..b09f8a7 100644 --- a/src/llama_stack/types/memory_bank_insert_params.py +++ b/src/llama_stack/types/memory_insert_params.py @@ -3,18 +3,22 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict -__all__ = ["MemoryBankInsertParams", "Document"] +from .._utils import PropertyInfo +__all__ = ["MemoryInsertParams", "Document"] -class MemoryBankInsertParams(TypedDict, total=False): + +class MemoryInsertParams(TypedDict, total=False): bank_id: Required[str] documents: Required[Iterable[Document]] ttl_seconds: int + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class Document(TypedDict, total=False): content: Required[Union[str, List[str]]] diff --git a/src/llama_stack/types/memory_bank_query_params.py b/src/llama_stack/types/memory_query_params.py similarity index 54% rename from src/llama_stack/types/memory_bank_query_params.py rename to src/llama_stack/types/memory_query_params.py index 05e4a56..d885600 100644 --- a/src/llama_stack/types/memory_bank_query_params.py +++ b/src/llama_stack/types/memory_query_params.py @@ -3,14 +3,18 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict -__all__ = ["MemoryBankQueryParams"] +from .._utils import PropertyInfo +__all__ = ["MemoryQueryParams"] -class MemoryBankQueryParams(TypedDict, total=False): + +class MemoryQueryParams(TypedDict, total=False): bank_id: Required[str] query: Required[Union[str, List[str]]] params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_retrieve_params.py b/src/llama_stack/types/memory_retrieve_params.py new file mode 100644 index 0000000..62f6496 --- /dev/null +++ b/src/llama_stack/types/memory_retrieve_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["MemoryRetrieveParams"] + + +class MemoryRetrieveParams(TypedDict, total=False): + bank_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/memory_bank_update_params.py b/src/llama_stack/types/memory_update_params.py similarity index 62% rename from src/llama_stack/types/memory_bank_update_params.py rename to src/llama_stack/types/memory_update_params.py index 2ab747b..ac9b8e4 100644 --- a/src/llama_stack/types/memory_bank_update_params.py +++ b/src/llama_stack/types/memory_update_params.py @@ -3,16 +3,20 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict -__all__ = ["MemoryBankUpdateParams", "Document"] +from .._utils import PropertyInfo +__all__ = ["MemoryUpdateParams", "Document"] -class MemoryBankUpdateParams(TypedDict, total=False): + +class MemoryUpdateParams(TypedDict, total=False): bank_id: Required[str] documents: Required[Iterable[Document]] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class Document(TypedDict, total=False): content: Required[Union[str, List[str]]] diff --git a/src/llama_stack/types/model_get_params.py b/src/llama_stack/types/model_get_params.py new file mode 100644 index 0000000..f3dc87d --- /dev/null +++ b/src/llama_stack/types/model_get_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["ModelGetParams"] + + +class ModelGetParams(TypedDict, total=False): + core_model_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/model_serving_spec.py b/src/llama_stack/types/model_serving_spec.py new file mode 100644 index 0000000..87b75a9 --- /dev/null +++ b/src/llama_stack/types/model_serving_spec.py @@ -0,0 +1,23 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ModelServingSpec", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_id: str + + +class ModelServingSpec(BaseModel): + llama_model: object + """ + The model family and SKU of the model along with other parameters corresponding + to the model. + """ + + provider_config: ProviderConfig diff --git a/src/llama_stack/types/post_training/job_artifacts_params.py b/src/llama_stack/types/post_training/job_artifacts_params.py index 4f75a13..1f7ae65 100644 --- a/src/llama_stack/types/post_training/job_artifacts_params.py +++ b/src/llama_stack/types/post_training/job_artifacts_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["JobArtifactsParams"] class JobArtifactsParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/post_training/job_cancel_params.py b/src/llama_stack/types/post_training/job_cancel_params.py index c9c30d8..9321c3b 100644 --- a/src/llama_stack/types/post_training/job_cancel_params.py +++ b/src/llama_stack/types/post_training/job_cancel_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["JobCancelParams"] class JobCancelParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/post_training/job_logs_params.py b/src/llama_stack/types/post_training/job_logs_params.py index a550be5..42f7e07 100644 --- a/src/llama_stack/types/post_training/job_logs_params.py +++ b/src/llama_stack/types/post_training/job_logs_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["JobLogsParams"] class JobLogsParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/post_training/job_status_params.py b/src/llama_stack/types/post_training/job_status_params.py index 8cf17b0..f1f8b20 100644 --- a/src/llama_stack/types/post_training/job_status_params.py +++ b/src/llama_stack/types/post_training/job_status_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo __all__ = ["JobStatusParams"] class JobStatusParams(TypedDict, total=False): job_uuid: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/post_training_preference_optimize_params.py b/src/llama_stack/types/post_training_preference_optimize_params.py index 9e6f3cc..805e6cf 100644 --- a/src/llama_stack/types/post_training_preference_optimize_params.py +++ b/src/llama_stack/types/post_training_preference_optimize_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Dict, Union, Iterable -from typing_extensions import Literal, Required, TypedDict +from typing_extensions import Literal, Required, Annotated, TypedDict +from .._utils import PropertyInfo from .train_eval_dataset_param import TrainEvalDatasetParam __all__ = ["PostTrainingPreferenceOptimizeParams", "AlgorithmConfig", "OptimizerConfig", "TrainingConfig"] @@ -31,6 +32,8 @@ class PostTrainingPreferenceOptimizeParams(TypedDict, total=False): validation_dataset: Required[TrainEvalDatasetParam] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class AlgorithmConfig(TypedDict, total=False): epsilon: Required[float] diff --git a/src/llama_stack/types/post_training_supervised_fine_tune_params.py b/src/llama_stack/types/post_training_supervised_fine_tune_params.py index 36f776d..084e1ed 100644 --- a/src/llama_stack/types/post_training_supervised_fine_tune_params.py +++ b/src/llama_stack/types/post_training_supervised_fine_tune_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict +from .._utils import PropertyInfo from .train_eval_dataset_param import TrainEvalDatasetParam __all__ = [ @@ -39,6 +40,8 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False): validation_dataset: Required[TrainEvalDatasetParam] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False): alpha: Required[int] diff --git a/src/llama_stack/types/reward_scoring_score_params.py b/src/llama_stack/types/reward_scoring_score_params.py index b969b75..bb7bfb6 100644 --- a/src/llama_stack/types/reward_scoring_score_params.py +++ b/src/llama_stack/types/reward_scoring_score_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Union, Iterable -from typing_extensions import Required, TypeAlias, TypedDict +from typing_extensions import Required, Annotated, TypeAlias, TypedDict +from .._utils import PropertyInfo from .shared_params.user_message import UserMessage from .shared_params.system_message import SystemMessage from .shared_params.completion_message import CompletionMessage @@ -23,6 +24,8 @@ class RewardScoringScoreParams(TypedDict, total=False): model: Required[str] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + DialogGenerationDialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/run_sheid_response.py b/src/llama_stack/types/run_sheid_response.py new file mode 100644 index 0000000..478b023 --- /dev/null +++ b/src/llama_stack/types/run_sheid_response.py @@ -0,0 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["RunSheidResponse", "Violation"] + + +class Violation(BaseModel): + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + violation_level: Literal["info", "warn", "error"] + + user_message: Optional[str] = None + + +class RunSheidResponse(BaseModel): + violation: Optional[Violation] = None diff --git a/src/llama_stack/types/safety_run_shields_params.py b/src/llama_stack/types/safety_run_shield_params.py similarity index 52% rename from src/llama_stack/types/safety_run_shields_params.py rename to src/llama_stack/types/safety_run_shield_params.py index 59498f6..430473b 100644 --- a/src/llama_stack/types/safety_run_shields_params.py +++ b/src/llama_stack/types/safety_run_shield_params.py @@ -2,22 +2,26 @@ from __future__ import annotations -from typing import Union, Iterable -from typing_extensions import Required, TypeAlias, TypedDict +from typing import Dict, Union, Iterable +from typing_extensions import Required, Annotated, TypeAlias, TypedDict -from .shield_definition_param import ShieldDefinitionParam +from .._utils import PropertyInfo from .shared_params.user_message import UserMessage from .shared_params.system_message import SystemMessage from .shared_params.completion_message import CompletionMessage from .shared_params.tool_response_message import ToolResponseMessage -__all__ = ["SafetyRunShieldsParams", "Message"] +__all__ = ["SafetyRunShieldParams", "Message"] -class SafetyRunShieldsParams(TypedDict, total=False): +class SafetyRunShieldParams(TypedDict, total=False): messages: Required[Iterable[Message]] - shields: Required[Iterable[ShieldDefinitionParam]] + params: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + shield_type: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] Message: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/safety_run_shields_response.py b/src/llama_stack/types/safety_run_shields_response.py deleted file mode 100644 index 24f87f2..0000000 --- a/src/llama_stack/types/safety_run_shields_response.py +++ /dev/null @@ -1,12 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import List - -from .._models import BaseModel -from .sheid_response import SheidResponse - -__all__ = ["SafetyRunShieldsResponse"] - - -class SafetyRunShieldsResponse(BaseModel): - responses: List[SheidResponse] diff --git a/src/llama_stack/types/sheid_response.py b/src/llama_stack/types/sheid_response.py deleted file mode 100644 index 99cdcf5..0000000 --- a/src/llama_stack/types/sheid_response.py +++ /dev/null @@ -1,20 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Union, Optional -from typing_extensions import Literal - -from .._models import BaseModel - -__all__ = ["SheidResponse"] - - -class SheidResponse(BaseModel): - is_violation: bool - - shield_type: Union[ - Literal["llama_guard", "code_scanner_guard", "third_party_shield", "injection_shield", "jailbreak_shield"], str - ] - - violation_return_message: Optional[str] = None - - violation_type: Optional[str] = None diff --git a/src/llama_stack/types/shield_call_step.py b/src/llama_stack/types/shield_call_step.py index e360825..d4b90d8 100644 --- a/src/llama_stack/types/shield_call_step.py +++ b/src/llama_stack/types/shield_call_step.py @@ -1,18 +1,23 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Optional +from typing import Dict, List, Union, Optional from datetime import datetime from typing_extensions import Literal from .._models import BaseModel -from .sheid_response import SheidResponse -__all__ = ["ShieldCallStep"] +__all__ = ["ShieldCallStep", "Violation"] -class ShieldCallStep(BaseModel): - response: SheidResponse +class Violation(BaseModel): + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + violation_level: Literal["info", "warn", "error"] + + user_message: Optional[str] = None + +class ShieldCallStep(BaseModel): step_id: str step_type: Literal["shield_call"] @@ -22,3 +27,5 @@ class ShieldCallStep(BaseModel): completed_at: Optional[datetime] = None started_at: Optional[datetime] = None + + violation: Optional[Violation] = None diff --git a/src/llama_stack/types/shield_definition_param.py b/src/llama_stack/types/shield_definition_param.py deleted file mode 100644 index 9672e03..0000000 --- a/src/llama_stack/types/shield_definition_param.py +++ /dev/null @@ -1,28 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union -from typing_extensions import Literal, Required, TypedDict - -from .tool_param_definition_param import ToolParamDefinitionParam -from .rest_api_execution_config_param import RestAPIExecutionConfigParam - -__all__ = ["ShieldDefinitionParam"] - - -class ShieldDefinitionParam(TypedDict, total=False): - on_violation_action: Required[Literal[0, 1, 2]] - - shield_type: Required[ - Union[ - Literal["llama_guard", "code_scanner_guard", "third_party_shield", "injection_shield", "jailbreak_shield"], - str, - ] - ] - - description: str - - execution_config: RestAPIExecutionConfigParam - - parameters: Dict[str, ToolParamDefinitionParam] diff --git a/src/llama_stack/types/shield_get_params.py b/src/llama_stack/types/shield_get_params.py new file mode 100644 index 0000000..cb9ce90 --- /dev/null +++ b/src/llama_stack/types/shield_get_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["ShieldGetParams"] + + +class ShieldGetParams(TypedDict, total=False): + shield_type: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/shield_spec.py b/src/llama_stack/types/shield_spec.py new file mode 100644 index 0000000..d83cd51 --- /dev/null +++ b/src/llama_stack/types/shield_spec.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ShieldSpec", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_id: str + + +class ShieldSpec(BaseModel): + provider_config: ProviderConfig + + shield_type: str diff --git a/src/llama_stack/types/synthetic_data_generation_generate_params.py b/src/llama_stack/types/synthetic_data_generation_generate_params.py index 4238473..2514992 100644 --- a/src/llama_stack/types/synthetic_data_generation_generate_params.py +++ b/src/llama_stack/types/synthetic_data_generation_generate_params.py @@ -3,8 +3,9 @@ from __future__ import annotations from typing import Union, Iterable -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict +from .._utils import PropertyInfo from .shared_params.user_message import UserMessage from .shared_params.system_message import SystemMessage from .shared_params.completion_message import CompletionMessage @@ -20,5 +21,7 @@ class SyntheticDataGenerationGenerateParams(TypedDict, total=False): model: str + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + Dialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/telemetry_get_trace_params.py b/src/llama_stack/types/telemetry_get_trace_params.py index 520724b..dbee698 100644 --- a/src/llama_stack/types/telemetry_get_trace_params.py +++ b/src/llama_stack/types/telemetry_get_trace_params.py @@ -2,10 +2,14 @@ from __future__ import annotations -from typing_extensions import Required, TypedDict +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo __all__ = ["TelemetryGetTraceParams"] class TelemetryGetTraceParams(TypedDict, total=False): trace_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack/types/telemetry_log_params.py b/src/llama_stack/types/telemetry_log_params.py index 6e6eb61..a2e4d9b 100644 --- a/src/llama_stack/types/telemetry_log_params.py +++ b/src/llama_stack/types/telemetry_log_params.py @@ -23,6 +23,8 @@ class TelemetryLogParams(TypedDict, total=False): event: Required[Event] + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + class EventUnstructuredLogEvent(TypedDict, total=False): message: Required[str] diff --git a/tests/api_resources/agents/test_sessions.py b/tests/api_resources/agents/test_sessions.py index 4040ef0..97836d7 100644 --- a/tests/api_resources/agents/test_sessions.py +++ b/tests/api_resources/agents/test_sessions.py @@ -28,6 +28,15 @@ def test_method_create(self, client: LlamaStack) -> None: ) assert_matches_type(SessionCreateResponse, session, path=["response"]) + @parametrize + def test_method_create_with_all_params(self, client: LlamaStack) -> None: + session = client.agents.sessions.create( + agent_id="agent_id", + session_name="session_name", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(SessionCreateResponse, session, path=["response"]) + @parametrize def test_raw_response_create(self, client: LlamaStack) -> None: response = client.agents.sessions.with_raw_response.create( @@ -68,6 +77,7 @@ def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None: agent_id="agent_id", session_id="session_id", turn_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(Session, session, path=["response"]) @@ -105,6 +115,15 @@ def test_method_delete(self, client: LlamaStack) -> None: ) assert session is None + @parametrize + def test_method_delete_with_all_params(self, client: LlamaStack) -> None: + session = client.agents.sessions.delete( + agent_id="agent_id", + session_id="session_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert session is None + @parametrize def test_raw_response_delete(self, client: LlamaStack) -> None: response = client.agents.sessions.with_raw_response.delete( @@ -143,6 +162,15 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(SessionCreateResponse, session, path=["response"]) + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: + session = await async_client.agents.sessions.create( + agent_id="agent_id", + session_name="session_name", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(SessionCreateResponse, session, path=["response"]) + @parametrize async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agents.sessions.with_raw_response.create( @@ -183,6 +211,7 @@ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaSta agent_id="agent_id", session_id="session_id", turn_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(Session, session, path=["response"]) @@ -220,6 +249,15 @@ async def test_method_delete(self, async_client: AsyncLlamaStack) -> None: ) assert session is None + @parametrize + async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None: + session = await async_client.agents.sessions.delete( + agent_id="agent_id", + session_id="session_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert session is None + @parametrize async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agents.sessions.with_raw_response.delete( diff --git a/tests/api_resources/agents/test_steps.py b/tests/api_resources/agents/test_steps.py index 3dda39e..64345d9 100644 --- a/tests/api_resources/agents/test_steps.py +++ b/tests/api_resources/agents/test_steps.py @@ -26,6 +26,16 @@ def test_method_retrieve(self, client: LlamaStack) -> None: ) assert_matches_type(AgentsStep, step, path=["response"]) + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None: + step = client.agents.steps.retrieve( + agent_id="agent_id", + step_id="step_id", + turn_id="turn_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(AgentsStep, step, path=["response"]) + @parametrize def test_raw_response_retrieve(self, client: LlamaStack) -> None: response = client.agents.steps.with_raw_response.retrieve( @@ -67,6 +77,16 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(AgentsStep, step, path=["response"]) + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None: + step = await async_client.agents.steps.retrieve( + agent_id="agent_id", + step_id="step_id", + turn_id="turn_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(AgentsStep, step, path=["response"]) + @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agents.steps.with_raw_response.retrieve( diff --git a/tests/api_resources/agents/test_turns.py b/tests/api_resources/agents/test_turns.py index 47614e2..ec890a0 100644 --- a/tests/api_resources/agents/test_turns.py +++ b/tests/api_resources/agents/test_turns.py @@ -76,6 +76,7 @@ def test_method_create_with_all_params_overload_1(self, client: LlamaStack) -> N }, ], stream=False, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) @@ -193,6 +194,7 @@ def test_method_create_with_all_params_overload_2(self, client: LlamaStack) -> N "mime_type": "mime_type", }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) turn_stream.response.close() @@ -259,6 +261,15 @@ def test_method_retrieve(self, client: LlamaStack) -> None: ) assert_matches_type(Turn, turn, path=["response"]) + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None: + turn = client.agents.turns.retrieve( + agent_id="agent_id", + turn_id="turn_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Turn, turn, path=["response"]) + @parametrize def test_raw_response_retrieve(self, client: LlamaStack) -> None: response = client.agents.turns.with_raw_response.retrieve( @@ -348,6 +359,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn }, ], stream=False, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) @@ -465,6 +477,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn "mime_type": "mime_type", }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) await turn_stream.response.aclose() @@ -531,6 +544,15 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(Turn, turn, path=["response"]) + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None: + turn = await async_client.agents.turns.retrieve( + agent_id="agent_id", + turn_id="turn_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Turn, turn, path=["response"]) + @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agents.turns.with_raw_response.retrieve( diff --git a/tests/api_resources/evaluate/jobs/test_artifacts.py b/tests/api_resources/evaluate/jobs/test_artifacts.py index 3cef0d7..44d205f 100755 --- a/tests/api_resources/evaluate/jobs/test_artifacts.py +++ b/tests/api_resources/evaluate/jobs/test_artifacts.py @@ -24,6 +24,14 @@ def test_method_list(self, client: LlamaStack) -> None: ) assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + artifact = client.evaluate.jobs.artifacts.list( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"]) + @parametrize def test_raw_response_list(self, client: LlamaStack) -> None: response = client.evaluate.jobs.artifacts.with_raw_response.list( @@ -59,6 +67,14 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + artifact = await async_client.evaluate.jobs.artifacts.list( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJobArtifacts, artifact, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.jobs.artifacts.with_raw_response.list( diff --git a/tests/api_resources/evaluate/jobs/test_logs.py b/tests/api_resources/evaluate/jobs/test_logs.py index 472ac90..51af56e 100755 --- a/tests/api_resources/evaluate/jobs/test_logs.py +++ b/tests/api_resources/evaluate/jobs/test_logs.py @@ -24,6 +24,14 @@ def test_method_list(self, client: LlamaStack) -> None: ) assert_matches_type(EvaluationJobLogStream, log, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + log = client.evaluate.jobs.logs.list( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJobLogStream, log, path=["response"]) + @parametrize def test_raw_response_list(self, client: LlamaStack) -> None: response = client.evaluate.jobs.logs.with_raw_response.list( @@ -59,6 +67,14 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(EvaluationJobLogStream, log, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + log = await async_client.evaluate.jobs.logs.list( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJobLogStream, log, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.jobs.logs.with_raw_response.list( diff --git a/tests/api_resources/evaluate/jobs/test_status.py b/tests/api_resources/evaluate/jobs/test_status.py index 3a6000e..c8bc510 100755 --- a/tests/api_resources/evaluate/jobs/test_status.py +++ b/tests/api_resources/evaluate/jobs/test_status.py @@ -24,6 +24,14 @@ def test_method_list(self, client: LlamaStack) -> None: ) assert_matches_type(EvaluationJobStatus, status, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + status = client.evaluate.jobs.status.list( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJobStatus, status, path=["response"]) + @parametrize def test_raw_response_list(self, client: LlamaStack) -> None: response = client.evaluate.jobs.status.with_raw_response.list( @@ -59,6 +67,14 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(EvaluationJobStatus, status, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + status = await async_client.evaluate.jobs.status.list( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJobStatus, status, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.jobs.status.with_raw_response.list( diff --git a/tests/api_resources/evaluate/test_jobs.py b/tests/api_resources/evaluate/test_jobs.py index 601bf38..3e296af 100644 --- a/tests/api_resources/evaluate/test_jobs.py +++ b/tests/api_resources/evaluate/test_jobs.py @@ -22,6 +22,13 @@ def test_method_list(self, client: LlamaStack) -> None: job = client.evaluate.jobs.list() assert_matches_type(EvaluationJob, job, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + job = client.evaluate.jobs.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, job, path=["response"]) + @parametrize def test_raw_response_list(self, client: LlamaStack) -> None: response = client.evaluate.jobs.with_raw_response.list() @@ -49,6 +56,14 @@ def test_method_cancel(self, client: LlamaStack) -> None: ) assert job is None + @parametrize + def test_method_cancel_with_all_params(self, client: LlamaStack) -> None: + job = client.evaluate.jobs.cancel( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + @parametrize def test_raw_response_cancel(self, client: LlamaStack) -> None: response = client.evaluate.jobs.with_raw_response.cancel( @@ -82,6 +97,13 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None: job = await async_client.evaluate.jobs.list() assert_matches_type(EvaluationJob, job, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.evaluate.jobs.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, job, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.jobs.with_raw_response.list() @@ -109,6 +131,14 @@ async def test_method_cancel(self, async_client: AsyncLlamaStack) -> None: ) assert job is None + @parametrize + async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.evaluate.jobs.cancel( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + @parametrize async def test_raw_response_cancel(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.jobs.with_raw_response.cancel( diff --git a/tests/api_resources/evaluate/test_question_answering.py b/tests/api_resources/evaluate/test_question_answering.py index b229620..d7b4a23 100644 --- a/tests/api_resources/evaluate/test_question_answering.py +++ b/tests/api_resources/evaluate/test_question_answering.py @@ -24,6 +24,14 @@ def test_method_create(self, client: LlamaStack) -> None: ) assert_matches_type(EvaluationJob, question_answering, path=["response"]) + @parametrize + def test_method_create_with_all_params(self, client: LlamaStack) -> None: + question_answering = client.evaluate.question_answering.create( + metrics=["em", "f1"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, question_answering, path=["response"]) + @parametrize def test_raw_response_create(self, client: LlamaStack) -> None: response = client.evaluate.question_answering.with_raw_response.create( @@ -59,6 +67,14 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(EvaluationJob, question_answering, path=["response"]) + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: + question_answering = await async_client.evaluate.question_answering.create( + metrics=["em", "f1"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, question_answering, path=["response"]) + @parametrize async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.question_answering.with_raw_response.create( diff --git a/tests/api_resources/inference/test_embeddings.py b/tests/api_resources/inference/test_embeddings.py index 62c90ce..b83a8ef 100644 --- a/tests/api_resources/inference/test_embeddings.py +++ b/tests/api_resources/inference/test_embeddings.py @@ -25,6 +25,15 @@ def test_method_create(self, client: LlamaStack) -> None: ) assert_matches_type(Embeddings, embedding, path=["response"]) + @parametrize + def test_method_create_with_all_params(self, client: LlamaStack) -> None: + embedding = client.inference.embeddings.create( + contents=["string", "string", "string"], + model="model", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Embeddings, embedding, path=["response"]) + @parametrize def test_raw_response_create(self, client: LlamaStack) -> None: response = client.inference.embeddings.with_raw_response.create( @@ -63,6 +72,15 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(Embeddings, embedding, path=["response"]) + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: + embedding = await async_client.inference.embeddings.create( + contents=["string", "string", "string"], + model="model", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Embeddings, embedding, path=["response"]) + @parametrize async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: response = await async_client.inference.embeddings.with_raw_response.create( diff --git a/tests/api_resources/memory_banks/__init__.py b/tests/api_resources/memory/__init__.py similarity index 100% rename from tests/api_resources/memory_banks/__init__.py rename to tests/api_resources/memory/__init__.py diff --git a/tests/api_resources/memory_banks/test_documents.py b/tests/api_resources/memory/test_documents.py similarity index 68% rename from tests/api_resources/memory_banks/test_documents.py rename to tests/api_resources/memory/test_documents.py index 04514eb..842efac 100644 --- a/tests/api_resources/memory_banks/test_documents.py +++ b/tests/api_resources/memory/test_documents.py @@ -9,7 +9,7 @@ from llama_stack import LlamaStack, AsyncLlamaStack from tests.utils import assert_matches_type -from llama_stack.types.memory_banks import DocumentRetrieveResponse +from llama_stack.types.memory import DocumentRetrieveResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -19,15 +19,24 @@ class TestDocuments: @parametrize def test_method_retrieve(self, client: LlamaStack) -> None: - document = client.memory_banks.documents.retrieve( + document = client.memory.documents.retrieve( bank_id="bank_id", document_ids=["string", "string", "string"], ) assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None: + document = client.memory.documents.retrieve( + bank_id="bank_id", + document_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @parametrize def test_raw_response_retrieve(self, client: LlamaStack) -> None: - response = client.memory_banks.documents.with_raw_response.retrieve( + response = client.memory.documents.with_raw_response.retrieve( bank_id="bank_id", document_ids=["string", "string", "string"], ) @@ -39,7 +48,7 @@ def test_raw_response_retrieve(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_retrieve(self, client: LlamaStack) -> None: - with client.memory_banks.documents.with_streaming_response.retrieve( + with client.memory.documents.with_streaming_response.retrieve( bank_id="bank_id", document_ids=["string", "string", "string"], ) as response: @@ -53,15 +62,24 @@ def test_streaming_response_retrieve(self, client: LlamaStack) -> None: @parametrize def test_method_delete(self, client: LlamaStack) -> None: - document = client.memory_banks.documents.delete( + document = client.memory.documents.delete( + bank_id="bank_id", + document_ids=["string", "string", "string"], + ) + assert document is None + + @parametrize + def test_method_delete_with_all_params(self, client: LlamaStack) -> None: + document = client.memory.documents.delete( bank_id="bank_id", document_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert document is None @parametrize def test_raw_response_delete(self, client: LlamaStack) -> None: - response = client.memory_banks.documents.with_raw_response.delete( + response = client.memory.documents.with_raw_response.delete( bank_id="bank_id", document_ids=["string", "string", "string"], ) @@ -73,7 +91,7 @@ def test_raw_response_delete(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_delete(self, client: LlamaStack) -> None: - with client.memory_banks.documents.with_streaming_response.delete( + with client.memory.documents.with_streaming_response.delete( bank_id="bank_id", document_ids=["string", "string", "string"], ) as response: @@ -91,15 +109,24 @@ class TestAsyncDocuments: @parametrize async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: - document = await async_client.memory_banks.documents.retrieve( + document = await async_client.memory.documents.retrieve( bank_id="bank_id", document_ids=["string", "string", "string"], ) assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None: + document = await async_client.memory.documents.retrieve( + bank_id="bank_id", + document_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.documents.with_raw_response.retrieve( + response = await async_client.memory.documents.with_raw_response.retrieve( bank_id="bank_id", document_ids=["string", "string", "string"], ) @@ -111,7 +138,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> Non @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.documents.with_streaming_response.retrieve( + async with async_client.memory.documents.with_streaming_response.retrieve( bank_id="bank_id", document_ids=["string", "string", "string"], ) as response: @@ -125,15 +152,24 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) @parametrize async def test_method_delete(self, async_client: AsyncLlamaStack) -> None: - document = await async_client.memory_banks.documents.delete( + document = await async_client.memory.documents.delete( + bank_id="bank_id", + document_ids=["string", "string", "string"], + ) + assert document is None + + @parametrize + async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None: + document = await async_client.memory.documents.delete( bank_id="bank_id", document_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert document is None @parametrize async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.documents.with_raw_response.delete( + response = await async_client.memory.documents.with_raw_response.delete( bank_id="bank_id", document_ids=["string", "string", "string"], ) @@ -145,7 +181,7 @@ async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None: @parametrize async def test_streaming_response_delete(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.documents.with_streaming_response.delete( + async with async_client.memory.documents.with_streaming_response.delete( bank_id="bank_id", document_ids=["string", "string", "string"], ) as response: diff --git a/tests/api_resources/post_training/test_jobs.py b/tests/api_resources/post_training/test_jobs.py index 2580031..ec2f8f0 100644 --- a/tests/api_resources/post_training/test_jobs.py +++ b/tests/api_resources/post_training/test_jobs.py @@ -27,6 +27,13 @@ def test_method_list(self, client: LlamaStack) -> None: job = client.post_training.jobs.list() assert_matches_type(PostTrainingJob, job, path=["response"]) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + job = client.post_training.jobs.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJob, job, path=["response"]) + @parametrize def test_raw_response_list(self, client: LlamaStack) -> None: response = client.post_training.jobs.with_raw_response.list() @@ -54,6 +61,14 @@ def test_method_artifacts(self, client: LlamaStack) -> None: ) assert_matches_type(PostTrainingJobArtifacts, job, path=["response"]) + @parametrize + def test_method_artifacts_with_all_params(self, client: LlamaStack) -> None: + job = client.post_training.jobs.artifacts( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJobArtifacts, job, path=["response"]) + @parametrize def test_raw_response_artifacts(self, client: LlamaStack) -> None: response = client.post_training.jobs.with_raw_response.artifacts( @@ -85,6 +100,14 @@ def test_method_cancel(self, client: LlamaStack) -> None: ) assert job is None + @parametrize + def test_method_cancel_with_all_params(self, client: LlamaStack) -> None: + job = client.post_training.jobs.cancel( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + @parametrize def test_raw_response_cancel(self, client: LlamaStack) -> None: response = client.post_training.jobs.with_raw_response.cancel( @@ -116,6 +139,14 @@ def test_method_logs(self, client: LlamaStack) -> None: ) assert_matches_type(PostTrainingJobLogStream, job, path=["response"]) + @parametrize + def test_method_logs_with_all_params(self, client: LlamaStack) -> None: + job = client.post_training.jobs.logs( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJobLogStream, job, path=["response"]) + @parametrize def test_raw_response_logs(self, client: LlamaStack) -> None: response = client.post_training.jobs.with_raw_response.logs( @@ -147,6 +178,14 @@ def test_method_status(self, client: LlamaStack) -> None: ) assert_matches_type(PostTrainingJobStatus, job, path=["response"]) + @parametrize + def test_method_status_with_all_params(self, client: LlamaStack) -> None: + job = client.post_training.jobs.status( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJobStatus, job, path=["response"]) + @parametrize def test_raw_response_status(self, client: LlamaStack) -> None: response = client.post_training.jobs.with_raw_response.status( @@ -180,6 +219,13 @@ async def test_method_list(self, async_client: AsyncLlamaStack) -> None: job = await async_client.post_training.jobs.list() assert_matches_type(PostTrainingJob, job, path=["response"]) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.post_training.jobs.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJob, job, path=["response"]) + @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.jobs.with_raw_response.list() @@ -207,6 +253,14 @@ async def test_method_artifacts(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(PostTrainingJobArtifacts, job, path=["response"]) + @parametrize + async def test_method_artifacts_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.post_training.jobs.artifacts( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJobArtifacts, job, path=["response"]) + @parametrize async def test_raw_response_artifacts(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.jobs.with_raw_response.artifacts( @@ -238,6 +292,14 @@ async def test_method_cancel(self, async_client: AsyncLlamaStack) -> None: ) assert job is None + @parametrize + async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.post_training.jobs.cancel( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + @parametrize async def test_raw_response_cancel(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.jobs.with_raw_response.cancel( @@ -269,6 +331,14 @@ async def test_method_logs(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(PostTrainingJobLogStream, job, path=["response"]) + @parametrize + async def test_method_logs_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.post_training.jobs.logs( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJobLogStream, job, path=["response"]) + @parametrize async def test_raw_response_logs(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.jobs.with_raw_response.logs( @@ -300,6 +370,14 @@ async def test_method_status(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(PostTrainingJobStatus, job, path=["response"]) + @parametrize + async def test_method_status_with_all_params(self, async_client: AsyncLlamaStack) -> None: + job = await async_client.post_training.jobs.status( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJobStatus, job, path=["response"]) + @parametrize async def test_raw_response_status(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.jobs.with_raw_response.status( diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index a9a46c9..355a0b1 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -21,7 +21,9 @@ class TestAgents: def test_method_create(self, client: LlamaStack) -> None: agent = client.agents.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", }, ) @@ -31,126 +33,12 @@ def test_method_create(self, client: LlamaStack) -> None: def test_method_create_with_all_params(self, client: LlamaStack) -> None: agent = client.agents.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "sampling_params": { "strategy": "greedy", "max_tokens": 0, @@ -166,124 +54,8 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None: "api_key": "api_key", "engine": "bing", "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "remote_execution": { "method": "GET", "url": "https://example.com", @@ -296,124 +68,8 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None: "api_key": "api_key", "engine": "bing", "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "remote_execution": { "method": "GET", "url": "https://example.com", @@ -426,124 +82,8 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None: "api_key": "api_key", "engine": "bing", "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "remote_execution": { "method": "GET", "url": "https://example.com", @@ -554,6 +94,7 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None: }, ], }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(AgentCreateResponse, agent, path=["response"]) @@ -561,7 +102,9 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None: def test_raw_response_create(self, client: LlamaStack) -> None: response = client.agents.with_raw_response.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", }, ) @@ -575,7 +118,9 @@ def test_raw_response_create(self, client: LlamaStack) -> None: def test_streaming_response_create(self, client: LlamaStack) -> None: with client.agents.with_streaming_response.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", }, ) as response: @@ -594,6 +139,14 @@ def test_method_delete(self, client: LlamaStack) -> None: ) assert agent is None + @parametrize + def test_method_delete_with_all_params(self, client: LlamaStack) -> None: + agent = client.agents.delete( + agent_id="agent_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert agent is None + @parametrize def test_raw_response_delete(self, client: LlamaStack) -> None: response = client.agents.with_raw_response.delete( @@ -626,7 +179,9 @@ class TestAsyncAgents: async def test_method_create(self, async_client: AsyncLlamaStack) -> None: agent = await async_client.agents.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", }, ) @@ -636,126 +191,12 @@ async def test_method_create(self, async_client: AsyncLlamaStack) -> None: async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: agent = await async_client.agents.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "sampling_params": { "strategy": "greedy", "max_tokens": 0, @@ -771,124 +212,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack "api_key": "api_key", "engine": "bing", "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "remote_execution": { "method": "GET", "url": "https://example.com", @@ -901,124 +226,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack "api_key": "api_key", "engine": "bing", "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "remote_execution": { "method": "GET", "url": "https://example.com", @@ -1031,124 +240,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack "api_key": "api_key", "engine": "bing", "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], + "input_shields": ["string", "string", "string"], + "output_shields": ["string", "string", "string"], "remote_execution": { "method": "GET", "url": "https://example.com", @@ -1159,6 +252,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack }, ], }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(AgentCreateResponse, agent, path=["response"]) @@ -1166,7 +260,9 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agents.with_raw_response.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", }, ) @@ -1180,7 +276,9 @@ async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: async with async_client.agents.with_streaming_response.create( agent_config={ + "enable_session_persistence": True, "instructions": "instructions", + "max_infer_iters": 0, "model": "model", }, ) as response: @@ -1199,6 +297,14 @@ async def test_method_delete(self, async_client: AsyncLlamaStack) -> None: ) assert agent is None + @parametrize + async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None: + agent = await async_client.agents.delete( + agent_id="agent_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert agent is None + @parametrize async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agents.with_raw_response.delete( diff --git a/tests/api_resources/test_batch_inference.py b/tests/api_resources/test_batch_inference.py index 13fd6bb..3277598 100644 --- a/tests/api_resources/test_batch_inference.py +++ b/tests/api_resources/test_batch_inference.py @@ -174,6 +174,7 @@ def test_method_chat_completion_with_all_params(self, client: LlamaStack) -> Non }, }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(BatchChatCompletion, batch_inference, path=["response"]) @@ -311,6 +312,7 @@ def test_method_completion_with_all_params(self, client: LlamaStack) -> None: "top_k": 0, "top_p": 0, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(BatchCompletion, batch_inference, path=["response"]) @@ -498,6 +500,7 @@ async def test_method_chat_completion_with_all_params(self, async_client: AsyncL }, }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(BatchChatCompletion, batch_inference, path=["response"]) @@ -635,6 +638,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS "top_k": 0, "top_p": 0, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(BatchCompletion, batch_inference, path=["response"]) diff --git a/tests/api_resources/test_datasets.py b/tests/api_resources/test_datasets.py index 4aec07d..8a0433b 100644 --- a/tests/api_resources/test_datasets.py +++ b/tests/api_resources/test_datasets.py @@ -37,6 +37,7 @@ def test_method_create_with_all_params(self, client: LlamaStack) -> None: "metadata": {"foo": True}, }, uuid="uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert dataset is None @@ -79,6 +80,14 @@ def test_method_delete(self, client: LlamaStack) -> None: ) assert dataset is None + @parametrize + def test_method_delete_with_all_params(self, client: LlamaStack) -> None: + dataset = client.datasets.delete( + dataset_uuid="dataset_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert dataset is None + @parametrize def test_raw_response_delete(self, client: LlamaStack) -> None: response = client.datasets.with_raw_response.delete( @@ -110,6 +119,14 @@ def test_method_get(self, client: LlamaStack) -> None: ) assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + @parametrize + def test_method_get_with_all_params(self, client: LlamaStack) -> None: + dataset = client.datasets.get( + dataset_uuid="dataset_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + @parametrize def test_raw_response_get(self, client: LlamaStack) -> None: response = client.datasets.with_raw_response.get( @@ -158,6 +175,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack "metadata": {"foo": True}, }, uuid="uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert dataset is None @@ -200,6 +218,14 @@ async def test_method_delete(self, async_client: AsyncLlamaStack) -> None: ) assert dataset is None + @parametrize + async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStack) -> None: + dataset = await async_client.datasets.delete( + dataset_uuid="dataset_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert dataset is None + @parametrize async def test_raw_response_delete(self, async_client: AsyncLlamaStack) -> None: response = await async_client.datasets.with_raw_response.delete( @@ -231,6 +257,14 @@ async def test_method_get(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + @parametrize + async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None: + dataset = await async_client.datasets.get( + dataset_uuid="dataset_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + @parametrize async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None: response = await async_client.datasets.with_raw_response.get( diff --git a/tests/api_resources/test_evaluations.py b/tests/api_resources/test_evaluations.py index dbdf834..7d6a275 100644 --- a/tests/api_resources/test_evaluations.py +++ b/tests/api_resources/test_evaluations.py @@ -24,6 +24,14 @@ def test_method_summarization(self, client: LlamaStack) -> None: ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize + def test_method_summarization_with_all_params(self, client: LlamaStack) -> None: + evaluation = client.evaluations.summarization( + metrics=["rouge", "bleu"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize def test_raw_response_summarization(self, client: LlamaStack) -> None: response = client.evaluations.with_raw_response.summarization( @@ -55,6 +63,14 @@ def test_method_text_generation(self, client: LlamaStack) -> None: ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize + def test_method_text_generation_with_all_params(self, client: LlamaStack) -> None: + evaluation = client.evaluations.text_generation( + metrics=["perplexity", "rouge", "bleu"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize def test_raw_response_text_generation(self, client: LlamaStack) -> None: response = client.evaluations.with_raw_response.text_generation( @@ -90,6 +106,14 @@ async def test_method_summarization(self, async_client: AsyncLlamaStack) -> None ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize + async def test_method_summarization_with_all_params(self, async_client: AsyncLlamaStack) -> None: + evaluation = await async_client.evaluations.summarization( + metrics=["rouge", "bleu"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize async def test_raw_response_summarization(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluations.with_raw_response.summarization( @@ -121,6 +145,14 @@ async def test_method_text_generation(self, async_client: AsyncLlamaStack) -> No ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize + async def test_method_text_generation_with_all_params(self, async_client: AsyncLlamaStack) -> None: + evaluation = await async_client.evaluations.text_generation( + metrics=["perplexity", "rouge", "bleu"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluationJob, evaluation, path=["response"]) + @parametrize async def test_raw_response_text_generation(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluations.with_raw_response.text_generation( diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py index a7b80b6..7b20e11 100644 --- a/tests/api_resources/test_inference.py +++ b/tests/api_resources/test_inference.py @@ -109,6 +109,7 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt }, }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @@ -254,6 +255,7 @@ def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaSt }, }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) inference_stream.response.close() @@ -333,6 +335,7 @@ def test_method_completion_with_all_params(self, client: LlamaStack) -> None: "top_p": 0, }, stream=True, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @@ -455,6 +458,7 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli }, }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @@ -600,6 +604,7 @@ async def test_method_chat_completion_with_all_params_overload_2(self, async_cli }, }, ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) await inference_stream.response.aclose() @@ -679,6 +684,7 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS "top_p": 0, }, stream=True, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) diff --git a/tests/api_resources/test_memory.py b/tests/api_resources/test_memory.py new file mode 100644 index 0000000..2d1cc05 --- /dev/null +++ b/tests/api_resources/test_memory.py @@ -0,0 +1,852 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from llama_stack import LlamaStack, AsyncLlamaStack +from tests.utils import assert_matches_type +from llama_stack.types import ( + QueryDocuments, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestMemory: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create(self, client: LlamaStack) -> None: + memory = client.memory.create( + body={}, + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_method_create_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.create( + body={}, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.create( + body={}, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_streaming_response_create(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.create( + body={}, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert_matches_type(object, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_retrieve(self, client: LlamaStack) -> None: + memory = client.memory.retrieve( + bank_id="bank_id", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.retrieve( + bank_id="bank_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.retrieve( + bank_id="bank_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.retrieve( + bank_id="bank_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert_matches_type(object, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_update(self, client: LlamaStack) -> None: + memory = client.memory.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + assert memory is None + + @parametrize + def test_method_update_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert memory is None + + @parametrize + def test_raw_response_update(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert memory is None + + @parametrize + def test_streaming_response_update(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert memory is None + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_list(self, client: LlamaStack) -> None: + memory = client.memory.list() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert_matches_type(object, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_drop(self, client: LlamaStack) -> None: + memory = client.memory.drop( + bank_id="bank_id", + ) + assert_matches_type(str, memory, path=["response"]) + + @parametrize + def test_method_drop_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.drop( + bank_id="bank_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(str, memory, path=["response"]) + + @parametrize + def test_raw_response_drop(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.drop( + bank_id="bank_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert_matches_type(str, memory, path=["response"]) + + @parametrize + def test_streaming_response_drop(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.drop( + bank_id="bank_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert_matches_type(str, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_insert(self, client: LlamaStack) -> None: + memory = client.memory.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + assert memory is None + + @parametrize + def test_method_insert_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + ], + ttl_seconds=0, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert memory is None + + @parametrize + def test_raw_response_insert(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert memory is None + + @parametrize + def test_streaming_response_insert(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert memory is None + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_query(self, client: LlamaStack) -> None: + memory = client.memory.query( + bank_id="bank_id", + query="string", + ) + assert_matches_type(QueryDocuments, memory, path=["response"]) + + @parametrize + def test_method_query_with_all_params(self, client: LlamaStack) -> None: + memory = client.memory.query( + bank_id="bank_id", + query="string", + params={"foo": True}, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(QueryDocuments, memory, path=["response"]) + + @parametrize + def test_raw_response_query(self, client: LlamaStack) -> None: + response = client.memory.with_raw_response.query( + bank_id="bank_id", + query="string", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = response.parse() + assert_matches_type(QueryDocuments, memory, path=["response"]) + + @parametrize + def test_streaming_response_query(self, client: LlamaStack) -> None: + with client.memory.with_streaming_response.query( + bank_id="bank_id", + query="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = response.parse() + assert_matches_type(QueryDocuments, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncMemory: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_create(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.create( + body={}, + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.create( + body={}, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.create( + body={}, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.create( + body={}, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert_matches_type(object, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.retrieve( + bank_id="bank_id", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.retrieve( + bank_id="bank_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.retrieve( + bank_id="bank_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.retrieve( + bank_id="bank_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert_matches_type(object, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_update(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + assert memory is None + + @parametrize + async def test_method_update_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert memory is None + + @parametrize + async def test_raw_response_update(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert memory is None + + @parametrize + async def test_streaming_response_update(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.update( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert memory is None + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.list() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert_matches_type(object, memory, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert_matches_type(object, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_drop(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.drop( + bank_id="bank_id", + ) + assert_matches_type(str, memory, path=["response"]) + + @parametrize + async def test_method_drop_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.drop( + bank_id="bank_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(str, memory, path=["response"]) + + @parametrize + async def test_raw_response_drop(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.drop( + bank_id="bank_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert_matches_type(str, memory, path=["response"]) + + @parametrize + async def test_streaming_response_drop(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.drop( + bank_id="bank_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert_matches_type(str, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_insert(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + assert memory is None + + @parametrize + async def test_method_insert_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + "mime_type": "mime_type", + }, + ], + ttl_seconds=0, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert memory is None + + @parametrize + async def test_raw_response_insert(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert memory is None + + @parametrize + async def test_streaming_response_insert(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.insert( + bank_id="bank_id", + documents=[ + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + { + "content": "string", + "document_id": "document_id", + "metadata": {"foo": True}, + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert memory is None + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_query(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.query( + bank_id="bank_id", + query="string", + ) + assert_matches_type(QueryDocuments, memory, path=["response"]) + + @parametrize + async def test_method_query_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory = await async_client.memory.query( + bank_id="bank_id", + query="string", + params={"foo": True}, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(QueryDocuments, memory, path=["response"]) + + @parametrize + async def test_raw_response_query(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory.with_raw_response.query( + bank_id="bank_id", + query="string", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory = await response.parse() + assert_matches_type(QueryDocuments, memory, path=["response"]) + + @parametrize + async def test_streaming_response_query(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory.with_streaming_response.query( + bank_id="bank_id", + query="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory = await response.parse() + assert_matches_type(QueryDocuments, memory, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_memory_banks.py b/tests/api_resources/test_memory_banks.py index 9b3f433..93708b1 100644 --- a/tests/api_resources/test_memory_banks.py +++ b/tests/api_resources/test_memory_banks.py @@ -3,15 +3,13 @@ from __future__ import annotations import os -from typing import Any, cast +from typing import Any, Optional, cast import pytest from llama_stack import LlamaStack, AsyncLlamaStack from tests.utils import assert_matches_type -from llama_stack.types import ( - QueryDocuments, -) +from llama_stack.types import MemoryBankSpec base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -20,153 +18,16 @@ class TestMemoryBanks: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - def test_method_create(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.create( - body={}, - ) - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - def test_raw_response_create(self, client: LlamaStack) -> None: - response = client.memory_banks.with_raw_response.create( - body={}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - def test_streaming_response_create(self, client: LlamaStack) -> None: - with client.memory_banks.with_streaming_response.create( - body={}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_retrieve(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.retrieve( - bank_id="bank_id", - ) - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - def test_raw_response_retrieve(self, client: LlamaStack) -> None: - response = client.memory_banks.with_raw_response.retrieve( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - def test_streaming_response_retrieve(self, client: LlamaStack) -> None: - with client.memory_banks.with_streaming_response.retrieve( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_update(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - assert memory_bank is None + def test_method_list(self, client: LlamaStack) -> None: + memory_bank = client.memory_banks.list() + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) @parametrize - def test_raw_response_update(self, client: LlamaStack) -> None: - response = client.memory_banks.with_raw_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + memory_bank = client.memory_banks.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = response.parse() - assert memory_bank is None - - @parametrize - def test_streaming_response_update(self, client: LlamaStack) -> None: - with client.memory_banks.with_streaming_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = response.parse() - assert memory_bank is None - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_list(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.list() - assert_matches_type(object, memory_bank, path=["response"]) + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) @parametrize def test_raw_response_list(self, client: LlamaStack) -> None: @@ -175,7 +36,7 @@ def test_raw_response_list(self, client: LlamaStack) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(object, memory_bank, path=["response"]) + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) @parametrize def test_streaming_response_list(self, client: LlamaStack) -> None: @@ -184,191 +45,46 @@ def test_streaming_response_list(self, client: LlamaStack) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_drop(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.drop( - bank_id="bank_id", - ) - assert_matches_type(str, memory_bank, path=["response"]) - - @parametrize - def test_raw_response_drop(self, client: LlamaStack) -> None: - response = client.memory_banks.with_raw_response.drop( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = response.parse() - assert_matches_type(str, memory_bank, path=["response"]) - - @parametrize - def test_streaming_response_drop(self, client: LlamaStack) -> None: - with client.memory_banks.with_streaming_response.drop( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = response.parse() - assert_matches_type(str, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_insert(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - assert memory_bank is None - - @parametrize - def test_method_insert_with_all_params(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - ], - ttl_seconds=0, - ) - assert memory_bank is None - - @parametrize - def test_raw_response_insert(self, client: LlamaStack) -> None: - response = client.memory_banks.with_raw_response.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = response.parse() - assert memory_bank is None - - @parametrize - def test_streaming_response_insert(self, client: LlamaStack) -> None: - with client.memory_banks.with_streaming_response.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = response.parse() - assert memory_bank is None + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_query(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.query( - bank_id="bank_id", - query="string", + def test_method_get(self, client: LlamaStack) -> None: + memory_bank = client.memory_banks.get( + bank_type="vector", ) - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) @parametrize - def test_method_query_with_all_params(self, client: LlamaStack) -> None: - memory_bank = client.memory_banks.query( - bank_id="bank_id", - query="string", - params={"foo": True}, + def test_method_get_with_all_params(self, client: LlamaStack) -> None: + memory_bank = client.memory_banks.get( + bank_type="vector", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) @parametrize - def test_raw_response_query(self, client: LlamaStack) -> None: - response = client.memory_banks.with_raw_response.query( - bank_id="bank_id", - query="string", + def test_raw_response_get(self, client: LlamaStack) -> None: + response = client.memory_banks.with_raw_response.get( + bank_type="vector", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) @parametrize - def test_streaming_response_query(self, client: LlamaStack) -> None: - with client.memory_banks.with_streaming_response.query( - bank_id="bank_id", - query="string", + def test_streaming_response_get(self, client: LlamaStack) -> None: + with client.memory_banks.with_streaming_response.get( + bank_type="vector", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) assert cast(Any, response.is_closed) is True @@ -377,153 +93,16 @@ class TestAsyncMemoryBanks: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.create( - body={}, - ) - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.with_raw_response.create( - body={}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = await response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.with_streaming_response.create( - body={}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = await response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.retrieve( - bank_id="bank_id", - ) - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.with_raw_response.retrieve( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = await response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.with_streaming_response.retrieve( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = await response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_update(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - assert memory_bank is None + async def test_method_list(self, async_client: AsyncLlamaStack) -> None: + memory_bank = await async_client.memory_banks.list() + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) @parametrize - async def test_raw_response_update(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.with_raw_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory_bank = await async_client.memory_banks.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = await response.parse() - assert memory_bank is None - - @parametrize - async def test_streaming_response_update(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.with_streaming_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = await response.parse() - assert memory_bank is None - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_list(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.list() - assert_matches_type(object, memory_bank, path=["response"]) + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: @@ -532,7 +111,7 @@ async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(object, memory_bank, path=["response"]) + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) @parametrize async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None: @@ -541,190 +120,45 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> N assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(object, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_drop(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.drop( - bank_id="bank_id", - ) - assert_matches_type(str, memory_bank, path=["response"]) - - @parametrize - async def test_raw_response_drop(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.with_raw_response.drop( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = await response.parse() - assert_matches_type(str, memory_bank, path=["response"]) - - @parametrize - async def test_streaming_response_drop(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.with_streaming_response.drop( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = await response.parse() - assert_matches_type(str, memory_bank, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_insert(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - assert memory_bank is None - - @parametrize - async def test_method_insert_with_all_params(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - ], - ttl_seconds=0, - ) - assert memory_bank is None - - @parametrize - async def test_raw_response_insert(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.with_raw_response.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory_bank = await response.parse() - assert memory_bank is None - - @parametrize - async def test_streaming_response_insert(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.with_streaming_response.insert( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory_bank = await response.parse() - assert memory_bank is None + assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_query(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.query( - bank_id="bank_id", - query="string", + async def test_method_get(self, async_client: AsyncLlamaStack) -> None: + memory_bank = await async_client.memory_banks.get( + bank_type="vector", ) - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) @parametrize - async def test_method_query_with_all_params(self, async_client: AsyncLlamaStack) -> None: - memory_bank = await async_client.memory_banks.query( - bank_id="bank_id", - query="string", - params={"foo": True}, + async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None: + memory_bank = await async_client.memory_banks.get( + bank_type="vector", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) @parametrize - async def test_raw_response_query(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.memory_banks.with_raw_response.query( - bank_id="bank_id", - query="string", + async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.memory_banks.with_raw_response.get( + bank_type="vector", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) @parametrize - async def test_streaming_response_query(self, async_client: AsyncLlamaStack) -> None: - async with async_client.memory_banks.with_streaming_response.query( - bank_id="bank_id", - query="string", + async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None: + async with async_client.memory_banks.with_streaming_response.get( + bank_type="vector", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(QueryDocuments, memory_bank, path=["response"]) + assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py new file mode 100644 index 0000000..47be556 --- /dev/null +++ b/tests/api_resources/test_models.py @@ -0,0 +1,164 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, Optional, cast + +import pytest + +from llama_stack import LlamaStack, AsyncLlamaStack +from tests.utils import assert_matches_type +from llama_stack.types import ModelServingSpec + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestModels: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_list(self, client: LlamaStack) -> None: + model = client.models.list() + assert_matches_type(ModelServingSpec, model, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + model = client.models.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ModelServingSpec, model, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: LlamaStack) -> None: + response = client.models.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert_matches_type(ModelServingSpec, model, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: LlamaStack) -> None: + with client.models.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert_matches_type(ModelServingSpec, model, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_get(self, client: LlamaStack) -> None: + model = client.models.get( + core_model_id="core_model_id", + ) + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + @parametrize + def test_method_get_with_all_params(self, client: LlamaStack) -> None: + model = client.models.get( + core_model_id="core_model_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + @parametrize + def test_raw_response_get(self, client: LlamaStack) -> None: + response = client.models.with_raw_response.get( + core_model_id="core_model_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + @parametrize + def test_streaming_response_get(self, client: LlamaStack) -> None: + with client.models.with_streaming_response.get( + core_model_id="core_model_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncModels: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStack) -> None: + model = await async_client.models.list() + assert_matches_type(ModelServingSpec, model, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + model = await async_client.models.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ModelServingSpec, model, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.models.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert_matches_type(ModelServingSpec, model, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None: + async with async_client.models.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert_matches_type(ModelServingSpec, model, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_get(self, async_client: AsyncLlamaStack) -> None: + model = await async_client.models.get( + core_model_id="core_model_id", + ) + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + @parametrize + async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None: + model = await async_client.models.get( + core_model_id="core_model_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + @parametrize + async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.models.with_raw_response.get( + core_model_id="core_model_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + @parametrize + async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None: + async with async_client.models.with_streaming_response.get( + core_model_id="core_model_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_post_training.py b/tests/api_resources/test_post_training.py index 588f159..9cee8d0 100644 --- a/tests/api_resources/test_post_training.py +++ b/tests/api_resources/test_post_training.py @@ -98,6 +98,7 @@ def test_method_preference_optimize_with_all_params(self, client: LlamaStack) -> "content_url": "https://example.com", "metadata": {"foo": True}, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -272,6 +273,7 @@ def test_method_supervised_fine_tune_with_all_params(self, client: LlamaStack) - "content_url": "https://example.com", "metadata": {"foo": True}, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -450,6 +452,7 @@ async def test_method_preference_optimize_with_all_params(self, async_client: As "content_url": "https://example.com", "metadata": {"foo": True}, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -624,6 +627,7 @@ async def test_method_supervised_fine_tune_with_all_params(self, async_client: A "content_url": "https://example.com", "metadata": {"foo": True}, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) diff --git a/tests/api_resources/test_reward_scoring.py b/tests/api_resources/test_reward_scoring.py index 83f6983..eaa9e43 100644 --- a/tests/api_resources/test_reward_scoring.py +++ b/tests/api_resources/test_reward_scoring.py @@ -116,6 +116,124 @@ def test_method_score(self, client: LlamaStack) -> None: ) assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + @parametrize + def test_method_score_with_all_params(self, client: LlamaStack) -> None: + reward_scoring = client.reward_scoring.score( + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + }, + ], + model="model", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + @parametrize def test_raw_response_score(self, client: LlamaStack) -> None: response = client.reward_scoring.with_raw_response.score( @@ -427,6 +545,124 @@ async def test_method_score(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + @parametrize + async def test_method_score_with_all_params(self, async_client: AsyncLlamaStack) -> None: + reward_scoring = await async_client.reward_scoring.score( + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + }, + ], + model="model", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + @parametrize async def test_raw_response_score(self, async_client: AsyncLlamaStack) -> None: response = await async_client.reward_scoring.with_raw_response.score( diff --git a/tests/api_resources/test_safety.py b/tests/api_resources/test_safety.py index f4b44dc..eccaf8a 100644 --- a/tests/api_resources/test_safety.py +++ b/tests/api_resources/test_safety.py @@ -9,7 +9,7 @@ from llama_stack import LlamaStack, AsyncLlamaStack from tests.utils import assert_matches_type -from llama_stack.types import SafetyRunShieldsResponse +from llama_stack.types import RunSheidResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -18,8 +18,8 @@ class TestSafety: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - def test_method_run_shields(self, client: LlamaStack) -> None: - safety = client.safety.run_shields( + def test_method_run_shield(self, client: LlamaStack) -> None: + safety = client.safety.run_shield( messages=[ { "content": "string", @@ -34,64 +34,66 @@ def test_method_run_shields(self, client: LlamaStack) -> None: "role": "user", }, ], - shields=[ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], + params={"foo": True}, + shield_type="shield_type", ) - assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) + assert_matches_type(RunSheidResponse, safety, path=["response"]) @parametrize - def test_raw_response_run_shields(self, client: LlamaStack) -> None: - response = client.safety.with_raw_response.run_shields( + def test_method_run_shield_with_all_params(self, client: LlamaStack) -> None: + safety = client.safety.run_shield( messages=[ { "content": "string", "role": "user", + "context": "string", }, { "content": "string", "role": "user", + "context": "string", }, { "content": "string", "role": "user", + "context": "string", }, ], - shields=[ + params={"foo": True}, + shield_type="shield_type", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(RunSheidResponse, safety, path=["response"]) + + @parametrize + def test_raw_response_run_shield(self, client: LlamaStack) -> None: + response = client.safety.with_raw_response.run_shield( + messages=[ { - "on_violation_action": 0, - "shield_type": "llama_guard", + "content": "string", + "role": "user", }, { - "on_violation_action": 0, - "shield_type": "llama_guard", + "content": "string", + "role": "user", }, { - "on_violation_action": 0, - "shield_type": "llama_guard", + "content": "string", + "role": "user", }, ], + params={"foo": True}, + shield_type="shield_type", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = response.parse() - assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) + assert_matches_type(RunSheidResponse, safety, path=["response"]) @parametrize - def test_streaming_response_run_shields(self, client: LlamaStack) -> None: - with client.safety.with_streaming_response.run_shields( + def test_streaming_response_run_shield(self, client: LlamaStack) -> None: + with client.safety.with_streaming_response.run_shield( messages=[ { "content": "string", @@ -106,26 +108,14 @@ def test_streaming_response_run_shields(self, client: LlamaStack) -> None: "role": "user", }, ], - shields=[ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], + params={"foo": True}, + shield_type="shield_type", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = response.parse() - assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) + assert_matches_type(RunSheidResponse, safety, path=["response"]) assert cast(Any, response.is_closed) is True @@ -134,8 +124,8 @@ class TestAsyncSafety: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_run_shields(self, async_client: AsyncLlamaStack) -> None: - safety = await async_client.safety.run_shields( + async def test_method_run_shield(self, async_client: AsyncLlamaStack) -> None: + safety = await async_client.safety.run_shield( messages=[ { "content": "string", @@ -150,64 +140,66 @@ async def test_method_run_shields(self, async_client: AsyncLlamaStack) -> None: "role": "user", }, ], - shields=[ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], + params={"foo": True}, + shield_type="shield_type", ) - assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) + assert_matches_type(RunSheidResponse, safety, path=["response"]) @parametrize - async def test_raw_response_run_shields(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.safety.with_raw_response.run_shields( + async def test_method_run_shield_with_all_params(self, async_client: AsyncLlamaStack) -> None: + safety = await async_client.safety.run_shield( messages=[ { "content": "string", "role": "user", + "context": "string", }, { "content": "string", "role": "user", + "context": "string", }, { "content": "string", "role": "user", + "context": "string", }, ], - shields=[ + params={"foo": True}, + shield_type="shield_type", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(RunSheidResponse, safety, path=["response"]) + + @parametrize + async def test_raw_response_run_shield(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.safety.with_raw_response.run_shield( + messages=[ { - "on_violation_action": 0, - "shield_type": "llama_guard", + "content": "string", + "role": "user", }, { - "on_violation_action": 0, - "shield_type": "llama_guard", + "content": "string", + "role": "user", }, { - "on_violation_action": 0, - "shield_type": "llama_guard", + "content": "string", + "role": "user", }, ], + params={"foo": True}, + shield_type="shield_type", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = await response.parse() - assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) + assert_matches_type(RunSheidResponse, safety, path=["response"]) @parametrize - async def test_streaming_response_run_shields(self, async_client: AsyncLlamaStack) -> None: - async with async_client.safety.with_streaming_response.run_shields( + async def test_streaming_response_run_shield(self, async_client: AsyncLlamaStack) -> None: + async with async_client.safety.with_streaming_response.run_shield( messages=[ { "content": "string", @@ -222,25 +214,13 @@ async def test_streaming_response_run_shields(self, async_client: AsyncLlamaStac "role": "user", }, ], - shields=[ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], + params={"foo": True}, + shield_type="shield_type", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = await response.parse() - assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) + assert_matches_type(RunSheidResponse, safety, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_shields.py b/tests/api_resources/test_shields.py new file mode 100644 index 0000000..fdc6728 --- /dev/null +++ b/tests/api_resources/test_shields.py @@ -0,0 +1,164 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, Optional, cast + +import pytest + +from llama_stack import LlamaStack, AsyncLlamaStack +from tests.utils import assert_matches_type +from llama_stack.types import ShieldSpec + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestShields: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_list(self, client: LlamaStack) -> None: + shield = client.shields.list() + assert_matches_type(ShieldSpec, shield, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: LlamaStack) -> None: + shield = client.shields.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ShieldSpec, shield, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: LlamaStack) -> None: + response = client.shields.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + shield = response.parse() + assert_matches_type(ShieldSpec, shield, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: LlamaStack) -> None: + with client.shields.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + shield = response.parse() + assert_matches_type(ShieldSpec, shield, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_get(self, client: LlamaStack) -> None: + shield = client.shields.get( + shield_type="shield_type", + ) + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + @parametrize + def test_method_get_with_all_params(self, client: LlamaStack) -> None: + shield = client.shields.get( + shield_type="shield_type", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + @parametrize + def test_raw_response_get(self, client: LlamaStack) -> None: + response = client.shields.with_raw_response.get( + shield_type="shield_type", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + shield = response.parse() + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + @parametrize + def test_streaming_response_get(self, client: LlamaStack) -> None: + with client.shields.with_streaming_response.get( + shield_type="shield_type", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + shield = response.parse() + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncShields: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStack) -> None: + shield = await async_client.shields.list() + assert_matches_type(ShieldSpec, shield, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStack) -> None: + shield = await async_client.shields.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ShieldSpec, shield, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.shields.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + shield = await response.parse() + assert_matches_type(ShieldSpec, shield, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None: + async with async_client.shields.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + shield = await response.parse() + assert_matches_type(ShieldSpec, shield, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_get(self, async_client: AsyncLlamaStack) -> None: + shield = await async_client.shields.get( + shield_type="shield_type", + ) + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + @parametrize + async def test_method_get_with_all_params(self, async_client: AsyncLlamaStack) -> None: + shield = await async_client.shields.get( + shield_type="shield_type", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + @parametrize + async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.shields.with_raw_response.get( + shield_type="shield_type", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + shield = await response.parse() + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + @parametrize + async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None: + async with async_client.shields.with_streaming_response.get( + shield_type="shield_type", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + shield = await response.parse() + assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_synthetic_data_generation.py b/tests/api_resources/test_synthetic_data_generation.py index 7f5d73b..5c4bd93 100644 --- a/tests/api_resources/test_synthetic_data_generation.py +++ b/tests/api_resources/test_synthetic_data_generation.py @@ -60,6 +60,7 @@ def test_method_generate_with_all_params(self, client: LlamaStack) -> None: ], filtering_function="none", model="model", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) @@ -162,6 +163,7 @@ async def test_method_generate_with_all_params(self, async_client: AsyncLlamaSta ], filtering_function="none", model="model", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) diff --git a/tests/api_resources/test_telemetry.py b/tests/api_resources/test_telemetry.py index 317bb8b..b7a7bda 100644 --- a/tests/api_resources/test_telemetry.py +++ b/tests/api_resources/test_telemetry.py @@ -25,6 +25,14 @@ def test_method_get_trace(self, client: LlamaStack) -> None: ) assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + @parametrize + def test_method_get_trace_with_all_params(self, client: LlamaStack) -> None: + telemetry = client.telemetry.get_trace( + trace_id="trace_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + @parametrize def test_raw_response_get_trace(self, client: LlamaStack) -> None: response = client.telemetry.with_raw_response.get_trace( @@ -75,6 +83,7 @@ def test_method_log_with_all_params(self, client: LlamaStack) -> None: "type": "unstructured_log", "attributes": {"foo": True}, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert telemetry is None @@ -127,6 +136,14 @@ async def test_method_get_trace(self, async_client: AsyncLlamaStack) -> None: ) assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + @parametrize + async def test_method_get_trace_with_all_params(self, async_client: AsyncLlamaStack) -> None: + telemetry = await async_client.telemetry.get_trace( + trace_id="trace_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + @parametrize async def test_raw_response_get_trace(self, async_client: AsyncLlamaStack) -> None: response = await async_client.telemetry.with_raw_response.get_trace( @@ -177,6 +194,7 @@ async def test_method_log_with_all_params(self, async_client: AsyncLlamaStack) - "type": "unstructured_log", "attributes": {"foo": True}, }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert telemetry is None diff --git a/tests/test_client.py b/tests/test_client.py index ea89ceb..adbd9ad 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -720,6 +720,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: response = client.agents.sessions.with_raw_response.create(agent_id="agent_id", session_name="session_name") assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success class TestAsyncLlamaStack: @@ -1405,3 +1406,4 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: ) assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success