Skip to content

Commit

Permalink
Merge pull request #814 from linsword13/template-vars
Browse files Browse the repository at this point in the history
Add the ability to define template-specific extra variables
  • Loading branch information
douglasjacobsen authored Jan 10, 2025
2 parents b5f8a7d + a82f624 commit cb82e32
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 12 deletions.
13 changes: 10 additions & 3 deletions lib/ramble/ramble/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,7 +2279,7 @@ def _get_template_config(
break
if not found:
raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}")
return {**tpl_config, "src_path": src_path}
return (obj, {**tpl_config, "src_path": src_path})

for tpl_config in self.templates.values():
yield _get_template_config(self, tpl_config)
Expand All @@ -2298,10 +2298,17 @@ def _get_template_config(

def _render_object_templates(self, extra_vars):
run_dir = self.expander.experiment_run_dir
for tpl_config in self._object_templates():
for obj, tpl_config in self._object_templates():
src_path = tpl_config["src_path"]
with open(src_path) as f_in:
content = f_in.read()
extra_vars_wm = tpl_config.get("extra_vars")
if extra_vars_wm is not None:
extra_vars.update(extra_vars_wm)
extra_vars_func_name = tpl_config.get("extra_vars_func_name")
if extra_vars_func_name is not None:
extra_vars_func = getattr(obj, extra_vars_func_name)
extra_vars.update(extra_vars_func())
rendered = self.expander.expand_var(content, extra_vars=extra_vars)
out_path = os.path.join(run_dir, tpl_config["dest_name"])
perm = tpl_config.get("content_perm", _DEFAULT_CONTENT_PERM)
Expand All @@ -2312,7 +2319,7 @@ def _render_object_templates(self, extra_vars):

def _define_object_template_vars(self):
run_dir = self.expander.experiment_run_dir
for tpl_config in self._object_templates():
for _, tpl_config in self._object_templates():
var_name = tpl_config["var_name"]
if var_name is not None:
path = os.path.join(run_dir, tpl_config["dest_name"])
Expand Down
19 changes: 18 additions & 1 deletion lib/ramble/ramble/language/shared_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# option. This file may not be copied, modified, or distributed
# except according to those terms.

from typing import Optional

import ramble.language.language_base
import ramble.language.language_helpers
import ramble.success_criteria
Expand Down Expand Up @@ -480,7 +482,13 @@ def _execute_target_shells(obj):

@shared_directive("templates")
def register_template(
name: str, src_name: str, dest_name: str, define_var: bool = True, output_perm=None
name: str,
src_name: str,
dest_name: str,
define_var: bool = True,
extra_vars: Optional[dict] = None,
extra_vars_func: Optional[str] = None,
output_perm=None,
):
"""Directive to define an object-specific template to be rendered into experiment run_dir.
Expand All @@ -498,15 +506,24 @@ def register_template(
dest_name: The leaf name of the rendered output under the experiment
run directory.
define_var: Controls if a variable named `name` should be defined.
extra_vars: If present, the variable dict is used as extra variables to
render the template.
extra_vars_func: If present, the name of the function to call to return
a dict of extra variables used to render the template.
This option is combined together with and takes precedence
over `extra_vars`, if both are present.
output_perm: The chmod mask for the rendered output file.
"""

def _define_template(obj):
var_name = name if define_var else None
extra_vars_func_name = f"_{extra_vars_func}" if extra_vars_func is not None else None
obj.templates[name] = {
"src_name": src_name,
"dest_name": dest_name,
"var_name": var_name,
"extra_vars": extra_vars,
"extra_vars_func_name": extra_vars_func_name,
"output_perm": output_perm,
}

Expand Down
2 changes: 2 additions & 0 deletions lib/ramble/ramble/test/end_to_end/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_template():
assert os.path.isfile(script_path)
with open(script_path) as f:
content = f.read()
assert "echo foobar" in content
assert "echo hello santa" in content
assert "echo not_exist" not in content
execute_path = os.path.join(run_dir, "execute_experiment")
with open(execute_path) as f:
content = f.read()
Expand Down
20 changes: 12 additions & 8 deletions var/ramble/repos/builtin.mock/applications/template/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@ class Template(ExecutableApplication):
workload="test_template",
)

register_phase(
"ingest_dynamic_variables",
pipeline="setup",
run_before=["make_experiments"],
register_template(
name="bar",
src_name="bar.tpl",
dest_name="bar.sh",
# The `dynamic_hello_world` will be overridden by `_bar_vars`
extra_vars={
"dynamic_var1": "foobar",
"dynamic_hello_world": "not_exist",
},
extra_vars_func="bar_vars",
)

def _ingest_dynamic_variables(self, workspace, app_inst):
def _bar_vars(self):
expander = self.expander
val = expander.expand_var('"hello {hello_name}"')
self.define_variable("dynamic_hello_world", val)

register_template("bar", src_name="bar.tpl", dest_name="bar.sh")
return {"dynamic_hello_world": val}
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#!/bin/bash
echo {dynamic_var1}
echo {dynamic_hello_world}

0 comments on commit cb82e32

Please sign in to comment.