Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Feb 22, 2024
1 parent b6de222 commit bfad41d
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ApiEndpointClassesEnum(Enum):
context_to_counters = WeakKeyDictionary()


def add_to_asset_metadata(context: AssetExecutionContext, usage_metadata: dict, output_name: str):
def add_to_asset_metadata(context: AssetExecutionContext, usage_metadata: dict, output_name: Optional[str]):
if context not in context_to_counters:
context_to_counters[context] = defaultdict(lambda: 0)
counters = context_to_counters[context]
Expand All @@ -45,7 +45,7 @@ def add_to_asset_metadata(context: AssetExecutionContext, usage_metadata: dict,
context.add_output_metadata(dict(counters), output_name)


def with_usage_metadata(context: AssetExecutionContext, output_name: str, func):
def with_usage_metadata(context: AssetExecutionContext, output_name: Optional[str], func):
"""This wrapper can be used on any endpoint of the
`openai library <https://github.com/openai/openai-python>`
to log the OpenAI API usage metadata in the asset metadata.
Expand Down Expand Up @@ -83,7 +83,7 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource):
)
"""
if not isinstance(context, AssetExecutionContext):
raise TypeError(
raise DagsterInvariantViolationError(
"The `with_usage_metadata` can only be used when context is of type AssetExecutionContext."
)

Expand Down Expand Up @@ -224,7 +224,6 @@ def openai_asset(context: AssetExecutionContext, openai: OpenAIResource):
raise DagsterInvariantViolationError(
"The argument `asset_key` must be specified for multi_asset with more than one asset."
)
asset_key = context.asset_key
# By default, when the resource is used in an asset context,
# we wrap the methods of `openai.resources.Completions`,
# `openai.resources.Embeddings` and `openai.resources.chat.Completions`.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json

import pytest
from dagster import (
AssetExecutionContext,
AssetOut,
AssetSelection,
Definitions,
OpExecutionContext,
StaticPartitionsDefinition,
asset,
define_asset_job,
multi_asset,
Expand All @@ -15,7 +14,7 @@
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.execution.context.init import build_init_resource_context
from dagster._utils.test import wrap_op_in_graph_and_execute
from dagster_openai import OpenAIResource
from dagster_openai import OpenAIResource, with_usage_metadata
from mock import ANY, MagicMock, patch


Expand All @@ -30,20 +29,13 @@ def test_openai_client(mock_client) -> None:


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
@patch("dagster.OpExecutionContext", autospec=OpExecutionContext)
@patch("dagster_openai.resources.Client")
def test_openai_resource_with_op(mock_client, mock_wrapper):
def test_openai_resource_with_op(mock_client, mock_context, mock_wrapper):
@op
def openai_op(openai_resource: OpenAIResource):
assert openai_resource
body = {"ok": True}
mock_client.chat.completions.create.return_value = {
"status": 200,
"body": json.dumps(body),
"headers": "",
}

mock_context = MagicMock()
mock_context.__class__ = OpExecutionContext

with openai_resource.get_client(context=mock_context) as client:
client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}]
Expand All @@ -62,20 +54,13 @@ def openai_op(openai_resource: OpenAIResource):


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_openai.resources.Client")
def test_openai_resource_with_asset(mock_client, mock_wrapper):
def test_openai_resource_with_asset(mock_client, mock_context, mock_wrapper):
@asset
def openai_asset(openai_resource: OpenAIResource):
assert openai_resource
body = {"ok": True}
mock_client.chat.completions.create.return_value = {
"status": 200,
"body": json.dumps(body),
"headers": "",
}

mock_context = MagicMock()
mock_context.__class__ = AssetExecutionContext

with openai_resource.get_client(context=mock_context) as client:
client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}]
Expand Down Expand Up @@ -104,8 +89,9 @@ def openai_asset(openai_resource: OpenAIResource):


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_openai.resources.Client")
def test_openai_resource_with_multi_asset_with_asset_key(mock_client, mock_wrapper):
def test_openai_resource_with_multi_asset(mock_client, mock_context, mock_wrapper):
@multi_asset(
outs={
"status": AssetOut(),
Expand All @@ -114,15 +100,10 @@ def test_openai_resource_with_multi_asset_with_asset_key(mock_client, mock_wrapp
)
def openai_multi_asset(openai_resource: OpenAIResource):
assert openai_resource
body = {"ok": True}
mock_client.chat.completions.create.return_value = {
"status": 200,
"body": json.dumps(body),
"headers": "",
}

mock_context = MagicMock()
mock_context.__class__ = AssetExecutionContext

mock_context.assets_def.keys_by_output_name.keys.return_value = ["status", "result"]

# Test success when asset_key is provided
with openai_resource.get_client(context=mock_context, asset_key="result") as client:
client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Say this is a test"}]
Expand All @@ -135,56 +116,163 @@ def openai_multi_asset(openai_resource: OpenAIResource):
context=mock_context,
output_name="result",
)

# Test failure when asset_key is not provided
with pytest.raises(DagsterInvariantViolationError):
with openai_resource.get_client(context=mock_context) as client:
client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Say this is a test"}],
)
return None, None

result = (
Definitions(
assets=[openai_multi_asset],
jobs=[define_asset_job("openai_asset_job")],
jobs=[
define_asset_job(
name="openai_multi_asset_job",
selection=AssetSelection.assets(openai_multi_asset),
)
],
resources={
"openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234")
},
)
.get_job_def("openai_asset_job")
.get_job_def("openai_multi_asset_job")
.execute_in_process()
)

assert result.success


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_openai.resources.Client")
def test_openai_resource_with_multi_asset_without_asset_key(mock_client, mock_wrapper):
@multi_asset(
outs={
"status": AssetOut(),
"result": AssetOut(),
},
)
def openai_multi_asset(openai_resource: OpenAIResource):
assert openai_resource
body = {"ok": True}
mock_client.chat.completions.create.return_value = {
"status": 200,
"body": json.dumps(body),
"headers": "",
}

mock_context = MagicMock()
mock_context.__class__ = AssetExecutionContext
mock_context.assets_def.keys_by_output_name.keys.return_value = ["status", "result"]
with pytest.raises(DagsterInvariantViolationError):
with openai_resource.get_client(context=mock_context) as client:
def test_openai_resource_with_partitioned_asset(mock_client, mock_context, mock_wrapper):
openai_partitions_def = StaticPartitionsDefinition([str(j) for j in range(5)])

openai_partitioned_assets = []

for i in range(5):

@asset(
name=f"openai_partitioned_asset_{i}",
group_name="openai_partitioned_assets",
partitions_def=openai_partitions_def,
)
def openai_partitioned_asset(openai_resource: OpenAIResource):
assert openai_resource

with openai_resource.get_client(context=mock_context, asset_key="result") as client:
client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Say this is a test"}],
)
return None, None

assert mock_client.called
assert mock_wrapper.called
mock_wrapper.assert_called_with(
api_endpoint_class=ANY,
context=mock_context,
output_name="result",
)

openai_partitioned_assets.append(openai_partitioned_asset)

defs = Definitions(
assets=openai_partitioned_assets,
jobs=[
define_asset_job(
name="openai_partitioned_asset_job",
selection=AssetSelection.groups("openai_partitioned_assets"),
partitions_def=openai_partitions_def,
)
],
resources={
"openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234")
},
)

for partition_key in openai_partitions_def.get_partition_keys():
result = defs.get_job_def("openai_partitioned_asset_job").execute_in_process(
partition_key=partition_key
)
assert result.success

expected_wrapper_call_counts = (
3 * len(openai_partitioned_assets) * len(openai_partitions_def.get_partition_keys())
)
assert mock_wrapper.call_count == expected_wrapper_call_counts


@patch("dagster.OpExecutionContext", autospec=OpExecutionContext)
@patch("dagster_openai.resources.Client")
def test_openai_wrapper_with_op(mock_client, mock_context):
@op
def openai_op(openai_resource: OpenAIResource):
assert openai_resource

with openai_resource.get_client(context=mock_context) as client:
with pytest.raises(DagsterInvariantViolationError):
client.fine_tuning.jobs.create = with_usage_metadata(
context=mock_context,
output_name="some_output_name",
func=client.fine_tuning.jobs.create,
)

result = wrap_op_in_graph_and_execute(
openai_op,
resources={
"openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234")
},
)
assert result.success


@patch("dagster_openai.resources.OpenAIResource._wrap_with_usage_metadata")
@patch("dagster.AssetExecutionContext", autospec=AssetExecutionContext)
@patch("dagster_openai.resources.Client")
def test_openai_wrapper_with_asset(mock_client, mock_context, mock_wrapper):
@asset
def openai_asset(openai_resource: OpenAIResource):
assert openai_resource
mock_completion = MagicMock()
mock_usage = MagicMock()
mock_usage.prompt_tokens = 1
mock_usage.total_tokens = 1
mock_usage.completion_tokens = 1
mock_completion.usage = mock_usage
mock_client.return_value.fine_tuning.jobs.create.return_value = mock_completion

with openai_resource.get_client(context=mock_context) as client:
client.fine_tuning.jobs.create = with_usage_metadata(
context=mock_context,
output_name="openai_asset",
func=client.fine_tuning.jobs.create,
)
client.fine_tuning.jobs.create(
model="gpt-3.5-turbo", training_file="some_training_file"
)

mock_context.add_output_metadata.assert_called_with(
metadata={
"openai.calls": 1,
"openai.total_tokens": 1,
"openai.prompt_tokens": 1,
"openai.completion_tokens": 1,
},
output_name="openai_asset",
)

result = (
Definitions(
assets=[openai_multi_asset],
jobs=[define_asset_job("openai_asset_job")],
assets=[openai_asset],
jobs=[
define_asset_job(
name="openai_asset_job", selection=AssetSelection.assets(openai_asset)
)
],
resources={
"openai_resource": OpenAIResource(api_key="xoxp-1234123412341234-12341234-1234")
},
Expand Down

0 comments on commit bfad41d

Please sign in to comment.