From 4ed415e8f6935ef60d766cf4f9d2ef7d8c897898 Mon Sep 17 00:00:00 2001 From: aleksandarmijat Date: Wed, 5 Jun 2024 12:24:55 +0200 Subject: [PATCH] Domain Payload Optimization to Action server --- rasa_sdk/endpoint.py | 10 +++- rasa_sdk/executor.py | 60 ++++++++++++++++--- rasa_sdk/interfaces.py | 11 ++++ tests/test_endpoint.py | 4 ++ tests/tracing/instrumentation/test_tracing.py | 1 + 5 files changed, 77 insertions(+), 9 deletions(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index ec8d4fb86..5818026b1 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -36,7 +36,11 @@ DEFAULT_SERVER_PORT, ) from rasa_sdk.executor import ActionExecutor - from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException + from rasa_sdk.interfaces import ( + ActionExecutionRejection, + ActionNotFoundException, + ActionMissingDomainException, + ) from rasa_sdk.plugin import plugin_manager from rasa_sdk.tracing.utils import ( get_tracer_and_context, @@ -153,6 +157,10 @@ async def webhook(request: Request) -> HTTPResponse: logger.error(e) body = {"error": e.message, "action_name": e.action_name} return response.json(body, status=404) + except ActionMissingDomainException as e: + logger.error(e) + body = {"error": e.message, "action_name": e.action_name} + return response.json(body, status=449) set_span_attributes(span, action_call) diff --git a/rasa_sdk/executor.py b/rasa_sdk/executor.py index 5d3e7a9f4..802d84a7a 100644 --- a/rasa_sdk/executor.py +++ b/rasa_sdk/executor.py @@ -2,7 +2,6 @@ import inspect import logging import pkgutil -import typing import warnings from typing import Text, List, Dict, Any, Type, Union, Callable, Optional, Set, cast from collections import namedtuple @@ -10,13 +9,15 @@ import sys import os -from rasa_sdk.interfaces import Tracker, ActionNotFoundException, Action +from rasa_sdk.interfaces import ( + Tracker, + ActionNotFoundException, + Action, + ActionMissingDomainException, +) from rasa_sdk import utils -if typing.TYPE_CHECKING: # pragma: no cover - from rasa_sdk.types import ActionCall - logger = logging.getLogger(__name__) @@ -24,7 +25,6 @@ class CollectingDispatcher: """Send messages back to user""" def __init__(self) -> None: - self.messages: List[Dict[Text, Any]] = [] def utter_message( @@ -162,6 +162,8 @@ def __init__(self) -> None: self.actions: Dict[Text, Callable] = {} self._modules: Dict[Text, TimestampModule] = {} self._loaded: Set[Type[Action]] = set() + self.domain: Optional[Dict[Text, Any]] = None + self.domain_digest: Optional[Text] = None def register_action(self, action: Union[Type[Action], Action]) -> None: if inspect.isclass(action): @@ -380,7 +382,49 @@ def validate_events(events: List[Dict[Text, Any]], action_name: Text): # we won't append this to validated events -> will be ignored return validated - async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]: + def is_domain_digest_valid(self, domain_digest: Optional[Text]) -> bool: + """Check if the domain_digest is valid + If the domain_digest is empty or different from the one provided, it is invalid. + Args: + domain_digest: latest value provided to compare the current value with. + Returns: + True if the domain_digest is valid, False otherwise. + """ + return bool(self.domain_digest) and self.domain_digest == domain_digest + + def update_and_return_domain( + self, payload: Dict[Text, Any], action_name: Text + ) -> Optional[Dict[Text, Any]]: + """Validate the digest, store the domain if available, and return the domain. + This method validates the domain digest from the payload. + If the digest is invalid and no domain is provided, an exception is raised. + If domain data is available, it stores the domain and digest. + Finally, it returns the domain. + Args: + payload: Request payload containing the domain data. + action_name: Name of the action that should be executed. + Returns: + The domain dictionary. + Raises: + ActionMissingDomainException: Invalid digest and no domain data available. + """ + payload_domain = payload.get("domain") + payload_domain_digest = payload.get("domain_digest") + + # If digest is invalid and no domain is available - raise the error + if ( + not self.is_domain_digest_valid(payload_domain_digest) + and payload_domain is None + ): + raise ActionMissingDomainException(action_name) + + if payload_domain: + self.domain = payload_domain + self.domain_digest = payload_domain_digest + + return self.domain + + async def run(self, action_call: Dict[Text, Any]) -> Optional[Dict[Text, Any]]: from rasa_sdk.interfaces import Tracker action_name = action_call.get("next_action") @@ -391,7 +435,7 @@ async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]: raise ActionNotFoundException(action_name) tracker_json = action_call["tracker"] - domain = action_call.get("domain", {}) + domain = self.update_and_return_domain(action_call, action_name) tracker = Tracker.from_dict(tracker_json) dispatcher = CollectingDispatcher() diff --git a/rasa_sdk/interfaces.py b/rasa_sdk/interfaces.py index 79b81c597..454ad24b7 100644 --- a/rasa_sdk/interfaces.py +++ b/rasa_sdk/interfaces.py @@ -384,3 +384,14 @@ def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: def __str__(self) -> Text: return self.message + + +class ActionMissingDomainException(Exception): + """Raising this exception when the domain is missing.""" + + def __init__(self, action_name: Text, message: Optional[Text] = None) -> None: + self.action_name = action_name + self.message = message or "Domain context is missing." + + def __str__(self) -> Text: + return self.message diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 315a4f3a9..820652825 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -65,6 +65,7 @@ def test_server_webhook_handles_action_exception(sanic_app: Sanic): data = { "next_action": "custom_action_exception", "tracker": {"sender_id": "1", "conversation_id": "default"}, + "domain": {}, } request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 500 @@ -76,6 +77,7 @@ def test_server_webhook_custom_action_returns_200(sanic_app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, + "domain": {}, } request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") @@ -88,6 +90,7 @@ def test_server_webhook_custom_async_action_returns_200(sanic_app: Sanic): data = { "next_action": "custom_async_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, + "domain": {}, } request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") @@ -148,6 +151,7 @@ def test_server_webhook_custom_action_with_dialogue_stack_returns_200( data = { "next_action": "custom_action_with_dialogue_stack", "tracker": {"sender_id": "1", "conversation_id": "default", **stack_state}, + "domain": {}, } _, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") diff --git a/tests/tracing/instrumentation/test_tracing.py b/tests/tracing/instrumentation/test_tracing.py index f1c7a411e..261f76054 100644 --- a/tests/tracing/instrumentation/test_tracing.py +++ b/tests/tracing/instrumentation/test_tracing.py @@ -45,6 +45,7 @@ def test_server_webhook_custom_action_is_instrumented( "rasa_sdk.endpoint.get_tracer_provider", lambda _: tracer_provider ) data["next_action"] = action_name + data["domain"] = {} app = ep.create_app(action_package) app.register_listener(