diff --git a/cylc/uiserver/data_store_mgr.py b/cylc/uiserver/data_store_mgr.py index 2e8d1240..ad823402 100644 --- a/cylc/uiserver/data_store_mgr.py +++ b/cylc/uiserver/data_store_mgr.py @@ -37,6 +37,7 @@ from functools import partial from pathlib import Path import time +from typing import Optional from cylc.flow import ID_DELIM from cylc.flow.network.server import PB_METHOD_MAP @@ -342,7 +343,7 @@ def reconcile_update(self, topic, delta, w_id): ), self.loop ) - _, new_delta_msg = future.result(self.RECONCILE_TIMEOUT) + new_delta_msg = future.result(self.RECONCILE_TIMEOUT) new_delta = DELTAS_MAP[topic]() new_delta.ParseFromString(new_delta_msg) self.clear_data_field(w_id, topic) @@ -358,12 +359,11 @@ def reconcile_update(self, topic, delta, w_id): except Exception as exc: self.log.exception(exc) - async def entire_workflow_update(self, ids=None): + async def entire_workflow_update(self, ids: Optional[list] = None) -> None: """Update entire local data-store of workflow(s). Args: - ids (list): List of workflow external IDs. - + ids: List of workflow external IDs. """ if ids is None: @@ -371,37 +371,35 @@ async def entire_workflow_update(self, ids=None): # Request new data req_method = 'pb_entire_workflow' - req_kwargs = ( - {'client': info['req_client'], - 'command': req_method, - 'req_context': w_id} - for w_id, info in self.workflows_mgr.active.items()) - - gathers = [ - workflow_request(**kwargs) - for kwargs in req_kwargs - if not ids or kwargs['req_context'] in ids - ] - items = await asyncio.gather(*gathers, return_exceptions=True) - for item in items: - if isinstance(item, Exception): + + requests = { + w_id: workflow_request( + client=info['req_client'], command=req_method + ) + for w_id, info in self.workflows_mgr.active.items() + if not ids or w_id in ids + } + results = await asyncio.gather( + *requests.values(), return_exceptions=True + ) + # result: + for w_id, result in zip(requests, results): + if isinstance(result, Exception): self.log.exception( 'Failed to update entire local data-store ' - 'of a workflow', exc_info=item + 'of a workflow', exc_info=result ) - else: - w_id, result = item - if result is not None and result != MSG_TIMEOUT: - pb_data = PB_METHOD_MAP[req_method]() - pb_data.ParseFromString(result) - new_data = deepcopy(DATA_TEMPLATE) - for field, value in pb_data.ListFields(): - if field.name == WORKFLOW: - new_data[field.name].CopyFrom(value) - new_data['delta_times'] = { - key: value.last_updated - for key in DATA_TEMPLATE - } - continue - new_data[field.name] = {n.id: n for n in value} - self.data[w_id] = new_data + elif result is not None and result != MSG_TIMEOUT: + pb_data = PB_METHOD_MAP[req_method]() + pb_data.ParseFromString(result) + new_data = deepcopy(DATA_TEMPLATE) + for field, value in pb_data.ListFields(): + if field.name == WORKFLOW: + new_data[field.name].CopyFrom(value) + new_data['delta_times'] = { + key: value.last_updated + for key in DATA_TEMPLATE + } + continue + new_data[field.name] = {n.id: n for n in value} + self.data[w_id] = new_data diff --git a/cylc/uiserver/handlers.py b/cylc/uiserver/handlers.py index 9a3566da..2b384c88 100644 --- a/cylc/uiserver/handlers.py +++ b/cylc/uiserver/handlers.py @@ -18,7 +18,7 @@ import json import getpass import socket -from typing import Callable, Union +from typing import TYPE_CHECKING, Callable, Union from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler from graphql import get_default_backend @@ -35,6 +35,9 @@ from cylc.uiserver.authorise import Authorization, AuthorizationMiddleware from cylc.uiserver.websockets import authenticated as websockets_authenticated +if TYPE_CHECKING: + from graphql.execution import ExecutionResult + ME = getpass.getuser() @@ -338,7 +341,7 @@ def prepare(self): super().prepare() @web.authenticated - async def execute(self, *args, **kwargs): + async def execute(self, *args, **kwargs) -> 'ExecutionResult': # Use own backend, and TornadoGraphQLHandler already does validation. return await self.schema.execute( *args, diff --git a/cylc/uiserver/resolvers.py b/cylc/uiserver/resolvers.py index e81eee84..5c0013e8 100644 --- a/cylc/uiserver/resolvers.py +++ b/cylc/uiserver/resolvers.py @@ -17,9 +17,20 @@ import os from subprocess import Popen, PIPE, DEVNULL +from typing import ( + TYPE_CHECKING, Any, Dict, List +) -from cylc.flow.network.resolvers import BaseResolvers from cylc.flow.data_store_mgr import WORKFLOW +from cylc.flow.network.resolvers import BaseResolvers +from cylc.flow.network.schema import GenericResponseTuple + + +if TYPE_CHECKING: + from logging import Logger + from graphql import ResolveInfo + from cylc.flow.data_store_mgr import DataStoreMgr + from cylc.uiserver.workflows_mgr import WorkflowsManager # show traceback from cylc commands @@ -172,11 +183,16 @@ async def play(cls, workflows, args, workflows_mgr, log): class Resolvers(BaseResolvers): """UI Server context GraphQL query and mutation resolvers.""" - workflows_mgr = None - - def __init__(self, data, log, **kwargs): + def __init__( + self, + data: 'DataStoreMgr', + log: 'Logger', + workflows_mgr: 'WorkflowsManager', + **kwargs + ): super().__init__(data) self.log = log + self.workflows_mgr = workflows_mgr # Set extra attributes for key, value in kwargs.items(): @@ -184,24 +200,52 @@ def __init__(self, data, log, **kwargs): setattr(self, key, value) # Mutations - async def mutator(self, info, *m_args): + async def mutator( + self, + info: 'ResolveInfo', + command: str, + w_args: Dict[str, Any], + _kwargs: Dict[str, Any] + ) -> List[GenericResponseTuple]: """Mutate workflow.""" - _, w_args, _ = m_args w_ids = [ flow[WORKFLOW].id for flow in await self.get_workflows_data(w_args)] if not w_ids: - return [{ - 'response': (False, 'No matching workflows')}] + return [ + GenericResponseTuple(None, False, "No matching workflows") + ] # Pass the request to the workflow GraphQL endpoints - req_str, variables, _, _ = info.context.get('graphql_params') + req_str, variables, _, _ = ( + info.context.get('graphql_params') # type: ignore[union-attr] + ) graphql_args = { 'request_string': req_str, 'variables': variables, } - return await self.workflows_mgr.multi_request( + results = await self.workflows_mgr.multi_request( 'graphql', w_ids, graphql_args ) + if not results: + return [ + GenericResponseTuple( + None, False, "No matching workflows running" + ) + ] + ret: List[GenericResponseTuple] = [] + for result in results: + if not isinstance(result, dict): + raise TypeError( + "Expected to receive GraphQL response dict " + f"but received: {result}" + ) + if not result.get('data'): + raise ValueError(f"Unexpected response: {result}") + mutation_result: dict = result['data'][command]['results'][0] + ret.append( + GenericResponseTuple(**mutation_result) + ) + return ret async def service(self, info, *m_args): return await Services.play( diff --git a/cylc/uiserver/schema.py b/cylc/uiserver/schema.py index 8f090bf7..1f77600a 100644 --- a/cylc/uiserver/schema.py +++ b/cylc/uiserver/schema.py @@ -19,14 +19,13 @@ """ -from functools import partial +from typing import TYPE_CHECKING, Any, List, Optional import graphene from graphene.types.generic import GenericScalar from cylc.flow.network.schema import ( CyclePoint, - GenericResponse, Mutations, Queries, Subscriptions, @@ -36,25 +35,9 @@ sstrip, ) - -async def mutator(root, info, command=None, workflows=None, - exworkflows=None, **args): - """Call the resolver method that act on the workflow service - via the internal command queue.""" - if workflows is None: - workflows = [] - if exworkflows is None: - exworkflows = [] - w_args = {} - w_args['workflows'] = [parse_workflow_id(w_id) for w_id in workflows] - w_args['exworkflows'] = [parse_workflow_id(w_id) for w_id in exworkflows] - if args.get('args', False): - args.update(args.get('args', {})) - args.pop('args') - - resolvers = info.context.get('resolvers') - res = await resolvers.service(info, command, w_args, args) - return GenericResponse(result=res) +if TYPE_CHECKING: + from graphql import ResolveInfo + from cylc.uiserver.resolvers import Resolvers class RunMode(graphene.Enum): @@ -89,7 +72,6 @@ class Meta: description = sstrip(''' Start, resume or un-pause a workflow run. ''') - resolver = partial(mutator, command='play') class Arguments: workflows = graphene.List(WorkflowID, required=True) @@ -199,6 +181,30 @@ class Arguments: ''') ) + @staticmethod + async def mutate( + root: Optional[Any], + info: 'ResolveInfo', + *, + workflows: Optional[List[str]] = None, + # _exworkflows: Optional[List[str]] = None, + **kwargs: Any + ): + """Call the resolver method that act on the workflow service + via the internal command queue.""" + if workflows is None: + workflows = [] + parsed_workflows = [parse_workflow_id(w_id) for w_id in workflows] + if kwargs.get('args', False): + kwargs.update(kwargs.get('args', {})) + kwargs.pop('args') + + resolvers: 'Resolvers' = ( + info.context.get('resolvers') # type: ignore[union-attr] + ) + res = await resolvers.service(parsed_workflows, kwargs) + return Play(result=res) + result = GenericScalar() diff --git a/cylc/uiserver/workflows_mgr.py b/cylc/uiserver/workflows_mgr.py index 3aede578..4d4598c0 100644 --- a/cylc/uiserver/workflows_mgr.py +++ b/cylc/uiserver/workflows_mgr.py @@ -26,6 +26,9 @@ from contextlib import suppress from getpass import getuser import sys +from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union +) import zmq.asyncio @@ -33,7 +36,6 @@ from cylc.flow.exceptions import ClientError, ClientTimeout from cylc.flow.network import API from cylc.flow.network.client import WorkflowRuntimeClient -from cylc.flow.network import MSG_TIMEOUT from cylc.flow.network.scan import ( api_version, contact_info, @@ -42,48 +44,38 @@ ) from cylc.flow.workflow_files import ContactFileFields as CFF +if TYPE_CHECKING: + from logging import Logger + + CLIENT_TIMEOUT = 2.0 async def workflow_request( - client, - command, - args=None, - timeout=None, - req_context=None, + client: WorkflowRuntimeClient, + command: str, + args: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, *, - log=None, -): + log: Optional['Logger'] = None, +) -> Union[bytes, object]: """Workflow request command. Args: - client (WorkflowRuntimeClient): Instantiated workflow client. - command (str): Command/Endpoint name. - args (dict): Endpoint arguments. - timeout (float): Client request timeout (secs). - req_context (str): A string to identifier. - - Returns: - tuple: (req_context, result) + client: Instantiated workflow client. + command: Command/Endpoint name. + args: Endpoint arguments. + timeout: Client request timeout (secs). """ - if req_context is None: - req_context = command try: - result = await client.async_request(command, args, timeout) - return (req_context, result) - except ClientTimeout as exc: + return await client.async_request(command, args, timeout) + except (ClientTimeout, ClientError) as exc: if log: log.exception(exc) else: print(exc, file=sys.stderr) - return (req_context, MSG_TIMEOUT) - except ClientError as exc: - if log: - log.exception(exc) - else: - print(exc, file=sys.stderr) - return (req_context, None) + raise exc class WorkflowsManager: @@ -267,44 +259,37 @@ async def update(self): async def multi_request( self, - command, - workflows, - args=None, - multi_args=None, + command: str, + workflows: Iterable[str], + args: Optional[Dict[str, Any]] = None, + multi_args: Optional[Dict[str, Any]] = None, timeout=None - ): + ) -> List[object]: """Send requests to multiple workflows.""" if args is None: args = {} if multi_args is None: multi_args = {} - req_args = { - w_id: ( + gathers = [ + workflow_request( self.active[w_id]['req_client'], command, multi_args.get(w_id, args), timeout, - ) for w_id in workflows + log=self.log + ) + for w_id in workflows if w_id in self.active - } - gathers = [ - workflow_request(req_context=info, *request_args, log=self.log) - for info, request_args in req_args.items() ] - results = await asyncio.gather(*gathers, return_exceptions=True) - res = [] + results: List[ + Union[bytes, object, Exception] + ] = await asyncio.gather(*gathers, return_exceptions=True) + res: List[Union[bytes, object]] = [] for result in results: if isinstance(result, Exception): self.log.exception( 'Failed to send requests to multiple workflows', exc_info=result ) - else: - _, val = result - res.extend([ - msg_core - for msg_core in list(val.values())[0].get('result') - if isinstance(val, dict) - and list(val.values()) - ]) + res.append(result) return res