Skip to content

Commit

Permalink
Add docstring section detailing factory use case
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Nov 1, 2023
1 parent 926b209 commit ee4a31a
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/content/api/modules.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/content/api/searchindex.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/content/api/sections.json

Large diffs are not rendered by default.

Binary file modified docs/next/public/objects.inv
Binary file not shown.
162 changes: 125 additions & 37 deletions python_modules/dagster/dagster/_config/pythonic_config/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
DefinitionConfigSchema,
)
from dagster._core.errors import DagsterInvalidConfigError
from dagster._core.execution.context.init import InitResourceContext, build_init_resource_context
from dagster._core.execution.context.init import (
InitResourceContext,
build_init_resource_context,
)
from dagster._utils.cached_method import cached_method

from .attach_other_object_to_context import (
Expand Down Expand Up @@ -94,7 +97,10 @@ def _resolve_required_resource_keys(
for attr_name, resource_def in self._nested_partial_resources.items()
}
check.invariant(
all(pointer_key is not None for pointer_key in nested_partial_resource_keys.values()),
all(
pointer_key is not None
for pointer_key in nested_partial_resource_keys.values()
),
"Any partially configured, nested resources must be provided to Definitions"
f" object: {nested_partial_resource_keys}",
)
Expand Down Expand Up @@ -229,10 +235,14 @@ def attach_resource_id_to_key_mapping(


def is_coercible_to_resource(val: Any) -> TypeGuard[CoercibleToResource]:
return isinstance(val, (ResourceDefinition, ConfigurableResourceFactory, PartialResource))
return isinstance(
val, (ResourceDefinition, ConfigurableResourceFactory, PartialResource)
)


class ConfigurableResourceFactoryResourceDefinition(ResourceDefinition, AllowDelayedDependencies):
class ConfigurableResourceFactoryResourceDefinition(
ResourceDefinition, AllowDelayedDependencies
):
def __init__(
self,
configurable_resource_cls: Type,
Expand Down Expand Up @@ -336,7 +346,9 @@ def asset_that_uses_database(database: ResourceParam[Database]):
"""

def __init__(self, **data: Any):
resource_pointers, data_without_resources = separate_resource_params(self.__class__, data)
resource_pointers, data_without_resources = separate_resource_params(
self.__class__, data
)

schema = infer_schema_from_config_class(
self.__class__, fields_to_omit=set(resource_pointers.keys())
Expand All @@ -352,13 +364,17 @@ def __init__(self, **data: Any):
for k, v in self._convert_to_config_dictionary().items()
if k in data_without_resources
}
resolved_config_dict = config_dictionary_from_values(casted_data_without_resources, schema)
resolved_config_dict = config_dictionary_from_values(
casted_data_without_resources, schema
)

self._state__internal__ = ConfigurableResourceFactoryState(
# We keep track of any resources we depend on which are not fully configured
# so that we can retrieve them at runtime
nested_partial_resources={
k: v for k, v in resource_pointers.items() if (not _is_fully_configured(v))
k: v
for k, v in resource_pointers.items()
if (not _is_fully_configured(v))
},
resolved_config_dict=resolved_config_dict,
# These are unfortunately named very similarily
Expand Down Expand Up @@ -397,15 +413,20 @@ def _is_dagster_maintained(cls) -> bool:
def _is_cm_resource_cls(cls: Type["ConfigurableResourceFactory"]) -> bool:
return (
cls.yield_for_execution != ConfigurableResourceFactory.yield_for_execution
or cls.teardown_after_execution != ConfigurableResourceFactory.teardown_after_execution
or cls.teardown_after_execution
!= ConfigurableResourceFactory.teardown_after_execution
)

@property
def _is_cm_resource(self) -> bool:
return self.__class__._is_cm_resource_cls() # noqa: SLF001

def _get_initialize_and_run_fn(self) -> Callable:
return self._initialize_and_run_cm if self._is_cm_resource else self._initialize_and_run
return (
self._initialize_and_run_cm
if self._is_cm_resource
else self._initialize_and_run
)

@cached_method
def get_resource_definition(self) -> ConfigurableResourceFactoryResourceDefinition:
Expand Down Expand Up @@ -450,8 +471,10 @@ def _with_updated_values(
# Since Resource extends BaseModel and is a dataclass, we know that the
# signature of any __init__ method will always consist of the fields
# of this class. We can therefore safely pass in the values as kwargs.
to_populate = self.__class__._get_non_default_public_field_values_cls( # noqa: SLF001
{**self._get_non_default_public_field_values(), **values}
to_populate = (
self.__class__._get_non_default_public_field_values_cls( # noqa: SLF001
{**self._get_non_default_public_field_values(), **values}
)
)
out = self.__class__(**to_populate)
out._state__internal__ = out._state__internal__._replace( # noqa: SLF001
Expand Down Expand Up @@ -490,7 +513,9 @@ def _resolve_and_update_nested_resources(

# Also evaluate any resources that are not partial
with contextlib.ExitStack() as stack:
resources_to_update, _ = separate_resource_params(self.__class__, self.__dict__)
resources_to_update, _ = separate_resource_params(
self.__class__, self.__dict__
)
resources_to_update = {
attr_name: _call_resource_fn_with_default(
stack, wrap_resource_for_execution(resource), context
Expand All @@ -503,7 +528,8 @@ def _resolve_and_update_nested_resources(
yield self._with_updated_values(to_update)

@deprecated(
breaking_version="2.0", additional_warn_text="Use `with_replaced_resource_context` instead"
breaking_version="2.0",
additional_warn_text="Use `with_replaced_resource_context` instead",
)
def with_resource_context(
self, resource_context: InitResourceContext
Expand All @@ -524,9 +550,11 @@ def with_replaced_resource_context(

def _initialize_and_run(self, context: InitResourceContext) -> TResValue:
with self._resolve_and_update_nested_resources(context) as has_nested_resource:
updated_resource = has_nested_resource.with_replaced_resource_context( # noqa: SLF001
context
)._with_updated_values(context.resource_config)
updated_resource = (
has_nested_resource.with_replaced_resource_context( # noqa: SLF001
context
)._with_updated_values(context.resource_config)
)

updated_resource.setup_for_execution(context)
return updated_resource.create_resource(context)
Expand All @@ -536,9 +564,11 @@ def _initialize_and_run_cm(
self, context: InitResourceContext
) -> Generator[TResValue, None, None]:
with self._resolve_and_update_nested_resources(context) as has_nested_resource:
updated_resource = has_nested_resource.with_replaced_resource_context( # noqa: SLF001
context
)._with_updated_values(context.resource_config)
updated_resource = (
has_nested_resource.with_replaced_resource_context( # noqa: SLF001
context
)._with_updated_values(context.resource_config)
)

with updated_resource.yield_for_execution(context) as value:
yield value
Expand All @@ -559,7 +589,9 @@ def teardown_after_execution(self, context: InitResourceContext) -> None:
pass

@contextlib.contextmanager
def yield_for_execution(self, context: InitResourceContext) -> Generator[TResValue, None, None]:
def yield_for_execution(
self, context: InitResourceContext
) -> Generator[TResValue, None, None]:
"""Optionally override this method to perform any lifecycle steps
before or after the resource is used in execution. By default, calls
setup_for_execution before yielding, and teardown_after_execution after yielding.
Expand Down Expand Up @@ -589,7 +621,8 @@ def process_config_and_initialize(self) -> TResValue:
return self.from_resource_context(
build_init_resource_context(
config=post_process_config(
self._config_schema.config_type, self._convert_to_config_dictionary()
self._config_schema.config_type,
self._convert_to_config_dictionary(),
).value
)
)
Expand Down Expand Up @@ -619,7 +652,9 @@ def my_resource(context: InitResourceContext) -> MyResource:
"Use from_resource_context_cm for resources which have custom teardown behavior,"
" e.g. overriding yield_for_execution or teardown_after_execution",
)
return cls(**context.resource_config or {})._initialize_and_run(context) # noqa: SLF001
return cls(**context.resource_config or {})._initialize_and_run(
context
) # noqa: SLF001

@classmethod
@contextlib.contextmanager
Expand All @@ -643,7 +678,9 @@ def my_resource(context: InitResourceContext) -> Generator[MyResource, None, Non
yield my_resource
"""
with cls(**context.resource_config or {})._initialize_and_run_cm( # noqa: SLF001
with cls(
**context.resource_config or {}
)._initialize_and_run_cm( # noqa: SLF001
context
) as value:
yield value
Expand Down Expand Up @@ -677,6 +714,34 @@ def asset_that_uses_writer(writer: WriterResource):
resources={"writer": WriterResource(prefix="a_prefix")},
)
You can optionally use this class to model configuration only and vend an object
of a different type for use at runtime. This is useful for those who wish to
have a separate object that manages configuration and a separate object at runtime. Or
where you want to directly use a third-party class that you do not control.
To do this you override the `create_resource` methods to return a different object.
.. code-block:: python
class WriterResource(ConfigurableResource):
str: prefix
def create_resource(self, context: InitResourceContext) -> Writer:
# Writer is pre-existing class defined else
return Writer(self.prefix)
Example usage:
.. code-block:: python
@asset
def use_preexisting_writer_as_resource(writer: ResourceParam[Writer]):
writer.output("text")
defs = Definitions(
assets=[use_preexisting_writer_as_resource],
resources={"writer": WriterResource(prefix="a_prefix")},
)
"""

def create_resource(self, context: InitResourceContext) -> TResValue:
Expand Down Expand Up @@ -717,7 +782,9 @@ class PartialResourceState(NamedTuple):
nested_resources: Dict[str, Any]


class PartialResource(Generic[TResValue], AllowDelayedDependencies, MakeConfigCacheable):
class PartialResource(
Generic[TResValue], AllowDelayedDependencies, MakeConfigCacheable
):
data: Dict[str, Any]
resource_cls: Type[Any]

Expand All @@ -726,13 +793,17 @@ def __init__(
resource_cls: Type[ConfigurableResourceFactory[TResValue]],
data: Dict[str, Any],
):
resource_pointers, _data_without_resources = separate_resource_params(resource_cls, data)
resource_pointers, _data_without_resources = separate_resource_params(
resource_cls, data
)

MakeConfigCacheable.__init__(self, data=data, resource_cls=resource_cls) # type: ignore # extends BaseModel, takes kwargs

def resource_fn(context: InitResourceContext):
to_populate = resource_cls._get_non_default_public_field_values_cls( # noqa: SLF001
{**data, **context.resource_config}
to_populate = (
resource_cls._get_non_default_public_field_values_cls( # noqa: SLF001
{**data, **context.resource_config}
)
)
instantiated = resource_cls(
**to_populate
Expand All @@ -743,7 +814,9 @@ def resource_fn(context: InitResourceContext):
# We keep track of any resources we depend on which are not fully configured
# so that we can retrieve them at runtime
nested_partial_resources={
k: v for k, v in resource_pointers.items() if (not _is_fully_configured(v))
k: v
for k, v in resource_pointers.items()
if (not _is_fully_configured(v))
},
config_schema=infer_schema_from_config_class(
resource_cls, fields_to_omit=set(resource_pointers.keys())
Expand Down Expand Up @@ -800,7 +873,9 @@ def __set_name__(self, _owner, name):
def __get__(self, obj: "ConfigurableResourceFactory", __owner: Any) -> V:
return getattr(obj, self._name)

def __set__(self, obj: Optional[object], value: ResourceOrPartialOrValue[V]) -> None:
def __set__(
self, obj: Optional[object], value: ResourceOrPartialOrValue[V]
) -> None:
setattr(obj, self._name, value)


Expand Down Expand Up @@ -864,9 +939,11 @@ def _is_annotated_as_resource_type(annotation: Type, metadata: List[str]) -> boo
if metadata and metadata[0] == "resource_dependency":
return True

is_annotated_as_resource_dependency = get_origin(annotation) == ResourceDependency or getattr(
annotation, "__metadata__", None
) == ("resource_dependency",)
is_annotated_as_resource_dependency = get_origin(
annotation
) == ResourceDependency or getattr(annotation, "__metadata__", None) == (
"resource_dependency",
)

return is_annotated_as_resource_dependency or safe_is_subclass(
annotation, (ResourceDefinition, ConfigurableResourceFactory)
Expand All @@ -880,12 +957,15 @@ class ResourceDataWithAnnotation(NamedTuple):
annotation_metadata: List[str]


def separate_resource_params(cls: Type[BaseModel], data: Dict[str, Any]) -> SeparatedResourceParams:
def separate_resource_params(
cls: Type[BaseModel], data: Dict[str, Any]
) -> SeparatedResourceParams:
"""Separates out the key/value inputs of fields in a structured config Resource class which
are marked as resources (ie, using ResourceDependency) from those which are not.
"""
fields_by_resolved_field_name = {
field.alias if field.alias else key: field for key, field in model_fields(cls).items()
field.alias if field.alias else key: field
for key, field in model_fields(cls).items()
}
data_with_annotation: List[ResourceDataWithAnnotation] = [
# No longer exists in Pydantic 2.x, will need to be updated when we upgrade
Expand Down Expand Up @@ -987,15 +1067,21 @@ def validate_resource_annotated_function(fn) -> None:
malformed_params = [
param
for param in get_function_params(fn)
if safe_is_subclass(param.annotation, (ResourceDefinition, ConfigurableResourceFactory))
if safe_is_subclass(
param.annotation, (ResourceDefinition, ConfigurableResourceFactory)
)
and not safe_is_subclass(param.annotation, ConfigurableResource)
]
if len(malformed_params) > 0:
malformed_param = malformed_params[0]
output_type = None
if safe_is_subclass(malformed_param.annotation, ConfigurableResourceFactory):
orig_bases = getattr(malformed_param.annotation, "__orig_bases__", None)
output_type = get_args(orig_bases[0])[0] if orig_bases and len(orig_bases) > 0 else None
output_type = (
get_args(orig_bases[0])[0]
if orig_bases and len(orig_bases) > 0
else None
)
if output_type == TResValue:
output_type = None

Expand All @@ -1022,5 +1108,7 @@ def _resolve_required_resource_keys_for_resource(
this mapping is used to obtain the top-level resource keys to depend on.
"""
if isinstance(resource, AllowDelayedDependencies):
return resource._resolve_required_resource_keys(resource_id_to_key_mapping) # noqa: SLF001
return resource._resolve_required_resource_keys(
resource_id_to_key_mapping
) # noqa: SLF001
return resource.required_resource_keys

0 comments on commit ee4a31a

Please sign in to comment.