diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d201bfb..89a70a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,13 +31,13 @@ $ pip install -r requirements-dev.lock ## Modifying/Adding code -Most of the SDK is generated code, and any modified code will be overridden on the next generation. The -`src/llama_stack/lib/` and `examples/` directories are exceptions and will never be overridden. +Most of the SDK is generated code. Modifications to code will be persisted between generations, but may +result in merge conflicts between manual patches and changes from the generator. The generator will never +modify the contents of the `src/llama_stack/lib/` and `examples/` directories. ## Adding and running examples -All files in the `examples/` directory are not modified by the Stainless generator and can be freely edited or -added to. +All files in the `examples/` directory are not modified by the generator and can be freely edited or added to. ```bash # add an example to examples/.py diff --git a/README.md b/README.md index 2efd1fd..59d7355 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ The REST API documentation can be found on [docs.llama-stack.todo](https://docs. ## Installation ```sh -pip install llama-stack +# install from this staging repo +pip install llama-stack-client ``` ## Usage @@ -30,13 +31,11 @@ client = LlamaStack( environment="sandbox", ) -agentic_system_create_response = client.agentic_system.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, +session = client.agentic_system.sessions.create( + agent_id="agent_id", + session_name="session_name", ) -print(agentic_system_create_response.agent_id) +print(session.session_id) ``` ## Async usage @@ -54,13 +53,11 @@ client = AsyncLlamaStack( async def main() -> None: - agentic_system_create_response = await client.agentic_system.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, + session = await client.agentic_system.sessions.create( + agent_id="agent_id", + session_name="session_name", ) - print(agentic_system_create_response.agent_id) + print(session.session_id) asyncio.run(main()) @@ -93,11 +90,9 @@ from llama_stack import LlamaStack client = LlamaStack() try: - client.agentic_system.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, + client.agentic_system.sessions.create( + agent_id="agent_id", + session_name="session_name", ) except llama_stack.APIConnectionError as e: print("The server could not be reached") @@ -141,11 +136,9 @@ client = LlamaStack( ) # Or, configure per-request: -client.with_options(max_retries=5).agentic_system.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, +client.with_options(max_retries=5).agentic_system.sessions.create( + agent_id="agent_id", + session_name="session_name", ) ``` @@ -169,11 +162,9 @@ client = LlamaStack( ) # Override per-request: -client.with_options(timeout=5.0).agentic_system.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, +client.with_options(timeout=5.0).agentic_system.sessions.create( + agent_id="agent_id", + session_name="session_name", ) ``` @@ -213,16 +204,14 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to from llama_stack import LlamaStack client = LlamaStack() -response = client.agentic_system.with_raw_response.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, +response = client.agentic_system.sessions.with_raw_response.create( + agent_id="agent_id", + session_name="session_name", ) print(response.headers.get('X-My-Header')) -agentic_system = response.parse() # get the object that `agentic_system.create()` would have returned -print(agentic_system.agent_id) +session = response.parse() # get the object that `agentic_system.sessions.create()` would have returned +print(session.session_id) ``` These methods return an [`APIResponse`](https://github.com/stainless-sdks/llama-stack-python/tree/main/src/llama_stack/_response.py) object. @@ -236,11 +225,9 @@ The above interface eagerly reads the full response body when you make the reque To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods. ```python -with client.agentic_system.with_streaming_response.create( - agent_config={ - "instructions": "instructions", - "model": "model", - }, +with client.agentic_system.sessions.with_streaming_response.create( + agent_id="agent_id", + session_name="session_name", ) as response: print(response.headers.get("X-My-Header")) diff --git a/api.md b/api.md index 9ee4c6a..6ca4424 100644 --- a/api.md +++ b/api.md @@ -2,11 +2,9 @@ ```python from llama_stack.types import ( - Artifact, Attachment, BatchCompletion, CompletionMessage, - Run, SamplingParams, SystemMessage, ToolCall, @@ -15,13 +13,29 @@ from llama_stack.types import ( ) ``` +# Telemetry + +Types: + +```python +from llama_stack.types import TelemetryGetTraceResponse +``` + +Methods: + +- client.telemetry.get_trace(\*\*params) -> TelemetryGetTraceResponse +- client.telemetry.log(\*\*params) -> None + # AgenticSystem Types: ```python from llama_stack.types import ( + CustomQueryGeneratorConfig, + DefaultQueryGeneratorConfig, InferenceStep, + LlmQueryGeneratorConfig, MemoryRetrievalStep, RestAPIExecutionConfig, ShieldCallStep, @@ -68,7 +82,7 @@ Methods: Types: ```python -from llama_stack.types.agentic_system import AgenticSystemTurnStreamChunk, Turn +from llama_stack.types.agentic_system import AgenticSystemTurnStreamChunk, Turn, TurnStreamEvent ``` Methods: @@ -76,12 +90,6 @@ Methods: - client.agentic_system.turns.create(\*\*params) -> AgenticSystemTurnStreamChunk - client.agentic_system.turns.retrieve(\*\*params) -> Turn -# Artifacts - -Methods: - -- client.artifacts.get(\*\*params) -> Artifact - # Datasets Types: @@ -152,41 +160,24 @@ Methods: - client.evaluations.summarization(\*\*params) -> EvaluationJob - client.evaluations.text_generation(\*\*params) -> EvaluationJob -# Experiments - -Types: - -```python -from llama_stack.types import Experiment -``` - -Methods: - -- client.experiments.create(\*\*params) -> Experiment -- client.experiments.retrieve(\*\*params) -> Experiment -- client.experiments.update(\*\*params) -> Experiment -- client.experiments.list() -> Experiment -- client.experiments.create_run(\*\*params) -> Run - -## Artifacts - -Methods: - -- client.experiments.artifacts.retrieve(\*\*params) -> Artifact -- client.experiments.artifacts.upload(\*\*params) -> Artifact - # Inference Types: ```python -from llama_stack.types import ChatCompletionStreamChunk, CompletionStreamChunk +from llama_stack.types import ( + ChatCompletionStreamChunk, + CompletionStreamChunk, + TokenLogProbs, + InferenceChatCompletionResponse, + InferenceCompletionResponse, +) ``` Methods: -- client.inference.chat_completion(\*\*params) -> ChatCompletionStreamChunk -- client.inference.completion(\*\*params) -> CompletionStreamChunk +- client.inference.chat_completion(\*\*params) -> InferenceChatCompletionResponse +- client.inference.completion(\*\*params) -> InferenceCompletionResponse ## Embeddings @@ -200,19 +191,6 @@ Methods: - client.inference.embeddings.create(\*\*params) -> Embeddings -# Logging - -Types: - -```python -from llama_stack.types import LoggingGetLogsResponse -``` - -Methods: - -- client.logging.get_logs(\*\*params) -> LoggingGetLogsResponse -- client.logging.log_messages(\*\*params) -> None - # Safety Types: @@ -307,25 +285,6 @@ Methods: - client.reward_scoring.score(\*\*params) -> RewardScoring -# Runs - -Methods: - -- client.runs.update(\*\*params) -> Run -- client.runs.log_metrics(\*\*params) -> None - -## Metrics - -Types: - -```python -from llama_stack.types.runs import MetricListResponse -``` - -Methods: - -- client.runs.metrics.list(\*\*params) -> MetricListResponse - # SyntheticDataGeneration Types: diff --git a/pyproject.toml b/pyproject.toml index 06c1ced..9a15fbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] -name = "llama_stack" -version = "0.0.1-alpha.4" +name = "llama_stack_client" +version = "0.0.1-alpha.0" description = "The official Python library for the llama-stack API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/requirements-dev.lock b/requirements-dev.lock index 327ad80..c8815cf 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -49,7 +49,7 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -mypy==1.10.1 +mypy==1.11.2 mypy-extensions==1.0.0 # via mypy nodeenv==1.8.0 @@ -70,7 +70,7 @@ pydantic-core==2.18.2 # via pydantic pygments==2.18.0 # via rich -pyright==1.1.374 +pyright==1.1.380 pytest==7.1.1 # via pytest-asyncio pytest-asyncio==0.21.1 @@ -80,7 +80,7 @@ pytz==2023.3.post1 # via dirty-equals respx==0.20.2 rich==13.7.1 -ruff==0.5.6 +ruff==0.6.5 setuptools==68.2.2 # via nodeenv six==1.16.0 diff --git a/src/llama_stack/_client.py b/src/llama_stack/_client.py index 9af6d5d..72fd4a9 100644 --- a/src/llama_stack/_client.py +++ b/src/llama_stack/_client.py @@ -52,19 +52,16 @@ class LlamaStack(SyncAPIClient): + telemetry: resources.TelemetryResource agentic_system: resources.AgenticSystemResource - artifacts: resources.ArtifactsResource datasets: resources.DatasetsResource evaluate: resources.EvaluateResource evaluations: resources.EvaluationsResource - experiments: resources.ExperimentsResource inference: resources.InferenceResource - logging: resources.LoggingResource safety: resources.SafetyResource memory_banks: resources.MemoryBanksResource post_training: resources.PostTrainingResource reward_scoring: resources.RewardScoringResource - runs: resources.RunsResource synthetic_data_generation: resources.SyntheticDataGenerationResource batch_inference: resources.BatchInferenceResource with_raw_response: LlamaStackWithRawResponse @@ -135,19 +132,16 @@ def __init__( _strict_response_validation=_strict_response_validation, ) + self.telemetry = resources.TelemetryResource(self) self.agentic_system = resources.AgenticSystemResource(self) - self.artifacts = resources.ArtifactsResource(self) self.datasets = resources.DatasetsResource(self) self.evaluate = resources.EvaluateResource(self) self.evaluations = resources.EvaluationsResource(self) - self.experiments = resources.ExperimentsResource(self) self.inference = resources.InferenceResource(self) - self.logging = resources.LoggingResource(self) self.safety = resources.SafetyResource(self) self.memory_banks = resources.MemoryBanksResource(self) self.post_training = resources.PostTrainingResource(self) self.reward_scoring = resources.RewardScoringResource(self) - self.runs = resources.RunsResource(self) self.synthetic_data_generation = resources.SyntheticDataGenerationResource(self) self.batch_inference = resources.BatchInferenceResource(self) self.with_raw_response = LlamaStackWithRawResponse(self) @@ -253,19 +247,16 @@ def _make_status_error( class AsyncLlamaStack(AsyncAPIClient): + telemetry: resources.AsyncTelemetryResource agentic_system: resources.AsyncAgenticSystemResource - artifacts: resources.AsyncArtifactsResource datasets: resources.AsyncDatasetsResource evaluate: resources.AsyncEvaluateResource evaluations: resources.AsyncEvaluationsResource - experiments: resources.AsyncExperimentsResource inference: resources.AsyncInferenceResource - logging: resources.AsyncLoggingResource safety: resources.AsyncSafetyResource memory_banks: resources.AsyncMemoryBanksResource post_training: resources.AsyncPostTrainingResource reward_scoring: resources.AsyncRewardScoringResource - runs: resources.AsyncRunsResource synthetic_data_generation: resources.AsyncSyntheticDataGenerationResource batch_inference: resources.AsyncBatchInferenceResource with_raw_response: AsyncLlamaStackWithRawResponse @@ -336,19 +327,16 @@ def __init__( _strict_response_validation=_strict_response_validation, ) + self.telemetry = resources.AsyncTelemetryResource(self) self.agentic_system = resources.AsyncAgenticSystemResource(self) - self.artifacts = resources.AsyncArtifactsResource(self) self.datasets = resources.AsyncDatasetsResource(self) self.evaluate = resources.AsyncEvaluateResource(self) self.evaluations = resources.AsyncEvaluationsResource(self) - self.experiments = resources.AsyncExperimentsResource(self) self.inference = resources.AsyncInferenceResource(self) - self.logging = resources.AsyncLoggingResource(self) self.safety = resources.AsyncSafetyResource(self) self.memory_banks = resources.AsyncMemoryBanksResource(self) self.post_training = resources.AsyncPostTrainingResource(self) self.reward_scoring = resources.AsyncRewardScoringResource(self) - self.runs = resources.AsyncRunsResource(self) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResource(self) self.batch_inference = resources.AsyncBatchInferenceResource(self) self.with_raw_response = AsyncLlamaStackWithRawResponse(self) @@ -455,19 +443,16 @@ def _make_status_error( class LlamaStackWithRawResponse: def __init__(self, client: LlamaStack) -> None: + self.telemetry = resources.TelemetryResourceWithRawResponse(client.telemetry) self.agentic_system = resources.AgenticSystemResourceWithRawResponse(client.agentic_system) - self.artifacts = resources.ArtifactsResourceWithRawResponse(client.artifacts) self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets) self.evaluate = resources.EvaluateResourceWithRawResponse(client.evaluate) self.evaluations = resources.EvaluationsResourceWithRawResponse(client.evaluations) - self.experiments = resources.ExperimentsResourceWithRawResponse(client.experiments) self.inference = resources.InferenceResourceWithRawResponse(client.inference) - self.logging = resources.LoggingResourceWithRawResponse(client.logging) self.safety = resources.SafetyResourceWithRawResponse(client.safety) self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks) self.post_training = resources.PostTrainingResourceWithRawResponse(client.post_training) self.reward_scoring = resources.RewardScoringResourceWithRawResponse(client.reward_scoring) - self.runs = resources.RunsResourceWithRawResponse(client.runs) self.synthetic_data_generation = resources.SyntheticDataGenerationResourceWithRawResponse( client.synthetic_data_generation ) @@ -476,19 +461,16 @@ def __init__(self, client: LlamaStack) -> None: class AsyncLlamaStackWithRawResponse: def __init__(self, client: AsyncLlamaStack) -> None: + self.telemetry = resources.AsyncTelemetryResourceWithRawResponse(client.telemetry) self.agentic_system = resources.AsyncAgenticSystemResourceWithRawResponse(client.agentic_system) - self.artifacts = resources.AsyncArtifactsResourceWithRawResponse(client.artifacts) self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets) self.evaluate = resources.AsyncEvaluateResourceWithRawResponse(client.evaluate) self.evaluations = resources.AsyncEvaluationsResourceWithRawResponse(client.evaluations) - self.experiments = resources.AsyncExperimentsResourceWithRawResponse(client.experiments) self.inference = resources.AsyncInferenceResourceWithRawResponse(client.inference) - self.logging = resources.AsyncLoggingResourceWithRawResponse(client.logging) self.safety = resources.AsyncSafetyResourceWithRawResponse(client.safety) self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks) self.post_training = resources.AsyncPostTrainingResourceWithRawResponse(client.post_training) self.reward_scoring = resources.AsyncRewardScoringResourceWithRawResponse(client.reward_scoring) - self.runs = resources.AsyncRunsResourceWithRawResponse(client.runs) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResourceWithRawResponse( client.synthetic_data_generation ) @@ -497,19 +479,16 @@ def __init__(self, client: AsyncLlamaStack) -> None: class LlamaStackWithStreamedResponse: def __init__(self, client: LlamaStack) -> None: + self.telemetry = resources.TelemetryResourceWithStreamingResponse(client.telemetry) self.agentic_system = resources.AgenticSystemResourceWithStreamingResponse(client.agentic_system) - self.artifacts = resources.ArtifactsResourceWithStreamingResponse(client.artifacts) self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets) self.evaluate = resources.EvaluateResourceWithStreamingResponse(client.evaluate) self.evaluations = resources.EvaluationsResourceWithStreamingResponse(client.evaluations) - self.experiments = resources.ExperimentsResourceWithStreamingResponse(client.experiments) self.inference = resources.InferenceResourceWithStreamingResponse(client.inference) - self.logging = resources.LoggingResourceWithStreamingResponse(client.logging) self.safety = resources.SafetyResourceWithStreamingResponse(client.safety) self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks) self.post_training = resources.PostTrainingResourceWithStreamingResponse(client.post_training) self.reward_scoring = resources.RewardScoringResourceWithStreamingResponse(client.reward_scoring) - self.runs = resources.RunsResourceWithStreamingResponse(client.runs) self.synthetic_data_generation = resources.SyntheticDataGenerationResourceWithStreamingResponse( client.synthetic_data_generation ) @@ -518,19 +497,16 @@ def __init__(self, client: LlamaStack) -> None: class AsyncLlamaStackWithStreamedResponse: def __init__(self, client: AsyncLlamaStack) -> None: + self.telemetry = resources.AsyncTelemetryResourceWithStreamingResponse(client.telemetry) self.agentic_system = resources.AsyncAgenticSystemResourceWithStreamingResponse(client.agentic_system) - self.artifacts = resources.AsyncArtifactsResourceWithStreamingResponse(client.artifacts) self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets) self.evaluate = resources.AsyncEvaluateResourceWithStreamingResponse(client.evaluate) self.evaluations = resources.AsyncEvaluationsResourceWithStreamingResponse(client.evaluations) - self.experiments = resources.AsyncExperimentsResourceWithStreamingResponse(client.experiments) self.inference = resources.AsyncInferenceResourceWithStreamingResponse(client.inference) - self.logging = resources.AsyncLoggingResourceWithStreamingResponse(client.logging) self.safety = resources.AsyncSafetyResourceWithStreamingResponse(client.safety) self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks) self.post_training = resources.AsyncPostTrainingResourceWithStreamingResponse(client.post_training) self.reward_scoring = resources.AsyncRewardScoringResourceWithStreamingResponse(client.reward_scoring) - self.runs = resources.AsyncRunsResourceWithStreamingResponse(client.runs) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResourceWithStreamingResponse( client.synthetic_data_generation ) diff --git a/src/llama_stack/_utils/_utils.py b/src/llama_stack/_utils/_utils.py index 2fc5a1c..0bba17c 100644 --- a/src/llama_stack/_utils/_utils.py +++ b/src/llama_stack/_utils/_utils.py @@ -363,12 +363,13 @@ def file_from_path(path: str) -> FileTypes: def get_required_header(headers: HeadersLike, header: str) -> str: lower_header = header.lower() - if isinstance(headers, Mapping): - for k, v in headers.items(): + if is_mapping_t(headers): + # mypy doesn't understand the type narrowing here + for k, v in headers.items(): # type: ignore if k.lower() == lower_header and isinstance(v, str): return v - """ to deal with the case where the header looks like Stainless-Event-Id """ + # to deal with the case where the header looks like Stainless-Event-Id intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) for normalized_header in [header, lower_header, header.upper(), intercaps_header]: diff --git a/src/llama_stack/resources/__init__.py b/src/llama_stack/resources/__init__.py index 5d3fccb..a9a971f 100644 --- a/src/llama_stack/resources/__init__.py +++ b/src/llama_stack/resources/__init__.py @@ -1,13 +1,5 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from .runs import ( - RunsResource, - AsyncRunsResource, - RunsResourceWithRawResponse, - AsyncRunsResourceWithRawResponse, - RunsResourceWithStreamingResponse, - AsyncRunsResourceWithStreamingResponse, -) from .safety import ( SafetyResource, AsyncSafetyResource, @@ -16,14 +8,6 @@ SafetyResourceWithStreamingResponse, AsyncSafetyResourceWithStreamingResponse, ) -from .logging import ( - LoggingResource, - AsyncLoggingResource, - LoggingResourceWithRawResponse, - AsyncLoggingResourceWithRawResponse, - LoggingResourceWithStreamingResponse, - AsyncLoggingResourceWithStreamingResponse, -) from .datasets import ( DatasetsResource, AsyncDatasetsResource, @@ -40,14 +24,6 @@ EvaluateResourceWithStreamingResponse, AsyncEvaluateResourceWithStreamingResponse, ) -from .artifacts import ( - ArtifactsResource, - AsyncArtifactsResource, - ArtifactsResourceWithRawResponse, - AsyncArtifactsResourceWithRawResponse, - ArtifactsResourceWithStreamingResponse, - AsyncArtifactsResourceWithStreamingResponse, -) from .inference import ( InferenceResource, AsyncInferenceResource, @@ -56,6 +32,14 @@ InferenceResourceWithStreamingResponse, AsyncInferenceResourceWithStreamingResponse, ) +from .telemetry import ( + TelemetryResource, + AsyncTelemetryResource, + TelemetryResourceWithRawResponse, + AsyncTelemetryResourceWithRawResponse, + TelemetryResourceWithStreamingResponse, + AsyncTelemetryResourceWithStreamingResponse, +) from .evaluations import ( EvaluationsResource, AsyncEvaluationsResource, @@ -64,14 +48,6 @@ EvaluationsResourceWithStreamingResponse, AsyncEvaluationsResourceWithStreamingResponse, ) -from .experiments import ( - ExperimentsResource, - AsyncExperimentsResource, - ExperimentsResourceWithRawResponse, - AsyncExperimentsResourceWithRawResponse, - ExperimentsResourceWithStreamingResponse, - AsyncExperimentsResourceWithStreamingResponse, -) from .memory_banks import ( MemoryBanksResource, AsyncMemoryBanksResource, @@ -122,18 +98,18 @@ ) __all__ = [ + "TelemetryResource", + "AsyncTelemetryResource", + "TelemetryResourceWithRawResponse", + "AsyncTelemetryResourceWithRawResponse", + "TelemetryResourceWithStreamingResponse", + "AsyncTelemetryResourceWithStreamingResponse", "AgenticSystemResource", "AsyncAgenticSystemResource", "AgenticSystemResourceWithRawResponse", "AsyncAgenticSystemResourceWithRawResponse", "AgenticSystemResourceWithStreamingResponse", "AsyncAgenticSystemResourceWithStreamingResponse", - "ArtifactsResource", - "AsyncArtifactsResource", - "ArtifactsResourceWithRawResponse", - "AsyncArtifactsResourceWithRawResponse", - "ArtifactsResourceWithStreamingResponse", - "AsyncArtifactsResourceWithStreamingResponse", "DatasetsResource", "AsyncDatasetsResource", "DatasetsResourceWithRawResponse", @@ -152,24 +128,12 @@ "AsyncEvaluationsResourceWithRawResponse", "EvaluationsResourceWithStreamingResponse", "AsyncEvaluationsResourceWithStreamingResponse", - "ExperimentsResource", - "AsyncExperimentsResource", - "ExperimentsResourceWithRawResponse", - "AsyncExperimentsResourceWithRawResponse", - "ExperimentsResourceWithStreamingResponse", - "AsyncExperimentsResourceWithStreamingResponse", "InferenceResource", "AsyncInferenceResource", "InferenceResourceWithRawResponse", "AsyncInferenceResourceWithRawResponse", "InferenceResourceWithStreamingResponse", "AsyncInferenceResourceWithStreamingResponse", - "LoggingResource", - "AsyncLoggingResource", - "LoggingResourceWithRawResponse", - "AsyncLoggingResourceWithRawResponse", - "LoggingResourceWithStreamingResponse", - "AsyncLoggingResourceWithStreamingResponse", "SafetyResource", "AsyncSafetyResource", "SafetyResourceWithRawResponse", @@ -194,12 +158,6 @@ "AsyncRewardScoringResourceWithRawResponse", "RewardScoringResourceWithStreamingResponse", "AsyncRewardScoringResourceWithStreamingResponse", - "RunsResource", - "AsyncRunsResource", - "RunsResourceWithRawResponse", - "AsyncRunsResourceWithRawResponse", - "RunsResourceWithStreamingResponse", - "AsyncRunsResourceWithStreamingResponse", "SyntheticDataGenerationResource", "AsyncSyntheticDataGenerationResource", "SyntheticDataGenerationResourceWithRawResponse", diff --git a/src/llama_stack/resources/agentic_system/turns.py b/src/llama_stack/resources/agentic_system/turns.py index f788383..9df70f7 100644 --- a/src/llama_stack/resources/agentic_system/turns.py +++ b/src/llama_stack/resources/agentic_system/turns.py @@ -2,10 +2,14 @@ from __future__ import annotations +from typing import Iterable, overload +from typing_extensions import Literal + import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( + required_args, maybe_transform, async_maybe_transform, ) @@ -17,9 +21,11 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._streaming import Stream, AsyncStream from ..._base_client import make_request_options from ...types.agentic_system import turn_create_params, turn_retrieve_params from ...types.agentic_system.turn import Turn +from ...types.shared_params.attachment import Attachment from ...types.agentic_system.agentic_system_turn_stream_chunk import AgenticSystemTurnStreamChunk __all__ = ["TurnsResource", "AsyncTurnsResource"] @@ -45,10 +51,15 @@ def with_streaming_response(self) -> TurnsResourceWithStreamingResponse: """ return TurnsResourceWithStreamingResponse(self) + @overload def create( self, *, - request: turn_create_params.Request, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -66,14 +77,99 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + def create( + self, + *, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + stream: Literal[True], + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[AgenticSystemTurnStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def create( + self, + *, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + stream: bool, + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AgenticSystemTurnStreamChunk | Stream[AgenticSystemTurnStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "messages", "session_id"], ["agent_id", "messages", "session_id", "stream"]) + def create( + self, + *, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AgenticSystemTurnStreamChunk | Stream[AgenticSystemTurnStreamChunk]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return self._post( "/agentic_system/turn/create", - body=maybe_transform({"request": request}, turn_create_params.TurnCreateParams), + body=maybe_transform( + { + "agent_id": agent_id, + "messages": messages, + "session_id": session_id, + "attachments": attachments, + "stream": stream, + }, + turn_create_params.TurnCreateParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=AgenticSystemTurnStreamChunk, + stream=stream or False, + stream_cls=Stream[AgenticSystemTurnStreamChunk], ) def retrieve( @@ -137,10 +233,15 @@ def with_streaming_response(self) -> AsyncTurnsResourceWithStreamingResponse: """ return AsyncTurnsResourceWithStreamingResponse(self) + @overload async def create( self, *, - request: turn_create_params.Request, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -158,14 +259,99 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + async def create( + self, + *, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + stream: Literal[True], + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[AgenticSystemTurnStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def create( + self, + *, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + stream: bool, + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AgenticSystemTurnStreamChunk | AsyncStream[AgenticSystemTurnStreamChunk]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["agent_id", "messages", "session_id"], ["agent_id", "messages", "session_id", "stream"]) + async def create( + self, + *, + agent_id: str, + messages: Iterable[turn_create_params.Message], + session_id: str, + attachments: Iterable[Attachment] | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AgenticSystemTurnStreamChunk | AsyncStream[AgenticSystemTurnStreamChunk]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return await self._post( "/agentic_system/turn/create", - body=await async_maybe_transform({"request": request}, turn_create_params.TurnCreateParams), + body=await async_maybe_transform( + { + "agent_id": agent_id, + "messages": messages, + "session_id": session_id, + "attachments": attachments, + "stream": stream, + }, + turn_create_params.TurnCreateParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=AgenticSystemTurnStreamChunk, + stream=stream or False, + stream_cls=AsyncStream[AgenticSystemTurnStreamChunk], ) async def retrieve( diff --git a/src/llama_stack/resources/artifacts.py b/src/llama_stack/resources/artifacts.py deleted file mode 100644 index 3737238..0000000 --- a/src/llama_stack/resources/artifacts.py +++ /dev/null @@ -1,168 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import httpx - -from ..types import artifact_get_params -from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from .._utils import ( - maybe_transform, - async_maybe_transform, -) -from .._compat import cached_property -from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from .._base_client import make_request_options -from ..types.shared.artifact import Artifact - -__all__ = ["ArtifactsResource", "AsyncArtifactsResource"] - - -class ArtifactsResource(SyncAPIResource): - @cached_property - def with_raw_response(self) -> ArtifactsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return ArtifactsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> ArtifactsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return ArtifactsResourceWithStreamingResponse(self) - - def get( - self, - *, - artifact_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Artifact: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._get( - "/artifacts/get", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform({"artifact_id": artifact_id}, artifact_get_params.ArtifactGetParams), - ), - cast_to=Artifact, - ) - - -class AsyncArtifactsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncArtifactsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncArtifactsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncArtifactsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncArtifactsResourceWithStreamingResponse(self) - - async def get( - self, - *, - artifact_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Artifact: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._get( - "/artifacts/get", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform({"artifact_id": artifact_id}, artifact_get_params.ArtifactGetParams), - ), - cast_to=Artifact, - ) - - -class ArtifactsResourceWithRawResponse: - def __init__(self, artifacts: ArtifactsResource) -> None: - self._artifacts = artifacts - - self.get = to_raw_response_wrapper( - artifacts.get, - ) - - -class AsyncArtifactsResourceWithRawResponse: - def __init__(self, artifacts: AsyncArtifactsResource) -> None: - self._artifacts = artifacts - - self.get = async_to_raw_response_wrapper( - artifacts.get, - ) - - -class ArtifactsResourceWithStreamingResponse: - def __init__(self, artifacts: ArtifactsResource) -> None: - self._artifacts = artifacts - - self.get = to_streamed_response_wrapper( - artifacts.get, - ) - - -class AsyncArtifactsResourceWithStreamingResponse: - def __init__(self, artifacts: AsyncArtifactsResource) -> None: - self._artifacts = artifacts - - self.get = async_to_streamed_response_wrapper( - artifacts.get, - ) diff --git a/src/llama_stack/resources/batch_inference.py b/src/llama_stack/resources/batch_inference.py index bfa8e25..36028b3 100644 --- a/src/llama_stack/resources/batch_inference.py +++ b/src/llama_stack/resources/batch_inference.py @@ -2,6 +2,9 @@ from __future__ import annotations +from typing import List, Union, Iterable +from typing_extensions import Literal + import httpx from ..types import batch_inference_completion_params, batch_inference_chat_completion_params @@ -21,6 +24,7 @@ from .._base_client import make_request_options from ..types.batch_chat_completion import BatchChatCompletion from ..types.shared.batch_completion import BatchCompletion +from ..types.shared_params.sampling_params import SamplingParams __all__ = ["BatchInferenceResource", "AsyncBatchInferenceResource"] @@ -48,7 +52,13 @@ def with_streaming_response(self) -> BatchInferenceResourceWithStreamingResponse def chat_completion( self, *, - request: batch_inference_chat_completion_params.Request, + messages_batch: Iterable[Iterable[batch_inference_chat_completion_params.MessagesBatch]], + model: str, + logprobs: batch_inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -58,6 +68,16 @@ def chat_completion( ) -> BatchChatCompletion: """ Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -69,7 +89,16 @@ def chat_completion( return self._post( "/batch_inference/chat_completion", body=maybe_transform( - {"request": request}, batch_inference_chat_completion_params.BatchInferenceChatCompletionParams + { + "messages_batch": messages_batch, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + "tool_choice": tool_choice, + "tool_prompt_format": tool_prompt_format, + "tools": tools, + }, + batch_inference_chat_completion_params.BatchInferenceChatCompletionParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -80,7 +109,10 @@ def chat_completion( def completion( self, *, - request: batch_inference_completion_params.Request, + content_batch: List[Union[str, List[str]]], + model: str, + logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -101,7 +133,13 @@ def completion( return self._post( "/batch_inference/completion", body=maybe_transform( - {"request": request}, batch_inference_completion_params.BatchInferenceCompletionParams + { + "content_batch": content_batch, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + }, + batch_inference_completion_params.BatchInferenceCompletionParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -133,7 +171,13 @@ def with_streaming_response(self) -> AsyncBatchInferenceResourceWithStreamingRes async def chat_completion( self, *, - request: batch_inference_chat_completion_params.Request, + messages_batch: Iterable[Iterable[batch_inference_chat_completion_params.MessagesBatch]], + model: str, + logprobs: batch_inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[batch_inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -143,6 +187,16 @@ async def chat_completion( ) -> BatchChatCompletion: """ Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -154,7 +208,16 @@ async def chat_completion( return await self._post( "/batch_inference/chat_completion", body=await async_maybe_transform( - {"request": request}, batch_inference_chat_completion_params.BatchInferenceChatCompletionParams + { + "messages_batch": messages_batch, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + "tool_choice": tool_choice, + "tool_prompt_format": tool_prompt_format, + "tools": tools, + }, + batch_inference_chat_completion_params.BatchInferenceChatCompletionParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -165,7 +228,10 @@ async def chat_completion( async def completion( self, *, - request: batch_inference_completion_params.Request, + content_batch: List[Union[str, List[str]]], + model: str, + logprobs: batch_inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -186,7 +252,13 @@ async def completion( return await self._post( "/batch_inference/completion", body=await async_maybe_transform( - {"request": request}, batch_inference_completion_params.BatchInferenceCompletionParams + { + "content_batch": content_batch, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + }, + batch_inference_completion_params.BatchInferenceCompletionParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack/resources/datasets.py b/src/llama_stack/resources/datasets.py index 2cc2421..321e301 100644 --- a/src/llama_stack/resources/datasets.py +++ b/src/llama_stack/resources/datasets.py @@ -4,7 +4,7 @@ import httpx -from ..types import dataset_get_params, dataset_create_params, dataset_delete_params +from ..types import TrainEvalDataset, dataset_get_params, dataset_create_params, dataset_delete_params from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, @@ -20,6 +20,7 @@ ) from .._base_client import make_request_options from ..types.train_eval_dataset import TrainEvalDataset +from ..types.train_eval_dataset_param import TrainEvalDatasetParam __all__ = ["DatasetsResource", "AsyncDatasetsResource"] @@ -47,7 +48,8 @@ def with_streaming_response(self) -> DatasetsResourceWithStreamingResponse: def create( self, *, - request: dataset_create_params.Request, + dataset: TrainEvalDatasetParam, + uuid: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,7 +70,13 @@ def create( extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( "/datasets/create", - body=maybe_transform({"request": request}, dataset_create_params.DatasetCreateParams), + body=maybe_transform( + { + "dataset": dataset, + "uuid": uuid, + }, + dataset_create_params.DatasetCreateParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -163,7 +171,8 @@ def with_streaming_response(self) -> AsyncDatasetsResourceWithStreamingResponse: async def create( self, *, - request: dataset_create_params.Request, + dataset: TrainEvalDatasetParam, + uuid: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -184,7 +193,13 @@ async def create( extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( "/datasets/create", - body=await async_maybe_transform({"request": request}, dataset_create_params.DatasetCreateParams), + body=await async_maybe_transform( + { + "dataset": dataset, + "uuid": uuid, + }, + dataset_create_params.DatasetCreateParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/llama_stack/resources/evaluate/jobs/__init__.py b/src/llama_stack/resources/evaluate/jobs/__init__.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/resources/evaluate/jobs/artifacts.py b/src/llama_stack/resources/evaluate/jobs/artifacts.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/resources/evaluate/jobs/jobs.py b/src/llama_stack/resources/evaluate/jobs/jobs.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/resources/evaluate/jobs/logs.py b/src/llama_stack/resources/evaluate/jobs/logs.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/resources/evaluate/jobs/status.py b/src/llama_stack/resources/evaluate/jobs/status.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/resources/evaluate/question_answering.py b/src/llama_stack/resources/evaluate/question_answering.py index 162072a..ca5169f 100644 --- a/src/llama_stack/resources/evaluate/question_answering.py +++ b/src/llama_stack/resources/evaluate/question_answering.py @@ -2,6 +2,9 @@ from __future__ import annotations +from typing import List +from typing_extensions import Literal + import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven @@ -47,7 +50,7 @@ def with_streaming_response(self) -> QuestionAnsweringResourceWithStreamingRespo def create( self, *, - request: question_answering_create_params.Request, + metrics: List[Literal["em", "f1"]], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -67,7 +70,7 @@ def create( """ return self._post( "/evaluate/question_answering/", - body=maybe_transform({"request": request}, question_answering_create_params.QuestionAnsweringCreateParams), + body=maybe_transform({"metrics": metrics}, question_answering_create_params.QuestionAnsweringCreateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -98,7 +101,7 @@ def with_streaming_response(self) -> AsyncQuestionAnsweringResourceWithStreaming async def create( self, *, - request: question_answering_create_params.Request, + metrics: List[Literal["em", "f1"]], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -119,7 +122,7 @@ async def create( return await self._post( "/evaluate/question_answering/", body=await async_maybe_transform( - {"request": request}, question_answering_create_params.QuestionAnsweringCreateParams + {"metrics": metrics}, question_answering_create_params.QuestionAnsweringCreateParams ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack/resources/evaluations.py b/src/llama_stack/resources/evaluations.py index cff5ce2..6328060 100644 --- a/src/llama_stack/resources/evaluations.py +++ b/src/llama_stack/resources/evaluations.py @@ -2,6 +2,9 @@ from __future__ import annotations +from typing import List +from typing_extensions import Literal + import httpx from ..types import evaluation_summarization_params, evaluation_text_generation_params @@ -47,7 +50,7 @@ def with_streaming_response(self) -> EvaluationsResourceWithStreamingResponse: def summarization( self, *, - request: evaluation_summarization_params.Request, + metrics: List[Literal["rouge", "bleu"]], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -67,7 +70,7 @@ def summarization( """ return self._post( "/evaluate/summarization/", - body=maybe_transform({"request": request}, evaluation_summarization_params.EvaluationSummarizationParams), + body=maybe_transform({"metrics": metrics}, evaluation_summarization_params.EvaluationSummarizationParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -77,7 +80,7 @@ def summarization( def text_generation( self, *, - request: evaluation_text_generation_params.Request, + metrics: List[Literal["perplexity", "rouge", "bleu"]], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -98,7 +101,7 @@ def text_generation( return self._post( "/evaluate/text_generation/", body=maybe_transform( - {"request": request}, evaluation_text_generation_params.EvaluationTextGenerationParams + {"metrics": metrics}, evaluation_text_generation_params.EvaluationTextGenerationParams ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -130,7 +133,7 @@ def with_streaming_response(self) -> AsyncEvaluationsResourceWithStreamingRespon async def summarization( self, *, - request: evaluation_summarization_params.Request, + metrics: List[Literal["rouge", "bleu"]], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -151,7 +154,7 @@ async def summarization( return await self._post( "/evaluate/summarization/", body=await async_maybe_transform( - {"request": request}, evaluation_summarization_params.EvaluationSummarizationParams + {"metrics": metrics}, evaluation_summarization_params.EvaluationSummarizationParams ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -162,7 +165,7 @@ async def summarization( async def text_generation( self, *, - request: evaluation_text_generation_params.Request, + metrics: List[Literal["perplexity", "rouge", "bleu"]], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -183,7 +186,7 @@ async def text_generation( return await self._post( "/evaluate/text_generation/", body=await async_maybe_transform( - {"request": request}, evaluation_text_generation_params.EvaluationTextGenerationParams + {"metrics": metrics}, evaluation_text_generation_params.EvaluationTextGenerationParams ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack/resources/experiments/__init__.py b/src/llama_stack/resources/experiments/__init__.py deleted file mode 100644 index 0bb3ee0..0000000 --- a/src/llama_stack/resources/experiments/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from .artifacts import ( - ArtifactsResource, - AsyncArtifactsResource, - ArtifactsResourceWithRawResponse, - AsyncArtifactsResourceWithRawResponse, - ArtifactsResourceWithStreamingResponse, - AsyncArtifactsResourceWithStreamingResponse, -) -from .experiments import ( - ExperimentsResource, - AsyncExperimentsResource, - ExperimentsResourceWithRawResponse, - AsyncExperimentsResourceWithRawResponse, - ExperimentsResourceWithStreamingResponse, - AsyncExperimentsResourceWithStreamingResponse, -) - -__all__ = [ - "ArtifactsResource", - "AsyncArtifactsResource", - "ArtifactsResourceWithRawResponse", - "AsyncArtifactsResourceWithRawResponse", - "ArtifactsResourceWithStreamingResponse", - "AsyncArtifactsResourceWithStreamingResponse", - "ExperimentsResource", - "AsyncExperimentsResource", - "ExperimentsResourceWithRawResponse", - "AsyncExperimentsResourceWithRawResponse", - "ExperimentsResourceWithStreamingResponse", - "AsyncExperimentsResourceWithStreamingResponse", -] diff --git a/src/llama_stack/resources/experiments/artifacts.py b/src/llama_stack/resources/experiments/artifacts.py deleted file mode 100644 index 4dac012..0000000 --- a/src/llama_stack/resources/experiments/artifacts.py +++ /dev/null @@ -1,238 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import httpx - -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) -from ..._compat import cached_property -from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from ..._base_client import make_request_options -from ...types.experiments import artifact_upload_params, artifact_retrieve_params -from ...types.shared.artifact import Artifact - -__all__ = ["ArtifactsResource", "AsyncArtifactsResource"] - - -class ArtifactsResource(SyncAPIResource): - @cached_property - def with_raw_response(self) -> ArtifactsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return ArtifactsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> ArtifactsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return ArtifactsResourceWithStreamingResponse(self) - - def retrieve( - self, - *, - experiment_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Artifact: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return self._post( - "/experiments/artifacts/get", - body=maybe_transform({"experiment_id": experiment_id}, artifact_retrieve_params.ArtifactRetrieveParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Artifact, - ) - - def upload( - self, - *, - request: artifact_upload_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Artifact: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/experiments/artifacts/upload", - body=maybe_transform({"request": request}, artifact_upload_params.ArtifactUploadParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Artifact, - ) - - -class AsyncArtifactsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncArtifactsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncArtifactsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncArtifactsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncArtifactsResourceWithStreamingResponse(self) - - async def retrieve( - self, - *, - experiment_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Artifact: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return await self._post( - "/experiments/artifacts/get", - body=await async_maybe_transform( - {"experiment_id": experiment_id}, artifact_retrieve_params.ArtifactRetrieveParams - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Artifact, - ) - - async def upload( - self, - *, - request: artifact_upload_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Artifact: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/experiments/artifacts/upload", - body=await async_maybe_transform({"request": request}, artifact_upload_params.ArtifactUploadParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Artifact, - ) - - -class ArtifactsResourceWithRawResponse: - def __init__(self, artifacts: ArtifactsResource) -> None: - self._artifacts = artifacts - - self.retrieve = to_raw_response_wrapper( - artifacts.retrieve, - ) - self.upload = to_raw_response_wrapper( - artifacts.upload, - ) - - -class AsyncArtifactsResourceWithRawResponse: - def __init__(self, artifacts: AsyncArtifactsResource) -> None: - self._artifacts = artifacts - - self.retrieve = async_to_raw_response_wrapper( - artifacts.retrieve, - ) - self.upload = async_to_raw_response_wrapper( - artifacts.upload, - ) - - -class ArtifactsResourceWithStreamingResponse: - def __init__(self, artifacts: ArtifactsResource) -> None: - self._artifacts = artifacts - - self.retrieve = to_streamed_response_wrapper( - artifacts.retrieve, - ) - self.upload = to_streamed_response_wrapper( - artifacts.upload, - ) - - -class AsyncArtifactsResourceWithStreamingResponse: - def __init__(self, artifacts: AsyncArtifactsResource) -> None: - self._artifacts = artifacts - - self.retrieve = async_to_streamed_response_wrapper( - artifacts.retrieve, - ) - self.upload = async_to_streamed_response_wrapper( - artifacts.upload, - ) diff --git a/src/llama_stack/resources/experiments/experiments.py b/src/llama_stack/resources/experiments/experiments.py deleted file mode 100644 index 7b699df..0000000 --- a/src/llama_stack/resources/experiments/experiments.py +++ /dev/null @@ -1,478 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import httpx - -from ...types import ( - experiment_create_params, - experiment_update_params, - experiment_retrieve_params, - experiment_create_run_params, -) -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) -from ..._compat import cached_property -from .artifacts import ( - ArtifactsResource, - AsyncArtifactsResource, - ArtifactsResourceWithRawResponse, - AsyncArtifactsResourceWithRawResponse, - ArtifactsResourceWithStreamingResponse, - AsyncArtifactsResourceWithStreamingResponse, -) -from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from ..._base_client import make_request_options -from ...types.experiment import Experiment -from ...types.shared.run import Run - -__all__ = ["ExperimentsResource", "AsyncExperimentsResource"] - - -class ExperimentsResource(SyncAPIResource): - @cached_property - def artifacts(self) -> ArtifactsResource: - return ArtifactsResource(self._client) - - @cached_property - def with_raw_response(self) -> ExperimentsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return ExperimentsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> ExperimentsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return ExperimentsResourceWithStreamingResponse(self) - - def create( - self, - *, - request: experiment_create_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/experiments/create", - body=maybe_transform({"request": request}, experiment_create_params.ExperimentCreateParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Experiment, - ) - - def retrieve( - self, - *, - experiment_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._get( - "/experiments/get", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform( - {"experiment_id": experiment_id}, experiment_retrieve_params.ExperimentRetrieveParams - ), - ), - cast_to=Experiment, - ) - - def update( - self, - *, - request: experiment_update_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/experiments/update", - body=maybe_transform({"request": request}, experiment_update_params.ExperimentUpdateParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Experiment, - ) - - def list( - self, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return self._get( - "/experiments/list", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Experiment, - ) - - def create_run( - self, - *, - request: experiment_create_run_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Run: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/experiments/create_run", - body=maybe_transform({"request": request}, experiment_create_run_params.ExperimentCreateRunParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Run, - ) - - -class AsyncExperimentsResource(AsyncAPIResource): - @cached_property - def artifacts(self) -> AsyncArtifactsResource: - return AsyncArtifactsResource(self._client) - - @cached_property - def with_raw_response(self) -> AsyncExperimentsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncExperimentsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncExperimentsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncExperimentsResourceWithStreamingResponse(self) - - async def create( - self, - *, - request: experiment_create_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/experiments/create", - body=await async_maybe_transform({"request": request}, experiment_create_params.ExperimentCreateParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Experiment, - ) - - async def retrieve( - self, - *, - experiment_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._get( - "/experiments/get", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform( - {"experiment_id": experiment_id}, experiment_retrieve_params.ExperimentRetrieveParams - ), - ), - cast_to=Experiment, - ) - - async def update( - self, - *, - request: experiment_update_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/experiments/update", - body=await async_maybe_transform({"request": request}, experiment_update_params.ExperimentUpdateParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Experiment, - ) - - async def list( - self, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Experiment: - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return await self._get( - "/experiments/list", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Experiment, - ) - - async def create_run( - self, - *, - request: experiment_create_run_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Run: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/experiments/create_run", - body=await async_maybe_transform( - {"request": request}, experiment_create_run_params.ExperimentCreateRunParams - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Run, - ) - - -class ExperimentsResourceWithRawResponse: - def __init__(self, experiments: ExperimentsResource) -> None: - self._experiments = experiments - - self.create = to_raw_response_wrapper( - experiments.create, - ) - self.retrieve = to_raw_response_wrapper( - experiments.retrieve, - ) - self.update = to_raw_response_wrapper( - experiments.update, - ) - self.list = to_raw_response_wrapper( - experiments.list, - ) - self.create_run = to_raw_response_wrapper( - experiments.create_run, - ) - - @cached_property - def artifacts(self) -> ArtifactsResourceWithRawResponse: - return ArtifactsResourceWithRawResponse(self._experiments.artifacts) - - -class AsyncExperimentsResourceWithRawResponse: - def __init__(self, experiments: AsyncExperimentsResource) -> None: - self._experiments = experiments - - self.create = async_to_raw_response_wrapper( - experiments.create, - ) - self.retrieve = async_to_raw_response_wrapper( - experiments.retrieve, - ) - self.update = async_to_raw_response_wrapper( - experiments.update, - ) - self.list = async_to_raw_response_wrapper( - experiments.list, - ) - self.create_run = async_to_raw_response_wrapper( - experiments.create_run, - ) - - @cached_property - def artifacts(self) -> AsyncArtifactsResourceWithRawResponse: - return AsyncArtifactsResourceWithRawResponse(self._experiments.artifacts) - - -class ExperimentsResourceWithStreamingResponse: - def __init__(self, experiments: ExperimentsResource) -> None: - self._experiments = experiments - - self.create = to_streamed_response_wrapper( - experiments.create, - ) - self.retrieve = to_streamed_response_wrapper( - experiments.retrieve, - ) - self.update = to_streamed_response_wrapper( - experiments.update, - ) - self.list = to_streamed_response_wrapper( - experiments.list, - ) - self.create_run = to_streamed_response_wrapper( - experiments.create_run, - ) - - @cached_property - def artifacts(self) -> ArtifactsResourceWithStreamingResponse: - return ArtifactsResourceWithStreamingResponse(self._experiments.artifacts) - - -class AsyncExperimentsResourceWithStreamingResponse: - def __init__(self, experiments: AsyncExperimentsResource) -> None: - self._experiments = experiments - - self.create = async_to_streamed_response_wrapper( - experiments.create, - ) - self.retrieve = async_to_streamed_response_wrapper( - experiments.retrieve, - ) - self.update = async_to_streamed_response_wrapper( - experiments.update, - ) - self.list = async_to_streamed_response_wrapper( - experiments.list, - ) - self.create_run = async_to_streamed_response_wrapper( - experiments.create_run, - ) - - @cached_property - def artifacts(self) -> AsyncArtifactsResourceWithStreamingResponse: - return AsyncArtifactsResourceWithStreamingResponse(self._experiments.artifacts) diff --git a/src/llama_stack/resources/inference/inference.py b/src/llama_stack/resources/inference/inference.py index 77fede8..648736e 100644 --- a/src/llama_stack/resources/inference/inference.py +++ b/src/llama_stack/resources/inference/inference.py @@ -2,11 +2,15 @@ from __future__ import annotations +from typing import Any, List, Union, Iterable, cast, overload +from typing_extensions import Literal + import httpx from ...types import inference_completion_params, inference_chat_completion_params from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( + required_args, maybe_transform, async_maybe_transform, ) @@ -26,9 +30,11 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._streaming import Stream, AsyncStream from ..._base_client import make_request_options -from ...types.completion_stream_chunk import CompletionStreamChunk -from ...types.chat_completion_stream_chunk import ChatCompletionStreamChunk +from ...types.inference_completion_response import InferenceCompletionResponse +from ...types.shared_params.sampling_params import SamplingParams +from ...types.inference_chat_completion_response import InferenceChatCompletionResponse __all__ = ["InferenceResource", "AsyncInferenceResource"] @@ -57,19 +63,37 @@ def with_streaming_response(self) -> InferenceResourceWithStreamingResponse: """ return InferenceResourceWithStreamingResponse(self) + @overload def chat_completion( self, *, - request: inference_chat_completion_params.Request, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ChatCompletionStreamChunk: + ) -> InferenceChatCompletionResponse: """ Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -78,27 +102,153 @@ def chat_completion( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + stream: Literal[True], + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + stream: bool, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model"], ["messages", "model", "stream"]) + def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | Stream[InferenceChatCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} - return self._post( - "/inference/chat_completion", - body=maybe_transform({"request": request}, inference_chat_completion_params.InferenceChatCompletionParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + InferenceChatCompletionResponse, + self._post( + "/inference/chat_completion", + body=maybe_transform( + { + "messages": messages, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + "stream": stream, + "tool_choice": tool_choice, + "tool_prompt_format": tool_prompt_format, + "tools": tools, + }, + inference_chat_completion_params.InferenceChatCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, InferenceChatCompletionResponse + ), # Union types cannot be passed in as arguments in the type system + stream=stream or False, + stream_cls=Stream[InferenceChatCompletionResponse], ), - cast_to=ChatCompletionStreamChunk, ) def completion( self, *, - request: inference_completion_params.Request, + content: Union[str, List[str]], + model: str, + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> CompletionStreamChunk: + ) -> InferenceCompletionResponse: """ Args: extra_headers: Send extra headers @@ -109,13 +259,27 @@ def completion( timeout: Override the client-level default timeout for this request, in seconds """ - return self._post( - "/inference/completion", - body=maybe_transform({"request": request}, inference_completion_params.InferenceCompletionParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + InferenceCompletionResponse, + self._post( + "/inference/completion", + body=maybe_transform( + { + "content": content, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + "stream": stream, + }, + inference_completion_params.InferenceCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, InferenceCompletionResponse + ), # Union types cannot be passed in as arguments in the type system ), - cast_to=CompletionStreamChunk, ) @@ -143,19 +307,119 @@ def with_streaming_response(self) -> AsyncInferenceResourceWithStreamingResponse """ return AsyncInferenceResourceWithStreamingResponse(self) + @overload async def chat_completion( self, *, - request: inference_chat_completion_params.Request, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ChatCompletionStreamChunk: + ) -> InferenceChatCompletionResponse: """ Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + stream: Literal[True], + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + stream: bool, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]: + """ + Args: + tool_prompt_format: `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -164,29 +428,71 @@ async def chat_completion( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @required_args(["messages", "model"], ["messages", "model", "stream"]) + async def chat_completion( + self, + *, + messages: Iterable[inference_chat_completion_params.Message], + model: str, + logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: Literal[False] | Literal[True] | NotGiven = NOT_GIVEN, + tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, + tool_prompt_format: Literal["json", "function_tag"] | NotGiven = NOT_GIVEN, + tools: Iterable[inference_chat_completion_params.Tool] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InferenceChatCompletionResponse | AsyncStream[InferenceChatCompletionResponse]: extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} - return await self._post( - "/inference/chat_completion", - body=await async_maybe_transform( - {"request": request}, inference_chat_completion_params.InferenceChatCompletionParams - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + InferenceChatCompletionResponse, + await self._post( + "/inference/chat_completion", + body=await async_maybe_transform( + { + "messages": messages, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + "stream": stream, + "tool_choice": tool_choice, + "tool_prompt_format": tool_prompt_format, + "tools": tools, + }, + inference_chat_completion_params.InferenceChatCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, InferenceChatCompletionResponse + ), # Union types cannot be passed in as arguments in the type system + stream=stream or False, + stream_cls=AsyncStream[InferenceChatCompletionResponse], ), - cast_to=ChatCompletionStreamChunk, ) async def completion( self, *, - request: inference_completion_params.Request, + content: Union[str, List[str]], + model: str, + logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + sampling_params: SamplingParams | NotGiven = NOT_GIVEN, + stream: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> CompletionStreamChunk: + ) -> InferenceCompletionResponse: """ Args: extra_headers: Send extra headers @@ -197,15 +503,27 @@ async def completion( timeout: Override the client-level default timeout for this request, in seconds """ - return await self._post( - "/inference/completion", - body=await async_maybe_transform( - {"request": request}, inference_completion_params.InferenceCompletionParams - ), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + return cast( + InferenceCompletionResponse, + await self._post( + "/inference/completion", + body=await async_maybe_transform( + { + "content": content, + "model": model, + "logprobs": logprobs, + "sampling_params": sampling_params, + "stream": stream, + }, + inference_completion_params.InferenceCompletionParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=cast( + Any, InferenceCompletionResponse + ), # Union types cannot be passed in as arguments in the type system ), - cast_to=CompletionStreamChunk, ) diff --git a/src/llama_stack/resources/post_training/post_training.py b/src/llama_stack/resources/post_training/post_training.py index dbca3c8..db5aac8 100644 --- a/src/llama_stack/resources/post_training/post_training.py +++ b/src/llama_stack/resources/post_training/post_training.py @@ -2,6 +2,9 @@ from __future__ import annotations +from typing import Dict, Union, Iterable +from typing_extensions import Literal + import httpx from .jobs import ( @@ -12,7 +15,10 @@ JobsResourceWithStreamingResponse, AsyncJobsResourceWithStreamingResponse, ) -from ...types import post_training_preference_optimize_params, post_training_supervised_fine_tune_params +from ...types import ( + post_training_preference_optimize_params, + post_training_supervised_fine_tune_params, +) from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, @@ -28,6 +34,7 @@ ) from ..._base_client import make_request_options from ...types.post_training_job import PostTrainingJob +from ...types.train_eval_dataset_param import TrainEvalDatasetParam __all__ = ["PostTrainingResource", "AsyncPostTrainingResource"] @@ -59,7 +66,16 @@ def with_streaming_response(self) -> PostTrainingResourceWithStreamingResponse: def preference_optimize( self, *, - request: post_training_preference_optimize_params.Request, + algorithm: Literal["dpo"], + algorithm_config: post_training_preference_optimize_params.AlgorithmConfig, + dataset: TrainEvalDatasetParam, + finetuned_model: str, + hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + job_uuid: str, + logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + optimizer_config: post_training_preference_optimize_params.OptimizerConfig, + training_config: post_training_preference_optimize_params.TrainingConfig, + validation_dataset: TrainEvalDatasetParam, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -80,7 +96,19 @@ def preference_optimize( return self._post( "/post_training/preference_optimize", body=maybe_transform( - {"request": request}, post_training_preference_optimize_params.PostTrainingPreferenceOptimizeParams + { + "algorithm": algorithm, + "algorithm_config": algorithm_config, + "dataset": dataset, + "finetuned_model": finetuned_model, + "hyperparam_search_config": hyperparam_search_config, + "job_uuid": job_uuid, + "logger_config": logger_config, + "optimizer_config": optimizer_config, + "training_config": training_config, + "validation_dataset": validation_dataset, + }, + post_training_preference_optimize_params.PostTrainingPreferenceOptimizeParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -91,7 +119,16 @@ def preference_optimize( def supervised_fine_tune( self, *, - request: post_training_supervised_fine_tune_params.Request, + algorithm: Literal["full", "lora", "qlora", "dora"], + algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig, + dataset: TrainEvalDatasetParam, + hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + job_uuid: str, + logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + model: str, + optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig, + training_config: post_training_supervised_fine_tune_params.TrainingConfig, + validation_dataset: TrainEvalDatasetParam, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -112,7 +149,19 @@ def supervised_fine_tune( return self._post( "/post_training/supervised_fine_tune", body=maybe_transform( - {"request": request}, post_training_supervised_fine_tune_params.PostTrainingSupervisedFineTuneParams + { + "algorithm": algorithm, + "algorithm_config": algorithm_config, + "dataset": dataset, + "hyperparam_search_config": hyperparam_search_config, + "job_uuid": job_uuid, + "logger_config": logger_config, + "model": model, + "optimizer_config": optimizer_config, + "training_config": training_config, + "validation_dataset": validation_dataset, + }, + post_training_supervised_fine_tune_params.PostTrainingSupervisedFineTuneParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -148,7 +197,16 @@ def with_streaming_response(self) -> AsyncPostTrainingResourceWithStreamingRespo async def preference_optimize( self, *, - request: post_training_preference_optimize_params.Request, + algorithm: Literal["dpo"], + algorithm_config: post_training_preference_optimize_params.AlgorithmConfig, + dataset: TrainEvalDatasetParam, + finetuned_model: str, + hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + job_uuid: str, + logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + optimizer_config: post_training_preference_optimize_params.OptimizerConfig, + training_config: post_training_preference_optimize_params.TrainingConfig, + validation_dataset: TrainEvalDatasetParam, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -169,7 +227,19 @@ async def preference_optimize( return await self._post( "/post_training/preference_optimize", body=await async_maybe_transform( - {"request": request}, post_training_preference_optimize_params.PostTrainingPreferenceOptimizeParams + { + "algorithm": algorithm, + "algorithm_config": algorithm_config, + "dataset": dataset, + "finetuned_model": finetuned_model, + "hyperparam_search_config": hyperparam_search_config, + "job_uuid": job_uuid, + "logger_config": logger_config, + "optimizer_config": optimizer_config, + "training_config": training_config, + "validation_dataset": validation_dataset, + }, + post_training_preference_optimize_params.PostTrainingPreferenceOptimizeParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -180,7 +250,16 @@ async def preference_optimize( async def supervised_fine_tune( self, *, - request: post_training_supervised_fine_tune_params.Request, + algorithm: Literal["full", "lora", "qlora", "dora"], + algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig, + dataset: TrainEvalDatasetParam, + hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + job_uuid: str, + logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], + model: str, + optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig, + training_config: post_training_supervised_fine_tune_params.TrainingConfig, + validation_dataset: TrainEvalDatasetParam, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -201,7 +280,19 @@ async def supervised_fine_tune( return await self._post( "/post_training/supervised_fine_tune", body=await async_maybe_transform( - {"request": request}, post_training_supervised_fine_tune_params.PostTrainingSupervisedFineTuneParams + { + "algorithm": algorithm, + "algorithm_config": algorithm_config, + "dataset": dataset, + "hyperparam_search_config": hyperparam_search_config, + "job_uuid": job_uuid, + "logger_config": logger_config, + "model": model, + "optimizer_config": optimizer_config, + "training_config": training_config, + "validation_dataset": validation_dataset, + }, + post_training_supervised_fine_tune_params.PostTrainingSupervisedFineTuneParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack/resources/reward_scoring.py b/src/llama_stack/resources/reward_scoring.py index 405a911..bed1479 100644 --- a/src/llama_stack/resources/reward_scoring.py +++ b/src/llama_stack/resources/reward_scoring.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Iterable + import httpx from ..types import reward_scoring_score_params @@ -47,7 +49,8 @@ def with_streaming_response(self) -> RewardScoringResourceWithStreamingResponse: def score( self, *, - request: reward_scoring_score_params.Request, + dialog_generations: Iterable[reward_scoring_score_params.DialogGeneration], + model: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -67,7 +70,13 @@ def score( """ return self._post( "/reward_scoring/score", - body=maybe_transform({"request": request}, reward_scoring_score_params.RewardScoringScoreParams), + body=maybe_transform( + { + "dialog_generations": dialog_generations, + "model": model, + }, + reward_scoring_score_params.RewardScoringScoreParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -98,7 +107,8 @@ def with_streaming_response(self) -> AsyncRewardScoringResourceWithStreamingResp async def score( self, *, - request: reward_scoring_score_params.Request, + dialog_generations: Iterable[reward_scoring_score_params.DialogGeneration], + model: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -119,7 +129,11 @@ async def score( return await self._post( "/reward_scoring/score", body=await async_maybe_transform( - {"request": request}, reward_scoring_score_params.RewardScoringScoreParams + { + "dialog_generations": dialog_generations, + "model": model, + }, + reward_scoring_score_params.RewardScoringScoreParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack/resources/runs/__init__.py b/src/llama_stack/resources/runs/__init__.py deleted file mode 100644 index cfacb8a..0000000 --- a/src/llama_stack/resources/runs/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from .runs import ( - RunsResource, - AsyncRunsResource, - RunsResourceWithRawResponse, - AsyncRunsResourceWithRawResponse, - RunsResourceWithStreamingResponse, - AsyncRunsResourceWithStreamingResponse, -) -from .metrics import ( - MetricsResource, - AsyncMetricsResource, - MetricsResourceWithRawResponse, - AsyncMetricsResourceWithRawResponse, - MetricsResourceWithStreamingResponse, - AsyncMetricsResourceWithStreamingResponse, -) - -__all__ = [ - "MetricsResource", - "AsyncMetricsResource", - "MetricsResourceWithRawResponse", - "AsyncMetricsResourceWithRawResponse", - "MetricsResourceWithStreamingResponse", - "AsyncMetricsResourceWithStreamingResponse", - "RunsResource", - "AsyncRunsResource", - "RunsResourceWithRawResponse", - "AsyncRunsResourceWithRawResponse", - "RunsResourceWithStreamingResponse", - "AsyncRunsResourceWithStreamingResponse", -] diff --git a/src/llama_stack/resources/runs/metrics.py b/src/llama_stack/resources/runs/metrics.py deleted file mode 100644 index f0e9abb..0000000 --- a/src/llama_stack/resources/runs/metrics.py +++ /dev/null @@ -1,170 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import httpx - -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) -from ..._compat import cached_property -from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from ...types.runs import metric_list_params -from ..._base_client import make_request_options -from ...types.runs.metric_list_response import MetricListResponse - -__all__ = ["MetricsResource", "AsyncMetricsResource"] - - -class MetricsResource(SyncAPIResource): - @cached_property - def with_raw_response(self) -> MetricsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return MetricsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> MetricsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return MetricsResourceWithStreamingResponse(self) - - def list( - self, - *, - run_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> MetricListResponse: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return self._get( - "/runs/metrics", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform({"run_id": run_id}, metric_list_params.MetricListParams), - ), - cast_to=MetricListResponse, - ) - - -class AsyncMetricsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncMetricsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncMetricsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncMetricsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncMetricsResourceWithStreamingResponse(self) - - async def list( - self, - *, - run_id: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> MetricListResponse: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return await self._get( - "/runs/metrics", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform({"run_id": run_id}, metric_list_params.MetricListParams), - ), - cast_to=MetricListResponse, - ) - - -class MetricsResourceWithRawResponse: - def __init__(self, metrics: MetricsResource) -> None: - self._metrics = metrics - - self.list = to_raw_response_wrapper( - metrics.list, - ) - - -class AsyncMetricsResourceWithRawResponse: - def __init__(self, metrics: AsyncMetricsResource) -> None: - self._metrics = metrics - - self.list = async_to_raw_response_wrapper( - metrics.list, - ) - - -class MetricsResourceWithStreamingResponse: - def __init__(self, metrics: MetricsResource) -> None: - self._metrics = metrics - - self.list = to_streamed_response_wrapper( - metrics.list, - ) - - -class AsyncMetricsResourceWithStreamingResponse: - def __init__(self, metrics: AsyncMetricsResource) -> None: - self._metrics = metrics - - self.list = async_to_streamed_response_wrapper( - metrics.list, - ) diff --git a/src/llama_stack/resources/runs/runs.py b/src/llama_stack/resources/runs/runs.py deleted file mode 100644 index 0bd3b4b..0000000 --- a/src/llama_stack/resources/runs/runs.py +++ /dev/null @@ -1,268 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import httpx - -from ...types import run_update_params, run_log_metrics_params -from .metrics import ( - MetricsResource, - AsyncMetricsResource, - MetricsResourceWithRawResponse, - AsyncMetricsResourceWithRawResponse, - MetricsResourceWithStreamingResponse, - AsyncMetricsResourceWithStreamingResponse, -) -from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) -from ..._compat import cached_property -from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import ( - to_raw_response_wrapper, - to_streamed_response_wrapper, - async_to_raw_response_wrapper, - async_to_streamed_response_wrapper, -) -from ..._base_client import make_request_options -from ...types.shared.run import Run - -__all__ = ["RunsResource", "AsyncRunsResource"] - - -class RunsResource(SyncAPIResource): - @cached_property - def metrics(self) -> MetricsResource: - return MetricsResource(self._client) - - @cached_property - def with_raw_response(self) -> RunsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return RunsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> RunsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return RunsResourceWithStreamingResponse(self) - - def update( - self, - *, - request: run_update_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Run: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return self._post( - "/runs/update", - body=maybe_transform({"request": request}, run_update_params.RunUpdateParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Run, - ) - - def log_metrics( - self, - *, - request: run_log_metrics_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._post( - "/runs/log_metrics", - body=maybe_transform({"request": request}, run_log_metrics_params.RunLogMetricsParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - - -class AsyncRunsResource(AsyncAPIResource): - @cached_property - def metrics(self) -> AsyncMetricsResource: - return AsyncMetricsResource(self._client) - - @cached_property - def with_raw_response(self) -> AsyncRunsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return the - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers - """ - return AsyncRunsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncRunsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response - """ - return AsyncRunsResourceWithStreamingResponse(self) - - async def update( - self, - *, - request: run_update_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Run: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - return await self._post( - "/runs/update", - body=await async_maybe_transform({"request": request}, run_update_params.RunUpdateParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Run, - ) - - async def log_metrics( - self, - *, - request: run_log_metrics_params.Request, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._post( - "/runs/log_metrics", - body=await async_maybe_transform({"request": request}, run_log_metrics_params.RunLogMetricsParams), - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - - -class RunsResourceWithRawResponse: - def __init__(self, runs: RunsResource) -> None: - self._runs = runs - - self.update = to_raw_response_wrapper( - runs.update, - ) - self.log_metrics = to_raw_response_wrapper( - runs.log_metrics, - ) - - @cached_property - def metrics(self) -> MetricsResourceWithRawResponse: - return MetricsResourceWithRawResponse(self._runs.metrics) - - -class AsyncRunsResourceWithRawResponse: - def __init__(self, runs: AsyncRunsResource) -> None: - self._runs = runs - - self.update = async_to_raw_response_wrapper( - runs.update, - ) - self.log_metrics = async_to_raw_response_wrapper( - runs.log_metrics, - ) - - @cached_property - def metrics(self) -> AsyncMetricsResourceWithRawResponse: - return AsyncMetricsResourceWithRawResponse(self._runs.metrics) - - -class RunsResourceWithStreamingResponse: - def __init__(self, runs: RunsResource) -> None: - self._runs = runs - - self.update = to_streamed_response_wrapper( - runs.update, - ) - self.log_metrics = to_streamed_response_wrapper( - runs.log_metrics, - ) - - @cached_property - def metrics(self) -> MetricsResourceWithStreamingResponse: - return MetricsResourceWithStreamingResponse(self._runs.metrics) - - -class AsyncRunsResourceWithStreamingResponse: - def __init__(self, runs: AsyncRunsResource) -> None: - self._runs = runs - - self.update = async_to_streamed_response_wrapper( - runs.update, - ) - self.log_metrics = async_to_streamed_response_wrapper( - runs.log_metrics, - ) - - @cached_property - def metrics(self) -> AsyncMetricsResourceWithStreamingResponse: - return AsyncMetricsResourceWithStreamingResponse(self._runs.metrics) diff --git a/src/llama_stack/resources/safety.py b/src/llama_stack/resources/safety.py index 54f2daf..3943f13 100644 --- a/src/llama_stack/resources/safety.py +++ b/src/llama_stack/resources/safety.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Iterable + import httpx from ..types import safety_run_shields_params @@ -19,6 +21,7 @@ async_to_streamed_response_wrapper, ) from .._base_client import make_request_options +from ..types.shield_definition_param import ShieldDefinitionParam from ..types.safety_run_shields_response import SafetyRunShieldsResponse __all__ = ["SafetyResource", "AsyncSafetyResource"] @@ -47,7 +50,8 @@ def with_streaming_response(self) -> SafetyResourceWithStreamingResponse: def run_shields( self, *, - request: safety_run_shields_params.Request, + messages: Iterable[safety_run_shields_params.Message], + shields: Iterable[ShieldDefinitionParam], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -67,7 +71,13 @@ def run_shields( """ return self._post( "/safety/run_shields", - body=maybe_transform({"request": request}, safety_run_shields_params.SafetyRunShieldsParams), + body=maybe_transform( + { + "messages": messages, + "shields": shields, + }, + safety_run_shields_params.SafetyRunShieldsParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -98,7 +108,8 @@ def with_streaming_response(self) -> AsyncSafetyResourceWithStreamingResponse: async def run_shields( self, *, - request: safety_run_shields_params.Request, + messages: Iterable[safety_run_shields_params.Message], + shields: Iterable[ShieldDefinitionParam], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -118,7 +129,13 @@ async def run_shields( """ return await self._post( "/safety/run_shields", - body=await async_maybe_transform({"request": request}, safety_run_shields_params.SafetyRunShieldsParams), + body=await async_maybe_transform( + { + "messages": messages, + "shields": shields, + }, + safety_run_shields_params.SafetyRunShieldsParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/llama_stack/resources/synthetic_data_generation.py b/src/llama_stack/resources/synthetic_data_generation.py index a602ab8..0847308 100644 --- a/src/llama_stack/resources/synthetic_data_generation.py +++ b/src/llama_stack/resources/synthetic_data_generation.py @@ -2,6 +2,9 @@ from __future__ import annotations +from typing import Iterable +from typing_extensions import Literal + import httpx from ..types import synthetic_data_generation_generate_params @@ -47,7 +50,9 @@ def with_streaming_response(self) -> SyntheticDataGenerationResourceWithStreamin def generate( self, *, - request: synthetic_data_generation_generate_params.Request, + dialogs: Iterable[synthetic_data_generation_generate_params.Dialog], + filtering_function: Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"], + model: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -68,7 +73,12 @@ def generate( return self._post( "/synthetic_data_generation/generate", body=maybe_transform( - {"request": request}, synthetic_data_generation_generate_params.SyntheticDataGenerationGenerateParams + { + "dialogs": dialogs, + "filtering_function": filtering_function, + "model": model, + }, + synthetic_data_generation_generate_params.SyntheticDataGenerationGenerateParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -100,7 +110,9 @@ def with_streaming_response(self) -> AsyncSyntheticDataGenerationResourceWithStr async def generate( self, *, - request: synthetic_data_generation_generate_params.Request, + dialogs: Iterable[synthetic_data_generation_generate_params.Dialog], + filtering_function: Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"], + model: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -121,7 +133,12 @@ async def generate( return await self._post( "/synthetic_data_generation/generate", body=await async_maybe_transform( - {"request": request}, synthetic_data_generation_generate_params.SyntheticDataGenerationGenerateParams + { + "dialogs": dialogs, + "filtering_function": filtering_function, + "model": model, + }, + synthetic_data_generation_generate_params.SyntheticDataGenerationGenerateParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/llama_stack/resources/logging.py b/src/llama_stack/resources/telemetry.py similarity index 61% rename from src/llama_stack/resources/logging.py rename to src/llama_stack/resources/telemetry.py index 5ae1e6b..b3e0524 100644 --- a/src/llama_stack/resources/logging.py +++ b/src/llama_stack/resources/telemetry.py @@ -4,7 +4,7 @@ import httpx -from ..types import logging_get_logs_params, logging_log_messages_params +from ..types import telemetry_log_params, telemetry_get_trace_params from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, @@ -19,42 +19,42 @@ async_to_streamed_response_wrapper, ) from .._base_client import make_request_options -from ..types.logging_get_logs_response import LoggingGetLogsResponse +from ..types.telemetry_get_trace_response import TelemetryGetTraceResponse -__all__ = ["LoggingResource", "AsyncLoggingResource"] +__all__ = ["TelemetryResource", "AsyncTelemetryResource"] -class LoggingResource(SyncAPIResource): +class TelemetryResource(SyncAPIResource): @cached_property - def with_raw_response(self) -> LoggingResourceWithRawResponse: + def with_raw_response(self) -> TelemetryResourceWithRawResponse: """ This property can be used as a prefix for any HTTP method call to return the the raw response object instead of the parsed content. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers """ - return LoggingResourceWithRawResponse(self) + return TelemetryResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> LoggingResourceWithStreamingResponse: + def with_streaming_response(self) -> TelemetryResourceWithStreamingResponse: """ An alternative to `.with_raw_response` that doesn't eagerly read the response body. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response """ - return LoggingResourceWithStreamingResponse(self) + return TelemetryResourceWithStreamingResponse(self) - def get_logs( + def get_trace( self, *, - request: logging_get_logs_params.Request, + trace_id: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> LoggingGetLogsResponse: + ) -> TelemetryGetTraceResponse: """ Args: extra_headers: Send extra headers @@ -65,20 +65,22 @@ def get_logs( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return self._post( - "/logging/get_logs", - body=maybe_transform({"request": request}, logging_get_logs_params.LoggingGetLogsParams), + return self._get( + "/telemetry/get_trace", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"trace_id": trace_id}, telemetry_get_trace_params.TelemetryGetTraceParams), ), - cast_to=LoggingGetLogsResponse, + cast_to=TelemetryGetTraceResponse, ) - def log_messages( + def log( self, *, - request: logging_log_messages_params.Request, + event: telemetry_log_params.Event, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -98,8 +100,8 @@ def log_messages( """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( - "/logging/log_messages", - body=maybe_transform({"request": request}, logging_log_messages_params.LoggingLogMessagesParams), + "/telemetry/log_event", + body=maybe_transform({"event": event}, telemetry_log_params.TelemetryLogParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -107,37 +109,37 @@ def log_messages( ) -class AsyncLoggingResource(AsyncAPIResource): +class AsyncTelemetryResource(AsyncAPIResource): @cached_property - def with_raw_response(self) -> AsyncLoggingResourceWithRawResponse: + def with_raw_response(self) -> AsyncTelemetryResourceWithRawResponse: """ This property can be used as a prefix for any HTTP method call to return the the raw response object instead of the parsed content. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers """ - return AsyncLoggingResourceWithRawResponse(self) + return AsyncTelemetryResourceWithRawResponse(self) @cached_property - def with_streaming_response(self) -> AsyncLoggingResourceWithStreamingResponse: + def with_streaming_response(self) -> AsyncTelemetryResourceWithStreamingResponse: """ An alternative to `.with_raw_response` that doesn't eagerly read the response body. For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response """ - return AsyncLoggingResourceWithStreamingResponse(self) + return AsyncTelemetryResourceWithStreamingResponse(self) - async def get_logs( + async def get_trace( self, *, - request: logging_get_logs_params.Request, + trace_id: str, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> LoggingGetLogsResponse: + ) -> TelemetryGetTraceResponse: """ Args: extra_headers: Send extra headers @@ -148,20 +150,24 @@ async def get_logs( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - return await self._post( - "/logging/get_logs", - body=await async_maybe_transform({"request": request}, logging_get_logs_params.LoggingGetLogsParams), + return await self._get( + "/telemetry/get_trace", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + {"trace_id": trace_id}, telemetry_get_trace_params.TelemetryGetTraceParams + ), ), - cast_to=LoggingGetLogsResponse, + cast_to=TelemetryGetTraceResponse, ) - async def log_messages( + async def log( self, *, - request: logging_log_messages_params.Request, + event: telemetry_log_params.Event, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -181,10 +187,8 @@ async def log_messages( """ extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( - "/logging/log_messages", - body=await async_maybe_transform( - {"request": request}, logging_log_messages_params.LoggingLogMessagesParams - ), + "/telemetry/log_event", + body=await async_maybe_transform({"event": event}, telemetry_log_params.TelemetryLogParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -192,49 +196,49 @@ async def log_messages( ) -class LoggingResourceWithRawResponse: - def __init__(self, logging: LoggingResource) -> None: - self._logging = logging +class TelemetryResourceWithRawResponse: + def __init__(self, telemetry: TelemetryResource) -> None: + self._telemetry = telemetry - self.get_logs = to_raw_response_wrapper( - logging.get_logs, + self.get_trace = to_raw_response_wrapper( + telemetry.get_trace, ) - self.log_messages = to_raw_response_wrapper( - logging.log_messages, + self.log = to_raw_response_wrapper( + telemetry.log, ) -class AsyncLoggingResourceWithRawResponse: - def __init__(self, logging: AsyncLoggingResource) -> None: - self._logging = logging +class AsyncTelemetryResourceWithRawResponse: + def __init__(self, telemetry: AsyncTelemetryResource) -> None: + self._telemetry = telemetry - self.get_logs = async_to_raw_response_wrapper( - logging.get_logs, + self.get_trace = async_to_raw_response_wrapper( + telemetry.get_trace, ) - self.log_messages = async_to_raw_response_wrapper( - logging.log_messages, + self.log = async_to_raw_response_wrapper( + telemetry.log, ) -class LoggingResourceWithStreamingResponse: - def __init__(self, logging: LoggingResource) -> None: - self._logging = logging +class TelemetryResourceWithStreamingResponse: + def __init__(self, telemetry: TelemetryResource) -> None: + self._telemetry = telemetry - self.get_logs = to_streamed_response_wrapper( - logging.get_logs, + self.get_trace = to_streamed_response_wrapper( + telemetry.get_trace, ) - self.log_messages = to_streamed_response_wrapper( - logging.log_messages, + self.log = to_streamed_response_wrapper( + telemetry.log, ) -class AsyncLoggingResourceWithStreamingResponse: - def __init__(self, logging: AsyncLoggingResource) -> None: - self._logging = logging +class AsyncTelemetryResourceWithStreamingResponse: + def __init__(self, telemetry: AsyncTelemetryResource) -> None: + self._telemetry = telemetry - self.get_logs = async_to_streamed_response_wrapper( - logging.get_logs, + self.get_trace = async_to_streamed_response_wrapper( + telemetry.get_trace, ) - self.log_messages = async_to_streamed_response_wrapper( - logging.log_messages, + self.log = async_to_streamed_response_wrapper( + telemetry.log, ) diff --git a/src/llama_stack/types/__init__.py b/src/llama_stack/types/__init__.py index b445c57..cc354da 100644 --- a/src/llama_stack/types/__init__.py +++ b/src/llama_stack/types/__init__.py @@ -3,8 +3,6 @@ from __future__ import annotations from .shared import ( - Run as Run, - Artifact as Artifact, ToolCall as ToolCall, Attachment as Attachment, UserMessage as UserMessage, @@ -14,33 +12,27 @@ CompletionMessage as CompletionMessage, ToolResponseMessage as ToolResponseMessage, ) -from .experiment import Experiment as Experiment from .evaluation_job import EvaluationJob as EvaluationJob from .inference_step import InferenceStep as InferenceStep from .reward_scoring import RewardScoring as RewardScoring from .sheid_response import SheidResponse as SheidResponse from .query_documents import QueryDocuments as QueryDocuments +from .token_log_probs import TokenLogProbs as TokenLogProbs from .shield_call_step import ShieldCallStep as ShieldCallStep from .post_training_job import PostTrainingJob as PostTrainingJob -from .run_update_params import RunUpdateParams as RunUpdateParams from .dataset_get_params import DatasetGetParams as DatasetGetParams from .train_eval_dataset import TrainEvalDataset as TrainEvalDataset -from .artifact_get_params import ArtifactGetParams as ArtifactGetParams from .tool_execution_step import ToolExecutionStep as ToolExecutionStep +from .telemetry_log_params import TelemetryLogParams as TelemetryLogParams from .batch_chat_completion import BatchChatCompletion as BatchChatCompletion from .dataset_create_params import DatasetCreateParams as DatasetCreateParams from .dataset_delete_params import DatasetDeleteParams as DatasetDeleteParams from .memory_retrieval_step import MemoryRetrievalStep as MemoryRetrievalStep -from .run_log_metrics_params import RunLogMetricsParams as RunLogMetricsParams from .completion_stream_chunk import CompletionStreamChunk as CompletionStreamChunk -from .logging_get_logs_params import LoggingGetLogsParams as LoggingGetLogsParams from .memory_bank_drop_params import MemoryBankDropParams as MemoryBankDropParams from .shield_definition_param import ShieldDefinitionParam as ShieldDefinitionParam -from .experiment_create_params import ExperimentCreateParams as ExperimentCreateParams -from .experiment_update_params import ExperimentUpdateParams as ExperimentUpdateParams from .memory_bank_query_params import MemoryBankQueryParams as MemoryBankQueryParams from .train_eval_dataset_param import TrainEvalDatasetParam as TrainEvalDatasetParam -from .logging_get_logs_response import LoggingGetLogsResponse as LoggingGetLogsResponse from .memory_bank_create_params import MemoryBankCreateParams as MemoryBankCreateParams from .memory_bank_drop_response import MemoryBankDropResponse as MemoryBankDropResponse from .memory_bank_insert_params import MemoryBankInsertParams as MemoryBankInsertParams @@ -48,9 +40,8 @@ from .safety_run_shields_params import SafetyRunShieldsParams as SafetyRunShieldsParams from .scored_dialog_generations import ScoredDialogGenerations as ScoredDialogGenerations from .synthetic_data_generation import SyntheticDataGeneration as SyntheticDataGeneration -from .experiment_retrieve_params import ExperimentRetrieveParams as ExperimentRetrieveParams +from .telemetry_get_trace_params import TelemetryGetTraceParams as TelemetryGetTraceParams from .inference_completion_params import InferenceCompletionParams as InferenceCompletionParams -from .logging_log_messages_params import LoggingLogMessagesParams as LoggingLogMessagesParams from .memory_bank_retrieve_params import MemoryBankRetrieveParams as MemoryBankRetrieveParams from .reward_scoring_score_params import RewardScoringScoreParams as RewardScoringScoreParams from .safety_run_shields_response import SafetyRunShieldsResponse as SafetyRunShieldsResponse @@ -58,13 +49,18 @@ from .agentic_system_create_params import AgenticSystemCreateParams as AgenticSystemCreateParams from .agentic_system_delete_params import AgenticSystemDeleteParams as AgenticSystemDeleteParams from .chat_completion_stream_chunk import ChatCompletionStreamChunk as ChatCompletionStreamChunk -from .experiment_create_run_params import ExperimentCreateRunParams as ExperimentCreateRunParams +from .telemetry_get_trace_response import TelemetryGetTraceResponse as TelemetryGetTraceResponse +from .inference_completion_response import InferenceCompletionResponse as InferenceCompletionResponse from .agentic_system_create_response import AgenticSystemCreateResponse as AgenticSystemCreateResponse from .evaluation_summarization_params import EvaluationSummarizationParams as EvaluationSummarizationParams from .rest_api_execution_config_param import RestAPIExecutionConfigParam as RestAPIExecutionConfigParam from .inference_chat_completion_params import InferenceChatCompletionParams as InferenceChatCompletionParams +from .llm_query_generator_config_param import LlmQueryGeneratorConfigParam as LlmQueryGeneratorConfigParam from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams from .evaluation_text_generation_params import EvaluationTextGenerationParams as EvaluationTextGenerationParams +from .inference_chat_completion_response import InferenceChatCompletionResponse as InferenceChatCompletionResponse +from .custom_query_generator_config_param import CustomQueryGeneratorConfigParam as CustomQueryGeneratorConfigParam +from .default_query_generator_config_param import DefaultQueryGeneratorConfigParam as DefaultQueryGeneratorConfigParam from .batch_inference_chat_completion_params import ( BatchInferenceChatCompletionParams as BatchInferenceChatCompletionParams, ) diff --git a/src/llama_stack/types/agentic_system/__init__.py b/src/llama_stack/types/agentic_system/__init__.py index dfbecc4..5db1683 100644 --- a/src/llama_stack/types/agentic_system/__init__.py +++ b/src/llama_stack/types/agentic_system/__init__.py @@ -4,6 +4,7 @@ from .turn import Turn as Turn from .session import Session as Session +from .turn_stream_event import TurnStreamEvent as TurnStreamEvent from .turn_create_params import TurnCreateParams as TurnCreateParams from .agentic_system_step import AgenticSystemStep as AgenticSystemStep from .step_retrieve_params import StepRetrieveParams as StepRetrieveParams diff --git a/src/llama_stack/types/agentic_system/agentic_system_turn_stream_chunk.py b/src/llama_stack/types/agentic_system/agentic_system_turn_stream_chunk.py index fff362c..0b6a2cd 100644 --- a/src/llama_stack/types/agentic_system/agentic_system_turn_stream_chunk.py +++ b/src/llama_stack/types/agentic_system/agentic_system_turn_stream_chunk.py @@ -1,103 +1,12 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional -from typing_extensions import Literal, TypeAlias -from pydantic import Field as FieldInfo -from .turn import Turn from ..._models import BaseModel -from ..inference_step import InferenceStep -from ..shared.tool_call import ToolCall -from ..shield_call_step import ShieldCallStep -from ..tool_execution_step import ToolExecutionStep -from ..memory_retrieval_step import MemoryRetrievalStep +from .turn_stream_event import TurnStreamEvent -__all__ = [ - "AgenticSystemTurnStreamChunk", - "Event", - "EventPayload", - "EventPayloadAgenticSystemTurnResponseStepStartPayload", - "EventPayloadAgenticSystemTurnResponseStepProgressPayload", - "EventPayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDelta", - "EventPayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDeltaContent", - "EventPayloadAgenticSystemTurnResponseStepCompletePayload", - "EventPayloadAgenticSystemTurnResponseStepCompletePayloadStepDetails", - "EventPayloadAgenticSystemTurnResponseTurnStartPayload", - "EventPayloadAgenticSystemTurnResponseTurnCompletePayload", -] - - -class EventPayloadAgenticSystemTurnResponseStepStartPayload(BaseModel): - event_type: Literal["step_start"] - - step_id: str - - step_type: Literal["inference", "tool_execution", "shield_call", "memory_retrieval"] - - metadata: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None - - -EventPayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDeltaContent: TypeAlias = Union[str, ToolCall] - - -class EventPayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDelta(BaseModel): - content: EventPayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDeltaContent - - parse_status: Literal["started", "in_progress", "failure", "success"] - - -class EventPayloadAgenticSystemTurnResponseStepProgressPayload(BaseModel): - event_type: Literal["step_progress"] - - step_id: str - - step_type: Literal["inference", "tool_execution", "shield_call", "memory_retrieval"] - - text_delta_model_response: Optional[str] = FieldInfo(alias="model_response_text_delta", default=None) - - tool_call_delta: Optional[EventPayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDelta] = None - - tool_response_text_delta: Optional[str] = None - - -EventPayloadAgenticSystemTurnResponseStepCompletePayloadStepDetails: TypeAlias = Union[ - InferenceStep, ToolExecutionStep, ShieldCallStep, MemoryRetrievalStep -] - - -class EventPayloadAgenticSystemTurnResponseStepCompletePayload(BaseModel): - event_type: Literal["step_complete"] - - step_details: EventPayloadAgenticSystemTurnResponseStepCompletePayloadStepDetails - - step_type: Literal["inference", "tool_execution", "shield_call", "memory_retrieval"] - - -class EventPayloadAgenticSystemTurnResponseTurnStartPayload(BaseModel): - event_type: Literal["turn_start"] - - turn_id: str - - -class EventPayloadAgenticSystemTurnResponseTurnCompletePayload(BaseModel): - event_type: Literal["turn_complete"] - - turn: Turn - - -EventPayload: TypeAlias = Union[ - EventPayloadAgenticSystemTurnResponseStepStartPayload, - EventPayloadAgenticSystemTurnResponseStepProgressPayload, - EventPayloadAgenticSystemTurnResponseStepCompletePayload, - EventPayloadAgenticSystemTurnResponseTurnStartPayload, - EventPayloadAgenticSystemTurnResponseTurnCompletePayload, -] - - -class Event(BaseModel): - payload: EventPayload +__all__ = ["AgenticSystemTurnStreamChunk"] class AgenticSystemTurnStreamChunk(BaseModel): - event: Event + event: TurnStreamEvent diff --git a/src/llama_stack/types/agentic_system/turn_create_params.py b/src/llama_stack/types/agentic_system/turn_create_params.py index b0c1969..ffbb54c 100644 --- a/src/llama_stack/types/agentic_system/turn_create_params.py +++ b/src/llama_stack/types/agentic_system/turn_create_params.py @@ -2,226 +2,35 @@ from __future__ import annotations -from typing import Dict, List, Union, Iterable +from typing import Union, Iterable from typing_extensions import Literal, Required, TypeAlias, TypedDict -from ..shield_definition_param import ShieldDefinitionParam from ..shared_params.attachment import Attachment from ..shared_params.user_message import UserMessage -from ..tool_param_definition_param import ToolParamDefinitionParam -from ..shared_params.sampling_params import SamplingParams -from ..rest_api_execution_config_param import RestAPIExecutionConfigParam from ..shared_params.tool_response_message import ToolResponseMessage -__all__ = [ - "TurnCreateParams", - "Request", - "RequestMessage", - "RequestTool", - "RequestToolSearchToolDefinition", - "RequestToolWolframAlphaToolDefinition", - "RequestToolPhotogenToolDefinition", - "RequestToolCodeInterpreterToolDefinition", - "RequestToolFunctionCallToolDefinition", - "RequestToolUnionMember5", - "RequestToolUnionMember5MemoryBankConfig", - "RequestToolUnionMember5MemoryBankConfigUnionMember0", - "RequestToolUnionMember5MemoryBankConfigUnionMember1", - "RequestToolUnionMember5MemoryBankConfigUnionMember2", - "RequestToolUnionMember5MemoryBankConfigUnionMember3", - "RequestToolUnionMember5QueryGeneratorConfig", - "RequestToolUnionMember5QueryGeneratorConfigUnionMember0", - "RequestToolUnionMember5QueryGeneratorConfigUnionMember1", - "RequestToolUnionMember5QueryGeneratorConfigType", -] +__all__ = ["TurnCreateParamsBase", "Message", "TurnCreateParamsNonStreaming", "TurnCreateParamsStreaming"] -class TurnCreateParams(TypedDict, total=False): - request: Required[Request] - - -RequestMessage: TypeAlias = Union[UserMessage, ToolResponseMessage] - - -class RequestToolSearchToolDefinition(TypedDict, total=False): - engine: Required[Literal["bing", "brave"]] - - type: Required[Literal["brave_search"]] - - input_shields: Iterable[ShieldDefinitionParam] - - output_shields: Iterable[ShieldDefinitionParam] - - remote_execution: RestAPIExecutionConfigParam - - -class RequestToolWolframAlphaToolDefinition(TypedDict, total=False): - type: Required[Literal["wolfram_alpha"]] - - input_shields: Iterable[ShieldDefinitionParam] - - output_shields: Iterable[ShieldDefinitionParam] - - remote_execution: RestAPIExecutionConfigParam - - -class RequestToolPhotogenToolDefinition(TypedDict, total=False): - type: Required[Literal["photogen"]] - - input_shields: Iterable[ShieldDefinitionParam] - - output_shields: Iterable[ShieldDefinitionParam] - - remote_execution: RestAPIExecutionConfigParam - - -class RequestToolCodeInterpreterToolDefinition(TypedDict, total=False): - enable_inline_code_execution: Required[bool] - - type: Required[Literal["code_interpreter"]] - - input_shields: Iterable[ShieldDefinitionParam] - - output_shields: Iterable[ShieldDefinitionParam] - - remote_execution: RestAPIExecutionConfigParam - - -class RequestToolFunctionCallToolDefinition(TypedDict, total=False): - description: Required[str] - - function_name: Required[str] - - parameters: Required[Dict[str, ToolParamDefinitionParam]] - - type: Required[Literal["function_call"]] - - input_shields: Iterable[ShieldDefinitionParam] - - output_shields: Iterable[ShieldDefinitionParam] - - remote_execution: RestAPIExecutionConfigParam - - -class RequestToolUnionMember5MemoryBankConfigUnionMember0(TypedDict, total=False): - bank_id: Required[str] - - type: Required[Literal["vector"]] - - -class RequestToolUnionMember5MemoryBankConfigUnionMember1(TypedDict, total=False): - bank_id: Required[str] - - keys: Required[List[str]] - - type: Required[Literal["keyvalue"]] - - -class RequestToolUnionMember5MemoryBankConfigUnionMember2(TypedDict, total=False): - bank_id: Required[str] - - type: Required[Literal["keyword"]] - - -class RequestToolUnionMember5MemoryBankConfigUnionMember3(TypedDict, total=False): - bank_id: Required[str] - - entities: Required[List[str]] - - type: Required[Literal["graph"]] - - -RequestToolUnionMember5MemoryBankConfig: TypeAlias = Union[ - RequestToolUnionMember5MemoryBankConfigUnionMember0, - RequestToolUnionMember5MemoryBankConfigUnionMember1, - RequestToolUnionMember5MemoryBankConfigUnionMember2, - RequestToolUnionMember5MemoryBankConfigUnionMember3, -] - - -class RequestToolUnionMember5QueryGeneratorConfigUnionMember0(TypedDict, total=False): - sep: Required[str] - - type: Required[Literal["default"]] - - -class RequestToolUnionMember5QueryGeneratorConfigUnionMember1(TypedDict, total=False): - model: Required[str] - - template: Required[str] - - type: Required[Literal["llm"]] - - -class RequestToolUnionMember5QueryGeneratorConfigType(TypedDict, total=False): - type: Required[Literal["custom"]] - - -RequestToolUnionMember5QueryGeneratorConfig: TypeAlias = Union[ - RequestToolUnionMember5QueryGeneratorConfigUnionMember0, - RequestToolUnionMember5QueryGeneratorConfigUnionMember1, - RequestToolUnionMember5QueryGeneratorConfigType, -] - - -class RequestToolUnionMember5(TypedDict, total=False): - max_chunks: Required[int] - - max_tokens_in_context: Required[int] - - memory_bank_configs: Required[Iterable[RequestToolUnionMember5MemoryBankConfig]] - - query_generator_config: Required[RequestToolUnionMember5QueryGeneratorConfig] - - type: Required[Literal["memory"]] - - input_shields: Iterable[ShieldDefinitionParam] - - output_shields: Iterable[ShieldDefinitionParam] - - -RequestTool: TypeAlias = Union[ - RequestToolSearchToolDefinition, - RequestToolWolframAlphaToolDefinition, - RequestToolPhotogenToolDefinition, - RequestToolCodeInterpreterToolDefinition, - RequestToolFunctionCallToolDefinition, - RequestToolUnionMember5, -] - - -class Request(TypedDict, total=False): +class TurnCreateParamsBase(TypedDict, total=False): agent_id: Required[str] - messages: Required[Iterable[RequestMessage]] + messages: Required[Iterable[Message]] session_id: Required[str] attachments: Iterable[Attachment] - input_shields: Iterable[ShieldDefinitionParam] - - instructions: str - - output_shields: Iterable[ShieldDefinitionParam] - sampling_params: SamplingParams +Message: TypeAlias = Union[UserMessage, ToolResponseMessage] - stream: bool - tool_choice: Literal["auto", "required"] +class TurnCreateParamsNonStreaming(TurnCreateParamsBase, total=False): + stream: Literal[False] - tool_prompt_format: Literal["json", "function_tag"] - """ - `json` -- Refers to the json format for calling tools. The json format takes the - form like { "type": "function", "function" : { "name": "function_name", - "description": "function_description", "parameters": {...} } } - `function_tag` -- This is an example of how you could define your own user - defined format for making tool calls. The function_tag format looks like this, - (parameters) +class TurnCreateParamsStreaming(TurnCreateParamsBase): + stream: Required[Literal[True]] - The detailed prompts for each of these formats are added to llama cli - """ - tools: Iterable[RequestTool] +TurnCreateParams = Union[TurnCreateParamsNonStreaming, TurnCreateParamsStreaming] diff --git a/src/llama_stack/types/agentic_system/turn_stream_event.py b/src/llama_stack/types/agentic_system/turn_stream_event.py new file mode 100644 index 0000000..7e627ec --- /dev/null +++ b/src/llama_stack/types/agentic_system/turn_stream_event.py @@ -0,0 +1,98 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal, TypeAlias + +from pydantic import Field as FieldInfo + +from .turn import Turn +from ..._models import BaseModel +from ..inference_step import InferenceStep +from ..shared.tool_call import ToolCall +from ..shield_call_step import ShieldCallStep +from ..tool_execution_step import ToolExecutionStep +from ..memory_retrieval_step import MemoryRetrievalStep + +__all__ = [ + "TurnStreamEvent", + "Payload", + "PayloadAgenticSystemTurnResponseStepStartPayload", + "PayloadAgenticSystemTurnResponseStepProgressPayload", + "PayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDelta", + "PayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDeltaContent", + "PayloadAgenticSystemTurnResponseStepCompletePayload", + "PayloadAgenticSystemTurnResponseStepCompletePayloadStepDetails", + "PayloadAgenticSystemTurnResponseTurnStartPayload", + "PayloadAgenticSystemTurnResponseTurnCompletePayload", +] + + +class PayloadAgenticSystemTurnResponseStepStartPayload(BaseModel): + event_type: Literal["step_start"] + + step_id: str + + step_type: Literal["inference", "tool_execution", "shield_call", "memory_retrieval"] + + metadata: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None + + +PayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDeltaContent: TypeAlias = Union[str, ToolCall] + + +class PayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDelta(BaseModel): + content: PayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDeltaContent + + parse_status: Literal["started", "in_progress", "failure", "success"] + + +class PayloadAgenticSystemTurnResponseStepProgressPayload(BaseModel): + event_type: Literal["step_progress"] + + step_id: str + + step_type: Literal["inference", "tool_execution", "shield_call", "memory_retrieval"] + + text_delta_model_response: Optional[str] = FieldInfo(alias="model_response_text_delta", default=None) + + tool_call_delta: Optional[PayloadAgenticSystemTurnResponseStepProgressPayloadToolCallDelta] = None + + tool_response_text_delta: Optional[str] = None + + +PayloadAgenticSystemTurnResponseStepCompletePayloadStepDetails: TypeAlias = Union[ + InferenceStep, ToolExecutionStep, ShieldCallStep, MemoryRetrievalStep +] + + +class PayloadAgenticSystemTurnResponseStepCompletePayload(BaseModel): + event_type: Literal["step_complete"] + + step_details: PayloadAgenticSystemTurnResponseStepCompletePayloadStepDetails + + step_type: Literal["inference", "tool_execution", "shield_call", "memory_retrieval"] + + +class PayloadAgenticSystemTurnResponseTurnStartPayload(BaseModel): + event_type: Literal["turn_start"] + + turn_id: str + + +class PayloadAgenticSystemTurnResponseTurnCompletePayload(BaseModel): + event_type: Literal["turn_complete"] + + turn: Turn + + +Payload: TypeAlias = Union[ + PayloadAgenticSystemTurnResponseStepStartPayload, + PayloadAgenticSystemTurnResponseStepProgressPayload, + PayloadAgenticSystemTurnResponseStepCompletePayload, + PayloadAgenticSystemTurnResponseTurnStartPayload, + PayloadAgenticSystemTurnResponseTurnCompletePayload, +] + + +class TurnStreamEvent(BaseModel): + payload: Payload diff --git a/src/llama_stack/types/agentic_system_create_params.py b/src/llama_stack/types/agentic_system_create_params.py index 9fb9767..69290da 100644 --- a/src/llama_stack/types/agentic_system_create_params.py +++ b/src/llama_stack/types/agentic_system_create_params.py @@ -9,6 +9,9 @@ from .tool_param_definition_param import ToolParamDefinitionParam from .shared_params.sampling_params import SamplingParams from .rest_api_execution_config_param import RestAPIExecutionConfigParam +from .llm_query_generator_config_param import LlmQueryGeneratorConfigParam +from .custom_query_generator_config_param import CustomQueryGeneratorConfigParam +from .default_query_generator_config_param import DefaultQueryGeneratorConfigParam __all__ = [ "AgenticSystemCreateParams", @@ -19,16 +22,13 @@ "AgentConfigToolPhotogenToolDefinition", "AgentConfigToolCodeInterpreterToolDefinition", "AgentConfigToolFunctionCallToolDefinition", - "AgentConfigToolUnionMember5", - "AgentConfigToolUnionMember5MemoryBankConfig", - "AgentConfigToolUnionMember5MemoryBankConfigUnionMember0", - "AgentConfigToolUnionMember5MemoryBankConfigUnionMember1", - "AgentConfigToolUnionMember5MemoryBankConfigUnionMember2", - "AgentConfigToolUnionMember5MemoryBankConfigUnionMember3", - "AgentConfigToolUnionMember5QueryGeneratorConfig", - "AgentConfigToolUnionMember5QueryGeneratorConfigUnionMember0", - "AgentConfigToolUnionMember5QueryGeneratorConfigUnionMember1", - "AgentConfigToolUnionMember5QueryGeneratorConfigType", + "AgentConfigToolShield", + "AgentConfigToolShieldMemoryBankConfig", + "AgentConfigToolShieldMemoryBankConfigVector", + "AgentConfigToolShieldMemoryBankConfigKeyvalue", + "AgentConfigToolShieldMemoryBankConfigKeyword", + "AgentConfigToolShieldMemoryBankConfigGraph", + "AgentConfigToolShieldQueryGeneratorConfig", ] @@ -96,13 +96,13 @@ class AgentConfigToolFunctionCallToolDefinition(TypedDict, total=False): remote_execution: RestAPIExecutionConfigParam -class AgentConfigToolUnionMember5MemoryBankConfigUnionMember0(TypedDict, total=False): +class AgentConfigToolShieldMemoryBankConfigVector(TypedDict, total=False): bank_id: Required[str] type: Required[Literal["vector"]] -class AgentConfigToolUnionMember5MemoryBankConfigUnionMember1(TypedDict, total=False): +class AgentConfigToolShieldMemoryBankConfigKeyvalue(TypedDict, total=False): bank_id: Required[str] keys: Required[List[str]] @@ -110,13 +110,13 @@ class AgentConfigToolUnionMember5MemoryBankConfigUnionMember1(TypedDict, total=F type: Required[Literal["keyvalue"]] -class AgentConfigToolUnionMember5MemoryBankConfigUnionMember2(TypedDict, total=False): +class AgentConfigToolShieldMemoryBankConfigKeyword(TypedDict, total=False): bank_id: Required[str] type: Required[Literal["keyword"]] -class AgentConfigToolUnionMember5MemoryBankConfigUnionMember3(TypedDict, total=False): +class AgentConfigToolShieldMemoryBankConfigGraph(TypedDict, total=False): bank_id: Required[str] entities: Required[List[str]] @@ -124,47 +124,26 @@ class AgentConfigToolUnionMember5MemoryBankConfigUnionMember3(TypedDict, total=F type: Required[Literal["graph"]] -AgentConfigToolUnionMember5MemoryBankConfig: TypeAlias = Union[ - AgentConfigToolUnionMember5MemoryBankConfigUnionMember0, - AgentConfigToolUnionMember5MemoryBankConfigUnionMember1, - AgentConfigToolUnionMember5MemoryBankConfigUnionMember2, - AgentConfigToolUnionMember5MemoryBankConfigUnionMember3, +AgentConfigToolShieldMemoryBankConfig: TypeAlias = Union[ + AgentConfigToolShieldMemoryBankConfigVector, + AgentConfigToolShieldMemoryBankConfigKeyvalue, + AgentConfigToolShieldMemoryBankConfigKeyword, + AgentConfigToolShieldMemoryBankConfigGraph, ] - -class AgentConfigToolUnionMember5QueryGeneratorConfigUnionMember0(TypedDict, total=False): - sep: Required[str] - - type: Required[Literal["default"]] - - -class AgentConfigToolUnionMember5QueryGeneratorConfigUnionMember1(TypedDict, total=False): - model: Required[str] - - template: Required[str] - - type: Required[Literal["llm"]] - - -class AgentConfigToolUnionMember5QueryGeneratorConfigType(TypedDict, total=False): - type: Required[Literal["custom"]] - - -AgentConfigToolUnionMember5QueryGeneratorConfig: TypeAlias = Union[ - AgentConfigToolUnionMember5QueryGeneratorConfigUnionMember0, - AgentConfigToolUnionMember5QueryGeneratorConfigUnionMember1, - AgentConfigToolUnionMember5QueryGeneratorConfigType, +AgentConfigToolShieldQueryGeneratorConfig: TypeAlias = Union[ + DefaultQueryGeneratorConfigParam, LlmQueryGeneratorConfigParam, CustomQueryGeneratorConfigParam ] -class AgentConfigToolUnionMember5(TypedDict, total=False): +class AgentConfigToolShield(TypedDict, total=False): max_chunks: Required[int] max_tokens_in_context: Required[int] - memory_bank_configs: Required[Iterable[AgentConfigToolUnionMember5MemoryBankConfig]] + memory_bank_configs: Required[Iterable[AgentConfigToolShieldMemoryBankConfig]] - query_generator_config: Required[AgentConfigToolUnionMember5QueryGeneratorConfig] + query_generator_config: Required[AgentConfigToolShieldQueryGeneratorConfig] type: Required[Literal["memory"]] @@ -179,7 +158,7 @@ class AgentConfigToolUnionMember5(TypedDict, total=False): AgentConfigToolPhotogenToolDefinition, AgentConfigToolCodeInterpreterToolDefinition, AgentConfigToolFunctionCallToolDefinition, - AgentConfigToolUnionMember5, + AgentConfigToolShield, ] diff --git a/src/llama_stack/types/batch_inference_chat_completion_params.py b/src/llama_stack/types/batch_inference_chat_completion_params.py index 9c309c4..28ad798 100644 --- a/src/llama_stack/types/batch_inference_chat_completion_params.py +++ b/src/llama_stack/types/batch_inference_chat_completion_params.py @@ -12,34 +12,15 @@ from .shared_params.completion_message import CompletionMessage from .shared_params.tool_response_message import ToolResponseMessage -__all__ = ["BatchInferenceChatCompletionParams", "Request", "RequestMessagesBatch", "RequestLogprobs", "RequestTool"] +__all__ = ["BatchInferenceChatCompletionParams", "MessagesBatch", "Logprobs", "Tool"] class BatchInferenceChatCompletionParams(TypedDict, total=False): - request: Required[Request] - - -RequestMessagesBatch: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] - - -class RequestLogprobs(TypedDict, total=False): - top_k: int - - -class RequestTool(TypedDict, total=False): - tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] - - description: str - - parameters: Dict[str, ToolParamDefinitionParam] - - -class Request(TypedDict, total=False): - messages_batch: Required[Iterable[Iterable[RequestMessagesBatch]]] + messages_batch: Required[Iterable[Iterable[MessagesBatch]]] model: Required[str] - logprobs: RequestLogprobs + logprobs: Logprobs sampling_params: SamplingParams @@ -58,4 +39,19 @@ class Request(TypedDict, total=False): The detailed prompts for each of these formats are added to llama cli """ - tools: Iterable[RequestTool] + tools: Iterable[Tool] + + +MessagesBatch: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] + + +class Logprobs(TypedDict, total=False): + top_k: int + + +class Tool(TypedDict, total=False): + tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] + + description: str + + parameters: Dict[str, ToolParamDefinitionParam] diff --git a/src/llama_stack/types/batch_inference_completion_params.py b/src/llama_stack/types/batch_inference_completion_params.py index 2dfc485..398531a 100644 --- a/src/llama_stack/types/batch_inference_completion_params.py +++ b/src/llama_stack/types/batch_inference_completion_params.py @@ -7,22 +7,18 @@ from .shared_params.sampling_params import SamplingParams -__all__ = ["BatchInferenceCompletionParams", "Request", "RequestLogprobs"] +__all__ = ["BatchInferenceCompletionParams", "Logprobs"] class BatchInferenceCompletionParams(TypedDict, total=False): - request: Required[Request] - - -class RequestLogprobs(TypedDict, total=False): - top_k: int - - -class Request(TypedDict, total=False): content_batch: Required[List[Union[str, List[str]]]] model: Required[str] - logprobs: RequestLogprobs + logprobs: Logprobs sampling_params: SamplingParams + + +class Logprobs(TypedDict, total=False): + top_k: int diff --git a/src/llama_stack/types/chat_completion_stream_chunk.py b/src/llama_stack/types/chat_completion_stream_chunk.py index a433673..6a1d5c8 100644 --- a/src/llama_stack/types/chat_completion_stream_chunk.py +++ b/src/llama_stack/types/chat_completion_stream_chunk.py @@ -1,9 +1,10 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional +from typing import List, Union, Optional from typing_extensions import Literal, TypeAlias from .._models import BaseModel +from .token_log_probs import TokenLogProbs from .shared.tool_call import ToolCall __all__ = [ @@ -12,7 +13,6 @@ "EventDelta", "EventDeltaToolCallDelta", "EventDeltaToolCallDeltaContent", - "EventLogprob", ] EventDeltaToolCallDeltaContent: TypeAlias = Union[str, ToolCall] @@ -27,16 +27,12 @@ class EventDeltaToolCallDelta(BaseModel): EventDelta: TypeAlias = Union[str, EventDeltaToolCallDelta] -class EventLogprob(BaseModel): - logprobs_by_token: Dict[str, float] - - class Event(BaseModel): delta: EventDelta event_type: Literal["start", "complete", "progress"] - logprobs: Optional[List[EventLogprob]] = None + logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[Literal["end_of_turn", "end_of_message", "out_of_tokens"]] = None diff --git a/src/llama_stack/types/completion_stream_chunk.py b/src/llama_stack/types/completion_stream_chunk.py index 9d188af..ff445db 100644 --- a/src/llama_stack/types/completion_stream_chunk.py +++ b/src/llama_stack/types/completion_stream_chunk.py @@ -1,20 +1,17 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Optional +from typing import List, Optional from typing_extensions import Literal from .._models import BaseModel +from .token_log_probs import TokenLogProbs -__all__ = ["CompletionStreamChunk", "Logprob"] - - -class Logprob(BaseModel): - logprobs_by_token: Dict[str, float] +__all__ = ["CompletionStreamChunk"] class CompletionStreamChunk(BaseModel): delta: str - logprobs: Optional[List[Logprob]] = None + logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[Literal["end_of_turn", "end_of_message", "out_of_tokens"]] = None diff --git a/src/llama_stack/types/custom_query_generator_config_param.py b/src/llama_stack/types/custom_query_generator_config_param.py new file mode 100644 index 0000000..432450c --- /dev/null +++ b/src/llama_stack/types/custom_query_generator_config_param.py @@ -0,0 +1,11 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["CustomQueryGeneratorConfigParam"] + + +class CustomQueryGeneratorConfigParam(TypedDict, total=False): + type: Required[Literal["custom"]] diff --git a/src/llama_stack/types/dataset_create_params.py b/src/llama_stack/types/dataset_create_params.py index ab2e62b..971ddb5 100644 --- a/src/llama_stack/types/dataset_create_params.py +++ b/src/llama_stack/types/dataset_create_params.py @@ -6,14 +6,10 @@ from .train_eval_dataset_param import TrainEvalDatasetParam -__all__ = ["DatasetCreateParams", "Request"] +__all__ = ["DatasetCreateParams"] class DatasetCreateParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): dataset: Required[TrainEvalDatasetParam] uuid: Required[str] diff --git a/src/llama_stack/types/default_query_generator_config_param.py b/src/llama_stack/types/default_query_generator_config_param.py new file mode 100644 index 0000000..2aaaa81 --- /dev/null +++ b/src/llama_stack/types/default_query_generator_config_param.py @@ -0,0 +1,13 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["DefaultQueryGeneratorConfigParam"] + + +class DefaultQueryGeneratorConfigParam(TypedDict, total=False): + sep: Required[str] + + type: Required[Literal["default"]] diff --git a/src/llama_stack/types/evaluate/jobs/__init__.py b/src/llama_stack/types/evaluate/jobs/__init__.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/evaluate/jobs/artifact_list_params.py b/src/llama_stack/types/evaluate/jobs/artifact_list_params.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/evaluate/jobs/log_list_params.py b/src/llama_stack/types/evaluate/jobs/log_list_params.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/evaluate/jobs/status_list_params.py b/src/llama_stack/types/evaluate/jobs/status_list_params.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/evaluate/question_answering_create_params.py b/src/llama_stack/types/evaluate/question_answering_create_params.py index 1b2d6ec..8477717 100644 --- a/src/llama_stack/types/evaluate/question_answering_create_params.py +++ b/src/llama_stack/types/evaluate/question_answering_create_params.py @@ -5,24 +5,8 @@ from typing import List from typing_extensions import Literal, Required, TypedDict -from ..train_eval_dataset_param import TrainEvalDatasetParam -from ..shared_params.sampling_params import SamplingParams - -__all__ = ["QuestionAnsweringCreateParams", "Request"] +__all__ = ["QuestionAnsweringCreateParams"] class QuestionAnsweringCreateParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - checkpoint: Required[object] - """Checkpoint created during training runs""" - - dataset: Required[TrainEvalDatasetParam] - - job_uuid: Required[str] - metrics: Required[List[Literal["em", "f1"]]] - - sampling_params: Required[SamplingParams] diff --git a/src/llama_stack/types/evaluation_summarization_params.py b/src/llama_stack/types/evaluation_summarization_params.py index e08a3c1..34542d6 100644 --- a/src/llama_stack/types/evaluation_summarization_params.py +++ b/src/llama_stack/types/evaluation_summarization_params.py @@ -5,24 +5,8 @@ from typing import List from typing_extensions import Literal, Required, TypedDict -from .train_eval_dataset_param import TrainEvalDatasetParam -from .shared_params.sampling_params import SamplingParams - -__all__ = ["EvaluationSummarizationParams", "Request"] +__all__ = ["EvaluationSummarizationParams"] class EvaluationSummarizationParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - checkpoint: Required[object] - """Checkpoint created during training runs""" - - dataset: Required[TrainEvalDatasetParam] - - job_uuid: Required[str] - metrics: Required[List[Literal["rouge", "bleu"]]] - - sampling_params: Required[SamplingParams] diff --git a/src/llama_stack/types/evaluation_text_generation_params.py b/src/llama_stack/types/evaluation_text_generation_params.py index 6fff1fd..deaec66 100644 --- a/src/llama_stack/types/evaluation_text_generation_params.py +++ b/src/llama_stack/types/evaluation_text_generation_params.py @@ -5,24 +5,8 @@ from typing import List from typing_extensions import Literal, Required, TypedDict -from .train_eval_dataset_param import TrainEvalDatasetParam -from .shared_params.sampling_params import SamplingParams - -__all__ = ["EvaluationTextGenerationParams", "Request"] +__all__ = ["EvaluationTextGenerationParams"] class EvaluationTextGenerationParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - checkpoint: Required[object] - """Checkpoint created during training runs""" - - dataset: Required[TrainEvalDatasetParam] - - job_uuid: Required[str] - metrics: Required[List[Literal["perplexity", "rouge", "bleu"]]] - - sampling_params: Required[SamplingParams] diff --git a/src/llama_stack/types/experiment.py b/src/llama_stack/types/experiment.py deleted file mode 100644 index b09fb14..0000000 --- a/src/llama_stack/types/experiment.py +++ /dev/null @@ -1,23 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Dict, List, Union -from datetime import datetime -from typing_extensions import Literal - -from .._models import BaseModel - -__all__ = ["Experiment"] - - -class Experiment(BaseModel): - id: str - - created_at: datetime - - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - name: str - - status: Literal["not_started", "running", "completed", "failed"] - - updated_at: datetime diff --git a/src/llama_stack/types/experiment_create_params.py b/src/llama_stack/types/experiment_create_params.py deleted file mode 100644 index ecf42ad..0000000 --- a/src/llama_stack/types/experiment_create_params.py +++ /dev/null @@ -1,18 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Required, TypedDict - -__all__ = ["ExperimentCreateParams", "Request"] - - -class ExperimentCreateParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - name: Required[str] - - metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] diff --git a/src/llama_stack/types/experiment_create_run_params.py b/src/llama_stack/types/experiment_create_run_params.py deleted file mode 100644 index fb49a61..0000000 --- a/src/llama_stack/types/experiment_create_run_params.py +++ /dev/null @@ -1,18 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Required, TypedDict - -__all__ = ["ExperimentCreateRunParams", "Request"] - - -class ExperimentCreateRunParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - experiment_id: Required[str] - - metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] diff --git a/src/llama_stack/types/experiment_retrieve_params.py b/src/llama_stack/types/experiment_retrieve_params.py deleted file mode 100644 index 0ee30c0..0000000 --- a/src/llama_stack/types/experiment_retrieve_params.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["ExperimentRetrieveParams"] - - -class ExperimentRetrieveParams(TypedDict, total=False): - experiment_id: Required[str] diff --git a/src/llama_stack/types/experiment_update_params.py b/src/llama_stack/types/experiment_update_params.py deleted file mode 100644 index 8591773..0000000 --- a/src/llama_stack/types/experiment_update_params.py +++ /dev/null @@ -1,20 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Literal, Required, TypedDict - -__all__ = ["ExperimentUpdateParams", "Request"] - - -class ExperimentUpdateParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - experiment_id: Required[str] - - metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - status: Literal["not_started", "running", "completed", "failed"] diff --git a/src/llama_stack/types/experiments/__init__.py b/src/llama_stack/types/experiments/__init__.py deleted file mode 100644 index 91e06eb..0000000 --- a/src/llama_stack/types/experiments/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from .artifact_upload_params import ArtifactUploadParams as ArtifactUploadParams -from .artifact_retrieve_params import ArtifactRetrieveParams as ArtifactRetrieveParams diff --git a/src/llama_stack/types/experiments/artifact_retrieve_params.py b/src/llama_stack/types/experiments/artifact_retrieve_params.py deleted file mode 100644 index 76b022e..0000000 --- a/src/llama_stack/types/experiments/artifact_retrieve_params.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["ArtifactRetrieveParams"] - - -class ArtifactRetrieveParams(TypedDict, total=False): - experiment_id: Required[str] diff --git a/src/llama_stack/types/experiments/artifact_upload_params.py b/src/llama_stack/types/experiments/artifact_upload_params.py deleted file mode 100644 index 590bdc2..0000000 --- a/src/llama_stack/types/experiments/artifact_upload_params.py +++ /dev/null @@ -1,24 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Required, TypedDict - -__all__ = ["ArtifactUploadParams", "Request"] - - -class ArtifactUploadParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - artifact_type: Required[str] - - content: Required[str] - - experiment_id: Required[str] - - name: Required[str] - - metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] diff --git a/src/llama_stack/types/inference_chat_completion_params.py b/src/llama_stack/types/inference_chat_completion_params.py index 37fa9c8..af21934 100644 --- a/src/llama_stack/types/inference_chat_completion_params.py +++ b/src/llama_stack/types/inference_chat_completion_params.py @@ -12,39 +12,25 @@ from .shared_params.completion_message import CompletionMessage from .shared_params.tool_response_message import ToolResponseMessage -__all__ = ["InferenceChatCompletionParams", "Request", "RequestMessage", "RequestLogprobs", "RequestTool"] +__all__ = [ + "InferenceChatCompletionParamsBase", + "Message", + "Logprobs", + "Tool", + "InferenceChatCompletionParamsNonStreaming", + "InferenceChatCompletionParamsStreaming", +] -class InferenceChatCompletionParams(TypedDict, total=False): - request: Required[Request] - - -RequestMessage: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] - - -class RequestLogprobs(TypedDict, total=False): - top_k: int - - -class RequestTool(TypedDict, total=False): - tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] - - description: str - - parameters: Dict[str, ToolParamDefinitionParam] - - -class Request(TypedDict, total=False): - messages: Required[Iterable[RequestMessage]] +class InferenceChatCompletionParamsBase(TypedDict, total=False): + messages: Required[Iterable[Message]] model: Required[str] - logprobs: RequestLogprobs + logprobs: Logprobs sampling_params: SamplingParams - stream: bool - tool_choice: Literal["auto", "required"] tool_prompt_format: Literal["json", "function_tag"] @@ -60,4 +46,30 @@ class Request(TypedDict, total=False): The detailed prompts for each of these formats are added to llama cli """ - tools: Iterable[RequestTool] + tools: Iterable[Tool] + + +Message: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] + + +class Logprobs(TypedDict, total=False): + top_k: int + + +class Tool(TypedDict, total=False): + tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] + + description: str + + parameters: Dict[str, ToolParamDefinitionParam] + + +class InferenceChatCompletionParamsNonStreaming(InferenceChatCompletionParamsBase, total=False): + stream: Literal[False] + + +class InferenceChatCompletionParamsStreaming(InferenceChatCompletionParamsBase): + stream: Required[Literal[True]] + + +InferenceChatCompletionParams = Union[InferenceChatCompletionParamsNonStreaming, InferenceChatCompletionParamsStreaming] diff --git a/src/llama_stack/types/inference_chat_completion_response.py b/src/llama_stack/types/inference_chat_completion_response.py new file mode 100644 index 0000000..2cf4254 --- /dev/null +++ b/src/llama_stack/types/inference_chat_completion_response.py @@ -0,0 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional +from typing_extensions import TypeAlias + +from .._models import BaseModel +from .token_log_probs import TokenLogProbs +from .shared.completion_message import CompletionMessage +from .chat_completion_stream_chunk import ChatCompletionStreamChunk + +__all__ = ["InferenceChatCompletionResponse", "ChatCompletionResponse"] + + +class ChatCompletionResponse(BaseModel): + completion_message: CompletionMessage + + logprobs: Optional[List[TokenLogProbs]] = None + + +InferenceChatCompletionResponse: TypeAlias = Union[ChatCompletionResponse, ChatCompletionStreamChunk] diff --git a/src/llama_stack/types/inference_completion_params.py b/src/llama_stack/types/inference_completion_params.py index 83f8eeb..79544b0 100644 --- a/src/llama_stack/types/inference_completion_params.py +++ b/src/llama_stack/types/inference_completion_params.py @@ -7,24 +7,20 @@ from .shared_params.sampling_params import SamplingParams -__all__ = ["InferenceCompletionParams", "Request", "RequestLogprobs"] +__all__ = ["InferenceCompletionParams", "Logprobs"] class InferenceCompletionParams(TypedDict, total=False): - request: Required[Request] - - -class RequestLogprobs(TypedDict, total=False): - top_k: int - - -class Request(TypedDict, total=False): content: Required[Union[str, List[str]]] model: Required[str] - logprobs: RequestLogprobs + logprobs: Logprobs sampling_params: SamplingParams stream: bool + + +class Logprobs(TypedDict, total=False): + top_k: int diff --git a/src/llama_stack/types/inference_completion_response.py b/src/llama_stack/types/inference_completion_response.py new file mode 100644 index 0000000..5fa75ce --- /dev/null +++ b/src/llama_stack/types/inference_completion_response.py @@ -0,0 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional +from typing_extensions import TypeAlias + +from .._models import BaseModel +from .token_log_probs import TokenLogProbs +from .completion_stream_chunk import CompletionStreamChunk +from .shared.completion_message import CompletionMessage + +__all__ = ["InferenceCompletionResponse", "CompletionResponse"] + + +class CompletionResponse(BaseModel): + completion_message: CompletionMessage + + logprobs: Optional[List[TokenLogProbs]] = None + + +InferenceCompletionResponse: TypeAlias = Union[CompletionResponse, CompletionStreamChunk] diff --git a/src/llama_stack/types/llm_query_generator_config_param.py b/src/llama_stack/types/llm_query_generator_config_param.py new file mode 100644 index 0000000..8d6bd31 --- /dev/null +++ b/src/llama_stack/types/llm_query_generator_config_param.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["LlmQueryGeneratorConfigParam"] + + +class LlmQueryGeneratorConfigParam(TypedDict, total=False): + model: Required[str] + + template: Required[str] + + type: Required[Literal["llm"]] diff --git a/src/llama_stack/types/logging_get_logs_params.py b/src/llama_stack/types/logging_get_logs_params.py deleted file mode 100644 index 68a47f3..0000000 --- a/src/llama_stack/types/logging_get_logs_params.py +++ /dev/null @@ -1,18 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from typing_extensions import Required, TypedDict - -__all__ = ["LoggingGetLogsParams", "Request"] - - -class LoggingGetLogsParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - query: Required[str] - - filters: Dict[str, Union[bool, float, str, Iterable[object], object, None]] diff --git a/src/llama_stack/types/logging_get_logs_response.py b/src/llama_stack/types/logging_get_logs_response.py deleted file mode 100644 index 0d02b17..0000000 --- a/src/llama_stack/types/logging_get_logs_response.py +++ /dev/null @@ -1,18 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Dict, List, Union -from datetime import datetime - -from .._models import BaseModel - -__all__ = ["LoggingGetLogsResponse"] - - -class LoggingGetLogsResponse(BaseModel): - additional_info: Dict[str, Union[bool, float, str, List[object], object, None]] - - level: str - - message: str - - timestamp: datetime diff --git a/src/llama_stack/types/logging_log_messages_params.py b/src/llama_stack/types/logging_log_messages_params.py deleted file mode 100644 index fafe9a4..0000000 --- a/src/llama_stack/types/logging_log_messages_params.py +++ /dev/null @@ -1,31 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from datetime import datetime -from typing_extensions import Required, Annotated, TypedDict - -from .._utils import PropertyInfo - -__all__ = ["LoggingLogMessagesParams", "Request", "RequestLog"] - - -class LoggingLogMessagesParams(TypedDict, total=False): - request: Required[Request] - - -class RequestLog(TypedDict, total=False): - additional_info: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] - - level: Required[str] - - message: Required[str] - - timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]] - - -class Request(TypedDict, total=False): - logs: Required[Iterable[RequestLog]] - - run_id: str diff --git a/src/llama_stack/types/post_training_preference_optimize_params.py b/src/llama_stack/types/post_training_preference_optimize_params.py index ff23a3b..9e6f3cc 100644 --- a/src/llama_stack/types/post_training_preference_optimize_params.py +++ b/src/llama_stack/types/post_training_preference_optimize_params.py @@ -7,20 +7,32 @@ from .train_eval_dataset_param import TrainEvalDatasetParam -__all__ = [ - "PostTrainingPreferenceOptimizeParams", - "Request", - "RequestAlgorithmConfig", - "RequestOptimizerConfig", - "RequestTrainingConfig", -] +__all__ = ["PostTrainingPreferenceOptimizeParams", "AlgorithmConfig", "OptimizerConfig", "TrainingConfig"] class PostTrainingPreferenceOptimizeParams(TypedDict, total=False): - request: Required[Request] + algorithm: Required[Literal["dpo"]] + + algorithm_config: Required[AlgorithmConfig] + dataset: Required[TrainEvalDatasetParam] + + finetuned_model: Required[str] + + hyperparam_search_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + job_uuid: Required[str] -class RequestAlgorithmConfig(TypedDict, total=False): + logger_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + optimizer_config: Required[OptimizerConfig] + + training_config: Required[TrainingConfig] + + validation_dataset: Required[TrainEvalDatasetParam] + + +class AlgorithmConfig(TypedDict, total=False): epsilon: Required[float] gamma: Required[float] @@ -30,7 +42,7 @@ class RequestAlgorithmConfig(TypedDict, total=False): reward_scale: Required[float] -class RequestOptimizerConfig(TypedDict, total=False): +class OptimizerConfig(TypedDict, total=False): lr: Required[float] lr_min: Required[float] @@ -40,7 +52,7 @@ class RequestOptimizerConfig(TypedDict, total=False): weight_decay: Required[float] -class RequestTrainingConfig(TypedDict, total=False): +class TrainingConfig(TypedDict, total=False): batch_size: Required[int] enable_activation_checkpointing: Required[bool] @@ -54,25 +66,3 @@ class RequestTrainingConfig(TypedDict, total=False): n_iters: Required[int] shuffle: Required[bool] - - -class Request(TypedDict, total=False): - algorithm: Required[Literal["dpo"]] - - algorithm_config: Required[RequestAlgorithmConfig] - - dataset: Required[TrainEvalDatasetParam] - - finetuned_model: Required[str] - - hyperparam_search_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] - - job_uuid: Required[str] - - logger_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] - - optimizer_config: Required[RequestOptimizerConfig] - - training_config: Required[RequestTrainingConfig] - - validation_dataset: Required[TrainEvalDatasetParam] diff --git a/src/llama_stack/types/post_training_supervised_fine_tune_params.py b/src/llama_stack/types/post_training_supervised_fine_tune_params.py index 5d09768..36f776d 100644 --- a/src/llama_stack/types/post_training_supervised_fine_tune_params.py +++ b/src/llama_stack/types/post_training_supervised_fine_tune_params.py @@ -9,21 +9,38 @@ __all__ = [ "PostTrainingSupervisedFineTuneParams", - "Request", - "RequestAlgorithmConfig", - "RequestAlgorithmConfigLoraFinetuningConfig", - "RequestAlgorithmConfigQLoraFinetuningConfig", - "RequestAlgorithmConfigDoraFinetuningConfig", - "RequestOptimizerConfig", - "RequestTrainingConfig", + "AlgorithmConfig", + "AlgorithmConfigLoraFinetuningConfig", + "AlgorithmConfigQLoraFinetuningConfig", + "AlgorithmConfigDoraFinetuningConfig", + "OptimizerConfig", + "TrainingConfig", ] class PostTrainingSupervisedFineTuneParams(TypedDict, total=False): - request: Required[Request] + algorithm: Required[Literal["full", "lora", "qlora", "dora"]] + + algorithm_config: Required[AlgorithmConfig] + dataset: Required[TrainEvalDatasetParam] + + hyperparam_search_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + job_uuid: Required[str] + + logger_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] -class RequestAlgorithmConfigLoraFinetuningConfig(TypedDict, total=False): + model: Required[str] + + optimizer_config: Required[OptimizerConfig] + + training_config: Required[TrainingConfig] + + validation_dataset: Required[TrainEvalDatasetParam] + + +class AlgorithmConfigLoraFinetuningConfig(TypedDict, total=False): alpha: Required[int] apply_lora_to_mlp: Required[bool] @@ -35,7 +52,7 @@ class RequestAlgorithmConfigLoraFinetuningConfig(TypedDict, total=False): rank: Required[int] -class RequestAlgorithmConfigQLoraFinetuningConfig(TypedDict, total=False): +class AlgorithmConfigQLoraFinetuningConfig(TypedDict, total=False): alpha: Required[int] apply_lora_to_mlp: Required[bool] @@ -47,7 +64,7 @@ class RequestAlgorithmConfigQLoraFinetuningConfig(TypedDict, total=False): rank: Required[int] -class RequestAlgorithmConfigDoraFinetuningConfig(TypedDict, total=False): +class AlgorithmConfigDoraFinetuningConfig(TypedDict, total=False): alpha: Required[int] apply_lora_to_mlp: Required[bool] @@ -59,14 +76,12 @@ class RequestAlgorithmConfigDoraFinetuningConfig(TypedDict, total=False): rank: Required[int] -RequestAlgorithmConfig: TypeAlias = Union[ - RequestAlgorithmConfigLoraFinetuningConfig, - RequestAlgorithmConfigQLoraFinetuningConfig, - RequestAlgorithmConfigDoraFinetuningConfig, +AlgorithmConfig: TypeAlias = Union[ + AlgorithmConfigLoraFinetuningConfig, AlgorithmConfigQLoraFinetuningConfig, AlgorithmConfigDoraFinetuningConfig ] -class RequestOptimizerConfig(TypedDict, total=False): +class OptimizerConfig(TypedDict, total=False): lr: Required[float] lr_min: Required[float] @@ -76,7 +91,7 @@ class RequestOptimizerConfig(TypedDict, total=False): weight_decay: Required[float] -class RequestTrainingConfig(TypedDict, total=False): +class TrainingConfig(TypedDict, total=False): batch_size: Required[int] enable_activation_checkpointing: Required[bool] @@ -90,25 +105,3 @@ class RequestTrainingConfig(TypedDict, total=False): n_iters: Required[int] shuffle: Required[bool] - - -class Request(TypedDict, total=False): - algorithm: Required[Literal["full", "lora", "qlora", "dora"]] - - algorithm_config: Required[RequestAlgorithmConfig] - - dataset: Required[TrainEvalDatasetParam] - - hyperparam_search_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] - - job_uuid: Required[str] - - logger_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] - - model: Required[str] - - optimizer_config: Required[RequestOptimizerConfig] - - training_config: Required[RequestTrainingConfig] - - validation_dataset: Required[TrainEvalDatasetParam] diff --git a/src/llama_stack/types/rest_api_execution_config_param.py b/src/llama_stack/types/rest_api_execution_config_param.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/reward_scoring_score_params.py b/src/llama_stack/types/reward_scoring_score_params.py index da88a22..b969b75 100644 --- a/src/llama_stack/types/reward_scoring_score_params.py +++ b/src/llama_stack/types/reward_scoring_score_params.py @@ -12,31 +12,24 @@ __all__ = [ "RewardScoringScoreParams", - "Request", - "RequestDialogGeneration", - "RequestDialogGenerationDialog", - "RequestDialogGenerationSampledGeneration", + "DialogGeneration", + "DialogGenerationDialog", + "DialogGenerationSampledGeneration", ] class RewardScoringScoreParams(TypedDict, total=False): - request: Required[Request] + dialog_generations: Required[Iterable[DialogGeneration]] - -RequestDialogGenerationDialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] - -RequestDialogGenerationSampledGeneration: TypeAlias = Union[ - UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage -] + model: Required[str] -class RequestDialogGeneration(TypedDict, total=False): - dialog: Required[Iterable[RequestDialogGenerationDialog]] +DialogGenerationDialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] - sampled_generations: Required[Iterable[RequestDialogGenerationSampledGeneration]] +DialogGenerationSampledGeneration: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] -class Request(TypedDict, total=False): - dialog_generations: Required[Iterable[RequestDialogGeneration]] +class DialogGeneration(TypedDict, total=False): + dialog: Required[Iterable[DialogGenerationDialog]] - model: Required[str] + sampled_generations: Required[Iterable[DialogGenerationSampledGeneration]] diff --git a/src/llama_stack/types/run_log_metrics_params.py b/src/llama_stack/types/run_log_metrics_params.py deleted file mode 100644 index 750dfcf..0000000 --- a/src/llama_stack/types/run_log_metrics_params.py +++ /dev/null @@ -1,31 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Union, Iterable -from datetime import datetime -from typing_extensions import Required, Annotated, TypedDict - -from .._utils import PropertyInfo - -__all__ = ["RunLogMetricsParams", "Request", "RequestMetric"] - - -class RunLogMetricsParams(TypedDict, total=False): - request: Required[Request] - - -class RequestMetric(TypedDict, total=False): - name: Required[str] - - run_id: Required[str] - - timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]] - - value: Required[Union[float, str, bool]] - - -class Request(TypedDict, total=False): - metrics: Required[Iterable[RequestMetric]] - - run_id: Required[str] diff --git a/src/llama_stack/types/run_update_params.py b/src/llama_stack/types/run_update_params.py deleted file mode 100644 index ba723f3..0000000 --- a/src/llama_stack/types/run_update_params.py +++ /dev/null @@ -1,25 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing import Dict, Union, Iterable -from datetime import datetime -from typing_extensions import Required, Annotated, TypedDict - -from .._utils import PropertyInfo - -__all__ = ["RunUpdateParams", "Request"] - - -class RunUpdateParams(TypedDict, total=False): - request: Required[Request] - - -class Request(TypedDict, total=False): - run_id: Required[str] - - ended_at: Annotated[Union[str, datetime], PropertyInfo(format="iso8601")] - - metadata: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - status: str diff --git a/src/llama_stack/types/runs/__init__.py b/src/llama_stack/types/runs/__init__.py deleted file mode 100644 index c1b2d09..0000000 --- a/src/llama_stack/types/runs/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from .metric_list_params import MetricListParams as MetricListParams -from .metric_list_response import MetricListResponse as MetricListResponse diff --git a/src/llama_stack/types/runs/metric_list_params.py b/src/llama_stack/types/runs/metric_list_params.py deleted file mode 100644 index 5287315..0000000 --- a/src/llama_stack/types/runs/metric_list_params.py +++ /dev/null @@ -1,11 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -from typing_extensions import Required, TypedDict - -__all__ = ["MetricListParams"] - - -class MetricListParams(TypedDict, total=False): - run_id: Required[str] diff --git a/src/llama_stack/types/runs/metric_list_response.py b/src/llama_stack/types/runs/metric_list_response.py deleted file mode 100644 index 78fcbb8..0000000 --- a/src/llama_stack/types/runs/metric_list_response.py +++ /dev/null @@ -1,18 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Union -from datetime import datetime - -from ..._models import BaseModel - -__all__ = ["MetricListResponse"] - - -class MetricListResponse(BaseModel): - name: str - - run_id: str - - timestamp: datetime - - value: Union[float, str, bool] diff --git a/src/llama_stack/types/safety_run_shields_params.py b/src/llama_stack/types/safety_run_shields_params.py index 0117d51..59498f6 100644 --- a/src/llama_stack/types/safety_run_shields_params.py +++ b/src/llama_stack/types/safety_run_shields_params.py @@ -11,17 +11,13 @@ from .shared_params.completion_message import CompletionMessage from .shared_params.tool_response_message import ToolResponseMessage -__all__ = ["SafetyRunShieldsParams", "Request", "RequestMessage"] +__all__ = ["SafetyRunShieldsParams", "Message"] class SafetyRunShieldsParams(TypedDict, total=False): - request: Required[Request] - - -RequestMessage: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] + messages: Required[Iterable[Message]] + shields: Required[Iterable[ShieldDefinitionParam]] -class Request(TypedDict, total=False): - messages: Required[Iterable[RequestMessage]] - shields: Required[Iterable[ShieldDefinitionParam]] +Message: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/shared/__init__.py b/src/llama_stack/types/shared/__init__.py index 244a74a..dcec3b3 100644 --- a/src/llama_stack/types/shared/__init__.py +++ b/src/llama_stack/types/shared/__init__.py @@ -1,7 +1,5 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from .run import Run as Run -from .artifact import Artifact as Artifact from .tool_call import ToolCall as ToolCall from .attachment import Attachment as Attachment from .user_message import UserMessage as UserMessage diff --git a/src/llama_stack/types/shared/artifact.py b/src/llama_stack/types/shared/artifact.py deleted file mode 100644 index 11b4fa2..0000000 --- a/src/llama_stack/types/shared/artifact.py +++ /dev/null @@ -1,23 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Dict, List, Union -from datetime import datetime -from typing_extensions import Literal - -from ..._models import BaseModel - -__all__ = ["Artifact"] - - -class Artifact(BaseModel): - id: str - - created_at: datetime - - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - name: str - - size: int - - type: Literal["model", "dataset", "checkpoint", "plot", "metric", "config", "code", "other"] diff --git a/src/llama_stack/types/shared/attachment.py b/src/llama_stack/types/shared/attachment.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared/completion_message.py b/src/llama_stack/types/shared/completion_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared/run.py b/src/llama_stack/types/shared/run.py deleted file mode 100644 index 2a94ee6..0000000 --- a/src/llama_stack/types/shared/run.py +++ /dev/null @@ -1,22 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from typing import Dict, List, Union, Optional -from datetime import datetime - -from ..._models import BaseModel - -__all__ = ["Run"] - - -class Run(BaseModel): - id: str - - experiment_id: str - - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - started_at: datetime - - status: str - - ended_at: Optional[datetime] = None diff --git a/src/llama_stack/types/shared/sampling_params.py b/src/llama_stack/types/shared/sampling_params.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared/system_message.py b/src/llama_stack/types/shared/system_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared/tool_call.py b/src/llama_stack/types/shared/tool_call.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared/tool_response_message.py b/src/llama_stack/types/shared/tool_response_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared/user_message.py b/src/llama_stack/types/shared/user_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/__init__.py b/src/llama_stack/types/shared_params/__init__.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/attachment.py b/src/llama_stack/types/shared_params/attachment.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/completion_message.py b/src/llama_stack/types/shared_params/completion_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/sampling_params.py b/src/llama_stack/types/shared_params/sampling_params.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/system_message.py b/src/llama_stack/types/shared_params/system_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/tool_call.py b/src/llama_stack/types/shared_params/tool_call.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/tool_response_message.py b/src/llama_stack/types/shared_params/tool_response_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shared_params/user_message.py b/src/llama_stack/types/shared_params/user_message.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/shield_definition_param.py b/src/llama_stack/types/shield_definition_param.py old mode 100755 new mode 100644 diff --git a/src/llama_stack/types/synthetic_data_generation_generate_params.py b/src/llama_stack/types/synthetic_data_generation_generate_params.py index 31a7233..4238473 100644 --- a/src/llama_stack/types/synthetic_data_generation_generate_params.py +++ b/src/llama_stack/types/synthetic_data_generation_generate_params.py @@ -10,19 +10,15 @@ from .shared_params.completion_message import CompletionMessage from .shared_params.tool_response_message import ToolResponseMessage -__all__ = ["SyntheticDataGenerationGenerateParams", "Request", "RequestDialog"] +__all__ = ["SyntheticDataGenerationGenerateParams", "Dialog"] class SyntheticDataGenerationGenerateParams(TypedDict, total=False): - request: Required[Request] - - -RequestDialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] - - -class Request(TypedDict, total=False): - dialogs: Required[Iterable[RequestDialog]] + dialogs: Required[Iterable[Dialog]] filtering_function: Required[Literal["none", "random", "top_k", "top_p", "top_k_top_p", "sigmoid"]] model: str + + +Dialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] diff --git a/src/llama_stack/types/artifact_get_params.py b/src/llama_stack/types/telemetry_get_trace_params.py similarity index 59% rename from src/llama_stack/types/artifact_get_params.py rename to src/llama_stack/types/telemetry_get_trace_params.py index acebbaa..520724b 100644 --- a/src/llama_stack/types/artifact_get_params.py +++ b/src/llama_stack/types/telemetry_get_trace_params.py @@ -4,8 +4,8 @@ from typing_extensions import Required, TypedDict -__all__ = ["ArtifactGetParams"] +__all__ = ["TelemetryGetTraceParams"] -class ArtifactGetParams(TypedDict, total=False): - artifact_id: Required[str] +class TelemetryGetTraceParams(TypedDict, total=False): + trace_id: Required[str] diff --git a/src/llama_stack/types/telemetry_get_trace_response.py b/src/llama_stack/types/telemetry_get_trace_response.py new file mode 100644 index 0000000..c1fa453 --- /dev/null +++ b/src/llama_stack/types/telemetry_get_trace_response.py @@ -0,0 +1,18 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from datetime import datetime + +from .._models import BaseModel + +__all__ = ["TelemetryGetTraceResponse"] + + +class TelemetryGetTraceResponse(BaseModel): + root_span_id: str + + start_time: datetime + + trace_id: str + + end_time: Optional[datetime] = None diff --git a/src/llama_stack/types/telemetry_log_params.py b/src/llama_stack/types/telemetry_log_params.py new file mode 100644 index 0000000..6e6eb61 --- /dev/null +++ b/src/llama_stack/types/telemetry_log_params.py @@ -0,0 +1,94 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, Union, Iterable +from datetime import datetime +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict + +from .._utils import PropertyInfo + +__all__ = [ + "TelemetryLogParams", + "Event", + "EventUnstructuredLogEvent", + "EventMetricEvent", + "EventStructuredLogEvent", + "EventStructuredLogEventPayload", + "EventStructuredLogEventPayloadSpanStartPayload", + "EventStructuredLogEventPayloadSpanEndPayload", +] + + +class TelemetryLogParams(TypedDict, total=False): + event: Required[Event] + + +class EventUnstructuredLogEvent(TypedDict, total=False): + message: Required[str] + + severity: Required[Literal["verbose", "debug", "info", "warn", "error", "critical"]] + + span_id: Required[str] + + timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]] + + trace_id: Required[str] + + type: Required[Literal["unstructured_log"]] + + attributes: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + + +class EventMetricEvent(TypedDict, total=False): + metric: Required[str] + + span_id: Required[str] + + timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]] + + trace_id: Required[str] + + type: Required[Literal["metric"]] + + unit: Required[str] + + value: Required[float] + + attributes: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + + +class EventStructuredLogEventPayloadSpanStartPayload(TypedDict, total=False): + name: Required[str] + + type: Required[Literal["span_start"]] + + parent_span_id: str + + +class EventStructuredLogEventPayloadSpanEndPayload(TypedDict, total=False): + status: Required[Literal["ok", "error"]] + + type: Required[Literal["span_end"]] + + +EventStructuredLogEventPayload: TypeAlias = Union[ + EventStructuredLogEventPayloadSpanStartPayload, EventStructuredLogEventPayloadSpanEndPayload +] + + +class EventStructuredLogEvent(TypedDict, total=False): + payload: Required[EventStructuredLogEventPayload] + + span_id: Required[str] + + timestamp: Required[Annotated[Union[str, datetime], PropertyInfo(format="iso8601")]] + + trace_id: Required[str] + + type: Required[Literal["structured_log"]] + + attributes: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + + +Event: TypeAlias = Union[EventUnstructuredLogEvent, EventMetricEvent, EventStructuredLogEvent] diff --git a/src/llama_stack/types/token_log_probs.py b/src/llama_stack/types/token_log_probs.py new file mode 100644 index 0000000..45bc634 --- /dev/null +++ b/src/llama_stack/types/token_log_probs.py @@ -0,0 +1,11 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict + +from .._models import BaseModel + +__all__ = ["TokenLogProbs"] + + +class TokenLogProbs(BaseModel): + logprobs_by_token: Dict[str, float] diff --git a/src/llama_stack/types/tool_param_definition_param.py b/src/llama_stack/types/tool_param_definition_param.py old mode 100755 new mode 100644 diff --git a/tests/api_resources/agentic_system/test_turns.py b/tests/api_resources/agentic_system/test_turns.py index 4443deb..8d2e3ca 100644 --- a/tests/api_resources/agentic_system/test_turns.py +++ b/tests/api_resources/agentic_system/test_turns.py @@ -21,610 +21,86 @@ class TestTurns: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - def test_method_create(self, client: LlamaStack) -> None: + def test_method_create_overload_1(self, client: LlamaStack) -> None: turn = client.agentic_system.turns.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "session_id": "session_id", - }, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", ) assert_matches_type(AgenticSystemTurnStreamChunk, turn, path=["response"]) @parametrize - def test_method_create_with_all_params(self, client: LlamaStack) -> None: + def test_method_create_with_all_params_overload_1(self, client: LlamaStack) -> None: turn = client.agentic_system.turns.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - "session_id": "session_id", - "attachments": [ - { - "content": "string", - "mime_type": "mime_type", - }, - { - "content": "string", - "mime_type": "mime_type", - }, - { - "content": "string", - "mime_type": "mime_type", - }, - ], - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "instructions": "instructions", - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + attachments=[ + { + "content": "string", + "mime_type": "mime_type", }, - "stream": True, - "tool_choice": "auto", - "tool_prompt_format": "json", - "tools": [ - { - "engine": "bing", - "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "remote_execution": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - }, - { - "engine": "bing", - "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "remote_execution": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - }, - { - "engine": "bing", - "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "remote_execution": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - }, - ], - }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + stream=False, ) assert_matches_type(AgenticSystemTurnStreamChunk, turn, path=["response"]) @parametrize - def test_raw_response_create(self, client: LlamaStack) -> None: + def test_raw_response_create_overload_1(self, client: LlamaStack) -> None: response = client.agentic_system.turns.with_raw_response.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "session_id": "session_id", - }, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", ) assert response.is_closed is True @@ -633,26 +109,24 @@ def test_raw_response_create(self, client: LlamaStack) -> None: assert_matches_type(AgenticSystemTurnStreamChunk, turn, path=["response"]) @parametrize - def test_streaming_response_create(self, client: LlamaStack) -> None: + def test_streaming_response_create_overload_1(self, client: LlamaStack) -> None: with client.agentic_system.turns.with_streaming_response.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "session_id": "session_id", - }, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -662,6 +136,124 @@ def test_streaming_response_create(self, client: LlamaStack) -> None: assert cast(Any, response.is_closed) is True + @parametrize + def test_method_create_overload_2(self, client: LlamaStack) -> None: + turn_stream = client.agentic_system.turns.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + turn_stream.response.close() + + @parametrize + def test_method_create_with_all_params_overload_2(self, client: LlamaStack) -> None: + turn_stream = client.agentic_system.turns.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + stream=True, + attachments=[ + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + ) + turn_stream.response.close() + + @parametrize + def test_raw_response_create_overload_2(self, client: LlamaStack) -> None: + response = client.agentic_system.turns.with_raw_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_create_overload_2(self, client: LlamaStack) -> None: + with client.agentic_system.turns.with_streaming_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_retrieve(self, client: LlamaStack) -> None: turn = client.agentic_system.turns.retrieve( @@ -701,610 +293,86 @@ class TestAsyncTurns: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, async_client: AsyncLlamaStack) -> None: + async def test_method_create_overload_1(self, async_client: AsyncLlamaStack) -> None: turn = await async_client.agentic_system.turns.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "session_id": "session_id", - }, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", ) assert_matches_type(AgenticSystemTurnStreamChunk, turn, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: + async def test_method_create_with_all_params_overload_1(self, async_client: AsyncLlamaStack) -> None: turn = await async_client.agentic_system.turns.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - "session_id": "session_id", - "attachments": [ - { - "content": "string", - "mime_type": "mime_type", - }, - { - "content": "string", - "mime_type": "mime_type", - }, - { - "content": "string", - "mime_type": "mime_type", - }, - ], - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "instructions": "instructions", - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + attachments=[ + { + "content": "string", + "mime_type": "mime_type", }, - "stream": True, - "tool_choice": "auto", - "tool_prompt_format": "json", - "tools": [ - { - "engine": "bing", - "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "remote_execution": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - }, - { - "engine": "bing", - "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "remote_execution": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - }, - { - "engine": "bing", - "type": "brave_search", - "input_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "output_shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - "description": "description", - "execution_config": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - ], - "remote_execution": { - "method": "GET", - "url": "https://example.com", - "body": {"foo": True}, - "headers": {"foo": True}, - "params": {"foo": True}, - }, - }, - ], - }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + stream=False, ) assert_matches_type(AgenticSystemTurnStreamChunk, turn, path=["response"]) @parametrize - async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: + async def test_raw_response_create_overload_1(self, async_client: AsyncLlamaStack) -> None: response = await async_client.agentic_system.turns.with_raw_response.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "session_id": "session_id", - }, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", ) assert response.is_closed is True @@ -1313,26 +381,24 @@ async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: assert_matches_type(AgenticSystemTurnStreamChunk, turn, path=["response"]) @parametrize - async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: + async def test_streaming_response_create_overload_1(self, async_client: AsyncLlamaStack) -> None: async with async_client.agentic_system.turns.with_streaming_response.create( - request={ - "agent_id": "agent_id", - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "session_id": "session_id", - }, + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -1342,6 +408,124 @@ async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> assert cast(Any, response.is_closed) is True + @parametrize + async def test_method_create_overload_2(self, async_client: AsyncLlamaStack) -> None: + turn_stream = await async_client.agentic_system.turns.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + await turn_stream.response.aclose() + + @parametrize + async def test_method_create_with_all_params_overload_2(self, async_client: AsyncLlamaStack) -> None: + turn_stream = await async_client.agentic_system.turns.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + stream=True, + attachments=[ + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + ) + await turn_stream.response.aclose() + + @parametrize + async def test_raw_response_create_overload_2(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.agentic_system.turns.with_raw_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_create_overload_2(self, async_client: AsyncLlamaStack) -> None: + async with async_client.agentic_system.turns.with_streaming_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: turn = await async_client.agentic_system.turns.retrieve( diff --git a/tests/api_resources/evaluate/test_question_answering.py b/tests/api_resources/evaluate/test_question_answering.py index d37c960..b229620 100644 --- a/tests/api_resources/evaluate/test_question_answering.py +++ b/tests/api_resources/evaluate/test_question_answering.py @@ -20,32 +20,14 @@ class TestQuestionAnswering: @parametrize def test_method_create(self, client: LlamaStack) -> None: question_answering = client.evaluate.question_answering.create( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["em", "f1"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["em", "f1"], ) assert_matches_type(EvaluationJob, question_answering, path=["response"]) @parametrize def test_raw_response_create(self, client: LlamaStack) -> None: response = client.evaluate.question_answering.with_raw_response.create( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["em", "f1"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["em", "f1"], ) assert response.is_closed is True @@ -56,16 +38,7 @@ def test_raw_response_create(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_create(self, client: LlamaStack) -> None: with client.evaluate.question_answering.with_streaming_response.create( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["em", "f1"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["em", "f1"], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -82,32 +55,14 @@ class TestAsyncQuestionAnswering: @parametrize async def test_method_create(self, async_client: AsyncLlamaStack) -> None: question_answering = await async_client.evaluate.question_answering.create( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["em", "f1"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["em", "f1"], ) assert_matches_type(EvaluationJob, question_answering, path=["response"]) @parametrize async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluate.question_answering.with_raw_response.create( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["em", "f1"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["em", "f1"], ) assert response.is_closed is True @@ -118,16 +73,7 @@ async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: @parametrize async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: async with async_client.evaluate.question_answering.with_streaming_response.create( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["em", "f1"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["em", "f1"], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/experiments/__init__.py b/tests/api_resources/experiments/__init__.py deleted file mode 100644 index fd8019a..0000000 --- a/tests/api_resources/experiments/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/experiments/test_artifacts.py b/tests/api_resources/experiments/test_artifacts.py deleted file mode 100644 index 85100aa..0000000 --- a/tests/api_resources/experiments/test_artifacts.py +++ /dev/null @@ -1,202 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from llama_stack import LlamaStack, AsyncLlamaStack -from tests.utils import assert_matches_type -from llama_stack.types.shared import Artifact - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestArtifacts: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_retrieve(self, client: LlamaStack) -> None: - artifact = client.experiments.artifacts.retrieve( - experiment_id="experiment_id", - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_raw_response_retrieve(self, client: LlamaStack) -> None: - response = client.experiments.artifacts.with_raw_response.retrieve( - experiment_id="experiment_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - artifact = response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_streaming_response_retrieve(self, client: LlamaStack) -> None: - with client.experiments.artifacts.with_streaming_response.retrieve( - experiment_id="experiment_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - artifact = response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_upload(self, client: LlamaStack) -> None: - artifact = client.experiments.artifacts.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - }, - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_method_upload_with_all_params(self, client: LlamaStack) -> None: - artifact = client.experiments.artifacts.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - "metadata": {"foo": True}, - }, - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_raw_response_upload(self, client: LlamaStack) -> None: - response = client.experiments.artifacts.with_raw_response.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - }, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - artifact = response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_streaming_response_upload(self, client: LlamaStack) -> None: - with client.experiments.artifacts.with_streaming_response.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - }, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - artifact = response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - assert cast(Any, response.is_closed) is True - - -class TestAsyncArtifacts: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: - artifact = await async_client.experiments.artifacts.retrieve( - experiment_id="experiment_id", - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.artifacts.with_raw_response.retrieve( - experiment_id="experiment_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - artifact = await response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.artifacts.with_streaming_response.retrieve( - experiment_id="experiment_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - artifact = await response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_upload(self, async_client: AsyncLlamaStack) -> None: - artifact = await async_client.experiments.artifacts.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - }, - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_method_upload_with_all_params(self, async_client: AsyncLlamaStack) -> None: - artifact = await async_client.experiments.artifacts.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - "metadata": {"foo": True}, - }, - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_raw_response_upload(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.artifacts.with_raw_response.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - }, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - artifact = await response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_streaming_response_upload(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.artifacts.with_streaming_response.upload( - request={ - "artifact_type": "artifact_type", - "content": "content", - "experiment_id": "experiment_id", - "name": "name", - }, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - artifact = await response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/runs/__init__.py b/tests/api_resources/runs/__init__.py deleted file mode 100644 index fd8019a..0000000 --- a/tests/api_resources/runs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/runs/test_metrics.py b/tests/api_resources/runs/test_metrics.py deleted file mode 100644 index 9ca013a..0000000 --- a/tests/api_resources/runs/test_metrics.py +++ /dev/null @@ -1,84 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from llama_stack import LlamaStack, AsyncLlamaStack -from tests.utils import assert_matches_type -from llama_stack.types.runs import MetricListResponse - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestMetrics: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_list(self, client: LlamaStack) -> None: - metric = client.runs.metrics.list( - run_id="run_id", - ) - assert_matches_type(MetricListResponse, metric, path=["response"]) - - @parametrize - def test_raw_response_list(self, client: LlamaStack) -> None: - response = client.runs.metrics.with_raw_response.list( - run_id="run_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - metric = response.parse() - assert_matches_type(MetricListResponse, metric, path=["response"]) - - @parametrize - def test_streaming_response_list(self, client: LlamaStack) -> None: - with client.runs.metrics.with_streaming_response.list( - run_id="run_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - metric = response.parse() - assert_matches_type(MetricListResponse, metric, path=["response"]) - - assert cast(Any, response.is_closed) is True - - -class TestAsyncMetrics: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_list(self, async_client: AsyncLlamaStack) -> None: - metric = await async_client.runs.metrics.list( - run_id="run_id", - ) - assert_matches_type(MetricListResponse, metric, path=["response"]) - - @parametrize - async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.runs.metrics.with_raw_response.list( - run_id="run_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - metric = await response.parse() - assert_matches_type(MetricListResponse, metric, path=["response"]) - - @parametrize - async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None: - async with async_client.runs.metrics.with_streaming_response.list( - run_id="run_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - metric = await response.parse() - assert_matches_type(MetricListResponse, metric, path=["response"]) - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_artifacts.py b/tests/api_resources/test_artifacts.py deleted file mode 100644 index 9bfed6a..0000000 --- a/tests/api_resources/test_artifacts.py +++ /dev/null @@ -1,84 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from llama_stack import LlamaStack, AsyncLlamaStack -from tests.utils import assert_matches_type -from llama_stack.types.shared import Artifact - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestArtifacts: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_get(self, client: LlamaStack) -> None: - artifact = client.artifacts.get( - artifact_id="artifact_id", - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_raw_response_get(self, client: LlamaStack) -> None: - response = client.artifacts.with_raw_response.get( - artifact_id="artifact_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - artifact = response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - def test_streaming_response_get(self, client: LlamaStack) -> None: - with client.artifacts.with_streaming_response.get( - artifact_id="artifact_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - artifact = response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - assert cast(Any, response.is_closed) is True - - -class TestAsyncArtifacts: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_get(self, async_client: AsyncLlamaStack) -> None: - artifact = await async_client.artifacts.get( - artifact_id="artifact_id", - ) - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_raw_response_get(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.artifacts.with_raw_response.get( - artifact_id="artifact_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - artifact = await response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - @parametrize - async def test_streaming_response_get(self, async_client: AsyncLlamaStack) -> None: - async with async_client.artifacts.with_streaming_response.get( - artifact_id="artifact_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - artifact = await response.parse() - assert_matches_type(Artifact, artifact, path=["response"]) - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_batch_inference.py b/tests/api_resources/test_batch_inference.py index 3629110..13fd6bb 100644 --- a/tests/api_resources/test_batch_inference.py +++ b/tests/api_resources/test_batch_inference.py @@ -23,214 +23,208 @@ class TestBatchInference: @parametrize def test_method_chat_completion(self, client: LlamaStack) -> None: batch_inference = client.batch_inference.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, ], - "model": "model", - }, + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", ) assert_matches_type(BatchChatCompletion, batch_inference, path=["response"]) @parametrize def test_method_chat_completion_with_all_params(self, client: LlamaStack) -> None: batch_inference = client.batch_inference.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, ], - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - "tool_choice": "auto", - "tool_prompt_format": "json", - "tools": [ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", }, ], + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + ], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } + }, + }, + ], ) assert_matches_type(BatchChatCompletion, batch_inference, path=["response"]) @parametrize def test_raw_response_chat_completion(self, client: LlamaStack) -> None: response = client.batch_inference.with_raw_response.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, ], - "model": "model", - }, + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", ) assert response.is_closed is True @@ -241,53 +235,51 @@ def test_raw_response_chat_completion(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_chat_completion(self, client: LlamaStack) -> None: with client.batch_inference.with_streaming_response.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, ], - "model": "model", - }, + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -300,28 +292,24 @@ def test_streaming_response_chat_completion(self, client: LlamaStack) -> None: @parametrize def test_method_completion(self, client: LlamaStack) -> None: batch_inference = client.batch_inference.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - }, + content_batch=["string", "string", "string"], + model="model", ) assert_matches_type(BatchCompletion, batch_inference, path=["response"]) @parametrize def test_method_completion_with_all_params(self, client: LlamaStack) -> None: batch_inference = client.batch_inference.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, + content_batch=["string", "string", "string"], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, }, ) assert_matches_type(BatchCompletion, batch_inference, path=["response"]) @@ -329,10 +317,8 @@ def test_method_completion_with_all_params(self, client: LlamaStack) -> None: @parametrize def test_raw_response_completion(self, client: LlamaStack) -> None: response = client.batch_inference.with_raw_response.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - }, + content_batch=["string", "string", "string"], + model="model", ) assert response.is_closed is True @@ -343,10 +329,8 @@ def test_raw_response_completion(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_completion(self, client: LlamaStack) -> None: with client.batch_inference.with_streaming_response.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - }, + content_batch=["string", "string", "string"], + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -363,214 +347,208 @@ class TestAsyncBatchInference: @parametrize async def test_method_chat_completion(self, async_client: AsyncLlamaStack) -> None: batch_inference = await async_client.batch_inference.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, ], - "model": "model", - }, + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", ) assert_matches_type(BatchChatCompletion, batch_inference, path=["response"]) @parametrize async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStack) -> None: batch_inference = await async_client.batch_inference.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, ], - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - "tool_choice": "auto", - "tool_prompt_format": "json", - "tools": [ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", }, ], + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + ], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } + }, + }, + ], ) assert_matches_type(BatchChatCompletion, batch_inference, path=["response"]) @parametrize async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStack) -> None: response = await async_client.batch_inference.with_raw_response.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, ], - "model": "model", - }, + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", ) assert response.is_closed is True @@ -581,53 +559,51 @@ async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStack) @parametrize async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStack) -> None: async with async_client.batch_inference.with_streaming_response.chat_completion( - request={ - "messages_batch": [ - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, ], - "model": "model", - }, + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -640,28 +616,24 @@ async def test_streaming_response_chat_completion(self, async_client: AsyncLlama @parametrize async def test_method_completion(self, async_client: AsyncLlamaStack) -> None: batch_inference = await async_client.batch_inference.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - }, + content_batch=["string", "string", "string"], + model="model", ) assert_matches_type(BatchCompletion, batch_inference, path=["response"]) @parametrize async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStack) -> None: batch_inference = await async_client.batch_inference.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, + content_batch=["string", "string", "string"], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, }, ) assert_matches_type(BatchCompletion, batch_inference, path=["response"]) @@ -669,10 +641,8 @@ async def test_method_completion_with_all_params(self, async_client: AsyncLlamaS @parametrize async def test_raw_response_completion(self, async_client: AsyncLlamaStack) -> None: response = await async_client.batch_inference.with_raw_response.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - }, + content_batch=["string", "string", "string"], + model="model", ) assert response.is_closed is True @@ -683,10 +653,8 @@ async def test_raw_response_completion(self, async_client: AsyncLlamaStack) -> N @parametrize async def test_streaming_response_completion(self, async_client: AsyncLlamaStack) -> None: async with async_client.batch_inference.with_streaming_response.completion( - request={ - "content_batch": ["string", "string", "string"], - "model": "model", - }, + content_batch=["string", "string", "string"], + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_datasets.py b/tests/api_resources/test_datasets.py index 8c709b5..4aec07d 100644 --- a/tests/api_resources/test_datasets.py +++ b/tests/api_resources/test_datasets.py @@ -20,26 +20,34 @@ class TestDatasets: @parametrize def test_method_create(self, client: LlamaStack) -> None: dataset = client.datasets.create( - request={ - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "uuid": "uuid", + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, + uuid="uuid", + ) + assert dataset is None + + @parametrize + def test_method_create_with_all_params(self, client: LlamaStack) -> None: + dataset = client.datasets.create( + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, + }, + uuid="uuid", ) assert dataset is None @parametrize def test_raw_response_create(self, client: LlamaStack) -> None: response = client.datasets.with_raw_response.create( - request={ - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "uuid": "uuid", + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, + uuid="uuid", ) assert response.is_closed is True @@ -50,13 +58,11 @@ def test_raw_response_create(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_create(self, client: LlamaStack) -> None: with client.datasets.with_streaming_response.create( - request={ - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "uuid": "uuid", + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, + uuid="uuid", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -135,26 +141,34 @@ class TestAsyncDatasets: @parametrize async def test_method_create(self, async_client: AsyncLlamaStack) -> None: dataset = await async_client.datasets.create( - request={ - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "uuid": "uuid", + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + uuid="uuid", + ) + assert dataset is None + + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: + dataset = await async_client.datasets.create( + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, }, + uuid="uuid", ) assert dataset is None @parametrize async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: response = await async_client.datasets.with_raw_response.create( - request={ - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "uuid": "uuid", + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, + uuid="uuid", ) assert response.is_closed is True @@ -165,13 +179,11 @@ async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: @parametrize async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: async with async_client.datasets.with_streaming_response.create( - request={ - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "uuid": "uuid", + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, + uuid="uuid", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_evaluations.py b/tests/api_resources/test_evaluations.py index 9ba95df..dbdf834 100644 --- a/tests/api_resources/test_evaluations.py +++ b/tests/api_resources/test_evaluations.py @@ -20,32 +20,14 @@ class TestEvaluations: @parametrize def test_method_summarization(self, client: LlamaStack) -> None: evaluation = client.evaluations.summarization( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["rouge", "bleu"], ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) @parametrize def test_raw_response_summarization(self, client: LlamaStack) -> None: response = client.evaluations.with_raw_response.summarization( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["rouge", "bleu"], ) assert response.is_closed is True @@ -56,16 +38,7 @@ def test_raw_response_summarization(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_summarization(self, client: LlamaStack) -> None: with client.evaluations.with_streaming_response.summarization( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["rouge", "bleu"], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -78,32 +51,14 @@ def test_streaming_response_summarization(self, client: LlamaStack) -> None: @parametrize def test_method_text_generation(self, client: LlamaStack) -> None: evaluation = client.evaluations.text_generation( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["perplexity", "rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["perplexity", "rouge", "bleu"], ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) @parametrize def test_raw_response_text_generation(self, client: LlamaStack) -> None: response = client.evaluations.with_raw_response.text_generation( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["perplexity", "rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["perplexity", "rouge", "bleu"], ) assert response.is_closed is True @@ -114,16 +69,7 @@ def test_raw_response_text_generation(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_text_generation(self, client: LlamaStack) -> None: with client.evaluations.with_streaming_response.text_generation( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["perplexity", "rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["perplexity", "rouge", "bleu"], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -140,32 +86,14 @@ class TestAsyncEvaluations: @parametrize async def test_method_summarization(self, async_client: AsyncLlamaStack) -> None: evaluation = await async_client.evaluations.summarization( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["rouge", "bleu"], ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) @parametrize async def test_raw_response_summarization(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluations.with_raw_response.summarization( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["rouge", "bleu"], ) assert response.is_closed is True @@ -176,16 +104,7 @@ async def test_raw_response_summarization(self, async_client: AsyncLlamaStack) - @parametrize async def test_streaming_response_summarization(self, async_client: AsyncLlamaStack) -> None: async with async_client.evaluations.with_streaming_response.summarization( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["rouge", "bleu"], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -198,32 +117,14 @@ async def test_streaming_response_summarization(self, async_client: AsyncLlamaSt @parametrize async def test_method_text_generation(self, async_client: AsyncLlamaStack) -> None: evaluation = await async_client.evaluations.text_generation( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["perplexity", "rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["perplexity", "rouge", "bleu"], ) assert_matches_type(EvaluationJob, evaluation, path=["response"]) @parametrize async def test_raw_response_text_generation(self, async_client: AsyncLlamaStack) -> None: response = await async_client.evaluations.with_raw_response.text_generation( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["perplexity", "rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["perplexity", "rouge", "bleu"], ) assert response.is_closed is True @@ -234,16 +135,7 @@ async def test_raw_response_text_generation(self, async_client: AsyncLlamaStack) @parametrize async def test_streaming_response_text_generation(self, async_client: AsyncLlamaStack) -> None: async with async_client.evaluations.with_streaming_response.text_generation( - request={ - "checkpoint": {}, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "job_uuid": "job_uuid", - "metrics": ["perplexity", "rouge", "bleu"], - "sampling_params": {"strategy": "greedy"}, - }, + metrics=["perplexity", "rouge", "bleu"], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_experiments.py b/tests/api_resources/test_experiments.py deleted file mode 100644 index 0b988f0..0000000 --- a/tests/api_resources/test_experiments.py +++ /dev/null @@ -1,385 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from llama_stack import LlamaStack, AsyncLlamaStack -from tests.utils import assert_matches_type -from llama_stack.types import ( - Experiment, -) -from llama_stack.types.shared import Run - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestExperiments: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_create(self, client: LlamaStack) -> None: - experiment = client.experiments.create( - request={"name": "name"}, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_method_create_with_all_params(self, client: LlamaStack) -> None: - experiment = client.experiments.create( - request={ - "name": "name", - "metadata": {"foo": True}, - }, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_raw_response_create(self, client: LlamaStack) -> None: - response = client.experiments.with_raw_response.create( - request={"name": "name"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_streaming_response_create(self, client: LlamaStack) -> None: - with client.experiments.with_streaming_response.create( - request={"name": "name"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_retrieve(self, client: LlamaStack) -> None: - experiment = client.experiments.retrieve( - experiment_id="experiment_id", - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_raw_response_retrieve(self, client: LlamaStack) -> None: - response = client.experiments.with_raw_response.retrieve( - experiment_id="experiment_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_streaming_response_retrieve(self, client: LlamaStack) -> None: - with client.experiments.with_streaming_response.retrieve( - experiment_id="experiment_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_update(self, client: LlamaStack) -> None: - experiment = client.experiments.update( - request={"experiment_id": "experiment_id"}, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_method_update_with_all_params(self, client: LlamaStack) -> None: - experiment = client.experiments.update( - request={ - "experiment_id": "experiment_id", - "metadata": {"foo": True}, - "status": "not_started", - }, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_raw_response_update(self, client: LlamaStack) -> None: - response = client.experiments.with_raw_response.update( - request={"experiment_id": "experiment_id"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_streaming_response_update(self, client: LlamaStack) -> None: - with client.experiments.with_streaming_response.update( - request={"experiment_id": "experiment_id"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_list(self, client: LlamaStack) -> None: - experiment = client.experiments.list() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_raw_response_list(self, client: LlamaStack) -> None: - response = client.experiments.with_raw_response.list() - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - def test_streaming_response_list(self, client: LlamaStack) -> None: - with client.experiments.with_streaming_response.list() as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_create_run(self, client: LlamaStack) -> None: - experiment = client.experiments.create_run( - request={"experiment_id": "experiment_id"}, - ) - assert_matches_type(Run, experiment, path=["response"]) - - @parametrize - def test_method_create_run_with_all_params(self, client: LlamaStack) -> None: - experiment = client.experiments.create_run( - request={ - "experiment_id": "experiment_id", - "metadata": {"foo": True}, - }, - ) - assert_matches_type(Run, experiment, path=["response"]) - - @parametrize - def test_raw_response_create_run(self, client: LlamaStack) -> None: - response = client.experiments.with_raw_response.create_run( - request={"experiment_id": "experiment_id"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = response.parse() - assert_matches_type(Run, experiment, path=["response"]) - - @parametrize - def test_streaming_response_create_run(self, client: LlamaStack) -> None: - with client.experiments.with_streaming_response.create_run( - request={"experiment_id": "experiment_id"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = response.parse() - assert_matches_type(Run, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - -class TestAsyncExperiments: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_create(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.create( - request={"name": "name"}, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.create( - request={ - "name": "name", - "metadata": {"foo": True}, - }, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_raw_response_create(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.with_raw_response.create( - request={"name": "name"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_streaming_response_create(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.with_streaming_response.create( - request={"name": "name"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_retrieve(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.retrieve( - experiment_id="experiment_id", - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.with_raw_response.retrieve( - experiment_id="experiment_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.with_streaming_response.retrieve( - experiment_id="experiment_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_update(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.update( - request={"experiment_id": "experiment_id"}, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_method_update_with_all_params(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.update( - request={ - "experiment_id": "experiment_id", - "metadata": {"foo": True}, - "status": "not_started", - }, - ) - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_raw_response_update(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.with_raw_response.update( - request={"experiment_id": "experiment_id"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_streaming_response_update(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.with_streaming_response.update( - request={"experiment_id": "experiment_id"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_list(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.list() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_raw_response_list(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.with_raw_response.list() - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - @parametrize - async def test_streaming_response_list(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.with_streaming_response.list() as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = await response.parse() - assert_matches_type(Experiment, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_create_run(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.create_run( - request={"experiment_id": "experiment_id"}, - ) - assert_matches_type(Run, experiment, path=["response"]) - - @parametrize - async def test_method_create_run_with_all_params(self, async_client: AsyncLlamaStack) -> None: - experiment = await async_client.experiments.create_run( - request={ - "experiment_id": "experiment_id", - "metadata": {"foo": True}, - }, - ) - assert_matches_type(Run, experiment, path=["response"]) - - @parametrize - async def test_raw_response_create_run(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.experiments.with_raw_response.create_run( - request={"experiment_id": "experiment_id"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - experiment = await response.parse() - assert_matches_type(Run, experiment, path=["response"]) - - @parametrize - async def test_streaming_response_create_run(self, async_client: AsyncLlamaStack) -> None: - async with async_client.experiments.with_streaming_response.create_run( - request={"experiment_id": "experiment_id"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - experiment = await response.parse() - assert_matches_type(Run, experiment, path=["response"]) - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py index 149d1d7..a7b80b6 100644 --- a/tests/api_resources/test_inference.py +++ b/tests/api_resources/test_inference.py @@ -10,8 +10,8 @@ from llama_stack import LlamaStack, AsyncLlamaStack from tests.utils import assert_matches_type from llama_stack.types import ( - CompletionStreamChunk, - ChatCompletionStreamChunk, + InferenceCompletionResponse, + InferenceChatCompletionResponse, ) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -21,214 +21,344 @@ class TestInference: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - def test_method_chat_completion(self, client: LlamaStack) -> None: + def test_method_chat_completion_overload_1(self, client: LlamaStack) -> None: inference = client.inference.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "model": "model", - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", ) - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @parametrize - def test_method_chat_completion_with_all_params(self, client: LlamaStack) -> None: + def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaStack) -> None: inference = client.inference.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - "stream": True, - "tool_choice": "auto", - "tool_prompt_format": "json", - "tools": [ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + stream=False, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - ], - }, + }, + ], ) - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @parametrize - def test_raw_response_chat_completion(self, client: LlamaStack) -> None: + def test_raw_response_chat_completion_overload_1(self, client: LlamaStack) -> None: response = client.inference.with_raw_response.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "model": "model", - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = response.parse() - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @parametrize - def test_streaming_response_chat_completion(self, client: LlamaStack) -> None: + def test_streaming_response_chat_completion_overload_1(self, client: LlamaStack) -> None: with client.inference.with_streaming_response.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inference = response.parse() + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_chat_completion_overload_2(self, client: LlamaStack) -> None: + inference_stream = client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + stream=True, + ) + inference_stream.response.close() + + @parametrize + def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaStack) -> None: + inference_stream = client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + model="model", + stream=True, + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "content": "string", - "role": "user", + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "content": "string", - "role": "user", + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - ], - "model": "model", - }, + }, + ], + ) + inference_stream.response.close() + + @parametrize + def test_raw_response_chat_completion_overload_2(self, client: LlamaStack) -> None: + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_chat_completion_overload_2(self, client: LlamaStack) -> None: + with client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + stream=True, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - inference = response.parse() - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + stream = response.parse() + stream.close() assert cast(Any, response.is_closed) is True @parametrize def test_method_completion(self, client: LlamaStack) -> None: inference = client.inference.completion( - request={ - "content": "string", - "model": "model", - }, + content="string", + model="model", ) - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize def test_method_completion_with_all_params(self, client: LlamaStack) -> None: inference = client.inference.completion( - request={ - "content": "string", - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - "stream": True, + content="string", + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, }, + stream=True, ) - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize def test_raw_response_completion(self, client: LlamaStack) -> None: response = client.inference.with_raw_response.completion( - request={ - "content": "string", - "model": "model", - }, + content="string", + model="model", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = response.parse() - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize def test_streaming_response_completion(self, client: LlamaStack) -> None: with client.inference.with_streaming_response.completion( - request={ - "content": "string", - "model": "model", - }, + content="string", + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = response.parse() - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) assert cast(Any, response.is_closed) is True @@ -237,213 +367,343 @@ class TestAsyncInference: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_chat_completion(self, async_client: AsyncLlamaStack) -> None: + async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaStack) -> None: inference = await async_client.inference.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "model": "model", - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", ) - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @parametrize - async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStack) -> None: + async def test_method_chat_completion_with_all_params_overload_1(self, async_client: AsyncLlamaStack) -> None: inference = await async_client.inference.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - "stream": True, - "tool_choice": "auto", - "tool_prompt_format": "json", - "tools": [ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + stream=False, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "description": "description", - "required": True, - } - }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - ], - }, + }, + ], ) - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @parametrize - async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStack) -> None: + async def test_raw_response_chat_completion_overload_1(self, async_client: AsyncLlamaStack) -> None: response = await async_client.inference.with_raw_response.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "model": "model", - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = await response.parse() - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) @parametrize - async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStack) -> None: + async def test_streaming_response_chat_completion_overload_1(self, async_client: AsyncLlamaStack) -> None: async with async_client.inference.with_streaming_response.chat_completion( - request={ - "messages": [ - { - "content": "string", - "role": "user", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inference = await response.parse() + assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_chat_completion_overload_2(self, async_client: AsyncLlamaStack) -> None: + inference_stream = await async_client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + stream=True, + ) + await inference_stream.response.aclose() + + @parametrize + async def test_method_chat_completion_with_all_params_overload_2(self, async_client: AsyncLlamaStack) -> None: + inference_stream = await async_client.inference.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + model="model", + stream=True, + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "content": "string", - "role": "user", + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - { - "content": "string", - "role": "user", + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "description": "description", + "required": True, + } }, - ], - "model": "model", - }, + }, + ], + ) + await inference_stream.response.aclose() + + @parametrize + async def test_raw_response_chat_completion_overload_2(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_chat_completion_overload_2(self, async_client: AsyncLlamaStack) -> None: + async with async_client.inference.with_streaming_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + stream=True, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - inference = await response.parse() - assert_matches_type(ChatCompletionStreamChunk, inference, path=["response"]) + stream = await response.parse() + await stream.close() assert cast(Any, response.is_closed) is True @parametrize async def test_method_completion(self, async_client: AsyncLlamaStack) -> None: inference = await async_client.inference.completion( - request={ - "content": "string", - "model": "model", - }, + content="string", + model="model", ) - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStack) -> None: inference = await async_client.inference.completion( - request={ - "content": "string", - "model": "model", - "logprobs": {"top_k": 0}, - "sampling_params": { - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - "stream": True, + content="string", + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, }, + stream=True, ) - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize async def test_raw_response_completion(self, async_client: AsyncLlamaStack) -> None: response = await async_client.inference.with_raw_response.completion( - request={ - "content": "string", - "model": "model", - }, + content="string", + model="model", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = await response.parse() - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize async def test_streaming_response_completion(self, async_client: AsyncLlamaStack) -> None: async with async_client.inference.with_streaming_response.completion( - request={ - "content": "string", - "model": "model", - }, + content="string", + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = await response.parse() - assert_matches_type(CompletionStreamChunk, inference, path=["response"]) + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_logging.py b/tests/api_resources/test_logging.py deleted file mode 100644 index fa1d9f2..0000000 --- a/tests/api_resources/test_logging.py +++ /dev/null @@ -1,351 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from llama_stack import LlamaStack, AsyncLlamaStack -from tests.utils import assert_matches_type -from llama_stack.types import LoggingGetLogsResponse -from llama_stack._utils import parse_datetime - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestLogging: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_get_logs(self, client: LlamaStack) -> None: - logging = client.logging.get_logs( - request={"query": "query"}, - ) - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - @parametrize - def test_method_get_logs_with_all_params(self, client: LlamaStack) -> None: - logging = client.logging.get_logs( - request={ - "query": "query", - "filters": {"foo": True}, - }, - ) - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - @parametrize - def test_raw_response_get_logs(self, client: LlamaStack) -> None: - response = client.logging.with_raw_response.get_logs( - request={"query": "query"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - logging = response.parse() - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - @parametrize - def test_streaming_response_get_logs(self, client: LlamaStack) -> None: - with client.logging.with_streaming_response.get_logs( - request={"query": "query"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - logging = response.parse() - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_log_messages(self, client: LlamaStack) -> None: - logging = client.logging.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ] - }, - ) - assert logging is None - - @parametrize - def test_method_log_messages_with_all_params(self, client: LlamaStack) -> None: - logging = client.logging.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ], - "run_id": "run_id", - }, - ) - assert logging is None - - @parametrize - def test_raw_response_log_messages(self, client: LlamaStack) -> None: - response = client.logging.with_raw_response.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ] - }, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - logging = response.parse() - assert logging is None - - @parametrize - def test_streaming_response_log_messages(self, client: LlamaStack) -> None: - with client.logging.with_streaming_response.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ] - }, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - logging = response.parse() - assert logging is None - - assert cast(Any, response.is_closed) is True - - -class TestAsyncLogging: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_get_logs(self, async_client: AsyncLlamaStack) -> None: - logging = await async_client.logging.get_logs( - request={"query": "query"}, - ) - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - @parametrize - async def test_method_get_logs_with_all_params(self, async_client: AsyncLlamaStack) -> None: - logging = await async_client.logging.get_logs( - request={ - "query": "query", - "filters": {"foo": True}, - }, - ) - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - @parametrize - async def test_raw_response_get_logs(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.logging.with_raw_response.get_logs( - request={"query": "query"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - logging = await response.parse() - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - @parametrize - async def test_streaming_response_get_logs(self, async_client: AsyncLlamaStack) -> None: - async with async_client.logging.with_streaming_response.get_logs( - request={"query": "query"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - logging = await response.parse() - assert_matches_type(LoggingGetLogsResponse, logging, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_log_messages(self, async_client: AsyncLlamaStack) -> None: - logging = await async_client.logging.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ] - }, - ) - assert logging is None - - @parametrize - async def test_method_log_messages_with_all_params(self, async_client: AsyncLlamaStack) -> None: - logging = await async_client.logging.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ], - "run_id": "run_id", - }, - ) - assert logging is None - - @parametrize - async def test_raw_response_log_messages(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.logging.with_raw_response.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ] - }, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - logging = await response.parse() - assert logging is None - - @parametrize - async def test_streaming_response_log_messages(self, async_client: AsyncLlamaStack) -> None: - async with async_client.logging.with_streaming_response.log_messages( - request={ - "logs": [ - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - { - "additional_info": {"foo": True}, - "level": "level", - "message": "message", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - }, - ] - }, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - logging = await response.parse() - assert logging is None - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_post_training.py b/tests/api_resources/test_post_training.py index c16c1b7..588f159 100644 --- a/tests/api_resources/test_post_training.py +++ b/tests/api_resources/test_post_training.py @@ -22,41 +22,81 @@ class TestPostTraining: @parametrize def test_method_preference_optimize(self, client: LlamaStack) -> None: post_training = client.post_training.preference_optimize( - request={ - "algorithm": "dpo", - "algorithm_config": { - "epsilon": 0, - "gamma": 0, - "reward_clip": 0, - "reward_scale": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "finetuned_model": "https://example.com", - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + ) + assert_matches_type(PostTrainingJob, post_training, path=["response"]) + + @parametrize + def test_method_preference_optimize_with_all_params(self, client: LlamaStack) -> None: + post_training = client.post_training.preference_optimize( + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, }, ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -64,41 +104,39 @@ def test_method_preference_optimize(self, client: LlamaStack) -> None: @parametrize def test_raw_response_preference_optimize(self, client: LlamaStack) -> None: response = client.post_training.with_raw_response.preference_optimize( - request={ - "algorithm": "dpo", - "algorithm_config": { - "epsilon": 0, - "gamma": 0, - "reward_clip": 0, - "reward_scale": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "finetuned_model": "https://example.com", - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) @@ -110,41 +148,39 @@ def test_raw_response_preference_optimize(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_preference_optimize(self, client: LlamaStack) -> None: with client.post_training.with_streaming_response.preference_optimize( - request={ - "algorithm": "dpo", - "algorithm_config": { - "epsilon": 0, - "gamma": 0, - "reward_clip": 0, - "reward_scale": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "finetuned_model": "https://example.com", - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) as response: assert not response.is_closed @@ -158,42 +194,83 @@ def test_streaming_response_preference_optimize(self, client: LlamaStack) -> Non @parametrize def test_method_supervised_fine_tune(self, client: LlamaStack) -> None: post_training = client.post_training.supervised_fine_tune( - request={ - "algorithm": "full", - "algorithm_config": { - "alpha": 0, - "apply_lora_to_mlp": True, - "apply_lora_to_output": True, - "lora_attn_modules": ["string", "string", "string"], - "rank": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "model": "model", - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + ) + assert_matches_type(PostTrainingJob, post_training, path=["response"]) + + @parametrize + def test_method_supervised_fine_tune_with_all_params(self, client: LlamaStack) -> None: + post_training = client.post_training.supervised_fine_tune( + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, }, ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -201,42 +278,40 @@ def test_method_supervised_fine_tune(self, client: LlamaStack) -> None: @parametrize def test_raw_response_supervised_fine_tune(self, client: LlamaStack) -> None: response = client.post_training.with_raw_response.supervised_fine_tune( - request={ - "algorithm": "full", - "algorithm_config": { - "alpha": 0, - "apply_lora_to_mlp": True, - "apply_lora_to_output": True, - "lora_attn_modules": ["string", "string", "string"], - "rank": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "model": "model", - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) @@ -248,42 +323,40 @@ def test_raw_response_supervised_fine_tune(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_supervised_fine_tune(self, client: LlamaStack) -> None: with client.post_training.with_streaming_response.supervised_fine_tune( - request={ - "algorithm": "full", - "algorithm_config": { - "alpha": 0, - "apply_lora_to_mlp": True, - "apply_lora_to_output": True, - "lora_attn_modules": ["string", "string", "string"], - "rank": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "model": "model", - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) as response: assert not response.is_closed @@ -301,41 +374,81 @@ class TestAsyncPostTraining: @parametrize async def test_method_preference_optimize(self, async_client: AsyncLlamaStack) -> None: post_training = await async_client.post_training.preference_optimize( - request={ - "algorithm": "dpo", - "algorithm_config": { - "epsilon": 0, - "gamma": 0, - "reward_clip": 0, - "reward_scale": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "finetuned_model": "https://example.com", - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + ) + assert_matches_type(PostTrainingJob, post_training, path=["response"]) + + @parametrize + async def test_method_preference_optimize_with_all_params(self, async_client: AsyncLlamaStack) -> None: + post_training = await async_client.post_training.preference_optimize( + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, }, ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -343,41 +456,39 @@ async def test_method_preference_optimize(self, async_client: AsyncLlamaStack) - @parametrize async def test_raw_response_preference_optimize(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.with_raw_response.preference_optimize( - request={ - "algorithm": "dpo", - "algorithm_config": { - "epsilon": 0, - "gamma": 0, - "reward_clip": 0, - "reward_scale": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "finetuned_model": "https://example.com", - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) @@ -389,41 +500,39 @@ async def test_raw_response_preference_optimize(self, async_client: AsyncLlamaSt @parametrize async def test_streaming_response_preference_optimize(self, async_client: AsyncLlamaStack) -> None: async with async_client.post_training.with_streaming_response.preference_optimize( - request={ - "algorithm": "dpo", - "algorithm_config": { - "epsilon": 0, - "gamma": 0, - "reward_clip": 0, - "reward_scale": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "finetuned_model": "https://example.com", - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="dpo", + algorithm_config={ + "epsilon": 0, + "gamma": 0, + "reward_clip": 0, + "reward_scale": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + finetuned_model="https://example.com", + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) as response: assert not response.is_closed @@ -437,42 +546,83 @@ async def test_streaming_response_preference_optimize(self, async_client: AsyncL @parametrize async def test_method_supervised_fine_tune(self, async_client: AsyncLlamaStack) -> None: post_training = await async_client.post_training.supervised_fine_tune( - request={ - "algorithm": "full", - "algorithm_config": { - "alpha": 0, - "apply_lora_to_mlp": True, - "apply_lora_to_output": True, - "lora_attn_modules": ["string", "string", "string"], - "rank": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "model": "model", - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + ) + assert_matches_type(PostTrainingJob, post_training, path=["response"]) + + @parametrize + async def test_method_supervised_fine_tune_with_all_params(self, async_client: AsyncLlamaStack) -> None: + post_training = await async_client.post_training.supervised_fine_tune( + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + "metadata": {"foo": True}, }, ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -480,42 +630,40 @@ async def test_method_supervised_fine_tune(self, async_client: AsyncLlamaStack) @parametrize async def test_raw_response_supervised_fine_tune(self, async_client: AsyncLlamaStack) -> None: response = await async_client.post_training.with_raw_response.supervised_fine_tune( - request={ - "algorithm": "full", - "algorithm_config": { - "alpha": 0, - "apply_lora_to_mlp": True, - "apply_lora_to_output": True, - "lora_attn_modules": ["string", "string", "string"], - "rank": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "model": "model", - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) @@ -527,42 +675,40 @@ async def test_raw_response_supervised_fine_tune(self, async_client: AsyncLlamaS @parametrize async def test_streaming_response_supervised_fine_tune(self, async_client: AsyncLlamaStack) -> None: async with async_client.post_training.with_streaming_response.supervised_fine_tune( - request={ - "algorithm": "full", - "algorithm_config": { - "alpha": 0, - "apply_lora_to_mlp": True, - "apply_lora_to_output": True, - "lora_attn_modules": ["string", "string", "string"], - "rank": 0, - }, - "dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - "hyperparam_search_config": {"foo": True}, - "job_uuid": "job_uuid", - "logger_config": {"foo": True}, - "model": "model", - "optimizer_config": { - "lr": 0, - "lr_min": 0, - "optimizer_type": "adam", - "weight_decay": 0, - }, - "training_config": { - "batch_size": 0, - "enable_activation_checkpointing": True, - "fsdp_cpu_offload": True, - "memory_efficient_fsdp_wrap": True, - "n_epochs": 0, - "n_iters": 0, - "shuffle": True, - }, - "validation_dataset": { - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + algorithm="full", + algorithm_config={ + "alpha": 0, + "apply_lora_to_mlp": True, + "apply_lora_to_output": True, + "lora_attn_modules": ["string", "string", "string"], + "rank": 0, + }, + dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", + }, + hyperparam_search_config={"foo": True}, + job_uuid="job_uuid", + logger_config={"foo": True}, + model="model", + optimizer_config={ + "lr": 0, + "lr_min": 0, + "optimizer_type": "adam", + "weight_decay": 0, + }, + training_config={ + "batch_size": 0, + "enable_activation_checkpointing": True, + "fsdp_cpu_offload": True, + "memory_efficient_fsdp_wrap": True, + "n_epochs": 0, + "n_iters": 0, + "shuffle": True, + }, + validation_dataset={ + "columns": {"foo": "dialog"}, + "content_url": "https://example.com", }, ) as response: assert not response.is_closed diff --git a/tests/api_resources/test_reward_scoring.py b/tests/api_resources/test_reward_scoring.py index 7a62b93..83f6983 100644 --- a/tests/api_resources/test_reward_scoring.py +++ b/tests/api_resources/test_reward_scoring.py @@ -20,202 +20,198 @@ class TestRewardScoring: @parametrize def test_method_score(self, client: LlamaStack) -> None: reward_scoring = client.reward_scoring.score( - request={ - "dialog_generations": [ - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - ], - "model": "model", - }, + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + ], + model="model", ) assert_matches_type(RewardScoring, reward_scoring, path=["response"]) @parametrize def test_raw_response_score(self, client: LlamaStack) -> None: response = client.reward_scoring.with_raw_response.score( - request={ - "dialog_generations": [ - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - ], - "model": "model", - }, + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + ], + model="model", ) assert response.is_closed is True @@ -226,101 +222,99 @@ def test_raw_response_score(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_score(self, client: LlamaStack) -> None: with client.reward_scoring.with_streaming_response.score( - request={ - "dialog_generations": [ - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - ], - "model": "model", - }, + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + ], + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -337,202 +331,198 @@ class TestAsyncRewardScoring: @parametrize async def test_method_score(self, async_client: AsyncLlamaStack) -> None: reward_scoring = await async_client.reward_scoring.score( - request={ - "dialog_generations": [ - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - ], - "model": "model", - }, + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + ], + model="model", ) assert_matches_type(RewardScoring, reward_scoring, path=["response"]) @parametrize async def test_raw_response_score(self, async_client: AsyncLlamaStack) -> None: response = await async_client.reward_scoring.with_raw_response.score( - request={ - "dialog_generations": [ - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - ], - "model": "model", - }, + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + ], + model="model", ) assert response.is_closed is True @@ -543,101 +533,99 @@ async def test_raw_response_score(self, async_client: AsyncLlamaStack) -> None: @parametrize async def test_streaming_response_score(self, async_client: AsyncLlamaStack) -> None: async with async_client.reward_scoring.with_streaming_response.score( - request={ - "dialog_generations": [ - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - { - "dialog": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "sampled_generations": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - }, - ], - "model": "model", - }, + dialog_generations=[ + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + { + "dialog": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + "sampled_generations": [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + }, + ], + model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_runs.py b/tests/api_resources/test_runs.py deleted file mode 100644 index beedc59..0000000 --- a/tests/api_resources/test_runs.py +++ /dev/null @@ -1,303 +0,0 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - -from __future__ import annotations - -import os -from typing import Any, cast - -import pytest - -from llama_stack import LlamaStack, AsyncLlamaStack -from tests.utils import assert_matches_type -from llama_stack._utils import parse_datetime -from llama_stack.types.shared import Run - -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") - - -class TestRuns: - parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - def test_method_update(self, client: LlamaStack) -> None: - run = client.runs.update( - request={"run_id": "run_id"}, - ) - assert_matches_type(Run, run, path=["response"]) - - @parametrize - def test_method_update_with_all_params(self, client: LlamaStack) -> None: - run = client.runs.update( - request={ - "run_id": "run_id", - "ended_at": parse_datetime("2019-12-27T18:11:19.117Z"), - "metadata": {"foo": True}, - "status": "status", - }, - ) - assert_matches_type(Run, run, path=["response"]) - - @parametrize - def test_raw_response_update(self, client: LlamaStack) -> None: - response = client.runs.with_raw_response.update( - request={"run_id": "run_id"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - run = response.parse() - assert_matches_type(Run, run, path=["response"]) - - @parametrize - def test_streaming_response_update(self, client: LlamaStack) -> None: - with client.runs.with_streaming_response.update( - request={"run_id": "run_id"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - run = response.parse() - assert_matches_type(Run, run, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_log_metrics(self, client: LlamaStack) -> None: - run = client.runs.log_metrics( - request={ - "metrics": [ - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - ], - "run_id": "run_id", - }, - ) - assert run is None - - @parametrize - def test_raw_response_log_metrics(self, client: LlamaStack) -> None: - response = client.runs.with_raw_response.log_metrics( - request={ - "metrics": [ - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - ], - "run_id": "run_id", - }, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - run = response.parse() - assert run is None - - @parametrize - def test_streaming_response_log_metrics(self, client: LlamaStack) -> None: - with client.runs.with_streaming_response.log_metrics( - request={ - "metrics": [ - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - ], - "run_id": "run_id", - }, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - run = response.parse() - assert run is None - - assert cast(Any, response.is_closed) is True - - -class TestAsyncRuns: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - - @parametrize - async def test_method_update(self, async_client: AsyncLlamaStack) -> None: - run = await async_client.runs.update( - request={"run_id": "run_id"}, - ) - assert_matches_type(Run, run, path=["response"]) - - @parametrize - async def test_method_update_with_all_params(self, async_client: AsyncLlamaStack) -> None: - run = await async_client.runs.update( - request={ - "run_id": "run_id", - "ended_at": parse_datetime("2019-12-27T18:11:19.117Z"), - "metadata": {"foo": True}, - "status": "status", - }, - ) - assert_matches_type(Run, run, path=["response"]) - - @parametrize - async def test_raw_response_update(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.runs.with_raw_response.update( - request={"run_id": "run_id"}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - run = await response.parse() - assert_matches_type(Run, run, path=["response"]) - - @parametrize - async def test_streaming_response_update(self, async_client: AsyncLlamaStack) -> None: - async with async_client.runs.with_streaming_response.update( - request={"run_id": "run_id"}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - run = await response.parse() - assert_matches_type(Run, run, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_log_metrics(self, async_client: AsyncLlamaStack) -> None: - run = await async_client.runs.log_metrics( - request={ - "metrics": [ - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - ], - "run_id": "run_id", - }, - ) - assert run is None - - @parametrize - async def test_raw_response_log_metrics(self, async_client: AsyncLlamaStack) -> None: - response = await async_client.runs.with_raw_response.log_metrics( - request={ - "metrics": [ - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - ], - "run_id": "run_id", - }, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - run = await response.parse() - assert run is None - - @parametrize - async def test_streaming_response_log_metrics(self, async_client: AsyncLlamaStack) -> None: - async with async_client.runs.with_streaming_response.log_metrics( - request={ - "metrics": [ - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - { - "name": "name", - "run_id": "run_id", - "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), - "value": 0, - }, - ], - "run_id": "run_id", - }, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - run = await response.parse() - assert run is None - - assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_safety.py b/tests/api_resources/test_safety.py index 3781341..f4b44dc 100644 --- a/tests/api_resources/test_safety.py +++ b/tests/api_resources/test_safety.py @@ -20,72 +20,68 @@ class TestSafety: @parametrize def test_method_run_shields(self, client: LlamaStack) -> None: safety = client.safety.run_shields( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + shields=[ + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + ], ) assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) @parametrize def test_raw_response_run_shields(self, client: LlamaStack) -> None: response = client.safety.with_raw_response.run_shields( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + shields=[ + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + ], ) assert response.is_closed is True @@ -96,36 +92,34 @@ def test_raw_response_run_shields(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_run_shields(self, client: LlamaStack) -> None: with client.safety.with_streaming_response.run_shields( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + shields=[ + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + ], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -142,72 +136,68 @@ class TestAsyncSafety: @parametrize async def test_method_run_shields(self, async_client: AsyncLlamaStack) -> None: safety = await async_client.safety.run_shields( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + shields=[ + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + ], ) assert_matches_type(SafetyRunShieldsResponse, safety, path=["response"]) @parametrize async def test_raw_response_run_shields(self, async_client: AsyncLlamaStack) -> None: response = await async_client.safety.with_raw_response.run_shields( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + shields=[ + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + ], ) assert response.is_closed is True @@ -218,36 +208,34 @@ async def test_raw_response_run_shields(self, async_client: AsyncLlamaStack) -> @parametrize async def test_streaming_response_run_shields(self, async_client: AsyncLlamaStack) -> None: async with async_client.safety.with_streaming_response.run_shields( - request={ - "messages": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "shields": [ - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - { - "on_violation_action": 0, - "shield_type": "llama_guard", - }, - ], - }, + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + shields=[ + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + { + "on_violation_action": 0, + "shield_type": "llama_guard", + }, + ], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_synthetic_data_generation.py b/tests/api_resources/test_synthetic_data_generation.py index ee93b41..7f5d73b 100644 --- a/tests/api_resources/test_synthetic_data_generation.py +++ b/tests/api_resources/test_synthetic_data_generation.py @@ -20,73 +20,67 @@ class TestSyntheticDataGeneration: @parametrize def test_method_generate(self, client: LlamaStack) -> None: synthetic_data_generation = client.synthetic_data_generation.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "filtering_function": "none", - }, + dialogs=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + filtering_function="none", ) assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) @parametrize def test_method_generate_with_all_params(self, client: LlamaStack) -> None: synthetic_data_generation = client.synthetic_data_generation.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - "filtering_function": "none", - "model": "model", - }, + dialogs=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + filtering_function="none", + model="model", ) assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) @parametrize def test_raw_response_generate(self, client: LlamaStack) -> None: response = client.synthetic_data_generation.with_raw_response.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "filtering_function": "none", - }, + dialogs=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + filtering_function="none", ) assert response.is_closed is True @@ -97,23 +91,21 @@ def test_raw_response_generate(self, client: LlamaStack) -> None: @parametrize def test_streaming_response_generate(self, client: LlamaStack) -> None: with client.synthetic_data_generation.with_streaming_response.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "filtering_function": "none", - }, + dialogs=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + filtering_function="none", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -130,73 +122,67 @@ class TestAsyncSyntheticDataGeneration: @parametrize async def test_method_generate(self, async_client: AsyncLlamaStack) -> None: synthetic_data_generation = await async_client.synthetic_data_generation.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "filtering_function": "none", - }, + dialogs=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + filtering_function="none", ) assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) @parametrize async def test_method_generate_with_all_params(self, async_client: AsyncLlamaStack) -> None: synthetic_data_generation = await async_client.synthetic_data_generation.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], - "filtering_function": "none", - "model": "model", - }, + dialogs=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + filtering_function="none", + model="model", ) assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) @parametrize async def test_raw_response_generate(self, async_client: AsyncLlamaStack) -> None: response = await async_client.synthetic_data_generation.with_raw_response.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "filtering_function": "none", - }, + dialogs=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + filtering_function="none", ) assert response.is_closed is True @@ -207,23 +193,21 @@ async def test_raw_response_generate(self, async_client: AsyncLlamaStack) -> Non @parametrize async def test_streaming_response_generate(self, async_client: AsyncLlamaStack) -> None: async with async_client.synthetic_data_generation.with_streaming_response.generate( - request={ - "dialogs": [ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], - "filtering_function": "none", - }, + dialogs=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + filtering_function="none", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_telemetry.py b/tests/api_resources/test_telemetry.py new file mode 100644 index 0000000..317bb8b --- /dev/null +++ b/tests/api_resources/test_telemetry.py @@ -0,0 +1,219 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from llama_stack import LlamaStack, AsyncLlamaStack +from tests.utils import assert_matches_type +from llama_stack.types import TelemetryGetTraceResponse +from llama_stack._utils import parse_datetime + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestTelemetry: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_get_trace(self, client: LlamaStack) -> None: + telemetry = client.telemetry.get_trace( + trace_id="trace_id", + ) + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + + @parametrize + def test_raw_response_get_trace(self, client: LlamaStack) -> None: + response = client.telemetry.with_raw_response.get_trace( + trace_id="trace_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + telemetry = response.parse() + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + + @parametrize + def test_streaming_response_get_trace(self, client: LlamaStack) -> None: + with client.telemetry.with_streaming_response.get_trace( + trace_id="trace_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + telemetry = response.parse() + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_log(self, client: LlamaStack) -> None: + telemetry = client.telemetry.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + }, + ) + assert telemetry is None + + @parametrize + def test_method_log_with_all_params(self, client: LlamaStack) -> None: + telemetry = client.telemetry.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + "attributes": {"foo": True}, + }, + ) + assert telemetry is None + + @parametrize + def test_raw_response_log(self, client: LlamaStack) -> None: + response = client.telemetry.with_raw_response.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + }, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + telemetry = response.parse() + assert telemetry is None + + @parametrize + def test_streaming_response_log(self, client: LlamaStack) -> None: + with client.telemetry.with_streaming_response.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + telemetry = response.parse() + assert telemetry is None + + assert cast(Any, response.is_closed) is True + + +class TestAsyncTelemetry: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_get_trace(self, async_client: AsyncLlamaStack) -> None: + telemetry = await async_client.telemetry.get_trace( + trace_id="trace_id", + ) + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + + @parametrize + async def test_raw_response_get_trace(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.telemetry.with_raw_response.get_trace( + trace_id="trace_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + telemetry = await response.parse() + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + + @parametrize + async def test_streaming_response_get_trace(self, async_client: AsyncLlamaStack) -> None: + async with async_client.telemetry.with_streaming_response.get_trace( + trace_id="trace_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + telemetry = await response.parse() + assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_log(self, async_client: AsyncLlamaStack) -> None: + telemetry = await async_client.telemetry.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + }, + ) + assert telemetry is None + + @parametrize + async def test_method_log_with_all_params(self, async_client: AsyncLlamaStack) -> None: + telemetry = await async_client.telemetry.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + "attributes": {"foo": True}, + }, + ) + assert telemetry is None + + @parametrize + async def test_raw_response_log(self, async_client: AsyncLlamaStack) -> None: + response = await async_client.telemetry.with_raw_response.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + }, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + telemetry = await response.parse() + assert telemetry is None + + @parametrize + async def test_streaming_response_log(self, async_client: AsyncLlamaStack) -> None: + async with async_client.telemetry.with_streaming_response.log( + event={ + "message": "message", + "severity": "verbose", + "span_id": "span_id", + "timestamp": parse_datetime("2019-12-27T18:11:19.117Z"), + "trace_id": "trace_id", + "type": "unstructured_log", + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + telemetry = await response.parse() + assert telemetry is None + + assert cast(Any, response.is_closed) is True diff --git a/tests/test_client.py b/tests/test_client.py index 8091889..c531ba5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -673,20 +673,12 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("llama_stack._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agentic_system/create").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/agentic_system/session/create").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): self.client.post( - "/agentic_system/create", - body=cast( - object, - dict( - agent_config={ - "instructions": "instructions", - "model": "model", - } - ), - ), + "/agentic_system/session/create", + body=cast(object, dict(agent_id="agent_id", session_name="session_name")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -696,20 +688,12 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No @mock.patch("llama_stack._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agentic_system/create").mock(return_value=httpx.Response(500)) + respx_mock.post("/agentic_system/session/create").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): self.client.post( - "/agentic_system/create", - body=cast( - object, - dict( - agent_config={ - "instructions": "instructions", - "model": "model", - } - ), - ), + "/agentic_system/session/create", + body=cast(object, dict(agent_id="agent_id", session_name="session_name")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -731,13 +715,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/agentic_system/create").mock(side_effect=retry_handler) + respx_mock.post("/agentic_system/session/create").mock(side_effect=retry_handler) - response = client.agentic_system.with_raw_response.create( - agent_config={ - "instructions": "instructions", - "model": "model", - } + response = client.agentic_system.sessions.with_raw_response.create( + agent_id="agent_id", session_name="session_name" ) assert response.retries_taken == failures_before_success @@ -1374,20 +1355,12 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("llama_stack._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agentic_system/create").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/agentic_system/session/create").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): await self.client.post( - "/agentic_system/create", - body=cast( - object, - dict( - agent_config={ - "instructions": "instructions", - "model": "model", - } - ), - ), + "/agentic_system/session/create", + body=cast(object, dict(agent_id="agent_id", session_name="session_name")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1397,20 +1370,12 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) @mock.patch("llama_stack._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agentic_system/create").mock(return_value=httpx.Response(500)) + respx_mock.post("/agentic_system/session/create").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): await self.client.post( - "/agentic_system/create", - body=cast( - object, - dict( - agent_config={ - "instructions": "instructions", - "model": "model", - } - ), - ), + "/agentic_system/session/create", + body=cast(object, dict(agent_id="agent_id", session_name="session_name")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1435,13 +1400,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/agentic_system/create").mock(side_effect=retry_handler) + respx_mock.post("/agentic_system/session/create").mock(side_effect=retry_handler) - response = await client.agentic_system.with_raw_response.create( - agent_config={ - "instructions": "instructions", - "model": "model", - } + response = await client.agentic_system.sessions.with_raw_response.create( + agent_id="agent_id", session_name="session_name" ) assert response.retries_taken == failures_before_success