Skip to content

Commit

Permalink
Update get_client
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Feb 27, 2024
1 parent 8a8f02a commit ec36d2c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,9 @@ def setup_for_execution(self, context: InitResourceContext) -> None:

@contextmanager
def get_client(
self,
context: Union[AssetExecutionContext, OpExecutionContext],
asset_key: Optional[AssetKey] = None,
self, context: Union[AssetExecutionContext, OpExecutionContext]
) -> Generator[Client, None, None]:
"""Returns an ``openai.Client`` for interacting with the OpenAI API.
"""Yields an ``openai.Client`` for interacting with the OpenAI API.
By default, in an asset context, the client comes with wrapped endpoints
for three API resources, Completions, Embeddings and Chat,
Expand All @@ -228,7 +226,6 @@ def get_client(
to automatically capture the API usage metadata in an op context.
:param context: The ``context`` object for computing the op or asset in which ``get_client`` is called.
:param asset_key: the ``asset_key`` of the asset for which a materialization should include the metadata.
Examples:
.. code-block:: python
Expand Down Expand Up @@ -275,6 +272,86 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource):
},
)
"""
yield from self._get_client(context=context, asset_key=None)

@contextmanager
def get_client_for_asset(
self, context: AssetExecutionContext, asset_key: AssetKey
) -> Generator[Client, None, None]:
"""Yields an ``openai.Client`` for interacting with the OpenAI.
When using this method, the OpenAI API usage metadata is automatically
logged in the asset materializations associated with the provided ``asset_key``.
By default, the client comes with wrapped endpoints
for three API resources, Completions, Embeddings and Chat,
allowing to log the API usage metadata in the asset metadata.
This method can only be called when working with assets,
i.e. the provided ``context`` must be of type ``AssetExecutionContext.
:param context: The ``context`` object for computing the asset in which ``get_client`` is called.
:param asset_key: the ``asset_key`` of the asset for which a materialization should include the metadata.
Examples:
.. code-block:: python
from dagster import (
AssetExecutionContext,
AssetKey,
AssetSpec,
Definitions,
EnvVar,
MaterializeResult,
asset,
define_asset_job,
multi_asset,
)
from dagster_openai import OpenAIResource
@asset(compute_kind="OpenAI")
def openai_asset(context: AssetExecutionContext, openai: OpenAIResource):
with openai.get_client_for_asset(context, context.asset_key) as client:
client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}]
)
openai_asset_job = define_asset_job(name="openai_asset_job", selection="openai_asset")
@multi_asset(specs=[AssetSpec("my_asset1"), AssetSpec("my_asset2")], compute_kind="OpenAI")
def openai_multi_asset(context: AssetExecutionContext, openai_resource: OpenAIResource):
with openai_resource.get_client_for_asset(context, asset_key=AssetKey("my_asset1")) as client:
client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}]
)
return (
MaterializeResult(asset_key="my_asset1", metadata={"some_key": "some_value1"}),
MaterializeResult(asset_key="my_asset2", metadata={"some_key": "some_value2"}),
)
openai_multi_asset_job = define_asset_job(
name="openai_multi_asset_job", selection="openai_multi_asset"
)
defs = Definitions(
assets=[openai_asset, openai_multi_asset],
jobs=[openai_asset_job, openai_multi_asset_job],
resources={
"openai": OpenAIResource(api_key=EnvVar("OPENAI_API_KEY")),
},
)
"""
yield from self._get_client(context=context, asset_key=asset_key)

def _get_client(
self,
context: Union[AssetExecutionContext, OpExecutionContext],
asset_key: Optional[AssetKey] = None,
) -> Generator[Client, None, None]:
if isinstance(context, AssetExecutionContext):
if asset_key is None:
if len(context.assets_def.keys_by_output_name.keys()) > 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def openai_multi_asset(openai_resource: OpenAIResource):
mock_context.output_for_asset_key.return_value = "result"

# Test success when asset_key is provided
with openai_resource.get_client(
with openai_resource.get_client_for_asset(
context=mock_context, asset_key=AssetKey("result")
) as client:
client.chat.completions.create(
Expand Down Expand Up @@ -375,7 +375,7 @@ def openai_multi_asset(openai_resource: OpenAIResource):
mock_completion.usage = mock_usage
mock_client.return_value.fine_tuning.jobs.create.return_value = mock_completion

with openai_resource.get_client(
with openai_resource.get_client_for_asset(
context=mock_context, asset_key=AssetKey("result")
) as client:
client.fine_tuning.jobs.create = with_usage_metadata(
Expand Down

0 comments on commit ec36d2c

Please sign in to comment.