Skip to content

Commit

Permalink
[docs] fix pyright issues from adding context type hints (#17193)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria authored Nov 10, 2023
1 parent a958f6b commit 56598f2
Show file tree
Hide file tree
Showing 18 changed files with 103 additions and 55 deletions.
37 changes: 24 additions & 13 deletions docs/content/concepts/io-management/io-managers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,9 @@ class DataframeTableIOManager(ConfigurableIOManager):

def load_input(self, context: InputContext):
# upstream_output.name is the name given to the Out that we're loading for
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)
if context.upstream_output:
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)


@job(resource_defs={"io_manager": DataframeTableIOManager()})
Expand Down Expand Up @@ -725,14 +726,22 @@ In this case, the table names are encoded in the job definition. If, instead, yo
```python file=/concepts/io_management/metadata.py startafter=io_manager_start_marker endbefore=io_manager_end_marker
class MyIOManager(ConfigurableIOManager):
def handle_output(self, context: OutputContext, obj):
table_name = context.metadata["table"]
schema = context.metadata["schema"]
write_dataframe_to_table(name=table_name, schema=schema, dataframe=obj)
if context.metadata:
table_name = context.metadata["table"]
schema = context.metadata["schema"]
write_dataframe_to_table(name=table_name, schema=schema, dataframe=obj)
else:
raise Exception(
f"op {context.op_def.name} doesn't have schema and metadata set"
)

def load_input(self, context: InputContext):
table_name = context.upstream_output.metadata["table"]
schema = context.upstream_output.metadata["schema"]
return read_dataframe_from_table(name=table_name, schema=schema)
if context.upstream_output and context.upstream_output.metadata:
table_name = context.upstream_output.metadata["table"]
schema = context.upstream_output.metadata["schema"]
return read_dataframe_from_table(name=table_name, schema=schema)
else:
raise Exception("Upstream output doesn't have schema and metadata set")
```

### Per-input loading in assets
Expand Down Expand Up @@ -809,9 +818,10 @@ class MyIOManager(IOManager):
self.storage_dict[(context.step_key, context.name)] = obj

def load_input(self, context: InputContext):
return self.storage_dict[
(context.upstream_output.step_key, context.upstream_output.name)
]
if context.upstream_output:
return self.storage_dict[
(context.upstream_output.step_key, context.upstream_output.name)
]


def test_my_io_manager_handle_output():
Expand Down Expand Up @@ -844,8 +854,9 @@ class DataframeTableIOManagerWithMetadata(ConfigurableIOManager):
context.add_output_metadata({"num_rows": len(obj), "table_name": table_name})

def load_input(self, context: InputContext):
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)
if context.upstream_output:
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)
```

Any entries yielded this way will be attached to the `Handled Output` event for this output.
Expand Down
3 changes: 2 additions & 1 deletion docs/content/concepts/io-management/unconnected-inputs.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ class MyIOManager(ConfigurableIOManager):
write_dataframe_to_table(name=table_name, dataframe=obj)

def load_input(self, context: InputContext):
return read_dataframe_from_table(name=context.upstream_output.name)
if context.upstream_output:
return read_dataframe_from_table(name=context.upstream_output.name)


@input_manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ from dagster import (
)
)
def multi_partitions_asset(context: AssetExecutionContext):
context.log.info(context.partition_key.keys_by_dimension)
if isinstance(context.partition_key, MultiPartitionKey):
context.log.info(context.partition_key.keys_by_dimension)
```

In this example, the asset would contain a partition for each combination of color and date:
Expand Down Expand Up @@ -166,8 +167,8 @@ def image_sensor(context: SensorEvaluationContext):
new_images = [
img_filename
for img_filename in os.listdir(os.getenv("MY_DIRECTORY"))
if not context.instance.has_dynamic_partition(
images_partitions_def.name, img_filename
if not images_partitions_def.has_partition_key(
img_filename, dynamic_partitions_store=context.instance
)
]

Expand Down
20 changes: 14 additions & 6 deletions docs/content/guides/dagster/managing-ml.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,23 @@ def conditional_machine_learning_model(context: AssetExecutionContext):
AssetKey(["conditional_machine_learning_model"])
)
if materialization is None:
yield Output(reg, metadata={"model_accuracy": reg.score(X_test, y_test)})
yield Output(reg, metadata={"model_accuracy": float(reg.score(X_test, y_test))})

else:
previous_model_accuracy = materialization.asset_materialization.metadata[
"model_accuracy"
]
previous_model_accuracy = None
if materialization.asset_materialization and isinstance(
materialization.asset_materialization.metadata["model_accuracy"].value,
float,
):
previous_model_accuracy = float(
materialization.asset_materialization.metadata["model_accuracy"].value
)
new_model_accuracy = reg.score(X_test, y_test)
if new_model_accuracy > previous_model_accuracy:
yield Output(reg, metadata={"model_accuracy": new_model_accuracy})
if (
previous_model_accuracy is None
or new_model_accuracy > previous_model_accuracy
):
yield Output(reg, metadata={"model_accuracy": float(new_model_accuracy)})
```

A sensor can be set up that triggers if an asset fails to materialize. Alerts can be customized and sent through e-mail or natively through Slack. In this example, a Slack message is sent anytime the `ml_job` fails.
Expand Down
2 changes: 1 addition & 1 deletion docs/content/integrations/bigquery/reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ from dagster import (
},
)
def iris_data_partitioned(context: AssetExecutionContext) -> pd.DataFrame:
partition = partition = context.partition_key.keys_by_dimension
partition = context.partition_key.keys_by_dimension
species = partition["species"]
date = partition["date"]

Expand Down
2 changes: 1 addition & 1 deletion docs/content/integrations/duckdb/reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ from dagster import (
metadata={"partition_expr": {"date": "TO_TIMESTAMP(TIME)", "species": "SPECIES"}},
)
def iris_dataset_partitioned(context: AssetExecutionContext) -> pd.DataFrame:
partition = partition = context.partition_key.keys_by_dimension
partition = context.partition_key.keys_by_dimension
species = partition["species"]
date = partition["date"]

Expand Down
3 changes: 2 additions & 1 deletion docs/content/integrations/snowflake/reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ import pandas as pd
from dagster import (
AssetExecutionContext,
DailyPartitionsDefinition,
MultiPartitionKey,
MultiPartitionsDefinition,
StaticPartitionsDefinition,
asset,
Expand All @@ -252,7 +253,7 @@ from dagster import (
},
)
def iris_dataset_partitioned(context: AssetExecutionContext) -> pd.DataFrame:
partition = partition = context.partition_key.keys_by_dimension
partition = context.partition_key.keys_by_dimension
species = partition["species"]
date = partition["date"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def handle_output(self, context: OutputContext, obj):

def load_input(self, context: InputContext):
# upstream_output.name is the name given to the Out that we're loading for
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)
if context.upstream_output:
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)


@job(resource_defs={"io_manager": DataframeTableIOManager()})
Expand All @@ -131,8 +132,9 @@ def handle_output(self, context: OutputContext, obj):
context.add_output_metadata({"num_rows": len(obj), "table_name": table_name})

def load_input(self, context: InputContext):
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)
if context.upstream_output:
table_name = context.upstream_output.name
return read_dataframe_from_table(name=table_name)


# end_metadata_marker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def handle_output(self, context: OutputContext, obj):
write_dataframe_to_table(name=table_name, dataframe=obj)

def load_input(self, context: InputContext):
return read_dataframe_from_table(name=context.upstream_output.name)
if context.upstream_output:
return read_dataframe_from_table(name=context.upstream_output.name)


@input_manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,22 @@ def op_2(_input_dataframe):
# io_manager_start_marker
class MyIOManager(ConfigurableIOManager):
def handle_output(self, context: OutputContext, obj):
table_name = context.metadata["table"]
schema = context.metadata["schema"]
write_dataframe_to_table(name=table_name, schema=schema, dataframe=obj)
if context.metadata:
table_name = context.metadata["table"]
schema = context.metadata["schema"]
write_dataframe_to_table(name=table_name, schema=schema, dataframe=obj)
else:
raise Exception(
f"op {context.op_def.name} doesn't have schema and metadata set"
)

def load_input(self, context: InputContext):
table_name = context.upstream_output.metadata["table"]
schema = context.upstream_output.metadata["schema"]
return read_dataframe_from_table(name=table_name, schema=schema)
if context.upstream_output and context.upstream_output.metadata:
table_name = context.upstream_output.metadata["table"]
schema = context.upstream_output.metadata["schema"]
return read_dataframe_from_table(name=table_name, schema=schema)
else:
raise Exception("Upstream output doesn't have schema and metadata set")


# io_manager_end_marker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def handle_output(self, context: OutputContext, obj):
write_dataframe_to_table(name=table_name, dataframe=obj)

def load_input(self, context: InputContext):
table_name = context.upstream_output.config["table"]
return read_dataframe_from_table(name=table_name)
if context.upstream_output:
table_name = context.upstream_output.config["table"]
return read_dataframe_from_table(name=table_name)


@io_manager(output_config_schema={"table": str})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def handle_output(self, context: OutputContext, obj):
self.storage_dict[(context.step_key, context.name)] = obj

def load_input(self, context: InputContext):
return self.storage_dict[
(context.upstream_output.step_key, context.upstream_output.name)
]
if context.upstream_output:
return self.storage_dict[
(context.upstream_output.step_key, context.upstream_output.name)
]


def test_my_io_manager_handle_output():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def image_sensor(context: SensorEvaluationContext):
new_images = [
img_filename
for img_filename in os.listdir(os.getenv("MY_DIRECTORY"))
if not context.instance.has_dynamic_partition(
images_partitions_def.name, img_filename
if not images_partitions_def.has_partition_key(
img_filename, dynamic_partitions_store=context.instance
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
)
)
def multi_partitions_asset(context: AssetExecutionContext):
context.log.info(context.partition_key.keys_by_dimension)
if isinstance(context.partition_key, MultiPartitionKey):
context.log.info(context.partition_key.keys_by_dimension)


# end_multi_partitions_marker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,23 @@ def conditional_machine_learning_model(context: AssetExecutionContext):
AssetKey(["conditional_machine_learning_model"])
)
if materialization is None:
yield Output(reg, metadata={"model_accuracy": reg.score(X_test, y_test)})
yield Output(reg, metadata={"model_accuracy": float(reg.score(X_test, y_test))})

else:
previous_model_accuracy = materialization.asset_materialization.metadata[
"model_accuracy"
]
previous_model_accuracy = None
if materialization.asset_materialization and isinstance(
materialization.asset_materialization.metadata["model_accuracy"].value,
float,
):
previous_model_accuracy = float(
materialization.asset_materialization.metadata["model_accuracy"].value
)
new_model_accuracy = reg.score(X_test, y_test)
if new_model_accuracy > previous_model_accuracy:
yield Output(reg, metadata={"model_accuracy": new_model_accuracy})
if (
previous_model_accuracy is None
or new_model_accuracy > previous_model_accuracy
):
yield Output(reg, metadata={"model_accuracy": float(new_model_accuracy)})


## conditional_monitoring_end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_iris_data_for_date(*args, **kwargs):
},
)
def iris_data_partitioned(context: AssetExecutionContext) -> pd.DataFrame:
partition = partition = context.partition_key.keys_by_dimension
partition = context.partition_key.keys_by_dimension # type: ignore
species = partition["species"]
date = partition["date"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_iris_data_for_date(*args, **kwargs):
metadata={"partition_expr": {"date": "TO_TIMESTAMP(TIME)", "species": "SPECIES"}},
)
def iris_dataset_partitioned(context: AssetExecutionContext) -> pd.DataFrame:
partition = partition = context.partition_key.keys_by_dimension
partition = context.partition_key.keys_by_dimension # type: ignore
species = partition["species"]
date = partition["date"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pandas as pd


def get_iris_data_for_date(*args, **kwargs):
pass
return pd.DataFrame()


# start_example
Expand All @@ -9,6 +12,7 @@ def get_iris_data_for_date(*args, **kwargs):
from dagster import (
AssetExecutionContext,
DailyPartitionsDefinition,
MultiPartitionKey,
MultiPartitionsDefinition,
StaticPartitionsDefinition,
asset,
Expand All @@ -29,7 +33,7 @@ def get_iris_data_for_date(*args, **kwargs):
},
)
def iris_dataset_partitioned(context: AssetExecutionContext) -> pd.DataFrame:
partition = partition = context.partition_key.keys_by_dimension
partition = context.partition_key.keys_by_dimension # type: ignore
species = partition["species"]
date = partition["date"]

Expand Down

0 comments on commit 56598f2

Please sign in to comment.