From 9660735f4352000777beeb2d5ea5931941fb0297 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Thu, 18 Jan 2024 16:38:31 -0800 Subject: [PATCH] Add ruff auto fix (#657) * Add ruff Signed-off-by: Bernd Verst * Autoformat all files using ruff Signed-off-by: Bernd Verst * Run ruff check on CI Signed-off-by: Bernd Verst * fix up type checker exemption Signed-off-by: Bernd Verst --------- Signed-off-by: Bernd Verst Co-authored-by: Bernd Verst --- .github/workflows/build.yaml | 31 +- README.md | 12 +- dapr/actor/__init__.py | 16 +- dapr/actor/actor_interface.py | 3 + dapr/actor/client/proxy.py | 71 +- dapr/actor/runtime/_call_type.py | 1 + dapr/actor/runtime/_reminder_data.py | 34 +- dapr/actor/runtime/_state_provider.py | 34 +- dapr/actor/runtime/_timer_data.py | 22 +- dapr/actor/runtime/_type_information.py | 19 +- dapr/actor/runtime/_type_utils.py | 32 +- dapr/actor/runtime/actor.py | 36 +- dapr/actor/runtime/config.py | 77 +- dapr/actor/runtime/context.py | 19 +- dapr/actor/runtime/manager.py | 41 +- dapr/actor/runtime/method_dispatcher.py | 11 +- dapr/actor/runtime/remindable.py | 11 +- dapr/actor/runtime/runtime.py | 37 +- dapr/actor/runtime/state_change.py | 3 +- dapr/actor/runtime/state_manager.py | 87 +- dapr/aio/clients/__init__.py | 74 +- dapr/aio/clients/grpc/_asynchelpers.py | 32 +- dapr/aio/clients/grpc/client.py | 950 +++++++++-------- dapr/clients/__init__.py | 108 +- dapr/clients/base.py | 30 +- dapr/clients/exceptions.py | 15 +- dapr/clients/grpc/_helpers.py | 48 +- dapr/clients/grpc/_request.py | 57 +- dapr/clients/grpc/_response.py | 197 ++-- dapr/clients/grpc/_state.py | 6 +- dapr/clients/grpc/client.py | 954 +++++++++--------- dapr/clients/http/client.py | 60 +- dapr/clients/http/dapr_actor_http_client.py | 71 +- .../http/dapr_invocation_http_client.py | 57 +- dapr/conf/global_settings.py | 6 +- dapr/conf/helpers.py | 34 +- dapr/serializers/__init__.py | 5 +- dapr/serializers/base.py | 11 +- dapr/serializers/json.py | 35 +- dapr/serializers/util.py | 27 +- dapr/version/__init__.py | 4 +- dev-requirements.txt | 2 + docs/conf.py | 25 +- examples/configuration/configuration.py | 32 +- examples/demo_actor/demo_actor/demo_actor.py | 64 +- .../demo_actor/demo_actor_client.py | 7 +- .../demo_actor/demo_actor/demo_actor_flask.py | 16 +- .../demo_actor/demo_actor_service.py | 10 +- examples/demo_workflow/app.py | 95 +- examples/distributed_lock/lock.py | 24 +- .../grpc_proxying/helloworld_service_pb2.py | 60 +- .../helloworld_service_pb2_grpc.py | 75 +- examples/grpc_proxying/invoke-caller.py | 13 +- examples/grpc_proxying/invoke-receiver.py | 18 +- .../invoke-binding/invoke-input-binding.py | 6 +- .../invoke-binding/invoke-output-binding.py | 7 +- examples/invoke-custom-data/invoke-caller.py | 8 +- .../invoke-custom-data/invoke-receiver.py | 9 +- .../invoke-custom-data/proto/response_pb2.py | 34 +- .../proto/response_pb2_grpc.py | 1 - examples/invoke-http/invoke-caller.py | 11 +- examples/invoke-http/invoke-receiver.py | 10 +- examples/invoke-simple/invoke-caller.py | 9 +- examples/invoke-simple/invoke-receiver.py | 6 +- examples/metadata/app.py | 8 +- examples/pubsub-simple/publisher.py | 32 +- examples/pubsub-simple/subscriber.py | 50 +- examples/secret_store/example.py | 11 +- examples/state_store/state_store.py | 48 +- .../state_store_query/state_store_query.py | 13 +- examples/w3c-tracing/invoke-caller.py | 22 +- examples/w3c-tracing/invoke-receiver.py | 20 +- examples/workflow/child_workflow.py | 21 +- examples/workflow/fan_out_fan_in.py | 24 +- examples/workflow/human_approval.py | 25 +- examples/workflow/monitor.py | 15 +- examples/workflow/task_chaining.py | 21 +- .../dapr/ext/fastapi/__init__.py | 5 +- .../dapr/ext/fastapi/actor.py | 112 +- ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py | 56 +- ext/dapr-ext-fastapi/setup.py | 26 +- ext/dapr-ext-fastapi/tests/test_app.py | 85 +- ext/dapr-ext-fastapi/tests/test_dapractor.py | 42 +- ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py | 14 +- ext/dapr-ext-grpc/dapr/ext/grpc/_servicier.py | 74 +- ext/dapr-ext-grpc/dapr/ext/grpc/app.py | 42 +- ext/dapr-ext-grpc/setup.py | 26 +- ext/dapr-ext-grpc/tests/test_app.py | 40 +- ext/dapr-ext-grpc/tests/test_servicier.py | 145 ++- .../tests/test_topic_event_response.py | 2 +- .../dapr/ext/workflow/__init__.py | 20 +- .../dapr/ext/workflow/dapr_workflow_client.py | 102 +- .../ext/workflow/dapr_workflow_context.py | 57 +- .../dapr/ext/workflow/logger/__init__.py | 5 +- .../dapr/ext/workflow/logger/logger.py | 4 +- .../dapr/ext/workflow/logger/options.py | 11 +- .../dapr/ext/workflow/retry_policy.py | 31 +- .../dapr/ext/workflow/util.py | 5 +- .../ext/workflow/workflow_activity_context.py | 8 +- .../dapr/ext/workflow/workflow_context.py | 25 +- .../dapr/ext/workflow/workflow_runtime.py | 94 +- .../dapr/ext/workflow/workflow_state.py | 26 +- ext/dapr-ext-workflow/setup.py | 26 +- .../tests/test_dapr_workflow_context.py | 14 +- .../tests/test_workflow_activity_context.py | 7 +- .../tests/test_workflow_client.py | 73 +- .../tests/test_workflow_runtime.py | 43 +- .../tests/test_workflow_util.py | 1 - ext/flask_dapr/flask_dapr/__init__.py | 2 +- ext/flask_dapr/flask_dapr/actor.py | 70 +- ext/flask_dapr/flask_dapr/app.py | 51 +- ext/flask_dapr/setup.py | 26 +- ext/flask_dapr/tests/test_app.py | 67 +- pyproject.toml | 18 + setup.cfg | 2 +- setup.py | 26 +- tests/actor/fake_actor_classes.py | 48 +- tests/actor/fake_client.py | 47 +- tests/actor/test_actor.py | 113 ++- tests/actor/test_actor_id.py | 14 +- tests/actor/test_actor_manager.py | 79 +- tests/actor/test_actor_reentrancy.py | 170 ++-- tests/actor/test_actor_runtime.py | 46 +- tests/actor/test_actor_runtime_config.py | 125 ++- tests/actor/test_client_proxy.py | 60 +- tests/actor/test_method_dispatcher.py | 8 +- tests/actor/test_reminder_data.py | 57 +- tests/actor/test_state_manager.py | 297 +++--- tests/actor/test_timer_data.py | 36 +- tests/actor/test_type_utils.py | 19 +- tests/clients/certs.py | 10 +- tests/clients/fake_dapr_server.py | 120 +-- tests/clients/fake_http_server.py | 31 +- tests/clients/test_client_interceptor.py | 17 +- tests/clients/test_dapr_async_grpc_client.py | 437 ++++---- tests/clients/test_dapr_grpc_client.py | 416 ++++---- tests/clients/test_dapr_grpc_request.py | 28 +- tests/clients/test_dapr_grpc_response.py | 67 +- .../test_http_service_invocation_client.py | 176 ++-- .../test_secure_dapr_async_grpc_client.py | 6 +- tests/clients/test_secure_dapr_grpc_client.py | 6 +- ...t_secure_http_service_invocation_client.py | 36 +- tests/conf/helpers_test.py | 401 ++++++-- .../test_default_json_serializer.py | 53 +- tests/serializers/test_util.py | 37 +- tox.ini | 8 + 146 files changed, 4725 insertions(+), 3956 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index fb6fc2ada..277810e5f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,7 +20,35 @@ on: merge_group: jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: 3.9 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel tox + - name: Run Autoformatter + run: | + tox -e ruff + statusResult=$(git status -u --porcelain) + if [ -z $statusResult ] + then + exit 0 + else + echo "Source files are not formatted correctly. Run 'tox -e ruff' to autoformat." + exit 1 + fi + - name: Run Linter + run: | + tox -e flake8 + build: + needs: lint runs-on: ubuntu-latest strategy: fail-fast: false @@ -36,9 +64,6 @@ jobs: run: | python -m pip install --upgrade pip pip install setuptools wheel tox - - name: Run Linter - run: | - tox -e flake8 - name: Check Typing run: | tox -e type diff --git a/README.md b/README.md index 233cfe3c9..faa834910 100644 --- a/README.md +++ b/README.md @@ -96,19 +96,25 @@ pip3 install -r dev-requirements.txt tox -e flake8 ``` -5. Run unit-test +5. Run autofix + +```bash +tox -e ruff +``` + +6. Run unit-test ```bash tox -e py311 ``` -6. Run type check +7. Run type check ```bash tox -e type ``` -7. Run examples +8. Run examples ```bash tox -e examples diff --git a/dapr/actor/__init__.py b/dapr/actor/__init__.py index 4323caae2..5c94a8f1d 100644 --- a/dapr/actor/__init__.py +++ b/dapr/actor/__init__.py @@ -22,12 +22,12 @@ __all__ = [ - 'ActorInterface', - 'ActorProxy', - 'ActorProxyFactory', - 'ActorId', - 'Actor', - 'ActorRuntime', - 'Remindable', - 'actormethod', + "ActorInterface", + "ActorProxy", + "ActorProxyFactory", + "ActorId", + "Actor", + "ActorRuntime", + "Remindable", + "actormethod", ] diff --git a/dapr/actor/actor_interface.py b/dapr/actor/actor_interface.py index dbe238c22..2de1c04ad 100644 --- a/dapr/actor/actor_interface.py +++ b/dapr/actor/actor_interface.py @@ -33,6 +33,7 @@ async def do_actor_method1(self, param): async def do_actor_method2(self, param): ... """ + ... @@ -51,8 +52,10 @@ async def do_actor_call(self, param): Args: name (str, optional): the name of actor method. """ + def wrapper(funcobj): funcobj.__actormethod__ = name funcobj.__isabstractmethod__ = True return funcobj + return wrapper diff --git a/dapr/actor/client/proxy.py b/dapr/actor/client/proxy.py index 58ed16c52..6585ce5c1 100644 --- a/dapr/actor/client/proxy.py +++ b/dapr/actor/client/proxy.py @@ -24,14 +24,17 @@ from dapr.conf import settings # Actor factory Callable type hint. -ACTOR_FACTORY_CALLBACK = Callable[[ActorInterface, str, str], 'ActorProxy'] +ACTOR_FACTORY_CALLBACK = Callable[[ActorInterface, str, str], "ActorProxy"] class ActorFactoryBase(ABC): @abstractmethod def create( - self, actor_type: str, actor_id: ActorId, - actor_interface: Optional[Type[ActorInterface]] = None) -> 'ActorProxy': + self, + actor_type: str, + actor_id: ActorId, + actor_interface: Optional[Type[ActorInterface]] = None, + ) -> "ActorProxy": ... @@ -44,32 +47,36 @@ class ActorProxyFactory(ActorFactoryBase): """ def __init__( - self, - message_serializer=DefaultJSONSerializer(), - http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS): + self, + message_serializer=DefaultJSONSerializer(), + http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS, + ): # TODO: support serializer for state store later self._dapr_client = DaprActorHttpClient(message_serializer, timeout=http_timeout_seconds) self._message_serializer = message_serializer def create( - self, actor_type: str, actor_id: ActorId, - actor_interface: Optional[Type[ActorInterface]] = None) -> 'ActorProxy': + self, + actor_type: str, + actor_id: ActorId, + actor_interface: Optional[Type[ActorInterface]] = None, + ) -> "ActorProxy": return ActorProxy( - self._dapr_client, actor_type, actor_id, - actor_interface, self._message_serializer) + self._dapr_client, actor_type, actor_id, actor_interface, self._message_serializer + ) class CallableProxy: def __init__( - self, proxy: 'ActorProxy', attr_call_type: Dict[str, Any], - message_serializer: Serializer): + self, proxy: "ActorProxy", attr_call_type: Dict[str, Any], message_serializer: Serializer + ): self._proxy = proxy self._attr_call_type = attr_call_type self._message_serializer = message_serializer async def __call__(self, *args, **kwargs) -> Any: if len(args) > 1: - raise ValueError('does not support multiple arguments') + raise ValueError("does not support multiple arguments") bytes_data = None if len(args) > 0: @@ -78,9 +85,9 @@ async def __call__(self, *args, **kwargs) -> Any: else: bytes_data = self._message_serializer.serialize(args[0]) - rtnval = await self._proxy.invoke_method(self._attr_call_type['actor_method'], bytes_data) + rtnval = await self._proxy.invoke_method(self._attr_call_type["actor_method"], bytes_data) - return self._message_serializer.deserialize(rtnval, self._attr_call_type['return_types']) + return self._message_serializer.deserialize(rtnval, self._attr_call_type["return_types"]) class ActorProxy: @@ -94,11 +101,13 @@ class ActorProxy: _default_proxy_factory = ActorProxyFactory() def __init__( - self, client: DaprActorClientBase, - actor_type: str, - actor_id: ActorId, - actor_interface: Optional[Type[ActorInterface]], - message_serializer: Serializer): + self, + client: DaprActorClientBase, + actor_type: str, + actor_id: ActorId, + actor_interface: Optional[Type[ActorInterface]], + message_serializer: Serializer, + ): self._dapr_client = client self._actor_id = actor_id self._actor_type = actor_type @@ -120,10 +129,12 @@ def actor_type(self) -> str: @classmethod def create( - cls, - actor_type: str, actor_id: ActorId, - actor_interface: Optional[Type[ActorInterface]] = None, - actor_proxy_factory: Optional[ActorFactoryBase] = None) -> 'ActorProxy': + cls, + actor_type: str, + actor_id: ActorId, + actor_interface: Optional[Type[ActorInterface]] = None, + actor_proxy_factory: Optional[ActorFactoryBase] = None, + ) -> "ActorProxy": """Creates ActorProxy client to call actor. Args: @@ -157,10 +168,11 @@ async def invoke_method(self, method: str, raw_body: Optional[bytes] = None) -> """ if raw_body is not None and not isinstance(raw_body, bytes): - raise ValueError(f'raw_body {type(raw_body)} is not bytes type') + raise ValueError(f"raw_body {type(raw_body)} is not bytes type") return await self._dapr_client.invoke_method( - self._actor_type, str(self._actor_id), method, raw_body) + self._actor_type, str(self._actor_id), method, raw_body + ) def __getattr__(self, name: str) -> CallableProxy: """Enables RPC style actor method invocation. @@ -177,17 +189,18 @@ def __getattr__(self, name: str) -> CallableProxy: AttributeError: method is not defined in Actor interface. """ if not self._actor_interface: - raise ValueError('actor_interface is not set. use invoke method.') + raise ValueError("actor_interface is not set. use invoke method.") if name not in self._dispatchable_attr: get_dispatchable_attrs_from_interface(self._actor_interface, self._dispatchable_attr) attr_call_type = self._dispatchable_attr.get(name) if attr_call_type is None: - raise AttributeError(f'{self._actor_interface.__class__} has no attribute {name}') + raise AttributeError(f"{self._actor_interface.__class__} has no attribute {name}") if name not in self._callable_proxies: self._callable_proxies[name] = CallableProxy( - self, attr_call_type, self._message_serializer) + self, attr_call_type, self._message_serializer + ) return self._callable_proxies[name] diff --git a/dapr/actor/runtime/_call_type.py b/dapr/actor/runtime/_call_type.py index f47d853e6..a61b3ab8f 100644 --- a/dapr/actor/runtime/_call_type.py +++ b/dapr/actor/runtime/_call_type.py @@ -21,6 +21,7 @@ class ActorCallType(Enum): :class:`ActorMethodContext` includes :class:`ActorCallType` passing to :meth:`Actor._on_pre_actor_method` and :meth:`Actor._on_post_actor_method` """ + # Specifies that the method invoked is an actor interface method for a given client request. actor_interface_method = 0 # Specifies that the method invoked is a timer callback method. diff --git a/dapr/actor/runtime/_reminder_data.py b/dapr/actor/runtime/_reminder_data.py index a080ce590..fb6f7bbd5 100644 --- a/dapr/actor/runtime/_reminder_data.py +++ b/dapr/actor/runtime/_reminder_data.py @@ -32,8 +32,13 @@ class ActorReminderData: """ def __init__( - self, reminder_name: str, state: Optional[bytes], - due_time: timedelta, period: timedelta, ttl: Optional[timedelta] = None): + self, + reminder_name: str, + state: Optional[bytes], + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta] = None, + ): """Creates new :class:`ActorReminderData` instance. Args: @@ -52,7 +57,7 @@ def __init__( self._ttl = ttl if not isinstance(state, bytes): - raise ValueError(f'only bytes are allowed for state: {type(state)}') + raise ValueError(f"only bytes are allowed for state: {type(state)}") self._state = state @@ -87,26 +92,27 @@ def as_dict(self) -> Dict[str, Any]: if self._state is not None: encoded_state = base64.b64encode(self._state) reminderDict: Dict[str, Any] = { - 'reminderName': self._reminder_name, - 'dueTime': self._due_time, - 'period': self._period, - 'data': encoded_state.decode("utf-8") + "reminderName": self._reminder_name, + "dueTime": self._due_time, + "period": self._period, + "data": encoded_state.decode("utf-8"), } if self._ttl is not None: - reminderDict.update({'ttl': self._ttl}) + reminderDict.update({"ttl": self._ttl}) return reminderDict @classmethod - def from_dict(cls, reminder_name: str, obj: Dict[str, Any]) -> 'ActorReminderData': + def from_dict(cls, reminder_name: str, obj: Dict[str, Any]) -> "ActorReminderData": """Creates :class:`ActorReminderData` object from dict object.""" - b64encoded_state = obj.get('data') + b64encoded_state = obj.get("data") state_bytes = None if b64encoded_state is not None and len(b64encoded_state) > 0: state_bytes = base64.b64decode(b64encoded_state) - if 'ttl' in obj: - return ActorReminderData(reminder_name, state_bytes, obj['dueTime'], obj['period'], - obj['ttl']) + if "ttl" in obj: + return ActorReminderData( + reminder_name, state_bytes, obj["dueTime"], obj["period"], obj["ttl"] + ) else: - return ActorReminderData(reminder_name, state_bytes, obj['dueTime'], obj['period']) + return ActorReminderData(reminder_name, state_bytes, obj["dueTime"], obj["period"]) diff --git a/dapr/actor/runtime/_state_provider.py b/dapr/actor/runtime/_state_provider.py index 6b7f3227a..c5a745b3b 100644 --- a/dapr/actor/runtime/_state_provider.py +++ b/dapr/actor/runtime/_state_provider.py @@ -23,9 +23,9 @@ # Mapping StateChangeKind to Dapr State Operation _MAP_CHANGE_KIND_TO_OPERATION = { - StateChangeKind.remove: b'delete', - StateChangeKind.add: b'upsert', - StateChangeKind.update: b'upsert', + StateChangeKind.remove: b"delete", + StateChangeKind.add: b"upsert", + StateChangeKind.update: b"upsert", } @@ -34,16 +34,18 @@ class StateProvider: This provides the decorator methods to load and save states and check the existence of states. """ + def __init__( - self, - actor_client: DaprActorClientBase, - state_serializer: Serializer = DefaultJSONSerializer()): + self, + actor_client: DaprActorClientBase, + state_serializer: Serializer = DefaultJSONSerializer(), + ): self._state_client = actor_client self._state_serializer = state_serializer async def try_load_state( - self, actor_type: str, actor_id: str, - state_name: str, state_type: Type[Any] = object) -> Tuple[bool, Any]: + self, actor_type: str, actor_id: str, state_name: str, state_type: Type[Any] = object + ) -> Tuple[bool, Any]: raw_state_value = await self._state_client.get_state(actor_type, actor_id, state_name) if (not raw_state_value) or len(raw_state_value) == 0: return (False, None) @@ -55,8 +57,8 @@ async def contains_state(self, actor_type: str, actor_id: str, state_name: str) return (raw_state_value is not None) and len(raw_state_value) > 0 async def save_state( - self, actor_type: str, actor_id: str, - state_changes: List[ActorStateChange]) -> None: + self, actor_type: str, actor_id: str, state_changes: List[ActorStateChange] + ) -> None: """ Transactional state update request body: [ @@ -77,24 +79,24 @@ async def save_state( """ json_output = io.BytesIO() - json_output.write(b'[') + json_output.write(b"[") first_state = True for state in state_changes: if not first_state: - json_output.write(b',') - operation = _MAP_CHANGE_KIND_TO_OPERATION.get(state.change_kind) or b'' + json_output.write(b",") + operation = _MAP_CHANGE_KIND_TO_OPERATION.get(state.change_kind) or b"" json_output.write(b'{"operation":"') json_output.write(operation) json_output.write(b'","request":{"key":"') - json_output.write(state.state_name.encode('utf-8')) + json_output.write(state.state_name.encode("utf-8")) json_output.write(b'"') if state.value is not None: serialized = self._state_serializer.serialize(state.value) json_output.write(b',"value":') json_output.write(serialized) - json_output.write(b'}}') + json_output.write(b"}}") first_state = False - json_output.write(b']') + json_output.write(b"]") data = json_output.getvalue() json_output.close() await self._state_client.save_state_transactionally(actor_type, actor_id, data) diff --git a/dapr/actor/runtime/_timer_data.py b/dapr/actor/runtime/_timer_data.py index 2b9518c2a..c274e2d9b 100644 --- a/dapr/actor/runtime/_timer_data.py +++ b/dapr/actor/runtime/_timer_data.py @@ -33,10 +33,14 @@ class ActorTimerData: """ def __init__( - self, timer_name: str, - callback: TIMER_CALLBACK, state: Any, - due_time: timedelta, period: timedelta, - ttl: Optional[timedelta] = None): + self, + timer_name: str, + callback: TIMER_CALLBACK, + state: Any, + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta] = None, + ): """Create new :class:`ActorTimerData` instance. Args: @@ -93,13 +97,13 @@ def as_dict(self) -> Dict[str, Any]: """ timerDict: Dict[str, Any] = { - 'callback': self._callback, - 'data': self._state, - 'dueTime': self._due_time, - 'period': self._period + "callback": self._callback, + "data": self._state, + "dueTime": self._due_time, + "period": self._period, } if self._ttl: - timerDict.update({'ttl': self._ttl}) + timerDict.update({"ttl": self._ttl}) return timerDict diff --git a/dapr/actor/runtime/_type_information.py b/dapr/actor/runtime/_type_information.py index 9d393d37b..bbb8ec037 100644 --- a/dapr/actor/runtime/_type_information.py +++ b/dapr/actor/runtime/_type_information.py @@ -17,6 +17,7 @@ from dapr.actor.runtime._type_utils import is_dapr_actor, get_actor_interfaces from typing import List, Type, TYPE_CHECKING + if TYPE_CHECKING: from dapr.actor.actor_interface import ActorInterface # noqa: F401 from dapr.actor.runtime.actor import Actor # noqa: F401 @@ -27,8 +28,12 @@ class ActorTypeInformation: implementing an actor. """ - def __init__(self, name: str, implementation_class: Type['Actor'], - actor_bases: List[Type['ActorInterface']]): + def __init__( + self, + name: str, + implementation_class: Type["Actor"], + actor_bases: List[Type["ActorInterface"]], + ): self._name = name self._impl_type = implementation_class self._actor_bases = actor_bases @@ -39,12 +44,12 @@ def type_name(self) -> str: return self._name @property - def implementation_type(self) -> Type['Actor']: + def implementation_type(self) -> Type["Actor"]: """Returns Actor implementation type.""" return self._impl_type @property - def actor_interfaces(self) -> List[Type['ActorInterface']]: + def actor_interfaces(self) -> List[Type["ActorInterface"]]: """Returns the list of :class:`ActorInterface` of this type.""" return self._actor_bases @@ -53,7 +58,7 @@ def is_remindable(self) -> bool: return Remindable in self._impl_type.__bases__ @classmethod - def create(cls, actor_class: Type['Actor']) -> 'ActorTypeInformation': + def create(cls, actor_class: Type["Actor"]) -> "ActorTypeInformation": """Creates :class:`ActorTypeInformation` for actor_class. Args: @@ -64,10 +69,10 @@ def create(cls, actor_class: Type['Actor']) -> 'ActorTypeInformation': and actor base class deriving :class:`ActorInterface` """ if not is_dapr_actor(actor_class): - raise ValueError(f'{actor_class.__name__} is not actor') + raise ValueError(f"{actor_class.__name__} is not actor") actors = get_actor_interfaces(actor_class) if len(actors) == 0: - raise ValueError(f'{actor_class.__name__} does not implement ActorInterface') + raise ValueError(f"{actor_class.__name__} does not implement ActorInterface") return ActorTypeInformation(actor_class.__name__, actor_class, actors) diff --git a/dapr/actor/runtime/_type_utils.py b/dapr/actor/runtime/_type_utils.py index 6094f5db7..46d97c31a 100644 --- a/dapr/actor/runtime/_type_utils.py +++ b/dapr/actor/runtime/_type_utils.py @@ -20,16 +20,16 @@ def get_class_method_args(func: Any) -> List[str]: - args = func.__code__.co_varnames[:func.__code__.co_argcount] + args = func.__code__.co_varnames[: func.__code__.co_argcount] # Exclude self, cls arguments - if args[0] == 'self' or args[0] == 'cls': + if args[0] == "self" or args[0] == "cls": args = args[1:] return list(args) def get_method_arg_types(func: Any) -> List[Type]: - annotations = getattr(func, '__annotations__') + annotations = getattr(func, "__annotations__") args = get_class_method_args(func) arg_types = [] for arg_name in args: @@ -39,26 +39,26 @@ def get_method_arg_types(func: Any) -> List[Type]: def get_method_return_types(func: Any) -> Type: - annotations = getattr(func, '__annotations__') - if len(annotations) == 0 or not annotations['return']: + annotations = getattr(func, "__annotations__") + if len(annotations) == 0 or not annotations["return"]: return object - return annotations['return'] + return annotations["return"] def get_dispatchable_attrs_from_interface( - actor_interface: Type[ActorInterface], - dispatch_map: Dict[str, Any]) -> None: + actor_interface: Type[ActorInterface], dispatch_map: Dict[str, Any] +) -> None: for attr, v in actor_interface.__dict__.items(): - if attr.startswith('_') or not callable(v): + if attr.startswith("_") or not callable(v): continue - actor_method_name = getattr(v, '__actormethod__') if hasattr(v, '__actormethod__') else attr + actor_method_name = getattr(v, "__actormethod__") if hasattr(v, "__actormethod__") else attr dispatch_map[actor_method_name] = { - 'actor_method': actor_method_name, - 'method_name': attr, - 'arg_names': get_class_method_args(v), - 'arg_types': get_method_arg_types(v), - 'return_types': get_method_return_types(v) + "actor_method": actor_method_name, + "method_name": attr, + "arg_names": get_class_method_args(v), + "arg_types": get_method_arg_types(v), + "return_types": get_method_return_types(v), } @@ -77,7 +77,7 @@ def get_dispatchable_attrs(actor_class: Type[Actor]) -> Dict[str, Any]: # Find all user actor interfaces derived from ActorInterface actor_interfaces = get_actor_interfaces(actor_class) if len(actor_interfaces) == 0: - raise ValueError(f'{actor_class.__name__} has not inherited from ActorInterface') + raise ValueError(f"{actor_class.__name__} has not inherited from ActorInterface") # Find all dispatchable attributes dispatch_map: Dict[str, Any] = {} diff --git a/dapr/actor/runtime/actor.py b/dapr/actor/runtime/actor.py index e74239247..0e8eb7fcd 100644 --- a/dapr/actor/runtime/actor.py +++ b/dapr/actor/runtime/actor.py @@ -64,11 +64,17 @@ def runtime_ctx(self) -> ActorRuntimeContext: return self._runtime_ctx def __get_new_timer_name(self): - return f'{self.id}_Timer_{uuid.uuid4()}' + return f"{self.id}_Timer_{uuid.uuid4()}" async def register_timer( - self, name: Optional[str], callback: TIMER_CALLBACK, state: Any, - due_time: timedelta, period: timedelta, ttl: Optional[timedelta] = None) -> None: + self, + name: Optional[str], + callback: TIMER_CALLBACK, + state: Any, + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta] = None, + ) -> None: """Registers actor timer. All timers are stopped when the actor is deactivated as part of garbage collection. @@ -88,7 +94,8 @@ async def register_timer( req_body = self._runtime_ctx.message_serializer.serialize(timer.as_dict()) await self._runtime_ctx.dapr_client.register_timer( - self._runtime_ctx.actor_type_info.type_name, self.id.id, name, req_body) + self._runtime_ctx.actor_type_info.type_name, self.id.id, name, req_body + ) async def unregister_timer(self, name: str) -> None: """Unregisters actor timer. @@ -97,11 +104,17 @@ async def unregister_timer(self, name: str) -> None: name (str): the name of the timer to unregister. """ await self._runtime_ctx.dapr_client.unregister_timer( - self._runtime_ctx.actor_type_info.type_name, self.id.id, name) + self._runtime_ctx.actor_type_info.type_name, self.id.id, name + ) async def register_reminder( - self, name: str, state: bytes, - due_time: timedelta, period: timedelta, ttl: Optional[timedelta] = None) -> None: + self, + name: str, + state: bytes, + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta] = None, + ) -> None: """Registers actor reminder. Reminders are a mechanism to trigger persistent callbacks on an actor at specified times. @@ -124,7 +137,8 @@ async def register_reminder( reminder = ActorReminderData(name, state, due_time, period, ttl) req_body = self._runtime_ctx.message_serializer.serialize(reminder.as_dict()) await self._runtime_ctx.dapr_client.register_reminder( - self._runtime_ctx.actor_type_info.type_name, self.id.id, name, req_body) + self._runtime_ctx.actor_type_info.type_name, self.id.id, name, req_body + ) async def unregister_reminder(self, name: str) -> None: """Unregisters actor reminder. @@ -133,7 +147,8 @@ async def unregister_reminder(self, name: str) -> None: name (str): the name of the reminder to unregister. """ await self._runtime_ctx.dapr_client.unregister_reminder( - self._runtime_ctx.actor_type_info.type_name, self.id.id, name) + self._runtime_ctx.actor_type_info.type_name, self.id.id, name + ) async def _on_activate_internal(self) -> None: """Clears all state cache, calls the overridden :meth:`_on_activate`, @@ -170,8 +185,7 @@ async def _on_post_actor_method_internal(self, method_context: ActorMethodContex await self._save_state_internal() async def _on_invoke_failed_internal(self, exception=None): - """Clears states in the cache when actor method invocation is failed. - """ + """Clears states in the cache when actor method invocation is failed.""" await self._reset_state_internal() async def _reset_state_internal(self) -> None: diff --git a/dapr/actor/runtime/config.py b/dapr/actor/runtime/config.py index 3b2099139..fef7060ff 100644 --- a/dapr/actor/runtime/config.py +++ b/dapr/actor/runtime/config.py @@ -18,10 +18,7 @@ class ActorReentrancyConfig: - def __init__( - self, - enabled: bool = False, - maxStackDepth: int = 32): + def __init__(self, enabled: bool = False, maxStackDepth: int = 32): """Inits :class:`ActorReentrancyConfig` to optionally configure actor reentrancy. @@ -37,8 +34,8 @@ def __init__( def as_dict(self) -> Dict[str, Any]: """Returns ActorReentrancyConfig as a dict.""" return { - 'enabled': self._enabled, - 'maxStackDepth': self._maxStackDepth, + "enabled": self._enabled, + "maxStackDepth": self._maxStackDepth, } @@ -48,14 +45,15 @@ class ActorTypeConfig: """ def __init__( - self, - actor_type: str, - actor_idle_timeout: Optional[timedelta] = None, - actor_scan_interval: Optional[timedelta] = None, - drain_ongoing_call_timeout: Optional[timedelta] = None, - drain_rebalanced_actors: Optional[bool] = None, - reentrancy: Optional[ActorReentrancyConfig] = None, - reminders_storage_partitions: Optional[int] = None): + self, + actor_type: str, + actor_idle_timeout: Optional[timedelta] = None, + actor_scan_interval: Optional[timedelta] = None, + drain_ongoing_call_timeout: Optional[timedelta] = None, + drain_rebalanced_actors: Optional[bool] = None, + reentrancy: Optional[ActorReentrancyConfig] = None, + reminders_storage_partitions: Optional[int] = None, + ): """Inits :class:`ActorTypeConfig` to configure the behavior of a specific actor type when dapr runtime starts. @@ -87,26 +85,25 @@ def as_dict(self) -> Dict[str, Any]: """Returns ActorTypeConfig as a dict.""" configDict: Dict[str, Any] = dict() - configDict['entities'] = [self._actor_type] + configDict["entities"] = [self._actor_type] if self._actor_idle_timeout is not None: - configDict.update({'actorIdleTimeout': self._actor_idle_timeout}) + configDict.update({"actorIdleTimeout": self._actor_idle_timeout}) if self._actor_scan_interval is not None: - configDict.update({'actorScanInterval': self._actor_scan_interval}) + configDict.update({"actorScanInterval": self._actor_scan_interval}) if self._drain_ongoing_call_timeout is not None: - configDict.update({'drainOngoingCallTimeout': self._drain_ongoing_call_timeout}) + configDict.update({"drainOngoingCallTimeout": self._drain_ongoing_call_timeout}) if self._drain_rebalanced_actors is not None: - configDict.update({'drainRebalancedActors': self._drain_rebalanced_actors}) + configDict.update({"drainRebalancedActors": self._drain_rebalanced_actors}) if self._reentrancy: - configDict.update({'reentrancy': self._reentrancy.as_dict()}) + configDict.update({"reentrancy": self._reentrancy.as_dict()}) if self._reminders_storage_partitions: - configDict.update( - {'remindersStoragePartitions': self._reminders_storage_partitions}) + configDict.update({"remindersStoragePartitions": self._reminders_storage_partitions}) return configDict @@ -117,14 +114,15 @@ class ActorRuntimeConfig: """ def __init__( - self, - actor_idle_timeout: Optional[timedelta] = timedelta(hours=1), - actor_scan_interval: Optional[timedelta] = timedelta(seconds=30), - drain_ongoing_call_timeout: Optional[timedelta] = timedelta(minutes=1), - drain_rebalanced_actors: Optional[bool] = True, - reentrancy: Optional[ActorReentrancyConfig] = None, - reminders_storage_partitions: Optional[int] = None, - actor_type_configs: List[ActorTypeConfig] = []): + self, + actor_idle_timeout: Optional[timedelta] = timedelta(hours=1), + actor_scan_interval: Optional[timedelta] = timedelta(seconds=30), + drain_ongoing_call_timeout: Optional[timedelta] = timedelta(minutes=1), + drain_rebalanced_actors: Optional[bool] = True, + reentrancy: Optional[ActorReentrancyConfig] = None, + reminders_storage_partitions: Optional[int] = None, + actor_type_configs: List[ActorTypeConfig] = [], + ): """Inits :class:`ActorRuntimeConfig` to configure actors when dapr runtime starts. Args: @@ -175,24 +173,23 @@ def as_dict(self) -> Dict[str, Any]: entities: Set[str] = self._entities configDict: Dict[str, Any] = { - 'actorIdleTimeout': self._actor_idle_timeout, - 'actorScanInterval': self._actor_scan_interval, - 'drainOngoingCallTimeout': self._drain_ongoing_call_timeout, - 'drainRebalancedActors': self._drain_rebalanced_actors, + "actorIdleTimeout": self._actor_idle_timeout, + "actorScanInterval": self._actor_scan_interval, + "drainOngoingCallTimeout": self._drain_ongoing_call_timeout, + "drainRebalancedActors": self._drain_rebalanced_actors, } if self._reentrancy: - configDict.update({'reentrancy': self._reentrancy.as_dict()}) + configDict.update({"reentrancy": self._reentrancy.as_dict()}) if self._reminders_storage_partitions: - configDict.update( - {'remindersStoragePartitions': self._reminders_storage_partitions}) + configDict.update({"remindersStoragePartitions": self._reminders_storage_partitions}) - configDict['entitiesConfig'] = [] + configDict["entitiesConfig"] = [] for entityConfig in self._entitiesConfig: - configDict['entitiesConfig'].append(entityConfig.as_dict()) + configDict["entitiesConfig"].append(entityConfig.as_dict()) entities.add(entityConfig._actor_type) - configDict['entities'] = list(entities) + configDict["entities"] = list(entities) return configDict diff --git a/dapr/actor/runtime/context.py b/dapr/actor/runtime/context.py index c90bcfd48..571cdeb3d 100644 --- a/dapr/actor/runtime/context.py +++ b/dapr/actor/runtime/context.py @@ -19,6 +19,7 @@ from dapr.serializers import Serializer from typing import Callable, Optional, TYPE_CHECKING + if TYPE_CHECKING: from dapr.actor.runtime.actor import Actor from dapr.actor.runtime._type_information import ActorTypeInformation @@ -42,10 +43,13 @@ class ActorRuntimeContext: """ def __init__( - self, actor_type_info: 'ActorTypeInformation', - message_serializer: Serializer, state_serializer: Serializer, - actor_client: DaprActorClientBase, - actor_factory: Optional[Callable[['ActorRuntimeContext', ActorId], 'Actor']] = None): + self, + actor_type_info: "ActorTypeInformation", + message_serializer: Serializer, + state_serializer: Serializer, + actor_client: DaprActorClientBase, + actor_factory: Optional[Callable[["ActorRuntimeContext", ActorId], "Actor"]] = None, + ): """Creates :class:`ActorRuntimeContext` object. Args: @@ -68,7 +72,7 @@ def __init__( self._provider: StateProvider = StateProvider(self._dapr_client, state_serializer) @property - def actor_type_info(self) -> 'ActorTypeInformation': + def actor_type_info(self) -> "ActorTypeInformation": """Return :class:`ActorTypeInformation` in this context.""" return self._actor_type_info @@ -92,7 +96,7 @@ def dapr_client(self) -> DaprActorClientBase: """Return dapr client.""" return self._dapr_client - def create_actor(self, actor_id: ActorId) -> 'Actor': + def create_actor(self, actor_id: ActorId) -> "Actor": """Create the object of :class:`Actor` for :class:`ActorId`. Args: @@ -103,8 +107,7 @@ def create_actor(self, actor_id: ActorId) -> 'Actor': """ return self._actor_factory(self, actor_id) - def _default_actor_factory( - self, ctx: 'ActorRuntimeContext', actor_id: ActorId) -> 'Actor': + def _default_actor_factory(self, ctx: "ActorRuntimeContext", actor_id: ActorId) -> "Actor": """Creates new Actor with actor_id. Args: diff --git a/dapr/actor/runtime/manager.py b/dapr/actor/runtime/manager.py index 60f55ff52..8e9ea4292 100644 --- a/dapr/actor/runtime/manager.py +++ b/dapr/actor/runtime/manager.py @@ -27,8 +27,8 @@ from dapr.actor.runtime._reminder_data import ActorReminderData from dapr.actor.runtime.reentrancy_context import reentrancy_ctx -TIMER_METHOD_NAME = 'fire_timer' -REMINDER_METHOD_NAME = 'receive_reminder' +TIMER_METHOD_NAME = "fire_timer" +REMINDER_METHOD_NAME = "receive_reminder" class ActorManager: @@ -57,15 +57,16 @@ async def deactivate_actor(self, actor_id: ActorId): async with self._active_actors_lock: deactivated_actor = self._active_actors.pop(actor_id.id, None) if not deactivated_actor: - raise ValueError(f'{actor_id} is not activated') + raise ValueError(f"{actor_id} is not activated") await deactivated_actor._on_deactivate_internal() async def fire_reminder( - self, actor_id: ActorId, - reminder_name: str, request_body: bytes) -> None: + self, actor_id: ActorId, reminder_name: str, request_body: bytes + ) -> None: if not self._runtime_ctx.actor_type_info.is_remindable(): raise ValueError( - f'{self._runtime_ctx.actor_type_info.type_name} does not implment Remindable.') + f"{self._runtime_ctx.actor_type_info.type_name} does not implment Remindable." + ) request_obj = self._message_serializer.deserialize(request_body, object) if isinstance(request_obj, dict): reminder_data = ActorReminderData.from_dict(reminder_name, request_obj) @@ -74,26 +75,29 @@ async def fire_reminder( async def invoke_reminder(actor: Actor) -> Optional[bytes]: reminder = getattr(actor, REMINDER_METHOD_NAME) if reminder is not None: - await reminder(reminder_data.reminder_name, reminder_data.state, - reminder_data.due_time, reminder_data.period, reminder_data.ttl) + await reminder( + reminder_data.reminder_name, + reminder_data.state, + reminder_data.due_time, + reminder_data.period, + reminder_data.ttl, + ) return None await self._dispatch_internal(actor_id, self._reminder_method_context, invoke_reminder) - async def fire_timer( - self, actor_id: ActorId, - timer_name: str, request_body: bytes) -> None: + async def fire_timer(self, actor_id: ActorId, timer_name: str, request_body: bytes) -> None: timer = self._message_serializer.deserialize(request_body, object) async def invoke_timer(actor: Actor) -> Optional[bytes]: - await actor._fire_timer_internal(timer['callback'], timer['data']) + await actor._fire_timer_internal(timer["callback"], timer["data"]) return None await self._dispatch_internal(actor_id, self._timer_method_context, invoke_timer) async def dispatch( - self, actor_id: ActorId, - actor_method_name: str, request_body: bytes) -> bytes: + self, actor_id: ActorId, actor_method_name: str, request_body: bytes + ) -> bytes: method_context = ActorMethodContext.create_for_actor(actor_method_name) arg_types = self._dispatcher.get_arg_types(actor_method_name) @@ -113,8 +117,11 @@ async def invoke_method(actor): return self._message_serializer.serialize(rtn_obj) async def _dispatch_internal( - self, actor_id: ActorId, method_context: ActorMethodContext, - dispatch_action: Callable[[Actor], Coroutine[Any, Any, Optional[bytes]]]) -> object: + self, + actor_id: ActorId, + method_context: ActorMethodContext, + dispatch_action: Callable[[Actor], Coroutine[Any, Any, Optional[bytes]]], + ) -> object: # Activate actor when actor is invoked. if actor_id.id not in self._active_actors: await self.activate_actor(actor_id) @@ -122,7 +129,7 @@ async def _dispatch_internal( async with self._active_actors_lock: actor = self._active_actors.get(actor_id.id, None) if not actor: - raise ValueError(f'{actor_id} is not activated') + raise ValueError(f"{actor_id} is not activated") try: if reentrancy_ctx.get(None) is not None: diff --git a/dapr/actor/runtime/method_dispatcher.py b/dapr/actor/runtime/method_dispatcher.py index 68bf4dbd6..aeff21b44 100644 --- a/dapr/actor/runtime/method_dispatcher.py +++ b/dapr/actor/runtime/method_dispatcher.py @@ -25,21 +25,20 @@ def __init__(self, type_info: ActorTypeInformation): async def dispatch(self, actor: Actor, name: str, *args, **kwargs) -> Any: self._check_name_exist(name) - return await getattr(actor, self._dispatch_mapping[name]['method_name'])(*args, **kwargs) + return await getattr(actor, self._dispatch_mapping[name]["method_name"])(*args, **kwargs) def get_arg_names(self, name: str) -> List[str]: self._check_name_exist(name) - return self._dispatch_mapping[name]['arg_names'] + return self._dispatch_mapping[name]["arg_names"] def get_arg_types(self, name: str) -> List[Any]: self._check_name_exist(name) - return self._dispatch_mapping[name]['arg_types'] + return self._dispatch_mapping[name]["arg_types"] def get_return_type(self, name: str) -> Dict[str, Any]: self._check_name_exist(name) - return self._dispatch_mapping[name]['return_types'] + return self._dispatch_mapping[name]["return_types"] def _check_name_exist(self, name: str): if name not in self._dispatch_mapping: - raise AttributeError( - f'type object {self.__class__.__name__} has no method {name}') + raise AttributeError(f"type object {self.__class__.__name__} has no method {name}") diff --git a/dapr/actor/runtime/remindable.py b/dapr/actor/runtime/remindable.py index d06f49117..6d6a2a3f0 100644 --- a/dapr/actor/runtime/remindable.py +++ b/dapr/actor/runtime/remindable.py @@ -24,9 +24,14 @@ class Remindable(ABC): """ @abstractmethod - async def receive_reminder(self, name: str, state: bytes, - due_time: timedelta, period: timedelta, - ttl: Optional[timedelta] = None) -> None: + async def receive_reminder( + self, + name: str, + state: bytes, + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta] = None, + ) -> None: """A callback which will be called when reminder is triggered. Args: diff --git a/dapr/actor/runtime/runtime.py b/dapr/actor/runtime/runtime.py index 8f89d0834..c08d67883 100644 --- a/dapr/actor/runtime/runtime.py +++ b/dapr/actor/runtime/runtime.py @@ -42,10 +42,12 @@ class ActorRuntime: @classmethod async def register_actor( - cls, actor: Type[Actor], - message_serializer: Serializer = DefaultJSONSerializer(), - state_serializer: Serializer = DefaultJSONSerializer(), - http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS) -> None: + cls, + actor: Type[Actor], + message_serializer: Serializer = DefaultJSONSerializer(), + state_serializer: Serializer = DefaultJSONSerializer(), + http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS, + ) -> None: """Registers an :class:`Actor` object with the runtime. Args: @@ -83,14 +85,18 @@ async def deactivate(cls, actor_type_name: str, actor_id: str) -> None: """ manager = await cls._get_actor_manager(actor_type_name) if not manager: - raise ValueError(f'{actor_type_name} is not registered.') + raise ValueError(f"{actor_type_name} is not registered.") await manager.deactivate_actor(ActorId(actor_id)) @classmethod async def dispatch( - cls, actor_type_name: str, actor_id: str, - actor_method_name: str, request_body: bytes, - reentrancy_id: Optional[str] = None) -> bytes: + cls, + actor_type_name: str, + actor_id: str, + actor_method_name: str, + request_body: bytes, + reentrancy_id: Optional[str] = None, + ) -> bytes: """Dispatches actor method defined in actor_type. Args: @@ -110,13 +116,13 @@ async def dispatch( reentrancy_ctx.set(reentrancy_id) manager = await cls._get_actor_manager(actor_type_name) if not manager: - raise ValueError(f'{actor_type_name} is not registered.') + raise ValueError(f"{actor_type_name} is not registered.") return await manager.dispatch(ActorId(actor_id), actor_method_name, request_body) @classmethod async def fire_reminder( - cls, actor_type_name: str, actor_id: str, - name: str, state: bytes) -> None: + cls, actor_type_name: str, actor_id: str, name: str, state: bytes + ) -> None: """Fires a reminder for the Actor. Args: @@ -131,14 +137,11 @@ async def fire_reminder( manager = await cls._get_actor_manager(actor_type_name) if not manager: - raise ValueError(f'{actor_type_name} is not registered.') + raise ValueError(f"{actor_type_name} is not registered.") await manager.fire_reminder(ActorId(actor_id), name, state) @classmethod - async def fire_timer( - cls, actor_type_name: str, - actor_id: str, name: str, - state: bytes) -> None: + async def fire_timer(cls, actor_type_name: str, actor_id: str, name: str, state: bytes) -> None: """Fires a timer for the Actor. Args: @@ -152,7 +155,7 @@ async def fire_timer( """ manager = await cls._get_actor_manager(actor_type_name) if not manager: - raise ValueError(f'{actor_type_name} is not registered.') + raise ValueError(f"{actor_type_name} is not registered.") await manager.fire_timer(ActorId(actor_id), name, state) @classmethod diff --git a/dapr/actor/runtime/state_change.py b/dapr/actor/runtime/state_change.py index ad3fc3d82..de6580850 100644 --- a/dapr/actor/runtime/state_change.py +++ b/dapr/actor/runtime/state_change.py @@ -16,13 +16,14 @@ from enum import Enum from typing import TypeVar, Generic -T = TypeVar('T') +T = TypeVar("T") class StateChangeKind(Enum): """A enumeration that represents the kind of state change for an actor state when saves change is called to a set of actor states. """ + # No change in state none = 0 # The state needs to be added diff --git a/dapr/actor/runtime/state_manager.py b/dapr/actor/runtime/state_manager.py index 91fa03790..cf4768009 100644 --- a/dapr/actor/runtime/state_manager.py +++ b/dapr/actor/runtime/state_manager.py @@ -20,11 +20,12 @@ from dapr.actor.runtime.reentrancy_context import reentrancy_ctx from typing import Any, Callable, Dict, Generic, List, Tuple, TypeVar, Optional, TYPE_CHECKING + if TYPE_CHECKING: from dapr.actor.runtime.actor import Actor -T = TypeVar('T') -CONTEXT: ContextVar[Optional[Dict[str, Any]]] = ContextVar('state_tracker_context') +T = TypeVar("T") +CONTEXT: ContextVar[Optional[Dict[str, Any]]] = ContextVar("state_tracker_context") class StateMetadata(Generic[T]): @@ -50,30 +51,30 @@ def change_kind(self, new_kind: StateChangeKind) -> None: class ActorStateManager(Generic[T]): - def __init__(self, actor: 'Actor'): + def __init__(self, actor: "Actor"): self._actor = actor if not actor.runtime_ctx: - raise AttributeError('runtime context was not set') + raise AttributeError("runtime context was not set") self._type_name = actor.runtime_ctx.actor_type_info.type_name self._default_state_change_tracker: Dict[str, StateMetadata] = {} async def add_state(self, state_name: str, value: T) -> None: if not await self.try_add_state(state_name, value): - raise ValueError(f'The actor state name {state_name} already exist.') + raise ValueError(f"The actor state name {state_name} already exist.") async def try_add_state(self, state_name: str, value: T) -> bool: state_change_tracker = self._get_contextual_state_tracker() if state_name in state_change_tracker: state_metadata = state_change_tracker[state_name] if state_metadata.change_kind == StateChangeKind.remove: - state_change_tracker[state_name] = \ - StateMetadata(value, StateChangeKind.update) + state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.update) return True return False existed = await self._actor.runtime_ctx.state_provider.contains_state( - self._type_name, self._actor.id.id, state_name) + self._type_name, self._actor.id.id, state_name + ) if not existed: return False @@ -85,7 +86,7 @@ async def get_state(self, state_name: str) -> Optional[T]: if has_value: return val else: - raise KeyError(f'Actor State with name {state_name} was not found.') + raise KeyError(f"Actor State with name {state_name} was not found.") async def try_get_state(self, state_name: str) -> Tuple[bool, Optional[T]]: state_change_tracker = self._get_contextual_state_tracker() @@ -95,7 +96,8 @@ async def try_get_state(self, state_name: str) -> Tuple[bool, Optional[T]]: return False, None return True, state_metadata.value has_value, val = await self._actor.runtime_ctx.state_provider.try_load_state( - self._type_name, self._actor.id.id, state_name) + self._type_name, self._actor.id.id, state_name + ) if has_value: state_change_tracker[state_name] = StateMetadata(val, StateChangeKind.none) return has_value, val @@ -106,14 +108,17 @@ async def set_state(self, state_name: str, value: T) -> None: state_metadata = state_change_tracker[state_name] state_metadata.value = value - if state_metadata.change_kind == StateChangeKind.none \ - or state_metadata.change_kind == StateChangeKind.remove: + if ( + state_metadata.change_kind == StateChangeKind.none + or state_metadata.change_kind == StateChangeKind.remove + ): state_metadata.change_kind = StateChangeKind.update state_change_tracker[state_name] = state_metadata return existed = await self._actor.runtime_ctx.state_provider.contains_state( - self._type_name, self._actor.id.id, state_name) + self._type_name, self._actor.id.id, state_name + ) if existed: state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.update) else: @@ -121,7 +126,7 @@ async def set_state(self, state_name: str, value: T) -> None: async def remove_state(self, state_name: str) -> None: if not await self.try_remove_state(state_name): - raise KeyError(f'Actor State with name {state_name} was not found.') + raise KeyError(f"Actor State with name {state_name} was not found.") async def try_remove_state(self, state_name: str) -> bool: state_change_tracker = self._get_contextual_state_tracker() @@ -136,7 +141,8 @@ async def try_remove_state(self, state_name: str) -> bool: return True existed = await self._actor.runtime_ctx.state_provider.contains_state( - self._type_name, self._actor.id.id, state_name) + self._type_name, self._actor.id.id, state_name + ) if existed: state_change_tracker[state_name] = StateMetadata(None, StateChangeKind.remove) return True @@ -148,30 +154,33 @@ async def contains_state(self, state_name: str) -> bool: state_metadata = state_change_tracker[state_name] return state_metadata.change_kind != StateChangeKind.remove return await self._actor.runtime_ctx.state_provider.contains_state( - self._type_name, self._actor.id.id, state_name) + self._type_name, self._actor.id.id, state_name + ) async def get_or_add_state(self, state_name: str, value: T) -> Optional[T]: state_change_tracker = self._get_contextual_state_tracker() has_value, val = await self.try_get_state(state_name) if has_value: return val - change_kind = StateChangeKind.update if self.is_state_marked_for_remove(state_name) \ + change_kind = ( + StateChangeKind.update + if self.is_state_marked_for_remove(state_name) else StateChangeKind.add + ) state_change_tracker[state_name] = StateMetadata(value, change_kind) return value async def add_or_update_state( - self, state_name: str, - value: T, update_value_factory: Callable[[str, T], T]) -> T: + self, state_name: str, value: T, update_value_factory: Callable[[str, T], T] + ) -> T: if not callable(update_value_factory): - raise AttributeError('update_value_factory is not callable') + raise AttributeError("update_value_factory is not callable") state_change_tracker = self._get_contextual_state_tracker() if state_name in state_change_tracker: state_metadata = state_change_tracker[state_name] if state_metadata.change_kind == StateChangeKind.remove: - state_change_tracker[state_name] = \ - StateMetadata(value, StateChangeKind.update) + state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.update) return value new_value = update_value_factory(state_name, state_metadata.value) state_metadata.value = new_value @@ -181,15 +190,14 @@ async def add_or_update_state( return new_value has_value, val = await self._actor.runtime_ctx.state_provider.try_load_state( - self._type_name, self._actor.id.id, state_name) + self._type_name, self._actor.id.id, state_name + ) if has_value: new_value = update_value_factory(state_name, val) - state_change_tracker[state_name] = \ - StateMetadata(new_value, StateChangeKind.update) + state_change_tracker[state_name] = StateMetadata(new_value, StateChangeKind.update) return new_value - state_change_tracker[state_name] = \ - StateMetadata(value, StateChangeKind.add) + state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add) return value async def get_state_names(self) -> List[str]: @@ -222,36 +230,37 @@ async def save_state(self) -> None: for state_name, state_metadata in state_change_tracker.items(): if state_metadata.change_kind == StateChangeKind.none: continue - state_changes.append(ActorStateChange( - state_name, state_metadata.value, - state_metadata.change_kind)) + state_changes.append( + ActorStateChange(state_name, state_metadata.value, state_metadata.change_kind) + ) if state_metadata.change_kind == StateChangeKind.remove: states_to_remove.append(state_name) # Mark the states as unmodified so that tracking for next invocation is done correctly. state_metadata.change_kind = StateChangeKind.none if len(state_changes) > 0: await self._actor.runtime_ctx.state_provider.save_state( - self._type_name, self._actor.id.id, state_changes) + self._type_name, self._actor.id.id, state_changes + ) for state_name in states_to_remove: state_change_tracker.pop(state_name, None) def is_state_marked_for_remove(self, state_name: str) -> bool: state_change_tracker = self._get_contextual_state_tracker() - return state_name in state_change_tracker and \ - state_change_tracker[state_name].change_kind == StateChangeKind.remove + return ( + state_name in state_change_tracker + and state_change_tracker[state_name].change_kind == StateChangeKind.remove + ) def _get_contextual_state_tracker(self) -> Dict[str, StateMetadata]: context = CONTEXT.get(None) - if (context is not None and reentrancy_ctx.get(None) is not None): - return context['tracker'] + if context is not None and reentrancy_ctx.get(None) is not None: + return context["tracker"] else: return self._default_state_change_tracker def set_state_context(self, contextID: Optional[str]): - if (contextID is not None): - CONTEXT.set({ - 'id': contextID, - 'tracker': {}}) + if contextID is not None: + CONTEXT.set({"id": contextID, "tracker": {}}) else: CONTEXT.set(None) return diff --git a/dapr/aio/clients/__init__.py b/dapr/aio/clients/__init__.py index db932b206..866119ff0 100644 --- a/dapr/aio/clients/__init__.py +++ b/dapr/aio/clients/__init__.py @@ -24,18 +24,18 @@ from google.protobuf.message import Message as GrpcMessage __all__ = [ - 'DaprClient', - 'DaprActorClientBase', - 'DaprActorHttpClient', - 'DaprInternalError', - 'ERROR_CODE_UNKNOWN', + "DaprClient", + "DaprActorClientBase", + "DaprActorHttpClient", + "DaprInternalError", + "ERROR_CODE_UNKNOWN", ] from grpc.aio import ( # type: ignore UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, - StreamStreamClientInterceptor + StreamStreamClientInterceptor, ) @@ -46,16 +46,22 @@ class DaprClient(DaprGrpcClientAsync): variable. See: https://github.com/dapr/python-sdk/issues/176 for more details""" def __init__( - self, - address: Optional[str] = None, - headers_callback: Optional[Callable[[], Dict[str, str]]] = None, - interceptors: Optional[List[Union[ - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, - StreamStreamClientInterceptor]]] = None, - http_timeout_seconds: Optional[int] = None, - max_grpc_message_length: Optional[int] = None): + self, + address: Optional[str] = None, + headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + interceptors: Optional[ + List[ + Union[ + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, + StreamStreamClientInterceptor, + ] + ] + ] = None, + http_timeout_seconds: Optional[int] = None, + max_grpc_message_length: Optional[int] = None, + ): """Connects to Dapr Runtime and via gRPC and HTTP. Args: @@ -74,27 +80,30 @@ def __init__( invocation_protocol = settings.DAPR_API_METHOD_INVOCATION_PROTOCOL.upper() - if invocation_protocol == 'HTTP': + if invocation_protocol == "HTTP": if http_timeout_seconds is None: http_timeout_seconds = settings.DAPR_HTTP_TIMEOUT_SECONDS - self.invocation_client = DaprInvocationHttpClient(headers_callback=headers_callback, - timeout=http_timeout_seconds) - elif invocation_protocol == 'GRPC': + self.invocation_client = DaprInvocationHttpClient( + headers_callback=headers_callback, timeout=http_timeout_seconds + ) + elif invocation_protocol == "GRPC": pass else: raise DaprInternalError( - f'Unknown value for DAPR_API_METHOD_INVOCATION_PROTOCOL: {invocation_protocol}') + f"Unknown value for DAPR_API_METHOD_INVOCATION_PROTOCOL: {invocation_protocol}" + ) async def invoke_method( - self, - app_id: str, - method_name: str, - data: Union[bytes, str, GrpcMessage], - content_type: Optional[str] = None, - metadata: Optional[MetadataTuple] = None, - http_verb: Optional[str] = None, - http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None) -> InvokeMethodResponse: + self, + app_id: str, + method_name: str, + data: Union[bytes, str, GrpcMessage], + content_type: Optional[str] = None, + metadata: Optional[MetadataTuple] = None, + http_verb: Optional[str] = None, + http_querystring: Optional[MetadataTuple] = None, + timeout: Optional[int] = None, + ) -> InvokeMethodResponse: """Invoke a service method over gRPC or HTTP. Args: @@ -119,7 +128,8 @@ async def invoke_method( metadata=metadata, http_verb=http_verb, http_querystring=http_querystring, - timeout=timeout) + timeout=timeout, + ) else: return await super().invoke_method( app_id, @@ -129,5 +139,5 @@ async def invoke_method( metadata=metadata, http_verb=http_verb, http_querystring=http_querystring, - timeout=timeout + timeout=timeout, ) diff --git a/dapr/aio/clients/grpc/_asynchelpers.py b/dapr/aio/clients/grpc/_asynchelpers.py index 484e2cdf3..73e837f3d 100644 --- a/dapr/aio/clients/grpc/_asynchelpers.py +++ b/dapr/aio/clients/grpc/_asynchelpers.py @@ -20,14 +20,16 @@ class _ClientCallDetailsAsync( - namedtuple( - '_ClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']), - ClientCallDetails): + namedtuple( + "_ClientCallDetails", ["method", "timeout", "metadata", "credentials", "wait_for_ready"] + ), + ClientCallDetails, +): """This is an implementation of the ClientCallDetails interface needed for interceptors. This class takes five named values and inherits the ClientCallDetails from grpc package. This class encloses the values that describe a RPC to be invoked. """ + pass @@ -46,9 +48,7 @@ class DaprClientInterceptorAsync(UnaryUnaryClientInterceptor): intercepted_channel = grpc.intercept_channel(grpc_channel, interceptor) """ - def __init__( - self, - metadata: List[Tuple[str, str]]): + def __init__(self, metadata: List[Tuple[str, str]]): """Initializes the metadata field for the class. Args: @@ -58,9 +58,7 @@ def __init__( self._metadata = metadata - async def _intercept_call( - self, - client_call_details: ClientCallDetails) -> ClientCallDetails: + async def _intercept_call(self, client_call_details: ClientCallDetails) -> ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details. @@ -78,15 +76,15 @@ async def _intercept_call( metadata.extend(self._metadata) new_call_details = _ClientCallDetailsAsync( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials, client_call_details.wait_for_ready) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + ) return new_call_details - async def intercept_unary_unary( - self, - continuation, - client_call_details, - request): + async def intercept_unary_unary(self, continuation, client_call_details, request): """This method intercepts a unary-unary gRPC call. This is the implementation of the abstract method defined in UnaryUnaryClientInterceptor defined in grpc. This is invoked automatically by grpc based on the order in which interceptors are added to the channel. diff --git a/dapr/aio/clients/grpc/client.py b/dapr/aio/clients/grpc/client.py index bc3c46b2f..41101985a 100644 --- a/dapr/aio/clients/grpc/client.py +++ b/dapr/aio/clients/grpc/client.py @@ -35,7 +35,7 @@ UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, - StreamStreamClientInterceptor + StreamStreamClientInterceptor, ) from dapr.clients.exceptions import DaprInternalError @@ -57,7 +57,7 @@ from dapr.clients.grpc._request import ( InvokeMethodRequest, BindingRequest, - TransactionalStateOperation + TransactionalStateOperation, ) from dapr.clients.grpc._response import ( BindingResponse, @@ -105,12 +105,17 @@ class DaprGrpcClientAsync: def __init__( self, address: Optional[str] = None, - interceptors: Optional[List[Union[ - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, - StreamStreamClientInterceptor]]] = None, - max_grpc_message_length: Optional[int] = None + interceptors: Optional[ + List[ + Union[ + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, + StreamStreamClientInterceptor, + ] + ] + ] = None, + max_grpc_message_length: Optional[int] = None, ): """Connects to Dapr Runtime and initialize gRPC client stub. @@ -123,43 +128,48 @@ def __init__( max_grpc_messsage_length (int, optional): The maximum grpc send and receive message length in bytes. """ - useragent = f'dapr-sdk-python/{__version__}' + useragent = f"dapr-sdk-python/{__version__}" if not max_grpc_message_length: options = [ - ('grpc.primary_user_agent', useragent), + ("grpc.primary_user_agent", useragent), ] else: options = [ - ('grpc.max_send_message_length', max_grpc_message_length), - ('grpc.max_receive_message_length', max_grpc_message_length), - ('grpc.primary_user_agent', useragent) + ("grpc.max_send_message_length", max_grpc_message_length), + ("grpc.max_receive_message_length", max_grpc_message_length), + ("grpc.primary_user_agent", useragent), ] if not address: - address = settings.DAPR_GRPC_ENDPOINT or (f"{settings.DAPR_RUNTIME_HOST}:" - f"{settings.DAPR_GRPC_PORT}") + address = settings.DAPR_GRPC_ENDPOINT or ( + f"{settings.DAPR_RUNTIME_HOST}:" f"{settings.DAPR_GRPC_PORT}" + ) try: self._uri = GrpcEndpoint(address) except ValueError as error: - raise DaprInternalError(f'{error}') from error + raise DaprInternalError(f"{error}") from error if self._uri.tls: - self._channel = grpc.aio.secure_channel(self._uri.endpoint, - credentials=self.get_credentials(), - options=options) # type: ignore + self._channel = grpc.aio.secure_channel( + self._uri.endpoint, credentials=self.get_credentials(), options=options + ) # type: ignore else: - self._channel = grpc.aio.insecure_channel(self._uri.endpoint, - options) # type: ignore + self._channel = grpc.aio.insecure_channel(self._uri.endpoint, options) # type: ignore if settings.DAPR_API_TOKEN: - api_token_interceptor = DaprClientInterceptorAsync([ - ('dapr-api-token', settings.DAPR_API_TOKEN), ]) + api_token_interceptor = DaprClientInterceptorAsync( + [ + ("dapr-api-token", settings.DAPR_API_TOKEN), + ] + ) self._channel = grpc.aio.insecure_channel( # type: ignore - address, options=options, interceptors=(api_token_interceptor,)) + address, options=options, interceptors=(api_token_interceptor,) + ) if interceptors: self._channel = grpc.aio.insecure_channel( # type: ignore - address, options=options, *interceptors) + address, options=options, *interceptors + ) self._stub = api_service_v1.DaprStub(self._channel) @@ -168,7 +178,7 @@ def get_credentials(self): async def close(self): """Closes Dapr runtime gRPC channel.""" - if hasattr(self, '_channel') and self._channel: + if hasattr(self, "_channel") and self._channel: await self._channel.close() async def __aenter__(self) -> Self: # type: ignore @@ -178,8 +188,7 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: await self.close() def _get_http_extension( - self, http_verb: str, - http_querystring: Optional[MetadataTuple] = None + self, http_verb: str, http_querystring: Optional[MetadataTuple] = None ) -> common_v1.HTTPExtension: # type: ignore verb = common_v1.HTTPExtension.Verb.Value(http_verb) # type: ignore http_ext = common_v1.HTTPExtension(verb=verb) @@ -188,15 +197,16 @@ def _get_http_extension( return http_ext async def invoke_method( - self, - app_id: str, - method_name: str, - data: Union[bytes, str, GrpcMessage] = '', - content_type: Optional[str] = None, - metadata: Optional[MetadataTuple] = None, - http_verb: Optional[str] = None, - http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None) -> InvokeMethodResponse: + self, + app_id: str, + method_name: str, + data: Union[bytes, str, GrpcMessage] = "", + content_type: Optional[str] = None, + metadata: Optional[MetadataTuple] = None, + http_verb: Optional[str] = None, + http_querystring: Optional[MetadataTuple] = None, + timeout: Optional[int] = None, + ) -> InvokeMethodResponse: """Invokes the target service to call method. This can invoke the specified target service to call method with bytes array data or @@ -270,11 +280,18 @@ async def invoke_method( Returns: :class:`InvokeMethodResponse` object returned from callee """ - warn('invoke_method with protocol gRPC is deprecated. Use gRPC proxying instead.', - DeprecationWarning, stacklevel=2) + warn( + "invoke_method with protocol gRPC is deprecated. Use gRPC proxying instead.", + DeprecationWarning, + stacklevel=2, + ) if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) req_data = InvokeMethodRequest(data, content_type) http_ext = None @@ -290,7 +307,8 @@ async def invoke_method( method=method_name, data=req_data.proto, content_type=content_type, - http_extension=http_ext) + http_extension=http_ext, + ), ) call = self._stub.InvokeService(req, metadata=metadata, timeout=timeout) @@ -301,12 +319,13 @@ async def invoke_method( return resp_data async def invoke_binding( - self, - binding_name: str, - operation: str, - data: Union[bytes, str] = '', - binding_metadata: Dict[str, str] = {}, - metadata: Optional[MetadataTuple] = None) -> BindingResponse: + self, + binding_name: str, + operation: str, + data: Union[bytes, str] = "", + binding_metadata: Dict[str, str] = {}, + metadata: Optional[MetadataTuple] = None, + ) -> BindingResponse: """Invokes the output binding with the specified operation. The data field takes any JSON serializable value and acts as the @@ -338,8 +357,12 @@ async def invoke_binding( :class:`InvokeBindingResponse` object returned from binding """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token ' - 'headers and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token " + "headers and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) req_data = BindingRequest(data, binding_metadata) @@ -347,23 +370,24 @@ async def invoke_binding( name=binding_name, data=req_data.data, metadata=req_data.binding_metadata, - operation=operation + operation=operation, ) call = self._stub.InvokeBinding(req, metadata=metadata) response = await call return BindingResponse( - response.data, dict(response.metadata), - await call.initial_metadata()) + response.data, dict(response.metadata), await call.initial_metadata() + ) async def publish_event( - self, - pubsub_name: str, - topic_name: str, - data: Union[bytes, str], - publish_metadata: Dict[str, str] = {}, - metadata: Optional[MetadataTuple] = None, - data_content_type: Optional[str] = None) -> DaprResponse: + self, + pubsub_name: str, + topic_name: str, + data: Union[bytes, str], + publish_metadata: Dict[str, str] = {}, + metadata: Optional[MetadataTuple] = None, + data_content_type: Optional[str] = None, + ) -> DaprResponse: """Publish to a given topic. This publishes an event with bytes array or str data to a specified topic and specified pubsub component. The str data is encoded into bytes with default @@ -394,18 +418,22 @@ async def publish_event( :class:`DaprResponse` gRPC metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not isinstance(data, bytes) and not isinstance(data, str): - raise ValueError(f'invalid type for data {type(data)}') + raise ValueError(f"invalid type for data {type(data)}") req_data: bytes if isinstance(data, bytes): req_data = data else: if isinstance(data, str): - req_data = data.encode('utf-8') + req_data = data.encode("utf-8") content_type = "" if data_content_type: @@ -415,7 +443,8 @@ async def publish_event( topic=topic_name, data=req_data, data_content_type=content_type, - metadata=publish_metadata) + metadata=publish_metadata, + ) call = self._stub.PublishEvent(req, metadata=metadata) # response is google.protobuf.Empty @@ -424,11 +453,12 @@ async def publish_event( return DaprResponse(await call.initial_metadata()) async def get_state( - self, - store_name: str, - key: str, - state_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> StateResponse: + self, + store_name: str, + key: str, + state_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> StateResponse: """Gets value from a statestore with a key The example gets value from a statestore: @@ -452,8 +482,12 @@ async def get_state( and value obtained from the state store """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") @@ -461,17 +495,17 @@ async def get_state( call = self._stub.GetState(req, metadata=metadata) response = await call return StateResponse( - data=response.data, - etag=response.etag, - headers=await call.initial_metadata()) + data=response.data, etag=response.etag, headers=await call.initial_metadata() + ) async def get_bulk_state( - self, - store_name: str, - keys: Sequence[str], - parallelism: int = 1, - states_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> BulkStatesResponse: + self, + store_name: str, + keys: Sequence[str], + parallelism: int = 1, + states_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> BulkStatesResponse: """Gets values from a statestore with keys The example gets value from a statestore: @@ -496,36 +530,31 @@ async def get_bulk_state( and value obtained from the state store """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") req = api_v1.GetBulkStateRequest( - store_name=store_name, - keys=keys, - parallelism=parallelism, - metadata=states_metadata) + store_name=store_name, keys=keys, parallelism=parallelism, metadata=states_metadata + ) call = self._stub.GetBulkState(req, metadata=metadata) response = await call items = [] for item in response.items: items.append( - BulkStateItem( - key=item.key, - data=item.data, - etag=item.etag, - error=item.error)) - return BulkStatesResponse( - items=items, - headers=await call.initial_metadata()) + BulkStateItem(key=item.key, data=item.data, etag=item.etag, error=item.error) + ) + return BulkStatesResponse(items=items, headers=await call.initial_metadata()) async def query_state( - self, - store_name: str, - query: str, - states_metadata: Optional[Dict[str, str]] = dict()) -> QueryResponse: + self, store_name: str, query: str, states_metadata: Optional[Dict[str, str]] = dict() + ) -> QueryResponse: """Queries a statestore with a query For details on supported queries see https://docs.dapr.io/ @@ -563,43 +592,41 @@ async def query_state( :class:`QueryStateResponse` gRPC metadata returned from callee, pagination token and results of the query """ - warn('The State Store Query API is an Alpha version and is subject to change.', - UserWarning, stacklevel=2) + warn( + "The State Store Query API is an Alpha version and is subject to change.", + UserWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") - req = api_v1.QueryStateRequest( - store_name=store_name, - query=query, - metadata=states_metadata) + req = api_v1.QueryStateRequest(store_name=store_name, query=query, metadata=states_metadata) call = self._stub.QueryStateAlpha1(req) response = await call results = [] for item in response.results: results.append( - QueryResponseItem( - key=item.key, - value=item.data, - etag=item.etag, - error=item.error) + QueryResponseItem(key=item.key, value=item.data, etag=item.etag, error=item.error) ) return QueryResponse( token=response.token, results=results, metadata=response.metadata, - headers=await call.initial_metadata()) + headers=await call.initial_metadata(), + ) async def save_state( - self, - store_name: str, - key: str, - value: Union[bytes, str], - etag: Optional[str] = None, - options: Optional[StateOptions] = None, - state_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, + store_name: str, + key: str, + value: Union[bytes, str], + etag: Optional[str] = None, + options: Optional[StateOptions] = None, + state_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> DaprResponse: """Saves key-value pairs to a statestore This saves a value to the statestore with a given key and state store name. @@ -635,11 +662,15 @@ async def save_state( ValueError: store_name is empty """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not isinstance(value, (bytes, str)): - raise ValueError(f'invalid type for data {type(value)}') + raise ValueError(f"invalid type for data {type(value)}") req_value = value @@ -656,19 +687,17 @@ async def save_state( value=to_bytes(req_value), etag=common_v1.Etag(value=etag) if etag is not None else None, options=state_options, - metadata=state_metadata) + metadata=state_metadata, + ) req = api_v1.SaveStateRequest(store_name=store_name, states=[state]) call = self._stub.SaveState(req, metadata=metadata) await call - return DaprResponse( - headers=await call.initial_metadata()) + return DaprResponse(headers=await call.initial_metadata()) async def save_bulk_state( - self, - store_name: str, - states: List[StateItem], - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, store_name: str, states: List[StateItem], metadata: Optional[MetadataTuple] = None + ) -> DaprResponse: """Saves state items to a statestore This saves a given state item into the statestore specified by store_name. @@ -695,8 +724,12 @@ async def save_bulk_state( ValueError: store_name is empty """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not states or len(states) == 0: raise ValueError("States to be saved cannot be empty") @@ -704,25 +737,29 @@ async def save_bulk_state( if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") - req_states = [common_v1.StateItem( - key=i.key, - value=to_bytes(i.value), - etag=common_v1.Etag(value=i.etag) if i.etag is not None else None, - options=i.options, - metadata=i.metadata) for i in states] + req_states = [ + common_v1.StateItem( + key=i.key, + value=to_bytes(i.value), + etag=common_v1.Etag(value=i.etag) if i.etag is not None else None, + options=i.options, + metadata=i.metadata, + ) + for i in states + ] req = api_v1.SaveStateRequest(store_name=store_name, states=req_states) call = self._stub.SaveState(req, metadata=metadata) await call - return DaprResponse( - headers=await call.initial_metadata()) + return DaprResponse(headers=await call.initial_metadata()) async def execute_state_transaction( - self, - store_name: str, - operations: Sequence[TransactionalStateOperation], - transactional_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, + store_name: str, + operations: Sequence[TransactionalStateOperation], + transactional_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> DaprResponse: """Saves or deletes key-value pairs to a statestore as a transaction This saves or deletes key-values to the statestore as part of a single transaction, @@ -754,36 +791,43 @@ async def execute_state_transaction( :class:`DaprResponse` gRPC metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") - req_ops = [api_v1.TransactionalStateOperation( - operationType=o.operation_type.value, - request=common_v1.StateItem( - key=o.key, - value=to_bytes(o.data), - etag=common_v1.Etag(value=o.etag) if o.etag is not None else None)) - for o in operations] + req_ops = [ + api_v1.TransactionalStateOperation( + operationType=o.operation_type.value, + request=common_v1.StateItem( + key=o.key, + value=to_bytes(o.data), + etag=common_v1.Etag(value=o.etag) if o.etag is not None else None, + ), + ) + for o in operations + ] req = api_v1.ExecuteStateTransactionRequest( - storeName=store_name, - operations=req_ops, - metadata=transactional_metadata) + storeName=store_name, operations=req_ops, metadata=transactional_metadata + ) call = self._stub.ExecuteStateTransaction(req, metadata=metadata) await call - return DaprResponse( - headers=await call.initial_metadata()) + return DaprResponse(headers=await call.initial_metadata()) async def delete_state( - self, - store_name: str, - key: str, - etag: Optional[str] = None, - options: Optional[StateOptions] = None, - state_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, + store_name: str, + key: str, + etag: Optional[str] = None, + options: Optional[StateOptions] = None, + state_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> DaprResponse: """Deletes key-value pairs from a statestore This deletes a value from the statestore with a given key and state store name. @@ -813,8 +857,12 @@ async def delete_state( :class:`DaprResponse` gRPC metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") @@ -825,20 +873,24 @@ async def delete_state( state_options = options.get_proto() etag_object = common_v1.Etag(value=etag) if etag is not None else None - req = api_v1.DeleteStateRequest(store_name=store_name, key=key, - etag=etag_object, options=state_options, - metadata=state_metadata) + req = api_v1.DeleteStateRequest( + store_name=store_name, + key=key, + etag=etag_object, + options=state_options, + metadata=state_metadata, + ) call = self._stub.DeleteState(req, metadata=metadata) await call - return DaprResponse( - headers=await call.initial_metadata()) + return DaprResponse(headers=await call.initial_metadata()) async def get_secret( - self, - store_name: str, - key: str, - secret_metadata: Optional[Dict[str, str]] = {}, - metadata: Optional[MetadataTuple] = None) -> GetSecretResponse: + self, + store_name: str, + key: str, + secret_metadata: Optional[Dict[str, str]] = {}, + metadata: Optional[MetadataTuple] = None, + ) -> GetSecretResponse: """Get secret with a given key. This gets a secret from secret store with a given key and secret store name. @@ -870,26 +922,26 @@ async def get_secret( :class:`GetSecretResponse` object with the secret and metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) - req = api_v1.GetSecretRequest( - store_name=store_name, - key=key, - metadata=secret_metadata) + req = api_v1.GetSecretRequest(store_name=store_name, key=key, metadata=secret_metadata) call = self._stub.GetSecret(req, metadata=metadata) response = await call - return GetSecretResponse( - secret=response.data, - headers=await call.initial_metadata()) + return GetSecretResponse(secret=response.data, headers=await call.initial_metadata()) async def get_bulk_secret( - self, - store_name: str, - secret_metadata: Optional[Dict[str, str]] = {}, - metadata: Optional[MetadataTuple] = None) -> GetBulkSecretResponse: + self, + store_name: str, + secret_metadata: Optional[Dict[str, str]] = {}, + metadata: Optional[MetadataTuple] = None, + ) -> GetBulkSecretResponse: """Get all granted secrets. This gets all granted secrets from secret store. @@ -918,12 +970,14 @@ async def get_bulk_secret( :class:`GetBulkSecretResponse` object with secrets and metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) - req = api_v1.GetBulkSecretRequest( - store_name=store_name, - metadata=secret_metadata) + req = api_v1.GetBulkSecretRequest(store_name=store_name, metadata=secret_metadata) call = self._stub.GetBulkSecret(req, metadata=metadata) response = await call @@ -936,15 +990,11 @@ async def get_bulk_secret( secrets_submap[subkey] = secret_response.secrets[subkey] secrets_map[key] = secrets_submap - return GetBulkSecretResponse( - secrets=secrets_map, - headers=await call.initial_metadata()) + return GetBulkSecretResponse(secrets=secrets_map, headers=await call.initial_metadata()) async def get_configuration( - self, - store_name: str, - keys: List[str], - config_metadata: Optional[Dict[str, str]] = dict()) -> ConfigurationResponse: + self, store_name: str, keys: List[str], config_metadata: Optional[Dict[str, str]] = dict() + ) -> ConfigurationResponse: """Gets values from a config store with keys The example gets value from a config store: @@ -969,19 +1019,19 @@ async def get_configuration( raise ValueError("Config store name cannot be empty to get the configuration") req = api_v1.GetConfigurationRequest( - store_name=store_name, keys=keys, metadata=config_metadata) + store_name=store_name, keys=keys, metadata=config_metadata + ) call = self._stub.GetConfiguration(req) response = await call - return ConfigurationResponse( - items=response.items, - headers=await call.initial_metadata()) + return ConfigurationResponse(items=response.items, headers=await call.initial_metadata()) async def subscribe_configuration( - self, - store_name: str, - keys: List[str], - handler: Callable[[Text, ConfigurationResponse], None], - config_metadata: Optional[Dict[str, str]] = dict()) -> Text: + self, + store_name: str, + keys: List[str], + handler: Callable[[Text, ConfigurationResponse], None], + config_metadata: Optional[Dict[str, str]] = dict(), + ) -> Text: """Gets changed value from a config store with a key The example gets value from a config store: @@ -1007,71 +1057,69 @@ async def subscribe_configuration( raise ValueError("Config store name cannot be empty to get the configuration") configWatcher = ConfigurationWatcher() - id = configWatcher.watch_configuration(self._stub, store_name, keys, - handler, config_metadata) + id = configWatcher.watch_configuration( + self._stub, store_name, keys, handler, config_metadata + ) return id - async def unsubscribe_configuration( - self, - store_name: str, - id: str) -> bool: + async def unsubscribe_configuration(self, store_name: str, id: str) -> bool: """Unsubscribes from configuration changes. - Args: - store_name (str): the state store name to unsubscribe from - id (str): the subscription id to unsubscribe + Args: + store_name (str): the state store name to unsubscribe from + id (str): the subscription id to unsubscribe - Returns: - bool: True if unsubscribed successfully, False otherwise + Returns: + bool: True if unsubscribed successfully, False otherwise """ req = api_v1.UnsubscribeConfigurationRequest(store_name=store_name, id=id) response: UnsubscribeConfigurationResponse = await self._stub.UnsubscribeConfiguration(req) return response.ok async def try_lock( - self, - store_name: str, - resource_id: str, - lock_owner: str, - expiry_in_seconds: int) -> TryLockResponse: + self, store_name: str, resource_id: str, lock_owner: str, expiry_in_seconds: int + ) -> TryLockResponse: """Tries to get a lock with an expiry. - You can use the result of this operation directly on an `if` statement: + You can use the result of this operation directly on an `if` statement: - if client.try_lock(store_name, resource_id, first_client_id, expiry_s): - # lock acquired successfully... + if client.try_lock(store_name, resource_id, first_client_id, expiry_s): + # lock acquired successfully... - You can also inspect the response's `success` attribute: + You can also inspect the response's `success` attribute: - response = client.try_lock(store_name, resource_id, first_client_id, expiry_s) - if response.success: - # lock acquired successfully... + response = client.try_lock(store_name, resource_id, first_client_id, expiry_s) + if response.success: + # lock acquired successfully... - Finally, you can use this response with a `with` statement, and have the lock - be automatically unlocked after the with-statement scope ends + Finally, you can use this response with a `with` statement, and have the lock + be automatically unlocked after the with-statement scope ends - with client.try_lock(store_name, resource_id, first_client_id, expiry_s) as lock: - if lock: - # lock acquired successfully... - # Lock automatically unlocked at this point, no need to call client->unlock(...) + with client.try_lock(store_name, resource_id, first_client_id, expiry_s) as lock: + if lock: + # lock acquired successfully... + # Lock automatically unlocked at this point, no need to call client->unlock(...) - Args: - store_name (str): the lock store name, e.g. `redis`. - resource_id (str): the lock key. e.g. `order_id_111`. - It stands for "which resource I want to protect". - lock_owner (str): indicates the identifier of lock owner. - expiry_in_seconds (int): The length of time (in seconds) for which this lock - will be held and after which it expires. - - Returns: - :class:`TryLockResponse`: With the result of the try-lock operation. + Args: + store_name (str): the lock store name, e.g. `redis`. + resource_id (str): the lock key. e.g. `order_id_111`. + It stands for "which resource I want to protect". + lock_owner (str): indicates the identifier of lock owner. + expiry_in_seconds (int): The length of time (in seconds) for which this lock + will be held and after which it expires. + + Returns: + :class:`TryLockResponse`: With the result of the try-lock operation. """ # Warnings and input validation - warn('The Distributed Lock API is an Alpha version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(store_name=store_name, - resource_id=resource_id, - lock_owner=lock_owner) + warn( + "The Distributed Lock API is an Alpha version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + store_name=store_name, resource_id=resource_id, lock_owner=lock_owner + ) if not expiry_in_seconds or expiry_in_seconds < 1: raise ValueError("expiry_in_seconds must be a positive number") # Actual tryLock invocation @@ -1079,7 +1127,8 @@ async def try_lock( store_name=store_name, resource_id=resource_id, lock_owner=lock_owner, - expiry_in_seconds=expiry_in_seconds) + expiry_in_seconds=expiry_in_seconds, + ) call = self._stub.TryLockAlpha1(req) response = await call return TryLockResponse( @@ -1088,78 +1137,84 @@ async def try_lock( store_name=store_name, resource_id=resource_id, lock_owner=lock_owner, - headers=await call.initial_metadata()) + headers=await call.initial_metadata(), + ) - async def unlock( - self, - store_name: str, - resource_id: str, - lock_owner: str) -> UnlockResponse: + async def unlock(self, store_name: str, resource_id: str, lock_owner: str) -> UnlockResponse: """Unlocks a lock. - Args: - store_name (str): the lock store name, e.g. `redis`. - resource_id (str): the lock key. e.g. `order_id_111`. - It stands for "which resource I want to protect". - lock_owner (str): indicates the identifier of lock owner. - metadata (tuple, optional, DEPRECATED): gRPC custom metadata - - Returns: - :class:`UnlockResponseStatus`: Status of the request, - `UnlockResponseStatus.success` if it was successful of some other - status otherwise. + Args: + store_name (str): the lock store name, e.g. `redis`. + resource_id (str): the lock key. e.g. `order_id_111`. + It stands for "which resource I want to protect". + lock_owner (str): indicates the identifier of lock owner. + metadata (tuple, optional, DEPRECATED): gRPC custom metadata + + Returns: + :class:`UnlockResponseStatus`: Status of the request, + `UnlockResponseStatus.success` if it was successful of some other + status otherwise. """ # Warnings and input validation - warn('The Distributed Lock API is an Alpha version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(store_name=store_name, - resource_id=resource_id, - lock_owner=lock_owner) + warn( + "The Distributed Lock API is an Alpha version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + store_name=store_name, resource_id=resource_id, lock_owner=lock_owner + ) # Actual unlocking invocation req = api_v1.UnlockRequest( - store_name=store_name, - resource_id=resource_id, - lock_owner=lock_owner) + store_name=store_name, resource_id=resource_id, lock_owner=lock_owner + ) call = self._stub.UnlockAlpha1(req) response = await call - return UnlockResponse(status=UnlockResponseStatus(response.status), - headers=await call.initial_metadata()) + return UnlockResponse( + status=UnlockResponseStatus(response.status), headers=await call.initial_metadata() + ) async def start_workflow( - self, - workflow_component: str, - workflow_name: str, - input: Optional[Union[Any, bytes]] = None, - instance_id: Optional[str] = None, - workflow_options: Optional[Dict[str, str]] = dict(), - send_raw_bytes: bool = False) -> StartWorkflowResponse: + self, + workflow_component: str, + workflow_name: str, + input: Optional[Union[Any, bytes]] = None, + instance_id: Optional[str] = None, + workflow_options: Optional[Dict[str, str]] = dict(), + send_raw_bytes: bool = False, + ) -> StartWorkflowResponse: """Starts a workflow. - Args: - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. - workflow_name (str): the name of the workflow that will be executed. - input (Optional[Union[Any, bytes]]): the input that the workflow will receive. - The input value will be serialized to JSON - by default. Use the send_raw_bytes param - to send unencoded binary input. - instance_id (Optional[str]): the name of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_options (Optional[Dict[str, str]]): the key-value options - that the workflow will receive. - send_raw_bytes (bool) if true, no serialization will be performed on the input - bytes + Args: + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. + workflow_name (str): the name of the workflow that will be executed. + input (Optional[Union[Any, bytes]]): the input that the workflow will receive. + The input value will be serialized to JSON + by default. Use the send_raw_bytes param + to send unencoded binary input. + instance_id (Optional[str]): the name of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_options (Optional[Dict[str, str]]): the key-value options + that the workflow will receive. + send_raw_bytes (bool) if true, no serialization will be performed on the input + bytes - Returns: - :class:`StartWorkflowResponse`: Instance ID associated with the started workflow + Returns: + :class:`StartWorkflowResponse`: Instance ID associated with the started workflow """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component, - workflow_name=workflow_name) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + instance_id=instance_id, + workflow_component=workflow_component, + workflow_name=workflow_name, + ) if instance_id is None: instance_id = str(uuid.uuid4()) @@ -1168,8 +1223,7 @@ async def start_workflow( encoded_data = input else: try: - encoded_data = json.dumps(input).encode( - "utf-8") if input is not None else bytes([]) + encoded_data = json.dumps(input).encode("utf-8") if input is not None else bytes([]) except TypeError: raise DaprInternalError("start_workflow: input data must be JSON serializable") except ValueError as e: @@ -1181,7 +1235,8 @@ async def start_workflow( workflow_component=workflow_component, workflow_name=workflow_name, options=workflow_options, - input=encoded_data) + input=encoded_data, + ) try: response = self._stub.StartWorkflowBeta1(req) @@ -1189,30 +1244,29 @@ async def start_workflow( except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) - async def get_workflow( - self, - instance_id: str, - workflow_component: str) -> GetWorkflowResponse: + async def get_workflow(self, instance_id: str, workflow_component: str) -> GetWorkflowResponse: """Gets information on a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`GetWorkflowResponse`: Instance ID associated with the started workflow + Returns: + :class:`GetWorkflowResponse`: Instance ID associated with the started workflow """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual get workflow invocation req = api_v1.GetWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: resp = self._stub.GetWorkflowBeta1(req) @@ -1220,90 +1274,99 @@ async def get_workflow( resp.created_at = datetime.now if resp.last_updated_at is None: resp.last_updated_at = datetime.now - return GetWorkflowResponse(instance_id=instance_id, - workflow_name=resp.workflow_name, - created_at=resp.created_at, - last_updated_at=resp.last_updated_at, - runtime_status=getWorkflowRuntimeStatus(resp.runtime_status), - properties=resp.properties) + return GetWorkflowResponse( + instance_id=instance_id, + workflow_name=resp.workflow_name, + created_at=resp.created_at, + last_updated_at=resp.last_updated_at, + runtime_status=getWorkflowRuntimeStatus(resp.runtime_status), + properties=resp.properties, + ) except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) - async def terminate_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + async def terminate_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Terminates a workflow. - Args: - instance_id (str): the ID of the workflow instance, e.g. - `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, e.g. + `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual terminate workflow invocation req = api_v1.TerminateWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.TerminateWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) async def raise_workflow_event( - self, - instance_id: str, - workflow_component: str, - event_name: str, - event_data: Optional[Union[Any, bytes]] = None, - send_raw_bytes: bool = False) -> DaprResponse: + self, + instance_id: str, + workflow_component: str, + event_name: str, + event_data: Optional[Union[Any, bytes]] = None, + send_raw_bytes: bool = False, + ) -> DaprResponse: """Raises an event on a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. - event_name (str): the name of the event to be raised on - the workflow. - event_data (Optional[Union[Any, bytes]]): the input that the workflow will receive. - The input value will be serialized to JSON - by default. Use the send_raw_bytes param - to send unencoded binary input. - send_raw_bytes (bool) if true, no serialization will be performed on the input - bytes - - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. + event_name (str): the name of the event to be raised on + the workflow. + event_data (Optional[Union[Any, bytes]]): the input that the workflow will receive. + The input value will be serialized to JSON + by default. Use the send_raw_bytes param + to send unencoded binary input. + send_raw_bytes (bool) if true, no serialization will be performed on the input + bytes + + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component, - event_name=event_name) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + instance_id=instance_id, workflow_component=workflow_component, event_name=event_name + ) if isinstance(event_data, bytes) and send_raw_bytes: encoded_data = event_data else: if event_data is not None: try: - encoded_data = json.dumps(event_data).encode( - "utf-8") if event_data is not None else bytes([]) + encoded_data = ( + json.dumps(event_data).encode("utf-8") + if event_data is not None + else bytes([]) + ) except TypeError: - raise DaprInternalError("raise_workflow_event:\ - event_data must be JSON serializable") + raise DaprInternalError( + "raise_workflow_event:\ + event_data must be JSON serializable" + ) except ValueError as e: raise DaprInternalError(f"raise_workflow_event JSON serialization error: {e}") encoded_data = json.dumps(event_data).encode("utf-8") @@ -1314,19 +1377,16 @@ async def raise_workflow_event( instance_id=instance_id, workflow_component=workflow_component, event_name=event_name, - event_data=encoded_data) + event_data=encoded_data, + ) try: _, call = self._stub.RaiseEventWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) - async def pause_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + async def pause_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Pause a workflow. Args: @@ -1340,86 +1400,83 @@ async def pause_workflow( """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual pause workflow invocation req = api_v1.PauseWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.PauseWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) - async def resume_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + async def resume_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Resumes a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual resume workflow invocation req = api_v1.ResumeWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.ResumeWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) - async def purge_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + async def purge_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Purges a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual purge workflow invocation req = api_v1.PurgeWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.PurgeWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except grpc.aio.AioRpcError as err: raise DaprInternalError(err.details()) @@ -1471,14 +1528,12 @@ async def get_metadata(self) -> GetMetadataResponse: response: api_v1.GetMetadataResponse = _resp # type alias # Convert to more pythonic formats active_actors_count = { - type_count.type: type_count.count - for type_count in response.active_actors_count + type_count.type: type_count.count for type_count in response.active_actors_count } registered_components = [ - RegisteredComponents(name=i.name, - type=i.type, - version=i.version, - capabilities=i.capabilities) + RegisteredComponents( + name=i.name, type=i.type, version=i.version, capabilities=i.capabilities + ) for i in response.registered_components ] extended_metadata = dict(response.extended_metadata.items()) @@ -1488,7 +1543,8 @@ async def get_metadata(self) -> GetMetadataResponse: active_actors_count=active_actors_count, registered_components=registered_components, extended_metadata=extended_metadata, - headers=await call.initial_metadata()) + headers=await call.initial_metadata(), + ) async def set_metadata(self, attributeName: str, attributeValue: str) -> DaprResponse: """Adds a custom (extended) metadata attribute to the Dapr sidecar diff --git a/dapr/clients/__init__.py b/dapr/clients/__init__.py index da5f373b6..65878e2c4 100644 --- a/dapr/clients/__init__.py +++ b/dapr/clients/__init__.py @@ -25,18 +25,18 @@ from google.protobuf.message import Message as GrpcMessage __all__ = [ - 'DaprClient', - 'DaprActorClientBase', - 'DaprActorHttpClient', - 'DaprInternalError', - 'ERROR_CODE_UNKNOWN', + "DaprClient", + "DaprActorClientBase", + "DaprActorHttpClient", + "DaprInternalError", + "ERROR_CODE_UNKNOWN", ] from grpc import ( # type: ignore UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, - StreamStreamClientInterceptor + StreamStreamClientInterceptor, ) @@ -47,16 +47,22 @@ class DaprClient(DaprGrpcClient): variable. See: https://github.com/dapr/python-sdk/issues/176 for more details""" def __init__( - self, - address: Optional[str] = None, - headers_callback: Optional[Callable[[], Dict[str, str]]] = None, - interceptors: Optional[List[Union[ - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, - StreamStreamClientInterceptor]]] = None, - http_timeout_seconds: Optional[int] = None, - max_grpc_message_length: Optional[int] = None): + self, + address: Optional[str] = None, + headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + interceptors: Optional[ + List[ + Union[ + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, + StreamStreamClientInterceptor, + ] + ] + ] = None, + http_timeout_seconds: Optional[int] = None, + max_grpc_message_length: Optional[int] = None, + ): """Connects to Dapr Runtime via gRPC and HTTP. Args: @@ -75,28 +81,30 @@ def __init__( invocation_protocol = settings.DAPR_API_METHOD_INVOCATION_PROTOCOL.upper() - if invocation_protocol == 'HTTP': + if invocation_protocol == "HTTP": if http_timeout_seconds is None: http_timeout_seconds = settings.DAPR_HTTP_TIMEOUT_SECONDS - self.invocation_client = DaprInvocationHttpClient(headers_callback=headers_callback, - timeout=http_timeout_seconds, - address=address) - elif invocation_protocol == 'GRPC': + self.invocation_client = DaprInvocationHttpClient( + headers_callback=headers_callback, timeout=http_timeout_seconds, address=address + ) + elif invocation_protocol == "GRPC": pass else: raise DaprInternalError( - f'Unknown value for DAPR_API_METHOD_INVOCATION_PROTOCOL: {invocation_protocol}') + f"Unknown value for DAPR_API_METHOD_INVOCATION_PROTOCOL: {invocation_protocol}" + ) def invoke_method( - self, - app_id: str, - method_name: str, - data: Union[bytes, str, GrpcMessage] = '', - content_type: Optional[str] = None, - metadata: Optional[MetadataTuple] = None, - http_verb: Optional[str] = None, - http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None) -> InvokeMethodResponse: + self, + app_id: str, + method_name: str, + data: Union[bytes, str, GrpcMessage] = "", + content_type: Optional[str] = None, + metadata: Optional[MetadataTuple] = None, + http_verb: Optional[str] = None, + http_querystring: Optional[MetadataTuple] = None, + timeout: Optional[int] = None, + ) -> InvokeMethodResponse: """Invoke a service method over gRPC or HTTP. Args: @@ -121,7 +129,8 @@ def invoke_method( metadata=metadata, http_verb=http_verb, http_querystring=http_querystring, - timeout=timeout) + timeout=timeout, + ) else: return super().invoke_method( app_id, @@ -131,18 +140,20 @@ def invoke_method( metadata=metadata, http_verb=http_verb, http_querystring=http_querystring, - timeout=timeout) + timeout=timeout, + ) async def invoke_method_async( - self, - app_id: str, - method_name: str, - data: Union[bytes, str, GrpcMessage], - content_type: Optional[str] = None, - metadata: Optional[MetadataTuple] = None, - http_verb: Optional[str] = None, - http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None) -> InvokeMethodResponse: + self, + app_id: str, + method_name: str, + data: Union[bytes, str, GrpcMessage], + content_type: Optional[str] = None, + metadata: Optional[MetadataTuple] = None, + http_verb: Optional[str] = None, + http_querystring: Optional[MetadataTuple] = None, + timeout: Optional[int] = None, + ) -> InvokeMethodResponse: """Invoke a service method over gRPC or HTTP. Args: @@ -159,8 +170,11 @@ async def invoke_method_async( InvokeMethodResponse: the method invocation response. """ if self.invocation_client: - warn('Async invocation is deprecated. Please use `dapr.aio.clients.DaprClient`.', - DeprecationWarning, stacklevel=2) + warn( + "Async invocation is deprecated. Please use `dapr.aio.clients.DaprClient`.", + DeprecationWarning, + stacklevel=2, + ) return await self.invocation_client.invoke_method_async( app_id, method_name, @@ -169,7 +183,9 @@ async def invoke_method_async( metadata=metadata, http_verb=http_verb, http_querystring=http_querystring, - timeout=timeout) + timeout=timeout, + ) else: raise NotImplementedError( - 'Please use `dapr.aio.clients.DaprClient` for async invocation') + "Please use `dapr.aio.clients.DaprClient` for async invocation" + ) diff --git a/dapr/clients/base.py b/dapr/clients/base.py index 4af6865f5..908063f0d 100644 --- a/dapr/clients/base.py +++ b/dapr/clients/base.py @@ -17,47 +17,41 @@ from typing import Optional -DEFAULT_ENCODING = 'utf-8' -DEFAULT_JSON_CONTENT_TYPE = f'application/json; charset={DEFAULT_ENCODING}' +DEFAULT_ENCODING = "utf-8" +DEFAULT_JSON_CONTENT_TYPE = f"application/json; charset={DEFAULT_ENCODING}" class DaprActorClientBase(ABC): - """A base class that represents Dapr Actor Client. - """ + """A base class that represents Dapr Actor Client.""" @abstractmethod async def invoke_method( - self, actor_type: str, actor_id: str, - method: str, data: Optional[bytes] = None) -> bytes: + self, actor_type: str, actor_id: str, method: str, data: Optional[bytes] = None + ) -> bytes: ... @abstractmethod - async def save_state_transactionally( - self, actor_type: str, actor_id: str, - data: bytes) -> None: + async def save_state_transactionally(self, actor_type: str, actor_id: str, data: bytes) -> None: ... @abstractmethod - async def get_state( - self, actor_type: str, actor_id: str, name: str) -> bytes: + async def get_state(self, actor_type: str, actor_id: str, name: str) -> bytes: ... @abstractmethod async def register_reminder( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + self, actor_type: str, actor_id: str, name: str, data: bytes + ) -> None: ... @abstractmethod - async def unregister_reminder( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_reminder(self, actor_type: str, actor_id: str, name: str) -> None: ... @abstractmethod - async def register_timer( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + async def register_timer(self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: ... @abstractmethod - async def unregister_timer( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_timer(self, actor_type: str, actor_id: str, name: str) -> None: ... diff --git a/dapr/clients/exceptions.py b/dapr/clients/exceptions.py index 6f9842468..fafa84849 100644 --- a/dapr/clients/exceptions.py +++ b/dapr/clients/exceptions.py @@ -21,17 +21,20 @@ class DaprInternalError(Exception): """DaprInternalError encapsulates all Dapr exceptions""" + def __init__( - self, message: Optional[str], - error_code: Optional[str] = ERROR_CODE_UNKNOWN, - raw_response_bytes: Optional[bytes] = None): + self, + message: Optional[str], + error_code: Optional[str] = ERROR_CODE_UNKNOWN, + raw_response_bytes: Optional[bytes] = None, + ): self._message = message self._error_code = error_code self._raw_response_bytes = raw_response_bytes def as_dict(self): return { - 'message': self._message, - 'errorCode': self._error_code, - 'raw_response_bytes': self._raw_response_bytes + "message": self._message, + "errorCode": self._error_code, + "raw_response_bytes": self._raw_response_bytes, } diff --git a/dapr/clients/grpc/_helpers.py b/dapr/clients/grpc/_helpers.py index 6a0c27e59..f92996485 100644 --- a/dapr/clients/grpc/_helpers.py +++ b/dapr/clients/grpc/_helpers.py @@ -52,9 +52,9 @@ def unpack(data: GrpcAny, message: GrpcMessage) -> None: matched with the response data type """ if not isinstance(message, GrpcMessage): - raise ValueError('output message is not protocol buffer message object') + raise ValueError("output message is not protocol buffer message object") if not data.Is(message.DESCRIPTOR): - raise ValueError(f'invalid type. serialized message type: {data.type_url}') + raise ValueError(f"invalid type. serialized message type: {data.type_url}") data.Unpack(message) @@ -63,9 +63,9 @@ def to_bytes(data: Union[str, bytes]) -> bytes: if isinstance(data, bytes): return data elif isinstance(data, str): - return data.encode('utf-8') + return data.encode("utf-8") else: - raise f'invalid data type {type(data)}' + raise f"invalid data type {type(data)}" def to_str(data: Union[str, bytes]) -> str: @@ -73,20 +73,23 @@ def to_str(data: Union[str, bytes]) -> str: if isinstance(data, str): return data elif isinstance(data, bytes): - return data.decode('utf-8') + return data.decode("utf-8") else: - raise f'invalid data type {type(data)}' + raise f"invalid data type {type(data)}" class _ClientCallDetails( - namedtuple( - '_ClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), - ClientCallDetails): + namedtuple( + "_ClientCallDetails", + ["method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"], + ), + ClientCallDetails, +): """This is an implementation of the ClientCallDetails interface needed for interceptors. This class takes six named values and inherits the ClientCallDetails from grpc package. This class encloses the values that describe a RPC to be invoked. """ + pass @@ -105,9 +108,7 @@ class DaprClientInterceptor(UnaryUnaryClientInterceptor): intercepted_channel = grpc.intercept_channel(grpc_channel, interceptor) """ - def __init__( - self, - metadata: List[Tuple[str, str]]): + def __init__(self, metadata: List[Tuple[str, str]]): """Initializes the metadata field for the class. Args: @@ -117,9 +118,7 @@ def __init__( self._metadata = metadata - def _intercept_call( - self, - client_call_details: ClientCallDetails) -> ClientCallDetails: + def _intercept_call(self, client_call_details: ClientCallDetails) -> ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details. @@ -137,16 +136,16 @@ def _intercept_call( metadata.extend(self._metadata) new_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials, client_call_details.wait_for_ready, - client_call_details.compression) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression, + ) return new_call_details - def intercept_unary_unary( - self, - continuation, - client_call_details, - request): + def intercept_unary_unary(self, continuation, client_call_details, request): """This method intercepts a unary-unary gRPC call. This is the implementation of the abstract method defined in UnaryUnaryClientInterceptor defined in grpc. This is invoked automatically by grpc based on the order in which interceptors are added to the channel. @@ -168,6 +167,7 @@ def intercept_unary_unary( # Data validation helpers + def validateNotNone(**kwargs: Optional[str]): for field_name, value in kwargs.items(): if value is None: diff --git a/dapr/clients/grpc/_request.py b/dapr/clients/grpc/_request.py index dc3e64d0d..dd146ff23 100644 --- a/dapr/clients/grpc/_request.py +++ b/dapr/clients/grpc/_request.py @@ -26,7 +26,7 @@ tuple_to_dict, to_bytes, to_str, - unpack + unpack, ) @@ -38,6 +38,7 @@ class DaprRequest: Attributes: metadata(dict): A dict to include the headers from Dapr Request. """ + def __init__(self, metadata: MetadataTuple = ()): self.metadata = metadata # type: ignore @@ -50,7 +51,7 @@ def metadata(self) -> MetadataDict: def metadata(self, val) -> None: """Sets metadata.""" if not isinstance(val, tuple): - raise ValueError('val is not tuple') + raise ValueError("val is not tuple") self._metadata = val def get_metadata(self, as_dict: bool = False) -> Union[MetadataDict, MetadataTuple]: @@ -83,21 +84,13 @@ class InvokeMethodRequest(DaprRequest): only for bytes array data. """ - HTTP_METHODS = [ - 'GET', - 'HEAD', - 'POST', - 'PUT', - 'DELETE', - 'CONNECT', - 'OPTIONS', - 'TRACE' - ] + HTTP_METHODS = ["GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE"] def __init__( - self, - data: Union[str, bytes, GrpcAny, GrpcMessage, None] = None, - content_type: Optional[str] = None): + self, + data: Union[str, bytes, GrpcAny, GrpcMessage, None] = None, + content_type: Optional[str] = None, + ): """Inits InvokeMethodRequestData with data and content_type. Args: @@ -131,7 +124,7 @@ def http_verb(self) -> Optional[str]: def http_verb(self, val: Optional[str]) -> None: """Sets HTTP method to Dapr invocation request.""" if val not in self.HTTP_METHODS: - raise ValueError(f'{val} is the invalid HTTP verb.') + raise ValueError(f"{val} is the invalid HTTP verb.") self._http_verb = val @property @@ -141,7 +134,7 @@ def http_querystring(self) -> Dict[str, str]: def is_http(self) -> bool: """Return true if this request is http compatible.""" - return hasattr(self, '_http_verb') and not (not self._http_verb) + return hasattr(self, "_http_verb") and not (not self._http_verb) @property def proto(self) -> GrpcAny: @@ -150,7 +143,7 @@ def proto(self) -> GrpcAny: def is_proto(self) -> bool: """Returns true if data is protocol-buffer serialized.""" - return hasattr(self, '_data') and self._data.type_url != '' + return hasattr(self, "_data") and self._data.type_url != "" def pack(self, val: Union[GrpcAny, GrpcMessage]) -> None: """Serializes protocol buffer message. @@ -167,7 +160,7 @@ def pack(self, val: Union[GrpcAny, GrpcMessage]) -> None: self._data = GrpcAny() self._data.Pack(val) else: - raise ValueError('invalid data type') + raise ValueError("invalid data type") def unpack(self, message: GrpcMessage) -> None: """Deserializes the serialized protocol buffer message. @@ -186,7 +179,7 @@ def unpack(self, message: GrpcMessage) -> None: def data(self) -> bytes: """Gets request data as bytes.""" if self.is_proto(): - raise ValueError('data is protocol buffer message object.') + raise ValueError("data is protocol buffer message object.") return self._data.value @data.setter @@ -203,7 +196,7 @@ def set_data(self, val: Union[str, bytes, GrpcAny, GrpcMessage, None]) -> None: elif isinstance(val, (GrpcAny, GrpcMessage)): self.pack(val) else: - raise ValueError(f'invalid data type {type(val)}') + raise ValueError(f"invalid data type {type(val)}") def text(self) -> str: """Gets the request data as str.""" @@ -230,10 +223,8 @@ class BindingRequest(DaprRequest): data (bytes): the data which is used for invoke_binding request. metadata (Dict[str, str]): the metadata sent to the binding. """ - def __init__( - self, - data: Union[str, bytes], - binding_metadata: Dict[str, str] = {}): + + def __init__(self, data: Union[str, bytes], binding_metadata: Dict[str, str] = {}): """Inits BindingRequest with data and metadata if given. Args: @@ -244,7 +235,7 @@ def __init__( ValueError: data is not bytes or str. """ super(BindingRequest, self).__init__(()) - self.data = data # type: ignore + self.data = data # type: ignore self._binding_metadata = binding_metadata @property @@ -269,6 +260,7 @@ def binding_metadata(self): class TransactionOperationType(Enum): """Represents the type of operation for a Dapr Transaction State Api Call""" + upsert = "upsert" delete = "delete" @@ -284,11 +276,12 @@ class TransactionalStateOperation: """ def __init__( - self, - key: str, - data: Union[bytes, str], - etag: Optional[str] = None, - operation_type: TransactionOperationType = TransactionOperationType.upsert): + self, + key: str, + data: Union[bytes, str], + etag: Optional[str] = None, + operation_type: TransactionOperationType = TransactionOperationType.upsert, + ): """Initializes TransactionalStateOperation item from :obj:`runtime_v1.TransactionalStateOperation`. @@ -302,7 +295,7 @@ def __init__( ValueError: data is not bytes or str. """ if not isinstance(data, (bytes, str)): - raise ValueError(f'invalid type for data {type(data)}') + raise ValueError(f"invalid type for data {type(data)}") self._key = key self._data = data # type: ignore diff --git a/dapr/clients/grpc/_response.py b/dapr/clients/grpc/_response.py index a491756d0..79b169ce7 100644 --- a/dapr/clients/grpc/_response.py +++ b/dapr/clients/grpc/_response.py @@ -20,8 +20,16 @@ from datetime import datetime from enum import Enum from typing import ( - Callable, Dict, List, Optional, Text, Union, - Sequence, Mapping, TYPE_CHECKING, NamedTuple + Callable, + Dict, + List, + Optional, + Text, + Union, + Sequence, + Mapping, + TYPE_CHECKING, + NamedTuple, ) from google.protobuf.any_pb2 import Any as GrpcAny @@ -59,9 +67,7 @@ class DaprResponse: headers(dict): A dict to include the headers from Dapr gRPC Response. """ - def __init__( - self, - headers: MetadataTuple = ()): + def __init__(self, headers: MetadataTuple = ()): """Inits DapResponse with headers and trailers. Args: @@ -111,11 +117,12 @@ class InvokeMethodResponse(DaprResponse): """ def __init__( - self, - data: Union[str, bytes, GrpcAny, GrpcMessage, None] = None, - content_type: Optional[str] = None, - headers: MetadataTuple = (), - status_code: Optional[int] = None): + self, + data: Union[str, bytes, GrpcAny, GrpcMessage, None] = None, + content_type: Optional[str] = None, + headers: MetadataTuple = (), + status_code: Optional[int] = None, + ): """Initializes InvokeMethodReponse from :obj:`common_v1.InvokeResponse`. Args: @@ -146,7 +153,7 @@ def proto(self) -> GrpcAny: def is_proto(self) -> bool: """Returns True if the response data is the serialized protocol buffer message.""" - return hasattr(self, '_data') and self._data.type_url != '' + return hasattr(self, "_data") and self._data.type_url != "" @property def data(self) -> bytes: @@ -157,7 +164,7 @@ def data(self) -> bytes: ValueError: the response data is the serialized protocol buffer message """ if self.is_proto(): - raise ValueError('data is protocol buffer message object.') + raise ValueError("data is protocol buffer message object.") return self._data.value @data.setter @@ -174,7 +181,7 @@ def set_data(self, val: Union[str, bytes, GrpcAny, GrpcMessage, None]) -> None: elif isinstance(val, (GrpcAny, GrpcMessage)): self.pack(val) else: - raise ValueError(f'invalid data type {type(val)}') + raise ValueError(f"invalid data type {type(val)}") def text(self) -> str: """Gets content as str if the response data content is not serialized @@ -219,7 +226,7 @@ def pack(self, val: Union[GrpcAny, GrpcMessage]) -> None: self._data = GrpcAny() self._data.Pack(val) else: - raise ValueError('invalid data type') + raise ValueError("invalid data type") @property def status_code(self) -> Optional[int]: @@ -261,10 +268,11 @@ class BindingResponse(DaprResponse): """ def __init__( - self, - data: Union[bytes, str], - binding_metadata: Dict[str, str] = {}, - headers: MetadataTuple = ()): + self, + data: Union[bytes, str], + binding_metadata: Dict[str, str] = {}, + headers: MetadataTuple = (), + ): """Initializes InvokeBindingReponse from :obj:`runtime_v1.InvokeBindingResponse`. Args: @@ -313,10 +321,7 @@ class GetSecretResponse(DaprResponse): secret (Dict[str, str]): secret received from response """ - def __init__( - self, - secret: Dict[str, str], - headers: MetadataTuple = ()): + def __init__(self, secret: Dict[str, str], headers: MetadataTuple = ()): """Initializes GetSecretReponse from :obj:`dapr_v1.GetSecretResponse`. Args: @@ -341,10 +346,7 @@ class GetBulkSecretResponse(DaprResponse): secret (Dict[str, Dict[str, str]]): secret received from response """ - def __init__( - self, - secrets: Dict[str, Dict[str, str]], - headers: MetadataTuple = ()): + def __init__(self, secrets: Dict[str, Dict[str, str]], headers: MetadataTuple = ()): """Initializes GetBulkSecretReponse from :obj:`dapr_v1.GetBulkSecretResponse`. Args: @@ -371,11 +373,7 @@ class StateResponse(DaprResponse): headers (Tuple, optional): the headers from Dapr gRPC response """ - def __init__( - self, - data: Union[bytes, str], - etag: str = '', - headers: MetadataTuple = ()): + def __init__(self, data: Union[bytes, str], etag: str = "", headers: MetadataTuple = ()): """Initializes StateResponse from :obj:`runtime_v1.GetStateResponse`. Args: @@ -425,12 +423,7 @@ class BulkStateItem: error (str): error when state was retrieved """ - def __init__( - self, - key: str, - data: Union[bytes, str], - etag: str = '', - error: str = ''): + def __init__(self, key: str, data: Union[bytes, str], etag: str = "", error: str = ""): """Initializes BulkStateItem item from :obj:`runtime_v1.BulkStateItem`. Args: @@ -482,10 +475,7 @@ class BulkStatesResponse(DaprResponse): data (Union[bytes, str]): state's data. """ - def __init__( - self, - items: Sequence[BulkStateItem], - headers: MetadataTuple = ()): + def __init__(self, items: Sequence[BulkStateItem], headers: MetadataTuple = ()): """Initializes BulkStatesResponse from :obj:`runtime_v1.GetBulkStateResponse`. Args: @@ -511,12 +501,7 @@ class QueryResponseItem: error (str): error when state was retrieved """ - def __init__( - self, - key: str, - value: bytes, - etag: str = '', - error: str = ''): + def __init__(self, key: str, value: bytes, etag: str = "", error: str = ""): """Initializes QueryResponseItem item from :obj:`runtime_v1.QueryStateItem`. Args: @@ -571,11 +556,12 @@ class QueryResponse(DaprResponse): """ def __init__( - self, - results: Sequence[QueryResponseItem], - token: str = '', - metadata: Dict[str, str] = dict(), - headers: MetadataTuple = ()): + self, + results: Sequence[QueryResponseItem], + token: str = "", + metadata: Dict[str, str] = dict(), + headers: MetadataTuple = (), + ): """Initializes QueryResponse from :obj:`runtime_v1.QueryStateResponse`. Args: @@ -614,11 +600,7 @@ class ConfigurationItem: metadata (str): metadata """ - def __init__( - self, - value: str, - version: str, - metadata: Optional[Dict[str, str]] = dict()): + def __init__(self, value: str, version: str, metadata: Optional[Dict[str, str]] = dict()): """Initializes ConfigurationItem item from :obj:`runtime_v1.ConfigurationItem`. Args: @@ -663,10 +645,7 @@ class ConfigurationResponse(DaprResponse): - items (Mapping[Text, ConfigurationItem]): state's data. """ - def __init__( - self, - items: Mapping[Text, ConfigurationItem], - headers: MetadataTuple = ()): + def __init__(self, items: Mapping[Text, ConfigurationItem], headers: MetadataTuple = ()): """Initializes ConfigurationResponse from :obj:`runtime_v1.GetConfigurationResponse`. Args: @@ -682,16 +661,22 @@ def items(self) -> Mapping[Text, ConfigurationItem]: return self._items -class ConfigurationWatcher(): +class ConfigurationWatcher: def __init__(self): self.event: threading.Event = threading.Event() self.id: str = "" - def watch_configuration(self, stub: api_service_v1.DaprStub, store_name: str, - keys: List[str], handler: Callable[[Text, ConfigurationResponse], None], - config_metadata: Optional[Dict[str, str]] = dict()): + def watch_configuration( + self, + stub: api_service_v1.DaprStub, + store_name: str, + keys: List[str], + handler: Callable[[Text, ConfigurationResponse], None], + config_metadata: Optional[Dict[str, str]] = dict(), + ): req = api_v1.SubscribeConfigurationRequest( - store_name=store_name, keys=keys, metadata=config_metadata) + store_name=store_name, keys=keys, metadata=config_metadata + ) thread = threading.Thread(target=self._read_subscribe_config, args=(stub, req, handler)) thread.daemon = True thread.start() @@ -703,9 +688,12 @@ def watch_configuration(self, stub: api_service_v1.DaprStub, store_name: str, return None return self.id - def _read_subscribe_config(self, stub: api_service_v1.DaprStub, - req: api_v1.SubscribeConfigurationRequest, - handler: Callable[[Text, ConfigurationResponse], None]): + def _read_subscribe_config( + self, + stub: api_service_v1.DaprStub, + req: api_v1.SubscribeConfigurationRequest, + handler: Callable[[Text, ConfigurationResponse], None], + ): try: responses = stub.SubscribeConfigurationAlpha1(req) isFirst = True @@ -717,8 +705,7 @@ def _read_subscribe_config(self, stub: api_service_v1.DaprStub, if len(response.items) > 0: handler(response.id, ConfigurationResponse(response.items)) except Exception: - print(f"{self.store_name} configuration watcher for keys " - f"{self.keys} stopped.") + print(f"{self.store_name} configuration watcher for keys " f"{self.keys} stopped.") pass @@ -770,26 +757,26 @@ def status(self) -> TopicEventResponseStatus: class UnlockResponseStatus(Enum): success = api_v1.UnlockResponse.Status.SUCCESS - '''The Unlock operation for the referred lock was successful.''' + """The Unlock operation for the referred lock was successful.""" lock_does_not_exist = api_v1.UnlockResponse.Status.LOCK_DOES_NOT_EXIST - ''''The unlock operation failed: the referred lock does not exist.''' + """'The unlock operation failed: the referred lock does not exist.""" lock_belongs_to_others = api_v1.UnlockResponse.Status.LOCK_BELONGS_TO_OTHERS - '''The unlock operation failed: the referred lock belongs to another owner.''' + """The unlock operation failed: the referred lock belongs to another owner.""" internal_error = api_v1.UnlockResponse.Status.INTERNAL_ERROR - '''An internal error happened while handling the Unlock operation''' + """An internal error happened while handling the Unlock operation""" class UnlockResponse(DaprResponse): - '''The response of an unlock operation. + """The response of an unlock operation. This inherits from DaprResponse Attributes: status (UnlockResponseStatus): the status of the unlock operation. - ''' + """ def __init__( self, @@ -812,13 +799,14 @@ def status(self) -> UnlockResponseStatus: class TryLockResponse(contextlib.AbstractContextManager, DaprResponse): - '''The response of a try_lock operation. + """The response of a try_lock operation. This inherits from DaprResponse and AbstractContextManager. Attributes: success (bool): the result of the try_lock operation. - ''' + """ + def __init__( self, success: bool, @@ -854,39 +842,42 @@ def success(self) -> bool: return self._success def __exit__(self, *exc) -> None: - ''''Automatically unlocks the lock if this TryLockResponse was used as + """'Automatically unlocks the lock if this TryLockResponse was used as a ContextManager / `with` statement. Notice: we are not checking the result of the unlock operation. If this is something you care about it might be wiser creating your own ContextManager that logs or otherwise raises exceptions if unlock doesn't return `UnlockResponseStatus.success`. - ''' + """ if self._success: self._client.unlock(self._store_name, self._resource_id, self._lock_owner) # else: there is no point unlocking a lock we did not acquire. async def __aexit__(self, *exc) -> None: - ''''Automatically unlocks the lock if this TryLockResponse was used as + """'Automatically unlocks the lock if this TryLockResponse was used as a ContextManager / `with` statement. Notice: we are not checking the result of the unlock operation. If this is something you care about it might be wiser creating your own ContextManager that logs or otherwise raises exceptions if unlock doesn't return `UnlockResponseStatus.success`. - ''' + """ if self._success: - await self._client.unlock(self._store_name, # type: ignore - self._resource_id, self._lock_owner) + await self._client.unlock( + self._store_name, # type: ignore + self._resource_id, + self._lock_owner, + ) # else: there is no point unlocking a lock we did not acquire. - async def __aenter__(self) -> 'TryLockResponse': - '''Returns self as the context manager object.''' + async def __aenter__(self) -> "TryLockResponse": + """Returns self as the context manager object.""" return self class GetMetadataResponse(DaprResponse): - '''GetMetadataResponse is a message that is returned on GetMetadata rpc call.''' + """GetMetadataResponse is a message that is returned on GetMetadata rpc call.""" def __init__( self, @@ -896,7 +887,7 @@ def __init__( extended_metadata: Dict[str, str], headers: MetadataTuple = (), ): - '''Initializes GetMetadataResponse. + """Initializes GetMetadataResponse. Args: application_id (str): The Application ID. @@ -907,7 +898,7 @@ def __init__( extended_metadata (Dict[str, str]): mapping of custom (extended) attributes to their respective values. headers (Tuple, optional): the headers from Dapr gRPC response. - ''' + """ super().__init__(headers) self._application_id = application_id self._active_actors_count = active_actors_count @@ -916,27 +907,27 @@ def __init__( @property def application_id(self) -> str: - '''The Application ID.''' + """The Application ID.""" return self._application_id @property def active_actors_count(self) -> Dict[str, int]: - '''Mapping from the type of registered actors to their number of running instances.''' + """Mapping from the type of registered actors to their number of running instances.""" return self._active_actors_count @property def registered_components(self) -> Sequence[RegisteredComponents]: - '''List of loaded components metadata.''' + """List of loaded components metadata.""" return self._registered_components @property def extended_metadata(self) -> Dict[str, str]: - '''Mapping of custom (extended) attributes to their respective values.''' + """Mapping of custom (extended) attributes to their respective values.""" return self._extended_metadata -class GetWorkflowResponse(): - '''The response of get_workflow operation.''' +class GetWorkflowResponse: + """The response of get_workflow operation.""" def __init__( self, @@ -965,8 +956,8 @@ def __init__( self.properties = properties -class StartWorkflowResponse(): - '''The response of start_workflow operation.''' +class StartWorkflowResponse: + """The response of start_workflow operation.""" def __init__( self, @@ -981,16 +972,16 @@ def __init__( class RegisteredComponents(NamedTuple): - '''Describes a loaded Dapr component.''' + """Describes a loaded Dapr component.""" name: str - '''Name of the component.''' + """Name of the component.""" type: str - '''Component type.''' + """Component type.""" version: str - '''Component version.''' + """Component version.""" capabilities: Sequence[str] - '''Supported capabilities for this component type and version.''' + """Supported capabilities for this component type and version.""" diff --git a/dapr/clients/grpc/_state.py b/dapr/clients/grpc/_state.py index 41eff8a1a..0bd2fd72c 100644 --- a/dapr/clients/grpc/_state.py +++ b/dapr/clients/grpc/_state.py @@ -5,6 +5,7 @@ class Consistency(Enum): """Represents the consistency mode for a Dapr State Api Call""" + unspecified = common_v1.StateOptions.StateConsistency.CONSISTENCY_UNSPECIFIED # type: ignore eventual = common_v1.StateOptions.StateConsistency.CONSISTENCY_EVENTUAL # type: ignore strong = common_v1.StateOptions.StateConsistency.CONSISTENCY_STRONG # type: ignore @@ -12,6 +13,7 @@ class Consistency(Enum): class Concurrency(Enum): """Represents the consistency mode for a Dapr State Api Call""" + unspecified = common_v1.StateOptions.StateConcurrency.CONCURRENCY_UNSPECIFIED # type: ignore first_write = common_v1.StateOptions.StateConcurrency.CONCURRENCY_FIRST_WRITE # type: ignore last_write = common_v1.StateOptions.StateConcurrency.CONCURRENCY_LAST_WRITE # type: ignore @@ -56,7 +58,7 @@ def __init__( value: Union[bytes, str], etag: Optional[str] = None, options: Optional[StateOptions] = None, - metadata: Optional[Dict[str, str]] = dict() + metadata: Optional[Dict[str, str]] = dict(), ): """Inits StateItem with the required parameters. @@ -71,7 +73,7 @@ def __init__( ValueError: value is not bytes or str """ if not isinstance(value, (bytes, str)): - raise ValueError(f'invalid type for data {type(value)}') + raise ValueError(f"invalid type for data {type(value)}") self._key = key self._value = value diff --git a/dapr/clients/grpc/client.py b/dapr/clients/grpc/client.py index cccf54188..4420a4e97 100644 --- a/dapr/clients/grpc/client.py +++ b/dapr/clients/grpc/client.py @@ -34,7 +34,7 @@ UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor, - RpcError + RpcError, ) from dapr.clients.exceptions import DaprInternalError @@ -56,7 +56,7 @@ from dapr.clients.grpc._request import ( InvokeMethodRequest, BindingRequest, - TransactionalStateOperation + TransactionalStateOperation, ) from dapr.clients.grpc._response import ( BindingResponse, @@ -102,14 +102,19 @@ class DaprGrpcClient: """ def __init__( - self, - address: Optional[str] = None, - interceptors: Optional[List[Union[ - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, - StreamStreamClientInterceptor]]] = None, - max_grpc_message_length: Optional[int] = None + self, + address: Optional[str] = None, + interceptors: Optional[ + List[ + Union[ + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, + StreamStreamClientInterceptor, + ] + ] + ] = None, + max_grpc_message_length: Optional[int] = None, ): """Connects to Dapr Runtime and initialize gRPC client stub. @@ -122,54 +127,64 @@ def __init__( max_grpc_messsage_length (int, optional): The maximum grpc send and receive message length in bytes. """ - useragent = f'dapr-sdk-python/{__version__}' + useragent = f"dapr-sdk-python/{__version__}" if not max_grpc_message_length: options = [ - ('grpc.primary_user_agent', useragent), + ("grpc.primary_user_agent", useragent), ] else: options = [ - ('grpc.max_send_message_length', max_grpc_message_length), # type: ignore - ('grpc.max_receive_message_length', max_grpc_message_length), # type: ignore - ('grpc.primary_user_agent', useragent) + ("grpc.max_send_message_length", max_grpc_message_length), # type: ignore + ("grpc.max_receive_message_length", max_grpc_message_length), # type: ignore + ("grpc.primary_user_agent", useragent), ] if not address: - address = settings.DAPR_GRPC_ENDPOINT or (f"{settings.DAPR_RUNTIME_HOST}:" - f"{settings.DAPR_GRPC_PORT}") + address = settings.DAPR_GRPC_ENDPOINT or ( + f"{settings.DAPR_RUNTIME_HOST}:" f"{settings.DAPR_GRPC_PORT}" + ) try: self._uri = GrpcEndpoint(address) except ValueError as error: - raise DaprInternalError(f'{error}') from error + raise DaprInternalError(f"{error}") from error if self._uri.tls: - self._channel = grpc.secure_channel(self._uri.endpoint, # type: ignore - self.get_credentials(), - options=options) + self._channel = grpc.secure_channel( # type: ignore + self._uri.endpoint, + self.get_credentials(), + options=options, + ) else: - self._channel = grpc.insecure_channel(self._uri.endpoint, # type: ignore - options=options) + self._channel = grpc.insecure_channel( # type: ignore + self._uri.endpoint, + options=options, + ) if settings.DAPR_API_TOKEN: - api_token_interceptor = DaprClientInterceptor([ - ('dapr-api-token', settings.DAPR_API_TOKEN), ]) + api_token_interceptor = DaprClientInterceptor( + [ + ("dapr-api-token", settings.DAPR_API_TOKEN), + ] + ) self._channel = grpc.intercept_channel( # type: ignore - self._channel, api_token_interceptor) + self._channel, api_token_interceptor + ) if interceptors: self._channel = grpc.intercept_channel( # type: ignore - self._channel, *interceptors) + self._channel, *interceptors + ) self._stub = api_service_v1.DaprStub(self._channel) def get_credentials(self): # This method is used (overwritten) from tests # to return credentials for self-signed certificates - return grpc.ssl_channel_credentials() # type: ignore + return grpc.ssl_channel_credentials() # type: ignore def close(self): """Closes Dapr runtime gRPC channel.""" - if hasattr(self, '_channel') and self._channel: + if hasattr(self, "_channel") and self._channel: self._channel.close() def __del__(self): @@ -182,8 +197,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: self.close() def _get_http_extension( - self, http_verb: str, - http_querystring: Optional[MetadataTuple] = None + self, http_verb: str, http_querystring: Optional[MetadataTuple] = None ) -> common_v1.HTTPExtension: # type: ignore verb = common_v1.HTTPExtension.Verb.Value(http_verb) # type: ignore http_ext = common_v1.HTTPExtension(verb=verb) @@ -192,15 +206,16 @@ def _get_http_extension( return http_ext def invoke_method( - self, - app_id: str, - method_name: str, - data: Union[bytes, str, GrpcMessage] = '', - content_type: Optional[str] = None, - metadata: Optional[MetadataTuple] = None, - http_verb: Optional[str] = None, - http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None) -> InvokeMethodResponse: + self, + app_id: str, + method_name: str, + data: Union[bytes, str, GrpcMessage] = "", + content_type: Optional[str] = None, + metadata: Optional[MetadataTuple] = None, + http_verb: Optional[str] = None, + http_querystring: Optional[MetadataTuple] = None, + timeout: Optional[int] = None, + ) -> InvokeMethodResponse: """Invokes the target service to call method. This can invoke the specified target service to call method with bytes array data or @@ -274,11 +289,18 @@ def invoke_method( Returns: :class:`InvokeMethodResponse` object returned from callee """ - warn('invoke_method with protocol gRPC is deprecated. Use gRPC proxying instead.', - DeprecationWarning, stacklevel=2) + warn( + "invoke_method with protocol gRPC is deprecated. Use gRPC proxying instead.", + DeprecationWarning, + stacklevel=2, + ) if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) req_data = InvokeMethodRequest(data, content_type) http_ext = None @@ -294,7 +316,8 @@ def invoke_method( method=method_name, data=req_data.proto, content_type=content_type, - http_extension=http_ext) + http_extension=http_ext, + ), ) response, call = self._stub.InvokeService.with_call(req, metadata=metadata, timeout=timeout) @@ -304,12 +327,13 @@ def invoke_method( return resp_data def invoke_binding( - self, - binding_name: str, - operation: str, - data: Union[bytes, str] = '', - binding_metadata: Dict[str, str] = {}, - metadata: Optional[MetadataTuple] = None) -> BindingResponse: + self, + binding_name: str, + operation: str, + data: Union[bytes, str] = "", + binding_metadata: Dict[str, str] = {}, + metadata: Optional[MetadataTuple] = None, + ) -> BindingResponse: """Invokes the output binding with the specified operation. The data field takes any JSON serializable value and acts as the @@ -341,8 +365,12 @@ def invoke_binding( :class:`InvokeBindingResponse` object returned from binding """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) req_data = BindingRequest(data, binding_metadata) @@ -350,22 +378,21 @@ def invoke_binding( name=binding_name, data=req_data.data, metadata=req_data.binding_metadata, - operation=operation + operation=operation, ) response, call = self._stub.InvokeBinding.with_call(req, metadata=metadata) - return BindingResponse( - response.data, dict(response.metadata), - call.initial_metadata()) + return BindingResponse(response.data, dict(response.metadata), call.initial_metadata()) def publish_event( - self, - pubsub_name: str, - topic_name: str, - data: Union[bytes, str], - publish_metadata: Dict[str, str] = {}, - metadata: Optional[MetadataTuple] = None, - data_content_type: Optional[str] = None) -> DaprResponse: + self, + pubsub_name: str, + topic_name: str, + data: Union[bytes, str], + publish_metadata: Dict[str, str] = {}, + metadata: Optional[MetadataTuple] = None, + data_content_type: Optional[str] = None, + ) -> DaprResponse: """Publish to a given topic. This publishes an event with bytes array or str data to a specified topic and specified pubsub component. The str data is encoded into bytes with default @@ -396,18 +423,22 @@ def publish_event( :class:`DaprResponse` gRPC metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not isinstance(data, bytes) and not isinstance(data, str): - raise ValueError(f'invalid type for data {type(data)}') + raise ValueError(f"invalid type for data {type(data)}") req_data: bytes if isinstance(data, bytes): req_data = data else: if isinstance(data, str): - req_data = data.encode('utf-8') + req_data = data.encode("utf-8") content_type = "" if data_content_type: @@ -417,7 +448,8 @@ def publish_event( topic=topic_name, data=req_data, data_content_type=content_type, - metadata=publish_metadata) + metadata=publish_metadata, + ) # response is google.protobuf.Empty _, call = self._stub.PublishEvent.with_call(req, metadata=metadata) @@ -425,11 +457,12 @@ def publish_event( return DaprResponse(call.initial_metadata()) def get_state( - self, - store_name: str, - key: str, - state_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> StateResponse: + self, + store_name: str, + key: str, + state_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> StateResponse: """Gets value from a statestore with a key The example gets value from a statestore: @@ -453,25 +486,29 @@ def get_state( and value obtained from the state store """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") req = api_v1.GetStateRequest(store_name=store_name, key=key, metadata=state_metadata) response, call = self._stub.GetState.with_call(req, metadata=metadata) return StateResponse( - data=response.data, - etag=response.etag, - headers=call.initial_metadata()) + data=response.data, etag=response.etag, headers=call.initial_metadata() + ) def get_bulk_state( - self, - store_name: str, - keys: Sequence[str], - parallelism: int = 1, - states_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> BulkStatesResponse: + self, + store_name: str, + keys: Sequence[str], + parallelism: int = 1, + states_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> BulkStatesResponse: """Gets values from a statestore with keys The example gets value from a statestore: @@ -496,35 +533,30 @@ def get_bulk_state( and value obtained from the state store """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") req = api_v1.GetBulkStateRequest( - store_name=store_name, - keys=keys, - parallelism=parallelism, - metadata=states_metadata) + store_name=store_name, keys=keys, parallelism=parallelism, metadata=states_metadata + ) response, call = self._stub.GetBulkState.with_call(req, metadata=metadata) items = [] for item in response.items: items.append( - BulkStateItem( - key=item.key, - data=item.data, - etag=item.etag, - error=item.error)) - return BulkStatesResponse( - items=items, - headers=call.initial_metadata()) + BulkStateItem(key=item.key, data=item.data, etag=item.etag, error=item.error) + ) + return BulkStatesResponse(items=items, headers=call.initial_metadata()) def query_state( - self, - store_name: str, - query: str, - states_metadata: Optional[Dict[str, str]] = dict()) -> QueryResponse: + self, store_name: str, query: str, states_metadata: Optional[Dict[str, str]] = dict() + ) -> QueryResponse: """Queries a statestore with a query For details on supported queries see https://docs.dapr.io/ @@ -562,42 +594,40 @@ def query_state( :class:`QueryStateResponse` gRPC metadata returned from callee, pagination token and results of the query """ - warn('The State Store Query API is an Alpha version and is subject to change.', - UserWarning, stacklevel=2) + warn( + "The State Store Query API is an Alpha version and is subject to change.", + UserWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") - req = api_v1.QueryStateRequest( - store_name=store_name, - query=query, - metadata=states_metadata) + req = api_v1.QueryStateRequest(store_name=store_name, query=query, metadata=states_metadata) response, call = self._stub.QueryStateAlpha1.with_call(req) results = [] for item in response.results: results.append( - QueryResponseItem( - key=item.key, - value=item.data, - etag=item.etag, - error=item.error) + QueryResponseItem(key=item.key, value=item.data, etag=item.etag, error=item.error) ) return QueryResponse( token=response.token, results=results, metadata=response.metadata, - headers=call.initial_metadata()) + headers=call.initial_metadata(), + ) def save_state( - self, - store_name: str, - key: str, - value: Union[bytes, str], - etag: Optional[str] = None, - options: Optional[StateOptions] = None, - state_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, + store_name: str, + key: str, + value: Union[bytes, str], + etag: Optional[str] = None, + options: Optional[StateOptions] = None, + state_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> DaprResponse: """Saves key-value pairs to a statestore This saves a value to the statestore with a given key and state store name. @@ -633,11 +663,15 @@ def save_state( ValueError: store_name is empty """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not isinstance(value, (bytes, str)): - raise ValueError(f'invalid type for data {type(value)}') + raise ValueError(f"invalid type for data {type(value)}") req_value = value @@ -654,18 +688,16 @@ def save_state( value=to_bytes(req_value), etag=common_v1.Etag(value=etag) if etag is not None else None, options=state_options, - metadata=state_metadata) + metadata=state_metadata, + ) req = api_v1.SaveStateRequest(store_name=store_name, states=[state]) _, call = self._stub.SaveState.with_call(req, metadata=metadata) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) def save_bulk_state( - self, - store_name: str, - states: List[StateItem], - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, store_name: str, states: List[StateItem], metadata: Optional[MetadataTuple] = None + ) -> DaprResponse: """Saves state items to a statestore This saves a given state item into the statestore specified by store_name. @@ -692,8 +724,12 @@ def save_bulk_state( ValueError: store_name is empty """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not states or len(states) == 0: raise ValueError("States to be saved cannot be empty") @@ -701,24 +737,28 @@ def save_bulk_state( if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") - req_states = [common_v1.StateItem( - key=i.key, - value=to_bytes(i.value), - etag=common_v1.Etag(value=i.etag) if i.etag is not None else None, - options=i.options, - metadata=i.metadata) for i in states] + req_states = [ + common_v1.StateItem( + key=i.key, + value=to_bytes(i.value), + etag=common_v1.Etag(value=i.etag) if i.etag is not None else None, + options=i.options, + metadata=i.metadata, + ) + for i in states + ] req = api_v1.SaveStateRequest(store_name=store_name, states=req_states) _, call = self._stub.SaveState.with_call(req, metadata=metadata) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) def execute_state_transaction( - self, - store_name: str, - operations: Sequence[TransactionalStateOperation], - transactional_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, + store_name: str, + operations: Sequence[TransactionalStateOperation], + transactional_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> DaprResponse: """Saves or deletes key-value pairs to a statestore as a transaction This saves or deletes key-values to the statestore as part of a single transaction, @@ -750,35 +790,42 @@ def execute_state_transaction( :class:`DaprResponse` gRPC metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token headers ' - 'and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token headers " + "and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") - req_ops = [api_v1.TransactionalStateOperation( - operationType=o.operation_type.value, - request=common_v1.StateItem( - key=o.key, - value=to_bytes(o.data), - etag=common_v1.Etag(value=o.etag) if o.etag is not None else None)) - for o in operations] + req_ops = [ + api_v1.TransactionalStateOperation( + operationType=o.operation_type.value, + request=common_v1.StateItem( + key=o.key, + value=to_bytes(o.data), + etag=common_v1.Etag(value=o.etag) if o.etag is not None else None, + ), + ) + for o in operations + ] req = api_v1.ExecuteStateTransactionRequest( - storeName=store_name, - operations=req_ops, - metadata=transactional_metadata) + storeName=store_name, operations=req_ops, metadata=transactional_metadata + ) _, call = self._stub.ExecuteStateTransaction.with_call(req, metadata=metadata) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) def delete_state( - self, - store_name: str, - key: str, - etag: Optional[str] = None, - options: Optional[StateOptions] = None, - state_metadata: Optional[Dict[str, str]] = dict(), - metadata: Optional[MetadataTuple] = None) -> DaprResponse: + self, + store_name: str, + key: str, + etag: Optional[str] = None, + options: Optional[StateOptions] = None, + state_metadata: Optional[Dict[str, str]] = dict(), + metadata: Optional[MetadataTuple] = None, + ) -> DaprResponse: """Deletes key-value pairs from a statestore This deletes a value from the statestore with a given key and state store name. @@ -808,8 +855,12 @@ def delete_state( :class:`DaprResponse` gRPC metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token ' - 'headers and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token " + "headers and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("State store name cannot be empty") @@ -820,19 +871,23 @@ def delete_state( state_options = options.get_proto() etag_object = common_v1.Etag(value=etag) if etag is not None else None - req = api_v1.DeleteStateRequest(store_name=store_name, key=key, - etag=etag_object, options=state_options, - metadata=state_metadata) + req = api_v1.DeleteStateRequest( + store_name=store_name, + key=key, + etag=etag_object, + options=state_options, + metadata=state_metadata, + ) _, call = self._stub.DeleteState.with_call(req, metadata=metadata) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) def get_secret( - self, - store_name: str, - key: str, - secret_metadata: Optional[Dict[str, str]] = {}, - metadata: Optional[MetadataTuple] = None) -> GetSecretResponse: + self, + store_name: str, + key: str, + secret_metadata: Optional[Dict[str, str]] = {}, + metadata: Optional[MetadataTuple] = None, + ) -> GetSecretResponse: """Get secret with a given key. This gets a secret from secret store with a given key and secret store name. @@ -864,25 +919,25 @@ def get_secret( :class:`GetSecretResponse` object with the secret and metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token ' - 'headers and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token " + "headers and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) - req = api_v1.GetSecretRequest( - store_name=store_name, - key=key, - metadata=secret_metadata) + req = api_v1.GetSecretRequest(store_name=store_name, key=key, metadata=secret_metadata) response, call = self._stub.GetSecret.with_call(req, metadata=metadata) - return GetSecretResponse( - secret=response.data, - headers=call.initial_metadata()) + return GetSecretResponse(secret=response.data, headers=call.initial_metadata()) def get_bulk_secret( - self, - store_name: str, - secret_metadata: Optional[Dict[str, str]] = {}, - metadata: Optional[MetadataTuple] = None) -> GetBulkSecretResponse: + self, + store_name: str, + secret_metadata: Optional[Dict[str, str]] = {}, + metadata: Optional[MetadataTuple] = None, + ) -> GetBulkSecretResponse: """Get all granted secrets. This gets all granted secrets from secret store. @@ -911,12 +966,14 @@ def get_bulk_secret( :class:`GetBulkSecretResponse` object with secrets and metadata returned from callee """ if metadata is not None: - warn('metadata argument is deprecated. Dapr already intercepts API token ' - 'headers and this is not needed.', DeprecationWarning, stacklevel=2) + warn( + "metadata argument is deprecated. Dapr already intercepts API token " + "headers and this is not needed.", + DeprecationWarning, + stacklevel=2, + ) - req = api_v1.GetBulkSecretRequest( - store_name=store_name, - metadata=secret_metadata) + req = api_v1.GetBulkSecretRequest(store_name=store_name, metadata=secret_metadata) response, call = self._stub.GetBulkSecret.with_call(req, metadata=metadata) @@ -928,15 +985,11 @@ def get_bulk_secret( secrets_submap[subkey] = secret_response.secrets[subkey] secrets_map[key] = secrets_submap - return GetBulkSecretResponse( - secrets=secrets_map, - headers=call.initial_metadata()) + return GetBulkSecretResponse(secrets=secrets_map, headers=call.initial_metadata()) def get_configuration( - self, - store_name: str, - keys: List[str], - config_metadata: Optional[Dict[str, str]] = dict()) -> ConfigurationResponse: + self, store_name: str, keys: List[str], config_metadata: Optional[Dict[str, str]] = dict() + ) -> ConfigurationResponse: """Gets value from a config store with a key The example gets value from a config store: @@ -960,18 +1013,18 @@ def get_configuration( if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError("Config store name cannot be empty to get the configuration") req = api_v1.GetConfigurationRequest( - store_name=store_name, keys=keys, metadata=config_metadata) + store_name=store_name, keys=keys, metadata=config_metadata + ) response, call = self._stub.GetConfiguration.with_call(req) - return ConfigurationResponse( - items=response.items, - headers=call.initial_metadata()) + return ConfigurationResponse(items=response.items, headers=call.initial_metadata()) def subscribe_configuration( - self, - store_name: str, - keys: List[str], - handler: Callable[[Text, ConfigurationResponse], None], - config_metadata: Optional[Dict[str, str]] = dict()) -> Text: + self, + store_name: str, + keys: List[str], + handler: Callable[[Text, ConfigurationResponse], None], + config_metadata: Optional[Dict[str, str]] = dict(), + ) -> Text: """Gets changed value from a config store with a key The example gets value from a config store: @@ -998,71 +1051,69 @@ def subscribe_configuration( raise ValueError("Config store name cannot be empty to get the configuration") configWatcher = ConfigurationWatcher() - id = configWatcher.watch_configuration(self._stub, store_name, keys, - handler, config_metadata) + id = configWatcher.watch_configuration( + self._stub, store_name, keys, handler, config_metadata + ) return id - def unsubscribe_configuration( - self, - store_name: str, - id: str) -> bool: + def unsubscribe_configuration(self, store_name: str, id: str) -> bool: """Unsubscribes from configuration changes. - Args: - store_name (str): the state store name to unsubscribe from - id (str): the subscription id to unsubscribe + Args: + store_name (str): the state store name to unsubscribe from + id (str): the subscription id to unsubscribe - Returns: - bool: True if unsubscribed successfully, False otherwise + Returns: + bool: True if unsubscribed successfully, False otherwise """ req = api_v1.UnsubscribeConfigurationRequest(store_name=store_name, id=id) response: UnsubscribeConfigurationResponse = self._stub.UnsubscribeConfiguration(req) return response.ok def try_lock( - self, - store_name: str, - resource_id: str, - lock_owner: str, - expiry_in_seconds: int) -> TryLockResponse: + self, store_name: str, resource_id: str, lock_owner: str, expiry_in_seconds: int + ) -> TryLockResponse: """Tries to get a lock with an expiry. - You can use the result of this operation directly on an `if` statement: + You can use the result of this operation directly on an `if` statement: - if client.try_lock(store_name, resource_id, first_client_id, expiry_s): - # lock acquired successfully... + if client.try_lock(store_name, resource_id, first_client_id, expiry_s): + # lock acquired successfully... - You can also inspect the response's `success` attribute: + You can also inspect the response's `success` attribute: - response = client.try_lock(store_name, resource_id, first_client_id, expiry_s) - if response.success: - # lock acquired successfully... + response = client.try_lock(store_name, resource_id, first_client_id, expiry_s) + if response.success: + # lock acquired successfully... - Finally, you can use this response with a `with` statement, and have the lock - be automatically unlocked after the with-statement scope ends + Finally, you can use this response with a `with` statement, and have the lock + be automatically unlocked after the with-statement scope ends - with client.try_lock(store_name, resource_id, first_client_id, expiry_s) as lock: - if lock: - # lock acquired successfully... - # Lock automatically unlocked at this point, no need to call client->unlock(...) + with client.try_lock(store_name, resource_id, first_client_id, expiry_s) as lock: + if lock: + # lock acquired successfully... + # Lock automatically unlocked at this point, no need to call client->unlock(...) - Args: - store_name (str): the lock store name, e.g. `redis`. - resource_id (str): the lock key. e.g. `order_id_111`. - It stands for "which resource I want to protect". - lock_owner (str): indicates the identifier of lock owner. - expiry_in_seconds (int): The length of time (in seconds) for which this lock - will be held and after which it expires. - - Returns: - :class:`TryLockResponse`: With the result of the try-lock operation. + Args: + store_name (str): the lock store name, e.g. `redis`. + resource_id (str): the lock key. e.g. `order_id_111`. + It stands for "which resource I want to protect". + lock_owner (str): indicates the identifier of lock owner. + expiry_in_seconds (int): The length of time (in seconds) for which this lock + will be held and after which it expires. + + Returns: + :class:`TryLockResponse`: With the result of the try-lock operation. """ # Warnings and input validation - warn('The Distributed Lock API is an Alpha version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(store_name=store_name, - resource_id=resource_id, - lock_owner=lock_owner) + warn( + "The Distributed Lock API is an Alpha version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + store_name=store_name, resource_id=resource_id, lock_owner=lock_owner + ) if not expiry_in_seconds or expiry_in_seconds < 1: raise ValueError("expiry_in_seconds must be a positive number") # Actual tryLock invocation @@ -1070,7 +1121,8 @@ def try_lock( store_name=store_name, resource_id=resource_id, lock_owner=lock_owner, - expiry_in_seconds=expiry_in_seconds) + expiry_in_seconds=expiry_in_seconds, + ) response, call = self._stub.TryLockAlpha1.with_call(req) return TryLockResponse( success=response.success, @@ -1078,76 +1130,79 @@ def try_lock( store_name=store_name, resource_id=resource_id, lock_owner=lock_owner, - headers=call.initial_metadata()) + headers=call.initial_metadata(), + ) - def unlock( - self, - store_name: str, - resource_id: str, - lock_owner: str) -> UnlockResponse: + def unlock(self, store_name: str, resource_id: str, lock_owner: str) -> UnlockResponse: """Unlocks a lock. - Args: - store_name (str): the lock store name, e.g. `redis`. - resource_id (str): the lock key. e.g. `order_id_111`. - It stands for "which resource I want to protect". - lock_owner (str): indicates the identifier of lock owner. - metadata (tuple, optional, DEPRECATED): gRPC custom metadata - - Returns: - :class:`UnlockResponseStatus`: Status of the request, - `UnlockResponseStatus.success` if it was successful of some other - status otherwise. + Args: + store_name (str): the lock store name, e.g. `redis`. + resource_id (str): the lock key. e.g. `order_id_111`. + It stands for "which resource I want to protect". + lock_owner (str): indicates the identifier of lock owner. + metadata (tuple, optional, DEPRECATED): gRPC custom metadata + + Returns: + :class:`UnlockResponseStatus`: Status of the request, + `UnlockResponseStatus.success` if it was successful of some other + status otherwise. """ # Warnings and input validation - warn('The Distributed Lock API is an Alpha version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(store_name=store_name, - resource_id=resource_id, - lock_owner=lock_owner) + warn( + "The Distributed Lock API is an Alpha version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + store_name=store_name, resource_id=resource_id, lock_owner=lock_owner + ) # Actual unlocking invocation req = api_v1.UnlockRequest( - store_name=store_name, - resource_id=resource_id, - lock_owner=lock_owner) + store_name=store_name, resource_id=resource_id, lock_owner=lock_owner + ) response, call = self._stub.UnlockAlpha1.with_call(req) - return UnlockResponse(status=UnlockResponseStatus(response.status), - headers=call.initial_metadata()) + return UnlockResponse( + status=UnlockResponseStatus(response.status), headers=call.initial_metadata() + ) def start_workflow( - self, - workflow_component: str, - workflow_name: str, - input: Optional[Union[Any, bytes]] = None, - instance_id: Optional[str] = None, - workflow_options: Optional[Dict[str, str]] = dict(), - send_raw_bytes: bool = False) -> StartWorkflowResponse: + self, + workflow_component: str, + workflow_name: str, + input: Optional[Union[Any, bytes]] = None, + instance_id: Optional[str] = None, + workflow_options: Optional[Dict[str, str]] = dict(), + send_raw_bytes: bool = False, + ) -> StartWorkflowResponse: """Starts a workflow. - Args: - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. - workflow_name (str): the name of the workflow that will be executed. - input (Optional[Union[Any, bytes]]): the input that the workflow will receive. - The input value will be serialized to JSON - by default. Use the send_raw_bytes param - to send unencoded binary input. - instance_id (Optional[str]): the name of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_options (Optional[Dict[str, str]]): the key-value options - that the workflow will receive. - send_raw_bytes (bool) if true, no serialization will be performed on the input - bytes + Args: + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. + workflow_name (str): the name of the workflow that will be executed. + input (Optional[Union[Any, bytes]]): the input that the workflow will receive. + The input value will be serialized to JSON + by default. Use the send_raw_bytes param + to send unencoded binary input. + instance_id (Optional[str]): the name of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_options (Optional[Dict[str, str]]): the key-value options + that the workflow will receive. + send_raw_bytes (bool) if true, no serialization will be performed on the input + bytes - Returns: - :class:`StartWorkflowResponse`: Instance ID associated with the started workflow + Returns: + :class:`StartWorkflowResponse`: Instance ID associated with the started workflow """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(workflow_component=workflow_component, - workflow_name=workflow_name) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(workflow_component=workflow_component, workflow_name=workflow_name) if instance_id is None: instance_id = str(uuid.uuid4()) @@ -1156,8 +1211,7 @@ def start_workflow( encoded_data = input else: try: - encoded_data = json.dumps(input).encode( - "utf-8") if input is not None else bytes([]) + encoded_data = json.dumps(input).encode("utf-8") if input is not None else bytes([]) except TypeError: raise DaprInternalError("start_workflow: input data must be JSON serializable") except ValueError as e: @@ -1169,7 +1223,8 @@ def start_workflow( workflow_component=workflow_component, workflow_name=workflow_name, options=workflow_options, - input=encoded_data) + input=encoded_data, + ) try: response = self._stub.StartWorkflowBeta1(req) @@ -1177,30 +1232,29 @@ def start_workflow( except RpcError as err: raise DaprInternalError(err.details()) - def get_workflow( - self, - instance_id: str, - workflow_component: str) -> GetWorkflowResponse: + def get_workflow(self, instance_id: str, workflow_component: str) -> GetWorkflowResponse: """Gets information on a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`GetWorkflowResponse`: Instance ID associated with the started workflow + Returns: + :class:`GetWorkflowResponse`: Instance ID associated with the started workflow """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual get workflow invocation req = api_v1.GetWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: resp = self._stub.GetWorkflowBeta1(req) @@ -1208,90 +1262,99 @@ def get_workflow( resp.created_at = datetime.now() if resp.last_updated_at is None: resp.last_updated_at = datetime.now() - return GetWorkflowResponse(instance_id=instance_id, - workflow_name=resp.workflow_name, - created_at=resp.created_at, - last_updated_at=resp.last_updated_at, - runtime_status=getWorkflowRuntimeStatus(resp.runtime_status), - properties=resp.properties) + return GetWorkflowResponse( + instance_id=instance_id, + workflow_name=resp.workflow_name, + created_at=resp.created_at, + last_updated_at=resp.last_updated_at, + runtime_status=getWorkflowRuntimeStatus(resp.runtime_status), + properties=resp.properties, + ) except RpcError as err: raise DaprInternalError(err.details()) - def terminate_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + def terminate_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Terminates a workflow. - Args: - instance_id (str): the ID of the workflow instance, e.g. - `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, e.g. + `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual terminate workflow invocation req = api_v1.TerminateWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.TerminateWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) def raise_workflow_event( - self, - instance_id: str, - workflow_component: str, - event_name: str, - event_data: Optional[Union[Any, bytes]] = None, - send_raw_bytes: bool = False) -> DaprResponse: + self, + instance_id: str, + workflow_component: str, + event_name: str, + event_data: Optional[Union[Any, bytes]] = None, + send_raw_bytes: bool = False, + ) -> DaprResponse: """Raises an event on a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. - event_data (Optional[Union[Any, bytes]]): the input that the workflow will receive. - The input value will be serialized to JSON - by default. Use the send_raw_bytes param - to send unencoded binary input. - event_data (Optional[Union[Any, bytes]]): the input to the event. - send_raw_bytes (bool) if true, no serialization will be performed on the input - bytes - - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. + event_data (Optional[Union[Any, bytes]]): the input that the workflow will receive. + The input value will be serialized to JSON + by default. Use the send_raw_bytes param + to send unencoded binary input. + event_data (Optional[Union[Any, bytes]]): the input to the event. + send_raw_bytes (bool) if true, no serialization will be performed on the input + bytes + + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component, - event_name=event_name) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString( + instance_id=instance_id, workflow_component=workflow_component, event_name=event_name + ) if isinstance(event_data, bytes) and send_raw_bytes: encoded_data = event_data else: if event_data is not None: try: - encoded_data = json.dumps(event_data).encode( - "utf-8") if event_data is not None else bytes([]) + encoded_data = ( + json.dumps(event_data).encode("utf-8") + if event_data is not None + else bytes([]) + ) except TypeError: - raise DaprInternalError("raise_workflow_event:\ - event_data must be JSON serializable") + raise DaprInternalError( + "raise_workflow_event:\ + event_data must be JSON serializable" + ) except ValueError as e: raise DaprInternalError(f"raise_workflow_event JSON serialization error: {e}") encoded_data = json.dumps(event_data).encode("utf-8") @@ -1303,19 +1366,16 @@ def raise_workflow_event( instance_id=instance_id, workflow_component=workflow_component, event_name=event_name, - event_data=encoded_data) + event_data=encoded_data, + ) try: _, call = self._stub.RaiseEventWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) - def pause_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + def pause_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Pause a workflow. Args: @@ -1329,86 +1389,83 @@ def pause_workflow( """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual pause workflow invocation req = api_v1.PauseWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.PauseWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) - def resume_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + def resume_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Resumes a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual resume workflow invocation req = api_v1.ResumeWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.ResumeWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) - def purge_workflow( - self, - instance_id: str, - workflow_component: str) -> DaprResponse: + def purge_workflow(self, instance_id: str, workflow_component: str) -> DaprResponse: """Purges a workflow. - Args: - instance_id (str): the ID of the workflow instance, - e.g. `order_processing_workflow-103784`. - workflow_component (str): the name of the workflow component - that will run the workflow. e.g. `dapr`. + Args: + instance_id (str): the ID of the workflow instance, + e.g. `order_processing_workflow-103784`. + workflow_component (str): the name of the workflow component + that will run the workflow. e.g. `dapr`. - Returns: - :class:`DaprResponse` gRPC metadata returned from callee + Returns: + :class:`DaprResponse` gRPC metadata returned from callee """ # Warnings and input validation - warn('The Workflow API is a Beta version and is subject to change.', - UserWarning, stacklevel=2) - validateNotBlankString(instance_id=instance_id, - workflow_component=workflow_component) + warn( + "The Workflow API is a Beta version and is subject to change.", + UserWarning, + stacklevel=2, + ) + validateNotBlankString(instance_id=instance_id, workflow_component=workflow_component) # Actual purge workflow invocation req = api_v1.PurgeWorkflowRequest( - instance_id=instance_id, - workflow_component=workflow_component) + instance_id=instance_id, workflow_component=workflow_component + ) try: _, call = self._stub.PurgeWorkflowBeta1.with_call(req) - return DaprResponse( - headers=call.initial_metadata()) + return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) @@ -1459,14 +1516,12 @@ def get_metadata(self) -> GetMetadataResponse: response: api_v1.GetMetadataResponse = _resp # type alias # Convert to more pythonic formats active_actors_count = { - type_count.type: type_count.count - for type_count in response.active_actors_count + type_count.type: type_count.count for type_count in response.active_actors_count } registered_components = [ - RegisteredComponents(name=i.name, - type=i.type, - version=i.version, - capabilities=i.capabilities) + RegisteredComponents( + name=i.name, type=i.type, version=i.version, capabilities=i.capabilities + ) for i in response.registered_components ] extended_metadata = dict(response.extended_metadata.items()) @@ -1476,7 +1531,8 @@ def get_metadata(self) -> GetMetadataResponse: active_actors_count=active_actors_count, registered_components=registered_components, extended_metadata=extended_metadata, - headers=call.initial_metadata()) + headers=call.initial_metadata(), + ) def set_metadata(self, attributeName: str, attributeValue: str) -> DaprResponse: """Adds a custom (extended) metadata attribute to the Dapr sidecar diff --git a/dapr/clients/http/client.py b/dapr/clients/http/client.py index 713ca863d..601c060ab 100644 --- a/dapr/clients/http/client.py +++ b/dapr/clients/http/client.py @@ -25,20 +25,22 @@ from dapr.clients.exceptions import DaprInternalError, ERROR_CODE_DOES_NOT_EXIST, ERROR_CODE_UNKNOWN from dapr.version import __version__ -CONTENT_TYPE_HEADER = 'content-type' -DAPR_API_TOKEN_HEADER = 'dapr-api-token' -USER_AGENT_HEADER = 'User-Agent' -DAPR_USER_AGENT = f'dapr-sdk-python/{__version__}' +CONTENT_TYPE_HEADER = "content-type" +DAPR_API_TOKEN_HEADER = "dapr-api-token" +USER_AGENT_HEADER = "User-Agent" +DAPR_USER_AGENT = f"dapr-sdk-python/{__version__}" class DaprHttpClient: """A Dapr Http API client""" - def __init__(self, - message_serializer: 'Serializer', - timeout: Optional[int] = 60, - headers_callback: Optional[Callable[[], Dict[str, str]]] = None, - address: Optional[str] = None): + def __init__( + self, + message_serializer: "Serializer", + timeout: Optional[int] = 60, + headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + address: Optional[str] = None, + ): """Invokes Dapr over HTTP. Args: @@ -53,19 +55,22 @@ def __init__(self, def get_api_url(self) -> str: if self._address: - return '{}/{}'.format(self._address, settings.DAPR_API_VERSION) + return "{}/{}".format(self._address, settings.DAPR_API_VERSION) if settings.DAPR_HTTP_ENDPOINT: - return '{}/{}'.format(settings.DAPR_HTTP_ENDPOINT, settings.DAPR_API_VERSION) + return "{}/{}".format(settings.DAPR_HTTP_ENDPOINT, settings.DAPR_API_VERSION) else: - return 'http://{}:{}/{}'.format(settings.DAPR_RUNTIME_HOST, - settings.DAPR_HTTP_PORT, settings.DAPR_API_VERSION) + return "http://{}:{}/{}".format( + settings.DAPR_RUNTIME_HOST, settings.DAPR_HTTP_PORT, settings.DAPR_API_VERSION + ) async def send_bytes( - self, method: str, url: str, - data: Optional[bytes], - headers: Dict[str, Union[bytes, str]] = {}, - query_params: Optional[Mapping] = None, - timeout: Optional[int] = None + self, + method: str, + url: str, + data: Optional[bytes], + headers: Dict[str, Union[bytes, str]] = {}, + query_params: Optional[Mapping] = None, + timeout: Optional[int] = None, ) -> Tuple[bytes, aiohttp.ClientResponse]: headers_map = headers if not headers_map.get(CONTENT_TYPE_HEADER): @@ -91,7 +96,8 @@ async def send_bytes( data=data, headers=headers_map, ssl=sslcontext, - params=query_params) + params=query_params, + ) if r.status >= 200 and r.status < 300: return await r.read(), r @@ -106,16 +112,20 @@ async def convert_to_error(self, response: aiohttp.ClientResponse) -> DaprIntern return DaprInternalError("Not Found", ERROR_CODE_DOES_NOT_EXIST) error_info = self._serializer.deserialize(error_body) except Exception: - return DaprInternalError(f'Unknown Dapr Error. HTTP status code: {response.status}', - raw_response_bytes=error_body) + return DaprInternalError( + f"Unknown Dapr Error. HTTP status code: {response.status}", + raw_response_bytes=error_body, + ) if error_info and isinstance(error_info, dict): - message = error_info.get('message') - error_code = error_info.get('errorCode') or ERROR_CODE_UNKNOWN + message = error_info.get("message") + error_code = error_info.get("errorCode") or ERROR_CODE_UNKNOWN return DaprInternalError(message, error_code, raw_response_bytes=error_body) - return DaprInternalError(f'Unknown Dapr Error. HTTP status code: {response.status}', - raw_response_bytes=error_body) + return DaprInternalError( + f"Unknown Dapr Error. HTTP status code: {response.status}", + raw_response_bytes=error_body, + ) def get_ssl_context(self): # This method is used (overwritten) from tests diff --git a/dapr/clients/http/dapr_actor_http_client.py b/dapr/clients/http/dapr_actor_http_client.py index f575b1913..ab38a16c3 100644 --- a/dapr/clients/http/dapr_actor_http_client.py +++ b/dapr/clients/http/dapr_actor_http_client.py @@ -21,17 +21,18 @@ from dapr.clients.http.client import DaprHttpClient from dapr.clients.base import DaprActorClientBase -DAPR_REENTRANCY_ID_HEADER = 'Dapr-Reentrancy-Id' +DAPR_REENTRANCY_ID_HEADER = "Dapr-Reentrancy-Id" class DaprActorHttpClient(DaprActorClientBase): """A Dapr Actor http client implementing :class:`DaprActorClientBase`""" def __init__( - self, - message_serializer: 'Serializer', - timeout: int = 60, - headers_callback: Optional[Callable[[], Dict[str, str]]] = None): + self, + message_serializer: "Serializer", + timeout: int = 60, + headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + ): """Invokes Dapr Actors over HTTP. Args: @@ -42,8 +43,8 @@ def __init__( self._client = DaprHttpClient(message_serializer, timeout, headers_callback) async def invoke_method( - self, actor_type: str, actor_id: str, - method: str, data: Optional[bytes] = None) -> bytes: + self, actor_type: str, actor_id: str, method: str, data: Optional[bytes] = None + ) -> bytes: """Invoke method defined in :class:`Actor` remotely. Args: @@ -55,21 +56,21 @@ async def invoke_method( Returns: bytes: the response from the actor. """ - url = f'{self._get_base_url(actor_type, actor_id)}/method/{method}' + url = f"{self._get_base_url(actor_type, actor_id)}/method/{method}" # import to avoid circular dependency from dapr.actor.runtime.reentrancy_context import reentrancy_ctx + reentrancy_id = reentrancy_ctx.get() headers: Dict[str, Union[bytes, str]] = ( - {DAPR_REENTRANCY_ID_HEADER: reentrancy_id} if reentrancy_id else {}) + {DAPR_REENTRANCY_ID_HEADER: reentrancy_id} if reentrancy_id else {} + ) - body, _ = await self._client.send_bytes(method='POST', url=url, data=data, headers=headers) + body, _ = await self._client.send_bytes(method="POST", url=url, data=data, headers=headers) return body - async def save_state_transactionally( - self, actor_type: str, actor_id: str, - data: bytes) -> None: + async def save_state_transactionally(self, actor_type: str, actor_id: str, data: bytes) -> None: """Save state transactionally. Args: @@ -77,11 +78,10 @@ async def save_state_transactionally( actor_id (str): Id of Actor type. data (bytes): Json-serialized the transactional state operations. """ - url = f'{self._get_base_url(actor_type, actor_id)}/state' - await self._client.send_bytes(method='PUT', url=url, data=data) + url = f"{self._get_base_url(actor_type, actor_id)}/state" + await self._client.send_bytes(method="PUT", url=url, data=data) - async def get_state( - self, actor_type: str, actor_id: str, name: str) -> bytes: + async def get_state(self, actor_type: str, actor_id: str, name: str) -> bytes: """Get state value for name key. Args: @@ -92,12 +92,13 @@ async def get_state( Returns: bytes: the value of the state. """ - url = f'{self._get_base_url(actor_type, actor_id)}/state/{name}' - body, _ = await self._client.send_bytes(method='GET', url=url, data=None) + url = f"{self._get_base_url(actor_type, actor_id)}/state/{name}" + body, _ = await self._client.send_bytes(method="GET", url=url, data=None) return body async def register_reminder( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + self, actor_type: str, actor_id: str, name: str, data: bytes + ) -> None: """Register actor reminder. Args: @@ -106,11 +107,10 @@ async def register_reminder( name (str): The name of reminder data (bytes): Reminder request json body. """ - url = f'{self._get_base_url(actor_type, actor_id)}/reminders/{name}' - await self._client.send_bytes(method='PUT', url=url, data=data) + url = f"{self._get_base_url(actor_type, actor_id)}/reminders/{name}" + await self._client.send_bytes(method="PUT", url=url, data=data) - async def unregister_reminder( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_reminder(self, actor_type: str, actor_id: str, name: str) -> None: """Unregister actor reminder. Args: @@ -118,11 +118,10 @@ async def unregister_reminder( actor_id (str): Id of Actor type. name (str): the name of reminder. """ - url = f'{self._get_base_url(actor_type, actor_id)}/reminders/{name}' - await self._client.send_bytes(method='DELETE', url=url, data=None) + url = f"{self._get_base_url(actor_type, actor_id)}/reminders/{name}" + await self._client.send_bytes(method="DELETE", url=url, data=None) - async def register_timer( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + async def register_timer(self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: """Register actor timer. Args: @@ -131,11 +130,10 @@ async def register_timer( name (str): The name of reminder. data (bytes): Timer request json body. """ - url = f'{self._get_base_url(actor_type, actor_id)}/timers/{name}' - await self._client.send_bytes(method='PUT', url=url, data=data) + url = f"{self._get_base_url(actor_type, actor_id)}/timers/{name}" + await self._client.send_bytes(method="PUT", url=url, data=data) - async def unregister_timer( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_timer(self, actor_type: str, actor_id: str, name: str) -> None: """Unregister actor timer. Args: @@ -143,11 +141,8 @@ async def unregister_timer( actor_id (str): Id of Actor type. name (str): The name of timer """ - url = f'{self._get_base_url(actor_type, actor_id)}/timers/{name}' - await self._client.send_bytes(method='DELETE', url=url, data=None) + url = f"{self._get_base_url(actor_type, actor_id)}/timers/{name}" + await self._client.send_bytes(method="DELETE", url=url, data=None) def _get_base_url(self, actor_type: str, actor_id: str) -> str: - return '{}/actors/{}/{}'.format( - self._client.get_api_url(), - actor_type, - actor_id) + return "{}/actors/{}/{}".format(self._client.get_api_url(), actor_type, actor_id) diff --git a/dapr/clients/http/dapr_invocation_http_client.py b/dapr/clients/http/dapr_invocation_http_client.py index 0010fad50..1cfae62de 100644 --- a/dapr/clients/http/dapr_invocation_http_client.py +++ b/dapr/clients/http/dapr_invocation_http_client.py @@ -24,18 +24,19 @@ from dapr.serializers import DefaultJSONSerializer from dapr.version import __version__ -USER_AGENT_HEADER = 'User-Agent' -DAPR_USER_AGENT = f'dapr-python-sdk/{__version__}' +USER_AGENT_HEADER = "User-Agent" +DAPR_USER_AGENT = f"dapr-python-sdk/{__version__}" class DaprInvocationHttpClient: """Service Invocation HTTP Client""" def __init__( - self, - timeout: int = 60, - headers_callback: Optional[Callable[[], Dict[str, str]]] = None, - address: Optional[str] = None): + self, + timeout: int = 60, + headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + address: Optional[str] = None, + ): """Invokes Dapr's API for method invocation over HTTP. Args: @@ -45,15 +46,16 @@ def __init__( self._client = DaprHttpClient(DefaultJSONSerializer(), timeout, headers_callback, address) async def invoke_method_async( - self, - app_id: str, - method_name: str, - data: Union[bytes, str, GrpcMessage], - content_type: Optional[str] = None, - metadata: Optional[MetadataTuple] = None, - http_verb: Optional[str] = None, - http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None) -> InvokeMethodResponse: + self, + app_id: str, + method_name: str, + data: Union[bytes, str, GrpcMessage], + content_type: Optional[str] = None, + metadata: Optional[MetadataTuple] = None, + http_verb: Optional[str] = None, + http_querystring: Optional[MetadataTuple] = None, + timeout: Optional[int] = None, + ) -> InvokeMethodResponse: """Invoke a service method over HTTP (async). Args: @@ -70,7 +72,7 @@ async def invoke_method_async( InvokeMethodResponse: the response from the method invocation. """ - verb = 'GET' + verb = "GET" if http_verb is not None: verb = http_verb @@ -89,12 +91,12 @@ async def invoke_method_async( headers[USER_AGENT_HEADER] = DAPR_USER_AGENT - url = f'{self._client.get_api_url()}/invoke/{app_id}/method/{method_name}' + url = f"{self._client.get_api_url()}/invoke/{app_id}/method/{method_name}" if isinstance(data, GrpcMessage): body = data.SerializeToString() elif isinstance(data, str): - body = data.encode('utf-8') + body = data.encode("utf-8") else: body = data @@ -105,7 +107,8 @@ async def make_request() -> InvokeMethodResponse: url=url, data=body, query_params=query_params, - timeout=timeout) + timeout=timeout, + ) respHeaders: MetadataTuple = tuple(r.headers.items()) @@ -113,8 +116,10 @@ async def make_request() -> InvokeMethodResponse: data=resp_body, content_type=r.content_type, headers=respHeaders, - status_code=r.status) + status_code=r.status, + ) return resp_data + return await make_request() def invoke_method( @@ -126,7 +131,7 @@ def invoke_method( metadata: Optional[MetadataTuple] = None, http_verb: Optional[str] = None, http_querystring: Optional[MetadataTuple] = None, - timeout: Optional[int] = None + timeout: Optional[int] = None, ) -> InvokeMethodResponse: """Invoke a service method over HTTP (async). @@ -151,12 +156,6 @@ def invoke_method( asyncio.set_event_loop(loop) awaitable = self.invoke_method_async( - app_id, - method_name, - data, - content_type, - metadata, - http_verb, - http_querystring, - timeout) + app_id, method_name, data, content_type, metadata, http_verb, http_querystring, timeout + ) return loop.run_until_complete(awaitable) diff --git a/dapr/conf/global_settings.py b/dapr/conf/global_settings.py index 5fe5647fe..0f1342371 100644 --- a/dapr/conf/global_settings.py +++ b/dapr/conf/global_settings.py @@ -21,11 +21,11 @@ DAPR_API_TOKEN = None DAPR_HTTP_ENDPOINT = None DAPR_GRPC_ENDPOINT = None -DAPR_RUNTIME_HOST = '127.0.0.1' +DAPR_RUNTIME_HOST = "127.0.0.1" DAPR_HTTP_PORT = 3500 DAPR_GRPC_PORT = 50001 -DAPR_API_VERSION = 'v1.0' +DAPR_API_VERSION = "v1.0" -DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' +DAPR_API_METHOD_INVOCATION_PROTOCOL = "http" DAPR_HTTP_TIMEOUT_SECONDS = 60 diff --git a/dapr/conf/helpers.py b/dapr/conf/helpers.py index 0342b7519..ff17839df 100644 --- a/dapr/conf/helpers.py +++ b/dapr/conf/helpers.py @@ -123,9 +123,11 @@ def _preprocess_uri(self, url: str) -> str: if len(url_list) == 3 and "://" not in url: # A URI like dns:mydomain:5000 or vsock:mycid:5000 was used url = url.replace(":", "://", 1) - elif len(url_list) >= 2 and "://" not in url and url_list[ - 0] in URIParseConfig.ACCEPTED_SCHEMES: - + elif ( + len(url_list) >= 2 + and "://" not in url + and url_list[0] in URIParseConfig.ACCEPTED_SCHEMES + ): # A URI like dns:mydomain or dns:[2001:db8:1f70::999:de8:7648:6e8]:mydomain was used # Possibly a URI like dns:[2001:db8:1f70::999:de8:7648:6e8]:mydomain was used url = url.replace(":", "://", 1) @@ -135,7 +137,7 @@ def _preprocess_uri(self, url: str) -> str: # If a scheme was not explicitly specified in the URL # we need to add a default scheme, # because of how urlparse works - url = f'{URIParseConfig.DEFAULT_SCHEME}://{url}' + url = f"{URIParseConfig.DEFAULT_SCHEME}://{url}" else: # If a scheme was explicitly specified in the URL # we need to make sure it is a valid scheme @@ -151,13 +153,13 @@ def _preprocess_uri(self, url: str) -> str: if len(url_list) < 4: raise ValueError(f"invalid dns authority '{url_list[2]}' in URL '{url}'") self._authority = url_list[2] - url = f'dns://{url_list[3]}' + url = f"dns://{url_list[3]}" return url def _set_tls(self): query_dict = parse_qs(self._parsed_url.query) - tls_str = query_dict.get('tls', [""])[0] - tls = tls_str.lower() == 'true' + tls_str = query_dict.get("tls", [""])[0] + tls = tls_str.lower() == "true" if self._parsed_url.scheme == "https": tls = True @@ -169,15 +171,19 @@ def tls(self) -> bool: def _validate_path_and_query(self) -> None: if self._parsed_url.path: - raise ValueError(f"paths are not supported for gRPC endpoints:" - f" '{self._parsed_url.path}'") + raise ValueError( + f"paths are not supported for gRPC endpoints:" f" '{self._parsed_url.path}'" + ) if self._parsed_url.query: query_dict = parse_qs(self._parsed_url.query) - if 'tls' in query_dict and self._parsed_url.scheme in ["http", "https"]: + if "tls" in query_dict and self._parsed_url.scheme in ["http", "https"]: raise ValueError( f"the tls query parameter is not supported for http(s) endpoints: " - f"'{self._parsed_url.query}'") - query_dict.pop('tls', None) + f"'{self._parsed_url.query}'" + ) + query_dict.pop("tls", None) if query_dict: - raise ValueError(f"query parameters are not supported for gRPC endpoints:" - f" '{self._parsed_url.query}'") + raise ValueError( + f"query parameters are not supported for gRPC endpoints:" + f" '{self._parsed_url.query}'" + ) diff --git a/dapr/serializers/__init__.py b/dapr/serializers/__init__.py index e51cf27e0..46a53bc1d 100644 --- a/dapr/serializers/__init__.py +++ b/dapr/serializers/__init__.py @@ -16,7 +16,4 @@ from dapr.serializers.base import Serializer from dapr.serializers.json import DefaultJSONSerializer -__all__ = [ - 'Serializer', - 'DefaultJSONSerializer' -] +__all__ = ["Serializer", "DefaultJSONSerializer"] diff --git a/dapr/serializers/base.py b/dapr/serializers/base.py index f10c8149c..5ff1d9e8b 100644 --- a/dapr/serializers/base.py +++ b/dapr/serializers/base.py @@ -22,12 +22,15 @@ class Serializer(ABC): @abstractmethod def serialize( - self, obj: object, - custom_hook: Optional[Callable[[object], bytes]] = None) -> bytes: + self, obj: object, custom_hook: Optional[Callable[[object], bytes]] = None + ) -> bytes: ... @abstractmethod def deserialize( - self, data: bytes, data_type: Optional[Type] = object, - custom_hook: Optional[Callable[[bytes], object]] = None) -> Any: + self, + data: bytes, + data_type: Optional[Type] = object, + custom_hook: Optional[Callable[[bytes], object]] = None, + ) -> Any: ... diff --git a/dapr/serializers/json.py b/dapr/serializers/json.py index 594860dda..22f1921b7 100644 --- a/dapr/serializers/json.py +++ b/dapr/serializers/json.py @@ -25,7 +25,7 @@ from dapr.serializers.util import ( convert_from_dapr_duration, convert_to_dapr_duration, - DAPR_DURATION_PARSER + DAPR_DURATION_PARSER, ) @@ -34,35 +34,34 @@ def __init__(self, ensure_ascii: bool = True) -> None: self.ensure_ascii = ensure_ascii def serialize( - self, obj: object, - custom_hook: Optional[Callable[[object], bytes]] = None) -> bytes: - + self, obj: object, custom_hook: Optional[Callable[[object], bytes]] = None + ) -> bytes: dict_obj = obj # importing this from top scope creates a circular import from dapr.actor.runtime.config import ActorRuntimeConfig + if callable(custom_hook): dict_obj = custom_hook(obj) elif isinstance(obj, bytes): - dict_obj = base64.b64encode(obj).decode('utf-8') + dict_obj = base64.b64encode(obj).decode("utf-8") elif isinstance(obj, ActorRuntimeConfig): dict_obj = obj.as_dict() serialized = json.dumps( - dict_obj, - cls=DaprJSONEncoder, - separators=(',', ':'), - ensure_ascii=self.ensure_ascii + dict_obj, cls=DaprJSONEncoder, separators=(",", ":"), ensure_ascii=self.ensure_ascii ) - return serialized.encode('utf-8') + return serialized.encode("utf-8") def deserialize( - self, data: bytes, data_type: Optional[Type] = object, - custom_hook: Optional[Callable[[bytes], object]] = None) -> Any: - + self, + data: bytes, + data_type: Optional[Type] = object, + custom_hook: Optional[Callable[[bytes], object]] = None, + ) -> Any: if not isinstance(data, (str, bytes)): - raise ValueError('data must be str or bytes types') + raise ValueError("data must be str or bytes types") obj = json.loads(data, cls=DaprJSONDecoder) @@ -76,22 +75,22 @@ def default(self, obj): r = obj.isoformat() if obj.microsecond: r = r[:23] + r[26:] - if r.endswith('+00:00'): - r = r[:-6] + 'Z' + if r.endswith("+00:00"): + r = r[:-6] + "Z" return r elif isinstance(obj, datetime.date): return obj.isoformat() elif isinstance(obj, datetime.timedelta): return convert_to_dapr_duration(obj) elif isinstance(obj, bytes): - return base64.b64encode(obj).decode('utf-8') + return base64.b64encode(obj).decode("utf-8") else: return json.JSONEncoder.default(self, obj) class DaprJSONDecoder(json.JSONDecoder): # TODO: improve regex - datetime_regex = re.compile(r'(\d{4}[-/]\d{2}[-/]\d{2})') + datetime_regex = re.compile(r"(\d{4}[-/]\d{2}[-/]\d{2})") def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, *args, **kwargs) diff --git a/dapr/serializers/util.py b/dapr/serializers/util.py index d184dc37f..ed101f4cb 100644 --- a/dapr/serializers/util.py +++ b/dapr/serializers/util.py @@ -18,7 +18,8 @@ # Regex to parse Go Duration datatype, e.g. 4h15m50s123ms345μs DAPR_DURATION_PARSER = re.compile( - r'((?P\d+)h)?((?P\d+)m)?((?P\d+)s)?((?P\d+)ms)?((?P\d+)(μs|us))?$') # noqa: E501 + r"((?P\d+)h)?((?P\d+)m)?((?P\d+)s)?((?P\d+)ms)?((?P\d+)(μs|us))?$" +) # noqa: E501 def convert_from_dapr_duration(duration: str) -> timedelta: @@ -33,19 +34,21 @@ def convert_from_dapr_duration(duration: str) -> timedelta: matched = DAPR_DURATION_PARSER.match(duration) if not matched or matched.lastindex == 0: - raise ValueError(f'Invalid Dapr Duration format: \'{duration}\'') + raise ValueError(f"Invalid Dapr Duration format: '{duration}'") days = 0.0 hours = 0.0 - if matched.group('hours') is not None: - days, hours = divmod(float(matched.group('hours')), 24) - mins = 0.0 if not matched.group('mins') else float(matched.group('mins')) - seconds = 0.0 if not matched.group('seconds') else float(matched.group('seconds')) - milliseconds = 0.0 if not matched.group( - 'milliseconds') else float(matched.group('milliseconds')) - microseconds = 0.0 if not matched.group( - 'microseconds') else float(matched.group('microseconds')) + if matched.group("hours") is not None: + days, hours = divmod(float(matched.group("hours")), 24) + mins = 0.0 if not matched.group("mins") else float(matched.group("mins")) + seconds = 0.0 if not matched.group("seconds") else float(matched.group("seconds")) + milliseconds = ( + 0.0 if not matched.group("milliseconds") else float(matched.group("milliseconds")) + ) + microseconds = ( + 0.0 if not matched.group("microseconds") else float(matched.group("microseconds")) + ) return timedelta( days=days, @@ -53,7 +56,7 @@ def convert_from_dapr_duration(duration: str) -> timedelta: minutes=mins, seconds=seconds, milliseconds=milliseconds, - microseconds=microseconds + microseconds=microseconds, ) @@ -71,4 +74,4 @@ def convert_to_dapr_duration(td: timedelta) -> str: milliseconds, microseconds = divmod(td.microseconds, 1000.0) hours, mins = divmod(total_minutes, 60.0) - return f'{hours:.0f}h{mins:.0f}m{seconds:.0f}s{milliseconds:.0f}ms{microseconds:.0f}μs' + return f"{hours:.0f}h{mins:.0f}m{seconds:.0f}s{milliseconds:.0f}ms{microseconds:.0f}μs" diff --git a/dapr/version/__init__.py b/dapr/version/__init__.py index 22a4219c6..5d04bf3f3 100644 --- a/dapr/version/__init__.py +++ b/dapr/version/__init__.py @@ -1,5 +1,3 @@ from dapr.version.version import __version__ -__all__ = [ - '__version__' -] +__all__ = ["__version__"] diff --git a/dev-requirements.txt b/dev-requirements.txt index e4bfff131..bf648667c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -11,3 +11,5 @@ httpx>=0.24 pyOpenSSL>=23.2.0 # needed for type checking Flask>=1.1 +# needed for auto fix +ruff>=0.1.11 \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 22e76e094..221beb963 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,14 +12,15 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../dapr')) + +sys.path.insert(0, os.path.abspath("../dapr")) # -- Project information ----------------------------------------------------- -project = 'dapr-python-sdk' -copyright = '2020, dapr' -author = 'dapr' +project = "dapr-python-sdk" +copyright = "2020, dapr" +author = "dapr" # -- General configuration --------------------------------------------------- @@ -27,8 +28,12 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', - 'sphinx.ext.ifconfig', 'sphinx.ext.napoleon', ] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.ifconfig", + "sphinx.ext.napoleon", +] # Napoleon settings napoleon_google_docstring = True @@ -44,12 +49,12 @@ napoleon_use_rtype = True autodoc_mock_imports = ["dapr"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -57,9 +62,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'classic' +html_theme = "classic" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/examples/configuration/configuration.py b/examples/configuration/configuration.py index b794c09a8..653ef590b 100644 --- a/examples/configuration/configuration.py +++ b/examples/configuration/configuration.py @@ -9,17 +9,22 @@ configuration: ConfigurationWatcher = ConfigurationWatcher() + def handler(id: str, resp: ConfigurationResponse): for key in resp.items: - print(f"Subscribe key={key} value={resp.items[key].value} " - f"version={resp.items[key].version} " - f"metadata={resp.items[key].metadata}", flush=True) + print( + f"Subscribe key={key} value={resp.items[key].value} " + f"version={resp.items[key].version} " + f"metadata={resp.items[key].metadata}", + flush=True, + ) + async def executeConfiguration(): with DaprClient() as d: - storeName = 'configurationstore' + storeName = "configurationstore" - keys = ['orderId1', 'orderId2'] + keys = ["orderId1", "orderId2"] # Wait for sidecar to be up within 20 seconds. d.wait(20) @@ -29,14 +34,18 @@ async def executeConfiguration(): # Get one configuration by key. configuration = d.get_configuration(store_name=storeName, keys=keys, config_metadata={}) for key in configuration.items: - print(f"Got key={key} " - f"value={configuration.items[key].value} " - f"version={configuration.items[key].version} " - f"metadata={configuration.items[key].metadata}", flush=True) + print( + f"Got key={key} " + f"value={configuration.items[key].value} " + f"version={configuration.items[key].version} " + f"metadata={configuration.items[key].metadata}", + flush=True, + ) # Subscribe to configuration for keys {orderId1,orderId2}. - id = d.subscribe_configuration(store_name=storeName, keys=keys, - handler=handler, config_metadata={}) + id = d.subscribe_configuration( + store_name=storeName, keys=keys, handler=handler, config_metadata={} + ) print("Subscription ID is", id, flush=True) sleep(10) @@ -44,4 +53,5 @@ async def executeConfiguration(): isSuccess = d.unsubscribe_configuration(store_name=storeName, id=id) print(f"Unsubscribed successfully? {isSuccess}", flush=True) + asyncio.run(executeConfiguration()) diff --git a/examples/demo_actor/demo_actor/demo_actor.py b/examples/demo_actor/demo_actor/demo_actor.py index 329446cde..e0efd07f0 100644 --- a/examples/demo_actor/demo_actor/demo_actor.py +++ b/examples/demo_actor/demo_actor/demo_actor.py @@ -33,28 +33,28 @@ def __init__(self, ctx, actor_id): async def _on_activate(self) -> None: """An callback which will be called whenever actor is activated.""" - print(f'Activate {self.__class__.__name__} actor!', flush=True) + print(f"Activate {self.__class__.__name__} actor!", flush=True) async def _on_deactivate(self) -> None: """An callback which will be called whenever actor is deactivated.""" - print(f'Deactivate {self.__class__.__name__} actor!', flush=True) + print(f"Deactivate {self.__class__.__name__} actor!", flush=True) async def get_my_data(self) -> object: """An actor method which gets mydata state value.""" - has_value, val = await self._state_manager.try_get_state('mydata') - print(f'has_value: {has_value}', flush=True) + has_value, val = await self._state_manager.try_get_state("mydata") + print(f"has_value: {has_value}", flush=True) return val async def set_my_data(self, data) -> None: """An actor method which set mydata state value.""" - print(f'set_my_data: {data}', flush=True) - data['ts'] = datetime.datetime.now(datetime.timezone.utc) - await self._state_manager.set_state('mydata', data) + print(f"set_my_data: {data}", flush=True) + data["ts"] = datetime.datetime.now(datetime.timezone.utc) + await self._state_manager.set_state("mydata", data) await self._state_manager.save_state() async def clear_my_data(self) -> None: - print('clear_my_data', flush=True) - await self._state_manager.remove_state('mydata') + print("clear_my_data", flush=True) + await self._state_manager.remove_state("mydata") await self._state_manager.save_state() async def set_reminder(self, enabled) -> None: @@ -63,20 +63,21 @@ async def set_reminder(self, enabled) -> None: Args: enabled (bool): the flag to enable and disable demo_reminder. """ - print(f'set reminder to {enabled}', flush=True) + print(f"set reminder to {enabled}", flush=True) if enabled: # Register 'demo_reminder' reminder and call receive_reminder method await self.register_reminder( - 'demo_reminder', # reminder name - b'reminder_state', # user_state (bytes) + "demo_reminder", # reminder name + b"reminder_state", # user_state (bytes) # The amount of time to delay before firing the reminder datetime.timedelta(seconds=5), datetime.timedelta(seconds=5), # The time interval between firing of reminders - datetime.timedelta(seconds=5)) + datetime.timedelta(seconds=5), + ) else: # Unregister 'demo_reminder' - await self.unregister_reminder('demo_reminder') - print('set reminder is done', flush=True) + await self.unregister_reminder("demo_reminder") + print("set reminder is done", flush=True) async def set_timer(self, enabled) -> None: """Enables and disables a timer. @@ -84,21 +85,22 @@ async def set_timer(self, enabled) -> None: Args: enabled (bool): the flag to enable and disable demo_timer. """ - print(f'set_timer to {enabled}', flush=True) + print(f"set_timer to {enabled}", flush=True) if enabled: # Register 'demo_timer' timer and call timer_callback method await self.register_timer( - 'demo_timer', # timer name - self.timer_callback, # Callback method - 'timer_state', # Parameter to pass to the callback method + "demo_timer", # timer name + self.timer_callback, # Callback method + "timer_state", # Parameter to pass to the callback method # Amount of time to delay before the callback is invoked datetime.timedelta(seconds=5), datetime.timedelta(seconds=5), # Time interval between invocations - datetime.timedelta(seconds=5)) + datetime.timedelta(seconds=5), + ) else: # Unregister 'demo_timer' - await self.unregister_timer('demo_timer') - print('set_timer is done', flush=True) + await self.unregister_timer("demo_timer") + print("set_timer is done", flush=True) async def timer_callback(self, state) -> None: """A callback which will be called whenever timer is triggered. @@ -106,15 +108,21 @@ async def timer_callback(self, state) -> None: Args: state (object): an object which is defined when timer is registered. """ - print(f'time_callback is called - {state}', flush=True) - - async def receive_reminder(self, name: str, state: bytes, - due_time: datetime.timedelta, period: datetime.timedelta, - ttl: Optional[datetime.timedelta] = None) -> None: + print(f"time_callback is called - {state}", flush=True) + + async def receive_reminder( + self, + name: str, + state: bytes, + due_time: datetime.timedelta, + period: datetime.timedelta, + ttl: Optional[datetime.timedelta] = None, + ) -> None: """A callback which will be called when reminder is triggered.""" - print(f'receive_reminder is called - {name} reminder - {str(state)}', flush=True) + print(f"receive_reminder is called - {name} reminder - {str(state)}", flush=True) async def get_reentrancy_status(self) -> bool: """For Testing Only: An actor method which gets reentrancy status.""" from dapr.actor.runtime.reentrancy_context import reentrancy_ctx + return reentrancy_ctx.get(None) is not None diff --git a/examples/demo_actor/demo_actor/demo_actor_client.py b/examples/demo_actor/demo_actor/demo_actor_client.py index a5ba14e78..3faba7c39 100644 --- a/examples/demo_actor/demo_actor/demo_actor_client.py +++ b/examples/demo_actor/demo_actor/demo_actor_client.py @@ -18,7 +18,7 @@ async def main(): # Create proxy client - proxy = ActorProxy.create('DemoActor', ActorId('1'), DemoActorInterface) + proxy = ActorProxy.create("DemoActor", ActorId("1"), DemoActorInterface) # ----------------------------------------------- # Actor invocation demo @@ -33,14 +33,14 @@ async def main(): print(rtn_obj, flush=True) # Check actor is reentrant is_reentrant = await proxy.invoke_method("GetReentrancyStatus") - print(f'Actor reentrancy enabled: {str(is_reentrant)}', flush=True) + print(f"Actor reentrancy enabled: {str(is_reentrant)}", flush=True) # ----------------------------------------------- # Actor state management demo # ----------------------------------------------- # Invoke SetMyData actor method to save the state print("call SetMyData actor method to save the state", flush=True) - await proxy.SetMyData({'data': 'new_data'}) + await proxy.SetMyData({"data": "new_data"}) # Invoke GetMyData actor method to get the state print("call GetMyData actor method to get the state", flush=True) rtn_obj = await proxy.GetMyData() @@ -75,5 +75,4 @@ async def main(): await proxy.ClearMyData() - asyncio.run(main()) diff --git a/examples/demo_actor/demo_actor/demo_actor_flask.py b/examples/demo_actor/demo_actor/demo_actor_flask.py index bc627ea16..00703485f 100644 --- a/examples/demo_actor/demo_actor/demo_actor_flask.py +++ b/examples/demo_actor/demo_actor/demo_actor_flask.py @@ -18,16 +18,14 @@ from dapr.actor.runtime.runtime import ActorRuntime from demo_actor import DemoActor -app = Flask(f'{DemoActor.__name__}Service') +app = Flask(f"{DemoActor.__name__}Service") # This is an optional advanced configuration which enables reentrancy only for the # specified actor type. By default reentrancy is not enabled for all actor types. config = ActorRuntimeConfig() # init with default values -config.update_actor_type_configs([ - ActorTypeConfig( - actor_type=DemoActor.__name__, - reentrancy=ActorReentrancyConfig(enabled=True)) -]) +config.update_actor_type_configs( + [ActorTypeConfig(actor_type=DemoActor.__name__, reentrancy=ActorReentrancyConfig(enabled=True))] +) ActorRuntime.set_actor_config(config) # Enable DaprActor Flask extension @@ -37,10 +35,10 @@ # This route is optional. -@app.route('/') +@app.route("/") def index(): - return jsonify({'status': 'ok'}), 200 + return jsonify({"status": "ok"}), 200 -if __name__ == '__main__': +if __name__ == "__main__": app.run(port=settings.HTTP_APP_PORT) diff --git a/examples/demo_actor/demo_actor/demo_actor_service.py b/examples/demo_actor/demo_actor/demo_actor_service.py index d02647421..2fdecf6d7 100644 --- a/examples/demo_actor/demo_actor/demo_actor_service.py +++ b/examples/demo_actor/demo_actor/demo_actor_service.py @@ -17,16 +17,14 @@ from demo_actor import DemoActor -app = FastAPI(title=f'{DemoActor.__name__}Service') +app = FastAPI(title=f"{DemoActor.__name__}Service") # This is an optional advanced configuration which enables reentrancy only for the # specified actor type. By default reentrancy is not enabled for all actor types. config = ActorRuntimeConfig() # init with default values -config.update_actor_type_configs([ - ActorTypeConfig( - actor_type=DemoActor.__name__, - reentrancy=ActorReentrancyConfig(enabled=True)) -]) +config.update_actor_type_configs( + [ActorTypeConfig(actor_type=DemoActor.__name__, reentrancy=ActorReentrancyConfig(enabled=True))] +) ActorRuntime.set_actor_config(config) # Add Dapr Actor Extension diff --git a/examples/demo_workflow/app.py b/examples/demo_workflow/app.py index 8fe3c25a7..ef0d6a558 100644 --- a/examples/demo_workflow/app.py +++ b/examples/demo_workflow/app.py @@ -12,7 +12,12 @@ from datetime import timedelta from time import sleep -from dapr.ext.workflow import WorkflowRuntime, DaprWorkflowContext, WorkflowActivityContext, RetryPolicy +from dapr.ext.workflow import ( + WorkflowRuntime, + DaprWorkflowContext, + WorkflowActivityContext, + RetryPolicy, +) from dapr.conf import Settings from dapr.clients import DaprClient from dapr.clients.exceptions import DaprInternalError @@ -34,16 +39,17 @@ event_data = "eventData" non_existent_id_error = "no such instance exists" -retry_policy=RetryPolicy(first_retry_interval=timedelta(seconds=1), - max_number_of_attempts=3, - backoff_coefficient=2, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=100) - ) +retry_policy = RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=100), +) def hello_world_wf(ctx: DaprWorkflowContext, wf_input): - print(f'{wf_input}') + print(f"{wf_input}") yield ctx.call_activity(hello_act, input=1) yield ctx.call_activity(hello_act, input=10) yield ctx.call_activity(hello_retryable_act, retry_policy=retry_policy) @@ -56,16 +62,16 @@ def hello_world_wf(ctx: DaprWorkflowContext, wf_input): def hello_act(ctx: WorkflowActivityContext, wf_input): global counter counter += wf_input - print(f'New counter value is: {counter}!', flush=True) + print(f"New counter value is: {counter}!", flush=True) def hello_retryable_act(ctx: WorkflowActivityContext): global retry_count if (retry_count % 2) == 0: - print(f'Retry count value is: {retry_count}!', flush=True) + print(f"Retry count value is: {retry_count}!", flush=True) retry_count += 1 raise ValueError("Retryable Error") - print(f'Retry count value is: {retry_count}! This print statement verifies retry', flush=True) + print(f"Retry count value is: {retry_count}! This print statement verifies retry", flush=True) retry_count += 1 @@ -73,19 +79,21 @@ def child_wf(ctx: DaprWorkflowContext): global child_orchestrator_string, child_orchestrator_count if not ctx.is_replaying: child_orchestrator_count += 1 - print(f'Appending {child_orchestrator_count} to child_orchestrator_string!', flush=True) + print(f"Appending {child_orchestrator_count} to child_orchestrator_string!", flush=True) child_orchestrator_string += str(child_orchestrator_count) - yield ctx.call_activity(act_for_child_wf, input=child_orchestrator_count, retry_policy=retry_policy) - if (child_orchestrator_count < 3): + yield ctx.call_activity( + act_for_child_wf, input=child_orchestrator_count, retry_policy=retry_policy + ) + if child_orchestrator_count < 3: raise ValueError("Retryable Error") def act_for_child_wf(ctx: WorkflowActivityContext, inp): global child_orchestrator_string, child_act_retry_count - inp_char = chr(96+inp) - print(f'Appending {inp_char} to child_orchestrator_string!', flush=True) + inp_char = chr(96 + inp) + print(f"Appending {inp_char} to child_orchestrator_string!", flush=True) child_orchestrator_string += inp_char - if (child_act_retry_count %2 == 0): + if child_act_retry_count % 2 == 0: child_act_retry_count += 1 raise ValueError("Retryable Error") child_act_retry_count += 1 @@ -104,10 +112,13 @@ def main(): sleep(2) print("==========Start Counter Increase as per Input:==========") - start_resp = d.start_workflow(instance_id=instance_id, - workflow_component=workflow_component, - workflow_name=workflow_name, input=input_data, - workflow_options=workflow_options) + start_resp = d.start_workflow( + instance_id=instance_id, + workflow_component=workflow_component, + workflow_name=workflow_name, + input=input_data, + workflow_options=workflow_options, + ) print(f"start_resp {start_resp.instance_id}") # Sleep for a while to let the workflow run @@ -118,20 +129,26 @@ def main(): # Pause Test d.pause_workflow(instance_id=instance_id, workflow_component=workflow_component) - get_response = d.get_workflow(instance_id=instance_id, - workflow_component=workflow_component) + get_response = d.get_workflow( + instance_id=instance_id, workflow_component=workflow_component + ) print(f"Get response from {workflow_name} after pause call: {get_response.runtime_status}") # Resume Test d.resume_workflow(instance_id=instance_id, workflow_component=workflow_component) - get_response = d.get_workflow(instance_id=instance_id, - workflow_component=workflow_component) + get_response = d.get_workflow( + instance_id=instance_id, workflow_component=workflow_component + ) print(f"Get response from {workflow_name} after resume call: {get_response.runtime_status}") sleep(1) # Raise event - d.raise_workflow_event(instance_id=instance_id, workflow_component=workflow_component, - event_name=event_name, event_data=event_data) + d.raise_workflow_event( + instance_id=instance_id, + workflow_component=workflow_component, + event_name=event_name, + event_data=event_data, + ) sleep(5) # Purge Test @@ -145,19 +162,25 @@ def main(): # Kick off another workflow for termination purposes # This will also test using the same instance ID on a new workflow after # the old instance was purged - start_resp = d.start_workflow(instance_id=instance_id, - workflow_component=workflow_component, - workflow_name=workflow_name, input=input_data, - workflow_options=workflow_options) + start_resp = d.start_workflow( + instance_id=instance_id, + workflow_component=workflow_component, + workflow_name=workflow_name, + input=input_data, + workflow_options=workflow_options, + ) print(f"start_resp {start_resp.instance_id}") # Terminate Test d.terminate_workflow(instance_id=instance_id, workflow_component=workflow_component) sleep(1) - get_response = d.get_workflow(instance_id=instance_id, - workflow_component=workflow_component) - print(f"Get response from {workflow_name} " - f"after terminate call: {get_response.runtime_status}") + get_response = d.get_workflow( + instance_id=instance_id, workflow_component=workflow_component + ) + print( + f"Get response from {workflow_name} " + f"after terminate call: {get_response.runtime_status}" + ) # Purge Test d.purge_workflow(instance_id=instance_id, workflow_component=workflow_component) @@ -170,5 +193,5 @@ def main(): workflow_runtime.shutdown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/distributed_lock/lock.py b/examples/distributed_lock/lock.py index 4f7ec848c..932d41e15 100644 --- a/examples/distributed_lock/lock.py +++ b/examples/distributed_lock/lock.py @@ -17,28 +17,28 @@ def main(): # Lock parameters - store_name = 'lockstore' # as defined in components/lockstore.yaml - resource_id = 'example-lock-resource' - client_id = 'example-client-id' + store_name = "lockstore" # as defined in components/lockstore.yaml + resource_id = "example-lock-resource" + client_id = "example-client-id" expiry_in_seconds = 60 with DaprClient() as dapr: - print('Will try to acquire a lock from lock store named [%s]' % store_name) - print('The lock is for a resource named [%s]' % resource_id) - print('The client identifier is [%s]' % client_id) - print('The lock will will expire in %s seconds.' % expiry_in_seconds) + print("Will try to acquire a lock from lock store named [%s]" % store_name) + print("The lock is for a resource named [%s]" % resource_id) + print("The client identifier is [%s]" % client_id) + print("The lock will will expire in %s seconds." % expiry_in_seconds) with dapr.try_lock(store_name, resource_id, client_id, expiry_in_seconds) as lock_result: - assert lock_result.success, 'Failed to acquire the lock. Aborting.' - print('Lock acquired successfully!!!') + assert lock_result.success, "Failed to acquire the lock. Aborting." + print("Lock acquired successfully!!!") # At this point the lock was released - by magic of the `with` clause ;) unlock_result = dapr.unlock(store_name, resource_id, client_id) - print('We already released the lock so unlocking will not work.') - print('We tried to unlock it anyway and got back [%s]' % unlock_result.status) + print("We already released the lock so unlocking will not work.") + print("We tried to unlock it anyway and got back [%s]" % unlock_result.status) -if __name__ == '__main__': +if __name__ == "__main__": # Suppress "The Distributed Lock API is an Alpha" warnings warnings.simplefilter("ignore") main() diff --git a/examples/grpc_proxying/helloworld_service_pb2.py b/examples/grpc_proxying/helloworld_service_pb2.py index 0b47317d7..7a8d16756 100644 --- a/examples/grpc_proxying/helloworld_service_pb2.py +++ b/examples/grpc_proxying/helloworld_service_pb2.py @@ -12,36 +12,42 @@ _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18helloworld_service.proto\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2=\n\x11HelloWorldService\x12(\n\x08SayHello\x12\r.HelloRequest\x1a\x0b.HelloReply\"\x00\x62\x06proto3') - - - -_HELLOREQUEST = DESCRIPTOR.message_types_by_name['HelloRequest'] -_HELLOREPLY = DESCRIPTOR.message_types_by_name['HelloReply'] -HelloRequest = _reflection.GeneratedProtocolMessageType('HelloRequest', (_message.Message,), { - 'DESCRIPTOR' : _HELLOREQUEST, - '__module__' : 'helloworld_service_pb2' - # @@protoc_insertion_point(class_scope:HelloRequest) - }) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x18helloworld_service.proto"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2=\n\x11HelloWorldService\x12(\n\x08SayHello\x12\r.HelloRequest\x1a\x0b.HelloReply"\x00\x62\x06proto3' +) + + +_HELLOREQUEST = DESCRIPTOR.message_types_by_name["HelloRequest"] +_HELLOREPLY = DESCRIPTOR.message_types_by_name["HelloReply"] +HelloRequest = _reflection.GeneratedProtocolMessageType( + "HelloRequest", + (_message.Message,), + { + "DESCRIPTOR": _HELLOREQUEST, + "__module__": "helloworld_service_pb2", + # @@protoc_insertion_point(class_scope:HelloRequest) + }, +) _sym_db.RegisterMessage(HelloRequest) -HelloReply = _reflection.GeneratedProtocolMessageType('HelloReply', (_message.Message,), { - 'DESCRIPTOR' : _HELLOREPLY, - '__module__' : 'helloworld_service_pb2' - # @@protoc_insertion_point(class_scope:HelloReply) - }) +HelloReply = _reflection.GeneratedProtocolMessageType( + "HelloReply", + (_message.Message,), + { + "DESCRIPTOR": _HELLOREPLY, + "__module__": "helloworld_service_pb2", + # @@protoc_insertion_point(class_scope:HelloReply) + }, +) _sym_db.RegisterMessage(HelloReply) -_HELLOWORLDSERVICE = DESCRIPTOR.services_by_name['HelloWorldService'] +_HELLOWORLDSERVICE = DESCRIPTOR.services_by_name["HelloWorldService"] if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _HELLOREQUEST._serialized_start=28 - _HELLOREQUEST._serialized_end=56 - _HELLOREPLY._serialized_start=58 - _HELLOREPLY._serialized_end=87 - _HELLOWORLDSERVICE._serialized_start=89 - _HELLOWORLDSERVICE._serialized_end=150 + DESCRIPTOR._options = None + _HELLOREQUEST._serialized_start = 28 + _HELLOREQUEST._serialized_end = 56 + _HELLOREPLY._serialized_start = 58 + _HELLOREPLY._serialized_end = 87 + _HELLOWORLDSERVICE._serialized_start = 89 + _HELLOWORLDSERVICE._serialized_end = 150 # @@protoc_insertion_point(module_scope) diff --git a/examples/grpc_proxying/helloworld_service_pb2_grpc.py b/examples/grpc_proxying/helloworld_service_pb2_grpc.py index c54213420..841124b33 100644 --- a/examples/grpc_proxying/helloworld_service_pb2_grpc.py +++ b/examples/grpc_proxying/helloworld_service_pb2_grpc.py @@ -6,8 +6,7 @@ class HelloWorldServiceStub(object): - """The greeting service definition. - """ + """The greeting service definition.""" def __init__(self, channel): """Constructor. @@ -16,55 +15,63 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.SayHello = channel.unary_unary( - '/HelloWorldService/SayHello', - request_serializer=helloworld__service__pb2.HelloRequest.SerializeToString, - response_deserializer=helloworld__service__pb2.HelloReply.FromString, - ) + "/HelloWorldService/SayHello", + request_serializer=helloworld__service__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__service__pb2.HelloReply.FromString, + ) class HelloWorldServiceServicer(object): - """The greeting service definition. - """ + """The greeting service definition.""" def SayHello(self, request, context): - """Sends a greeting - """ + """Sends a greeting""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") def add_HelloWorldServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'SayHello': grpc.unary_unary_rpc_method_handler( - servicer.SayHello, - request_deserializer=helloworld__service__pb2.HelloRequest.FromString, - response_serializer=helloworld__service__pb2.HelloReply.SerializeToString, - ), + "SayHello": grpc.unary_unary_rpc_method_handler( + servicer.SayHello, + request_deserializer=helloworld__service__pb2.HelloRequest.FromString, + response_serializer=helloworld__service__pb2.HelloReply.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler( - 'HelloWorldService', rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler("HelloWorldService", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class HelloWorldService(object): - """The greeting service definition. - """ + """The greeting service definition.""" @staticmethod - def SayHello(request, + def SayHello( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/HelloWorldService/SayHello', + "/HelloWorldService/SayHello", helloworld__service__pb2.HelloRequest.SerializeToString, helloworld__service__pb2.HelloReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/examples/grpc_proxying/invoke-caller.py b/examples/grpc_proxying/invoke-caller.py index 11d31a323..aec579d47 100644 --- a/examples/grpc_proxying/invoke-caller.py +++ b/examples/grpc_proxying/invoke-caller.py @@ -6,15 +6,16 @@ from helloworld_service_pb2 import HelloRequest, HelloReply import json, time + async def run() -> None: - async with grpc.aio.insecure_channel('127.0.0.1:50007') as channel: - metadata = (('dapr-app-id', 'invoke-receiver'),) + async with grpc.aio.insecure_channel("127.0.0.1:50007") as channel: + metadata = (("dapr-app-id", "invoke-receiver"),) stub = helloworld_service_pb2_grpc.HelloWorldServiceStub(channel) - response = await stub.SayHello(request=HelloRequest(name='you'), metadata=metadata) + response = await stub.SayHello(request=HelloRequest(name="you"), metadata=metadata) print("Greeter client received: " + response.message) -if __name__ == '__main__': - print('I am in main') + +if __name__ == "__main__": + print("I am in main") logging.basicConfig() asyncio.run(run()) - diff --git a/examples/grpc_proxying/invoke-receiver.py b/examples/grpc_proxying/invoke-receiver.py index 2156ca9a4..65301be22 100644 --- a/examples/grpc_proxying/invoke-receiver.py +++ b/examples/grpc_proxying/invoke-receiver.py @@ -6,17 +6,19 @@ from dapr.ext.grpc import App import json + class HelloWorldService(helloworld_service_pb2_grpc.HelloWorldService): - def SayHello( - self, request: HelloRequest, - context: grpc.aio.ServicerContext) -> HelloReply: + def SayHello(self, request: HelloRequest, context: grpc.aio.ServicerContext) -> HelloReply: logging.info(request) - return HelloReply(message='Hello, %s!' % request.name) + return HelloReply(message="Hello, %s!" % request.name) + app = App() -if __name__ == '__main__': - print('starting the HelloWorld Service') +if __name__ == "__main__": + print("starting the HelloWorld Service") logging.basicConfig(level=logging.INFO) - app.add_external_service(helloworld_service_pb2_grpc.add_HelloWorldServiceServicer_to_server, HelloWorldService()) - app.run(50051) \ No newline at end of file + app.add_external_service( + helloworld_service_pb2_grpc.add_HelloWorldServiceServicer_to_server, HelloWorldService() + ) + app.run(50051) diff --git a/examples/invoke-binding/invoke-input-binding.py b/examples/invoke-binding/invoke-input-binding.py index a1f086ca9..dd8162cb5 100644 --- a/examples/invoke-binding/invoke-input-binding.py +++ b/examples/invoke-binding/invoke-input-binding.py @@ -2,8 +2,10 @@ app = App() -@app.binding('kafkaBinding') + +@app.binding("kafkaBinding") def binding(request: BindingRequest): print(request.text(), flush=True) -app.run(50051) \ No newline at end of file + +app.run(50051) diff --git a/examples/invoke-binding/invoke-output-binding.py b/examples/invoke-binding/invoke-output-binding.py index a256c80ca..da5918efe 100644 --- a/examples/invoke-binding/invoke-output-binding.py +++ b/examples/invoke-binding/invoke-output-binding.py @@ -7,14 +7,11 @@ n = 0 while True: n += 1 - req_data = { - 'id': n, - 'message': 'hello world' - } + req_data = {"id": n, "message": "hello world"} print(f'Sending message id: {req_data["id"]}, message "{req_data["message"]}"', flush=True) # Create a typed message with content type and body - resp = d.invoke_binding('kafkaBinding', 'create', json.dumps(req_data)) + resp = d.invoke_binding("kafkaBinding", "create", json.dumps(req_data)) time.sleep(1) diff --git a/examples/invoke-custom-data/invoke-caller.py b/examples/invoke-custom-data/invoke-caller.py index 27dabd4de..8a2a0833b 100644 --- a/examples/invoke-custom-data/invoke-caller.py +++ b/examples/invoke-custom-data/invoke-caller.py @@ -5,10 +5,10 @@ with DaprClient() as d: # Create a typed message with content type and body resp = d.invoke_method( - app_id='invoke-receiver', - method_name='my_method', - data=b'SOME_DATA', - content_type='text/plain; charset=UTF-8', + app_id="invoke-receiver", + method_name="my_method", + data=b"SOME_DATA", + content_type="text/plain; charset=UTF-8", ) res = response_messages.CustomResponse() diff --git a/examples/invoke-custom-data/invoke-receiver.py b/examples/invoke-custom-data/invoke-receiver.py index 0d3dc7a87..f05644372 100644 --- a/examples/invoke-custom-data/invoke-receiver.py +++ b/examples/invoke-custom-data/invoke-receiver.py @@ -4,14 +4,15 @@ app = App() -@app.method('my_method') + +@app.method("my_method") def mymethod(request: InvokeMethodRequest): print(request.metadata, flush=True) print(request.text(), flush=True) return response_messages.CustomResponse( - isSuccess=True, - code=200, - message="Hello World - Success!") + isSuccess=True, code=200, message="Hello World - Success!" + ) + app.run(50051) diff --git a/examples/invoke-custom-data/proto/response_pb2.py b/examples/invoke-custom-data/proto/response_pb2.py index 8b04cb277..fa3fddbb7 100644 --- a/examples/invoke-custom-data/proto/response_pb2.py +++ b/examples/invoke-custom-data/proto/response_pb2.py @@ -12,23 +12,25 @@ _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0eresponse.proto\"B\n\x0e\x43ustomResponse\x12\x11\n\tisSuccess\x18\x01 \x01(\x08\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\x12\x0f\n\x07message\x18\x03 \x01(\tb\x06proto3') - - - -_CUSTOMRESPONSE = DESCRIPTOR.message_types_by_name['CustomResponse'] -CustomResponse = _reflection.GeneratedProtocolMessageType('CustomResponse', (_message.Message,), { - 'DESCRIPTOR' : _CUSTOMRESPONSE, - '__module__' : 'response_pb2' - # @@protoc_insertion_point(class_scope:CustomResponse) - }) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0eresponse.proto"B\n\x0e\x43ustomResponse\x12\x11\n\tisSuccess\x18\x01 \x01(\x08\x12\x0c\n\x04\x63ode\x18\x02 \x01(\x05\x12\x0f\n\x07message\x18\x03 \x01(\tb\x06proto3' +) + + +_CUSTOMRESPONSE = DESCRIPTOR.message_types_by_name["CustomResponse"] +CustomResponse = _reflection.GeneratedProtocolMessageType( + "CustomResponse", + (_message.Message,), + { + "DESCRIPTOR": _CUSTOMRESPONSE, + "__module__": "response_pb2", + # @@protoc_insertion_point(class_scope:CustomResponse) + }, +) _sym_db.RegisterMessage(CustomResponse) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _CUSTOMRESPONSE._serialized_start=18 - _CUSTOMRESPONSE._serialized_end=84 + DESCRIPTOR._options = None + _CUSTOMRESPONSE._serialized_start = 18 + _CUSTOMRESPONSE._serialized_end = 84 # @@protoc_insertion_point(module_scope) diff --git a/examples/invoke-custom-data/proto/response_pb2_grpc.py b/examples/invoke-custom-data/proto/response_pb2_grpc.py index 2daafffeb..8a9393943 100644 --- a/examples/invoke-custom-data/proto/response_pb2_grpc.py +++ b/examples/invoke-custom-data/proto/response_pb2_grpc.py @@ -1,4 +1,3 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc - diff --git a/examples/invoke-http/invoke-caller.py b/examples/invoke-http/invoke-caller.py index c69ac1618..6fbadb587 100644 --- a/examples/invoke-http/invoke-caller.py +++ b/examples/invoke-http/invoke-caller.py @@ -4,17 +4,14 @@ from dapr.clients import DaprClient with DaprClient() as d: - req_data = { - 'id': 1, - 'message': 'hello world' - } + req_data = {"id": 1, "message": "hello world"} while True: # Create a typed message with content type and body resp = d.invoke_method( - 'invoke-receiver', - 'my-method', - http_verb='POST', + "invoke-receiver", + "my-method", + http_verb="POST", data=json.dumps(req_data), ) diff --git a/examples/invoke-http/invoke-receiver.py b/examples/invoke-http/invoke-receiver.py index 2d984206e..3aea8a7e4 100644 --- a/examples/invoke-http/invoke-receiver.py +++ b/examples/invoke-http/invoke-receiver.py @@ -4,12 +4,12 @@ app = Flask(__name__) -@app.route('/my-method', methods=['POST']) + +@app.route("/my-method", methods=["POST"]) def getOrder(): data = request.json - print('Order received : ' + json.dumps(data), flush=True) - return json.dumps({'success': True}), 200, { - 'ContentType': 'application/json'} + print("Order received : " + json.dumps(data), flush=True) + return json.dumps({"success": True}), 200, {"ContentType": "application/json"} -app.run(port=8088) \ No newline at end of file +app.run(port=8088) diff --git a/examples/invoke-simple/invoke-caller.py b/examples/invoke-simple/invoke-caller.py index f49371e56..911f75160 100644 --- a/examples/invoke-simple/invoke-caller.py +++ b/examples/invoke-simple/invoke-caller.py @@ -4,16 +4,13 @@ from dapr.clients import DaprClient with DaprClient() as d: - req_data = { - 'id': 1, - 'message': 'hello world' - } + req_data = {"id": 1, "message": "hello world"} while True: # Create a typed message with content type and body resp = d.invoke_method( - 'invoke-receiver', - 'my-method', + "invoke-receiver", + "my-method", data=json.dumps(req_data), ) diff --git a/examples/invoke-simple/invoke-receiver.py b/examples/invoke-simple/invoke-receiver.py index 941e4fc5c..debb92676 100644 --- a/examples/invoke-simple/invoke-receiver.py +++ b/examples/invoke-simple/invoke-receiver.py @@ -2,11 +2,13 @@ app = App() -@app.method(name='my-method') + +@app.method(name="my-method") def mymethod(request: InvokeMethodRequest) -> InvokeMethodResponse: print(request.metadata, flush=True) print(request.text(), flush=True) - return InvokeMethodResponse(b'INVOKE_RECEIVED', "text/plain; charset=UTF-8") + return InvokeMethodResponse(b"INVOKE_RECEIVED", "text/plain; charset=UTF-8") + app.run(50051) diff --git a/examples/metadata/app.py b/examples/metadata/app.py index 63b205877..3ff66eee8 100644 --- a/examples/metadata/app.py +++ b/examples/metadata/app.py @@ -15,13 +15,13 @@ def main(): - extended_attribute_name = 'is-this-our-metadata-example' + extended_attribute_name = "is-this-our-metadata-example" with DaprClient() as dapr: print("First, we will assign a new custom label to Dapr sidecar") # We do this so example can be made deterministic across # multiple invocations. - original_value = 'yes' + original_value = "yes" dapr.set_metadata(extended_attribute_name, original_value) print("Now, we will fetch the sidecar's metadata") @@ -36,7 +36,7 @@ def main(): print(f" name={name} type={type} version={version} capabilities={sorted(caps)}") print("We will update our custom label value and check it was persisted") - dapr.set_metadata(extended_attribute_name, 'You bet it is!') + dapr.set_metadata(extended_attribute_name, "You bet it is!") metadata = dapr.get_metadata() new_value = metadata.extended_metadata[extended_attribute_name] print("We added a custom label named [%s]" % extended_attribute_name) @@ -45,5 +45,5 @@ def main(): print("And we are done 👋", flush=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/pubsub-simple/publisher.py b/examples/pubsub-simple/publisher.py index 10d814b31..8b51b1ea4 100644 --- a/examples/pubsub-simple/publisher.py +++ b/examples/pubsub-simple/publisher.py @@ -20,17 +20,14 @@ id = 0 while id < 3: id += 1 - req_data = { - 'id': id, - 'message': 'hello world' - } + req_data = {"id": id, "message": "hello world"} # Create a typed message with content type and body resp = d.publish_event( - pubsub_name='pubsub', - topic_name='TOPIC_A', + pubsub_name="pubsub", + topic_name="TOPIC_A", data=json.dumps(req_data), - data_content_type='application/json', + data_content_type="application/json", ) # Print the request @@ -44,15 +41,12 @@ id = 3 while id < 6: id += 1 - req_data = { - 'id': id, - 'message': 'hello world' - } + req_data = {"id": id, "message": "hello world"} resp = d.publish_event( - pubsub_name='pubsub', - topic_name=f'topic/{id}', + pubsub_name="pubsub", + topic_name=f"topic/{id}", data=json.dumps(req_data), - data_content_type='application/json', + data_content_type="application/json", ) # Print the request @@ -61,13 +55,13 @@ time.sleep(0.5) # This topic will fail - initiate a retry which gets routed to the dead letter topic - req_data['id'] = 7 + req_data["id"] = 7 resp = d.publish_event( - pubsub_name='pubsub', - topic_name='TOPIC_D', + pubsub_name="pubsub", + topic_name="TOPIC_D", data=json.dumps(req_data), - data_content_type='application/json', - publish_metadata={'custommeta': 'somevalue'} + data_content_type="application/json", + publish_metadata={"custommeta": "somevalue"}, ) # Print the request diff --git a/examples/pubsub-simple/subscriber.py b/examples/pubsub-simple/subscriber.py index 17a03e6be..031adf5bb 100644 --- a/examples/pubsub-simple/subscriber.py +++ b/examples/pubsub-simple/subscriber.py @@ -23,52 +23,64 @@ should_retry = True # To control whether dapr should retry sending a message -@app.subscribe(pubsub_name='pubsub', topic='TOPIC_A') +@app.subscribe(pubsub_name="pubsub", topic="TOPIC_A") def mytopic(event: v1.Event) -> TopicEventResponse: global should_retry data = json.loads(event.Data()) - print(f'Subscriber received: id={data["id"]}, message="{data["message"]}", ' - f'content_type="{event.content_type}"', flush=True) + print( + f'Subscriber received: id={data["id"]}, message="{data["message"]}", ' + f'content_type="{event.content_type}"', + flush=True, + ) # event.Metadata() contains a dictionary of cloud event extensions and publish metadata if should_retry: should_retry = False # we only retry once in this example sleep(0.5) # add some delay to help with ordering of expected logs - return TopicEventResponse('retry') - return TopicEventResponse('success') + return TopicEventResponse("retry") + return TopicEventResponse("success") -@app.subscribe(pubsub_name='pubsub', topic='TOPIC_D', dead_letter_topic='TOPIC_D_DEAD') +@app.subscribe(pubsub_name="pubsub", topic="TOPIC_D", dead_letter_topic="TOPIC_D_DEAD") def fail_and_send_to_dead_topic(event: v1.Event) -> TopicEventResponse: - return TopicEventResponse('retry') + return TopicEventResponse("retry") -@app.subscribe(pubsub_name='pubsub', topic='TOPIC_D_DEAD') +@app.subscribe(pubsub_name="pubsub", topic="TOPIC_D_DEAD") def mytopic_dead(event: v1.Event) -> TopicEventResponse: data = json.loads(event.Data()) - print(f'Dead-Letter Subscriber received: id={data["id"]}, message="{data["message"]}", ' - f'content_type="{event.content_type}"', flush=True) + print( + f'Dead-Letter Subscriber received: id={data["id"]}, message="{data["message"]}", ' + f'content_type="{event.content_type}"', + flush=True, + ) print("Dead-Letter Subscriber. Received via deadletter topic: " + event.Subject(), flush=True) - print("Dead-Letter Subscriber. Originally intended topic: " + event.Extensions()['topic'], - flush=True) - return TopicEventResponse('success') + print( + "Dead-Letter Subscriber. Originally intended topic: " + event.Extensions()["topic"], + flush=True, + ) + return TopicEventResponse("success") # == for testing with Redis only == # workaround as redis pubsub does not support wildcards # we manually register the distinct topics for id in range(4, 7): - app._servicer._registered_topics.append(appcallback_v1.TopicSubscription( - pubsub_name='pubsub', topic=f'topic/{id}')) + app._servicer._registered_topics.append( + appcallback_v1.TopicSubscription(pubsub_name="pubsub", topic=f"topic/{id}") + ) # ================================= # this allows subscribing to all events sent to this app - useful for wildcard topics -@app.subscribe(pubsub_name='pubsub', topic='topic/#', disable_topic_validation=True) +@app.subscribe(pubsub_name="pubsub", topic="topic/#", disable_topic_validation=True) def mytopic_wildcard(event: v1.Event) -> TopicEventResponse: data = json.loads(event.Data()) - print(f'Wildcard-Subscriber received: id={data["id"]}, message="{data["message"]}", ' - f'content_type="{event.content_type}"', flush=True) - return TopicEventResponse('success') + print( + f'Wildcard-Subscriber received: id={data["id"]}, message="{data["message"]}", ' + f'content_type="{event.content_type}"', + flush=True, + ) + return TopicEventResponse("success") app.run(50051) diff --git a/examples/secret_store/example.py b/examples/secret_store/example.py index 7b30d4e46..4375fa9ea 100644 --- a/examples/secret_store/example.py +++ b/examples/secret_store/example.py @@ -14,21 +14,20 @@ from dapr.clients import DaprClient with DaprClient() as d: - key = 'secretKey' + key = "secretKey" randomKey = "random" - storeName = 'localsecretstore' + storeName = "localsecretstore" resp = d.get_secret(store_name=storeName, key=key) - print('Got!') + print("Got!") print(resp.secret) resp = d.get_bulk_secret(store_name=storeName) - print('Got!') + print("Got!") # Converts dict into sorted list of tuples for deterministic output. print(sorted(resp.secrets.items())) try: resp = d.get_secret(store_name=storeName, key=randomKey) - print('Got!') + print("Got!") print(resp.secret) except: print("Got expected error for accessing random key") - diff --git a/examples/state_store/state_store.py b/examples/state_store/state_store.py index 1054934d2..5bf9f4c60 100644 --- a/examples/state_store/state_store.py +++ b/examples/state_store/state_store.py @@ -1,4 +1,3 @@ - """ dapr run python3 state_store.py """ @@ -11,7 +10,7 @@ from dapr.clients.grpc._state import StateItem with DaprClient() as d: - storeName = 'statestore' + storeName = "statestore" key = "key_1" value = "value_1" @@ -41,20 +40,29 @@ # print(f"Details={err.details()}) # Save multiple states. - d.save_bulk_state(store_name=storeName, states=[StateItem(key=another_key, value=another_value), - StateItem(key=yet_another_key, value=yet_another_value)]) + d.save_bulk_state( + store_name=storeName, + states=[ + StateItem(key=another_key, value=another_value), + StateItem(key=yet_another_key, value=yet_another_value), + ], + ) print(f"State store has successfully saved {another_value} with {another_key} as key") print(f"State store has successfully saved {yet_another_value} with {yet_another_key} as key") # Save bulk with etag that is different from the one stored in the database. try: - d.save_bulk_state(store_name=storeName, states=[ - StateItem(key=another_key, value=another_value, etag="999"), - StateItem(key=yet_another_key, value=yet_another_value, etag="999")]) + d.save_bulk_state( + store_name=storeName, + states=[ + StateItem(key=another_key, value=another_value, etag="999"), + StateItem(key=yet_another_key, value=yet_another_value, etag="999"), + ], + ) except grpc.RpcError as err: # StatusCode should be StatusCode.ABORTED. print(f"Cannot save bulk due to bad etags. ErrorCode={err.code()}") - + # For detailed error messages from the dapr runtime: # print(f"Details={err.details()}) @@ -63,17 +71,23 @@ print(f"Got value={state.data} eTag={state.etag}") # Transaction upsert - d.execute_state_transaction(store_name=storeName, operations=[ - TransactionalStateOperation( - operation_type=TransactionOperationType.upsert, - key=key, - data=updated_value, - etag=state.etag), - TransactionalStateOperation(key=another_key, data=another_value), - ]) + d.execute_state_transaction( + store_name=storeName, + operations=[ + TransactionalStateOperation( + operation_type=TransactionOperationType.upsert, + key=key, + data=updated_value, + etag=state.etag, + ), + TransactionalStateOperation(key=another_key, data=another_value), + ], + ) # Batch get - items = d.get_bulk_state(store_name=storeName, keys=[key, another_key], states_metadata={"metakey": "metavalue"}).items + items = d.get_bulk_state( + store_name=storeName, keys=[key, another_key], states_metadata={"metakey": "metavalue"} + ).items print(f"Got items with etags: {[(i.data, i.etag) for i in items]}") # Delete one state by key. diff --git a/examples/state_store_query/state_store_query.py b/examples/state_store_query/state_store_query.py index eaa5a2fe0..3b273fbc7 100644 --- a/examples/state_store_query/state_store_query.py +++ b/examples/state_store_query/state_store_query.py @@ -1,4 +1,3 @@ - """ dapr run python3 state_store_query.py """ @@ -9,23 +8,23 @@ import json with DaprClient() as d: - storeName = 'statestore' + storeName = "statestore" # Wait for sidecar to be up within 5 seconds. d.wait(5) # Query the state store - query = open('query.json', 'r').read() + query = open("query.json", "r").read() res = d.query_state(store_name=storeName, query=query) for r in res.results: - print(r.key, json.dumps(json.loads(str(r.value, 'UTF-8')), sort_keys=True)) + print(r.key, json.dumps(json.loads(str(r.value, "UTF-8")), sort_keys=True)) print("Token:", res.token) # Get more results using a pagination token - query = open('query-token.json', 'r').read() + query = open("query-token.json", "r").read() res = d.query_state(store_name=storeName, query=query) for r in res.results: - print(r.key, json.dumps(json.loads(str(r.value, 'UTF-8')), sort_keys=True)) - print("Token:", res.token) \ No newline at end of file + print(r.key, json.dumps(json.loads(str(r.value, "UTF-8")), sort_keys=True)) + print("Token:", res.token) diff --git a/examples/w3c-tracing/invoke-caller.py b/examples/w3c-tracing/invoke-caller.py index 9817daced..aef558f4d 100644 --- a/examples/w3c-tracing/invoke-caller.py +++ b/examples/w3c-tracing/invoke-caller.py @@ -7,33 +7,29 @@ from opencensus.trace.samplers import AlwaysOnSampler ze = ZipkinExporter( - service_name="python-example", - host_name='localhost', - port=9411, - endpoint='/api/v2/spans') + service_name="python-example", host_name="localhost", port=9411, endpoint="/api/v2/spans" +) tracer = Tracer(exporter=ze, sampler=AlwaysOnSampler()) with tracer.span(name="main") as span: - with DaprClient(headers_callback=lambda: tracer.propagator.to_headers(tracer.span_context)) as d: - + with DaprClient( + headers_callback=lambda: tracer.propagator.to_headers(tracer.span_context) + ) as d: num_messages = 2 for i in range(num_messages): # Create a typed message with content type and body resp = d.invoke_method( - 'invoke-receiver', - 'say', - data=json.dumps({ - 'id': i, - 'message': 'hello world' - }), + "invoke-receiver", + "say", + data=json.dumps({"id": i, "message": "hello world"}), ) # Print the response print(resp.content_type, flush=True) print(resp.text(), flush=True) - resp = d.invoke_method('invoke-receiver', 'sleep', data='') + resp = d.invoke_method("invoke-receiver", "sleep", data="") # Print the response print(resp.content_type, flush=True) print(resp.text(), flush=True) diff --git a/examples/w3c-tracing/invoke-receiver.py b/examples/w3c-tracing/invoke-receiver.py index 2991aa151..3aa0e967b 100644 --- a/examples/w3c-tracing/invoke-receiver.py +++ b/examples/w3c-tracing/invoke-receiver.py @@ -11,29 +11,31 @@ tracer_interceptor = server_interceptor.OpenCensusServerInterceptor(AlwaysOnSampler()) app = App( - thread_pool=futures.ThreadPoolExecutor(max_workers=10), - interceptors=(tracer_interceptor,)) + thread_pool=futures.ThreadPoolExecutor(max_workers=10), interceptors=(tracer_interceptor,) +) -@app.method(name='say') + +@app.method(name="say") def say(request: InvokeMethodRequest) -> InvokeMethodResponse: tracer = Tracer(sampler=AlwaysOnSampler()) - with tracer.span(name='say') as span: + with tracer.span(name="say") as span: data = request.text() - span.add_annotation('Request length', len=len(data)) + span.add_annotation("Request length", len=len(data)) print(request.metadata, flush=True) print(request.text(), flush=True) - return InvokeMethodResponse(b'SAY', "text/plain; charset=UTF-8") + return InvokeMethodResponse(b"SAY", "text/plain; charset=UTF-8") + -@app.method(name='sleep') +@app.method(name="sleep") def sleep(request: InvokeMethodRequest) -> InvokeMethodResponse: tracer = Tracer(sampler=AlwaysOnSampler()) - with tracer.span(name='sleep') as _: + with tracer.span(name="sleep") as _: time.sleep(2) print(request.metadata, flush=True) print(request.text(), flush=True) - return InvokeMethodResponse(b'SLEEP', "text/plain; charset=UTF-8") + return InvokeMethodResponse(b"SLEEP", "text/plain; charset=UTF-8") app.run(3001) diff --git a/examples/workflow/child_workflow.py b/examples/workflow/child_workflow.py index b4d2cf274..48dd95045 100644 --- a/examples/workflow/child_workflow.py +++ b/examples/workflow/child_workflow.py @@ -15,26 +15,31 @@ wfr = wf.WorkflowRuntime() + @wfr.workflow def main_workflow(ctx: wf.DaprWorkflowContext): try: instance_id = ctx.instance_id - child_instance_id = instance_id + '-child' - print(f'*** Calling child workflow {child_instance_id}', flush=True) - yield ctx.call_child_workflow(workflow=child_workflow,input=None,instance_id=child_instance_id) + child_instance_id = instance_id + "-child" + print(f"*** Calling child workflow {child_instance_id}", flush=True) + yield ctx.call_child_workflow( + workflow=child_workflow, input=None, instance_id=child_instance_id + ) except Exception as e: - print(f'*** Exception: {e}') + print(f"*** Exception: {e}") return + @wfr.workflow def child_workflow(ctx: wf.DaprWorkflowContext): instance_id = ctx.instance_id - print(f'*** Child workflow {instance_id} called', flush=True) + print(f"*** Child workflow {instance_id} called", flush=True) + -if __name__ == '__main__': +if __name__ == "__main__": wfr.start() - time.sleep(10) # wait for workflow runtime to start + time.sleep(10) # wait for workflow runtime to start wf_client = wf.DaprWorkflowClient() instance_id = wf_client.schedule_new_workflow(workflow=main_workflow) @@ -42,4 +47,4 @@ def child_workflow(ctx: wf.DaprWorkflowContext): # Wait for the workflow to complete time.sleep(5) - wfr.shutdown() \ No newline at end of file + wfr.shutdown() diff --git a/examples/workflow/fan_out_fan_in.py b/examples/workflow/fan_out_fan_in.py index dff17c17f..eb7513d9e 100644 --- a/examples/workflow/fan_out_fan_in.py +++ b/examples/workflow/fan_out_fan_in.py @@ -16,45 +16,49 @@ wfr = wf.WorkflowRuntime() + @wfr.workflow(name="batch_processing") def batch_processing_workflow(ctx: wf.DaprWorkflowContext, wf_input: int): # get a batch of N work items to process in parallel work_batch = yield ctx.call_activity(get_work_batch, input=wf_input) # schedule N parallel tasks to process the work items and wait for all to complete - parallel_tasks = [ctx.call_activity(process_work_item, input=work_item) for work_item in work_batch] + parallel_tasks = [ + ctx.call_activity(process_work_item, input=work_item) for work_item in work_batch + ] outputs = yield wf.when_all(parallel_tasks) # aggregate the results and send them to another activity total = sum(outputs) yield ctx.call_activity(process_results, input=total) + @wfr.activity(name="get_batch") def get_work_batch(ctx, batch_size: int) -> List[int]: return [i + 1 for i in range(batch_size)] + @wfr.activity def process_work_item(ctx, work_item: int) -> int: - print(f'Processing work item: {work_item}.') + print(f"Processing work item: {work_item}.") time.sleep(5) result = work_item * 2 - print(f'Work item {work_item} processed. Result: {result}.') + print(f"Work item {work_item} processed. Result: {result}.") return result + @wfr.activity(name="final_process") def process_results(ctx, final_result: int): - print(f'Final result: {final_result}.') + print(f"Final result: {final_result}.") -if __name__ == '__main__': +if __name__ == "__main__": wfr.start() - time.sleep(10) # wait for workflow runtime to start + time.sleep(10) # wait for workflow runtime to start wf_client = wf.DaprWorkflowClient() - instance_id = wf_client.schedule_new_workflow( - workflow=batch_processing_workflow, - input=10) - print(f'Workflow started. Instance ID: {instance_id}') + instance_id = wf_client.schedule_new_workflow(workflow=batch_processing_workflow, input=10) + print(f"Workflow started. Instance ID: {instance_id}") state = wf_client.wait_for_workflow_completion(instance_id) wfr.shutdown() diff --git a/examples/workflow/human_approval.py b/examples/workflow/human_approval.py index f67eca0b1..5cd16bc45 100644 --- a/examples/workflow/human_approval.py +++ b/examples/workflow/human_approval.py @@ -20,6 +20,7 @@ wfr = wf.WorkflowRuntime() + @dataclass class Order: cost: float @@ -27,7 +28,7 @@ class Order: quantity: int def __str__(self): - return f'{self.product} ({self.quantity})' + return f"{self.product} ({self.quantity})" @dataclass @@ -38,6 +39,7 @@ class Approval: def from_dict(dict): return Approval(**dict) + @wfr.workflow(name="purchase_order_wf") def purchase_order_workflow(ctx: wf.DaprWorkflowContext, order: Order): # Orders under $1000 are auto-approved @@ -62,12 +64,12 @@ def purchase_order_workflow(ctx: wf.DaprWorkflowContext, order: Order): @wfr.activity(name="send_approval") def send_approval_request(_, order: Order) -> None: - print(f'*** Requesting approval from user for order: {order}') + print(f"*** Requesting approval from user for order: {order}") @wfr.activity def place_order(_, order: Order) -> None: - print(f'*** Placing order: {order}') + print(f"*** Placing order: {order}") if __name__ == "__main__": @@ -86,9 +88,7 @@ def place_order(_, order: Order) -> None: order = Order(args.cost, "MyProduct", 1) wf_client = wf.DaprWorkflowClient() - instance_id = wf_client.schedule_new_workflow( - workflow=purchase_order_workflow, - input=order) + instance_id = wf_client.schedule_new_workflow(workflow=purchase_order_workflow, input=order) def prompt_for_approval(): # Give the workflow time to start up and notify the user @@ -99,7 +99,8 @@ def prompt_for_approval(): instance_id=instance_id, workflow_component="dapr", event_name="approval_received", - event_data=asdict(Approval(args.approver))) + event_data=asdict(Approval(args.approver)), + ) # Prompt the user for approval on a background thread threading.Thread(target=prompt_for_approval, daemon=True).start() @@ -107,14 +108,14 @@ def prompt_for_approval(): # Wait for the orchestration to complete try: state = wf_client.wait_for_workflow_completion( - instance_id, - timeout_in_seconds=args.timeout + 2) + instance_id, timeout_in_seconds=args.timeout + 2 + ) if not state: print("Workflow not found!") # not expected - elif state.runtime_status.name == 'COMPLETED': - print(f'Workflow completed! Result: {state.serialized_output}') + elif state.runtime_status.name == "COMPLETED": + print(f"Workflow completed! Result: {state.serialized_output}") else: - print(f'Workflow failed! Status: {state.runtime_status.name}') # not expected + print(f"Workflow failed! Status: {state.runtime_status.name}") # not expected except TimeoutError: print("*** Workflow timed out!") diff --git a/examples/workflow/monitor.py b/examples/workflow/monitor.py index 606319705..5826988ef 100644 --- a/examples/workflow/monitor.py +++ b/examples/workflow/monitor.py @@ -17,6 +17,8 @@ import dapr.ext.workflow as wf wfr = wf.WorkflowRuntime() + + @dataclass class JobStatus: job_id: str @@ -52,10 +54,10 @@ def check_status(ctx, _) -> str: @wfr.activity def send_alert(ctx, message: str): - print(f'*** Alert: {message}') + print(f"*** Alert: {message}") -if __name__ == '__main__': +if __name__ == "__main__": wfr.start() sleep(10) # wait for workflow runtime to start @@ -66,14 +68,15 @@ def send_alert(ctx, message: str): status = wf_client.get_workflow_state(job_id) except Exception: pass - if not status or status.runtime_status.name != 'RUNNING': + if not status or status.runtime_status.name != "RUNNING": instance_id = wf_client.schedule_new_workflow( workflow=status_monitor_workflow, input=JobStatus(job_id=job_id, is_healthy=True), - instance_id=job_id) - print(f'Workflow started. Instance ID: {instance_id}') + instance_id=job_id, + ) + print(f"Workflow started. Instance ID: {instance_id}") else: - print(f'Workflow already running. Instance ID: {job_id}') + print(f"Workflow already running. Instance ID: {job_id}") input("Press Enter to stop...\n") wfr.shutdown() diff --git a/examples/workflow/task_chaining.py b/examples/workflow/task_chaining.py index aeefd2f07..b04a6bf39 100644 --- a/examples/workflow/task_chaining.py +++ b/examples/workflow/task_chaining.py @@ -17,6 +17,7 @@ wfr = wf.WorkflowRuntime() + @wfr.workflow(name="random_workflow") def task_chain_workflow(ctx: wf.DaprWorkflowContext, wf_input: int): try: @@ -31,41 +32,39 @@ def task_chain_workflow(ctx: wf.DaprWorkflowContext, wf_input: int): @wfr.activity(name="step10") def step1(ctx, activity_input): - print(f'Step 1: Received input: {activity_input}.') + print(f"Step 1: Received input: {activity_input}.") # Do some work return activity_input + 1 @wfr.activity def step2(ctx, activity_input): - print(f'Step 2: Received input: {activity_input}.') + print(f"Step 2: Received input: {activity_input}.") # Do some work return activity_input * 2 @wfr.activity def step3(ctx, activity_input): - print(f'Step 3: Received input: {activity_input}.') + print(f"Step 3: Received input: {activity_input}.") # Do some work return activity_input ^ 2 @wfr.activity def error_handler(ctx, error): - print(f'Executing error handler: {error}.') + print(f"Executing error handler: {error}.") # Do some compensating work -if __name__ == '__main__': +if __name__ == "__main__": wfr.start() - sleep(10) # wait for workflow runtime to start + sleep(10) # wait for workflow runtime to start wf_client = wf.DaprWorkflowClient() - instance_id = wf_client.schedule_new_workflow( - workflow=task_chain_workflow, - input=42) - print(f'Workflow started. Instance ID: {instance_id}') + instance_id = wf_client.schedule_new_workflow(workflow=task_chain_workflow, input=42) + print(f"Workflow started. Instance ID: {instance_id}") state = wf_client.wait_for_workflow_completion(instance_id) - print(f'Workflow completed! Status: {state.runtime_status}') + print(f"Workflow completed! Status: {state.runtime_status}") wfr.shutdown() diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py index 026063f4a..8c6168953 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/__init__.py @@ -17,7 +17,4 @@ from .app import DaprApp -__all__ = [ - 'DaprActor', - 'DaprApp' -] +__all__ = ["DaprActor", "DaprApp"] diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py index 8ce441200..ec18f5d7b 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/actor.py @@ -15,7 +15,7 @@ from typing import Any, Optional, Type, List -from fastapi import FastAPI, APIRouter, Request, Response, status # type: ignore +from fastapi import FastAPI, APIRouter, Request, Response, status # type: ignore from fastapi.logger import logger from fastapi.responses import JSONResponse @@ -24,21 +24,22 @@ from dapr.serializers import DefaultJSONSerializer DEFAULT_CONTENT_TYPE = "application/json; utf-8" -DAPR_REENTRANCY_ID_HEADER = 'Dapr-Reentrancy-Id' +DAPR_REENTRANCY_ID_HEADER = "Dapr-Reentrancy-Id" def _wrap_response( - status_code: int, - msg: Any, - error_code: Optional[str] = None, - content_type: Optional[str] = DEFAULT_CONTENT_TYPE): + status_code: int, + msg: Any, + error_code: Optional[str] = None, + content_type: Optional[str] = DEFAULT_CONTENT_TYPE, +): resp = None if isinstance(msg, str): response_obj = { - 'message': msg, + "message": msg, } if not (status_code >= 200 and status_code < 300) and error_code: - response_obj['errorCode'] = error_code + response_obj["errorCode"] = error_code resp = JSONResponse(content=response_obj, status_code=status_code) elif isinstance(msg, bytes): resp = Response(content=msg, media_type=content_type) @@ -48,9 +49,7 @@ def _wrap_response( class DaprActor(object): - - def __init__(self, app: FastAPI, - router_tags: Optional[List[str]] = ['Actor']): + def __init__(self, app: FastAPI, router_tags: Optional[List[str]] = ["Actor"]): # router_tags should be added to all magic Dapr Actor methods implemented here self._router_tags = router_tags self._router = APIRouter() @@ -61,108 +60,95 @@ def __init__(self, app: FastAPI, def init_routes(self, router: APIRouter): @router.get("/healthz", tags=self._router_tags) async def healthz(): - return {'status': 'ok'} + return {"status": "ok"} - @router.get('/dapr/config', tags=self._router_tags) + @router.get("/dapr/config", tags=self._router_tags) async def dapr_config(): serialized = self._dapr_serializer.serialize(ActorRuntime.get_actor_config()) return _wrap_response(status.HTTP_200_OK, serialized) - @router.delete('/actors/{actor_type_name}/{actor_id}', tags=self._router_tags) + @router.delete("/actors/{actor_type_name}/{actor_id}", tags=self._router_tags) async def actor_deactivation(actor_type_name: str, actor_id: str): try: await ActorRuntime.deactivate(actor_type_name, actor_id) except DaprInternalError as ex: - return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - ex.as_dict()) + return _wrap_response(status.HTTP_500_INTERNAL_SERVER_ERROR, ex.as_dict()) except Exception as ex: return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - repr(ex), - ERROR_CODE_UNKNOWN) + status.HTTP_500_INTERNAL_SERVER_ERROR, repr(ex), ERROR_CODE_UNKNOWN + ) - msg = f'deactivated actor: {actor_type_name}.{actor_id}' + msg = f"deactivated actor: {actor_type_name}.{actor_id}" logger.debug(msg) return _wrap_response(status.HTTP_200_OK, msg) - @router.put('/actors/{actor_type_name}/{actor_id}/method/{method_name}', - tags=self._router_tags) + @router.put( + "/actors/{actor_type_name}/{actor_id}/method/{method_name}", tags=self._router_tags + ) async def actor_method( - actor_type_name: str, - actor_id: str, - method_name: str, - request: Request): + actor_type_name: str, actor_id: str, method_name: str, request: Request + ): try: # Read raw bytes from request stream req_body = await request.body() reentrancy_id = request.headers.get(DAPR_REENTRANCY_ID_HEADER) result = await ActorRuntime.dispatch( - actor_type_name, actor_id, method_name, req_body, reentrancy_id) + actor_type_name, actor_id, method_name, req_body, reentrancy_id + ) except DaprInternalError as ex: - return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, ex.as_dict()) + return _wrap_response(status.HTTP_500_INTERNAL_SERVER_ERROR, ex.as_dict()) except Exception as ex: return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - repr(ex), - ERROR_CODE_UNKNOWN) + status.HTTP_500_INTERNAL_SERVER_ERROR, repr(ex), ERROR_CODE_UNKNOWN + ) - msg = f'called method. actor: {actor_type_name}.{actor_id}, method: {method_name}' + msg = f"called method. actor: {actor_type_name}.{actor_id}, method: {method_name}" logger.debug(msg) return _wrap_response(status.HTTP_200_OK, result) - @router.put('/actors/{actor_type_name}/{actor_id}/method/timer/{timer_name}', - tags=self._router_tags) + @router.put( + "/actors/{actor_type_name}/{actor_id}/method/timer/{timer_name}", tags=self._router_tags + ) async def actor_timer( - actor_type_name: str, - actor_id: str, - timer_name: str, - request: Request): + actor_type_name: str, actor_id: str, timer_name: str, request: Request + ): try: # Read raw bytes from request stream req_body = await request.body() await ActorRuntime.fire_timer(actor_type_name, actor_id, timer_name, req_body) except DaprInternalError as ex: - return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - ex.as_dict()) + return _wrap_response(status.HTTP_500_INTERNAL_SERVER_ERROR, ex.as_dict()) except Exception as ex: return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - repr(ex), - ERROR_CODE_UNKNOWN) + status.HTTP_500_INTERNAL_SERVER_ERROR, repr(ex), ERROR_CODE_UNKNOWN + ) - msg = f'called timer. actor: {actor_type_name}.{actor_id}, timer: {timer_name}' + msg = f"called timer. actor: {actor_type_name}.{actor_id}, timer: {timer_name}" logger.debug(msg) return _wrap_response(status.HTTP_200_OK, msg) - @router.put('/actors/{actor_type_name}/{actor_id}/method/remind/{reminder_name}', - tags=self._router_tags) + @router.put( + "/actors/{actor_type_name}/{actor_id}/method/remind/{reminder_name}", + tags=self._router_tags, + ) async def actor_reminder( - actor_type_name: str, - actor_id: str, - reminder_name: str, - request: Request): + actor_type_name: str, actor_id: str, reminder_name: str, request: Request + ): try: # Read raw bytes from request stream req_body = await request.body() - await ActorRuntime.fire_reminder( - actor_type_name, actor_id, reminder_name, req_body) + await ActorRuntime.fire_reminder(actor_type_name, actor_id, reminder_name, req_body) except DaprInternalError as ex: - return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - ex.as_dict()) + return _wrap_response(status.HTTP_500_INTERNAL_SERVER_ERROR, ex.as_dict()) except Exception as ex: return _wrap_response( - status.HTTP_500_INTERNAL_SERVER_ERROR, - repr(ex), - ERROR_CODE_UNKNOWN) + status.HTTP_500_INTERNAL_SERVER_ERROR, repr(ex), ERROR_CODE_UNKNOWN + ) - msg = f'called reminder. actor: {actor_type_name}.{actor_id}, reminder: {reminder_name}' + msg = f"called reminder. actor: {actor_type_name}.{actor_id}, reminder: {reminder_name}" logger.debug(msg) return _wrap_response(status.HTTP_200_OK, msg) async def register_actor(self, actor: Type[Actor]) -> None: await ActorRuntime.register_actor(actor) - logger.debug(f'registered actor: {actor.__class__.__name__}') + logger.debug(f"registered actor: {actor.__class__.__name__}") diff --git a/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py b/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py index 8fac13199..5773beb25 100644 --- a/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py +++ b/ext/dapr-ext-fastapi/dapr/ext/fastapi/app.py @@ -24,24 +24,24 @@ class DaprApp: app_instance: The FastAPI instance to wrap. """ - def __init__(self, app_instance: FastAPI, - router_tags: Optional[List[str]] = ['PubSub']): + def __init__(self, app_instance: FastAPI, router_tags: Optional[List[str]] = ["PubSub"]): # The router_tags should be added to all magic Dapr App PubSub methods implemented here self._router_tags = router_tags self._app = app_instance self._subscriptions: List[Dict[str, object]] = [] - self._app.add_api_route("/dapr/subscribe", - self._get_subscriptions, - methods=["GET"], - tags=self._router_tags) - - def subscribe(self, - pubsub: str, - topic: str, - metadata: Optional[Dict[str, str]] = {}, - route: Optional[str] = None, - dead_letter_topic: Optional[str] = None): + self._app.add_api_route( + "/dapr/subscribe", self._get_subscriptions, methods=["GET"], tags=self._router_tags + ) + + def subscribe( + self, + pubsub: str, + topic: str, + metadata: Optional[Dict[str, str]] = {}, + route: Optional[str] = None, + dead_letter_topic: Optional[str] = None, + ): """ Subscribes to a topic on a pub/sub component. @@ -73,21 +73,27 @@ def subscribe(self, Returns: The decorator for the function. """ + def decorator(func): event_handler_route = f"/events/{pubsub}/{topic}" if route is None else route - self._app.add_api_route(event_handler_route, - func, - methods=["POST"], - tags=self._router_tags) - - self._subscriptions.append({ - "pubsubname": pubsub, - "topic": topic, - "route": event_handler_route, - "metadata": metadata, - **({"deadLetterTopic": dead_letter_topic} if dead_letter_topic is not None else {}) - }) + self._app.add_api_route( + event_handler_route, func, methods=["POST"], tags=self._router_tags + ) + + self._subscriptions.append( + { + "pubsubname": pubsub, + "topic": topic, + "route": event_handler_route, + "metadata": metadata, + **( + {"deadLetterTopic": dead_letter_topic} + if dead_letter_topic is not None + else {} + ), + } + ) return decorator diff --git a/ext/dapr-ext-fastapi/setup.py b/ext/dapr-ext-fastapi/setup.py index 26e459b60..5cd85528b 100644 --- a/ext/dapr-ext-fastapi/setup.py +++ b/ext/dapr-ext-fastapi/setup.py @@ -19,19 +19,19 @@ # Load version in dapr package. version_info = {} -with open('dapr/ext/fastapi/version.py') as fp: +with open("dapr/ext/fastapi/version.py") as fp: exec(fp.read(), version_info) -__version__ = version_info['__version__'] +__version__ = version_info["__version__"] def is_release(): - return '.dev' not in __version__ + return ".dev" not in __version__ -name = 'dapr-ext-fastapi' +name = "dapr-ext-fastapi" version = __version__ -description = 'The official release of Dapr FastAPI extension.' -long_description = ''' +description = "The official release of Dapr FastAPI extension." +long_description = """ This is the FastAPI extension for Dapr. Dapr is a portable, serverless, event-driven runtime that makes it easy for developers to @@ -42,18 +42,18 @@ def is_release(): independent, building blocks that enable you to build portable applications with the language and framework of your choice. Each building block is independent and you can use one, some, or all of them in your application. -'''.lstrip() +""".lstrip() # Get build number from GITHUB_RUN_NUMBER environment variable -build_number = os.environ.get('GITHUB_RUN_NUMBER', '0') +build_number = os.environ.get("GITHUB_RUN_NUMBER", "0") if not is_release(): - name += '-dev' - version = f'{__version__}{build_number}' - description = 'The developmental release for Dapr FastAPI extension.' - long_description = 'This is the developmental release for Dapr FastAPI extension.' + name += "-dev" + version = f"{__version__}{build_number}" + description = "The developmental release for Dapr FastAPI extension." + long_description = "This is the developmental release for Dapr FastAPI extension." -print(f'package name: {name}, version: {version}', flush=True) +print(f"package name: {name}, version: {version}", flush=True) setup( diff --git a/ext/dapr-ext-fastapi/tests/test_app.py b/ext/dapr-ext-fastapi/tests/test_app.py index 0497723ab..a8d5a74b9 100644 --- a/ext/dapr-ext-fastapi/tests/test_app.py +++ b/ext/dapr-ext-fastapi/tests/test_app.py @@ -28,11 +28,16 @@ def event_handler(event_data: Message): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/events/pubsub/test', - 'metadata': {} - }], response.json()) + [ + { + "pubsubname": "pubsub", + "topic": "test", + "route": "/events/pubsub/test", + "metadata": {}, + } + ], + response.json(), + ) response = self.client.post("/events/pubsub/test", json={"body": "new message"}) self.assertEqual(response.status_code, 200) @@ -50,11 +55,9 @@ def event_handler(event_data: Message): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/do-something', - 'metadata': {} - }], response.json()) + [{"pubsubname": "pubsub", "topic": "test", "route": "/do-something", "metadata": {}}], + response.json(), + ) response = self.client.post("/do-something", json={"body": "new message"}) self.assertEqual(response.status_code, 200) @@ -63,9 +66,7 @@ def event_handler(event_data: Message): def test_subscribe_metadata(self): handler_metadata = {"rawPayload": "true"} - @self.dapr_app.subscribe(pubsub="pubsub", - topic="test", - metadata=handler_metadata) + @self.dapr_app.subscribe(pubsub="pubsub", topic="test", metadata=handler_metadata) def event_handler(event_data: Message): return "custom metadata" @@ -73,11 +74,16 @@ def event_handler(event_data: Message): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/events/pubsub/test', - 'metadata': {"rawPayload": "true"} - }], response.json()) + [ + { + "pubsubname": "pubsub", + "topic": "test", + "route": "/events/pubsub/test", + "metadata": {"rawPayload": "true"}, + } + ], + response.json(), + ) response = self.client.post("/events/pubsub/test", json={"body": "new message"}) self.assertEqual(response.status_code, 200) @@ -87,45 +93,41 @@ def test_router_tag(self): app1 = FastAPI() app2 = FastAPI() app3 = FastAPI() - DaprApp(app_instance=app1, router_tags=['MyTag', 'PubSub']).subscribe( - pubsub="mypubsub", topic="test") + DaprApp(app_instance=app1, router_tags=["MyTag", "PubSub"]).subscribe( + pubsub="mypubsub", topic="test" + ) DaprApp(app_instance=app2).subscribe(pubsub="mypubsub", topic="test") DaprApp(app_instance=app3, router_tags=None).subscribe(pubsub="mypubsub", topic="test") - PATHS_WITH_EXPECTED_TAGS = [ - '/dapr/subscribe', - '/events/mypubsub/test' - ] + PATHS_WITH_EXPECTED_TAGS = ["/dapr/subscribe", "/events/mypubsub/test"] foundTags = False for route in app1.router.routes: if hasattr(route, "tags"): self.assertIn(route.path, PATHS_WITH_EXPECTED_TAGS) - self.assertEqual(['MyTag', 'PubSub'], route.tags) + self.assertEqual(["MyTag", "PubSub"], route.tags) foundTags = True if not foundTags: - self.fail('No tags found') + self.fail("No tags found") foundTags = False for route in app2.router.routes: if hasattr(route, "tags"): self.assertIn(route.path, PATHS_WITH_EXPECTED_TAGS) - self.assertEqual(['PubSub'], route.tags) + self.assertEqual(["PubSub"], route.tags) foundTags = True if not foundTags: - self.fail('No tags found') + self.fail("No tags found") for route in app3.router.routes: if hasattr(route, "tags"): if len(route.tags) > 0: - self.fail('Found tags on route that should not have any') + self.fail("Found tags on route that should not have any") def test_subscribe_dead_letter(self): dead_letter_topic = "dead-test" - @self.dapr_app.subscribe(pubsub="pubsub", - topic="test", - dead_letter_topic=dead_letter_topic) + @self.dapr_app.subscribe(pubsub="pubsub", topic="test", dead_letter_topic=dead_letter_topic) def event_handler(event_data: Message): return "dead letter test" @@ -133,17 +135,22 @@ def event_handler(event_data: Message): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/events/pubsub/test', - 'metadata': {}, - 'deadLetterTopic': dead_letter_topic - }], response.json()) + [ + { + "pubsubname": "pubsub", + "topic": "test", + "route": "/events/pubsub/test", + "metadata": {}, + "deadLetterTopic": dead_letter_topic, + } + ], + response.json(), + ) response = self.client.post("/events/pubsub/test", json={"body": "new message"}) self.assertEqual(response.status_code, 200) self.assertEqual(response.text, '"dead letter test"') -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/ext/dapr-ext-fastapi/tests/test_dapractor.py b/ext/dapr-ext-fastapi/tests/test_dapractor.py index 3f7ca43d0..42b772740 100644 --- a/ext/dapr-ext-fastapi/tests/test_dapractor.py +++ b/ext/dapr-ext-fastapi/tests/test_dapractor.py @@ -23,23 +23,23 @@ class DaprActorTest(unittest.TestCase): def test_wrap_response_str(self): - r = _wrap_response(200, 'fake_message') - self.assertEqual({'message': 'fake_message'}, json.loads(r.body)) + r = _wrap_response(200, "fake_message") + self.assertEqual({"message": "fake_message"}, json.loads(r.body)) self.assertEqual(200, r.status_code) def test_wrap_response_str_err(self): - r = _wrap_response(400, 'fake_message', 'ERR_FAKE') - self.assertEqual({'message': 'fake_message', 'errorCode': 'ERR_FAKE'}, json.loads(r.body)) + r = _wrap_response(400, "fake_message", "ERR_FAKE") + self.assertEqual({"message": "fake_message", "errorCode": "ERR_FAKE"}, json.loads(r.body)) self.assertEqual(400, r.status_code) def test_wrap_response_bytes_text(self): - r = _wrap_response(200, b'fake_bytes_message', content_type='text/plain') - self.assertEqual(b'fake_bytes_message', r.body) + r = _wrap_response(200, b"fake_bytes_message", content_type="text/plain") + self.assertEqual(b"fake_bytes_message", r.body) self.assertEqual(200, r.status_code) - self.assertEqual('text/plain', r.media_type) + self.assertEqual("text/plain", r.media_type) def test_wrap_response_obj(self): - fake_data = {'message': 'ok'} + fake_data = {"message": "ok"} r = _wrap_response(200, fake_data) self.assertEqual(fake_data, json.loads(r.body)) self.assertEqual(200, r.status_code) @@ -48,42 +48,42 @@ def test_router_tag(self): app1 = FastAPI() app2 = FastAPI() app3 = FastAPI() - DaprActor(app=app1, router_tags=['MyTag', 'Actor']) + DaprActor(app=app1, router_tags=["MyTag", "Actor"]) DaprActor(app=app2) DaprActor(app=app3, router_tags=None) PATHS_WITH_EXPECTED_TAGS = [ - '/healthz', - '/dapr/config', - '/actors/{actor_type_name}/{actor_id}', - '/actors/{actor_type_name}/{actor_id}/method/{method_name}', - '/actors/{actor_type_name}/{actor_id}/method/timer/{timer_name}', - '/actors/{actor_type_name}/{actor_id}/method/remind/{reminder_name}' + "/healthz", + "/dapr/config", + "/actors/{actor_type_name}/{actor_id}", + "/actors/{actor_type_name}/{actor_id}/method/{method_name}", + "/actors/{actor_type_name}/{actor_id}/method/timer/{timer_name}", + "/actors/{actor_type_name}/{actor_id}/method/remind/{reminder_name}", ] foundTags = False for route in app1.router.routes: if hasattr(route, "tags"): self.assertIn(route.path, PATHS_WITH_EXPECTED_TAGS) - self.assertEqual(['MyTag', 'Actor'], route.tags) + self.assertEqual(["MyTag", "Actor"], route.tags) foundTags = True if not foundTags: - self.fail('No tags found') + self.fail("No tags found") foundTags = False for route in app2.router.routes: if hasattr(route, "tags"): self.assertIn(route.path, PATHS_WITH_EXPECTED_TAGS) - self.assertEqual(['Actor'], route.tags) + self.assertEqual(["Actor"], route.tags) foundTags = True if not foundTags: - self.fail('No tags found') + self.fail("No tags found") for route in app3.router.routes: if hasattr(route, "tags"): if len(route.tags) > 0: - self.fail('Found tags on route that should not have any') + self.fail("Found tags on route that should not have any") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py b/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py index 4461eb822..a030859cb 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/__init__.py @@ -16,14 +16,14 @@ from dapr.clients.grpc._request import InvokeMethodRequest, BindingRequest from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse -from dapr.ext.grpc.app import App, Rule # type:ignore +from dapr.ext.grpc.app import App, Rule # type:ignore __all__ = [ - 'App', - 'Rule', - 'InvokeMethodRequest', - 'InvokeMethodResponse', - 'BindingRequest', - 'TopicEventResponse', + "App", + "Rule", + "InvokeMethodRequest", + "InvokeMethodResponse", + "BindingRequest", + "TopicEventResponse", ] diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicier.py b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicier.py index 95f5adf7c..198331130 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicier.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicier.py @@ -28,8 +28,7 @@ from dapr.clients.grpc._request import InvokeMethodRequest, BindingRequest from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse -InvokeMethodCallable = Callable[[ - InvokeMethodRequest], Union[str, bytes, InvokeMethodResponse]] +InvokeMethodCallable = Callable[[InvokeMethodRequest], Union[str, bytes, InvokeMethodResponse]] TopicSubscribeCallable = Callable[[v1.Event], Optional[TopicEventResponse]] BindingCallable = Callable[[BindingRequest], None] @@ -43,8 +42,11 @@ def __init__(self, match: str, priority: int) -> None: class _RegisteredSubscription: - def __init__(self, subscription: appcallback_v1.TopicSubscription, - rules: List[Tuple[int, appcallback_v1.TopicRule]]): + def __init__( + self, + subscription: appcallback_v1.TopicSubscription, + rules: List[Tuple[int, appcallback_v1.TopicRule]], + ): self.subscription = subscription self.rules = rules @@ -71,18 +73,19 @@ def __init__(self): def register_method(self, method: str, cb: InvokeMethodCallable) -> None: """Registers method for service invocation.""" if method in self._invoke_method_map: - raise ValueError(f'{method} is already registered') + raise ValueError(f"{method} is already registered") self._invoke_method_map[method] = cb def register_topic( - self, - pubsub_name: str, - topic: str, - cb: TopicSubscribeCallable, - metadata: Optional[Dict[str, str]], - dead_letter_topic: Optional[str] = None, - rule: Optional[Rule] = None, - disable_topic_validation: Optional[bool] = False) -> None: + self, + pubsub_name: str, + topic: str, + cb: TopicSubscribeCallable, + metadata: Optional[Dict[str, str]], + dead_letter_topic: Optional[str] = None, + rule: Optional[Rule] = None, + disable_topic_validation: Optional[bool] = False, + ) -> None: """Registers topic subscription for pubsub.""" if not disable_topic_validation: topic_key = pubsub_name + DELIMITER + topic @@ -90,11 +93,10 @@ def register_topic( topic_key = pubsub_name pubsub_topic = topic_key + DELIMITER if rule is not None: - path = getattr(cb, '__name__', rule.match) + path = getattr(cb, "__name__", rule.match) pubsub_topic = pubsub_topic + path if pubsub_topic in self._topic_map: - raise ValueError( - f'{topic} is already registered with {pubsub_name}') + raise ValueError(f"{topic} is already registered with {pubsub_name}") self._topic_map[pubsub_topic] = cb registered_topic = self._registered_topics_map.get(topic_key) @@ -105,7 +107,7 @@ def register_topic( pubsub_name=pubsub_name, topic=topic, metadata=metadata, - routes=appcallback_v1.TopicRoutes() + routes=appcallback_v1.TopicRoutes(), ) if dead_letter_topic: sub.dead_letter_topic = dead_letter_topic @@ -117,19 +119,17 @@ def register_topic( rules = registered_topic.rules if rule: - path = getattr(cb, '__name__', rule.match) - rules.append((rule.priority, appcallback_v1.TopicRule( - match=rule.match, path=path))) + path = getattr(cb, "__name__", rule.match) + rules.append((rule.priority, appcallback_v1.TopicRule(match=rule.match, path=path))) rules.sort(key=lambda x: x[0]) rs = [rule for id, rule in rules] del sub.routes.rules[:] sub.routes.rules.extend(rs) - def register_binding( - self, name: str, cb: BindingCallable) -> None: + def register_binding(self, name: str, cb: BindingCallable) -> None: """Registers input bindings.""" if name in self._binding_map: - raise ValueError(f'{name} is already registered') + raise ValueError(f"{name} is already registered") self._binding_map[name] = cb self._registered_bindings.append(name) @@ -137,8 +137,7 @@ def OnInvoke(self, request: InvokeRequest, context): """Invokes service method with InvokeRequest.""" if request.method not in self._invoke_method_map: context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore - raise NotImplementedError( - f'{request.method} method not implemented!') + raise NotImplementedError(f"{request.method} method not implemented!") req = InvokeMethodRequest(request.data, request.content_type) req.metadata = context.invocation_metadata() @@ -157,9 +156,8 @@ def OnInvoke(self, request: InvokeRequest, context): resp_data = resp else: context.set_code(grpc.StatusCode.OUT_OF_RANGE) - context.set_details(f'{type(resp)} is the invalid return type.') - raise NotImplementedError( - f'{request.method} method not implemented!') + context.set_details(f"{type(resp)} is the invalid return type.") + raise NotImplementedError(f"{request.method} method not implemented!") if len(resp_data.get_headers()) > 0: context.send_initial_metadata(resp_data.get_headers()) @@ -168,18 +166,15 @@ def OnInvoke(self, request: InvokeRequest, context): if resp_data.content_type: content_type = resp_data.content_type - return common_v1.InvokeResponse( - data=resp_data.proto, content_type=content_type) + return common_v1.InvokeResponse(data=resp_data.proto, content_type=content_type) def ListTopicSubscriptions(self, request, context): """Lists all topics subscribed by this app.""" - return appcallback_v1.ListTopicSubscriptionsResponse( - subscriptions=self._registered_topics) + return appcallback_v1.ListTopicSubscriptionsResponse(subscriptions=self._registered_topics) def OnTopicEvent(self, request: TopicEventRequest, context): """Subscribes events from Pubsub.""" - pubsub_topic = request.pubsub_name + DELIMITER + \ - request.topic + DELIMITER + request.path + pubsub_topic = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path no_validation_key = request.pubsub_name + DELIMITER + request.path if pubsub_topic not in self._topic_map: @@ -187,8 +182,7 @@ def OnTopicEvent(self, request: TopicEventRequest, context): pubsub_topic = no_validation_key else: context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore - raise NotImplementedError( - f'topic {request.topic} is not implemented!') + raise NotImplementedError(f"topic {request.topic} is not implemented!") customdata: Struct = request.extensions extensions = dict() @@ -213,8 +207,7 @@ def OnTopicEvent(self, request: TopicEventRequest, context): def ListInputBindings(self, request, context): """Lists all input bindings subscribed by this app.""" - return appcallback_v1.ListInputBindingsResponse( - bindings=self._registered_bindings) + return appcallback_v1.ListInputBindingsResponse(bindings=self._registered_bindings) def OnBindingEvent(self, request: BindingEventRequest, context): """Listens events from the input bindings @@ -222,9 +215,8 @@ def OnBindingEvent(self, request: BindingEventRequest, context): bindings optionally by returning BindingEventResponse. """ if request.name not in self._binding_map: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore - raise NotImplementedError( - f'{request.name} binding not implemented!') + context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore + raise NotImplementedError(f"{request.name} binding not implemented!") req = BindingRequest(request.data, dict(request.metadata)) req.metadata = context.invocation_metadata() diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/app.py b/ext/dapr-ext-grpc/dapr/ext/grpc/app.py index a6c01cd41..7eba8a099 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/app.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/app.py @@ -19,7 +19,7 @@ from typing import Dict, Optional from dapr.conf import settings -from dapr.ext.grpc._servicier import _CallbackServicer, Rule # type: ignore +from dapr.ext.grpc._servicier import _CallbackServicer, Rule # type: ignore from dapr.proto import appcallback_service_v1 @@ -47,10 +47,12 @@ def __init__(self, max_grpc_message_length: Optional[int] = None, **kwargs): options = [] if max_grpc_message_length is not None: options = [ - ('grpc.max_send_message_length', max_grpc_message_length), - ('grpc.max_receive_message_length', max_grpc_message_length)] + ("grpc.max_send_message_length", max_grpc_message_length), + ("grpc.max_receive_message_length", max_grpc_message_length), + ] self._server = grpc.server( # type: ignore - futures.ThreadPoolExecutor(max_workers=10), options=options) + futures.ThreadPoolExecutor(max_workers=10), options=options + ) else: self._server = grpc.server(**kwargs) # type: ignore appcallback_service_v1.add_AppCallbackServicer_to_server(self._servicer, self._server) @@ -73,8 +75,7 @@ def run(self, app_port: Optional[int] = None, listen_address: Optional[str] = No """ if app_port is None: app_port = settings.GRPC_APP_PORT - self._server.add_insecure_port( - f'{listen_address if listen_address else "[::]"}:{app_port}') + self._server.add_insecure_port(f'{listen_address if listen_address else "[::]"}:{app_port}') self._server.start() self._server.wait_for_termination() @@ -119,13 +120,21 @@ def start(request: InvokeMethodRequest): Args: name (str): name of invoked method """ + def decorator(func): self._servicer.register_method(name, func) + return decorator - def subscribe(self, pubsub_name: str, topic: str, metadata: Optional[Dict[str, str]] = {}, - dead_letter_topic: Optional[str] = None, rule: Optional[Rule] = None, - disable_topic_validation: Optional[bool] = False): + def subscribe( + self, + pubsub_name: str, + topic: str, + metadata: Optional[Dict[str, str]] = {}, + dead_letter_topic: Optional[str] = None, + rule: Optional[Rule] = None, + disable_topic_validation: Optional[bool] = False, + ): """A decorator that is used to register the subscribing topic method. The below example registers 'topic' subscription topic and pass custom @@ -144,9 +153,18 @@ def topic(event: v1.Event) -> None: during initialization dead_letter_topic (str, optional): the dead letter topic name for the subscription """ + def decorator(func): - self._servicer.register_topic(pubsub_name, topic, func, metadata, dead_letter_topic, - rule, disable_topic_validation) + self._servicer.register_topic( + pubsub_name, + topic, + func, + metadata, + dead_letter_topic, + rule, + disable_topic_validation, + ) + return decorator def binding(self, name: str): @@ -161,6 +179,8 @@ def input(request: BindingRequest) -> None: Args: name (str): the name of invoked method """ + def decorator(func): self._servicer.register_binding(name, func) + return decorator diff --git a/ext/dapr-ext-grpc/setup.py b/ext/dapr-ext-grpc/setup.py index 154eccf22..4c269596f 100644 --- a/ext/dapr-ext-grpc/setup.py +++ b/ext/dapr-ext-grpc/setup.py @@ -19,19 +19,19 @@ # Load version in dapr package. version_info = {} -with open('dapr/ext/grpc/version.py') as fp: +with open("dapr/ext/grpc/version.py") as fp: exec(fp.read(), version_info) -__version__ = version_info['__version__'] +__version__ = version_info["__version__"] def is_release(): - return '.dev' not in __version__ + return ".dev" not in __version__ -name = 'dapr-ext-grpc' +name = "dapr-ext-grpc" version = __version__ -description = 'The official release of Dapr Python SDK gRPC Extension.' -long_description = ''' +description = "The official release of Dapr Python SDK gRPC Extension." +long_description = """ This is the gRPC extension for Dapr. Dapr is a portable, serverless, event-driven runtime that makes it easy for developers to @@ -42,18 +42,18 @@ def is_release(): independent, building blocks that enable you to build portable applications with the language and framework of your choice. Each building block is independent and you can use one, some, or all of them in your application. -'''.lstrip() +""".lstrip() # Get build number from GITHUB_RUN_NUMBER environment variable -build_number = os.environ.get('GITHUB_RUN_NUMBER', '0') +build_number = os.environ.get("GITHUB_RUN_NUMBER", "0") if not is_release(): - name += '-dev' - version = f'{__version__}{build_number}' - description = 'The developmental release for Dapr gRPC AppCallback.' - long_description = 'This is the developmental release for Dapr gRPC AppCallback.' + name += "-dev" + version = f"{__version__}{build_number}" + description = "The developmental release for Dapr gRPC AppCallback." + long_description = "This is the developmental release for Dapr gRPC AppCallback." -print(f'package name: {name}, version: {version}', flush=True) +print(f"package name: {name}, version: {version}", flush=True) setup( diff --git a/ext/dapr-ext-grpc/tests/test_app.py b/ext/dapr-ext-grpc/tests/test_app.py index aeeb80c42..6d62a448b 100644 --- a/ext/dapr-ext-grpc/tests/test_app.py +++ b/ext/dapr-ext-grpc/tests/test_app.py @@ -24,51 +24,53 @@ def setUp(self): self._app = App() def test_method_decorator(self): - @self._app.method('Method1') + @self._app.method("Method1") def method1(request: InvokeMethodRequest): pass - @self._app.method('Method2') + @self._app.method("Method2") def method2(request: InvokeMethodRequest): pass method_map = self._app._servicer._invoke_method_map - self.assertIn('AppTests.test_method_decorator..method1', str( - method_map['Method1'])) - self.assertIn('AppTests.test_method_decorator..method2', str( - method_map['Method2'])) + self.assertIn("AppTests.test_method_decorator..method1", str(method_map["Method1"])) + self.assertIn("AppTests.test_method_decorator..method2", str(method_map["Method2"])) def test_binding_decorator(self): - @self._app.binding('binding1') + @self._app.binding("binding1") def binding1(request: BindingRequest): pass binding_map = self._app._servicer._binding_map self.assertIn( - 'AppTests.test_binding_decorator..binding1', - str(binding_map['binding1'])) + "AppTests.test_binding_decorator..binding1", str(binding_map["binding1"]) + ) def test_subscribe_decorator(self): - @self._app.subscribe(pubsub_name='pubsub', topic='topic') + @self._app.subscribe(pubsub_name="pubsub", topic="topic") def handle_default(event: v1.Event) -> None: pass - @self._app.subscribe(pubsub_name='pubsub', topic='topic', - rule=Rule("event.type == \"test\"", 1)) + @self._app.subscribe( + pubsub_name="pubsub", topic="topic", rule=Rule('event.type == "test"', 1) + ) def handle_test_event(event: v1.Event) -> None: pass - @self._app.subscribe(pubsub_name='pubsub', topic='topic2', dead_letter_topic='topic2_dead') + @self._app.subscribe(pubsub_name="pubsub", topic="topic2", dead_letter_topic="topic2_dead") def handle_dead_letter(event: v1.Event) -> None: pass subscription_map = self._app._servicer._topic_map self.assertIn( - 'AppTests.test_subscribe_decorator..handle_default', - str(subscription_map['pubsub:topic:'])) + "AppTests.test_subscribe_decorator..handle_default", + str(subscription_map["pubsub:topic:"]), + ) self.assertIn( - 'AppTests.test_subscribe_decorator..handle_test_event', - str(subscription_map['pubsub:topic:handle_test_event'])) + "AppTests.test_subscribe_decorator..handle_test_event", + str(subscription_map["pubsub:topic:handle_test_event"]), + ) self.assertIn( - 'AppTests.test_subscribe_decorator..handle_dead_letter', - str(subscription_map['pubsub:topic2:'])) + "AppTests.test_subscribe_decorator..handle_dead_letter", + str(subscription_map["pubsub:topic2:"]), + ) diff --git a/ext/dapr-ext-grpc/tests/test_servicier.py b/ext/dapr-ext-grpc/tests/test_servicier.py index 00c6144e0..5fc623cd0 100644 --- a/ext/dapr-ext-grpc/tests/test_servicier.py +++ b/ext/dapr-ext-grpc/tests/test_servicier.py @@ -35,8 +35,8 @@ def _on_invoke(self, method_name, method_cb): # fake context fake_context = MagicMock() fake_context.invocation_metadata.return_value = ( - ('key1', 'value1'), - ('key2', 'value1'), + ("key1", "value1"), + ("key2", "value1"), ) return self._servicier.OnInvoke( @@ -46,45 +46,49 @@ def _on_invoke(self, method_name, method_cb): def test_on_invoke_return_str(self): def method_cb(request: InvokeMethodRequest): - return 'method_str_cb' - resp = self._on_invoke('method_str', method_cb) + return "method_str_cb" - self.assertEqual(b'method_str_cb', resp.data.value) + resp = self._on_invoke("method_str", method_cb) + + self.assertEqual(b"method_str_cb", resp.data.value) def test_on_invoke_return_bytes(self): def method_cb(request: InvokeMethodRequest): - return b'method_str_cb' - resp = self._on_invoke('method_bytes', method_cb) + return b"method_str_cb" + + resp = self._on_invoke("method_bytes", method_cb) - self.assertEqual(b'method_str_cb', resp.data.value) + self.assertEqual(b"method_str_cb", resp.data.value) def test_on_invoke_return_proto(self): def method_cb(request: InvokeMethodRequest): - return common_v1.StateItem(key='fake_key') - resp = self._on_invoke('method_proto', method_cb) + return common_v1.StateItem(key="fake_key") + + resp = self._on_invoke("method_proto", method_cb) state = common_v1.StateItem() resp.data.Unpack(state) - self.assertEqual('fake_key', state.key) + self.assertEqual("fake_key", state.key) def test_on_invoke_return_invoke_method_response(self): def method_cb(request: InvokeMethodRequest): return InvokeMethodResponse( - data='fake_data', - content_type='text/plain', + data="fake_data", + content_type="text/plain", ) - resp = self._on_invoke('method_resp', method_cb) - self.assertEqual(b'fake_data', resp.data.value) - self.assertEqual('text/plain', resp.content_type) + resp = self._on_invoke("method_resp", method_cb) + + self.assertEqual(b"fake_data", resp.data.value) + self.assertEqual("text/plain", resp.content_type) def test_on_invoke_invalid_response(self): def method_cb(request: InvokeMethodRequest): return 1000 with self.assertRaises(NotImplementedError): - self._on_invoke('method_resp', method_cb) + self._on_invoke("method_resp", method_cb) class TopicSubscriptionTests(unittest.TestCase): @@ -96,68 +100,51 @@ def setUp(self): self._topic3_method.return_value = TopicEventResponse("success") self._topic4_method = Mock() + self._servicier.register_topic("pubsub1", "topic1", self._topic1_method, {"session": "key"}) + self._servicier.register_topic("pubsub1", "topic3", self._topic3_method, {"session": "key"}) + self._servicier.register_topic("pubsub2", "topic2", self._topic2_method, {"session": "key"}) + self._servicier.register_topic("pubsub2", "topic3", self._topic3_method, {"session": "key"}) self._servicier.register_topic( - 'pubsub1', - 'topic1', - self._topic1_method, - {'session': 'key'}) - self._servicier.register_topic( - 'pubsub1', - 'topic3', - self._topic3_method, - {'session': 'key'}) - self._servicier.register_topic( - 'pubsub2', - 'topic2', - self._topic2_method, - {'session': 'key'}) - self._servicier.register_topic( - 'pubsub2', - 'topic3', - self._topic3_method, - {'session': 'key'}) - self._servicier.register_topic( - 'pubsub3', - 'topic4', + "pubsub3", + "topic4", self._topic4_method, - {'session': 'key'}, - disable_topic_validation=True) + {"session": "key"}, + disable_topic_validation=True, + ) # fake context self.fake_context = MagicMock() self.fake_context.invocation_metadata.return_value = ( - ('key1', 'value1'), - ('key2', 'value1'), + ("key1", "value1"), + ("key2", "value1"), ) def test_duplicated_topic(self): with self.assertRaises(ValueError): self._servicier.register_topic( - 'pubsub1', - 'topic1', - self._topic1_method, - {'session': 'key'}) + "pubsub1", "topic1", self._topic1_method, {"session": "key"} + ) def test_list_topic_subscription(self): resp = self._servicier.ListTopicSubscriptions(None, None) - self.assertEqual('pubsub1', resp.subscriptions[0].pubsub_name) - self.assertEqual('topic1', resp.subscriptions[0].topic) - self.assertEqual({'session': 'key'}, resp.subscriptions[0].metadata) - self.assertEqual('pubsub1', resp.subscriptions[1].pubsub_name) - self.assertEqual('topic3', resp.subscriptions[1].topic) - self.assertEqual({'session': 'key'}, resp.subscriptions[1].metadata) - self.assertEqual('pubsub2', resp.subscriptions[2].pubsub_name) - self.assertEqual('topic2', resp.subscriptions[2].topic) - self.assertEqual({'session': 'key'}, resp.subscriptions[2].metadata) - self.assertEqual('pubsub2', resp.subscriptions[3].pubsub_name) - self.assertEqual('topic3', resp.subscriptions[3].topic) - self.assertEqual({'session': 'key'}, resp.subscriptions[3].metadata) - self.assertEqual('topic4', resp.subscriptions[4].topic) - self.assertEqual({'session': 'key'}, resp.subscriptions[4].metadata) + self.assertEqual("pubsub1", resp.subscriptions[0].pubsub_name) + self.assertEqual("topic1", resp.subscriptions[0].topic) + self.assertEqual({"session": "key"}, resp.subscriptions[0].metadata) + self.assertEqual("pubsub1", resp.subscriptions[1].pubsub_name) + self.assertEqual("topic3", resp.subscriptions[1].topic) + self.assertEqual({"session": "key"}, resp.subscriptions[1].metadata) + self.assertEqual("pubsub2", resp.subscriptions[2].pubsub_name) + self.assertEqual("topic2", resp.subscriptions[2].topic) + self.assertEqual({"session": "key"}, resp.subscriptions[2].metadata) + self.assertEqual("pubsub2", resp.subscriptions[3].pubsub_name) + self.assertEqual("topic3", resp.subscriptions[3].topic) + self.assertEqual({"session": "key"}, resp.subscriptions[3].metadata) + self.assertEqual("topic4", resp.subscriptions[4].topic) + self.assertEqual({"session": "key"}, resp.subscriptions[4].metadata) def test_topic_event(self): self._servicier.OnTopicEvent( - appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic1'), + appcallback_v1.TopicEventRequest(pubsub_name="pubsub1", topic="topic1"), self.fake_context, ) @@ -165,7 +152,7 @@ def test_topic_event(self): def test_topic3_event_called_once(self): self._servicier.OnTopicEvent( - appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic3'), + appcallback_v1.TopicEventRequest(pubsub_name="pubsub1", topic="topic3"), self.fake_context, ) @@ -173,18 +160,17 @@ def test_topic3_event_called_once(self): def test_topic3_event_response(self): response = self._servicier.OnTopicEvent( - appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic3'), + appcallback_v1.TopicEventRequest(pubsub_name="pubsub1", topic="topic3"), self.fake_context, ) self.assertIsInstance(response, appcallback_v1.TopicEventResponse) self.assertEqual( - response.status, - appcallback_v1.TopicEventResponse.TopicEventResponseStatus.SUCCESS + response.status, appcallback_v1.TopicEventResponse.TopicEventResponseStatus.SUCCESS ) def test_disable_topic_validation(self): self._servicier.OnTopicEvent( - appcallback_v1.TopicEventRequest(pubsub_name='pubsub3', topic='should_be_ignored'), + appcallback_v1.TopicEventRequest(pubsub_name="pubsub3", topic="should_be_ignored"), self.fake_context, ) @@ -193,7 +179,7 @@ def test_disable_topic_validation(self): def test_non_registered_topic(self): with self.assertRaises(NotImplementedError): self._servicier.OnTopicEvent( - appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic_non_existed'), + appcallback_v1.TopicEventRequest(pubsub_name="pubsub1", topic="topic_non_existed"), self.fake_context, ) @@ -204,31 +190,28 @@ def setUp(self): self._binding1_method = Mock() self._binding2_method = Mock() - self._servicier.register_binding( - 'binding1', self._binding1_method) - self._servicier.register_binding( - 'binding2', self._binding2_method) + self._servicier.register_binding("binding1", self._binding1_method) + self._servicier.register_binding("binding2", self._binding2_method) # fake context self.fake_context = MagicMock() self.fake_context.invocation_metadata.return_value = ( - ('key1', 'value1'), - ('key2', 'value1'), + ("key1", "value1"), + ("key2", "value1"), ) def test_duplicated_binding(self): with self.assertRaises(ValueError): - self._servicier.register_binding( - 'binding1', self._binding1_method) + self._servicier.register_binding("binding1", self._binding1_method) def test_list_bindings(self): resp = self._servicier.ListInputBindings(None, None) - self.assertEqual('binding1', resp.bindings[0]) - self.assertEqual('binding2', resp.bindings[1]) + self.assertEqual("binding1", resp.bindings[0]) + self.assertEqual("binding2", resp.bindings[1]) def test_binding_event(self): self._servicier.OnBindingEvent( - appcallback_v1.BindingEventRequest(name='binding1'), + appcallback_v1.BindingEventRequest(name="binding1"), self.fake_context, ) @@ -237,10 +220,10 @@ def test_binding_event(self): def test_non_registered_binding(self): with self.assertRaises(NotImplementedError): self._servicier.OnBindingEvent( - appcallback_v1.BindingEventRequest(name='binding3'), + appcallback_v1.BindingEventRequest(name="binding3"), self.fake_context, ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/ext/dapr-ext-grpc/tests/test_topic_event_response.py b/ext/dapr-ext-grpc/tests/test_topic_event_response.py index ad1f7b9bf..b93d4e38e 100644 --- a/ext/dapr-ext-grpc/tests/test_topic_event_response.py +++ b/ext/dapr-ext-grpc/tests/test_topic_event_response.py @@ -38,5 +38,5 @@ def test_topic_event_response_creation_fails_with_object(self): TopicEventResponse(None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py index 5a6c144a3..c45c8f2fc 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/__init__.py @@ -22,14 +22,14 @@ from dapr.ext.workflow.retry_policy import RetryPolicy __all__ = [ - 'WorkflowRuntime', - 'DaprWorkflowClient', - 'DaprWorkflowContext', - 'WorkflowActivityContext', - 'WorkflowState', - 'WorkflowStatus', - 'when_all', - 'when_any', - 'alternate_name', - 'RetryPolicy' + "WorkflowRuntime", + "DaprWorkflowClient", + "DaprWorkflowContext", + "WorkflowActivityContext", + "WorkflowState", + "WorkflowStatus", + "when_all", + "when_any", + "alternate_name", + "RetryPolicy", ] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index 865c18138..3f86ebbdf 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -29,32 +29,33 @@ from dapr.conf.helpers import GrpcEndpoint from dapr.ext.workflow.logger import LoggerOptions, Logger -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class DaprWorkflowClient: """Defines client operations for managing Dapr Workflow instances. - This is an alternative to the general purpose Dapr client. It uses a gRPC connection to send - commands directly to the workflow engine, bypassing the Dapr API layer. + This is an alternative to the general purpose Dapr client. It uses a gRPC connection to send + commands directly to the workflow engine, bypassing the Dapr API layer. - This client is intended to be used by workflow application, not by general purpose - application. + This client is intended to be used by workflow application, not by general purpose + application. """ def __init__( - self, - host: Optional[str] = None, - port: Optional[str] = None, - logger_options: Optional[LoggerOptions] = None): + self, + host: Optional[str] = None, + port: Optional[str] = None, + logger_options: Optional[LoggerOptions] = None, + ): address = getAddress(host, port) try: uri = GrpcEndpoint(address) except ValueError as error: - raise DaprInternalError(f'{error}') from error + raise DaprInternalError(f"{error}") from error self._logger = Logger("DaprWorkflowClient", logger_options) @@ -62,15 +63,22 @@ def __init__( if settings.DAPR_API_TOKEN: metadata = ((DAPR_API_TOKEN_HEADER, settings.DAPR_API_TOKEN),) options = self._logger.get_options() - self.__obj = client.TaskHubGrpcClient(host_address=uri.endpoint, - metadata=metadata, - secure_channel=uri.tls, - log_handler=options.log_handler, - log_formatter=options.log_formatter) - - def schedule_new_workflow(self, workflow: Workflow, *, input: Optional[TInput] = None, - instance_id: Optional[str] = None, - start_at: Optional[datetime] = None) -> str: + self.__obj = client.TaskHubGrpcClient( + host_address=uri.endpoint, + metadata=metadata, + secure_channel=uri.tls, + log_handler=options.log_handler, + log_formatter=options.log_formatter, + ) + + def schedule_new_workflow( + self, + workflow: Workflow, + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + start_at: Optional[datetime] = None, + ) -> str: """Schedules a new workflow instance for execution. Args: @@ -86,16 +94,20 @@ def schedule_new_workflow(self, workflow: Workflow, *, input: Optional[TInput] = Returns: The ID of the scheduled workflow instance. """ - if hasattr(workflow, '_dapr_alternate_name'): - return self.__obj.schedule_new_orchestration(workflow.__dict__['_dapr_alternate_name'], - input=input, instance_id=instance_id, - start_at=start_at) - return self.__obj.schedule_new_orchestration(workflow.__name__, input=input, - instance_id=instance_id, - start_at=start_at) - - def get_workflow_state(self, instance_id: str, *, - fetch_payloads: bool = True) -> Optional[WorkflowState]: + if hasattr(workflow, "_dapr_alternate_name"): + return self.__obj.schedule_new_orchestration( + workflow.__dict__["_dapr_alternate_name"], + input=input, + instance_id=instance_id, + start_at=start_at, + ) + return self.__obj.schedule_new_orchestration( + workflow.__name__, input=input, instance_id=instance_id, start_at=start_at + ) + + def get_workflow_state( + self, instance_id: str, *, fetch_payloads: bool = True + ) -> Optional[WorkflowState]: """Fetches runtime state for the specified workflow instance. Args: @@ -111,8 +123,9 @@ def get_workflow_state(self, instance_id: str, *, state = self.__obj.get_orchestration_state(instance_id, fetch_payloads=fetch_payloads) return WorkflowState(state) if state else None - def wait_for_workflow_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout_in_seconds: int = 60) -> Optional[WorkflowState]: + def wait_for_workflow_start( + self, instance_id: str, *, fetch_payloads: bool = False, timeout_in_seconds: int = 60 + ) -> Optional[WorkflowState]: """Waits for a workflow to start running and returns a WorkflowState object that contains metadata about the started workflow. @@ -131,12 +144,14 @@ def wait_for_workflow_start(self, instance_id: str, *, fetch_payloads: bool = Fa WorkflowState record that describes the workflow instance and its execution status. If the specified workflow isn't found, the WorkflowState.Exists value will be false. """ - state = self.__obj.wait_for_orchestration_start(instance_id, fetch_payloads=fetch_payloads, - timeout=timeout_in_seconds) + state = self.__obj.wait_for_orchestration_start( + instance_id, fetch_payloads=fetch_payloads, timeout=timeout_in_seconds + ) return WorkflowState(state) if state else None - def wait_for_workflow_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout_in_seconds: int = 60) -> Optional[WorkflowState]: + def wait_for_workflow_completion( + self, instance_id: str, *, fetch_payloads: bool = True, timeout_in_seconds: int = 60 + ) -> Optional[WorkflowState]: """Waits for a workflow to complete and returns a WorkflowState object that contains metadata about the started instance. @@ -162,13 +177,14 @@ def wait_for_workflow_completion(self, instance_id: str, *, fetch_payloads: bool Returns: WorkflowState record that describes the workflow instance and its execution status. """ - state = self.__obj.wait_for_orchestration_completion(instance_id, - fetch_payloads=fetch_payloads, - timeout=timeout_in_seconds) + state = self.__obj.wait_for_orchestration_completion( + instance_id, fetch_payloads=fetch_payloads, timeout=timeout_in_seconds + ) return WorkflowState(state) if state else None - def raise_workflow_event(self, instance_id: str, event_name: str, *, - data: Optional[Any] = None): + def raise_workflow_event( + self, instance_id: str, event_name: str, *, data: Optional[Any] = None + ): """Sends an event notification message to a waiting workflow instance. In order to handle the event, the target workflow instance must be waiting for an event named value of "eventName" param using the wait_for_external_event API. @@ -210,7 +226,7 @@ def terminate_workflow(self, instance_id: str, *, output: Optional[Any] = None): Args: instance_id: The ID of the workflow instance to terminate. output: The optional output to set for the terminated workflow instance. - """ + """ return self.__obj.terminate_orchestration(instance_id, output=output) def pause_workflow(self, instance_id: str): diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py index 402f5e74d..11585b91c 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_context.py @@ -23,18 +23,17 @@ from dapr.ext.workflow.logger import LoggerOptions, Logger from dapr.ext.workflow.retry_policy import RetryPolicy -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class DaprWorkflowContext(WorkflowContext): """DaprWorkflowContext that provides proxy access to internal OrchestrationContext instance.""" def __init__( - self, - ctx: task.OrchestrationContext, - logger_options: Optional[LoggerOptions] = None): + self, ctx: task.OrchestrationContext, logger_options: Optional[LoggerOptions] = None + ): self.__obj = ctx self._logger = Logger("DaprWorkflowContext", logger_options) @@ -55,15 +54,19 @@ def is_replaying(self) -> bool: return self.__obj.is_replaying def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: - self._logger.debug(f'{self.instance_id}: Creating timer to fire at {fire_at} time') + self._logger.debug(f"{self.instance_id}: Creating timer to fire at {fire_at} time") return self.__obj.create_timer(fire_at) - def call_activity(self, activity: Callable[[WorkflowActivityContext, TInput], TOutput], *, - input: TInput = None, - retry_policy: Optional[RetryPolicy] = None) -> task.Task[TOutput]: - self._logger.debug(f'{self.instance_id}: Creating activity {activity.__name__}') - if hasattr(activity, '_dapr_alternate_name'): - act = activity.__dict__['_dapr_alternate_name'] + def call_activity( + self, + activity: Callable[[WorkflowActivityContext, TInput], TOutput], + *, + input: TInput = None, + retry_policy: Optional[RetryPolicy] = None, + ) -> task.Task[TOutput]: + self._logger.debug(f"{self.instance_id}: Creating activity {activity.__name__}") + if hasattr(activity, "_dapr_alternate_name"): + act = activity.__dict__["_dapr_alternate_name"] else: # this case should ideally never happen act = activity.__name__ @@ -71,33 +74,39 @@ def call_activity(self, activity: Callable[[WorkflowActivityContext, TInput], TO return self.__obj.call_activity(activity=act, input=input) return self.__obj.call_activity(activity=act, input=input, retry_policy=retry_policy.obj) - def call_child_workflow(self, workflow: Workflow, *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - retry_policy: Optional[RetryPolicy] = None) -> task.Task[TOutput]: - self._logger.debug(f'{self.instance_id}: Creating child workflow {workflow.__name__}') + def call_child_workflow( + self, + workflow: Workflow, + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[RetryPolicy] = None, + ) -> task.Task[TOutput]: + self._logger.debug(f"{self.instance_id}: Creating child workflow {workflow.__name__}") def wf(ctx: task.OrchestrationContext, inp: TInput): daprWfContext = DaprWorkflowContext(ctx, self._logger.get_options()) return workflow(daprWfContext, inp) + # copy workflow name so durabletask.worker can find the orchestrator in its registry - if hasattr(workflow, '_dapr_alternate_name'): - wf.__name__ = workflow.__dict__['_dapr_alternate_name'] + if hasattr(workflow, "_dapr_alternate_name"): + wf.__name__ = workflow.__dict__["_dapr_alternate_name"] else: # this case should ideally never happen wf.__name__ = workflow.__name__ if retry_policy is None: return self.__obj.call_sub_orchestrator(wf, input=input, instance_id=instance_id) - return self.__obj.call_sub_orchestrator(wf, input=input, instance_id=instance_id, - retry_policy=retry_policy.obj) + return self.__obj.call_sub_orchestrator( + wf, input=input, instance_id=instance_id, retry_policy=retry_policy.obj + ) def wait_for_external_event(self, name: str) -> task.Task: - self._logger.debug(f'{self.instance_id}: Waiting for external event {name}') + self._logger.debug(f"{self.instance_id}: Waiting for external event {name}") return self.__obj.wait_for_external_event(name) def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None: - self._logger.debug(f'{self.instance_id}: Continuing as new') + self._logger.debug(f"{self.instance_id}: Continuing as new") self.__obj.continue_as_new(new_input, save_events=save_events) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py index 42284dffa..08226c87a 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/__init__.py @@ -1,7 +1,4 @@ from dapr.ext.workflow.logger.options import LoggerOptions from dapr.ext.workflow.logger.logger import Logger -__all__ = [ - 'LoggerOptions', - 'Logger' -] +__all__ = ["LoggerOptions", "Logger"] diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py index ef320bda2..6b0f3fec4 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/logger.py @@ -4,9 +4,7 @@ class Logger: - def __init__(self, - name: str, - options: Union[LoggerOptions, None] = None): + def __init__(self, name: str, options: Union[LoggerOptions, None] = None): # If options is None, then create a new LoggerOptions object if options is None: options = LoggerOptions() diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py index 46b499c10..abe88226c 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/logger/options.py @@ -19,10 +19,10 @@ class LoggerOptions: def __init__( - self, - log_level: Union[str, None] = None, - log_handler: Union[logging.Handler, None] = None, - log_formatter: Union[logging.Formatter, None] = None, + self, + log_level: Union[str, None] = None, + log_handler: Union[logging.Handler, None] = None, + log_formatter: Union[logging.Formatter, None] = None, ): # Set default log level to INFO if none is provided if log_level is None: @@ -34,7 +34,8 @@ def __init__( if log_formatter is None: log_formatter = logging.Formatter( fmt="%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s", - datefmt='%Y-%m-%d %H:%M:%S') + datefmt="%Y-%m-%d %H:%M:%S", + ) self.log_level = log_level self.log_handler = log_handler self.log_formatter = log_formatter diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py b/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py index 82da685d4..b5c191a21 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/retry_policy.py @@ -18,21 +18,22 @@ from durabletask import task -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class RetryPolicy: """Represents the retry policy for a workflow or activity function.""" def __init__( - self, *, - first_retry_interval: timedelta, - max_number_of_attempts: int, - backoff_coefficient: Optional[float] = 1.0, - max_retry_interval: Optional[timedelta] = None, - retry_timeout: Optional[timedelta] = None + self, + *, + first_retry_interval: timedelta, + max_number_of_attempts: int, + backoff_coefficient: Optional[float] = 1.0, + max_retry_interval: Optional[timedelta] = None, + retry_timeout: Optional[timedelta] = None, ): """Creates a new RetryPolicy instance. @@ -48,22 +49,22 @@ def __init__( """ # validate inputs if first_retry_interval < timedelta(seconds=0): - raise ValueError('first_retry_interval must be >= 0') + raise ValueError("first_retry_interval must be >= 0") if max_number_of_attempts < 1: - raise ValueError('max_number_of_attempts must be >= 1') + raise ValueError("max_number_of_attempts must be >= 1") if backoff_coefficient is not None and backoff_coefficient < 1: - raise ValueError('backoff_coefficient must be >= 1') + raise ValueError("backoff_coefficient must be >= 1") if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0): - raise ValueError('max_retry_interval must be >= 0') + raise ValueError("max_retry_interval must be >= 0") if retry_timeout is not None and retry_timeout < timedelta(seconds=0): - raise ValueError('retry_timeout must be >= 0') + raise ValueError("retry_timeout must be >= 0") self._obj = task.RetryPolicy( first_retry_interval=first_retry_interval, max_number_of_attempts=max_number_of_attempts, backoff_coefficient=backoff_coefficient, max_retry_interval=max_retry_interval, - retry_timeout=retry_timeout + retry_timeout=retry_timeout, ) @property diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/util.py b/ext/dapr-ext-workflow/dapr/ext/workflow/util.py index 9a8cad83a..b35e3a573 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/util.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/util.py @@ -20,8 +20,9 @@ def getAddress(host: Optional[str] = None, port: Optional[str] = None) -> str: if not host and not port: - address = settings.DAPR_GRPC_ENDPOINT or (f"{settings.DAPR_RUNTIME_HOST}:" - f"{settings.DAPR_GRPC_PORT}") + address = settings.DAPR_GRPC_ENDPOINT or ( + f"{settings.DAPR_RUNTIME_HOST}:" f"{settings.DAPR_GRPC_PORT}" + ) else: host = host or settings.DAPR_RUNTIME_HOST port = port or settings.DAPR_GRPC_PORT diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py index 8ea28bb10..667ed3451 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_activity_context.py @@ -18,12 +18,12 @@ from durabletask import task -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") -class WorkflowActivityContext(): +class WorkflowActivityContext: """Defines properties and methods for task activity context objects.""" def __init__(self, ctx: task.ActivityContext): diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py index 66ee03a5a..92b6cc857 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_context.py @@ -22,15 +22,15 @@ from dapr.ext.workflow.workflow_activity_context import Activity -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class WorkflowContext(ABC): """Context object used by workflow implementations to perform actions such as scheduling - activities, durable timers, waiting for external events, and for getting basic information - about the current workflow instance. + activities, durable timers, waiting for external events, and for getting basic information + about the current workflow instance. """ @property @@ -101,8 +101,9 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: pass @abstractmethod - def call_activity(self, activity: Activity[TOutput], *, - input: Optional[TInput] = None) -> task.Task[TOutput]: + def call_activity( + self, activity: Activity[TOutput], *, input: Optional[TInput] = None + ) -> task.Task[TOutput]: """Schedule an activity for execution. Parameters @@ -122,9 +123,13 @@ def call_activity(self, activity: Activity[TOutput], *, pass @abstractmethod - def call_child_workflow(self, orchestrator: Workflow[TOutput], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None) -> task.Task[TOutput]: + def call_child_workflow( + self, + orchestrator: Workflow[TOutput], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + ) -> task.Task[TOutput]: """Schedule child-workflow function for execution. Parameters diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 5e2c425d4..cc228c8e4 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -30,20 +30,20 @@ from dapr.conf.helpers import GrpcEndpoint from dapr.ext.workflow.logger import LoggerOptions, Logger -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class WorkflowRuntime: - """WorkflowRuntime is the entry point for registering workflows and activities. - """ + """WorkflowRuntime is the entry point for registering workflows and activities.""" def __init__( - self, - host: Optional[str] = None, - port: Optional[str] = None, - logger_options: Optional[LoggerOptions] = None): + self, + host: Optional[str] = None, + port: Optional[str] = None, + logger_options: Optional[LoggerOptions] = None, + ): self._logger = Logger("WorkflowRuntime", logger_options) metadata = tuple() if settings.DAPR_API_TOKEN: @@ -53,14 +53,16 @@ def __init__( try: uri = GrpcEndpoint(address) except ValueError as error: - raise DaprInternalError(f'{error}') from error + raise DaprInternalError(f"{error}") from error options = self._logger.get_options() - self.__worker = worker.TaskHubGrpcWorker(host_address=uri.endpoint, - metadata=metadata, - secure_channel=uri.tls, - log_handler=options.log_handler, - log_formatter=options.log_formatter) + self.__worker = worker.TaskHubGrpcWorker( + host_address=uri.endpoint, + metadata=metadata, + secure_channel=uri.tls, + log_handler=options.log_handler, + log_formatter=options.log_formatter, + ) def register_workflow(self, fn: Workflow, *, name: Optional[str] = None): self._logger.info(f"Registering workflow '{fn.__name__}' with runtime") @@ -72,25 +74,26 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = return fn(daprWfContext) return fn(daprWfContext, inp) - if hasattr(fn, '_workflow_registered'): + if hasattr(fn, "_workflow_registered"): # whenever a workflow is registered, it has a _dapr_alternate_name attribute - alt_name = fn.__dict__['_dapr_alternate_name'] - raise ValueError(f'Workflow {fn.__name__} already registered as {alt_name}') - if hasattr(fn, '_dapr_alternate_name'): + alt_name = fn.__dict__["_dapr_alternate_name"] + raise ValueError(f"Workflow {fn.__name__} already registered as {alt_name}") + if hasattr(fn, "_dapr_alternate_name"): alt_name = fn._dapr_alternate_name if name is not None: - m = f'Workflow {fn.__name__} already has an alternate name {alt_name}' + m = f"Workflow {fn.__name__} already has an alternate name {alt_name}" raise ValueError(m) else: - fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + fn.__dict__["_dapr_alternate_name"] = name if name else fn.__name__ - self.__worker._registry.add_named_orchestrator(fn.__dict__['_dapr_alternate_name'], - orchestrationWrapper) - fn.__dict__['_workflow_registered'] = True + self.__worker._registry.add_named_orchestrator( + fn.__dict__["_dapr_alternate_name"], orchestrationWrapper + ) + fn.__dict__["_workflow_registered"] = True def register_activity(self, fn: Activity, *, name: Optional[str] = None): """Registers a workflow activity as a function that takes - a specified input type and returns a specified output type. + a specified input type and returns a specified output type. """ self._logger.info(f"Registering activity '{fn.__name__}' with runtime") @@ -101,21 +104,22 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): return fn(wfActivityContext) return fn(wfActivityContext, inp) - if hasattr(fn, '_activity_registered'): + if hasattr(fn, "_activity_registered"): # whenever an activity is registered, it has a _dapr_alternate_name attribute - alt_name = fn.__dict__['_dapr_alternate_name'] - raise ValueError(f'Activity {fn.__name__} already registered as {alt_name}') - if hasattr(fn, '_dapr_alternate_name'): + alt_name = fn.__dict__["_dapr_alternate_name"] + raise ValueError(f"Activity {fn.__name__} already registered as {alt_name}") + if hasattr(fn, "_dapr_alternate_name"): alt_name = fn._dapr_alternate_name if name is not None: - m = f'Activity {fn.__name__} already has an alternate name {alt_name}' + m = f"Activity {fn.__name__} already has an alternate name {alt_name}" raise ValueError(m) else: - fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + fn.__dict__["_dapr_alternate_name"] = name if name else fn.__name__ - self.__worker._registry.add_named_activity(fn.__dict__['_dapr_alternate_name'], - activityWrapper) - fn.__dict__['_activity_registered'] = True + self.__worker._registry.add_named_activity( + fn.__dict__["_dapr_alternate_name"], activityWrapper + ) + fn.__dict__["_activity_registered"] = True def start(self): """Starts the listening for work items on a background thread.""" @@ -158,10 +162,11 @@ def wrapper(fn: Workflow): @wraps(fn) def innerfn(): return fn - if hasattr(fn, '_dapr_alternate_name'): - innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name'] + + if hasattr(fn, "_dapr_alternate_name"): + innerfn.__dict__["_dapr_alternate_name"] = fn.__dict__["_dapr_alternate_name"] else: - innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__dict__["_dapr_alternate_name"] = name if name else fn.__name__ innerfn.__signature__ = inspect.signature(fn) return innerfn @@ -205,10 +210,10 @@ def wrapper(fn: Activity): def innerfn(): return fn - if hasattr(fn, '_dapr_alternate_name'): - innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name'] + if hasattr(fn, "_dapr_alternate_name"): + innerfn.__dict__["_dapr_alternate_name"] = fn.__dict__["_dapr_alternate_name"] else: - innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__dict__["_dapr_alternate_name"] = name if name else fn.__name__ innerfn.__signature__ = inspect.signature(fn) return innerfn @@ -239,16 +244,17 @@ def add(ctx, x: int, y: int) -> int: """ def wrapper(fn: any): - if hasattr(fn, '_dapr_alternate_name'): + if hasattr(fn, "_dapr_alternate_name"): raise ValueError( - f'Function {fn.__name__} already has an alternate name {fn._dapr_alternate_name}') - fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + f"Function {fn.__name__} already has an alternate name {fn._dapr_alternate_name}" + ) + fn.__dict__["_dapr_alternate_name"] = name if name else fn.__name__ @wraps(fn) def innerfn(*args, **kwargs): return fn(*args, **kwargs) - innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ + innerfn.__dict__["_dapr_alternate_name"] = name if name else fn.__name__ innerfn.__signature__ = inspect.signature(fn) return innerfn diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py index 760456bb9..0b45c1fba 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_state.py @@ -61,17 +61,17 @@ def __str__(self) -> str: def to_json(self): return { - 'instance_id': self.__obj.instance_id, - 'name': self.__obj.name, - 'runtime_status': self.__obj.runtime_status.name, - 'created_at': self.__obj.created_at, - 'last_updated_at': self.__obj.last_updated_at, - 'serialized_input': self.__obj.serialized_input, - 'serialized_output': self.__obj.serialized_output, - 'serialized_custom_status': self.__obj.serialized_custom_status, - 'failure_details': { - 'message': self.__obj.failure_details.message, - 'error_type': self.__obj.failure_details.error_type, - 'stack_trace': self.__obj.failure_details.stack_trace - } + "instance_id": self.__obj.instance_id, + "name": self.__obj.name, + "runtime_status": self.__obj.runtime_status.name, + "created_at": self.__obj.created_at, + "last_updated_at": self.__obj.last_updated_at, + "serialized_input": self.__obj.serialized_input, + "serialized_output": self.__obj.serialized_output, + "serialized_custom_status": self.__obj.serialized_custom_status, + "failure_details": { + "message": self.__obj.failure_details.message, + "error_type": self.__obj.failure_details.error_type, + "stack_trace": self.__obj.failure_details.stack_trace, + }, } diff --git a/ext/dapr-ext-workflow/setup.py b/ext/dapr-ext-workflow/setup.py index 7565f43ae..26043a7d5 100644 --- a/ext/dapr-ext-workflow/setup.py +++ b/ext/dapr-ext-workflow/setup.py @@ -19,19 +19,19 @@ # Load version in dapr package. version_info = {} -with open('dapr/ext/workflow/version.py') as fp: +with open("dapr/ext/workflow/version.py") as fp: exec(fp.read(), version_info) -__version__ = version_info['__version__'] +__version__ = version_info["__version__"] def is_release(): - return '.dev' not in __version__ + return ".dev" not in __version__ -name = 'dapr-ext-workflow' +name = "dapr-ext-workflow" version = __version__ -description = 'The official release of Dapr Python SDK Workflow Authoring Extension.' -long_description = ''' +description = "The official release of Dapr Python SDK Workflow Authoring Extension." +long_description = """ This is the Workflow authoring extension for Dapr. Dapr is a portable, serverless, event-driven runtime that makes it easy for developers to @@ -42,18 +42,18 @@ def is_release(): independent, building blocks that enable you to build portable applications with the language and framework of your choice. Each building block is independent and you can use one, some, or all of them in your application. -'''.lstrip() +""".lstrip() # Get build number from GITHUB_RUN_NUMBER environment variable -build_number = os.environ.get('GITHUB_RUN_NUMBER', '0') +build_number = os.environ.get("GITHUB_RUN_NUMBER", "0") if not is_release(): - name += '-dev' - version = f'{__version__}{build_number}' - description = 'The developmental release for Dapr Workflow Authoring.' - long_description = 'This is the developmental release for Dapr Workflow Authoring.' + name += "-dev" + version = f"{__version__}{build_number}" + description = "The developmental release for Dapr Workflow Authoring." + long_description = "This is the developmental release for Dapr Workflow Authoring." -print(f'package name: {name}, version: {version}', flush=True) +print(f"package name: {name}, version: {version}", flush=True) setup( diff --git a/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py b/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py index 9ea802f29..787624737 100644 --- a/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py +++ b/ext/dapr-ext-workflow/tests/test_dapr_workflow_context.py @@ -42,23 +42,25 @@ def call_sub_orchestrator(self, orchestrator, input, instance_id): class DaprWorkflowContextTest(unittest.TestCase): - def mock_client_activity(ctx: WorkflowActivityContext, input): - print(f'{input}!', flush=True) + print(f"{input}!", flush=True) def mock_client_child_wf(ctx: DaprWorkflowContext, input): - print(f'{input}') + print(f"{input}") def test_workflow_context_functions(self): - with mock.patch('durabletask.worker._RuntimeOrchestrationContext', - return_value=FakeOrchestrationContext()): + with mock.patch( + "durabletask.worker._RuntimeOrchestrationContext", + return_value=FakeOrchestrationContext(), + ): fakeContext = worker._RuntimeOrchestrationContext(mock_instance_id) dapr_wf_ctx = DaprWorkflowContext(fakeContext) call_activity_result = dapr_wf_ctx.call_activity(self.mock_client_activity, input=None) assert call_activity_result == mock_call_activity call_sub_orchestrator_result = dapr_wf_ctx.call_child_workflow( - self.mock_client_child_wf) + self.mock_client_child_wf + ) assert call_sub_orchestrator_result == mock_call_sub_orchestrator create_timer_result = dapr_wf_ctx.create_timer(mock_date_time) diff --git a/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py b/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py index ac0dfc9f2..177f1ea54 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_activity_context.py @@ -34,9 +34,10 @@ def task_id(self): class WorkflowActivityContextTest(unittest.TestCase): def test_workflow_activity_context(self): - with mock.patch('durabletask.task.ActivityContext', return_value=FakeActivityContext()): - fake_act_ctx = task.ActivityContext(orchestration_id=mock_orchestration_id, - task_id=mock_task) + with mock.patch("durabletask.task.ActivityContext", return_value=FakeActivityContext()): + fake_act_ctx = task.ActivityContext( + orchestration_id=mock_orchestration_id, task_id=mock_task + ) act_ctx = WorkflowActivityContext(fake_act_ctx) actual_orchestration_id = act_ctx.workflow_id assert actual_orchestration_id == mock_orchestration_id diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client.py b/ext/dapr-ext-workflow/tests/test_workflow_client.py index 587a7df43..86b6f01b4 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client.py @@ -30,7 +30,6 @@ class FakeTaskHubGrpcClient: - def schedule_new_orchestration(self, workflow, input, instance_id, start_at): return mock_schedule_result @@ -38,19 +37,19 @@ def get_orchestration_state(self, instance_id, fetch_payloads): return self._inner_get_orchestration_state(instance_id, client.OrchestrationStatus.PENDING) def wait_for_orchestration_start(self, instance_id, fetch_payloads, timeout): - return self._inner_get_orchestration_state(instance_id, - client.OrchestrationStatus.RUNNING) + return self._inner_get_orchestration_state(instance_id, client.OrchestrationStatus.RUNNING) def wait_for_orchestration_completion(self, instance_id, fetch_payloads, timeout): - return self._inner_get_orchestration_state(instance_id, - client.OrchestrationStatus.COMPLETED) + return self._inner_get_orchestration_state( + instance_id, client.OrchestrationStatus.COMPLETED + ) - def raise_orchestration_event(self, instance_id: str, event_name: str, *, - data: Union[Any, None] = None): + def raise_orchestration_event( + self, instance_id: str, event_name: str, *, data: Union[Any, None] = None + ): return mock_raise_event_result - def terminate_orchestration(self, instance_id: str, *, - output: Union[Any, None] = None): + def terminate_orchestration(self, instance_id: str, *, output: Union[Any, None] = None): return mock_terminate_result def suspend_orchestration(self, instance_id: str): @@ -60,51 +59,59 @@ def resume_orchestration(self, instance_id: str): return mock_resume_result def _inner_get_orchestration_state(self, instance_id, state: client.OrchestrationStatus): - return client.OrchestrationState(instance_id=instance_id, name="", - runtime_status=state, - created_at=datetime.now(), - last_updated_at=datetime.now(), - serialized_input=None, - serialized_output=None, - serialized_custom_status=None, - failure_details=None) + return client.OrchestrationState( + instance_id=instance_id, + name="", + runtime_status=state, + created_at=datetime.now(), + last_updated_at=datetime.now(), + serialized_input=None, + serialized_output=None, + serialized_custom_status=None, + failure_details=None, + ) class WorkflowClientTest(unittest.TestCase): - def mock_client_wf(ctx: DaprWorkflowContext, input): - print(f'{input}') + print(f"{input}") def test_client_functions(self): - with mock.patch('durabletask.client.TaskHubGrpcClient', - return_value=FakeTaskHubGrpcClient()): + with mock.patch( + "durabletask.client.TaskHubGrpcClient", return_value=FakeTaskHubGrpcClient() + ): wfClient = DaprWorkflowClient() - actual_schedule_result = wfClient.schedule_new_workflow(workflow=self.mock_client_wf, - input='Hi Chef!') + actual_schedule_result = wfClient.schedule_new_workflow( + workflow=self.mock_client_wf, input="Hi Chef!" + ) assert actual_schedule_result == mock_schedule_result - actual_get_result = wfClient.get_workflow_state(instance_id=mockInstanceId, - fetch_payloads=True) + actual_get_result = wfClient.get_workflow_state( + instance_id=mockInstanceId, fetch_payloads=True + ) assert actual_get_result.runtime_status.name == "PENDING" assert actual_get_result.instance_id == mockInstanceId - actual_wait_start_result = wfClient.wait_for_workflow_start(instance_id=mockInstanceId, - timeout_in_seconds=30) + actual_wait_start_result = wfClient.wait_for_workflow_start( + instance_id=mockInstanceId, timeout_in_seconds=30 + ) assert actual_wait_start_result.runtime_status.name == "RUNNING" assert actual_wait_start_result.instance_id == mockInstanceId actual_wait_completion_result = wfClient.wait_for_workflow_completion( - instance_id=mockInstanceId, timeout_in_seconds=30) + instance_id=mockInstanceId, timeout_in_seconds=30 + ) assert actual_wait_completion_result.runtime_status.name == "COMPLETED" assert actual_wait_completion_result.instance_id == mockInstanceId - actual_raise_event_result = wfClient.raise_workflow_event(instance_id=mockInstanceId, - event_name="test_event", - data="test_data") + actual_raise_event_result = wfClient.raise_workflow_event( + instance_id=mockInstanceId, event_name="test_event", data="test_data" + ) assert actual_raise_event_result == mock_raise_event_result - actual_terminate_result = wfClient.terminate_workflow(instance_id=mockInstanceId, - output="test_output") + actual_terminate_result = wfClient.terminate_workflow( + instance_id=mockInstanceId, output="test_output" + ) assert actual_terminate_result == mock_terminate_result actual_suspend_result = wfClient.pause_workflow(instance_id=mockInstanceId) diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index c499a7851..2310af368 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -33,11 +33,10 @@ def add_named_activity(self, name: str, fn): class WorkflowRuntimeTest(unittest.TestCase): - def setUp(self): listActivities.clear() listOrchestrators.clear() - mock.patch('durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()).start() + mock.patch("durabletask.worker._Registry", return_value=FakeTaskHubGrpcWorker()).start() self.runtime_options = WorkflowRuntime() if hasattr(self.mock_client_wf, "_dapr_alternate_name"): del self.mock_client_wf.__dict__["_dapr_alternate_name"] @@ -49,10 +48,10 @@ def setUp(self): del self.mock_client_activity.__dict__["_activity_registered"] def mock_client_wf(ctx: DaprWorkflowContext, input): - print(f'{input}') + print(f"{input}") def mock_client_activity(ctx: WorkflowActivityContext, input): - print(f'{input}!', flush=True) + print(f"{input}!", flush=True) def test_register(self): self.runtime_options.register_workflow(self.mock_client_wf, name="mock_client_wf") @@ -101,8 +100,10 @@ def test_register_wf_act_using_both_decorator_and_method(self): with self.assertRaises(ValueError) as exeception_context: self.runtime_options.register_workflow(self.mock_client_wf) wf_name = self.mock_client_wf.__name__ - self.assertEqual(exeception_context.exception.args[0], - f'Workflow {wf_name} already registered as test_wf') + self.assertEqual( + exeception_context.exception.args[0], + f"Workflow {wf_name} already registered as test_wf", + ) client_act = (self.runtime_options.activity(name="test_act"))(self.mock_client_activity) wanted_activity = ["test_act"] @@ -111,21 +112,27 @@ def test_register_wf_act_using_both_decorator_and_method(self): with self.assertRaises(ValueError) as exeception_context: self.runtime_options.register_activity(self.mock_client_activity) act_name = self.mock_client_activity.__name__ - self.assertEqual(exeception_context.exception.args[0], - f'Activity {act_name} already registered as test_act') + self.assertEqual( + exeception_context.exception.args[0], + f"Activity {act_name} already registered as test_act", + ) def test_duplicate_dapr_alternate_name_registration(self): client_wf = (alternate_name(name="test"))(self.mock_client_wf) with self.assertRaises(ValueError) as exeception_context: (self.runtime_options.workflow(name="random"))(client_wf) - self.assertEqual(exeception_context.exception.args[0], - f'Workflow {client_wf.__name__} already has an alternate name test') + self.assertEqual( + exeception_context.exception.args[0], + f"Workflow {client_wf.__name__} already has an alternate name test", + ) client_act = (alternate_name(name="test"))(self.mock_client_activity) with self.assertRaises(ValueError) as exeception_context: (self.runtime_options.activity(name="random"))(client_act) - self.assertEqual(exeception_context.exception.args[0], - f'Activity {client_act.__name__} already has an alternate name test') + self.assertEqual( + exeception_context.exception.args[0], + f"Activity {client_act.__name__} already has an alternate name test", + ) def test_register_wf_act_using_both_decorator_and_method_without_name(self): client_wf = (self.runtime_options.workflow())(self.mock_client_wf) @@ -136,8 +143,10 @@ def test_register_wf_act_using_both_decorator_and_method_without_name(self): with self.assertRaises(ValueError) as exeception_context: self.runtime_options.register_workflow(self.mock_client_wf, name="test_wf") wf_name = self.mock_client_wf.__name__ - self.assertEqual(exeception_context.exception.args[0], - f'Workflow {wf_name} already registered as mock_client_wf') + self.assertEqual( + exeception_context.exception.args[0], + f"Workflow {wf_name} already registered as mock_client_wf", + ) client_act = (self.runtime_options.activity())(self.mock_client_activity) wanted_activity = ["mock_client_activity"] @@ -146,8 +155,10 @@ def test_register_wf_act_using_both_decorator_and_method_without_name(self): with self.assertRaises(ValueError) as exeception_context: self.runtime_options.register_activity(self.mock_client_activity, name="test_act") act_name = self.mock_client_activity.__name__ - self.assertEqual(exeception_context.exception.args[0], - f'Activity {act_name} already registered as mock_client_activity') + self.assertEqual( + exeception_context.exception.args[0], + f"Activity {act_name} already registered as mock_client_activity", + ) def test_decorator_register_optinal_name(self): client_wf = (self.runtime_options.workflow(name="test_wf"))(self.mock_client_wf) diff --git a/ext/dapr-ext-workflow/tests/test_workflow_util.py b/ext/dapr-ext-workflow/tests/test_workflow_util.py index 2fcab61df..bcad9e813 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_util.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_util.py @@ -6,7 +6,6 @@ class DaprWorkflowUtilTest(unittest.TestCase): - def test_get_address_default(self): expected = f"{settings.DAPR_RUNTIME_HOST}:{settings.DAPR_GRPC_PORT}" self.assertEqual(expected, getAddress()) diff --git a/ext/flask_dapr/flask_dapr/__init__.py b/ext/flask_dapr/flask_dapr/__init__.py index e43df65c9..4694e3899 100644 --- a/ext/flask_dapr/flask_dapr/__init__.py +++ b/ext/flask_dapr/flask_dapr/__init__.py @@ -16,4 +16,4 @@ from .actor import DaprActor from .app import DaprApp -__all__ = ['DaprActor', 'DaprApp'] +__all__ = ["DaprActor", "DaprApp"] diff --git a/ext/flask_dapr/flask_dapr/actor.py b/ext/flask_dapr/flask_dapr/actor.py index 07b803177..effd5205a 100644 --- a/ext/flask_dapr/flask_dapr/actor.py +++ b/ext/flask_dapr/flask_dapr/actor.py @@ -23,7 +23,7 @@ from dapr.serializers import DefaultJSONSerializer DEFAULT_CONTENT_TYPE = "application/json; utf-8" -DAPR_REENTRANCY_ID_HEADER = 'Dapr-Reentrancy-Id' +DAPR_REENTRANCY_ID_HEADER = "Dapr-Reentrancy-Id" class DaprActor(object): @@ -35,46 +35,42 @@ def __init__(self, app=None): self.init_routes(app) def init_routes(self, app): + app.add_url_rule("/healthz", None, self._healthz_handler, methods=["GET"]) + app.add_url_rule("/dapr/config", None, self._config_handler, methods=["GET"]) app.add_url_rule( - '/healthz', None, - self._healthz_handler, - methods=['GET'] - ) - app.add_url_rule( - '/dapr/config', None, - self._config_handler, - methods=['GET'] - ) - app.add_url_rule( - '/actors//', None, + "/actors//", + None, self._deactivation_handler, - methods=['DELETE'] + methods=["DELETE"], ) app.add_url_rule( - '/actors///method/', None, + "/actors///method/", + None, self._method_handler, - methods=['PUT'] + methods=["PUT"], ) app.add_url_rule( - '/actors///method/timer/', None, + "/actors///method/timer/", + None, self._timer_handler, - methods=['PUT'] + methods=["PUT"], ) app.add_url_rule( - '/actors///method/remind/', None, + "/actors///method/remind/", + None, self._reminder_handler, - methods=['PUT'] + methods=["PUT"], ) def teardown(self, exception): - self._app.logger.debug('actor service is shutting down.') + self._app.logger.debug("actor service is shutting down.") def register_actor(self, actor: Type[Actor]) -> None: asyncio.run(ActorRuntime.register_actor(actor)) - self._app.logger.debug(f'registered actor: {actor.__class__.__name__}') + self._app.logger.debug(f"registered actor: {actor.__class__.__name__}") def _healthz_handler(self): - return wrap_response(200, 'ok') + return wrap_response(200, "ok") def _config_handler(self): serialized = self._dapr_serializer.serialize(ActorRuntime.get_actor_config()) @@ -88,7 +84,7 @@ def _deactivation_handler(self, actor_type_name, actor_id): except Exception as ex: return wrap_response(500, repr(ex), ERROR_CODE_UNKNOWN) - msg = f'deactivated actor: {actor_type_name}.{actor_id}' + msg = f"deactivated actor: {actor_type_name}.{actor_id}" self._app.logger.debug(msg) return wrap_response(200, msg) @@ -97,14 +93,17 @@ def _method_handler(self, actor_type_name, actor_id, method_name): # Read raw bytes from request stream req_body = request.stream.read() reentrancy_id = request.headers.get(DAPR_REENTRANCY_ID_HEADER) - result = asyncio.run(ActorRuntime.dispatch( - actor_type_name, actor_id, method_name, req_body, reentrancy_id)) + result = asyncio.run( + ActorRuntime.dispatch( + actor_type_name, actor_id, method_name, req_body, reentrancy_id + ) + ) except DaprInternalError as ex: return wrap_response(500, ex.as_dict()) except Exception as ex: return wrap_response(500, repr(ex), ERROR_CODE_UNKNOWN) - msg = f'called method. actor: {actor_type_name}.{actor_id}, method: {method_name}' + msg = f"called method. actor: {actor_type_name}.{actor_id}, method: {method_name}" self._app.logger.debug(msg) return wrap_response(200, result) @@ -118,7 +117,7 @@ def _timer_handler(self, actor_type_name, actor_id, timer_name): except Exception as ex: return wrap_response(500, repr(ex), ERROR_CODE_UNKNOWN) - msg = f'called timer. actor: {actor_type_name}.{actor_id}, timer: {timer_name}' + msg = f"called timer. actor: {actor_type_name}.{actor_id}, timer: {timer_name}" self._app.logger.debug(msg) return wrap_response(200, msg) @@ -126,33 +125,34 @@ def _reminder_handler(self, actor_type_name, actor_id, reminder_name): try: # Read raw bytes from request stream req_body = request.stream.read() - asyncio.run(ActorRuntime.fire_reminder( - actor_type_name, actor_id, reminder_name, req_body)) + asyncio.run( + ActorRuntime.fire_reminder(actor_type_name, actor_id, reminder_name, req_body) + ) except DaprInternalError as ex: return wrap_response(500, ex.as_dict()) except Exception as ex: return wrap_response(500, repr(ex), ERROR_CODE_UNKNOWN) - msg = f'called reminder. actor: {actor_type_name}.{actor_id}, reminder: {reminder_name}' + msg = f"called reminder. actor: {actor_type_name}.{actor_id}, reminder: {reminder_name}" self._app.logger.debug(msg) return wrap_response(200, msg) # wrap_response wraps dapr errors to flask response def wrap_response( - status: int, msg: Any, - error_code: Optional[str] = None, content_type: Optional[str] = None): + status: int, msg: Any, error_code: Optional[str] = None, content_type: Optional[str] = None +): resp = None if isinstance(msg, str): response_obj = { - 'message': msg, + "message": msg, } if not (status >= 200 and status < 300) and error_code: - response_obj['errorCode'] = error_code + response_obj["errorCode"] = error_code resp = make_response(jsonify(response_obj), status) elif isinstance(msg, bytes): resp = make_response(msg, status) else: resp = make_response(jsonify(msg), status) - resp.headers['Content-type'] = content_type or DEFAULT_CONTENT_TYPE + resp.headers["Content-type"] = content_type or DEFAULT_CONTENT_TYPE return resp diff --git a/ext/flask_dapr/flask_dapr/app.py b/ext/flask_dapr/flask_dapr/app.py index 3fbc39ab6..b934d587e 100644 --- a/ext/flask_dapr/flask_dapr/app.py +++ b/ext/flask_dapr/flask_dapr/app.py @@ -29,17 +29,18 @@ def __init__(self, app_instance: Flask): self._app = app_instance self._subscriptions: List[Dict[str, object]] = [] - self._app.add_url_rule('/dapr/subscribe', - '/dapr/subscribe', - self._get_subscriptions, - methods=["GET"]) - - def subscribe(self, - pubsub: str, - topic: str, - metadata: Optional[Dict[str, str]] = {}, - route: Optional[str] = None, - dead_letter_topic: Optional[str] = None): + self._app.add_url_rule( + "/dapr/subscribe", "/dapr/subscribe", self._get_subscriptions, methods=["GET"] + ) + + def subscribe( + self, + pubsub: str, + topic: str, + metadata: Optional[Dict[str, str]] = {}, + route: Optional[str] = None, + dead_letter_topic: Optional[str] = None, + ): """ Subscribes to a topic on a pub/sub component. @@ -71,21 +72,25 @@ def subscribe(self, Returns: The decorator for the function. """ + def decorator(func): event_handler_route = f"/events/{pubsub}/{topic}" if route is None else route - self._app.add_url_rule(event_handler_route, - event_handler_route, - func, - methods=["POST"]) - - self._subscriptions.append({ - "pubsubname": pubsub, - "topic": topic, - "route": event_handler_route, - "metadata": metadata, - **({"deadLetterTopic": dead_letter_topic} if dead_letter_topic is not None else {}) - }) + self._app.add_url_rule(event_handler_route, event_handler_route, func, methods=["POST"]) + + self._subscriptions.append( + { + "pubsubname": pubsub, + "topic": topic, + "route": event_handler_route, + "metadata": metadata, + **( + {"deadLetterTopic": dead_letter_topic} + if dead_letter_topic is not None + else {} + ), + } + ) return decorator diff --git a/ext/flask_dapr/setup.py b/ext/flask_dapr/setup.py index 6142926c7..6785fc9c2 100644 --- a/ext/flask_dapr/setup.py +++ b/ext/flask_dapr/setup.py @@ -19,19 +19,19 @@ # Load version in dapr package. version_info = {} -with open('flask_dapr/version.py') as fp: +with open("flask_dapr/version.py") as fp: exec(fp.read(), version_info) -__version__ = version_info['__version__'] +__version__ = version_info["__version__"] def is_release(): - return '.dev' not in __version__ + return ".dev" not in __version__ -name = 'flask-dapr' +name = "flask-dapr" version = __version__ -description = 'The official release of Dapr Python SDK Flask Extension.' -long_description = ''' +description = "The official release of Dapr Python SDK Flask Extension." +long_description = """ This is the Flask extension for Dapr. Dapr is a portable, serverless, event-driven runtime that makes it easy for developers to @@ -42,18 +42,18 @@ def is_release(): independent, building blocks that enable you to build portable applications with the language and framework of your choice. Each building block is independent and you can use one, some, or all of them in your application. -'''.lstrip() +""".lstrip() # Get build number from GITHUB_RUN_NUMBER environment variable -build_number = os.environ.get('GITHUB_RUN_NUMBER', '0') +build_number = os.environ.get("GITHUB_RUN_NUMBER", "0") if not is_release(): - name += '-dev' - version = f'{__version__}{build_number}' - description = 'The developmental release for Dapr Python SDK Flask.' - long_description = 'This is the developmental release for Dapr Python SDK Flask.' + name += "-dev" + version = f"{__version__}{build_number}" + description = "The developmental release for Dapr Python SDK Flask." + long_description = "This is the developmental release for Dapr Python SDK Flask." -print(f'package name: {name}, version: {version}', flush=True) +print(f"package name: {name}, version: {version}", flush=True) setup( diff --git a/ext/flask_dapr/tests/test_app.py b/ext/flask_dapr/tests/test_app.py index 44f05c376..7aeb19b72 100644 --- a/ext/flask_dapr/tests/test_app.py +++ b/ext/flask_dapr/tests/test_app.py @@ -25,11 +25,16 @@ def event_handler(): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/events/pubsub/test', - 'metadata': {} - }], json.loads(response.data)) + [ + { + "pubsubname": "pubsub", + "topic": "test", + "route": "/events/pubsub/test", + "metadata": {}, + } + ], + json.loads(response.data), + ) response = self.client.post("/events/pubsub/test", json={"body": "new message"}) self.assertEqual(response.status_code, 200) @@ -47,22 +52,18 @@ def event_handler(): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/do-something', - 'metadata': {} - }], json.loads(response.data)) + [{"pubsubname": "pubsub", "topic": "test", "route": "/do-something", "metadata": {}}], + json.loads(response.data), + ) response = self.client.post("/do-something", json={"body": "new message"}) self.assertEqual(response.status_code, 200) - self.assertEqual(response.data.decode("utf-8"), 'custom route') + self.assertEqual(response.data.decode("utf-8"), "custom route") def test_subscribe_metadata(self): handler_metadata = {"rawPayload": "true"} - @self.dapr_app.subscribe(pubsub="pubsub", - topic="test", - metadata=handler_metadata) + @self.dapr_app.subscribe(pubsub="pubsub", topic="test", metadata=handler_metadata) def event_handler(): return "custom metadata" @@ -70,11 +71,16 @@ def event_handler(): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/events/pubsub/test', - 'metadata': {"rawPayload": "true"} - }], json.loads(response.data)) + [ + { + "pubsubname": "pubsub", + "topic": "test", + "route": "/events/pubsub/test", + "metadata": {"rawPayload": "true"}, + } + ], + json.loads(response.data), + ) response = self.client.post("/events/pubsub/test", json={"body": "new message"}) self.assertEqual(response.status_code, 200) @@ -83,9 +89,7 @@ def event_handler(): def test_subscribe_dead_letter(self): dead_letter_topic = "dead-test" - @self.dapr_app.subscribe(pubsub="pubsub", - topic="test", - dead_letter_topic=dead_letter_topic) + @self.dapr_app.subscribe(pubsub="pubsub", topic="test", dead_letter_topic=dead_letter_topic) def event_handler(): return "dead letter test" @@ -93,17 +97,22 @@ def event_handler(): response = self.client.get("/dapr/subscribe") self.assertEqual( - [{'pubsubname': 'pubsub', - 'topic': 'test', - 'route': '/events/pubsub/test', - 'metadata': {}, - 'deadLetterTopic': dead_letter_topic - }], json.loads(response.data)) + [ + { + "pubsubname": "pubsub", + "topic": "test", + "route": "/events/pubsub/test", + "metadata": {}, + "deadLetterTopic": dead_letter_topic, + } + ], + json.loads(response.data), + ) response = self.client.post("/events/pubsub/test", json={"body": "new message"}) self.assertEqual(response.status_code, 200) self.assertEqual(response.data.decode("utf-8"), "dead letter test") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..850327ed3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[tool.ruff] +target-version = "py38" +line-length = 100 +fix = true +extend-exclude = [".github", "dapr/proto"] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear + "UP", # pyupgrade +] +ignore = [ + # Undefined name {name} + "F821", +] diff --git a/setup.cfg b/setup.cfg index cfe3ff888..96699abec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,5 +63,5 @@ exclude = .tox, dapr/proto, examples -ignore = F821 +ignore = F821, E501, W503, E203 max-line-length = 100 diff --git a/setup.py b/setup.py index 56b616a65..e0bc28386 100644 --- a/setup.py +++ b/setup.py @@ -19,19 +19,19 @@ # Load version in dapr package. version_info = {} -with open('dapr/version/version.py') as fp: +with open("dapr/version/version.py") as fp: exec(fp.read(), version_info) -__version__ = version_info['__version__'] +__version__ = version_info["__version__"] def is_release(): - return '.dev' not in __version__ + return ".dev" not in __version__ -name = 'dapr' +name = "dapr" version = __version__ -description = 'The official release of Dapr Python SDK.' -long_description = ''' +description = "The official release of Dapr Python SDK." +long_description = """ Dapr is a portable, serverless, event-driven runtime that makes it easy for developers to build resilient, stateless and stateful microservices that run on the cloud and edge and embraces the diversity of languages and developer frameworks. @@ -40,18 +40,18 @@ def is_release(): independent, building blocks that enable you to build portable applications with the language and framework of your choice. Each building block is independent and you can use one, some, or all of them in your application. -'''.lstrip() +""".lstrip() # Get build number from GITHUB_RUN_NUMBER environment variable -build_number = os.environ.get('GITHUB_RUN_NUMBER', '0') +build_number = os.environ.get("GITHUB_RUN_NUMBER", "0") if not is_release(): - name += '-dev' - version = f'{__version__}{build_number}' - description = 'The developmental release for Dapr Python SDK.' - long_description = 'This is the developmental release for Dapr Python SDK.' + name += "-dev" + version = f"{__version__}{build_number}" + description = "The developmental release for Dapr Python SDK." + long_description = "This is the developmental release for Dapr Python SDK." -print(f'package name: {name}, version: {version}', flush=True) +print(f"package name: {name}, version: {version}", flush=True) setup( diff --git a/tests/actor/fake_actor_classes.py b/tests/actor/fake_actor_classes.py index 97234cffa..405a755b4 100644 --- a/tests/actor/fake_actor_classes.py +++ b/tests/actor/fake_actor_classes.py @@ -37,7 +37,7 @@ def __init__(self, ctx, actor_id): super(FakeSimpleActor, self).__init__(ctx, actor_id) async def actor_method(self, arg: int) -> dict: - return {'name': 'actor_method'} + return {"name": "actor_method"} async def non_actor_method(self, arg0: int, arg1: str, arg2: float) -> str: pass @@ -48,14 +48,19 @@ def __init__(self, ctx, actor_id): super(FakeSimpleReminderActor, self).__init__(ctx, actor_id) async def actor_method(self, arg: int) -> dict: - return {'name': 'actor_method'} + return {"name": "actor_method"} async def non_actor_method(self, arg0: int, arg1: str, arg2: float) -> str: pass - async def receive_reminder(self, name: str, state: bytes, - due_time: timedelta, period: timedelta, - ttl: Optional[timedelta]) -> None: + async def receive_reminder( + self, + name: str, + state: bytes, + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta], + ) -> None: pass @@ -65,14 +70,19 @@ def __init__(self, ctx, actor_id): self.timer_called = False async def actor_method(self, arg: int) -> dict: - return {'name': 'actor_method'} + return {"name": "actor_method"} async def timer_callback(self, obj) -> None: self.timer_called = True - async def receive_reminder(self, name: str, state: bytes, - due_time: timedelta, period: timedelta, - ttl: Optional[timedelta]) -> None: + async def receive_reminder( + self, + name: str, + state: bytes, + due_time: timedelta, + period: timedelta, + ttl: Optional[timedelta], + ) -> None: pass @@ -115,8 +125,9 @@ async def reentrant_pass_through_method(self, arg): ... -class FakeMultiInterfacesActor(Actor, FakeActorCls1Interface, FakeActorCls2Interface, - ReentrantActorInterface): +class FakeMultiInterfacesActor( + Actor, FakeActorCls1Interface, FakeActorCls2Interface, ReentrantActorInterface +): def __init__(self, ctx, actor_id): super(FakeMultiInterfacesActor, self).__init__(ctx, actor_id) self.activated = False @@ -137,11 +148,11 @@ async def actor_cls2_method(self, arg): async def action(self, data: object) -> str: self.action_data = data - return self.action_data['message'] + return self.action_data["message"] async def action_no_arg(self) -> str: - self.action_data = {'message': 'no_arg'} - return self.action_data['message'] + self.action_data = {"message": "no_arg"} + return self.action_data["message"] async def _on_activate(self): self.activated = True @@ -153,7 +164,7 @@ async def _on_deactivate(self): async def reentrant_method(self, data: object) -> str: self.action_data = data - return self.action_data['message'] + return self.action_data["message"] async def reentrant_pass_through_method(self, arg): pass @@ -168,8 +179,10 @@ async def reentrant_method(self, data: object) -> str: async def reentrant_pass_through_method(self, arg): from dapr.actor.client import proxy + await proxy.DaprActorHttpClient(DefaultJSONSerializer()).invoke_method( - FakeSlowReentrantActor.__name__, 'test-id', 'ReentrantMethod') + FakeSlowReentrantActor.__name__, "test-id", "ReentrantMethod" + ) async def actor_cls1_method(self, arg): pass @@ -193,7 +206,8 @@ async def reentrant_pass_through_method(self, arg): from dapr.actor.client import proxy await proxy.DaprActorHttpClient(DefaultJSONSerializer()).invoke_method( - FakeReentrantActor.__name__, 'test-id', 'ReentrantMethod') + FakeReentrantActor.__name__, "test-id", "ReentrantMethod" + ) async def actor_cls2_method(self, arg): pass diff --git a/tests/actor/fake_client.py b/tests/actor/fake_client.py index 18db0218e..fa5fe1577 100644 --- a/tests/actor/fake_client.py +++ b/tests/actor/fake_client.py @@ -20,64 +20,53 @@ # Fake Dapr Actor Client Base Class for testing class FakeDaprActorClientBase(DaprActorClientBase): async def invoke_method( - self, actor_type: str, actor_id: str, - method: str, data: Optional[bytes] = None) -> bytes: + self, actor_type: str, actor_id: str, method: str, data: Optional[bytes] = None + ) -> bytes: ... - async def save_state_transactionally( - self, actor_type: str, actor_id: str, - data: bytes) -> None: + async def save_state_transactionally(self, actor_type: str, actor_id: str, data: bytes) -> None: ... - async def get_state( - self, actor_type: str, actor_id: str, name: str) -> bytes: + async def get_state(self, actor_type: str, actor_id: str, name: str) -> bytes: ... async def register_reminder( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + self, actor_type: str, actor_id: str, name: str, data: bytes + ) -> None: ... - async def unregister_reminder( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_reminder(self, actor_type: str, actor_id: str, name: str) -> None: ... - async def register_timer( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + async def register_timer(self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: ... - async def unregister_timer( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_timer(self, actor_type: str, actor_id: str, name: str) -> None: ... class FakeDaprActorClient(FakeDaprActorClientBase): async def invoke_method( - self, actor_type: str, actor_id: str, - method: str, data: Optional[bytes] = None) -> bytes: + self, actor_type: str, actor_id: str, method: str, data: Optional[bytes] = None + ) -> bytes: return b'"expected_response"' - async def save_state_transactionally( - self, actor_type: str, actor_id: str, - data: bytes) -> None: + async def save_state_transactionally(self, actor_type: str, actor_id: str, data: bytes) -> None: pass - async def get_state( - self, actor_type: str, actor_id: str, name: str) -> bytes: + async def get_state(self, actor_type: str, actor_id: str, name: str) -> bytes: return b'"expected_response"' async def register_reminder( - self, actor_type: str, actor_id: str, - name: str, data: bytes) -> None: + self, actor_type: str, actor_id: str, name: str, data: bytes + ) -> None: pass - async def unregister_reminder( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_reminder(self, actor_type: str, actor_id: str, name: str) -> None: pass - async def register_timer( - self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: + async def register_timer(self, actor_type: str, actor_id: str, name: str, data: bytes) -> None: pass - async def unregister_timer( - self, actor_type: str, actor_id: str, name: str) -> None: + async def unregister_timer(self, actor_type: str, actor_id: str, name: str) -> None: pass diff --git a/tests/actor/test_actor.py b/tests/actor/test_actor.py index cf2c13d72..189f9996b 100644 --- a/tests/actor/test_actor.py +++ b/tests/actor/test_actor.py @@ -34,10 +34,7 @@ from tests.actor.fake_client import FakeDaprActorClient -from tests.actor.utils import ( - _async_mock, - _run -) +from tests.actor.utils import _async_mock, _run class ActorTests(unittest.TestCase): @@ -50,7 +47,7 @@ def setUp(self): def test_get_registered_actor_types(self): actor_types = ActorRuntime.get_registered_actor_types() - self.assertTrue(actor_types.index('FakeSimpleActor') >= 0) + self.assertTrue(actor_types.index("FakeSimpleActor") >= 0) self.assertTrue(actor_types.index(FakeMultiInterfacesActor.__name__) >= 0) def test_actor_config(self): @@ -64,7 +61,8 @@ def test_actor_config(self): # apply new config new_config = ActorRuntimeConfig( - timedelta(hours=3), timedelta(seconds=10), timedelta(minutes=1), False) + timedelta(hours=3), timedelta(seconds=10), timedelta(minutes=1), False + ) ActorRuntime.set_actor_config(new_config) config = ActorRuntime.get_actor_config() @@ -95,82 +93,103 @@ def test_dispatch(self): } test_request_body = self._serializer.serialize(request_body) - response = _run(ActorRuntime.dispatch( - FakeMultiInterfacesActor.__name__, 'test-id', - "ActionMethod", test_request_body)) + response = _run( + ActorRuntime.dispatch( + FakeMultiInterfacesActor.__name__, "test-id", "ActionMethod", test_request_body + ) + ) self.assertEqual(b'"hello dapr"', response) - _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, "test-id")) # Ensure test-id is deactivated with self.assertRaises(ValueError): - _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, "test-id")) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.register_reminder', - new=_async_mock(return_value=b'"ok"')) + "tests.actor.fake_client.FakeDaprActorClient.register_reminder", + new=_async_mock(return_value=b'"ok"'), + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.unregister_reminder', - new=_async_mock(return_value=b'"ok"')) + "tests.actor.fake_client.FakeDaprActorClient.unregister_reminder", + new=_async_mock(return_value=b'"ok"'), + ) def test_register_reminder(self): - - test_actor_id = ActorId('test_id') + test_actor_id = ActorId("test_id") test_type_info = ActorTypeInformation.create(FakeSimpleReminderActor) test_client = FakeDaprActorClient - ctx = ActorRuntimeContext( - test_type_info, self._serializer, - self._serializer, test_client) + ctx = ActorRuntimeContext(test_type_info, self._serializer, self._serializer, test_client) test_actor = FakeSimpleReminderActor(ctx, test_actor_id) # register reminder - _run(test_actor.register_reminder( - 'test_reminder', b'reminder_message', - timedelta(seconds=1), timedelta(seconds=1))) + _run( + test_actor.register_reminder( + "test_reminder", b"reminder_message", timedelta(seconds=1), timedelta(seconds=1) + ) + ) test_client.register_reminder.mock.assert_called_once() test_client.register_reminder.mock.assert_called_with( - 'FakeSimpleReminderActor', 'test_id', - 'test_reminder', - b'{"reminderName":"test_reminder","dueTime":"0h0m1s0ms0\\u03bcs","period":"0h0m1s0ms0\\u03bcs","data":"cmVtaW5kZXJfbWVzc2FnZQ=="}') # noqa E501 + "FakeSimpleReminderActor", + "test_id", + "test_reminder", + b'{"reminderName":"test_reminder","dueTime":"0h0m1s0ms0\\u03bcs","period":"0h0m1s0ms0\\u03bcs","data":"cmVtaW5kZXJfbWVzc2FnZQ=="}', + ) # noqa E501 # unregister reminder - _run(test_actor.unregister_reminder('test_reminder')) + _run(test_actor.unregister_reminder("test_reminder")) test_client.unregister_reminder.mock.assert_called_once() test_client.unregister_reminder.mock.assert_called_with( - 'FakeSimpleReminderActor', 'test_id', 'test_reminder') + "FakeSimpleReminderActor", "test_id", "test_reminder" + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.register_timer', - new=_async_mock(return_value=b'"ok"')) + "tests.actor.fake_client.FakeDaprActorClient.register_timer", + new=_async_mock(return_value=b'"ok"'), + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.unregister_timer', - new=_async_mock(return_value=b'"ok"')) + "tests.actor.fake_client.FakeDaprActorClient.unregister_timer", + new=_async_mock(return_value=b'"ok"'), + ) def test_register_timer(self): - - test_actor_id = ActorId('test_id') + test_actor_id = ActorId("test_id") test_type_info = ActorTypeInformation.create(FakeSimpleTimerActor) test_client = FakeDaprActorClient - ctx = ActorRuntimeContext( - test_type_info, self._serializer, - self._serializer, test_client) + ctx = ActorRuntimeContext(test_type_info, self._serializer, self._serializer, test_client) test_actor = FakeSimpleTimerActor(ctx, test_actor_id) # register timer - _run(test_actor.register_timer( - 'test_timer', test_actor.timer_callback, - "mydata", timedelta(seconds=1), timedelta(seconds=2))) + _run( + test_actor.register_timer( + "test_timer", + test_actor.timer_callback, + "mydata", + timedelta(seconds=1), + timedelta(seconds=2), + ) + ) test_client.register_timer.mock.assert_called_once() test_client.register_timer.mock.assert_called_with( - 'FakeSimpleTimerActor', 'test_id', 'test_timer', - b'{"callback":"timer_callback","data":"mydata","dueTime":"0h0m1s0ms0\\u03bcs","period":"0h0m2s0ms0\\u03bcs"}') # noqa E501 + "FakeSimpleTimerActor", + "test_id", + "test_timer", + b'{"callback":"timer_callback","data":"mydata","dueTime":"0h0m1s0ms0\\u03bcs","period":"0h0m2s0ms0\\u03bcs"}', + ) # noqa E501 # unregister timer - _run(test_actor.unregister_timer('test_timer')) + _run(test_actor.unregister_timer("test_timer")) test_client.unregister_timer.mock.assert_called_once() test_client.unregister_timer.mock.assert_called_with( - 'FakeSimpleTimerActor', 'test_id', 'test_timer') + "FakeSimpleTimerActor", "test_id", "test_timer" + ) # register timer without timer name - _run(test_actor.register_timer( - None, test_actor.timer_callback, - "timer call", timedelta(seconds=1), timedelta(seconds=1))) + _run( + test_actor.register_timer( + None, + test_actor.timer_callback, + "timer call", + timedelta(seconds=1), + timedelta(seconds=1), + ) + ) diff --git a/tests/actor/test_actor_id.py b/tests/actor/test_actor_id.py index 7fa341411..b9de8d192 100644 --- a/tests/actor/test_actor_id.py +++ b/tests/actor/test_actor_id.py @@ -20,21 +20,21 @@ class ActorIdTests(unittest.TestCase): def test_create_actor_id(self): - actor_id_1 = ActorId('1') - self.assertEqual('1', actor_id_1.id) + actor_id_1 = ActorId("1") + self.assertEqual("1", actor_id_1.id) def test_create_random_id(self): actor_id_random = ActorId.create_random_id() - self.assertEqual(len('f56d5aec5b3b11ea9121acde48001122'), len(actor_id_random.id)) + self.assertEqual(len("f56d5aec5b3b11ea9121acde48001122"), len(actor_id_random.id)) def test_get_hash(self): - actor_test_id = ActorId('testId') + actor_test_id = ActorId("testId") self.assertIsNotNone(actor_test_id.__hash__) def test_comparison(self): - actor_id_1 = ActorId('1') - actor_id_1a = ActorId('1') + actor_id_1 = ActorId("1") + actor_id_1a = ActorId("1") self.assertTrue(actor_id_1 == actor_id_1a) - actor_id_2 = ActorId('2') + actor_id_2 = ActorId("2") self.assertFalse(actor_id_1 == actor_id_2) diff --git a/tests/actor/test_actor_manager.py b/tests/actor/test_actor_manager.py index 7d0725647..9d93b4f70 100644 --- a/tests/actor/test_actor_manager.py +++ b/tests/actor/test_actor_manager.py @@ -45,13 +45,13 @@ def setUp(self): self._fake_client = FakeDaprActorClient self._runtime_ctx = ActorRuntimeContext( - self._test_type_info, self._serializer, - self._serializer, self._fake_client) + self._test_type_info, self._serializer, self._serializer, self._fake_client + ) self._manager = ActorManager(self._runtime_ctx) def test_activate_actor(self): """Activate ActorId(1)""" - test_actor_id = ActorId('1') + test_actor_id = ActorId("1") _run(self._manager.activate_actor(test_actor_id)) # assert @@ -61,7 +61,7 @@ def test_activate_actor(self): def test_deactivate_actor(self): """Activate ActorId('2') and deactivate it""" - test_actor_id = ActorId('2') + test_actor_id = ActorId("2") _run(self._manager.activate_actor(test_actor_id)) # assert @@ -74,7 +74,7 @@ def test_deactivate_actor(self): def test_dispatch_success(self): """dispatch ActionMethod""" - test_actor_id = ActorId('dispatch') + test_actor_id = ActorId("dispatch") _run(self._manager.activate_actor(test_actor_id)) request_body = { @@ -91,32 +91,34 @@ def setUp(self): self._serializer = DefaultJSONSerializer() self._fake_client = FakeDaprActorClient - self._test_reminder_req = self._serializer.serialize({ - 'name': 'test_reminder', - 'dueTime': timedelta(seconds=1), - 'period': timedelta(seconds=1), - 'ttl': timedelta(seconds=1), - 'data': 'cmVtaW5kZXJfc3RhdGU=', - }) + self._test_reminder_req = self._serializer.serialize( + { + "name": "test_reminder", + "dueTime": timedelta(seconds=1), + "period": timedelta(seconds=1), + "ttl": timedelta(seconds=1), + "data": "cmVtaW5kZXJfc3RhdGU=", + } + ) def test_fire_reminder_for_non_reminderable(self): test_type_info = ActorTypeInformation.create(FakeSimpleActor) ctx = ActorRuntimeContext( - test_type_info, self._serializer, - self._serializer, self._fake_client) + test_type_info, self._serializer, self._serializer, self._fake_client + ) manager = ActorManager(ctx) with self.assertRaises(ValueError): - _run(manager.fire_reminder(ActorId('testid'), 'test_reminder', self._test_reminder_req)) + _run(manager.fire_reminder(ActorId("testid"), "test_reminder", self._test_reminder_req)) def test_fire_reminder_success(self): - test_actor_id = ActorId('testid') + test_actor_id = ActorId("testid") test_type_info = ActorTypeInformation.create(FakeSimpleReminderActor) ctx = ActorRuntimeContext( - test_type_info, self._serializer, - self._serializer, self._fake_client) + test_type_info, self._serializer, self._serializer, self._fake_client + ) manager = ActorManager(ctx) _run(manager.activate_actor(test_actor_id)) - _run(manager.fire_reminder(test_actor_id, 'test_reminder', self._test_reminder_req)) + _run(manager.fire_reminder(test_actor_id, "test_reminder", self._test_reminder_req)) class ActorManagerTimerTests(unittest.TestCase): @@ -126,31 +128,40 @@ def setUp(self): self._fake_client = FakeDaprActorClient @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.invoke_method', - new=_async_mock(return_value=b'"expected_response"')) - @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.register_timer', - new=_async_mock()) + "tests.actor.fake_client.FakeDaprActorClient.invoke_method", + new=_async_mock(return_value=b'"expected_response"'), + ) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.register_timer", new=_async_mock()) def test_fire_timer_success(self): - test_actor_id = ActorId('testid') + test_actor_id = ActorId("testid") test_type_info = ActorTypeInformation.create(FakeSimpleTimerActor) ctx = ActorRuntimeContext( - test_type_info, self._serializer, - self._serializer, self._fake_client) + test_type_info, self._serializer, self._serializer, self._fake_client + ) manager = ActorManager(ctx) _run(manager.activate_actor(test_actor_id)) actor = manager._active_actors.get(test_actor_id.id, None) # Setup timer - _run(actor.register_timer( - 'test_timer', actor.timer_callback, - "timer call", timedelta(seconds=1), timedelta(seconds=1), timedelta(seconds=1))) + _run( + actor.register_timer( + "test_timer", + actor.timer_callback, + "timer call", + timedelta(seconds=1), + timedelta(seconds=1), + timedelta(seconds=1), + ) + ) # Fire timer - _run(manager.fire_timer( - test_actor_id, - 'test_timer', - '{ "callback": "timer_callback", "data": "timer call" }'.encode('UTF8'))) + _run( + manager.fire_timer( + test_actor_id, + "test_timer", + '{ "callback": "timer_callback", "data": "timer call" }'.encode("UTF8"), + ) + ) self.assertTrue(actor.timer_called) diff --git a/tests/actor/test_actor_reentrancy.py b/tests/actor/test_actor_reentrancy.py index eca121f6f..4ba9a4150 100644 --- a/tests/actor/test_actor_reentrancy.py +++ b/tests/actor/test_actor_reentrancy.py @@ -35,7 +35,8 @@ class ActorRuntimeTests(unittest.TestCase): def setUp(self): ActorRuntime._actor_managers = {} ActorRuntime.set_actor_config( - ActorRuntimeConfig(reentrancy=ActorReentrancyConfig(enabled=True))) + ActorRuntimeConfig(reentrancy=ActorReentrancyConfig(enabled=True)) + ) self._serializer = DefaultJSONSerializer() _run(ActorRuntime.register_actor(FakeReentrantActor)) _run(ActorRuntime.register_actor(FakeSlowReentrantActor)) @@ -51,38 +52,49 @@ def test_reentrant_dispatch(self): reentrancy_id = "0faa4c8b-f53a-4dff-9a9d-c50205035085" test_request_body = self._serializer.serialize(request_body) - response = _run(ActorRuntime.dispatch( - FakeMultiInterfacesActor.__name__, 'test-id', - "ReentrantMethod", test_request_body, reentrancy_id=reentrancy_id)) + response = _run( + ActorRuntime.dispatch( + FakeMultiInterfacesActor.__name__, + "test-id", + "ReentrantMethod", + test_request_body, + reentrancy_id=reentrancy_id, + ) + ) self.assertEqual(b'"hello dapr"', response) - _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, "test-id")) # Ensure test-id is deactivated with self.assertRaises(ValueError): - _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, "test-id")) def test_interleaved_reentrant_actor_dispatch(self): _run(ActorRuntime.register_actor(FakeReentrantActor)) _run(ActorRuntime.register_actor(FakeSlowReentrantActor)) - request_body = self._serializer.serialize({ - "message": "Normal", - }) + request_body = self._serializer.serialize( + { + "message": "Normal", + } + ) normal_reentrancy_id = "f6319f23-dc0a-4880-90d9-87b23c19c20a" slow_reentrancy_id = "b1653a2f-fe54-4514-8197-98b52d156454" async def dispatchReentrantCall(actorName: str, method: str, reentrancy_id: str): return await ActorRuntime.dispatch( - actorName, 'test-id', method, request_body, reentrancy_id=reentrancy_id) + actorName, "test-id", method, request_body, reentrancy_id=reentrancy_id + ) async def run_parallel_actors(): slow = dispatchReentrantCall( - FakeSlowReentrantActor.__name__, "ReentrantMethod", slow_reentrancy_id) + FakeSlowReentrantActor.__name__, "ReentrantMethod", slow_reentrancy_id + ) normal = dispatchReentrantCall( - FakeReentrantActor.__name__, "ReentrantMethod", normal_reentrancy_id) + FakeReentrantActor.__name__, "ReentrantMethod", normal_reentrancy_id + ) res = await asyncio.gather(slow, normal) self.slow_res = res[0] @@ -90,50 +102,58 @@ async def run_parallel_actors(): _run(run_parallel_actors()) - self.assertEqual(self.normal_res, bytes('"' + normal_reentrancy_id + '"', 'utf-8')) - self.assertEqual(self.slow_res, bytes('"' + slow_reentrancy_id + '"', 'utf-8')) + self.assertEqual(self.normal_res, bytes('"' + normal_reentrancy_id + '"', "utf-8")) + self.assertEqual(self.slow_res, bytes('"' + slow_reentrancy_id + '"', "utf-8")) - _run(ActorRuntime.deactivate(FakeSlowReentrantActor.__name__, 'test-id')) - _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeSlowReentrantActor.__name__, "test-id")) + _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, "test-id")) # Ensure test-id is deactivated with self.assertRaises(ValueError): - _run(ActorRuntime.deactivate(FakeSlowReentrantActor.__name__, 'test-id')) - _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeSlowReentrantActor.__name__, "test-id")) + _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, "test-id")) def test_reentrancy_header_passthrough(self): _run(ActorRuntime.register_actor(FakeReentrantActor)) _run(ActorRuntime.register_actor(FakeSlowReentrantActor)) - request_body = self._serializer.serialize({ - "message": "Normal", - }) + request_body = self._serializer.serialize( + { + "message": "Normal", + } + ) async def expected_return_value(*args, **kwargs): return ["expected", "None"] reentrancy_id = "f6319f23-dc0a-4880-90d9-87b23c19c20a" actor = FakeSlowReentrantActor.__name__ - method = 'ReentrantMethod' - - with mock.patch('dapr.clients.http.client.DaprHttpClient.send_bytes') as mocked: + method = "ReentrantMethod" + with mock.patch("dapr.clients.http.client.DaprHttpClient.send_bytes") as mocked: mocked.side_effect = expected_return_value - _run(ActorRuntime.dispatch( - FakeReentrantActor.__name__, 'test-id', 'ReentrantMethodWithPassthrough', - request_body, reentrancy_id=reentrancy_id)) + _run( + ActorRuntime.dispatch( + FakeReentrantActor.__name__, + "test-id", + "ReentrantMethodWithPassthrough", + request_body, + reentrancy_id=reentrancy_id, + ) + ) mocked.assert_called_with( method="POST", - url=f'http://127.0.0.1:3500/v1.0/actors/{actor}/test-id/method/{method}', + url=f"http://127.0.0.1:3500/v1.0/actors/{actor}/test-id/method/{method}", data=None, - headers={'Dapr-Reentrancy-Id': reentrancy_id}) + headers={"Dapr-Reentrancy-Id": reentrancy_id}, + ) - _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, "test-id")) # Ensure test-id is deactivated with self.assertRaises(ValueError): - _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, "test-id")) def test_header_passthrough_reentrancy_disabled(self): config = ActorRuntimeConfig(reentrancy=None) @@ -141,91 +161,105 @@ def test_header_passthrough_reentrancy_disabled(self): _run(ActorRuntime.register_actor(FakeReentrantActor)) _run(ActorRuntime.register_actor(FakeSlowReentrantActor)) - request_body = self._serializer.serialize({ - "message": "Normal", - }) + request_body = self._serializer.serialize( + { + "message": "Normal", + } + ) async def expected_return_value(*args, **kwargs): return ["expected", "None"] reentrancy_id = None # the runtime would not pass this header actor = FakeSlowReentrantActor.__name__ - method = 'ReentrantMethod' - - with mock.patch('dapr.clients.http.client.DaprHttpClient.send_bytes') as mocked: + method = "ReentrantMethod" + with mock.patch("dapr.clients.http.client.DaprHttpClient.send_bytes") as mocked: mocked.side_effect = expected_return_value - _run(ActorRuntime.dispatch( - FakeReentrantActor.__name__, 'test-id', 'ReentrantMethodWithPassthrough', - request_body, reentrancy_id=reentrancy_id)) + _run( + ActorRuntime.dispatch( + FakeReentrantActor.__name__, + "test-id", + "ReentrantMethodWithPassthrough", + request_body, + reentrancy_id=reentrancy_id, + ) + ) mocked.assert_called_with( method="POST", - url=f'http://127.0.0.1:3500/v1.0/actors/{actor}/test-id/method/{method}', + url=f"http://127.0.0.1:3500/v1.0/actors/{actor}/test-id/method/{method}", data=None, - headers={}) + headers={}, + ) - _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, "test-id")) # Ensure test-id is deactivated with self.assertRaises(ValueError): - _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeReentrantActor.__name__, "test-id")) def test_parse_incoming_reentrancy_header_flask(self): from ext.flask_dapr import flask_dapr from flask import Flask - app = Flask(f'{FakeReentrantActor.__name__}Service') + app = Flask(f"{FakeReentrantActor.__name__}Service") flask_dapr.DaprActor(app) reentrancy_id = "b1653a2f-fe54-4514-8197-98b52d156454" actor_type_name = FakeReentrantActor.__name__ - actor_id = 'test-id' - method_name = 'ReentrantMethod' + actor_id = "test-id" + method_name = "ReentrantMethod" - request_body = self._serializer.serialize({ - "message": "Normal", - }) + request_body = self._serializer.serialize( + { + "message": "Normal", + } + ) - relativeUrl = f'/actors/{actor_type_name}/{actor_id}/method/{method_name}' + relativeUrl = f"/actors/{actor_type_name}/{actor_id}/method/{method_name}" - with mock.patch('dapr.actor.runtime.runtime.ActorRuntime.dispatch') as mocked: + with mock.patch("dapr.actor.runtime.runtime.ActorRuntime.dispatch") as mocked: client = app.test_client() mocked.return_value = None client.put( relativeUrl, - headers={ - flask_dapr.actor.DAPR_REENTRANCY_ID_HEADER: reentrancy_id}, - data=request_body) + headers={flask_dapr.actor.DAPR_REENTRANCY_ID_HEADER: reentrancy_id}, + data=request_body, + ) mocked.assert_called_with( - actor_type_name, actor_id, method_name, request_body, reentrancy_id) + actor_type_name, actor_id, method_name, request_body, reentrancy_id + ) def test_parse_incoming_reentrancy_header_fastapi(self): from fastapi import FastAPI from fastapi.testclient import TestClient from dapr.ext import fastapi - app = FastAPI(title=f'{FakeReentrantActor.__name__}Service') + app = FastAPI(title=f"{FakeReentrantActor.__name__}Service") fastapi.DaprActor(app) reentrancy_id = "b1653a2f-fe54-4514-8197-98b52d156454" actor_type_name = FakeReentrantActor.__name__ - actor_id = 'test-id' - method_name = 'ReentrantMethod' + actor_id = "test-id" + method_name = "ReentrantMethod" - request_body = self._serializer.serialize({ - "message": "Normal", - }) + request_body = self._serializer.serialize( + { + "message": "Normal", + } + ) - relativeUrl = f'/actors/{actor_type_name}/{actor_id}/method/{method_name}' + relativeUrl = f"/actors/{actor_type_name}/{actor_id}/method/{method_name}" - with mock.patch('dapr.actor.runtime.runtime.ActorRuntime.dispatch') as mocked: + with mock.patch("dapr.actor.runtime.runtime.ActorRuntime.dispatch") as mocked: client = TestClient(app) mocked.return_value = None client.put( relativeUrl, - headers={ - fastapi.actor.DAPR_REENTRANCY_ID_HEADER: reentrancy_id}, - data=request_body) + headers={fastapi.actor.DAPR_REENTRANCY_ID_HEADER: reentrancy_id}, + data=request_body, + ) mocked.assert_called_with( - actor_type_name, actor_id, method_name, request_body, reentrancy_id) + actor_type_name, actor_id, method_name, request_body, reentrancy_id + ) diff --git a/tests/actor/test_actor_runtime.py b/tests/actor/test_actor_runtime.py index 17a88f4c5..daac3f171 100644 --- a/tests/actor/test_actor_runtime.py +++ b/tests/actor/test_actor_runtime.py @@ -41,7 +41,7 @@ def setUp(self): def test_get_registered_actor_types(self): actor_types = ActorRuntime.get_registered_actor_types() - self.assertTrue(actor_types.index('FakeSimpleActor') >= 0) + self.assertTrue(actor_types.index("FakeSimpleActor") >= 0) self.assertTrue(actor_types.index(FakeMultiInterfacesActor.__name__) >= 0) self.assertTrue(actor_types.index(FakeSimpleTimerActor.__name__) >= 0) @@ -56,8 +56,8 @@ def test_actor_config(self): # apply new config new_config = ActorRuntimeConfig( - timedelta(hours=3), timedelta(seconds=10), - timedelta(minutes=1), False) + timedelta(hours=3), timedelta(seconds=10), timedelta(minutes=1), False + ) ActorRuntime.set_actor_config(new_config) config = ActorRuntime.get_actor_config() @@ -88,34 +88,42 @@ def test_dispatch(self): } test_request_body = self._serializer.serialize(request_body) - response = _run(ActorRuntime.dispatch( - FakeMultiInterfacesActor.__name__, 'test-id', - "ActionMethod", test_request_body)) + response = _run( + ActorRuntime.dispatch( + FakeMultiInterfacesActor.__name__, "test-id", "ActionMethod", test_request_body + ) + ) self.assertEqual(b'"hello dapr"', response) - _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, "test-id")) # Ensure test-id is deactivated with self.assertRaises(ValueError): - _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, 'test-id')) + _run(ActorRuntime.deactivate(FakeMultiInterfacesActor.__name__, "test-id")) def test_fire_timer_success(self): # Fire timer - _run(ActorRuntime.fire_timer( - FakeSimpleTimerActor.__name__, - 'test-id', - 'test_timer', - '{ "callback": "timer_callback", "data": "timer call" }'.encode('UTF8'))) + _run( + ActorRuntime.fire_timer( + FakeSimpleTimerActor.__name__, + "test-id", + "test_timer", + '{ "callback": "timer_callback", "data": "timer call" }'.encode("UTF8"), + ) + ) manager = ActorRuntime._actor_managers[FakeSimpleTimerActor.__name__] - actor = manager._active_actors['test-id'] + actor = manager._active_actors["test-id"] self.assertTrue(actor.timer_called) def test_fire_timer_unregistered(self): with self.assertRaises(ValueError): - _run(ActorRuntime.fire_timer( - 'UnknownType', - 'test-id', - 'test_timer', - '{ "callback": "timer_callback", "data": "timer call" }'.encode('UTF8'))) + _run( + ActorRuntime.fire_timer( + "UnknownType", + "test-id", + "test_timer", + '{ "callback": "timer_callback", "data": "timer call" }'.encode("UTF8"), + ) + ) diff --git a/tests/actor/test_actor_runtime_config.py b/tests/actor/test_actor_runtime_config.py index 67aae6a76..4b265d262 100644 --- a/tests/actor/test_actor_runtime_config.py +++ b/tests/actor/test_actor_runtime_config.py @@ -21,30 +21,31 @@ class ActorTypeConfigTests(unittest.TestCase): def test_default_config(self): - config = ActorTypeConfig('testactor') + config = ActorTypeConfig("testactor") self.assertEqual(config._actor_idle_timeout, None) self.assertEqual(config._actor_scan_interval, None) self.assertEqual(config._drain_ongoing_call_timeout, None) self.assertEqual(config._drain_rebalanced_actors, None) self.assertEqual(config._reentrancy, None) - self.assertEqual(config.as_dict()['entities'], ['testactor']) + self.assertEqual(config.as_dict()["entities"], ["testactor"]) keys = config.as_dict().keys() - self.assertNotIn('reentrancy', keys) - self.assertNotIn('remindersStoragePartitions', keys) - self.assertNotIn('actorIdleTimeout', keys) - self.assertNotIn('actorScanInterval', keys) - self.assertNotIn('drainOngoingCallTimeout', keys) - self.assertNotIn('drainRebalancedActors', keys) + self.assertNotIn("reentrancy", keys) + self.assertNotIn("remindersStoragePartitions", keys) + self.assertNotIn("actorIdleTimeout", keys) + self.assertNotIn("actorScanInterval", keys) + self.assertNotIn("drainOngoingCallTimeout", keys) + self.assertNotIn("drainRebalancedActors", keys) def test_complete_config(self): config = ActorTypeConfig( - 'testactor', + "testactor", actor_idle_timeout=timedelta(seconds=3600), actor_scan_interval=timedelta(seconds=30), drain_ongoing_call_timeout=timedelta(seconds=60), drain_rebalanced_actors=False, reentrancy=ActorReentrancyConfig(enabled=True), - reminders_storage_partitions=10) + reminders_storage_partitions=10, + ) self.assertEqual(config._actor_idle_timeout, timedelta(seconds=3600)) self.assertEqual(config._actor_scan_interval, timedelta(seconds=30)) self.assertEqual(config._drain_ongoing_call_timeout, timedelta(seconds=60)) @@ -52,14 +53,14 @@ def test_complete_config(self): self.assertEqual(config._reentrancy._enabled, True) self.assertEqual(config._reentrancy._maxStackDepth, 32) d = config.as_dict() - self.assertEqual(d['entities'], ['testactor']) - self.assertEqual(d['reentrancy']['enabled'], True) - self.assertEqual(d['reentrancy']['maxStackDepth'], 32) - self.assertEqual(d['remindersStoragePartitions'], 10) - self.assertEqual(d['actorIdleTimeout'], timedelta(seconds=3600)) - self.assertEqual(d['actorScanInterval'], timedelta(seconds=30)) - self.assertEqual(d['drainOngoingCallTimeout'], timedelta(seconds=60)) - self.assertEqual(d['drainRebalancedActors'], False) + self.assertEqual(d["entities"], ["testactor"]) + self.assertEqual(d["reentrancy"]["enabled"], True) + self.assertEqual(d["reentrancy"]["maxStackDepth"], 32) + self.assertEqual(d["remindersStoragePartitions"], 10) + self.assertEqual(d["actorIdleTimeout"], timedelta(seconds=3600)) + self.assertEqual(d["actorScanInterval"], timedelta(seconds=30)) + self.assertEqual(d["drainOngoingCallTimeout"], timedelta(seconds=60)) + self.assertEqual(d["drainRebalancedActors"], False) class ActorRuntimeConfigTests(unittest.TestCase): @@ -73,9 +74,9 @@ def test_default_config(self): self.assertEqual(config._reentrancy, None) self.assertEqual(config._entities, set()) self.assertEqual(config._entitiesConfig, []) - self.assertNotIn('reentrancy', config.as_dict().keys()) - self.assertNotIn('remindersStoragePartitions', config.as_dict().keys()) - self.assertEqual(config.as_dict()['entitiesConfig'], []) + self.assertNotIn("reentrancy", config.as_dict().keys()) + self.assertNotIn("remindersStoragePartitions", config.as_dict().keys()) + self.assertEqual(config.as_dict()["entitiesConfig"], []) def test_default_config_with_reentrancy(self): reentrancyConfig = ActorReentrancyConfig(enabled=True) @@ -88,78 +89,76 @@ def test_default_config_with_reentrancy(self): self.assertEqual(config._reentrancy, reentrancyConfig) self.assertEqual(config._entities, set()) self.assertEqual(config._entitiesConfig, []) - self.assertEqual(config.as_dict()['reentrancy'], reentrancyConfig.as_dict()) - self.assertEqual(config.as_dict()['reentrancy']['enabled'], True) - self.assertEqual(config.as_dict()['reentrancy']['maxStackDepth'], 32) - self.assertNotIn('remindersStoragePartitions', config.as_dict().keys()) + self.assertEqual(config.as_dict()["reentrancy"], reentrancyConfig.as_dict()) + self.assertEqual(config.as_dict()["reentrancy"]["enabled"], True) + self.assertEqual(config.as_dict()["reentrancy"]["maxStackDepth"], 32) + self.assertNotIn("remindersStoragePartitions", config.as_dict().keys()) def test_config_with_actor_type_config(self): typeConfig1 = ActorTypeConfig( - 'testactor1', + "testactor1", actor_scan_interval=timedelta(seconds=10), - reentrancy=ActorReentrancyConfig(enabled=True)) + reentrancy=ActorReentrancyConfig(enabled=True), + ) typeConfig2 = ActorTypeConfig( - 'testactor2', + "testactor2", drain_ongoing_call_timeout=timedelta(seconds=60), - reminders_storage_partitions=10) - config = ActorRuntimeConfig( - actor_type_configs=[typeConfig1, typeConfig2]) + reminders_storage_partitions=10, + ) + config = ActorRuntimeConfig(actor_type_configs=[typeConfig1, typeConfig2]) self.assertEqual(config._actor_scan_interval, timedelta(seconds=30)) d = config.as_dict() self.assertEqual(config._drain_ongoing_call_timeout, timedelta(seconds=60)) - self.assertEqual(d['entitiesConfig'][0]['entities'], ['testactor1']) - self.assertEqual(d['entitiesConfig'][0]['actorScanInterval'], timedelta(seconds=10)) - self.assertEqual(d['entitiesConfig'][0]['reentrancy']['enabled'], True) - self.assertEqual(d['entitiesConfig'][0]['reentrancy']['maxStackDepth'], 32) - self.assertEqual(d['entitiesConfig'][1]['entities'], ['testactor2']) - self.assertEqual(d['entitiesConfig'][1]['drainOngoingCallTimeout'], timedelta(seconds=60)) - self.assertEqual(d['entitiesConfig'][1]['remindersStoragePartitions'], 10) - self.assertNotIn('reentrancy', d['entitiesConfig'][1]) - self.assertNotIn('actorScanInterval', d['entitiesConfig'][1]) - self.assertNotIn('draingOngoingCallTimeout', d['entitiesConfig'][0]) - self.assertNotIn('remindersStoragePartitions', d['entitiesConfig'][0]) - self.assertEqual(sorted(d['entities']), ['testactor1', 'testactor2']) + self.assertEqual(d["entitiesConfig"][0]["entities"], ["testactor1"]) + self.assertEqual(d["entitiesConfig"][0]["actorScanInterval"], timedelta(seconds=10)) + self.assertEqual(d["entitiesConfig"][0]["reentrancy"]["enabled"], True) + self.assertEqual(d["entitiesConfig"][0]["reentrancy"]["maxStackDepth"], 32) + self.assertEqual(d["entitiesConfig"][1]["entities"], ["testactor2"]) + self.assertEqual(d["entitiesConfig"][1]["drainOngoingCallTimeout"], timedelta(seconds=60)) + self.assertEqual(d["entitiesConfig"][1]["remindersStoragePartitions"], 10) + self.assertNotIn("reentrancy", d["entitiesConfig"][1]) + self.assertNotIn("actorScanInterval", d["entitiesConfig"][1]) + self.assertNotIn("draingOngoingCallTimeout", d["entitiesConfig"][0]) + self.assertNotIn("remindersStoragePartitions", d["entitiesConfig"][0]) + self.assertEqual(sorted(d["entities"]), ["testactor1", "testactor2"]) def test_update_entities(self): config = ActorRuntimeConfig() - config.update_entities(['actortype1']) + config.update_entities(["actortype1"]) self.assertEqual(config._actor_idle_timeout, timedelta(seconds=3600)) self.assertEqual(config._actor_scan_interval, timedelta(seconds=30)) self.assertEqual(config._drain_ongoing_call_timeout, timedelta(seconds=60)) self.assertEqual(config._drain_rebalanced_actors, True) - self.assertEqual(config._entities, {'actortype1'}) + self.assertEqual(config._entities, {"actortype1"}) self.assertEqual(config._entitiesConfig, []) - self.assertNotIn('remindersStoragePartitions', config.as_dict().keys()) + self.assertNotIn("remindersStoragePartitions", config.as_dict().keys()) def test_update_entities_two_types(self): config = ActorRuntimeConfig() - config.update_entities(['actortype1', 'actortype1']) + config.update_entities(["actortype1", "actortype1"]) self.assertEqual(config._actor_idle_timeout, timedelta(seconds=3600)) self.assertEqual(config._actor_scan_interval, timedelta(seconds=30)) self.assertEqual(config._drain_ongoing_call_timeout, timedelta(seconds=60)) self.assertEqual(config._drain_rebalanced_actors, True) - self.assertEqual(config._entities, {'actortype1', 'actortype1'}) + self.assertEqual(config._entities, {"actortype1", "actortype1"}) self.assertEqual(config._entitiesConfig, []) - self.assertNotIn('remindersStoragePartitions', config.as_dict().keys()) + self.assertNotIn("remindersStoragePartitions", config.as_dict().keys()) def test_update_actor_type_config(self): config = ActorRuntimeConfig() - config.update_entities(['actortype1']) - config.update_actor_type_configs([ - ActorTypeConfig( - 'updatetype1', - actor_scan_interval=timedelta(seconds=5) - ) - ]) + config.update_entities(["actortype1"]) + config.update_actor_type_configs( + [ActorTypeConfig("updatetype1", actor_scan_interval=timedelta(seconds=5))] + ) d = config.as_dict() - self.assertEqual(sorted(d['entities']), ['actortype1', 'updatetype1']) - self.assertEqual(d['entitiesConfig'][0]['actorScanInterval'], timedelta(seconds=5)) - self.assertEqual(d['entitiesConfig'][0]['entities'], ['updatetype1']) - self.assertEqual(d['actorScanInterval'], timedelta(seconds=30)) + self.assertEqual(sorted(d["entities"]), ["actortype1", "updatetype1"]) + self.assertEqual(d["entitiesConfig"][0]["actorScanInterval"], timedelta(seconds=5)) + self.assertEqual(d["entitiesConfig"][0]["entities"], ["updatetype1"]) + self.assertEqual(d["actorScanInterval"], timedelta(seconds=30)) def test_set_reminders_storage_partitions(self): config = ActorRuntimeConfig(reminders_storage_partitions=12) @@ -167,10 +166,10 @@ def test_set_reminders_storage_partitions(self): self.assertEqual(config._actor_scan_interval, timedelta(seconds=30)) self.assertEqual(config._drain_ongoing_call_timeout, timedelta(seconds=60)) self.assertEqual(config._drain_rebalanced_actors, True) - self.assertNotIn('reentrancy', config.as_dict().keys()) + self.assertNotIn("reentrancy", config.as_dict().keys()) self.assertEqual(config._reminders_storage_partitions, 12) - self.assertEqual(config.as_dict()['remindersStoragePartitions'], 12) + self.assertEqual(config.as_dict()["remindersStoragePartitions"], 12) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/actor/test_client_proxy.py b/tests/actor/test_client_proxy.py index e1c1120e1..9ad5cda1c 100644 --- a/tests/actor/test_client_proxy.py +++ b/tests/actor/test_client_proxy.py @@ -28,10 +28,7 @@ from tests.actor.fake_client import FakeDaprActorClient -from tests.actor.utils import ( - _async_mock, - _run -) +from tests.actor.utils import _async_mock, _run class FakeActoryProxyFactory: @@ -39,10 +36,10 @@ def __init__(self, fake_client): # TODO: support serializer for state store later self._dapr_client = fake_client - def create(self, actor_interface, - actor_type, actor_id) -> ActorProxy: - return ActorProxy(self._dapr_client, actor_interface, - actor_type, actor_id, DefaultJSONSerializer()) + def create(self, actor_interface, actor_type, actor_id) -> ActorProxy: + return ActorProxy( + self._dapr_client, actor_interface, actor_type, actor_id, DefaultJSONSerializer() + ) class ActorProxyTests(unittest.TestCase): @@ -52,45 +49,54 @@ def setUp(self): self._fake_factory = FakeActoryProxyFactory(self._fake_client) self._proxy = ActorProxy.create( FakeMultiInterfacesActor.__name__, - ActorId('fake-id'), + ActorId("fake-id"), FakeActorCls2Interface, - self._fake_factory) + self._fake_factory, + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.invoke_method', - new=_async_mock(return_value=b'"expected_response"')) + "tests.actor.fake_client.FakeDaprActorClient.invoke_method", + new=_async_mock(return_value=b'"expected_response"'), + ) def test_invoke(self): - response = _run(self._proxy.invoke_method('ActionMethod', b'arg0')) + response = _run(self._proxy.invoke_method("ActionMethod", b"arg0")) self.assertEqual(b'"expected_response"', response) self._fake_client.invoke_method.mock.assert_called_once_with( - FakeMultiInterfacesActor.__name__, 'fake-id', 'ActionMethod', b'arg0') + FakeMultiInterfacesActor.__name__, "fake-id", "ActionMethod", b"arg0" + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.invoke_method', - new=_async_mock(return_value=b'"expected_response"')) + "tests.actor.fake_client.FakeDaprActorClient.invoke_method", + new=_async_mock(return_value=b'"expected_response"'), + ) def test_invoke_no_arg(self): - response = _run(self._proxy.invoke_method('ActionMethodWithoutArg')) + response = _run(self._proxy.invoke_method("ActionMethodWithoutArg")) self.assertEqual(b'"expected_response"', response) self._fake_client.invoke_method.mock.assert_called_once_with( - FakeMultiInterfacesActor.__name__, 'fake-id', 'ActionMethodWithoutArg', None) + FakeMultiInterfacesActor.__name__, "fake-id", "ActionMethodWithoutArg", None + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.invoke_method', - new=_async_mock(return_value=b'"expected_response"')) + "tests.actor.fake_client.FakeDaprActorClient.invoke_method", + new=_async_mock(return_value=b'"expected_response"'), + ) def test_invoke_with_static_typing(self): - response = _run(self._proxy.ActionMethod(b'arg0')) - self.assertEqual('expected_response', response) + response = _run(self._proxy.ActionMethod(b"arg0")) + self.assertEqual("expected_response", response) self._fake_client.invoke_method.mock.assert_called_once_with( - FakeMultiInterfacesActor.__name__, 'fake-id', 'ActionMethod', b'arg0') + FakeMultiInterfacesActor.__name__, "fake-id", "ActionMethod", b"arg0" + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.invoke_method', - new=_async_mock(return_value=b'"expected_response"')) + "tests.actor.fake_client.FakeDaprActorClient.invoke_method", + new=_async_mock(return_value=b'"expected_response"'), + ) def test_invoke_with_static_typing_no_arg(self): response = _run(self._proxy.ActionMethodWithoutArg()) - self.assertEqual('expected_response', response) + self.assertEqual("expected_response", response) self._fake_client.invoke_method.mock.assert_called_once_with( - FakeMultiInterfacesActor.__name__, 'fake-id', 'ActionMethodWithoutArg', None) + FakeMultiInterfacesActor.__name__, "fake-id", "ActionMethodWithoutArg", None + ) def test_raise_exception_non_existing_method(self): with self.assertRaises(AttributeError): diff --git a/tests/actor/test_method_dispatcher.py b/tests/actor/test_method_dispatcher.py index 892cec438..d6435a41e 100644 --- a/tests/actor/test_method_dispatcher.py +++ b/tests/actor/test_method_dispatcher.py @@ -31,13 +31,13 @@ def setUp(self): self._serializer = DefaultJSONSerializer() self._fake_client = FakeDaprActorClient self._fake_runtime_ctx = ActorRuntimeContext( - self._testActorTypeInfo, self._serializer, - self._serializer, self._fake_client) + self._testActorTypeInfo, self._serializer, self._serializer, self._fake_client + ) def test_get_arg_names(self): dispatcher = ActorMethodDispatcher(self._testActorTypeInfo) arg_names = dispatcher.get_arg_names("ActorMethod") - self.assertEqual(['arg'], arg_names) + self.assertEqual(["arg"], arg_names) def test_get_arg_types(self): dispatcher = ActorMethodDispatcher(self._testActorTypeInfo) @@ -53,4 +53,4 @@ def test_dispatch(self): dispatcher = ActorMethodDispatcher(self._testActorTypeInfo) actorInstance = FakeSimpleActor(self._fake_runtime_ctx, None) result = _run(dispatcher.dispatch(actorInstance, "ActorMethod", 10)) - self.assertEqual({'name': 'actor_method'}, result) + self.assertEqual({"name": "actor_method"}, result) diff --git a/tests/actor/test_reminder_data.py b/tests/actor/test_reminder_data.py index f0136a6fe..64dabcc73 100644 --- a/tests/actor/test_reminder_data.py +++ b/tests/actor/test_reminder_data.py @@ -23,53 +23,60 @@ class ActorReminderTests(unittest.TestCase): def test_invalid_state(self): with self.assertRaises(ValueError): ActorReminderData( - 'test_reminder', + "test_reminder", 123, # int type timedelta(seconds=1), timedelta(seconds=2), - timedelta(seconds=3)) + timedelta(seconds=3), + ) ActorReminderData( - 'test_reminder', - 'reminder_state', # string type + "test_reminder", + "reminder_state", # string type timedelta(seconds=2), timedelta(seconds=1), - timedelta(seconds=3)) + timedelta(seconds=3), + ) def test_valid_state(self): # bytes type state data reminder = ActorReminderData( - 'test_reminder', - b'reminder_state', + "test_reminder", + b"reminder_state", timedelta(seconds=1), timedelta(seconds=2), - timedelta(seconds=3)) - self.assertEqual(b'reminder_state', reminder.state) + timedelta(seconds=3), + ) + self.assertEqual(b"reminder_state", reminder.state) def test_as_dict(self): reminder = ActorReminderData( - 'test_reminder', - b'reminder_state', + "test_reminder", + b"reminder_state", timedelta(seconds=1), timedelta(seconds=2), - timedelta(seconds=3)) + timedelta(seconds=3), + ) expected = { - 'reminderName': 'test_reminder', - 'dueTime': timedelta(seconds=1), - 'period': timedelta(seconds=2), - 'ttl': timedelta(seconds=3), - 'data': 'cmVtaW5kZXJfc3RhdGU=', + "reminderName": "test_reminder", + "dueTime": timedelta(seconds=1), + "period": timedelta(seconds=2), + "ttl": timedelta(seconds=3), + "data": "cmVtaW5kZXJfc3RhdGU=", } self.assertDictEqual(expected, reminder.as_dict()) def test_from_dict(self): - reminder = ActorReminderData.from_dict('test_reminder', { - 'dueTime': timedelta(seconds=1), - 'period': timedelta(seconds=2), - 'ttl': timedelta(seconds=3), - 'data': 'cmVtaW5kZXJfc3RhdGU=', - }) - self.assertEqual('test_reminder', reminder.reminder_name) + reminder = ActorReminderData.from_dict( + "test_reminder", + { + "dueTime": timedelta(seconds=1), + "period": timedelta(seconds=2), + "ttl": timedelta(seconds=3), + "data": "cmVtaW5kZXJfc3RhdGU=", + }, + ) + self.assertEqual("test_reminder", reminder.reminder_name) self.assertEqual(timedelta(seconds=1), reminder.due_time) self.assertEqual(timedelta(seconds=2), reminder.period) self.assertEqual(timedelta(seconds=3), reminder.ttl) - self.assertEqual(b'reminder_state', reminder.state) + self.assertEqual(b"reminder_state", reminder.state) diff --git a/tests/actor/test_state_manager.py b/tests/actor/test_state_manager.py index a619b1323..e4793df83 100644 --- a/tests/actor/test_state_manager.py +++ b/tests/actor/test_state_manager.py @@ -28,10 +28,7 @@ from tests.actor.fake_actor_classes import FakeSimpleActor from tests.actor.fake_client import FakeDaprActorClient -from tests.actor.utils import ( - _async_mock, - _run -) +from tests.actor.utils import _async_mock, _run class ActorStateManagerTests(unittest.TestCase): @@ -39,332 +36,330 @@ def setUp(self): # Create mock client self._fake_client = FakeDaprActorClient - self._test_actor_id = ActorId('1') + self._test_actor_id = ActorId("1") self._test_type_info = ActorTypeInformation.create(FakeSimpleActor) self._serializer = DefaultJSONSerializer() self._runtime_ctx = ActorRuntimeContext( - self._test_type_info, self._serializer, self._serializer, self._fake_client) + self._test_type_info, self._serializer, self._serializer, self._fake_client + ) self._fake_actor = FakeSimpleActor(self._runtime_ctx, self._test_actor_id) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=base64.b64encode(b'"value1"'))) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=base64.b64encode(b'"value1"')), + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.save_state_transactionally', - new=_async_mock()) + "tests.actor.fake_client.FakeDaprActorClient.save_state_transactionally", new=_async_mock() + ) def test_add_state(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() # Add first 'state1' - added = _run(state_manager.try_add_state('state1', 'value1')) + added = _run(state_manager.try_add_state("state1", "value1")) self.assertTrue(added) - state = state_change_tracker['state1'] - self.assertEqual('value1', state.value) + state = state_change_tracker["state1"] + self.assertEqual("value1", state.value) self.assertEqual(StateChangeKind.add, state.change_kind) # Add 'state1' again - added = _run(state_manager.try_add_state('state1', 'value1')) + added = _run(state_manager.try_add_state("state1", "value1")) self.assertFalse(added) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_get_state_for_no_state(self): - state_manager = ActorStateManager(self._fake_actor) - has_value, val = _run(state_manager.try_get_state('state1')) + has_value, val = _run(state_manager.try_get_state("state1")) self.assertFalse(has_value) self.assertIsNone(val) # Test if the test value is empty string - self._fake_client.get_state.return_value = '' - has_value, val = _run(state_manager.try_get_state('state1')) + self._fake_client.get_state.return_value = "" + has_value, val = _run(state_manager.try_get_state("state1")) self.assertFalse(has_value) self.assertIsNone(val) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_get_state_for_existing_value(self): - state_manager = ActorStateManager(self._fake_actor) - has_value, val = _run(state_manager.try_get_state('state1')) + has_value, val = _run(state_manager.try_get_state("state1")) self.assertTrue(has_value) self.assertEqual("value1", val) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_get_state_for_removed_value(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - removed = _run(state_manager.try_remove_state('state1')) + removed = _run(state_manager.try_remove_state("state1")) self.assertTrue(removed) - state = state_change_tracker['state1'] + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.remove, state.change_kind) - has_value, val = _run(state_manager.try_get_state('state1')) + has_value, val = _run(state_manager.try_get_state("state1")) self.assertFalse(has_value) self.assertIsNone(val) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_set_state_for_new_state(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - _run(state_manager.set_state('state1', 'value1')) + _run(state_manager.set_state("state1", "value1")) - state = state_change_tracker['state1'] + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.add, state.change_kind) - self.assertEqual('value1', state.value) + self.assertEqual("value1", state.value) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_set_state_for_existing_state_only_in_mem(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - _run(state_manager.set_state('state1', 'value1')) + _run(state_manager.set_state("state1", "value1")) - state = state_change_tracker['state1'] + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.add, state.change_kind) - self.assertEqual('value1', state.value) + self.assertEqual("value1", state.value) - _run(state_manager.set_state('state1', 'value2')) - state = state_change_tracker['state1'] + _run(state_manager.set_state("state1", "value2")) + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.add, state.change_kind) - self.assertEqual('value2', state.value) + self.assertEqual("value2", state.value) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_set_state_for_existing_state(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - _run(state_manager.set_state('state1', 'value2')) - state = state_change_tracker['state1'] + _run(state_manager.set_state("state1", "value2")) + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.update, state.change_kind) - self.assertEqual('value2', state.value) + self.assertEqual("value2", state.value) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_remove_state_for_non_existing_state(self): - state_manager = ActorStateManager(self._fake_actor) - removed = _run(state_manager.try_remove_state('state1')) + removed = _run(state_manager.try_remove_state("state1")) self.assertFalse(removed) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_remove_state_for_existing_state(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - removed = _run(state_manager.try_remove_state('state1')) + removed = _run(state_manager.try_remove_state("state1")) self.assertTrue(removed) - state = state_change_tracker['state1'] + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.remove, state.change_kind) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_remove_state_for_existing_state_in_mem(self): - state_manager = ActorStateManager(self._fake_actor) - _run(state_manager.set_state('state1', 'value1')) - removed = _run(state_manager.try_remove_state('state1')) + _run(state_manager.set_state("state1", "value1")) + removed = _run(state_manager.try_remove_state("state1")) self.assertTrue(removed) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_remove_state_twice_for_existing_state_in_mem(self): - state_manager = ActorStateManager(self._fake_actor) - _run(state_manager.set_state('state1', 'value1')) - removed = _run(state_manager.try_remove_state('state1')) + _run(state_manager.set_state("state1", "value1")) + removed = _run(state_manager.try_remove_state("state1")) self.assertTrue(removed) - removed = _run(state_manager.try_remove_state('state1')) + removed = _run(state_manager.try_remove_state("state1")) self.assertFalse(removed) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_contains_state_for_removed_state(self): - state_manager = ActorStateManager(self._fake_actor) - _run(state_manager.set_state('state1', 'value1')) + _run(state_manager.set_state("state1", "value1")) - exist = _run(state_manager.contains_state('state1')) + exist = _run(state_manager.contains_state("state1")) self.assertTrue(exist) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_contains_state_for_existing_state(self): - state_manager = ActorStateManager(self._fake_actor) - exist = _run(state_manager.contains_state('state1')) + exist = _run(state_manager.contains_state("state1")) self.assertTrue(exist) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_get_or_add_state_for_existing_state(self): - state_manager = ActorStateManager(self._fake_actor) - val = _run(state_manager.get_or_add_state('state1', 'value2')) - self.assertEqual('value1', val) + val = _run(state_manager.get_or_add_state("state1", "value2")) + self.assertEqual("value1", val) - @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_get_or_add_state_for_non_existing_state(self): state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - val = _run(state_manager.get_or_add_state('state1', 'value2')) + val = _run(state_manager.get_or_add_state("state1", "value2")) - state = state_change_tracker['state1'] + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.add, state.change_kind) - self.assertEqual('value2', val) + self.assertEqual("value2", val) self._fake_client.get_state.mock.assert_called_once_with( - self._test_type_info._name, - self._test_actor_id.id, 'state1') + self._test_type_info._name, self._test_actor_id.id, "state1" + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_get_or_add_state_for_removed_state(self): - state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - _run(state_manager.remove_state('state1')) - state = state_change_tracker['state1'] + _run(state_manager.remove_state("state1")) + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.remove, state.change_kind) - val = _run(state_manager.get_or_add_state('state1', 'value2')) - state = state_change_tracker['state1'] + val = _run(state_manager.get_or_add_state("state1", "value2")) + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.update, state.change_kind) - self.assertEqual('value2', val) + self.assertEqual("value2", val) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_add_or_update_state_for_new_state(self): """adds state if state does not exist.""" + def test_update_value(name, value): - return f'{name}-{value}' + return f"{name}-{value}" state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - val = _run(state_manager.add_or_update_state('state1', 'value1', test_update_value)) - self.assertEqual('value1', val) - state = state_change_tracker['state1'] + val = _run(state_manager.add_or_update_state("state1", "value1", test_update_value)) + self.assertEqual("value1", val) + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.add, state.change_kind) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_add_or_update_state_for_state_in_storage(self): """updates state value using update_value_factory if state is in the storage.""" + def test_update_value(name, value): - return f'{name}-{value}' + return f"{name}-{value}" state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - val = _run(state_manager.add_or_update_state('state1', 'value1', test_update_value)) - self.assertEqual('state1-value1', val) - state = state_change_tracker['state1'] + val = _run(state_manager.add_or_update_state("state1", "value1", test_update_value)) + self.assertEqual("state1-value1", val) + state = state_change_tracker["state1"] self.assertEqual(StateChangeKind.update, state.change_kind) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_add_or_update_state_for_removed_state(self): """add state value if state was removed.""" + def test_update_value(name, value): - return f'{name}-{value}' + return f"{name}-{value}" state_manager = ActorStateManager(self._fake_actor) - _run(state_manager.remove_state('state1')) + _run(state_manager.remove_state("state1")) - val = _run(state_manager.add_or_update_state('state1', 'value1', test_update_value)) - self.assertEqual('value1', val) + val = _run(state_manager.add_or_update_state("state1", "value1", test_update_value)) + self.assertEqual("value1", val) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value1"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value1"'), + ) def test_add_or_update_state_for_none_state_key(self): - """update state value for StateChangeKind.none state """ + """update state value for StateChangeKind.none state""" + def test_update_value(name, value): - return f'{name}-{value}' + return f"{name}-{value}" state_manager = ActorStateManager(self._fake_actor) - has_value, val = _run(state_manager.try_get_state('state1')) + has_value, val = _run(state_manager.try_get_state("state1")) self.assertTrue(has_value) - self.assertEqual('value1', val) + self.assertEqual("value1", val) - val = _run(state_manager.add_or_update_state('state1', 'value1', test_update_value)) - self.assertEqual('state1-value1', val) + val = _run(state_manager.add_or_update_state("state1", "value1", test_update_value)) + self.assertEqual("state1-value1", val) def test_add_or_update_state_without_update_value_factory(self): - """tries to add or update state without update_value_factory """ + """tries to add or update state without update_value_factory""" state_manager = ActorStateManager(self._fake_actor) with self.assertRaises(AttributeError): - _run(state_manager.add_or_update_state('state1', 'value1', None)) + _run(state_manager.add_or_update_state("state1", "value1", None)) - @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + @mock.patch("tests.actor.fake_client.FakeDaprActorClient.get_state", new=_async_mock()) def test_get_state_names(self): state_manager = ActorStateManager(self._fake_actor) - _run(state_manager.set_state('state1', 'value1')) - _run(state_manager.set_state('state2', 'value2')) - _run(state_manager.set_state('state3', 'value3')) + _run(state_manager.set_state("state1", "value1")) + _run(state_manager.set_state("state2", "value2")) + _run(state_manager.set_state("state3", "value3")) names = _run(state_manager.get_state_names()) - self.assertEqual(['state1', 'state2', 'state3'], names) + self.assertEqual(["state1", "state2", "state3"], names) self._fake_client.get_state.mock.assert_any_call( - self._test_type_info._name, - self._test_actor_id.id, - 'state1') + self._test_type_info._name, self._test_actor_id.id, "state1" + ) self._fake_client.get_state.mock.assert_any_call( - self._test_type_info._name, - self._test_actor_id.id, - 'state2') + self._test_type_info._name, self._test_actor_id.id, "state2" + ) self._fake_client.get_state.mock.assert_any_call( - self._test_type_info._name, - self._test_actor_id.id, - 'state3') + self._test_type_info._name, self._test_actor_id.id, "state3" + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value0"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value0"'), + ) def test_clear_cache(self): state_manager = ActorStateManager(self._fake_actor) state_change_tracker = state_manager._get_contextual_state_tracker() - _run(state_manager.set_state('state1', 'value1')) - _run(state_manager.set_state('state2', 'value2')) - _run(state_manager.set_state('state3', 'value3')) + _run(state_manager.set_state("state1", "value1")) + _run(state_manager.set_state("state2", "value2")) + _run(state_manager.set_state("state3", "value3")) _run(state_manager.clear_cache()) self.assertEqual(0, len(state_change_tracker)) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.get_state', - new=_async_mock(return_value=b'"value3"')) + "tests.actor.fake_client.FakeDaprActorClient.get_state", + new=_async_mock(return_value=b'"value3"'), + ) @mock.patch( - 'tests.actor.fake_client.FakeDaprActorClient.save_state_transactionally', - new=_async_mock()) + "tests.actor.fake_client.FakeDaprActorClient.save_state_transactionally", new=_async_mock() + ) def test_save_state(self): state_manager = ActorStateManager(self._fake_actor) # set states which are StateChangeKind.add - _run(state_manager.set_state('state1', 'value1')) - _run(state_manager.set_state('state2', 'value2')) + _run(state_manager.set_state("state1", "value1")) + _run(state_manager.set_state("state2", "value2")) - has_value, val = _run(state_manager.try_get_state('state3')) + has_value, val = _run(state_manager.try_get_state("state3")) self.assertTrue(has_value) self.assertEqual("value3", val) # set state which is StateChangeKind.remove - _run(state_manager.remove_state('state4')) + _run(state_manager.remove_state("state4")) # set state which is StateChangeKind.update - _run(state_manager.set_state('state5', 'value5')) + _run(state_manager.set_state("state5", "value5")) expected = b'[{"operation":"upsert","request":{"key":"state1","value":"value1"}},{"operation":"upsert","request":{"key":"state2","value":"value2"}},{"operation":"delete","request":{"key":"state4"}},{"operation":"upsert","request":{"key":"state5","value":"value5"}}]' # noqa: E501 # Save the state @@ -375,5 +370,5 @@ def mock_save_state(actor_type, actor_id, data): _run(state_manager.save_state()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/actor/test_timer_data.py b/tests/actor/test_timer_data.py index ec68a4923..eb1823270 100644 --- a/tests/actor/test_timer_data.py +++ b/tests/actor/test_timer_data.py @@ -24,12 +24,18 @@ class ActorTimerDataTests(unittest.TestCase): def test_timer_data(self): def my_callback(input: Any): print(input) + timer = ActorTimerData( - 'timer_name', my_callback, 'called', - timedelta(seconds=2), timedelta(seconds=1), timedelta(seconds=3)) - self.assertEqual('timer_name', timer.timer_name) - self.assertEqual('my_callback', timer.callback) - self.assertEqual('called', timer.state) + "timer_name", + my_callback, + "called", + timedelta(seconds=2), + timedelta(seconds=1), + timedelta(seconds=3), + ) + self.assertEqual("timer_name", timer.timer_name) + self.assertEqual("my_callback", timer.callback) + self.assertEqual("called", timer.state) self.assertEqual(timedelta(seconds=2), timer.due_time) self.assertEqual(timedelta(seconds=1), timer.period) self.assertEqual(timedelta(seconds=3), timer.ttl) @@ -37,14 +43,20 @@ def my_callback(input: Any): def test_as_dict(self): def my_callback(input: Any): print(input) + timer = ActorTimerData( - 'timer_name', my_callback, 'called', - timedelta(seconds=1), timedelta(seconds=1), timedelta(seconds=1)) + "timer_name", + my_callback, + "called", + timedelta(seconds=1), + timedelta(seconds=1), + timedelta(seconds=1), + ) expected = { - 'callback': 'my_callback', - 'data': 'called', - 'dueTime': timedelta(seconds=1), - 'period': timedelta(seconds=1), - 'ttl': timedelta(seconds=1), + "callback": "my_callback", + "data": "called", + "dueTime": timedelta(seconds=1), + "period": timedelta(seconds=1), + "ttl": timedelta(seconds=1), } self.assertDictEqual(expected, timer.as_dict()) diff --git a/tests/actor/test_type_utils.py b/tests/actor/test_type_utils.py index 379da7360..3e4700088 100644 --- a/tests/actor/test_type_utils.py +++ b/tests/actor/test_type_utils.py @@ -17,9 +17,12 @@ from dapr.actor.actor_interface import ActorInterface from dapr.actor.runtime._type_utils import ( - get_class_method_args, get_method_arg_types, - get_method_return_types, is_dapr_actor, - get_actor_interfaces, get_dispatchable_attrs + get_class_method_args, + get_method_arg_types, + get_method_return_types, + is_dapr_actor, + get_actor_interfaces, + get_dispatchable_attrs, ) from tests.actor.fake_actor_classes import ( @@ -33,7 +36,7 @@ class TypeUtilsTests(unittest.TestCase): def test_get_class_method_args(self): args = get_class_method_args(FakeSimpleActor.actor_method) - self.assertEqual(args, ['arg']) + self.assertEqual(args, ["arg"]) def test_get_method_arg_types(self): arg_types = get_method_arg_types(FakeSimpleActor.non_actor_method) @@ -66,10 +69,10 @@ def test_get_actor_interface(self): def test_get_dispatchable_attrs(self): dispatchable_attrs = get_dispatchable_attrs(FakeMultiInterfacesActor) expected_dispatchable_attrs = [ - 'ActorCls1Method', - 'ActorCls1Method1', - 'ActorCls1Method2', - 'ActorCls2Method' + "ActorCls1Method", + "ActorCls1Method1", + "ActorCls1Method2", + "ActorCls2Method", ] method_cnt = 0 diff --git a/tests/clients/certs.py b/tests/clients/certs.py index 4b05048b9..d32d10b3b 100644 --- a/tests/clients/certs.py +++ b/tests/clients/certs.py @@ -2,11 +2,11 @@ from OpenSSL import crypto -PRIVATE_KEY_PATH = os.path.join(os.path.dirname(__file__), 'private.key') -CERTIFICATE_CHAIN_PATH = os.path.join(os.path.dirname(__file__), 'selfsigned.pem') +PRIVATE_KEY_PATH = os.path.join(os.path.dirname(__file__), "private.key") +CERTIFICATE_CHAIN_PATH = os.path.join(os.path.dirname(__file__), "selfsigned.pem") -def create_certificates(server_type='grpc'): +def create_certificates(server_type="grpc"): # create a key pair k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 4096) @@ -20,10 +20,10 @@ def create_certificates(server_type='grpc'): cert.set_issuer(cert.get_subject()) cert.set_pubkey(k) - if server_type == 'http': + if server_type == "http": cert.add_extensions([crypto.X509Extension(b"subjectAltName", False, b"DNS:localhost")]) - cert.sign(k, 'sha512') + cert.sign(k, "sha512") f_cert = open(CERTIFICATE_CHAIN_PATH, "wt") f_cert.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) diff --git a/tests/clients/fake_dapr_server.py b/tests/clients/fake_dapr_server.py index 9341850dd..7e2d87b8e 100644 --- a/tests/clients/fake_dapr_server.py +++ b/tests/clients/fake_dapr_server.py @@ -30,12 +30,15 @@ ) from typing import Dict -from tests.clients.certs import create_certificates, delete_certificates, PRIVATE_KEY_PATH, \ - CERTIFICATE_CHAIN_PATH +from tests.clients.certs import ( + create_certificates, + delete_certificates, + PRIVATE_KEY_PATH, + CERTIFICATE_CHAIN_PATH, +) class FakeDaprSidecar(api_service_v1.DaprServicer): - def __init__(self): self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) api_service_v1.add_DaprServicer_to_server(self, self._server) @@ -47,25 +50,25 @@ def __init__(self): self.metadata: Dict[str, str] = {} def start(self, port: int = 8080): - self._server.add_insecure_port(f'[::]:{port}') + self._server.add_insecure_port(f"[::]:{port}") self._server.start() def start_secure(self, port: int = 4443): - create_certificates() private_key_file = open(PRIVATE_KEY_PATH, "rb") private_key_content = private_key_file.read() private_key_file.close() - certificate_chain_file = open(CERTIFICATE_CHAIN_PATH, 'rb') + certificate_chain_file = open(CERTIFICATE_CHAIN_PATH, "rb") certificate_chain_content = certificate_chain_file.read() certificate_chain_file.close() credentials = grpc.ssl_server_credentials( - [(private_key_content, certificate_chain_content)]) + [(private_key_content, certificate_chain_content)] + ) - self._server.add_secure_port(f'[::]:{port}', credentials) + self._server.add_secure_port(f"[::]:{port}", credentials) self._server.start() def stop(self): @@ -80,13 +83,13 @@ def InvokeService(self, request, context) -> common_v1.InvokeResponse: trailers = () for k, v in context.invocation_metadata(): - headers = headers + (('h' + k, v), ) - trailers = trailers + (('t' + k, v), ) + headers = headers + (("h" + k, v),) + trailers = trailers + (("t" + k, v),) resp = GrpcAny() - content_type = '' + content_type = "" - if request.message.method == 'bytes': + if request.message.method == "bytes": resp.value = request.message.data.value content_type = request.message.content_type else: @@ -102,13 +105,13 @@ def InvokeBinding(self, request, context) -> api_v1.InvokeBindingResponse: trailers = () for k, v in request.metadata.items(): - headers = headers + (('h' + k, v), ) - trailers = trailers + (('t' + k, v), ) + headers = headers + (("h" + k, v),) + trailers = trailers + (("t" + k, v),) - resp_data = b'INVALID' + resp_data = b"INVALID" metadata = {} - if request.operation == 'create': + if request.operation == "create": resp_data = request.data metadata = request.metadata @@ -121,18 +124,18 @@ def PublishEvent(self, request, context): headers = () trailers = () if request.topic: - headers = headers + (('htopic', request.topic),) - trailers = trailers + (('ttopic', request.topic),) + headers = headers + (("htopic", request.topic),) + trailers = trailers + (("ttopic", request.topic),) if request.data: - headers = headers + (('hdata', request.data), ) - trailers = trailers + (('hdata', request.data), ) + headers = headers + (("hdata", request.data),) + trailers = trailers + (("hdata", request.data),) if request.data_content_type: - headers = headers + (('data_content_type', request.data_content_type), ) - trailers = trailers + (('data_content_type', request.data_content_type), ) - if request.metadata['rawPayload']: - headers = headers + (('metadata_raw_payload', request.metadata['rawPayload']), ) - if request.metadata['ttlInSeconds']: - headers = headers + (('metadata_ttl_in_seconds', request.metadata['ttlInSeconds']), ) + headers = headers + (("data_content_type", request.data_content_type),) + trailers = trailers + (("data_content_type", request.data_content_type),) + if request.metadata["rawPayload"]: + headers = headers + (("metadata_raw_payload", request.metadata["rawPayload"]),) + if request.metadata["ttlInSeconds"]: + headers = headers + (("metadata_ttl_in_seconds", request.metadata["ttlInSeconds"]),) context.send_initial_metadata(headers) context.set_trailing_metadata(trailers) @@ -145,10 +148,10 @@ def SaveState(self, request, context): data = state.value if state.metadata["capitalize"]: data = to_bytes(data.decode("utf-8").capitalize()) - if state.HasField('etag'): + if state.HasField("etag"): self.store[state.key] = (data, state.etag.value) else: - self.store[state.key] = (data, 'ETAG_WAS_NONE') + self.store[state.key] = (data, "ETAG_WAS_NONE") context.send_initial_metadata(headers) context.set_trailing_metadata(trailers) @@ -158,10 +161,10 @@ def ExecuteStateTransaction(self, request, context): headers = () trailers = () for operation in request.operations: - if operation.operationType == 'delete': + if operation.operationType == "delete": del self.store[operation.request.key] else: - etag = 'ETAG_WAS_NONE' + etag = "ETAG_WAS_NONE" if operation.request.HasField("etag"): etag = operation.request.etag.value self.store[operation.request.key] = (operation.request.value, etag) @@ -212,8 +215,8 @@ def GetSecret(self, request, context) -> api_v1.GetSecretResponse: key = request.key - headers = headers + (('keyh', key), ) - trailers = trailers + (('keyt', key), ) + headers = headers + (("keyh", key),) + trailers = trailers + (("keyt", key),) resp = {key: "val"} @@ -226,8 +229,8 @@ def GetBulkSecret(self, request, context) -> api_v1.GetBulkSecretResponse: headers = () trailers = () - headers = headers + (('keyh', "bulk"), ) - trailers = trailers + (('keyt', "bulk"), ) + headers = headers + (("keyh", "bulk"),) + trailers = trailers + (("keyt", "bulk"),) resp = {"keya": api_v1.SecretResponse(secrets={"keyb": "val"})} @@ -239,17 +242,15 @@ def GetBulkSecret(self, request, context) -> api_v1.GetBulkSecretResponse: def GetConfiguration(self, request, context): items = dict() for key in request.keys: - items[str(key)] = ConfigurationItem(value='value', version='1.5.0') + items[str(key)] = ConfigurationItem(value="value", version="1.5.0") return api_v1.GetConfigurationResponse(items=items) def SubscribeConfiguration(self, request, context): items = [] for key in request.keys: - item = {'key': key, 'value': 'value', 'version': '1.5.0', 'metadata': {}} + item = {"key": key, "value": "value", "version": "1.5.0", "metadata": {}} items.append(item) - response = { - items: items - } + response = {items: items} responses = [] responses.append(response) return api_v1.SubscribeConfigurationResponse(responses=responses) @@ -258,18 +259,20 @@ def UnsubscribeConfiguration(self, request, context): return api_v1.UnsubscribeConfigurationResponse(ok=True) def QueryStateAlpha1(self, request, context): - items = [QueryStateItem( - key=str(key), data=bytes('value of ' + str(key), 'UTF-8')) for key in range(1, 11)] + items = [ + QueryStateItem(key=str(key), data=bytes("value of " + str(key), "UTF-8")) + for key in range(1, 11) + ] query = json.loads(request.query) tokenIndex = 1 - if 'page' in query: - if 'token' in query['page']: + if "page" in query: + if "token" in query["page"]: # For testing purposes, we return a token that is the same as the key - tokenIndex = int(query['page']['token']) - items = items[tokenIndex - 1:] - if 'limit' in query['page']: - limit = int(query['page']['limit']) + tokenIndex = int(query["page"]["token"]) + items = items[tokenIndex - 1 :] + if "limit" in query["page"]: + limit = int(query["page"]["limit"]) if len(items) > limit: items = items[:limit] tokenIndex = tokenIndex + len(items) @@ -311,13 +314,15 @@ def GetWorkflowBeta1(self, request: GetWorkflowRequest, context): instance_id = request.instance_id if instance_id in self.workflow_status: - status = str(self.workflow_status[instance_id])[len("WorkflowRuntimeStatus."):] - return GetWorkflowResponse(instance_id=instance_id, - workflow_name="example", - created_at=None, - last_updated_at=None, - runtime_status=status, - properties=self.workflow_options) + status = str(self.workflow_status[instance_id])[len("WorkflowRuntimeStatus.") :] + return GetWorkflowResponse( + instance_id=instance_id, + workflow_name="example", + created_at=None, + last_updated_at=None, + runtime_status=status, + properties=self.workflow_options, + ) else: # workflow non-existent raise Exception("Workflow instance does not exist") @@ -373,7 +378,7 @@ def RaiseEventWorkflowBeta1(self, request: RaiseEventWorkflowRequest, context): def GetMetadata(self, request, context): return GetMetadataResponse( - id='myapp', + id="myapp", active_actors_count=[ ActiveActorsCount( type="Nichelle Nichols", @@ -388,10 +393,7 @@ def GetMetadata(self, request, context): # Missing capabilities definition, ), RegisteredComponents( - name="pubsub", - type="pubsub.redis", - version="v1", - capabilities=[] + name="pubsub", type="pubsub.redis", version="v1", capabilities=[] ), RegisteredComponents( name="statestore", diff --git a/tests/clients/fake_http_server.py b/tests/clients/fake_http_server.py index 4de092368..a84248bbc 100644 --- a/tests/clients/fake_http_server.py +++ b/tests/clients/fake_http_server.py @@ -4,12 +4,16 @@ from threading import Thread from http.server import BaseHTTPRequestHandler, HTTPServer -from tests.clients.certs import CERTIFICATE_CHAIN_PATH, PRIVATE_KEY_PATH, create_certificates, \ - delete_certificates +from tests.clients.certs import ( + CERTIFICATE_CHAIN_PATH, + PRIVATE_KEY_PATH, + create_certificates, + delete_certificates, +) class DaprHandler(BaseHTTPRequestHandler): - protocol_version = 'HTTP/1.1' + protocol_version = "HTTP/1.1" def serve_forever(self): while not self.running: @@ -20,15 +24,15 @@ def do_request(self, verb): time.sleep(self.server.sleep_time) self.received_verb = verb self.server.request_headers = self.headers - if 'Content-Length' in self.headers: - content_length = int(self.headers['Content-Length']) + if "Content-Length" in self.headers: + content_length = int(self.headers["Content-Length"]) self.server.request_body += self.rfile.read(content_length) self.send_response(self.server.response_code) for key, value in self.server.response_header_list: self.send_header(key, value) - self.send_header('Content-Length', str(len(self.server.response_body))) + self.send_header("Content-Length", str(len(self.server.response_body))) self.end_headers() self.server.path = self.path @@ -36,26 +40,25 @@ def do_request(self, verb): self.wfile.write(self.server.response_body) def do_GET(self): - self.do_request('GET') + self.do_request("GET") def do_POST(self): - self.do_request('POST') + self.do_request("POST") def do_PUT(self): - self.do_request('PUT') + self.do_request("PUT") def do_DELETE(self): - self.do_request('DELETE') + self.do_request("DELETE") class FakeHttpServer(Thread): - def __init__(self, secure=False): super().__init__() self.secure = secure self.port = 4443 if secure else 8080 - self.server = HTTPServer(('localhost', self.port), DaprHandler) + self.server = HTTPServer(("localhost", self.port), DaprHandler) if self.secure: create_certificates("http") @@ -63,10 +66,10 @@ def __init__(self, secure=False): ssl_context.load_cert_chain(CERTIFICATE_CHAIN_PATH, PRIVATE_KEY_PATH) self.server.socket = ssl_context.wrap_socket(self.server.socket, server_side=True) - self.server.response_body = b'' + self.server.response_body = b"" self.server.response_code = 200 self.server.response_header_list = [] - self.server.request_body = b'' + self.server.request_body = b"" self.server.sleep_time = None def get_port(self): diff --git a/tests/clients/test_client_interceptor.py b/tests/clients/test_client_interceptor.py index 1f356a259..8f6945345 100644 --- a/tests/clients/test_client_interceptor.py +++ b/tests/clients/test_client_interceptor.py @@ -19,7 +19,6 @@ class DaprClientInterceptorTests(unittest.TestCase): - def setUp(self): self._fake_request = "fake request" @@ -27,21 +26,23 @@ def fake_continuation(self, call_details, request): return call_details def test_intercept_unary_unary_single_header(self): - interceptor = DaprClientInterceptor([('api-token', 'test-token')]) + interceptor = DaprClientInterceptor([("api-token", "test-token")]) call_details = _ClientCallDetails("method1", 10, None, None, None, None) response = interceptor.intercept_unary_unary( - self.fake_continuation, call_details, self._fake_request) + self.fake_continuation, call_details, self._fake_request + ) self.assertIsNotNone(response) self.assertEqual(1, len(response.metadata)) - self.assertEqual([('api-token', 'test-token')], response.metadata) + self.assertEqual([("api-token", "test-token")], response.metadata) def test_intercept_unary_unary_existing_metadata(self): - interceptor = DaprClientInterceptor([('api-token', 'test-token')]) - call_details = _ClientCallDetails("method1", 10, [('header', 'value')], None, None, None) + interceptor = DaprClientInterceptor([("api-token", "test-token")]) + call_details = _ClientCallDetails("method1", 10, [("header", "value")], None, None, None) response = interceptor.intercept_unary_unary( - self.fake_continuation, call_details, self._fake_request) + self.fake_continuation, call_details, self._fake_request + ) self.assertIsNotNone(response) self.assertEqual(2, len(response.metadata)) - self.assertEqual([('header', 'value'), ('api-token', 'test-token')], response.metadata) + self.assertEqual([("header", "value"), ("api-token", "test-token")], response.metadata) diff --git a/tests/clients/test_dapr_async_grpc_client.py b/tests/clients/test_dapr_async_grpc_client.py index f81796a50..5374ab6c6 100644 --- a/tests/clients/test_dapr_async_grpc_client.py +++ b/tests/clients/test_dapr_async_grpc_client.py @@ -38,7 +38,7 @@ class DaprGrpcClientAsyncTests(unittest.IsolatedAsyncioTestCase): server_port = 8080 - scheme = '' + scheme = "" def setUp(self): self._fake_dapr_server = FakeDaprSidecar() @@ -48,227 +48,223 @@ def tearDown(self): self._fake_dapr_server.stop() async def test_http_extension(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Test POST verb without querystring - ext = dapr._get_http_extension('POST') + ext = dapr._get_http_extension("POST") self.assertEqual(common_v1.HTTPExtension.Verb.POST, ext.verb) # Test Non-supported http verb with self.assertRaises(ValueError): - ext = dapr._get_http_extension('') + ext = dapr._get_http_extension("") # Test POST verb with querystring qs = ( - ('query1', 'string1'), - ('query2', 'string2'), - ('query1', 'string 3'), + ("query1", "string1"), + ("query2", "string2"), + ("query1", "string 3"), ) - ext = dapr._get_http_extension('POST', qs) + ext = dapr._get_http_extension("POST", qs) self.assertEqual(common_v1.HTTPExtension.Verb.POST, ext.verb) self.assertEqual("query1=string1&query2=string2&query1=string+3", ext.querystring) async def test_invoke_method_bytes_data(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_method( - app_id='targetId', - method_name='bytes', - data=b'haha', + app_id="targetId", + method_name="bytes", + data=b"haha", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), - http_verb='PUT', + http_verb="PUT", ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) async def test_invoke_method_no_data(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_method( - app_id='targetId', - method_name='bytes', + app_id="targetId", + method_name="bytes", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), - http_verb='PUT', + http_verb="PUT", ) - self.assertEqual(b'', resp.data) + self.assertEqual(b"", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) async def test_invoke_method_with_dapr_client(self): - dapr = DaprClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprClient(f"{self.scheme}localhost:{self.server_port}") dapr.invocation_client = None # force to use grpc client resp = await dapr.invoke_method( - app_id='targetId', - method_name='bytes', - data=b'haha', + app_id="targetId", + method_name="bytes", + data=b"haha", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), - http_verb='PUT', + http_verb="PUT", ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) async def test_invoke_method_proto_data(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - req = common_v1.StateItem(key='test') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") + req = common_v1.StateItem(key="test") resp = await dapr.invoke_method( - app_id='targetId', - method_name='proto', + app_id="targetId", + method_name="proto", data=req, metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) self.assertTrue(resp.is_proto()) # unpack to new protobuf object new_resp = common_v1.StateItem() resp.unpack(new_resp) - self.assertEqual('test', new_resp.key) + self.assertEqual("test", new_resp.key) async def test_invoke_binding_bytes_data(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_binding( - binding_name='binding', - operation='create', - data=b'haha', + binding_name="binding", + operation="create", + data=b"haha", binding_metadata={ - 'key1': 'value1', - 'key2': 'value2', + "key1": "value1", + "key2": "value2", }, ) - self.assertEqual(b'haha', resp.data) - self.assertEqual({'key1': 'value1', 'key2': 'value2'}, resp.binding_metadata) + self.assertEqual(b"haha", resp.data) + self.assertEqual({"key1": "value1", "key2": "value2"}, resp.binding_metadata) self.assertEqual(2, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) async def test_invoke_binding_no_metadata(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_binding( - binding_name='binding', - operation='create', - data=b'haha', + binding_name="binding", + operation="create", + data=b"haha", ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual({}, resp.binding_metadata) self.assertEqual(0, len(resp.headers)) async def test_invoke_binding_no_data(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_binding( - binding_name='binding', - operation='create', + binding_name="binding", + operation="create", ) - self.assertEqual(b'', resp.data) + self.assertEqual(b"", resp.data) self.assertEqual({}, resp.binding_metadata) self.assertEqual(0, len(resp.headers)) async def test_invoke_binding_no_create(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_binding( - binding_name='binding', - operation='delete', - data=b'haha', + binding_name="binding", + operation="delete", + data=b"haha", ) - self.assertEqual(b'INVALID', resp.data) + self.assertEqual(b"INVALID", resp.data) self.assertEqual({}, resp.binding_metadata) self.assertEqual(0, len(resp.headers)) async def test_publish_event(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - resp = await dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', - data=b'haha' - ) + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") + resp = await dapr.publish_event(pubsub_name="pubsub", topic_name="example", data=b"haha") self.assertEqual(2, len(resp.headers)) - self.assertEqual(['haha'], resp.headers['hdata']) + self.assertEqual(["haha"], resp.headers["hdata"]) async def test_publish_event_with_content_type(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', + pubsub_name="pubsub", + topic_name="example", data=b'{"foo": "bar"}', - data_content_type='application/json' + data_content_type="application/json", ) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['{"foo": "bar"}'], resp.headers['hdata']) - self.assertEqual(['application/json'], resp.headers['data_content_type']) + self.assertEqual(['{"foo": "bar"}'], resp.headers["hdata"]) + self.assertEqual(["application/json"], resp.headers["data_content_type"]) async def test_publish_event_with_metadata(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', + pubsub_name="pubsub", + topic_name="example", data=b'{"foo": "bar"}', - publish_metadata={'ttlInSeconds': '100', 'rawPayload': 'false'} + publish_metadata={"ttlInSeconds": "100", "rawPayload": "false"}, ) print(resp.headers) - self.assertEqual(['{"foo": "bar"}'], resp.headers['hdata']) - self.assertEqual(['false'], resp.headers['metadata_raw_payload']) - self.assertEqual(['100'], resp.headers['metadata_ttl_in_seconds']) + self.assertEqual(['{"foo": "bar"}'], resp.headers["hdata"]) + self.assertEqual(["false"], resp.headers["metadata_raw_payload"]) + self.assertEqual(["100"], resp.headers["metadata_ttl_in_seconds"]) async def test_publish_error(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") with self.assertRaisesRegex(ValueError, "invalid type for data "): await dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', + pubsub_name="pubsub", + topic_name="example", data=111, ) - @patch.object(settings, 'DAPR_API_TOKEN', 'test-token') + @patch.object(settings, "DAPR_API_TOKEN", "test-token") async def test_dapr_api_token_insertion(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.invoke_method( - app_id='targetId', - method_name='bytes', - data=b'haha', + app_id="targetId", + method_name="bytes", + data=b"haha", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(4, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) - self.assertEqual(['test-token'], resp.headers['hdapr-api-token']) + self.assertEqual(["value1"], resp.headers["hkey1"]) + self.assertEqual(["test-token"], resp.headers["hdapr-api-token"]) async def test_get_save_delete_state(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") key = "key_1" value = "value_1" options = StateOptions( @@ -279,9 +275,9 @@ async def test_get_save_delete_state(self): store_name="statestore", key=key, value=value, - etag='fake_etag', + etag="fake_etag", options=options, - state_metadata={"capitalize": "1"} + state_metadata={"capitalize": "1"}, ) resp = await dapr.get_state(store_name="statestore", key=key) @@ -293,43 +289,34 @@ async def test_get_save_delete_state(self): self.assertEqual(resp.etag, "fake_etag") resp = await dapr.get_state(store_name="statestore", key="NotValidKey") - self.assertEqual(resp.data, b'') - self.assertEqual(resp.etag, '') + self.assertEqual(resp.data, b"") + self.assertEqual(resp.etag, "") - await dapr.delete_state( - store_name="statestore", - key=key - ) + await dapr.delete_state(store_name="statestore", key=key) resp = await dapr.get_state(store_name="statestore", key=key) - self.assertEqual(resp.data, b'') - self.assertEqual(resp.etag, '') + self.assertEqual(resp.data, b"") + self.assertEqual(resp.etag, "") with self.assertRaises(Exception) as context: await dapr.delete_state( - store_name="statestore", - key=key, - state_metadata={"must_delete": "1"}) + store_name="statestore", key=key, state_metadata={"must_delete": "1"} + ) print(context.exception) - self.assertTrue('delete failed' in str(context.exception)) + self.assertTrue("delete failed" in str(context.exception)) async def test_get_save_state_etag_none(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") - value = 'test' - no_etag_key = 'no_etag' - empty_etag_key = 'empty_etag' + value = "test" + no_etag_key = "no_etag" + empty_etag_key = "empty_etag" await dapr.save_state( store_name="statestore", key=no_etag_key, value=value, ) - await dapr.save_state( - store_name="statestore", - key=empty_etag_key, - value=value, - etag="" - ) + await dapr.save_state(store_name="statestore", key=empty_etag_key, value=value, etag="") resp = await dapr.get_state(store_name="statestore", key=no_etag_key) self.assertEqual(resp.data, to_bytes(value)) @@ -340,7 +327,7 @@ async def test_get_save_state_etag_none(self): self.assertEqual(resp.etag, "") async def test_transaction_then_get_states(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") key = str(uuid.uuid4()) value = str(uuid.uuid4()) @@ -353,7 +340,7 @@ async def test_transaction_then_get_states(self): TransactionalStateOperation(key=key, data=value, etag="foo"), TransactionalStateOperation(key=another_key, data=another_value), ], - transactional_metadata={"metakey": "metavalue"} + transactional_metadata={"metakey": "metavalue"}, ) resp = await dapr.get_bulk_state(store_name="statestore", keys=[key, another_key]) @@ -365,16 +352,15 @@ async def test_transaction_then_get_states(self): self.assertEqual(resp.items[1].etag, "ETAG_WAS_NONE") resp = await dapr.get_bulk_state( - store_name="statestore", - keys=[key, another_key], - states_metadata={"upper": "1"}) + store_name="statestore", keys=[key, another_key], states_metadata={"upper": "1"} + ) self.assertEqual(resp.items[0].key, key) self.assertEqual(resp.items[0].data, to_bytes(value.upper())) self.assertEqual(resp.items[1].key, another_key) self.assertEqual(resp.items[1].data, to_bytes(another_value.upper())) async def test_save_then_get_states(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") key = str(uuid.uuid4()) value = str(uuid.uuid4()) @@ -387,7 +373,7 @@ async def test_save_then_get_states(self): StateItem(key=key, value=value, metadata={"capitalize": "1"}), StateItem(key=another_key, value=another_value, etag="1"), ], - metadata=(("metakey", "metavalue"),) + metadata=(("metakey", "metavalue"),), ) resp = await dapr.get_bulk_state(store_name="statestore", keys=[key, another_key]) @@ -399,9 +385,8 @@ async def test_save_then_get_states(self): self.assertEqual(resp.items[1].etag, "1") resp = await dapr.get_bulk_state( - store_name="statestore", - keys=[key, another_key], - states_metadata={"upper": "1"}) + store_name="statestore", keys=[key, another_key], states_metadata={"upper": "1"} + ) self.assertEqual(resp.items[0].key, key) self.assertEqual(resp.items[0].etag, "ETAG_WAS_NONE") self.assertEqual(resp.items[0].data, to_bytes(value.upper())) @@ -410,57 +395,57 @@ async def test_save_then_get_states(self): self.assertEqual(resp.items[1].data, to_bytes(another_value.upper())) async def test_get_secret(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - key1 = 'key_1' + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") + key1 = "key_1" resp = await dapr.get_secret( - store_name='store_1', + store_name="store_1", key=key1, metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) self.assertEqual(1, len(resp.headers)) - self.assertEqual([key1], resp.headers['keyh']) + self.assertEqual([key1], resp.headers["keyh"]) self.assertEqual({key1: "val"}, resp._secret) async def test_get_secret_metadata_absent(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - key1 = 'key_1' + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") + key1 = "key_1" resp = await dapr.get_secret( - store_name='store_1', + store_name="store_1", key=key1, ) self.assertEqual(1, len(resp.headers)) - self.assertEqual([key1], resp.headers['keyh']) + self.assertEqual([key1], resp.headers["keyh"]) self.assertEqual({key1: "val"}, resp._secret) async def test_get_bulk_secret(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.get_bulk_secret( - store_name='store_1', + store_name="store_1", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) self.assertEqual(1, len(resp.headers)) - self.assertEqual(["bulk"], resp.headers['keyh']) + self.assertEqual(["bulk"], resp.headers["keyh"]) self.assertEqual({"keya": {"keyb": "val"}}, resp._secrets) async def test_get_bulk_secret_metadata_absent(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - resp = await dapr.get_bulk_secret(store_name='store_1') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") + resp = await dapr.get_bulk_secret(store_name="store_1") self.assertEqual(1, len(resp.headers)) - self.assertEqual(["bulk"], resp.headers['keyh']) + self.assertEqual(["bulk"], resp.headers["keyh"]) self.assertEqual({"keya": {"keyb": "val"}}, resp._secrets) async def test_get_configuration(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") keys = ["k", "k1"] value = "value" version = "1.5.0" @@ -475,7 +460,8 @@ async def test_get_configuration(self): self.assertEqual(item.metadata, metadata) resp = await dapr.get_configuration( - store_name="configurationstore", keys=keys, config_metadata=metadata) + store_name="configurationstore", keys=keys, config_metadata=metadata + ) self.assertEqual(len(resp.items), len(keys)) self.assertIn(keys[0], resp.items) item = resp.items[keys[0]] @@ -484,31 +470,33 @@ async def test_get_configuration(self): self.assertEqual(item.metadata, metadata) async def test_subscribe_configuration(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") def mock_watch(self, stub, store_name, keys, handler, config_metadata): - handler("id", ConfigurationResponse(items={ - "k": ConfigurationItem( - value="test", - version="1.7.0") - })) + handler( + "id", + ConfigurationResponse( + items={"k": ConfigurationItem(value="test", version="1.7.0")} + ), + ) return "id" def handler(id: str, resp: ConfigurationResponse): self.assertEqual(resp.items["k"].value, "test") self.assertEqual(resp.items["k"].version, "1.7.0") - with patch.object(ConfigurationWatcher, 'watch_configuration', mock_watch): + with patch.object(ConfigurationWatcher, "watch_configuration", mock_watch): await dapr.subscribe_configuration( - store_name="configurationstore", keys=["k"], handler=handler) + store_name="configurationstore", keys=["k"], handler=handler + ) async def test_unsubscribe_configuration(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") res = await dapr.unsubscribe_configuration(store_name="configurationstore", id="k") self.assertTrue(res) async def test_query_state(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") resp = await dapr.query_state( store_name="statestore", @@ -525,29 +513,29 @@ async def test_query_state(self): self.assertEqual(len(resp.results), 3) async def test_shutdown(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") await dapr.shutdown() self.assertTrue(self._fake_dapr_server.shutdown_received) async def test_wait_ok(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") await dapr.wait(0.1) async def test_wait_timeout(self): # First, pick an unused port port = 0 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) port = s.getsockname()[1] - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{port}") with self.assertRaises(Exception) as context: await dapr.wait(0.1) - self.assertTrue('Connection refused' in str(context.exception)) + self.assertTrue("Connection refused" in str(context.exception)) async def test_lock_acquire_success(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) lock_owner = str(uuid.uuid4()) expiry_in_seconds = 60 @@ -558,9 +546,9 @@ async def test_lock_acquire_success(self): self.assertEqual(UnlockResponseStatus.success, unlock_response.status) async def test_lock_release_twice_fails(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) lock_owner = str(uuid.uuid4()) expiry_in_seconds = 60 @@ -574,9 +562,9 @@ async def test_lock_release_twice_fails(self): self.assertEqual(UnlockResponseStatus.lock_does_not_exist, unlock_response.status) async def test_lock_conflict(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) first_client_id = str(uuid.uuid4()) second_client_id = str(uuid.uuid4()) @@ -596,28 +584,29 @@ async def test_lock_conflict(self): self.assertEqual(UnlockResponseStatus.success, unlock_response.status) async def test_lock_not_previously_acquired(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") unlock_response = await dapr.unlock( - store_name='lockstore', - resource_id=str(uuid.uuid4()), - lock_owner=str(uuid.uuid4())) + store_name="lockstore", resource_id=str(uuid.uuid4()), lock_owner=str(uuid.uuid4()) + ) self.assertEqual(UnlockResponseStatus.lock_does_not_exist, unlock_response.status) async def test_lock_release_twice_fails_with_context_manager(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) first_client_id = str(uuid.uuid4()) second_client_id = str(uuid.uuid4()) expiry = 60 - async with await dapr.try_lock(store_name, resource_id, first_client_id, expiry - ) as first_lock: + async with await dapr.try_lock( + store_name, resource_id, first_client_id, expiry + ) as first_lock: self.assertTrue(first_lock.success) # If another client tries to acquire the same lock it will fail - async with await dapr.try_lock(store_name, resource_id, second_client_id, expiry - ) as second_lock: + async with await dapr.try_lock( + store_name, resource_id, second_client_id, expiry + ) as second_lock: self.assertFalse(second_lock.success) # At this point lock was auto-released # If client tries again it will discover the lock is gone @@ -625,60 +614,66 @@ async def test_lock_release_twice_fails_with_context_manager(self): self.assertEqual(UnlockResponseStatus.lock_does_not_exist, unlock_response.status) async def test_lock_are_not_reentrant(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) client_id = str(uuid.uuid4()) expiry_in_s = 60 - async with await dapr.try_lock(store_name, resource_id, client_id, expiry_in_s - ) as first_attempt: + async with await dapr.try_lock( + store_name, resource_id, client_id, expiry_in_s + ) as first_attempt: self.assertTrue(first_attempt.success) # If the same client tries to acquire the same lock again it will fail. - async with await dapr.try_lock(store_name, resource_id, client_id, expiry_in_s - ) as second_attempt: + async with await dapr.try_lock( + store_name, resource_id, client_id, expiry_in_s + ) as second_attempt: self.assertFalse(second_attempt.success) async def test_lock_input_validation(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Sane parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) client_id = str(uuid.uuid4()) expiry_in_s = 60 # Invalid inputs for string arguments - for invalid_input in [None, '', ' ']: + for invalid_input in [None, "", " "]: # store_name with self.assertRaises(ValueError): - async with await dapr.try_lock(invalid_input, resource_id, client_id, expiry_in_s - ) as res: + async with await dapr.try_lock( + invalid_input, resource_id, client_id, expiry_in_s + ) as res: self.assertTrue(res.success) # resource_id with self.assertRaises(ValueError): - async with await dapr.try_lock(store_name, invalid_input, client_id, expiry_in_s - ) as res: + async with await dapr.try_lock( + store_name, invalid_input, client_id, expiry_in_s + ) as res: self.assertTrue(res.success) # client_id with self.assertRaises(ValueError): - async with await dapr.try_lock(store_name, resource_id, invalid_input, expiry_in_s - ) as res: + async with await dapr.try_lock( + store_name, resource_id, invalid_input, expiry_in_s + ) as res: self.assertTrue(res.success) # Invalid inputs for expiry_in_s for invalid_input in [None, -1, 0]: with self.assertRaises(ValueError): - async with await dapr.try_lock(store_name, resource_id, client_id, invalid_input - ) as res: + async with await dapr.try_lock( + store_name, resource_id, client_id, invalid_input + ) as res: self.assertTrue(res.success) async def test_unlock_input_validation(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") # Sane parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) client_id = str(uuid.uuid4()) # Invalid inputs for string arguments - for invalid_input in [None, '', ' ']: + for invalid_input in [None, "", " "]: # store_name with self.assertRaises(ValueError): await dapr.unlock(invalid_input, resource_id, client_id) @@ -694,12 +689,12 @@ async def test_unlock_input_validation(self): # async def test_get_metadata(self): - async with DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') as dapr: + async with DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") as dapr: response = await dapr.get_metadata() self.assertIsNotNone(response) - self.assertEqual(response.application_id, 'myapp') + self.assertEqual(response.application_id, "myapp") actors = response.active_actors_count self.assertIsNotNone(actors) @@ -718,39 +713,35 @@ async def test_get_metadata(self): self.assertTrue(c.type) self.assertIsNotNone(c.version) self.assertIsNotNone(c.capabilities) - self.assertTrue("ETAG" in components['statestore'].capabilities) + self.assertTrue("ETAG" in components["statestore"].capabilities) self.assertIsNotNone(response.extended_metadata) async def test_set_metadata(self): metadata_key = "test_set_metadata_attempt" - async with DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') as dapr: + async with DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") as dapr: for metadata_value in [str(i) for i in range(10)]: - await dapr.set_metadata(attributeName=metadata_key, - attributeValue=metadata_value) + await dapr.set_metadata(attributeName=metadata_key, attributeValue=metadata_value) response = await dapr.get_metadata() self.assertIsNotNone(response) self.assertIsNotNone(response.extended_metadata) - self.assertEqual(response.extended_metadata[metadata_key], - metadata_value) + self.assertEqual(response.extended_metadata[metadata_key], metadata_value) # Empty string and blank strings should be accepted just fine # by this API - for metadata_value in ['', ' ']: - await dapr.set_metadata(attributeName=metadata_key, - attributeValue=metadata_value) + for metadata_value in ["", " "]: + await dapr.set_metadata(attributeName=metadata_key, attributeValue=metadata_value) response = await dapr.get_metadata() self.assertIsNotNone(response) self.assertIsNotNone(response.extended_metadata) - self.assertEqual(response.extended_metadata[metadata_key], - metadata_value) + self.assertEqual(response.extended_metadata[metadata_key], metadata_value) async def test_set_metadata_input_validation(self): - dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - valid_attr_name = 'attribute name' - valid_attr_value = 'attribute value' + dapr = DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") + valid_attr_name = "attribute name" + valid_attr_value = "attribute value" # Invalid inputs for string arguments - async with DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') as dapr: - for invalid_attr_name in [None, '', ' ']: + async with DaprGrpcClientAsync(f"{self.scheme}localhost:{self.server_port}") as dapr: + for invalid_attr_name in [None, "", " "]: with self.assertRaises(ValueError): await dapr.set_metadata(invalid_attr_name, valid_attr_value) # We are less strict with attribute values - we just cannot accept None @@ -759,5 +750,5 @@ async def test_set_metadata_input_validation(self): await dapr.set_metadata(valid_attr_name, invalid_attr_value) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/clients/test_dapr_grpc_client.py b/tests/clients/test_dapr_grpc_client.py index 54e98801f..43a9ad6e5 100644 --- a/tests/clients/test_dapr_grpc_client.py +++ b/tests/clients/test_dapr_grpc_client.py @@ -39,7 +39,7 @@ class DaprGrpcClientTests(unittest.TestCase): server_port = 8080 - scheme = '' + scheme = "" def setUp(self): self._fake_dapr_server = FakeDaprSidecar() @@ -49,227 +49,223 @@ def tearDown(self): self._fake_dapr_server.stop() def test_http_extension(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Test POST verb without querystring - ext = dapr._get_http_extension('POST') + ext = dapr._get_http_extension("POST") self.assertEqual(common_v1.HTTPExtension.Verb.POST, ext.verb) # Test Non-supported http verb with self.assertRaises(ValueError): - ext = dapr._get_http_extension('') + ext = dapr._get_http_extension("") # Test POST verb with querystring qs = ( - ('query1', 'string1'), - ('query2', 'string2'), - ('query1', 'string 3'), + ("query1", "string1"), + ("query2", "string2"), + ("query1", "string 3"), ) - ext = dapr._get_http_extension('POST', qs) + ext = dapr._get_http_extension("POST", qs) self.assertEqual(common_v1.HTTPExtension.Verb.POST, ext.verb) self.assertEqual("query1=string1&query2=string2&query1=string+3", ext.querystring) def test_invoke_method_bytes_data(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_method( - app_id='targetId', - method_name='bytes', - data=b'haha', + app_id="targetId", + method_name="bytes", + data=b"haha", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), - http_verb='PUT', + http_verb="PUT", ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) def test_invoke_method_no_data(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_method( - app_id='targetId', - method_name='bytes', + app_id="targetId", + method_name="bytes", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), - http_verb='PUT', + http_verb="PUT", ) - self.assertEqual(b'', resp.data) + self.assertEqual(b"", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) def test_invoke_method_async(self): - dapr = DaprClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprClient(f"{self.scheme}localhost:{self.server_port}") dapr.invocation_client = None # force to use grpc client with self.assertRaises(NotImplementedError): loop = asyncio.new_event_loop() loop.run_until_complete( dapr.invoke_method_async( - app_id='targetId', - method_name='bytes', - data=b'haha', + app_id="targetId", + method_name="bytes", + data=b"haha", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), - http_verb='PUT', + http_verb="PUT", ) ) def test_invoke_method_proto_data(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') - req = common_v1.StateItem(key='test') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") + req = common_v1.StateItem(key="test") resp = dapr.invoke_method( - app_id='targetId', - method_name='proto', + app_id="targetId", + method_name="proto", data=req, metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) self.assertTrue(resp.is_proto()) # unpack to new protobuf object new_resp = common_v1.StateItem() resp.unpack(new_resp) - self.assertEqual('test', new_resp.key) + self.assertEqual("test", new_resp.key) def test_invoke_binding_bytes_data(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_binding( - binding_name='binding', - operation='create', - data=b'haha', + binding_name="binding", + operation="create", + data=b"haha", binding_metadata={ - 'key1': 'value1', - 'key2': 'value2', + "key1": "value1", + "key2": "value2", }, ) - self.assertEqual(b'haha', resp.data) - self.assertEqual({'key1': 'value1', 'key2': 'value2'}, resp.binding_metadata) + self.assertEqual(b"haha", resp.data) + self.assertEqual({"key1": "value1", "key2": "value2"}, resp.binding_metadata) self.assertEqual(2, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) + self.assertEqual(["value1"], resp.headers["hkey1"]) def test_invoke_binding_no_metadata(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_binding( - binding_name='binding', - operation='create', - data=b'haha', + binding_name="binding", + operation="create", + data=b"haha", ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual({}, resp.binding_metadata) self.assertEqual(0, len(resp.headers)) def test_invoke_binding_no_data(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_binding( - binding_name='binding', - operation='create', + binding_name="binding", + operation="create", ) - self.assertEqual(b'', resp.data) + self.assertEqual(b"", resp.data) self.assertEqual({}, resp.binding_metadata) self.assertEqual(0, len(resp.headers)) def test_invoke_binding_no_create(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_binding( - binding_name='binding', - operation='delete', - data=b'haha', + binding_name="binding", + operation="delete", + data=b"haha", ) - self.assertEqual(b'INVALID', resp.data) + self.assertEqual(b"INVALID", resp.data) self.assertEqual({}, resp.binding_metadata) self.assertEqual(0, len(resp.headers)) def test_publish_event(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') - resp = dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', - data=b'haha' - ) + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") + resp = dapr.publish_event(pubsub_name="pubsub", topic_name="example", data=b"haha") self.assertEqual(2, len(resp.headers)) - self.assertEqual(['haha'], resp.headers['hdata']) + self.assertEqual(["haha"], resp.headers["hdata"]) def test_publish_event_with_content_type(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', + pubsub_name="pubsub", + topic_name="example", data=b'{"foo": "bar"}', - data_content_type='application/json' + data_content_type="application/json", ) self.assertEqual(3, len(resp.headers)) - self.assertEqual(['{"foo": "bar"}'], resp.headers['hdata']) - self.assertEqual(['application/json'], resp.headers['data_content_type']) + self.assertEqual(['{"foo": "bar"}'], resp.headers["hdata"]) + self.assertEqual(["application/json"], resp.headers["data_content_type"]) def test_publish_event_with_metadata(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', + pubsub_name="pubsub", + topic_name="example", data=b'{"foo": "bar"}', - publish_metadata={'ttlInSeconds': '100', 'rawPayload': 'false'} + publish_metadata={"ttlInSeconds": "100", "rawPayload": "false"}, ) print(resp.headers) - self.assertEqual(['{"foo": "bar"}'], resp.headers['hdata']) - self.assertEqual(['false'], resp.headers['metadata_raw_payload']) - self.assertEqual(['100'], resp.headers['metadata_ttl_in_seconds']) + self.assertEqual(['{"foo": "bar"}'], resp.headers["hdata"]) + self.assertEqual(["false"], resp.headers["metadata_raw_payload"]) + self.assertEqual(["100"], resp.headers["metadata_ttl_in_seconds"]) def test_publish_error(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") with self.assertRaisesRegex(ValueError, "invalid type for data "): dapr.publish_event( - pubsub_name='pubsub', - topic_name='example', + pubsub_name="pubsub", + topic_name="example", data=111, ) - @patch.object(settings, 'DAPR_API_TOKEN', 'test-token') + @patch.object(settings, "DAPR_API_TOKEN", "test-token") def test_dapr_api_token_insertion(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.invoke_method( - app_id='targetId', - method_name='bytes', - data=b'haha', + app_id="targetId", + method_name="bytes", + data=b"haha", content_type="text/plain", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) - self.assertEqual(b'haha', resp.data) + self.assertEqual(b"haha", resp.data) self.assertEqual("text/plain", resp.content_type) self.assertEqual(4, len(resp.headers)) - self.assertEqual(['value1'], resp.headers['hkey1']) - self.assertEqual(['test-token'], resp.headers['hdapr-api-token']) + self.assertEqual(["value1"], resp.headers["hkey1"]) + self.assertEqual(["test-token"], resp.headers["hdapr-api-token"]) def test_get_save_delete_state(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") key = "key_1" value = "value_1" options = StateOptions( @@ -280,9 +276,9 @@ def test_get_save_delete_state(self): store_name="statestore", key=key, value=value, - etag='fake_etag', + etag="fake_etag", options=options, - state_metadata={"capitalize": "1"} + state_metadata={"capitalize": "1"}, ) resp = dapr.get_state(store_name="statestore", key=key) @@ -294,43 +290,32 @@ def test_get_save_delete_state(self): self.assertEqual(resp.etag, "fake_etag") resp = dapr.get_state(store_name="statestore", key="NotValidKey") - self.assertEqual(resp.data, b'') - self.assertEqual(resp.etag, '') + self.assertEqual(resp.data, b"") + self.assertEqual(resp.etag, "") - dapr.delete_state( - store_name="statestore", - key=key - ) + dapr.delete_state(store_name="statestore", key=key) resp = dapr.get_state(store_name="statestore", key=key) - self.assertEqual(resp.data, b'') - self.assertEqual(resp.etag, '') + self.assertEqual(resp.data, b"") + self.assertEqual(resp.etag, "") with self.assertRaises(Exception) as context: - dapr.delete_state( - store_name="statestore", - key=key, - state_metadata={"must_delete": "1"}) + dapr.delete_state(store_name="statestore", key=key, state_metadata={"must_delete": "1"}) print(context.exception) - self.assertTrue('delete failed' in str(context.exception)) + self.assertTrue("delete failed" in str(context.exception)) def test_get_save_state_etag_none(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") - value = 'test' - no_etag_key = 'no_etag' - empty_etag_key = 'empty_etag' + value = "test" + no_etag_key = "no_etag" + empty_etag_key = "empty_etag" dapr.save_state( store_name="statestore", key=no_etag_key, value=value, ) - dapr.save_state( - store_name="statestore", - key=empty_etag_key, - value=value, - etag="" - ) + dapr.save_state(store_name="statestore", key=empty_etag_key, value=value, etag="") resp = dapr.get_state(store_name="statestore", key=no_etag_key) self.assertEqual(resp.data, to_bytes(value)) @@ -341,7 +326,7 @@ def test_get_save_state_etag_none(self): self.assertEqual(resp.etag, "") def test_transaction_then_get_states(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") key = str(uuid.uuid4()) value = str(uuid.uuid4()) @@ -354,7 +339,7 @@ def test_transaction_then_get_states(self): TransactionalStateOperation(key=key, data=value, etag="foo"), TransactionalStateOperation(key=another_key, data=another_value), ], - transactional_metadata={"metakey": "metavalue"} + transactional_metadata={"metakey": "metavalue"}, ) resp = dapr.get_bulk_state(store_name="statestore", keys=[key, another_key]) @@ -366,16 +351,15 @@ def test_transaction_then_get_states(self): self.assertEqual(resp.items[1].etag, "ETAG_WAS_NONE") resp = dapr.get_bulk_state( - store_name="statestore", - keys=[key, another_key], - states_metadata={"upper": "1"}) + store_name="statestore", keys=[key, another_key], states_metadata={"upper": "1"} + ) self.assertEqual(resp.items[0].key, key) self.assertEqual(resp.items[0].data, to_bytes(value.upper())) self.assertEqual(resp.items[1].key, another_key) self.assertEqual(resp.items[1].data, to_bytes(another_value.upper())) def test_save_then_get_states(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") key = str(uuid.uuid4()) value = str(uuid.uuid4()) @@ -388,7 +372,7 @@ def test_save_then_get_states(self): StateItem(key=key, value=value, metadata={"capitalize": "1"}), StateItem(key=another_key, value=another_value, etag="1"), ], - metadata=(("metakey", "metavalue"),) + metadata=(("metakey", "metavalue"),), ) resp = dapr.get_bulk_state(store_name="statestore", keys=[key, another_key]) @@ -400,9 +384,8 @@ def test_save_then_get_states(self): self.assertEqual(resp.items[1].etag, "1") resp = dapr.get_bulk_state( - store_name="statestore", - keys=[key, another_key], - states_metadata={"upper": "1"}) + store_name="statestore", keys=[key, another_key], states_metadata={"upper": "1"} + ) self.assertEqual(resp.items[0].key, key) self.assertEqual(resp.items[0].etag, "ETAG_WAS_NONE") self.assertEqual(resp.items[0].data, to_bytes(value.upper())) @@ -411,57 +394,57 @@ def test_save_then_get_states(self): self.assertEqual(resp.items[1].data, to_bytes(another_value.upper())) def test_get_secret(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') - key1 = 'key_1' + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") + key1 = "key_1" resp = dapr.get_secret( - store_name='store_1', + store_name="store_1", key=key1, metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) self.assertEqual(1, len(resp.headers)) - self.assertEqual([key1], resp.headers['keyh']) + self.assertEqual([key1], resp.headers["keyh"]) self.assertEqual({key1: "val"}, resp._secret) def test_get_secret_metadata_absent(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') - key1 = 'key_1' + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") + key1 = "key_1" resp = dapr.get_secret( - store_name='store_1', + store_name="store_1", key=key1, ) self.assertEqual(1, len(resp.headers)) - self.assertEqual([key1], resp.headers['keyh']) + self.assertEqual([key1], resp.headers["keyh"]) self.assertEqual({key1: "val"}, resp._secret) def test_get_bulk_secret(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.get_bulk_secret( - store_name='store_1', + store_name="store_1", metadata=( - ('key1', 'value1'), - ('key2', 'value2'), + ("key1", "value1"), + ("key2", "value2"), ), ) self.assertEqual(1, len(resp.headers)) - self.assertEqual(["bulk"], resp.headers['keyh']) + self.assertEqual(["bulk"], resp.headers["keyh"]) self.assertEqual({"keya": {"keyb": "val"}}, resp._secrets) def test_get_bulk_secret_metadata_absent(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') - resp = dapr.get_bulk_secret(store_name='store_1') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") + resp = dapr.get_bulk_secret(store_name="store_1") self.assertEqual(1, len(resp.headers)) - self.assertEqual(["bulk"], resp.headers['keyh']) + self.assertEqual(["bulk"], resp.headers["keyh"]) self.assertEqual({"keya": {"keyb": "val"}}, resp._secrets) def test_get_configuration(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") keys = ["k", "k1"] value = "value" version = "1.5.0" @@ -476,7 +459,8 @@ def test_get_configuration(self): self.assertEqual(item.metadata, metadata) resp = dapr.get_configuration( - store_name="configurationstore", keys=keys, config_metadata=metadata) + store_name="configurationstore", keys=keys, config_metadata=metadata + ) self.assertEqual(len(resp.items), len(keys)) self.assertIn(keys[0], resp.items) item = resp.items[keys[0]] @@ -485,31 +469,33 @@ def test_get_configuration(self): self.assertEqual(item.metadata, metadata) def test_subscribe_configuration(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") def mock_watch(self, stub, store_name, keys, handler, config_metadata): - handler("id", ConfigurationResponse(items={ - "k": ConfigurationItem( - value="test", - version="1.7.0") - })) + handler( + "id", + ConfigurationResponse( + items={"k": ConfigurationItem(value="test", version="1.7.0")} + ), + ) return "id" def handler(id: str, resp: ConfigurationResponse): self.assertEqual(resp.items["k"].value, "test") self.assertEqual(resp.items["k"].version, "1.7.0") - with patch.object(ConfigurationWatcher, 'watch_configuration', mock_watch): - dapr.subscribe_configuration(store_name="configurationstore", - keys=["k"], handler=handler) + with patch.object(ConfigurationWatcher, "watch_configuration", mock_watch): + dapr.subscribe_configuration( + store_name="configurationstore", keys=["k"], handler=handler + ) def test_unsubscribe_configuration(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") res = dapr.unsubscribe_configuration(store_name="configurationstore", id="k") self.assertTrue(res) def test_query_state(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") resp = dapr.query_state( store_name="statestore", @@ -526,29 +512,29 @@ def test_query_state(self): self.assertEqual(len(resp.results), 3) def test_shutdown(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") dapr.shutdown() self.assertTrue(self._fake_dapr_server.shutdown_received) def test_wait_ok(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") dapr.wait(0.1) def test_wait_timeout(self): # First, pick an unused port port = 0 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) port = s.getsockname()[1] - dapr = DaprGrpcClient(f'localhost:{port}') + dapr = DaprGrpcClient(f"localhost:{port}") with self.assertRaises(Exception) as context: dapr.wait(0.1) - self.assertTrue('Connection refused' in str(context.exception)) + self.assertTrue("Connection refused" in str(context.exception)) def test_lock_acquire_success(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) lock_owner = str(uuid.uuid4()) expiry_in_seconds = 60 @@ -559,9 +545,9 @@ def test_lock_acquire_success(self): self.assertEqual(UnlockResponseStatus.success, unlock_response.status) def test_lock_release_twice_fails(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) lock_owner = str(uuid.uuid4()) expiry_in_seconds = 60 @@ -575,9 +561,9 @@ def test_lock_release_twice_fails(self): self.assertEqual(UnlockResponseStatus.lock_does_not_exist, unlock_response.status) def test_lock_conflict(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) first_client_id = str(uuid.uuid4()) second_client_id = str(uuid.uuid4()) @@ -597,17 +583,16 @@ def test_lock_conflict(self): self.assertEqual(UnlockResponseStatus.success, unlock_response.status) def test_lock_not_previously_acquired(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") unlock_response = dapr.unlock( - store_name='lockstore', - resource_id=str(uuid.uuid4()), - lock_owner=str(uuid.uuid4())) + store_name="lockstore", resource_id=str(uuid.uuid4()), lock_owner=str(uuid.uuid4()) + ) self.assertEqual(UnlockResponseStatus.lock_does_not_exist, unlock_response.status) def test_lock_release_twice_fails_with_context_manager(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) first_client_id = str(uuid.uuid4()) second_client_id = str(uuid.uuid4()) @@ -624,9 +609,9 @@ def test_lock_release_twice_fails_with_context_manager(self): self.assertEqual(UnlockResponseStatus.lock_does_not_exist, unlock_response.status) def test_lock_are_not_reentrant(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Lock parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) client_id = str(uuid.uuid4()) expiry_in_s = 60 @@ -638,14 +623,14 @@ def test_lock_are_not_reentrant(self): self.assertFalse(second_attempt.success) def test_lock_input_validation(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Sane parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) client_id = str(uuid.uuid4()) expiry_in_s = 60 # Invalid inputs for string arguments - for invalid_input in [None, '', ' ']: + for invalid_input in [None, "", " "]: # store_name with self.assertRaises(ValueError): with dapr.try_lock(invalid_input, resource_id, client_id, expiry_in_s) as res: @@ -665,13 +650,13 @@ def test_lock_input_validation(self): self.assertTrue(res.success) def test_unlock_input_validation(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Sane parameters - store_name = 'lockstore' + store_name = "lockstore" resource_id = str(uuid.uuid4()) client_id = str(uuid.uuid4()) # Invalid inputs for string arguments - for invalid_input in [None, '', ' ']: + for invalid_input in [None, "", " "]: # store_name with self.assertRaises(ValueError): dapr.unlock(invalid_input, resource_id, client_id) @@ -687,26 +672,29 @@ def test_unlock_input_validation(self): # def test_workflow(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") # Sane parameters workflow_name = "test_workflow" event_name = "eventName" instance_id = str(uuid.uuid4()) workflow_component = "dapr" input = "paperclips" - event_data = 'cars' + event_data = "cars" # Start the workflow - start_response = dapr.start_workflow(instance_id=instance_id, - workflow_name=workflow_name, - workflow_component=workflow_component, - input=input, - workflow_options=None) + start_response = dapr.start_workflow( + instance_id=instance_id, + workflow_name=workflow_name, + workflow_component=workflow_component, + input=input, + workflow_options=None, + ) self.assertEqual(instance_id, start_response.instance_id) # Get info on the workflow to check that it is running - get_response = dapr.get_workflow(instance_id=instance_id, - workflow_component=workflow_component) + get_response = dapr.get_workflow( + instance_id=instance_id, workflow_component=workflow_component + ) self.assertEqual(WorkflowRuntimeStatus.RUNNING.value, get_response.runtime_status) # Pause the workflow @@ -749,12 +737,12 @@ def test_workflow(self): # def test_get_metadata(self): - with DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') as dapr: + with DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") as dapr: response = dapr.get_metadata() self.assertIsNotNone(response) - self.assertEqual(response.application_id, 'myapp') + self.assertEqual(response.application_id, "myapp") actors = response.active_actors_count self.assertIsNotNone(actors) @@ -773,39 +761,35 @@ def test_get_metadata(self): self.assertTrue(c.type) self.assertIsNotNone(c.version) self.assertIsNotNone(c.capabilities) - self.assertTrue("ETAG" in components['statestore'].capabilities) + self.assertTrue("ETAG" in components["statestore"].capabilities) self.assertIsNotNone(response.extended_metadata) def test_set_metadata(self): metadata_key = "test_set_metadata_attempt" - with DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') as dapr: + with DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") as dapr: for metadata_value in [str(i) for i in range(10)]: - dapr.set_metadata(attributeName=metadata_key, - attributeValue=metadata_value) + dapr.set_metadata(attributeName=metadata_key, attributeValue=metadata_value) response = dapr.get_metadata() self.assertIsNotNone(response) self.assertIsNotNone(response.extended_metadata) - self.assertEqual(response.extended_metadata[metadata_key], - metadata_value) + self.assertEqual(response.extended_metadata[metadata_key], metadata_value) # Empty string and blank strings should be accepted just fine # by this API - for metadata_value in ['', ' ']: - dapr.set_metadata(attributeName=metadata_key, - attributeValue=metadata_value) + for metadata_value in ["", " "]: + dapr.set_metadata(attributeName=metadata_key, attributeValue=metadata_value) response = dapr.get_metadata() self.assertIsNotNone(response) self.assertIsNotNone(response.extended_metadata) - self.assertEqual(response.extended_metadata[metadata_key], - metadata_value) + self.assertEqual(response.extended_metadata[metadata_key], metadata_value) def test_set_metadata_input_validation(self): - dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') - valid_attr_name = 'attribute name' - valid_attr_value = 'attribute value' + dapr = DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") + valid_attr_name = "attribute name" + valid_attr_value = "attribute value" # Invalid inputs for string arguments - with DaprGrpcClient(f'{self.scheme}localhost:{self.server_port}') as dapr: - for invalid_attr_name in [None, '', ' ']: + with DaprGrpcClient(f"{self.scheme}localhost:{self.server_port}") as dapr: + for invalid_attr_name in [None, "", " "]: with self.assertRaises(ValueError): dapr.set_metadata(invalid_attr_name, valid_attr_value) # We are less strict with attribute values - we just cannot accept None @@ -814,5 +798,5 @@ def test_set_metadata_input_validation(self): dapr.set_metadata(valid_attr_name, invalid_attr_value) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/clients/test_dapr_grpc_request.py b/tests/clients/test_dapr_grpc_request.py index 7324d90c2..ce689cc19 100644 --- a/tests/clients/test_dapr_grpc_request.py +++ b/tests/clients/test_dapr_grpc_request.py @@ -22,11 +22,11 @@ class InvokeMethodRequestTests(unittest.TestCase): def test_bytes_data(self): # act - req = InvokeMethodRequest(data=b'hello dapr') + req = InvokeMethodRequest(data=b"hello dapr") # arrange - self.assertEqual(b'hello dapr', req.data) - self.assertEqual('application/json; charset=utf-8', req.content_type) + self.assertEqual(b"hello dapr", req.data) + self.assertEqual("application/json; charset=utf-8", req.content_type) def test_proto_message_data(self): # arrange @@ -38,42 +38,42 @@ def test_proto_message_data(self): # assert self.assertIsNotNone(req.proto) self.assertEqual( - 'type.googleapis.com/dapr.proto.common.v1.InvokeRequest', - req.proto.type_url) + "type.googleapis.com/dapr.proto.common.v1.InvokeRequest", req.proto.type_url + ) self.assertIsNotNone(req.proto.value) self.assertIsNone(req.content_type) def test_invalid_data(self): with self.assertRaises(ValueError): data = InvokeMethodRequest(data=123) - self.assertIsNone(data, 'This should not be reached.') + self.assertIsNone(data, "This should not be reached.") class InvokeBindingRequestDataTests(unittest.TestCase): def test_bytes_data(self): # act - data = BindingRequest(data=b'hello dapr') + data = BindingRequest(data=b"hello dapr") # arrange - self.assertEqual(b'hello dapr', data.data) + self.assertEqual(b"hello dapr", data.data) self.assertEqual({}, data.metadata) def test_str_data(self): # act - data = BindingRequest(data='hello dapr') + data = BindingRequest(data="hello dapr") # arrange - self.assertEqual(b'hello dapr', data.data) + self.assertEqual(b"hello dapr", data.data) self.assertEqual({}, data.metadata) def test_non_empty_metadata(self): # act - data = BindingRequest(data='hello dapr', binding_metadata={'ttlInSeconds': '1000'}) + data = BindingRequest(data="hello dapr", binding_metadata={"ttlInSeconds": "1000"}) # arrange - self.assertEqual(b'hello dapr', data.data) - self.assertEqual({'ttlInSeconds': '1000'}, data.binding_metadata) + self.assertEqual(b"hello dapr", data.data) + self.assertEqual({"ttlInSeconds": "1000"}, data.binding_metadata) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/clients/test_dapr_grpc_response.py b/tests/clients/test_dapr_grpc_response.py index 972bb0ad8..01bd64dde 100644 --- a/tests/clients/test_dapr_grpc_response.py +++ b/tests/clients/test_dapr_grpc_response.py @@ -18,8 +18,11 @@ from google.protobuf.any_pb2 import Any as GrpcAny from dapr.clients.grpc._response import ( - DaprResponse, InvokeMethodResponse, BindingResponse, StateResponse, - BulkStateItem + DaprResponse, + InvokeMethodResponse, + BindingResponse, + StateResponse, + BulkStateItem, ) from dapr.proto import common_v1 @@ -27,9 +30,9 @@ class DaprResponseTests(unittest.TestCase): test_headers = ( - ('key1', 'value1'), - ('key2', 'value2'), - ('key3', 'value3'), + ("key1", "value1"), + ("key2", "value2"), + ("key3", "value3"), ) def test_convert_metadata(self): @@ -46,13 +49,11 @@ class InvokeMethodResponseTests(unittest.TestCase): def test_non_protobuf_message(self): with self.assertRaises(ValueError): resp = InvokeMethodResponse(data=123) - self.assertIsNone(resp, 'This should not be reached.') + self.assertIsNone(resp, "This should not be reached.") def test_is_proto_for_non_protobuf(self): - test_data = GrpcAny(value=b'hello dapr') - resp = InvokeMethodResponse( - data=test_data, - content_type='application/json') + test_data = GrpcAny(value=b"hello dapr") + resp = InvokeMethodResponse(data=test_data, content_type="application/json") self.assertFalse(resp.is_proto()) def test_is_proto_for_protobuf(self): @@ -68,17 +69,15 @@ def test_proto(self): self.assertIsNotNone(resp.proto) def test_data(self): - test_data = GrpcAny(value=b'hello dapr') - resp = InvokeMethodResponse( - data=test_data, - content_type='application/json') - self.assertEqual(b'hello dapr', resp.data) - self.assertEqual('hello dapr', resp.text()) - self.assertEqual('application/json', resp.content_type) + test_data = GrpcAny(value=b"hello dapr") + resp = InvokeMethodResponse(data=test_data, content_type="application/json") + self.assertEqual(b"hello dapr", resp.data) + self.assertEqual("hello dapr", resp.text()) + self.assertEqual("application/json", resp.content_type) def test_json_data(self): - resp = InvokeMethodResponse(data=b'{ "status": "ok" }', content_type='application/json') - self.assertEqual({'status': 'ok'}, resp.json()) + resp = InvokeMethodResponse(data=b'{ "status": "ok" }', content_type="application/json") + self.assertEqual({"status": "ok"}, resp.json()) def test_unpack(self): # arrange @@ -95,38 +94,38 @@ def test_unpack(self): class InvokeBindingResponseTests(unittest.TestCase): def test_bytes_message(self): - resp = BindingResponse(data=b'data', binding_metadata={}) + resp = BindingResponse(data=b"data", binding_metadata={}) self.assertEqual({}, resp.binding_metadata) - self.assertEqual(b'data', resp.data) - self.assertEqual('data', resp.text()) + self.assertEqual(b"data", resp.data) + self.assertEqual("data", resp.text()) def test_json_data(self): resp = BindingResponse(data=b'{"status": "ok"}', binding_metadata={}) - self.assertEqual({'status': 'ok'}, resp.json()) + self.assertEqual({"status": "ok"}, resp.json()) def test_metadata(self): - resp = BindingResponse(data=b'data', binding_metadata={'status': 'ok'}) - self.assertEqual({'status': 'ok'}, resp.binding_metadata) - self.assertEqual(b'data', resp.data) - self.assertEqual('data', resp.text()) + resp = BindingResponse(data=b"data", binding_metadata={"status": "ok"}) + self.assertEqual({"status": "ok"}, resp.binding_metadata) + self.assertEqual(b"data", resp.data) + self.assertEqual("data", resp.text()) class StateResponseTests(unittest.TestCase): def test_data(self): - resp = StateResponse(data=b'hello dapr') - self.assertEqual('hello dapr', resp.text()) - self.assertEqual(b'hello dapr', resp.data) + resp = StateResponse(data=b"hello dapr") + self.assertEqual("hello dapr", resp.text()) + self.assertEqual(b"hello dapr", resp.data) def test_json_data(self): resp = StateResponse(data=b'{"status": "ok"}') - self.assertEqual({'status': 'ok'}, resp.json()) + self.assertEqual({"status": "ok"}, resp.json()) class BulkStateItemTests(unittest.TestCase): def test_data(self): - item = BulkStateItem(key='item1', data=b'{ "status": "ok" }') - self.assertEqual({'status': 'ok'}, item.json()) + item = BulkStateItem(key="item1", data=b'{ "status": "ok" }') + self.assertEqual({"status": "ok"}, item.json()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/clients/test_http_service_invocation_client.py b/tests/clients/test_http_service_invocation_client.py index 587c4d8d1..039b55c73 100644 --- a/tests/clients/test_http_service_invocation_client.py +++ b/tests/clients/test_http_service_invocation_client.py @@ -33,34 +33,40 @@ def setUp(self): self.server_port = self.server.get_port() self.server.start() settings.DAPR_HTTP_PORT = self.server_port - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "http" self.client = DaprClient() - self.app_id = 'fakeapp' - self.method_name = 'fakemethod' - self.invoke_url = f'/v1.0/invoke/{self.app_id}/method/{self.method_name}' + self.app_id = "fakeapp" + self.method_name = "fakemethod" + self.invoke_url = f"/v1.0/invoke/{self.app_id}/method/{self.method_name}" def tearDown(self): self.server.shutdown_server() settings.DAPR_API_TOKEN = None - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "http" def test_get_api_url_default(self): client = DaprClient() self.assertEqual( - 'http://{}:{}/{}'.format(settings.DAPR_RUNTIME_HOST, settings.DAPR_HTTP_PORT, - settings.DAPR_API_VERSION), - client.invocation_client._client.get_api_url()) + "http://{}:{}/{}".format( + settings.DAPR_RUNTIME_HOST, settings.DAPR_HTTP_PORT, settings.DAPR_API_VERSION + ), + client.invocation_client._client.get_api_url(), + ) def test_get_api_url_endpoint_as_argument(self): client = DaprClient("http://localhost:5000") - self.assertEqual('http://localhost:5000/{}'.format(settings.DAPR_API_VERSION), - client.invocation_client._client.get_api_url()) + self.assertEqual( + "http://localhost:5000/{}".format(settings.DAPR_API_VERSION), + client.invocation_client._client.get_api_url(), + ) @patch.object(settings, "DAPR_HTTP_ENDPOINT", "https://domain1.com:5000") def test_get_api_url_endpoint_as_env_variable(self): client = DaprClient() - self.assertEqual('https://domain1.com:5000/{}'.format(settings.DAPR_API_VERSION), - client.invocation_client._client.get_api_url()) + self.assertEqual( + "https://domain1.com:5000/{}".format(settings.DAPR_API_VERSION), + client.invocation_client._client.get_api_url(), + ) def test_basic_invoke(self): self.server.set_response(b"STRING_BODY") @@ -74,9 +80,11 @@ def test_coroutine_basic_invoke(self): self.server.set_response(b"STRING_BODY") import asyncio + loop = asyncio.new_event_loop() response = loop.run_until_complete( - self.client.invoke_method_async(self.app_id, self.method_name, "")) + self.client.invoke_method_async(self.app_id, self.method_name, "") + ) self.assertEqual(b"STRING_BODY", response.data) self.assertEqual(self.invoke_url, self.server.request_path()) @@ -84,7 +92,7 @@ def test_coroutine_basic_invoke(self): def test_invoke_PUT_with_body(self): self.server.set_response(b"STRING_BODY") - response = self.client.invoke_method(self.app_id, self.method_name, b"FOO", http_verb='PUT') + response = self.client.invoke_method(self.app_id, self.method_name, b"FOO", http_verb="PUT") self.assertEqual(b"STRING_BODY", response.data) self.assertEqual(self.invoke_url, self.server.request_path()) @@ -93,7 +101,7 @@ def test_invoke_PUT_with_body(self): def test_invoke_PUT_with_bytes_body(self): self.server.set_response(b"STRING_BODY") - response = self.client.invoke_method(self.app_id, self.method_name, b"FOO", http_verb='PUT') + response = self.client.invoke_method(self.app_id, self.method_name, b"FOO", http_verb="PUT") self.assertEqual(b"STRING_BODY", response.data) self.assertEqual(self.invoke_url, self.server.request_path()) @@ -101,20 +109,22 @@ def test_invoke_PUT_with_bytes_body(self): def test_invoke_GET_with_query_params(self): self.server.set_response(b"STRING_BODY") - query_params = (('key1', 'value1'), ('key2', 'value2')) + query_params = (("key1", "value1"), ("key2", "value2")) - response = self.client.invoke_method(self.app_id, self.method_name, '', - http_querystring=query_params) + response = self.client.invoke_method( + self.app_id, self.method_name, "", http_querystring=query_params + ) self.assertEqual(b"STRING_BODY", response.data) self.assertEqual(f"{self.invoke_url}?key1=value1&key2=value2", self.server.request_path()) def test_invoke_GET_with_duplicate_query_params(self): self.server.set_response(b"STRING_BODY") - query_params = (('key1', 'value1'), ('key1', 'value2')) + query_params = (("key1", "value1"), ("key1", "value2")) - response = self.client.invoke_method(self.app_id, self.method_name, '', - http_querystring=query_params) + response = self.client.invoke_method( + self.app_id, self.method_name, "", http_querystring=query_params + ) self.assertEqual(b"STRING_BODY", response.data) self.assertEqual(f"{self.invoke_url}?key1=value1&key1=value2", self.server.request_path()) @@ -122,69 +132,87 @@ def test_invoke_GET_with_duplicate_query_params(self): def test_invoke_PUT_with_content_type(self): self.server.set_response(b"STRING_BODY") - sample_object = {'foo': ['val1', 'val2']} + sample_object = {"foo": ["val1", "val2"]} - response = self.client.invoke_method(self.app_id, self.method_name, - json.dumps(sample_object), - content_type='application/json') + response = self.client.invoke_method( + self.app_id, + self.method_name, + json.dumps(sample_object), + content_type="application/json", + ) self.assertEqual(b"STRING_BODY", response.data) self.assertEqual(b'{"foo": ["val1", "val2"]}', self.server.get_request_body()) def test_invoke_method_proto_data(self): self.server.set_response(b"\x0a\x04resp") - self.server.reply_header('Content-Type', 'application/x-protobuf') + self.server.reply_header("Content-Type", "application/x-protobuf") - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data=req) + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method(self.app_id, self.method_name, http_verb="PUT", data=req) self.assertEqual(b"\x0a\x04test", self.server.get_request_body()) # unpack to new protobuf object new_resp = common_v1.StateItem() - self.assertEqual(resp.headers['Content-Type'], ['application/x-protobuf']) + self.assertEqual(resp.headers["Content-Type"], ["application/x-protobuf"]) resp.unpack(new_resp) - self.assertEqual('resp', new_resp.key) + self.assertEqual("resp", new_resp.key) def test_invoke_method_metadata(self): self.server.set_response(b"FOO") - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data=req, - metadata=(('header1', 'value1'), ('header2', 'value2'))) + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data=req, + metadata=(("header1", "value1"), ("header2", "value2")), + ) request_headers = self.server.get_request_headers() - self.assertEqual(b'FOO', resp.data) + self.assertEqual(b"FOO", resp.data) - self.assertEqual('value1', request_headers['header1']) - self.assertEqual('value2', request_headers['header2']) + self.assertEqual("value1", request_headers["header1"]) + self.assertEqual("value2", request_headers["header2"]) def test_invoke_method_protobuf_response_with_suffix(self): self.server.set_response(b"\x0a\x04resp") - self.server.reply_header('Content-Type', 'application/x-protobuf; gzip') - - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data=req, - metadata=(('header1', 'value1'), ('header2', 'value2'))) + self.server.reply_header("Content-Type", "application/x-protobuf; gzip") + + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data=req, + metadata=(("header1", "value1"), ("header2", "value2")), + ) self.assertEqual(b"\x0a\x04test", self.server.get_request_body()) # unpack to new protobuf object new_resp = common_v1.StateItem() resp.unpack(new_resp) - self.assertEqual('resp', new_resp.key) + self.assertEqual("resp", new_resp.key) def test_invoke_method_protobuf_response_case_insensitive(self): self.server.set_response(b"\x0a\x04resp") - self.server.reply_header('Content-Type', 'apPlicaTion/x-protobuf; gzip') + self.server.reply_header("Content-Type", "apPlicaTion/x-protobuf; gzip") - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data=req, - metadata=(('header1', 'value1'), ('header2', 'value2'))) + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data=req, + metadata=(("header1", "value1"), ("header2", "value2")), + ) self.assertEqual(b"\x0a\x04test", self.server.get_request_body()) # unpack to new protobuf object new_resp = common_v1.StateItem() resp.unpack(new_resp) - self.assertEqual('resp', new_resp.key) + self.assertEqual("resp", new_resp.key) def test_invoke_method_error_returned(self): error_response = b'{"errorCode":"ERR_DIRECT_INVOKE","message":"Something bad happend"}' @@ -193,21 +221,31 @@ def test_invoke_method_error_returned(self): expected_msg = "('Something bad happend', 'ERR_DIRECT_INVOKE')" with self.assertRaises(DaprInternalError) as ctx: - self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data='FOO', ) + self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data="FOO", + ) self.assertEqual(expected_msg, str(ctx.exception)) def test_invoke_method_non_dapr_error(self): - error_response = b'UNPARSABLE_ERROR' + error_response = b"UNPARSABLE_ERROR" self.server.set_response(error_response, 500) expected_msg = "Unknown Dapr Error. HTTP status code: 500" with self.assertRaises(DaprInternalError) as ctx: - self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data='FOO', ) + self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data="FOO", + ) self.assertEqual(expected_msg, str(ctx.exception)) def test_generic_client_unknown_protocol(self): - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'unknown' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "unknown" expected_msg = "Unknown value for DAPR_API_METHOD_INVOCATION_PROTOCOL: UNKNOWN" @@ -216,44 +254,54 @@ def test_generic_client_unknown_protocol(self): self.assertEqual(expected_msg, str(ctx.exception)) - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'grpc' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "grpc" client = DaprClient() self.assertIsNotNone(client) - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "http" client = DaprClient() self.assertIsNotNone(client) def test_invoke_method_with_api_token(self): self.server.set_response(b"FOO") - settings.DAPR_API_TOKEN = 'c29saSBkZW8gZ2xvcmlhCg==' + settings.DAPR_API_TOKEN = "c29saSBkZW8gZ2xvcmlhCg==" - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', data=req, ) + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data=req, + ) request_headers = self.server.get_request_headers() - self.assertEqual('c29saSBkZW8gZ2xvcmlhCg==', request_headers['dapr-api-token']) - self.assertEqual(b'FOO', resp.data) + self.assertEqual("c29saSBkZW8gZ2xvcmlhCg==", request_headers["dapr-api-token"]) + self.assertEqual(b"FOO", resp.data) def test_invoke_method_with_tracer(self): tracer = Tracer(sampler=samplers.AlwaysOnSampler(), exporter=print_exporter.PrintExporter()) self.client = DaprClient( - headers_callback=lambda: tracer.propagator.to_headers(tracer.span_context)) + headers_callback=lambda: tracer.propagator.to_headers(tracer.span_context) + ) self.server.set_response(b"FOO") with tracer.span(name="test"): - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', - data=req, ) + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data=req, + ) request_headers = self.server.get_request_headers() - self.assertIn('Traceparent', request_headers) - self.assertEqual(b'FOO', resp.data) + self.assertIn("Traceparent", request_headers) + self.assertEqual(b"FOO", resp.data) def test_timeout_exception_thrown_when_timeout_reached(self): new_client = DaprClient(http_timeout_seconds=1) diff --git a/tests/clients/test_secure_dapr_async_grpc_client.py b/tests/clients/test_secure_dapr_async_grpc_client.py index c9780e60c..f054a968c 100644 --- a/tests/clients/test_secure_dapr_async_grpc_client.py +++ b/tests/clients/test_secure_dapr_async_grpc_client.py @@ -29,7 +29,7 @@ # Used temporarily, so we can trust self-signed certificates in unit tests # until they get their own environment variable def replacement_get_credentials_func(a): - f = open(os.path.join(os.path.dirname(__file__), 'selfsigned.pem'), 'rb') + f = open(os.path.join(os.path.dirname(__file__), "selfsigned.pem"), "rb") creds = grpc.ssl_channel_credentials(f.read()) f.close() @@ -41,7 +41,7 @@ def replacement_get_credentials_func(a): class DaprSecureGrpcClientAsyncTests(DaprGrpcClientAsyncTests): server_port = 4443 - scheme = 'https://' + scheme = "https://" def setUp(self): self._fake_dapr_server = FakeDaprSidecar() @@ -77,5 +77,5 @@ async def test_dapr_api_token_insertion(self): pass -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/clients/test_secure_dapr_grpc_client.py b/tests/clients/test_secure_dapr_grpc_client.py index e9f227aca..2a9bdee6b 100644 --- a/tests/clients/test_secure_dapr_grpc_client.py +++ b/tests/clients/test_secure_dapr_grpc_client.py @@ -27,7 +27,7 @@ # Used temporarily, so we can trust self-signed certificates in unit tests # until they get their own environment variable def replacement_get_credentials_func(a): - f = open(os.path.join(os.path.dirname(__file__), 'selfsigned.pem'), 'rb') + f = open(os.path.join(os.path.dirname(__file__), "selfsigned.pem"), "rb") creds = grpc.ssl_channel_credentials(f.read()) f.close() @@ -39,7 +39,7 @@ def replacement_get_credentials_func(a): class DaprSecureGrpcClientTests(DaprGrpcClientTests): server_port = 4443 - scheme = 'https://' + scheme = "https://" def setUp(self): self._fake_dapr_server = FakeDaprSidecar() @@ -72,5 +72,5 @@ def test_init_with_argument_and_DAPR_GRPC_ENDPOINT_and_DAPR_RUNTIME_HOST(self): self.assertEqual("dns:domain2.com:5002", dapr._uri.endpoint) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/clients/test_secure_http_service_invocation_client.py b/tests/clients/test_secure_http_service_invocation_client.py index 704eb0b8e..5a5455bdb 100644 --- a/tests/clients/test_secure_http_service_invocation_client.py +++ b/tests/clients/test_secure_http_service_invocation_client.py @@ -44,16 +44,16 @@ def setUp(self): self.server_port = self.server.get_port() self.server.start() settings.DAPR_HTTP_PORT = self.server_port - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "http" self.client = DaprClient("https://localhost:{}".format(self.server_port)) - self.app_id = 'fakeapp' - self.method_name = 'fakemethod' - self.invoke_url = f'/v1.0/invoke/{self.app_id}/method/{self.method_name}' + self.app_id = "fakeapp" + self.method_name = "fakemethod" + self.invoke_url = f"/v1.0/invoke/{self.app_id}/method/{self.method_name}" def tearDown(self): self.server.shutdown_server() settings.DAPR_API_TOKEN = None - settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' + settings.DAPR_API_METHOD_INVOCATION_PROTOCOL = "http" def test_global_timeout_setting_is_honored(self): previous_timeout = settings.DAPR_HTTP_TIMEOUT_SECONDS @@ -68,24 +68,30 @@ def test_global_timeout_setting_is_honored(self): def test_invoke_method_with_tracer(self): tracer = Tracer(sampler=samplers.AlwaysOnSampler(), exporter=print_exporter.PrintExporter()) - self.client = DaprClient("https://localhost:{}".format(self.server_port), - headers_callback=lambda: tracer.propagator.to_headers( - tracer.span_context)) + self.client = DaprClient( + "https://localhost:{}".format(self.server_port), + headers_callback=lambda: tracer.propagator.to_headers(tracer.span_context), + ) self.server.set_response(b"FOO") with tracer.span(name="test"): - req = common_v1.StateItem(key='test') - resp = self.client.invoke_method(self.app_id, self.method_name, http_verb='PUT', - data=req, ) + req = common_v1.StateItem(key="test") + resp = self.client.invoke_method( + self.app_id, + self.method_name, + http_verb="PUT", + data=req, + ) request_headers = self.server.get_request_headers() - self.assertIn('Traceparent', request_headers) - self.assertEqual(b'FOO', resp.data) + self.assertIn("Traceparent", request_headers) + self.assertEqual(b"FOO", resp.data) def test_timeout_exception_thrown_when_timeout_reached(self): - new_client = DaprClient("https://localhost:{}".format(self.server_port), - http_timeout_seconds=1) + new_client = DaprClient( + "https://localhost:{}".format(self.server_port), http_timeout_seconds=1 + ) self.server.set_server_delay(1.5) with self.assertRaises(TimeoutError): new_client.invoke_method(self.app_id, self.method_name, "") diff --git a/tests/conf/helpers_test.py b/tests/conf/helpers_test.py index 66c937ff7..456f40d75 100644 --- a/tests/conf/helpers_test.py +++ b/tests/conf/helpers_test.py @@ -4,133 +4,346 @@ class DaprClientHelpersTests(unittest.TestCase): - def test_parse_grpc_endpoint(self): testcases = [ # Port only - {"url": ":5000", "error": False, "secure": False, "scheme": "", "host": "localhost", - "port": 5000, "endpoint": "dns:localhost:5000"}, - {"url": ":5000?tls=false", "error": False, "secure": False, "scheme": "", - "host": "localhost", "port": 5000, "endpoint": "dns:localhost:5000"}, - {"url": ":5000?tls=true", "error": False, "secure": True, "scheme": "", - "host": "localhost", "port": 5000, "endpoint": "dns:localhost:5000"}, - + { + "url": ":5000", + "error": False, + "secure": False, + "scheme": "", + "host": "localhost", + "port": 5000, + "endpoint": "dns:localhost:5000", + }, + { + "url": ":5000?tls=false", + "error": False, + "secure": False, + "scheme": "", + "host": "localhost", + "port": 5000, + "endpoint": "dns:localhost:5000", + }, + { + "url": ":5000?tls=true", + "error": False, + "secure": True, + "scheme": "", + "host": "localhost", + "port": 5000, + "endpoint": "dns:localhost:5000", + }, # Host only - {"url": "myhost", "error": False, "secure": False, "scheme": "", "host": "myhost", - "port": 443, "endpoint": "dns:myhost:443"}, - {"url": "myhost?tls=false", "error": False, "secure": False, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - {"url": "myhost?tls=true", "error": False, "secure": True, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - + { + "url": "myhost", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, + { + "url": "myhost?tls=false", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, + { + "url": "myhost?tls=true", + "error": False, + "secure": True, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, # Host and port - {"url": "myhost:443", "error": False, "secure": False, "scheme": "", "host": "myhost", - "port": 443, "endpoint": "dns:myhost:443"}, - {"url": "myhost:443?tls=false", "error": False, "secure": False, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - {"url": "myhost:443?tls=true", "error": False, "secure": True, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - + { + "url": "myhost:443", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, + { + "url": "myhost:443?tls=false", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, + { + "url": "myhost:443?tls=true", + "error": False, + "secure": True, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, # Scheme, host and port - {"url": "http://myhost", "error": False, "secure": False, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, + { + "url": "http://myhost", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, {"url": "http://myhost?tls=false", "error": True}, # We can't have both http/https and the tls query parameter {"url": "http://myhost?tls=true", "error": True}, # We can't have both http/https and the tls query parameter - - {"url": "http://myhost:443", "error": False, "secure": False, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, + { + "url": "http://myhost:443", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, {"url": "http://myhost:443?tls=false", "error": True}, # We can't have both http/https and the tls query parameter {"url": "http://myhost:443?tls=true", "error": True}, # We can't have both http/https and the tls query parameter - - {"url": "http://myhost:5000", "error": False, "secure": False, "scheme": "", - "host": "myhost", "port": 5000, "endpoint": "dns:myhost:5000"}, + { + "url": "http://myhost:5000", + "error": False, + "secure": False, + "scheme": "", + "host": "myhost", + "port": 5000, + "endpoint": "dns:myhost:5000", + }, {"url": "http://myhost:5000?tls=false", "error": True}, # We can't have both http/https and the tls query parameter {"url": "http://myhost:5000?tls=true", "error": True}, # We can't have both http/https and the tls query parameter - - {"url": "https://myhost:443", "error": False, "secure": True, "scheme": "", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, + { + "url": "https://myhost:443", + "error": False, + "secure": True, + "scheme": "", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, {"url": "https://myhost:443?tls=false", "error": True}, {"url": "https://myhost:443?tls=true", "error": True}, - # Scheme = dns - {"url": "dns:myhost", "error": False, "secure": False, "scheme": "dns", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - {"url": "dns:myhost?tls=false", "error": False, "secure": False, "scheme": "dns", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - {"url": "dns:myhost?tls=true", "error": False, "secure": True, "scheme": "dns", - "host": "myhost", "port": 443, "endpoint": "dns:myhost:443"}, - + { + "url": "dns:myhost", + "error": False, + "secure": False, + "scheme": "dns", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, + { + "url": "dns:myhost?tls=false", + "error": False, + "secure": False, + "scheme": "dns", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, + { + "url": "dns:myhost?tls=true", + "error": False, + "secure": True, + "scheme": "dns", + "host": "myhost", + "port": 443, + "endpoint": "dns:myhost:443", + }, # Scheme = dns with authority - {"url": "dns://myauthority:53/myhost", "error": False, "secure": False, "scheme": "dns", - "host": "myhost", "port": 443, "endpoint": "dns://myauthority:53/myhost:443"}, - {"url": "dns://myauthority:53/myhost?tls=false", "error": False, "secure": False, - "scheme": "dns", "host": "myhost", "port": 443, - "endpoint": "dns://myauthority:53/myhost:443"}, - {"url": "dns://myauthority:53/myhost?tls=true", "error": False, "secure": True, - "scheme": "dns", "host": "myhost", "port": 443, - "endpoint": "dns://myauthority:53/myhost:443"}, {"url": "dns://myhost", "error": True}, - + { + "url": "dns://myauthority:53/myhost", + "error": False, + "secure": False, + "scheme": "dns", + "host": "myhost", + "port": 443, + "endpoint": "dns://myauthority:53/myhost:443", + }, + { + "url": "dns://myauthority:53/myhost?tls=false", + "error": False, + "secure": False, + "scheme": "dns", + "host": "myhost", + "port": 443, + "endpoint": "dns://myauthority:53/myhost:443", + }, + { + "url": "dns://myauthority:53/myhost?tls=true", + "error": False, + "secure": True, + "scheme": "dns", + "host": "myhost", + "port": 443, + "endpoint": "dns://myauthority:53/myhost:443", + }, + {"url": "dns://myhost", "error": True}, # Unix sockets - {"url": "unix:my.sock", "error": False, "secure": False, "scheme": "unix", - "host": "my.sock", "port": "", "endpoint": "unix:my.sock"}, - {"url": "unix:my.sock?tls=true", "error": False, "secure": True, "scheme": "unix", - "host": "my.sock", "port": "", "endpoint": "unix:my.sock"}, - + { + "url": "unix:my.sock", + "error": False, + "secure": False, + "scheme": "unix", + "host": "my.sock", + "port": "", + "endpoint": "unix:my.sock", + }, + { + "url": "unix:my.sock?tls=true", + "error": False, + "secure": True, + "scheme": "unix", + "host": "my.sock", + "port": "", + "endpoint": "unix:my.sock", + }, # Unix sockets with absolute path - {"url": "unix://my.sock", "error": False, "secure": False, "scheme": "unix", - "host": "my.sock", "port": "", "endpoint": "unix://my.sock"}, - {"url": "unix://my.sock?tls=true", "error": False, "secure": True, "scheme": "unix", - "host": "my.sock", "port": "", "endpoint": "unix://my.sock"}, - + { + "url": "unix://my.sock", + "error": False, + "secure": False, + "scheme": "unix", + "host": "my.sock", + "port": "", + "endpoint": "unix://my.sock", + }, + { + "url": "unix://my.sock?tls=true", + "error": False, + "secure": True, + "scheme": "unix", + "host": "my.sock", + "port": "", + "endpoint": "unix://my.sock", + }, # Unix abstract sockets - {"url": "unix-abstract:my.sock", "error": False, "secure": False, "scheme": "unix", - "host": "my.sock", "port": "", "endpoint": "unix-abstract:my.sock"}, - {"url": "unix-abstract:my.sock?tls=true", "error": False, "secure": True, - "scheme": "unix", "host": "my.sock", "port": "", "endpoint": "unix-abstract:my.sock"}, - + { + "url": "unix-abstract:my.sock", + "error": False, + "secure": False, + "scheme": "unix", + "host": "my.sock", + "port": "", + "endpoint": "unix-abstract:my.sock", + }, + { + "url": "unix-abstract:my.sock?tls=true", + "error": False, + "secure": True, + "scheme": "unix", + "host": "my.sock", + "port": "", + "endpoint": "unix-abstract:my.sock", + }, # Vsock - {"url": "vsock:mycid", "error": False, "secure": False, "scheme": "vsock", - "host": "mycid", "port": "443", "endpoint": "vsock:mycid:443"}, - {"url": "vsock:mycid:5000", "error": False, "secure": False, "scheme": "vsock", - "host": "mycid", "port": 5000, "endpoint": "vsock:mycid:5000"}, - {"url": "vsock:mycid:5000?tls=true", "error": False, "secure": True, "scheme": "vsock", - "host": "mycid", "port": 5000, "endpoint": "vsock:mycid:5000"}, - + { + "url": "vsock:mycid", + "error": False, + "secure": False, + "scheme": "vsock", + "host": "mycid", + "port": "443", + "endpoint": "vsock:mycid:443", + }, + { + "url": "vsock:mycid:5000", + "error": False, + "secure": False, + "scheme": "vsock", + "host": "mycid", + "port": 5000, + "endpoint": "vsock:mycid:5000", + }, + { + "url": "vsock:mycid:5000?tls=true", + "error": False, + "secure": True, + "scheme": "vsock", + "host": "mycid", + "port": 5000, + "endpoint": "vsock:mycid:5000", + }, # IPv6 addresses with dns scheme - {"url": "[2001:db8:1f70::999:de8:7648:6e8]", "error": False, "secure": False, - "scheme": "", "host": "[2001:db8:1f70::999:de8:7648:6e8]", "port": 443, - "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:443"}, - {"url": "dns:[2001:db8:1f70::999:de8:7648:6e8]", "error": False, "secure": False, - "scheme": "", "host": "[2001:db8:1f70::999:de8:7648:6e8]", "port": 443, - "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:443"}, - {"url": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000", "error": False, "secure": False, - "scheme": "", "host": "[2001:db8:1f70::999:de8:7648:6e8]", "port": 5000, - "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000"}, + { + "url": "[2001:db8:1f70::999:de8:7648:6e8]", + "error": False, + "secure": False, + "scheme": "", + "host": "[2001:db8:1f70::999:de8:7648:6e8]", + "port": 443, + "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:443", + }, + { + "url": "dns:[2001:db8:1f70::999:de8:7648:6e8]", + "error": False, + "secure": False, + "scheme": "", + "host": "[2001:db8:1f70::999:de8:7648:6e8]", + "port": 443, + "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:443", + }, + { + "url": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000", + "error": False, + "secure": False, + "scheme": "", + "host": "[2001:db8:1f70::999:de8:7648:6e8]", + "port": 5000, + "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000", + }, {"url": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000?abc=[]", "error": True}, - # IPv6 addresses with dns scheme and authority - {"url": "dns://myauthority:53/[2001:db8:1f70::999:de8:7648:6e8]", "error": False, - "secure": False, "scheme": "dns", "host": "[2001:db8:1f70::999:de8:7648:6e8]", - "port": 443, "endpoint": "dns://myauthority:53/[2001:db8:1f70::999:de8:7648:6e8]:443"}, - + { + "url": "dns://myauthority:53/[2001:db8:1f70::999:de8:7648:6e8]", + "error": False, + "secure": False, + "scheme": "dns", + "host": "[2001:db8:1f70::999:de8:7648:6e8]", + "port": 443, + "endpoint": "dns://myauthority:53/[2001:db8:1f70::999:de8:7648:6e8]:443", + }, # IPv6 addresses with https scheme - {"url": "https://[2001:db8:1f70::999:de8:7648:6e8]", "error": False, "secure": True, - "scheme": "", "host": "[2001:db8:1f70::999:de8:7648:6e8]", "port": 443, - "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:443"}, - {"url": "https://[2001:db8:1f70::999:de8:7648:6e8]:5000", "error": False, - "secure": True, "scheme": "", "host": "[2001:db8:1f70::999:de8:7648:6e8]", - "port": 5000, "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000"}, - + { + "url": "https://[2001:db8:1f70::999:de8:7648:6e8]", + "error": False, + "secure": True, + "scheme": "", + "host": "[2001:db8:1f70::999:de8:7648:6e8]", + "port": 443, + "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:443", + }, + { + "url": "https://[2001:db8:1f70::999:de8:7648:6e8]:5000", + "error": False, + "secure": True, + "scheme": "", + "host": "[2001:db8:1f70::999:de8:7648:6e8]", + "port": 5000, + "endpoint": "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000", + }, # Invalid addresses (with path and queries) {"url": "host:5000/v1/dapr", "error": True}, # Paths are not allowed in grpc endpoints {"url": "host:5000/?a=1", "error": True}, # Query params not allowed in grpc endpoints - # Invalid scheme {"url": "inv-scheme://myhost", "error": True}, {"url": "inv-scheme:myhost:5000", "error": True}, diff --git a/tests/serializers/test_default_json_serializer.py b/tests/serializers/test_default_json_serializer.py index 1b2734cb1..fce994760 100644 --- a/tests/serializers/test_default_json_serializer.py +++ b/tests/serializers/test_default_json_serializer.py @@ -23,27 +23,35 @@ class DefaultJSONSerializerTests(unittest.TestCase): def test_serialize(self): serializer = DefaultJSONSerializer() fakeDateTime = datetime.datetime( - year=2020, month=1, day=1, hour=1, minute=0, - second=0, microsecond=0, tzinfo=datetime.timezone.utc) + year=2020, + month=1, + day=1, + hour=1, + minute=0, + second=0, + microsecond=0, + tzinfo=datetime.timezone.utc, + ) input_dict_obj = { - 'propertyDecimal': 10, - 'propertyStr': 'StrValue', - 'propertyDateTime': fakeDateTime, + "propertyDecimal": 10, + "propertyStr": "StrValue", + "propertyDateTime": fakeDateTime, } serialized = serializer.serialize(input_dict_obj) - self.assertEqual(serialized, b'{"propertyDecimal":10,"propertyStr":"StrValue","propertyDateTime":"2020-01-01T01:00:00Z"}') # noqa: E501 + self.assertEqual( + serialized, + b'{"propertyDecimal":10,"propertyStr":"StrValue","propertyDateTime":"2020-01-01T01:00:00Z"}', + ) # noqa: E501 def test_serialize_bytes(self): serializer = DefaultJSONSerializer() # Serialize`bytes data - serialized = serializer.serialize(b'bytes_data') + serialized = serializer.serialize(b"bytes_data") self.assertEqual(b'"Ynl0ZXNfZGF0YQ=="', serialized) # Serialize`bytes property - input_dict_obj = { - 'propertyBytes': b'bytes_property' - } + input_dict_obj = {"propertyBytes": b"bytes_property"} serialized = serializer.serialize(input_dict_obj) self.assertEqual(serialized, b'{"propertyBytes":"Ynl0ZXNfcHJvcGVydHk="}') @@ -52,14 +60,23 @@ def test_deserialize(self): payload = b'{"propertyDecimal":10,"propertyStr":"StrValue","propertyDateTime":"2020-01-01T01:00:00Z"}' # noqa: E501 obj = serializer.deserialize(payload) - self.assertEqual(obj['propertyDecimal'], 10) - self.assertEqual(obj['propertyStr'], 'StrValue') - self.assertTrue(isinstance(obj['propertyDateTime'], datetime.datetime)) - self.assertEqual(obj['propertyDateTime'], datetime.datetime( - year=2020, month=1, day=1, hour=1, minute=0, - second=0, microsecond=0, - tzinfo=datetime.timezone.utc)) + self.assertEqual(obj["propertyDecimal"], 10) + self.assertEqual(obj["propertyStr"], "StrValue") + self.assertTrue(isinstance(obj["propertyDateTime"], datetime.datetime)) + self.assertEqual( + obj["propertyDateTime"], + datetime.datetime( + year=2020, + month=1, + day=1, + hour=1, + minute=0, + second=0, + microsecond=0, + tzinfo=datetime.timezone.utc, + ), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/serializers/test_util.py b/tests/serializers/test_util.py index 9423f8648..aa9a2e963 100644 --- a/tests/serializers/test_util.py +++ b/tests/serializers/test_util.py @@ -26,49 +26,54 @@ def setUp(self): pass def test_convert_hour_mins_secs(self): - delta = convert_from_dapr_duration('4h15m40s') + delta = convert_from_dapr_duration("4h15m40s") self.assertEqual(delta.total_seconds(), 15340.0) def test_convert_mins_secs(self): - delta = convert_from_dapr_duration('15m40s') + delta = convert_from_dapr_duration("15m40s") self.assertEqual(delta.total_seconds(), 940.0) def test_convert_secs(self): - delta = convert_from_dapr_duration('40s') + delta = convert_from_dapr_duration("40s") self.assertEqual(delta.total_seconds(), 40.0) def test_convert_millisecs(self): - delta = convert_from_dapr_duration('123ms') + delta = convert_from_dapr_duration("123ms") self.assertEqual(delta.total_seconds(), 0.123) def test_convert_microsecs_μs(self): - delta = convert_from_dapr_duration('123μs') + delta = convert_from_dapr_duration("123μs") self.assertEqual(delta.microseconds, 123) def test_convert_microsecs_us(self): - delta = convert_from_dapr_duration('345us') + delta = convert_from_dapr_duration("345us") self.assertEqual(delta.microseconds, 345) def test_convert_invalid_duration(self): with self.assertRaises(ValueError) as exeception_context: - convert_from_dapr_duration('invalid') - self.assertEqual(exeception_context.exception.args[0], - "Invalid Dapr Duration format: '{}'".format('invalid')) + convert_from_dapr_duration("invalid") + self.assertEqual( + exeception_context.exception.args[0], + "Invalid Dapr Duration format: '{}'".format("invalid"), + ) def test_convert_timedelta_to_dapr_duration(self): duration = convert_to_dapr_duration( - timedelta(hours=4, minutes=15, seconds=40, milliseconds=123, microseconds=35)) - self.assertEqual(duration, '4h15m40s123ms35μs') + timedelta(hours=4, minutes=15, seconds=40, milliseconds=123, microseconds=35) + ) + self.assertEqual(duration, "4h15m40s123ms35μs") def test_convert_invalid_duration_string(self): - TESTSTRING = '4h15m40s123ms35μshello' + TESTSTRING = "4h15m40s123ms35μshello" with self.assertRaises(ValueError) as exeception_context: convert_from_dapr_duration(TESTSTRING) - self.assertEqual(exeception_context.exception.args[0], - "Invalid Dapr Duration format: '{}'".format(TESTSTRING)) + self.assertEqual( + exeception_context.exception.args[0], + "Invalid Dapr Duration format: '{}'".format(TESTSTRING), + ) decoded = json.loads(json.dumps({"somevar": TESTSTRING}), cls=DaprJSONDecoder) - self.assertEqual(decoded['somevar'], TESTSTRING) + self.assertEqual(decoded["somevar"], TESTSTRING) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tox.ini b/tox.ini index c6cbff430..e33bd3c30 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,7 @@ minversion = 3.8.0 envlist = py{38,39,310,311,312} flake8, + ruff, mypy, [testenv] @@ -31,6 +32,13 @@ deps = flake8 commands = flake8 . +[testenv:ruff] +basepython = python3 +usedevelop = False +deps = ruff +commands = + ruff format + [testenv:examples] passenv = HOME basepython = python3