Skip to content

Commit

Permalink
Merge pull request #51 from MetRonnie/workflow-state
Browse files Browse the repository at this point in the history
Refactor `workflow_state` xtrig pre-8.3.0-back-compat
  • Loading branch information
hjoliver authored Jun 5, 2024
2 parents 960fb61 + 2a6b0ba commit 344759f
Show file tree
Hide file tree
Showing 20 changed files with 508 additions and 248 deletions.
93 changes: 58 additions & 35 deletions cylc/flow/dbstatecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import os
import sqlite3
import sys
from typing import Optional, List
from textwrap import dedent
from typing import Dict, Iterable, Optional, List, Union

from cylc.flow import LOG
from cylc.flow.exceptions import InputError
from cylc.flow.cycling.util import add_offset
from cylc.flow.cycling.integer import (
Expand All @@ -33,13 +33,20 @@
from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.task_outputs import (
TASK_OUTPUT_SUCCEEDED,
TASK_OUTPUT_FAILED
TASK_OUTPUT_FAILED,
TASK_OUTPUT_FINISHED,
)
from cylc.flow.util import deserialise_set
from metomi.isodatetime.parsers import TimePointParser
from metomi.isodatetime.exceptions import ISO8601SyntaxError


output_fallback_msg = (
"Unable to filter by task output label for tasks run in Cylc versions "
"between 8.0.0-8.3.0. Falling back to filtering by task message instead."
)


class CylcWorkflowDBChecker:
"""Object for querying task status or outputs from a workflow database.
Expand Down Expand Up @@ -70,12 +77,12 @@ def __init__(self, rund, workflow, db_path=None):
# Get workflow point format.
try:
self.db_point_fmt = self._get_db_point_format()
self.back_compat_mode = False
self.c7_back_compat_mode = False
except sqlite3.OperationalError as exc:
# BACK COMPAT: Cylc 7 DB (see method below).
try:
self.db_point_fmt = self._get_db_point_format_compat()
self.back_compat_mode = True
self.c7_back_compat_mode = True
except sqlite3.OperationalError:
raise exc # original error

Expand Down Expand Up @@ -194,7 +201,7 @@ def workflow_state_query(
]
For an output query:
[
[name, cycle, "[out1: msg1, out2: msg2, ...]"],
[name, cycle, "{out1: msg1, out2: msg2, ...}"],
...
]
"""
Expand All @@ -208,16 +215,16 @@ def workflow_state_query(
target_table = CylcWorkflowDAO.TABLE_TASK_STATES
mask = "name, cycle, status"

if not self.back_compat_mode:
if not self.c7_back_compat_mode:
# Cylc 8 DBs only
mask += ", flow_nums"

stmt = dedent(rf'''
stmt = rf'''
SELECT
{mask}
FROM
{target_table}
''') # nosec
''' # nosec
# * mask is hardcoded
# * target_table is a code constant

Expand All @@ -241,20 +248,20 @@ def workflow_state_query(
stmt_wheres.append("cycle==?")
stmt_args.append(cycle)

if selector is not None and not (is_output or is_message):
if (
selector is not None
and target_table == CylcWorkflowDAO.TABLE_TASK_STATES
):
# Can select by status in the DB but not outputs.
stmt_wheres.append("status==?")
stmt_args.append(selector)

if stmt_wheres:
stmt += "WHERE\n " + (" AND ").join(stmt_wheres)

if not (is_output or is_message):
if target_table == CylcWorkflowDAO.TABLE_TASK_STATES:
# (outputs table doesn't record submit number)
stmt += dedent("""
ORDER BY
submit_num
""")
stmt += r"ORDER BY submit_num"

# Query the DB and drop incompatible rows.
db_res = []
Expand All @@ -264,7 +271,7 @@ def workflow_state_query(
if row[2] is None:
# status can be None in Cylc 7 DBs
continue
if not self.back_compat_mode:
if not self.c7_back_compat_mode:
flow_nums = deserialise_set(row[3])
if flow_num is not None and flow_num not in flow_nums:
# skip result, wrong flow
Expand All @@ -274,34 +281,50 @@ def workflow_state_query(
res.append(fstr)
db_res.append(res)

if not (is_output or is_message):
if target_table == CylcWorkflowDAO.TABLE_TASK_STATES:
return db_res

warn_output_fallback = is_output
results = []
for row in db_res:
outputs_map = json.loads(row[2])
if is_message:
# task message
try:
outputs = list(outputs_map.values())
except AttributeError:
# Cylc 8 pre 8.3.0 back-compat: list of output messages
outputs = list(outputs_map)
outputs: Union[Dict[str, str], List[str]] = json.loads(row[2])
if isinstance(outputs, dict):
messages: Iterable[str] = outputs.values()
else:
# task output
outputs = list(outputs_map)
# Cylc 8 pre 8.3.0 back-compat: list of output messages
messages = outputs
if warn_output_fallback:
LOG.warning(output_fallback_msg)
warn_output_fallback = False

if (
selector is None or
selector in outputs or
(
selector in ("finished", "finish")
and (
TASK_OUTPUT_SUCCEEDED in outputs
or TASK_OUTPUT_FAILED in outputs
)
)
(is_message and selector in messages) or
(is_output and self._selector_in_outputs(selector, outputs))
):
results.append(row[:2] + [str(outputs)] + row[3:])

return results

@staticmethod
def _selector_in_outputs(selector: str, outputs: Iterable[str]) -> bool:
"""Check if a selector, including "finished", is in the outputs.
Examples:
>>> this = CylcWorkflowDBChecker._selector_in_outputs
>>> this('moop', ['started', 'moop'])
True
>>> this('moop', ['started'])
False
>>> this('finished', ['succeeded'])
True
>>> this('finish', ['failed'])
True
"""
return selector in outputs or (
selector in (TASK_OUTPUT_FINISHED, "finish")
and (
TASK_OUTPUT_SUCCEEDED in outputs
or TASK_OUTPUT_FAILED in outputs
)
)
9 changes: 5 additions & 4 deletions cylc/flow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,13 @@ class XtriggerConfigError(WorkflowConfigError):
"""

def __init__(self, label: str, message: str):
self.label: str = label
self.message: str = message
def __init__(self, label: str, func: str, message: Union[str, Exception]):
self.label = label
self.func = func
self.message = message

def __str__(self) -> str:
return f'[@{self.label}] {self.message}'
return f'[@{self.label}] {self.func}\n{self.message}'


class ClientError(CylcError):
Expand Down
18 changes: 12 additions & 6 deletions cylc/flow/rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import traceback
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Set,
Expand All @@ -38,6 +39,7 @@

if TYPE_CHECKING:
from pathlib import Path
from cylc.flow.flow_mgr import FlowNums


@dataclass
Expand Down Expand Up @@ -806,10 +808,12 @@ def select_latest_flow_nums(self):
flow_nums_str = list(self.connect().execute(stmt))[0][0]
return deserialise_set(flow_nums_str)

def select_task_outputs(self, name, point):
def select_task_outputs(
self, name: str, point: str
) -> 'Dict[str, FlowNums]':
"""Select task outputs for each flow.
Return: {outputs_list: flow_nums_set}
Return: {outputs_dict_str: flow_nums_set}
"""
stmt = rf'''
Expand All @@ -820,10 +824,12 @@ def select_task_outputs(self, name, point):
WHERE
name==? AND cycle==?
''' # nosec (table name is code constant)
ret = {}
for flow_nums, outputs in self.connect().execute(stmt, (name, point,)):
ret[outputs] = deserialise_set(flow_nums)
return ret
return {
outputs: deserialise_set(flow_nums)
for flow_nums, outputs in self.connect().execute(
stmt, (name, point,)
)
}

def select_xtriggers_for_restart(self, callback):
stmt = rf'''
Expand Down
Loading

0 comments on commit 344759f

Please sign in to comment.