Skip to content

Commit

Permalink
Type annotations for grpc/client.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Nov 9, 2023
1 parent 92f5396 commit 6d77a1b
Showing 1 changed file with 44 additions and 34 deletions.
78 changes: 44 additions & 34 deletions python_modules/dagster/dagster/_grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from contextlib import contextmanager
from threading import Event
from typing import Any, Iterator, Optional, Sequence, Tuple
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple

import grpc
from google.protobuf.reflection import GeneratedProtocolMessageType
Expand Down Expand Up @@ -134,12 +134,14 @@ def _get_response(
method: str,
request: str,
timeout: int = DEFAULT_GRPC_TIMEOUT,
):
) -> Any:
with self._channel() as channel:
stub = DagsterApiStub(channel)
return getattr(stub, method)(request, metadata=self._metadata, timeout=timeout)

def _raise_grpc_exception(self, e: Exception, timeout, custom_timeout_message=None):
def _raise_grpc_exception(
self, e: Exception, timeout: int, custom_timeout_message: Optional[str] = None
) -> None:
if isinstance(e, grpc.RpcError):
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore # (bad stubs)
raise DagsterUserCodeUnreachableError(
Expand All @@ -160,7 +162,7 @@ def _query(
timeout=DEFAULT_GRPC_TIMEOUT,
custom_timeout_message=None,
**kwargs,
):
) -> Any:
try:
return self._get_response(method, request=request_type(**kwargs), timeout=timeout)
except Exception as e:
Expand All @@ -180,10 +182,10 @@ def _get_streaming_response(

def _streaming_query(
self,
method,
request_type,
timeout=DEFAULT_GRPC_TIMEOUT,
custom_timeout_message=None,
method: str,
request_type: Any,
timeout: int = DEFAULT_GRPC_TIMEOUT,
custom_timeout_message: Optional[str] = None,
**kwargs,
) -> Iterator[Any]:
try:
Expand All @@ -195,17 +197,17 @@ def _streaming_query(
e, timeout=timeout, custom_timeout_message=custom_timeout_message
)

def ping(self, echo: str):
def ping(self, echo: str) -> str:
check.str_param(echo, "echo")
res = self._query("Ping", api_pb2.PingRequest, echo=echo)
return res.echo

def heartbeat(self, echo: str = ""):
def heartbeat(self, echo: str = "") -> str:
check.str_param(echo, "echo")
res = self._query("Heartbeat", api_pb2.PingRequest, echo=echo)
return res.echo

def streaming_ping(self, sequence_length: int, echo: str):
def streaming_ping(self, sequence_length: int, echo: str) -> Iterator[Dict[str, Any]]:
check.int_param(sequence_length, "sequence_length")
check.str_param(echo, "echo")

Expand All @@ -224,7 +226,9 @@ def get_server_id(self, timeout: int = DEFAULT_GRPC_TIMEOUT) -> str:
res = self._query("GetServerId", api_pb2.Empty, timeout=timeout)
return res.server_id

def execution_plan_snapshot(self, execution_plan_snapshot_args: ExecutionPlanSnapshotArgs):
def execution_plan_snapshot(
self, execution_plan_snapshot_args: ExecutionPlanSnapshotArgs
) -> str:
check.inst_param(
execution_plan_snapshot_args, "execution_plan_snapshot_args", ExecutionPlanSnapshotArgs
)
Expand All @@ -235,11 +239,11 @@ def execution_plan_snapshot(self, execution_plan_snapshot_args: ExecutionPlanSna
)
return res.serialized_execution_plan_snapshot

def list_repositories(self):
def list_repositories(self) -> str:
res = self._query("ListRepositories", api_pb2.ListRepositoriesRequest)
return res.serialized_list_repositories_response_or_error

def external_partition_names(self, partition_names_args):
def external_partition_names(self, partition_names_args: PartitionNamesArgs) -> str:
check.inst_param(partition_names_args, "partition_names_args", PartitionNamesArgs)

res = self._query(
Expand All @@ -250,7 +254,7 @@ def external_partition_names(self, partition_names_args):

return res.serialized_external_partition_names_or_external_partition_execution_error

def external_partition_config(self, partition_args):
def external_partition_config(self, partition_args: PartitionArgs) -> str:
check.inst_param(partition_args, "partition_args", PartitionArgs)

res = self._query(
Expand All @@ -261,7 +265,7 @@ def external_partition_config(self, partition_args):

return res.serialized_external_partition_config_or_external_partition_execution_error

def external_partition_tags(self, partition_args):
def external_partition_tags(self, partition_args: PartitionArgs) -> str:
check.inst_param(partition_args, "partition_args", PartitionArgs)

res = self._query(
Expand All @@ -272,7 +276,9 @@ def external_partition_tags(self, partition_args):

return res.serialized_external_partition_tags_or_external_partition_execution_error

def external_partition_set_execution_params(self, partition_set_execution_param_args):
def external_partition_set_execution_params(
self, partition_set_execution_param_args: PartitionSetExecutionParamArgs
) -> str:
check.inst_param(
partition_set_execution_param_args,
"partition_set_execution_param_args",
Expand All @@ -291,7 +297,7 @@ def external_partition_set_execution_params(self, partition_set_execution_param_

return "".join([chunk.serialized_chunk for chunk in chunks])

def external_pipeline_subset(self, pipeline_subset_snapshot_args):
def external_pipeline_subset(self, pipeline_subset_snapshot_args: JobSubsetSnapshotArgs) -> str:
check.inst_param(
pipeline_subset_snapshot_args,
"pipeline_subset_snapshot_args",
Expand All @@ -306,14 +312,14 @@ def external_pipeline_subset(self, pipeline_subset_snapshot_args):

return res.serialized_external_pipeline_subset_result

def reload_code(self, timeout: int):
def reload_code(self, timeout: int) -> Any:
return self._query("ReloadCode", api_pb2.ReloadCodeRequest, timeout=timeout)

def external_repository(
self,
external_repository_origin: ExternalRepositoryOrigin,
defer_snapshots: bool = False,
):
) -> str:
check.inst_param(
external_repository_origin,
"external_repository_origin",
Expand All @@ -334,7 +340,7 @@ def external_job(
self,
external_repository_origin: ExternalRepositoryOrigin,
job_name: str,
):
) -> str:
check.inst_param(
external_repository_origin,
"external_repository_origin",
Expand All @@ -352,8 +358,8 @@ def streaming_external_repository(
self,
external_repository_origin: ExternalRepositoryOrigin,
defer_snapshots: bool = False,
timeout=DEFAULT_REPOSITORY_GRPC_TIMEOUT,
):
timeout: int = DEFAULT_REPOSITORY_GRPC_TIMEOUT,
) -> Iterator[Dict[str, Any]]:
for res in self._streaming_query(
"StreamingExternalRepository",
api_pb2.ExternalRepositoryRequest,
Expand All @@ -368,8 +374,10 @@ def streaming_external_repository(
}

def external_schedule_execution(
self, external_schedule_execution_args, timeout=DEFAULT_SCHEDULE_GRPC_TIMEOUT
):
self,
external_schedule_execution_args: ExternalScheduleExecutionArgs,
timeout: int = DEFAULT_SCHEDULE_GRPC_TIMEOUT,
) -> str:
check.inst_param(
external_schedule_execution_args,
"external_schedule_execution_args",
Expand Down Expand Up @@ -399,7 +407,9 @@ def external_schedule_execution(

return "".join([chunk.serialized_chunk for chunk in chunks])

def external_sensor_execution(self, sensor_execution_args, timeout=DEFAULT_SENSOR_GRPC_TIMEOUT):
def external_sensor_execution(
self, sensor_execution_args: SensorExecutionArgs, timeout: int = DEFAULT_SENSOR_GRPC_TIMEOUT
) -> str:
check.inst_param(
sensor_execution_args,
"sensor_execution_args",
Expand Down Expand Up @@ -434,7 +444,7 @@ def external_sensor_execution(self, sensor_execution_args, timeout=DEFAULT_SENSO

return "".join([chunk.serialized_chunk for chunk in chunks])

def external_notebook_data(self, notebook_path: str):
def external_notebook_data(self, notebook_path: str) -> str:
check.str_param(notebook_path, "notebook_path")
res = self._query(
"ExternalNotebookData",
Expand All @@ -443,11 +453,11 @@ def external_notebook_data(self, notebook_path: str):
)
return res.content

def shutdown_server(self, timeout=15):
def shutdown_server(self, timeout: int = 15) -> str:
res = self._query("ShutdownServer", api_pb2.Empty, timeout=timeout)
return res.serialized_shutdown_server_result

def cancel_execution(self, cancel_execution_request):
def cancel_execution(self, cancel_execution_request: CancelExecutionRequest) -> str:
check.inst_param(
cancel_execution_request,
"cancel_execution_request",
Expand All @@ -466,7 +476,7 @@ def can_cancel_execution(
self,
can_cancel_execution_request: CanCancelExecutionRequest,
timeout: int = DEFAULT_GRPC_TIMEOUT,
):
) -> str:
check.inst_param(
can_cancel_execution_request,
"can_cancel_execution_request",
Expand All @@ -482,7 +492,7 @@ def can_cancel_execution(

return res.serialized_can_cancel_execution_result

def start_run(self, execute_run_args: ExecuteExternalJobArgs):
def start_run(self, execute_run_args: ExecuteExternalJobArgs) -> str:
check.inst_param(execute_run_args, "execute_run_args", ExecuteExternalJobArgs)

with DagsterInstance.from_ref(execute_run_args.instance_ref) as instance: # type: ignore # (possible none)
Expand All @@ -505,15 +515,15 @@ def start_run(self, execute_run_args: ExecuteExternalJobArgs):
)
raise

def get_current_image(self):
def get_current_image(self) -> str:
res = self._query("GetCurrentImage", api_pb2.Empty)
return res.serialized_current_image

def get_current_runs(self):
def get_current_runs(self) -> str:
res = self._query("GetCurrentRuns", api_pb2.Empty)
return res.serialized_current_runs

def health_check_query(self):
def health_check_query(self) -> Any:
try:
with self._channel() as channel:
response = HealthStub(channel).Check(
Expand Down

0 comments on commit 6d77a1b

Please sign in to comment.