From bfad41dae4f8846dc37f931dc26b4d071e582b03 Mon Sep 17 00:00:00 2001 From: Maxime Armstrong Date: Thu, 22 Feb 2024 13:06:53 -0500 Subject: [PATCH] Update tests --- .../dagster_openai/resources.py | 7 +- .../dagster_openai_tests/test_resources.py | 206 +++++++++++++----- 2 files changed, 150 insertions(+), 63 deletions(-) diff --git a/python_modules/libraries/dagster-openai/dagster_openai/resources.py b/python_modules/libraries/dagster-openai/dagster_openai/resources.py index 64eb41f5bf048..2655a76192e02 100644 --- a/python_modules/libraries/dagster-openai/dagster_openai/resources.py +++ b/python_modules/libraries/dagster-openai/dagster_openai/resources.py @@ -35,7 +35,7 @@ class ApiEndpointClassesEnum(Enum): context_to_counters = WeakKeyDictionary() -def add_to_asset_metadata(context: AssetExecutionContext, usage_metadata: dict, output_name: str): +def add_to_asset_metadata(context: AssetExecutionContext, usage_metadata: dict, output_name: Optional[str]): if context not in context_to_counters: context_to_counters[context] = defaultdict(lambda: 0) counters = context_to_counters[context] @@ -45,7 +45,7 @@ def add_to_asset_metadata(context: AssetExecutionContext, usage_metadata: dict, context.add_output_metadata(dict(counters), output_name) -def with_usage_metadata(context: AssetExecutionContext, output_name: str, func): +def with_usage_metadata(context: AssetExecutionContext, output_name: Optional[str], func): """This wrapper can be used on any endpoint of the `openai library ` to log the OpenAI API usage metadata in the asset metadata. @@ -83,7 +83,7 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource): ) """ if not isinstance(context, AssetExecutionContext): - raise TypeError( + raise DagsterInvariantViolationError( "The `with_usage_metadata` can only be used when context is of type AssetExecutionContext." ) @@ -224,7 +224,6 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource): raise DagsterInvariantViolationError( "The argument `asset_key` must be specified for multi_asset with more than one asset." ) - asset_key = context.asset_key # By default, when the resource is used in an asset context, # we wrap the methods of `openai.resources.Completions`, # `openai.resources.Embeddings` and `openai.resources.chat.Completions`. diff --git a/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py b/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py index 9ef2565a8ca11..ef523b3669baf 100644 --- a/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py +++ b/python_modules/libraries/dagster-openai/dagster_openai_tests/test_resources.py @@ -1,5 +1,3 @@ -import json - import pytest from dagster import ( AssetExecutionContext, @@ -7,6 +5,7 @@ AssetSelection, Definitions, OpExecutionContext, + StaticPartitionsDefinition, asset, define_asset_job, multi_asset, @@ -15,7 +14,7 @@ from dagster._core.errors import DagsterInvariantViolationError from dagster._core.execution.context.init import build_init_resource_context from dagster._utils.test import wrap_op_in_graph_and_execute -from dagster_openai import OpenAIResource +from dagster_openai import OpenAIResource, with_usage_metadata from mock import ANY, MagicMock, patch @@ -30,20 +29,13 @@ def test_openai_client(mock_client) -> None: @patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata") +@patch("dagster.OpExecutionContext", autospec=OpExecutionContext) @patch("dagster_openai.resources.Client") -def test_openai_resource_with_op(mock_client, mock_wrapper): +def test_openai_resource_with_op(mock_client, mock_context, mock_wrapper): @op def openai_op(openai_resource: OpenAIResource): assert openai_resource - body = {"ok": True} - mock_client.chat.completions.create.return_value = { - "status": 200, - "body": json.dumps(body), - "headers": "", - } - - mock_context = MagicMock() - mock_context.__class__ = OpExecutionContext + with openai_resource.get_client(context=mock_context) as client: client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}] @@ -62,20 +54,13 @@ def openai_op(openai_resource: OpenAIResource): @patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata") +@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_openai.resources.Client") -def test_openai_resource_with_asset(mock_client, mock_wrapper): +def test_openai_resource_with_asset(mock_client, mock_context, mock_wrapper): @asset def openai_asset(openai_resource: OpenAIResource): assert openai_resource - body = {"ok": True} - mock_client.chat.completions.create.return_value = { - "status": 200, - "body": json.dumps(body), - "headers": "", - } - - mock_context = MagicMock() - mock_context.__class__ = AssetExecutionContext + with openai_resource.get_client(context=mock_context) as client: client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}] @@ -104,8 +89,9 @@ def openai_asset(openai_resource: OpenAIResource): @patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata") +@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_openai.resources.Client") -def test_openai_resource_with_multi_asset_with_asset_key(mock_client, mock_wrapper): +def test_openai_resource_with_multi_asset(mock_client, mock_context, mock_wrapper): @multi_asset( outs={ "status": AssetOut(), @@ -114,15 +100,10 @@ def test_openai_resource_with_multi_asset_with_asset_key(mock_client, mock_wrapp ) def openai_multi_asset(openai_resource: OpenAIResource): assert openai_resource - body = {"ok": True} - mock_client.chat.completions.create.return_value = { - "status": 200, - "body": json.dumps(body), - "headers": "", - } - - mock_context = MagicMock() - mock_context.__class__ = AssetExecutionContext + + mock_context.assets_def.keys_by_output_name.keys.return_value = ["status", "result"] + + # Test success when asset_key is provided with openai_resource.get_client(context=mock_context, asset_key="result") as client: client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}] @@ -135,17 +116,30 @@ def openai_multi_asset(openai_resource: OpenAIResource): context=mock_context, output_name="result", ) + + # Test failure when asset_key is not provided + with pytest.raises(DagsterInvariantViolationError): + with openai_resource.get_client(context=mock_context) as client: + client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Say this is a test"}], + ) return None, None result = ( Definitions( assets=[openai_multi_asset], - jobs=[define_asset_job("openai_asset_job")], + jobs=[ + define_asset_job( + name="openai_multi_asset_job", + selection=AssetSelection.assets(openai_multi_asset), + ) + ], resources={ "openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234") }, ) - .get_job_def("openai_asset_job") + .get_job_def("openai_multi_asset_job") .execute_in_process() ) @@ -153,38 +147,132 @@ def openai_multi_asset(openai_resource: OpenAIResource): @patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata") +@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) @patch("dagster_openai.resources.Client") -def test_openai_resource_with_multi_asset_without_asset_key(mock_client, mock_wrapper): - @multi_asset( - outs={ - "status": AssetOut(), - "result": AssetOut(), - }, - ) - def openai_multi_asset(openai_resource: OpenAIResource): - assert openai_resource - body = {"ok": True} - mock_client.chat.completions.create.return_value = { - "status": 200, - "body": json.dumps(body), - "headers": "", - } - - mock_context = MagicMock() - mock_context.__class__ = AssetExecutionContext - mock_context.assets_def.keys_by_output_name.keys.return_value = ["status", "result"] - with pytest.raises(DagsterInvariantViolationError): - with openai_resource.get_client(context=mock_context) as client: +def test_openai_resource_with_partitioned_asset(mock_client, mock_context, mock_wrapper): + openai_partitions_def = StaticPartitionsDefinition([str(j) for j in range(5)]) + + openai_partitioned_assets = [] + + for i in range(5): + + @asset( + name=f"openai_partitioned_asset_{i}", + group_name="openai_partitioned_assets", + partitions_def=openai_partitions_def, + ) + def openai_partitioned_asset(openai_resource: OpenAIResource): + assert openai_resource + + with openai_resource.get_client(context=mock_context, asset_key="result") as client: client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}], ) - return None, None + + assert mock_client.called + assert mock_wrapper.called + mock_wrapper.assert_called_with( + api_endpoint_class=ANY, + context=mock_context, + output_name="result", + ) + + openai_partitioned_assets.append(openai_partitioned_asset) + + defs = Definitions( + assets=openai_partitioned_assets, + jobs=[ + define_asset_job( + name="openai_partitioned_asset_job", + selection=AssetSelection.groups("openai_partitioned_assets"), + partitions_def=openai_partitions_def, + ) + ], + resources={ + "openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234") + }, + ) + + for partition_key in openai_partitions_def.get_partition_keys(): + result = defs.get_job_def("openai_partitioned_asset_job").execute_in_process( + partition_key=partition_key + ) + assert result.success + + expected_wrapper_call_counts = ( + 3 * len(openai_partitioned_assets) * len(openai_partitions_def.get_partition_keys()) + ) + assert mock_wrapper.call_count == expected_wrapper_call_counts + + +@patch("dagster.OpExecutionContext", autospec=OpExecutionContext) +@patch("dagster_openai.resources.Client") +def test_openai_wrapper_with_op(mock_client, mock_context): + @op + def openai_op(openai_resource: OpenAIResource): + assert openai_resource + + with openai_resource.get_client(context=mock_context) as client: + with pytest.raises(DagsterInvariantViolationError): + client.fine_tuning.jobs.create = with_usage_metadata( + context=mock_context, + output_name="some_output_name", + func=client.fine_tuning.jobs.create, + ) + + result = wrap_op_in_graph_and_execute( + openai_op, + resources={ + "openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234") + }, + ) + assert result.success + + +@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata") +@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext) +@patch("dagster_openai.resources.Client") +def test_openai_wrapper_with_asset(mock_client, mock_context, mock_wrapper): + @asset + def openai_asset(openai_resource: OpenAIResource): + assert openai_resource + mock_completion = MagicMock() + mock_usage = MagicMock() + mock_usage.prompt_tokens = 1 + mock_usage.total_tokens = 1 + mock_usage.completion_tokens = 1 + mock_completion.usage = mock_usage + mock_client.return_value.fine_tuning.jobs.create.return_value = mock_completion + + with openai_resource.get_client(context=mock_context) as client: + client.fine_tuning.jobs.create = with_usage_metadata( + context=mock_context, + output_name="openai_asset", + func=client.fine_tuning.jobs.create, + ) + client.fine_tuning.jobs.create( + model="gpt-3.5-turbo", training_file="some_training_file" + ) + + mock_context.add_output_metadata.assert_called_with( + metadata={ + "openai.calls": 1, + "openai.total_tokens": 1, + "openai.prompt_tokens": 1, + "openai.completion_tokens": 1, + }, + output_name="openai_asset", + ) result = ( Definitions( - assets=[openai_multi_asset], - jobs=[define_asset_job("openai_asset_job")], + assets=[openai_asset], + jobs=[ + define_asset_job( + name="openai_asset_job", selection=AssetSelection.assets(openai_asset) + ) + ], resources={ "openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234") },