Skip to content

Commit

Permalink
Merge pull request #155 from AstraZeneca/argo_fail_fast
Browse files Browse the repository at this point in the history
fix: bug with env secrets, improved capture of std
  • Loading branch information
vijayvammi authored May 19, 2024
2 parents 5e4215f + 80047bd commit c37f168
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 41 deletions.
4 changes: 2 additions & 2 deletions examples/02-sequential/default_fail.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
python examples/02-sequential/default_fail.py
"""

from examples.common.functions import raise_ex
from examples.common.functions import hello, raise_ex
from runnable import Pipeline, PythonTask, Stub


def main():
step1 = Stub(name="step 1")
step1 = PythonTask(name="step 1", function=hello)

step2 = PythonTask(name="step 2", function=raise_ex) # This will fail

Expand Down
7 changes: 0 additions & 7 deletions examples/configs/argo-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ executor:
persistent_volumes: # (3)
- name: magnus-volume
mount_path: /mnt
secrets_from_k8s:
- environment_variable: AZURE_CLIENT_ID
secret_name: ms-graph
secret_key: AZURE_CLIENT_ID

run_log_store: # (4)
type: chunked-fs
Expand All @@ -20,6 +16,3 @@ catalog:
type: file-system
config:
catalog_location: /mnt/catalog

# secrets:
# type: do-nothing
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ runnable = 'runnable.cli:cli'
[tool.poetry.plugins."secrets"]
"do-nothing" = "runnable.secrets:DoNothingSecretManager"
"dotenv" = "runnable.extensions.secrets.dotenv.implementation:DotEnvSecrets"
"env-secrets" = "runnable.secrets:EnvSecretsManager"

# Plugins for Run Log store
[tool.poetry.plugins."run_log_store"]
Expand Down
2 changes: 2 additions & 0 deletions runnable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
console = Console(record=True)
console.print(":runner: Lets go!!")

task_console = Console(record=True)

from runnable.sdk import ( # noqa
Catalog,
Fail,
Expand Down
2 changes: 1 addition & 1 deletion runnable/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class RunnableConfig(TypedDict, total=False):
DEFAULT_EXECUTOR = ServiceConfig(type="local", config={})
DEFAULT_RUN_LOG_STORE = ServiceConfig(type="file-system", config={})
DEFAULT_CATALOG = ServiceConfig(type="file-system", config={})
DEFAULT_SECRETS = ServiceConfig(type="do-nothing", config={})
DEFAULT_SECRETS = ServiceConfig(type="env-secrets", config={})
DEFAULT_EXPERIMENT_TRACKER = ServiceConfig(type="do-nothing", config={})
DEFAULT_PICKLER = ServiceConfig(type="pickle", config={})

Expand Down
13 changes: 7 additions & 6 deletions runnable/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.table import Column

import runnable.context as context
from runnable import console, defaults, graph, utils
from runnable import console, defaults, graph, task_console, utils
from runnable.defaults import RunnableConfig, ServiceConfig

logger = logging.getLogger(defaults.LOGGER_NAME)
Expand Down Expand Up @@ -165,6 +165,7 @@ def execute(
tag=tag,
parameters_file=parameters_file,
)

console.print("Working with context:")
console.print(run_context)
console.rule(style="[dark orange]")
Expand Down Expand Up @@ -239,7 +240,7 @@ def execute_single_node(
"""
from runnable import nodes

console.print(f"Executing the single node: {step_name} with map variable: {map_variable}")
task_console.print(f"Executing the single node: {step_name} with map variable: {map_variable}")

configuration_file = os.environ.get("RUNNABLE_CONFIGURATION_FILE", configuration_file)

Expand All @@ -250,9 +251,9 @@ def execute_single_node(
tag=tag,
parameters_file=parameters_file,
)
console.print("Working with context:")
console.print(run_context)
console.rule(style="[dark orange]")
task_console.print("Working with context:")
task_console.print(run_context)
task_console.rule(style="[dark orange]")

executor = run_context.executor
run_context.execution_plan = defaults.EXECUTION_PLAN.CHAINED.value
Expand Down Expand Up @@ -281,7 +282,7 @@ def execute_single_node(
node=node_to_execute,
map_variable=map_variable_dict,
)
console.save_text(log_file_name)
task_console.save_text(log_file_name)

# Put the log file in the catalog
run_context.catalog_handler.put(name=log_file_name, run_id=run_context.run_id)
Expand Down
10 changes: 10 additions & 0 deletions runnable/extensions/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
exceptions,
integration,
parameters,
task_console,
utils,
)
from runnable.datastore import DataCatalog, JsonParameter, RunLog, StepLog
Expand Down Expand Up @@ -340,10 +341,18 @@ def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = Non
node.execute_as_graph(map_variable=map_variable, **kwargs)
return

task_console.export_text(clear=True)

task_name = node._resolve_map_placeholders(node.internal_name, map_variable)
console.print(f":runner: Executing the node {task_name} ... ", style="bold color(208)")
self.trigger_job(node=node, map_variable=map_variable, **kwargs)

log_file_name = utils.make_log_file_name(node=node, map_variable=map_variable)
task_console.save_text(log_file_name, clear=True)

self._context.catalog_handler.put(name=log_file_name, run_id=self._context.run_id)
os.remove(log_file_name)

def trigger_job(self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs):
"""
Call this method only if we are responsible for traversing the graph via
Expand Down Expand Up @@ -493,6 +502,7 @@ def execute_graph(self, dag: Graph, map_variable: TypeMapVariable = None, **kwar

logger.info(f"Finished execution of the {branch} with status {run_log.status}")

# We are in the root dag
if dag == self._context.dag:
run_log = cast(RunLog, run_log)
console.print("Completed Execution, Summary:", style="bold color(208)")
Expand Down
55 changes: 54 additions & 1 deletion runnable/extensions/executor/local_container/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import Field
from rich import print

from runnable import defaults, utils
from runnable import console, defaults, task_console, utils
from runnable.datastore import StepLog
from runnable.defaults import TypeMapVariable
from runnable.extensions.executor import GenericExecutor
Expand Down Expand Up @@ -96,6 +96,59 @@ def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None, **k
"""
return self._execute_node(node, map_variable, **kwargs)

def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs):
"""
This is the entry point to from the graph execution.
While the self.execute_graph is responsible for traversing the graph, this function is responsible for
actual execution of the node.
If the node type is:
* task : We can delegate to _execute_node after checking the eligibility for re-run in cases of a re-run
* success: We can delegate to _execute_node
* fail: We can delegate to _execute_node
For nodes that are internally graphs:
* parallel: Delegate the responsibility of execution to the node.execute_as_graph()
* dag: Delegate the responsibility of execution to the node.execute_as_graph()
* map: Delegate the responsibility of execution to the node.execute_as_graph()
Transpilers will NEVER use this method and will NEVER call ths method.
This method should only be used by interactive executors.
Args:
node (Node): The node to execute
map_variable (dict, optional): If the node if of a map state, this corresponds to the value of iterable.
Defaults to None.
"""
step_log = self._context.run_log_store.create_step_log(node.name, node._get_step_log_name(map_variable))

self.add_code_identities(node=node, step_log=step_log)

step_log.step_type = node.node_type
step_log.status = defaults.PROCESSING

self._context.run_log_store.add_step_log(step_log, self._context.run_id)

logger.info(f"Executing node: {node.get_summary()}")

# Add the step log to the database as per the situation.
# If its a terminal node, complete it now
if node.node_type in ["success", "fail"]:
self._execute_node(node, map_variable=map_variable, **kwargs)
return

# We call an internal function to iterate the sub graphs and execute them
if node.is_composite:
node.execute_as_graph(map_variable=map_variable, **kwargs)
return

task_console.export_text(clear=True)

task_name = node._resolve_map_placeholders(node.internal_name, map_variable)
console.print(f":runner: Executing the node {task_name} ... ", style="bold color(208)")
self.trigger_job(node=node, map_variable=map_variable, **kwargs)

def execute_job(self, node: TaskNode):
"""
Set up the step log and call the execute node
Expand Down
47 changes: 23 additions & 24 deletions runnable/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from stevedore import driver

import runnable.context as context
from runnable import console, defaults, exceptions, parameters, utils
from runnable import console, defaults, exceptions, parameters, task_console, utils
from runnable.datastore import (
JsonParameter,
MetricParameter,
Expand Down Expand Up @@ -144,27 +144,21 @@ def execution_context(self, map_variable: TypeMapVariable = None, allow_complex:
if context_param in params:
params[param_name].value = params[context_param].value

console.log("Parameters available for the execution:")
console.log(params)
task_console.log("Parameters available for the execution:")
task_console.log(params)

logger.debug(f"Resolved parameters: {params}")

if not allow_complex:
params = {key: value for key, value in params.items() if isinstance(value, JsonParameter)}

parameters_in = copy.deepcopy(params)
f = io.StringIO()
try:
with contextlib.redirect_stdout(f):
# with contextlib.nullcontext():
yield params
yield params
except Exception as e: # pylint: disable=broad-except
console.log(e, style=defaults.error_style)
logger.exception(e)
finally:
print(f.getvalue()) # print to console
f.close()

# Update parameters
# This should only update the parameters that are changed at the root level.
diff_parameters = self._diff_parameters(parameters_in=parameters_in, context_params=params)
Expand Down Expand Up @@ -226,9 +220,11 @@ def execute_command(
filtered_parameters = parameters.filter_arguments_for_func(f, params.copy(), map_variable)
logger.info(f"Calling {func} from {module} with {filtered_parameters}")

user_set_parameters = f(**filtered_parameters) # This is a tuple or single value
out_file = io.StringIO()
with contextlib.redirect_stdout(out_file):
user_set_parameters = f(**filtered_parameters) # This is a tuple or single value
task_console.print(out_file.getvalue())
except Exception as e:
console.log(e, style=defaults.error_style, markup=False)
raise exceptions.CommandCallError(f"Function call: {self.command} did not succeed.\n") from e

attempt_log.input_parameters = params.copy()
Expand Down Expand Up @@ -272,8 +268,8 @@ def execute_command(
except Exception as _e:
msg = f"Call to the function {self.command} did not succeed.\n"
attempt_log.message = msg
console.print_exception(show_locals=False)
console.log(_e, style=defaults.error_style)
task_console.print_exception(show_locals=False)
task_console.log(_e, style=defaults.error_style)

attempt_log.end_time = str(datetime.now())

Expand Down Expand Up @@ -359,7 +355,11 @@ def execute_command(
}
kwds.update(ploomber_optional_args)

pm.execute_notebook(**kwds)
out_file = io.StringIO()
with contextlib.redirect_stdout(out_file):
pm.execute_notebook(**kwds)
task_console.print(out_file.getvalue())

context.run_context.catalog_handler.put(name=notebook_output_path, run_id=context.run_context.run_id)

client = PloomberClient.from_path(path=notebook_output_path)
Expand All @@ -380,8 +380,8 @@ def execute_command(
)
except PicklingError as e:
logger.exception("Notebooks cannot return objects")
console.log("Notebooks cannot return objects", style=defaults.error_style)
console.log(e, style=defaults.error_style)
# task_console.log("Notebooks cannot return objects", style=defaults.error_style)
# task_console.log(e, style=defaults.error_style)

logger.exception(e)
raise
Expand All @@ -400,8 +400,7 @@ def execute_command(
logger.exception(msg)
logger.exception(e)

console.log(msg, style=defaults.error_style)

# task_console.log(msg, style=defaults.error_style)
attempt_log.status = defaults.FAIL

attempt_log.end_time = str(datetime.now())
Expand Down Expand Up @@ -488,14 +487,14 @@ def execute_command(

if proc.returncode != 0:
msg = ",".join(result[1].split("\n"))
console.print(msg, style=defaults.error_style)
task_console.print(msg, style=defaults.error_style)
raise exceptions.CommandCallError(msg)

# for stderr
for line in result[1].split("\n"):
if line.strip() == "":
continue
console.print(line, style=defaults.warning_style)
task_console.print(line, style=defaults.warning_style)

output_parameters: Dict[str, Parameter] = {}
metrics: Dict[str, Parameter] = {}
Expand All @@ -506,7 +505,7 @@ def execute_command(
continue

logger.info(line)
console.print(line)
task_console.print(line)

if line.strip() == collect_delimiter:
# The lines from now on should be captured
Expand Down Expand Up @@ -548,8 +547,8 @@ def execute_command(
logger.exception(msg)
logger.exception(e)

console.log(msg, style=defaults.error_style)
console.log(e, style=defaults.error_style)
task_console.log(msg, style=defaults.error_style)
task_console.log(e, style=defaults.error_style)

attempt_log.status = defaults.FAIL

Expand Down

0 comments on commit c37f168

Please sign in to comment.