From 6d77a1b666b6ec62be7ad783e3d64c1e5a2bcdaf Mon Sep 17 00:00:00 2001 From: Chris DeCarolis Date: Thu, 9 Nov 2023 10:10:55 -0800 Subject: [PATCH] Type annotations for grpc/client.py --- .../dagster/dagster/_grpc/client.py | 78 +++++++++++-------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/python_modules/dagster/dagster/_grpc/client.py b/python_modules/dagster/dagster/_grpc/client.py index a0fd1d8274156..ad59f72370569 100644 --- a/python_modules/dagster/dagster/_grpc/client.py +++ b/python_modules/dagster/dagster/_grpc/client.py @@ -2,7 +2,7 @@ import sys from contextlib import contextmanager from threading import Event -from typing import Any, Iterator, Optional, Sequence, Tuple +from typing import Any, Dict, Iterator, Optional, Sequence, Tuple import grpc from google.protobuf.reflection import GeneratedProtocolMessageType @@ -134,12 +134,14 @@ def _get_response( method: str, request: str, timeout: int = DEFAULT_GRPC_TIMEOUT, - ): + ) -> Any: with self._channel() as channel: stub = DagsterApiStub(channel) return getattr(stub, method)(request, metadata=self._metadata, timeout=timeout) - def _raise_grpc_exception(self, e: Exception, timeout, custom_timeout_message=None): + def _raise_grpc_exception( + self, e: Exception, timeout: int, custom_timeout_message: Optional[str] = None + ) -> None: if isinstance(e, grpc.RpcError): if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore # (bad stubs) raise DagsterUserCodeUnreachableError( @@ -160,7 +162,7 @@ def _query( timeout=DEFAULT_GRPC_TIMEOUT, custom_timeout_message=None, **kwargs, - ): + ) -> Any: try: return self._get_response(method, request=request_type(**kwargs), timeout=timeout) except Exception as e: @@ -180,10 +182,10 @@ def _get_streaming_response( def _streaming_query( self, - method, - request_type, - timeout=DEFAULT_GRPC_TIMEOUT, - custom_timeout_message=None, + method: str, + request_type: Any, + timeout: int = DEFAULT_GRPC_TIMEOUT, + custom_timeout_message: Optional[str] = None, **kwargs, ) -> Iterator[Any]: try: @@ -195,17 +197,17 @@ def _streaming_query( e, timeout=timeout, custom_timeout_message=custom_timeout_message ) - def ping(self, echo: str): + def ping(self, echo: str) -> str: check.str_param(echo, "echo") res = self._query("Ping", api_pb2.PingRequest, echo=echo) return res.echo - def heartbeat(self, echo: str = ""): + def heartbeat(self, echo: str = "") -> str: check.str_param(echo, "echo") res = self._query("Heartbeat", api_pb2.PingRequest, echo=echo) return res.echo - def streaming_ping(self, sequence_length: int, echo: str): + def streaming_ping(self, sequence_length: int, echo: str) -> Iterator[Dict[str, Any]]: check.int_param(sequence_length, "sequence_length") check.str_param(echo, "echo") @@ -224,7 +226,9 @@ def get_server_id(self, timeout: int = DEFAULT_GRPC_TIMEOUT) -> str: res = self._query("GetServerId", api_pb2.Empty, timeout=timeout) return res.server_id - def execution_plan_snapshot(self, execution_plan_snapshot_args: ExecutionPlanSnapshotArgs): + def execution_plan_snapshot( + self, execution_plan_snapshot_args: ExecutionPlanSnapshotArgs + ) -> str: check.inst_param( execution_plan_snapshot_args, "execution_plan_snapshot_args", ExecutionPlanSnapshotArgs ) @@ -235,11 +239,11 @@ def execution_plan_snapshot(self, execution_plan_snapshot_args: ExecutionPlanSna ) return res.serialized_execution_plan_snapshot - def list_repositories(self): + def list_repositories(self) -> str: res = self._query("ListRepositories", api_pb2.ListRepositoriesRequest) return res.serialized_list_repositories_response_or_error - def external_partition_names(self, partition_names_args): + def external_partition_names(self, partition_names_args: PartitionNamesArgs) -> str: check.inst_param(partition_names_args, "partition_names_args", PartitionNamesArgs) res = self._query( @@ -250,7 +254,7 @@ def external_partition_names(self, partition_names_args): return res.serialized_external_partition_names_or_external_partition_execution_error - def external_partition_config(self, partition_args): + def external_partition_config(self, partition_args: PartitionArgs) -> str: check.inst_param(partition_args, "partition_args", PartitionArgs) res = self._query( @@ -261,7 +265,7 @@ def external_partition_config(self, partition_args): return res.serialized_external_partition_config_or_external_partition_execution_error - def external_partition_tags(self, partition_args): + def external_partition_tags(self, partition_args: PartitionArgs) -> str: check.inst_param(partition_args, "partition_args", PartitionArgs) res = self._query( @@ -272,7 +276,9 @@ def external_partition_tags(self, partition_args): return res.serialized_external_partition_tags_or_external_partition_execution_error - def external_partition_set_execution_params(self, partition_set_execution_param_args): + def external_partition_set_execution_params( + self, partition_set_execution_param_args: PartitionSetExecutionParamArgs + ) -> str: check.inst_param( partition_set_execution_param_args, "partition_set_execution_param_args", @@ -291,7 +297,7 @@ def external_partition_set_execution_params(self, partition_set_execution_param_ return "".join([chunk.serialized_chunk for chunk in chunks]) - def external_pipeline_subset(self, pipeline_subset_snapshot_args): + def external_pipeline_subset(self, pipeline_subset_snapshot_args: JobSubsetSnapshotArgs) -> str: check.inst_param( pipeline_subset_snapshot_args, "pipeline_subset_snapshot_args", @@ -306,14 +312,14 @@ def external_pipeline_subset(self, pipeline_subset_snapshot_args): return res.serialized_external_pipeline_subset_result - def reload_code(self, timeout: int): + def reload_code(self, timeout: int) -> Any: return self._query("ReloadCode", api_pb2.ReloadCodeRequest, timeout=timeout) def external_repository( self, external_repository_origin: ExternalRepositoryOrigin, defer_snapshots: bool = False, - ): + ) -> str: check.inst_param( external_repository_origin, "external_repository_origin", @@ -334,7 +340,7 @@ def external_job( self, external_repository_origin: ExternalRepositoryOrigin, job_name: str, - ): + ) -> str: check.inst_param( external_repository_origin, "external_repository_origin", @@ -352,8 +358,8 @@ def streaming_external_repository( self, external_repository_origin: ExternalRepositoryOrigin, defer_snapshots: bool = False, - timeout=DEFAULT_REPOSITORY_GRPC_TIMEOUT, - ): + timeout: int = DEFAULT_REPOSITORY_GRPC_TIMEOUT, + ) -> Iterator[Dict[str, Any]]: for res in self._streaming_query( "StreamingExternalRepository", api_pb2.ExternalRepositoryRequest, @@ -368,8 +374,10 @@ def streaming_external_repository( } def external_schedule_execution( - self, external_schedule_execution_args, timeout=DEFAULT_SCHEDULE_GRPC_TIMEOUT - ): + self, + external_schedule_execution_args: ExternalScheduleExecutionArgs, + timeout: int = DEFAULT_SCHEDULE_GRPC_TIMEOUT, + ) -> str: check.inst_param( external_schedule_execution_args, "external_schedule_execution_args", @@ -399,7 +407,9 @@ def external_schedule_execution( return "".join([chunk.serialized_chunk for chunk in chunks]) - def external_sensor_execution(self, sensor_execution_args, timeout=DEFAULT_SENSOR_GRPC_TIMEOUT): + def external_sensor_execution( + self, sensor_execution_args: SensorExecutionArgs, timeout: int = DEFAULT_SENSOR_GRPC_TIMEOUT + ) -> str: check.inst_param( sensor_execution_args, "sensor_execution_args", @@ -434,7 +444,7 @@ def external_sensor_execution(self, sensor_execution_args, timeout=DEFAULT_SENSO return "".join([chunk.serialized_chunk for chunk in chunks]) - def external_notebook_data(self, notebook_path: str): + def external_notebook_data(self, notebook_path: str) -> str: check.str_param(notebook_path, "notebook_path") res = self._query( "ExternalNotebookData", @@ -443,11 +453,11 @@ def external_notebook_data(self, notebook_path: str): ) return res.content - def shutdown_server(self, timeout=15): + def shutdown_server(self, timeout: int = 15) -> str: res = self._query("ShutdownServer", api_pb2.Empty, timeout=timeout) return res.serialized_shutdown_server_result - def cancel_execution(self, cancel_execution_request): + def cancel_execution(self, cancel_execution_request: CancelExecutionRequest) -> str: check.inst_param( cancel_execution_request, "cancel_execution_request", @@ -466,7 +476,7 @@ def can_cancel_execution( self, can_cancel_execution_request: CanCancelExecutionRequest, timeout: int = DEFAULT_GRPC_TIMEOUT, - ): + ) -> str: check.inst_param( can_cancel_execution_request, "can_cancel_execution_request", @@ -482,7 +492,7 @@ def can_cancel_execution( return res.serialized_can_cancel_execution_result - def start_run(self, execute_run_args: ExecuteExternalJobArgs): + def start_run(self, execute_run_args: ExecuteExternalJobArgs) -> str: check.inst_param(execute_run_args, "execute_run_args", ExecuteExternalJobArgs) with DagsterInstance.from_ref(execute_run_args.instance_ref) as instance: # type: ignore # (possible none) @@ -505,15 +515,15 @@ def start_run(self, execute_run_args: ExecuteExternalJobArgs): ) raise - def get_current_image(self): + def get_current_image(self) -> str: res = self._query("GetCurrentImage", api_pb2.Empty) return res.serialized_current_image - def get_current_runs(self): + def get_current_runs(self) -> str: res = self._query("GetCurrentRuns", api_pb2.Empty) return res.serialized_current_runs - def health_check_query(self): + def health_check_query(self) -> Any: try: with self._channel() as channel: response = HealthStub(channel).Check(