diff --git a/copier/jinja_ext.py b/copier/jinja_ext.py new file mode 100644 index 000000000..687dacf48 --- /dev/null +++ b/copier/jinja_ext.py @@ -0,0 +1,109 @@ +"""Jinja2 extensions built for Copier.""" + +from __future__ import annotations + +from typing import Any, Callable, Sequence + +from jinja2 import nodes +from jinja2.exceptions import UndefinedError +from jinja2.ext import Extension +from jinja2.parser import Parser +from jinja2.sandbox import SandboxedEnvironment + + +class YieldEnvironment(SandboxedEnvironment): + """Jinja2 environment with a `yield_context` attribute. + + This is simple environment class that extends the SandboxedEnvironment + for use with the YieldExtension, mainly for avoiding type errors. + + We use the SandboxedEnvironment because we want to minimize the risk of hidden malware + in the templates so we use the SandboxedEnvironment instead of the regular one. + Of course we still have the post-copy tasks to worry about, but at least + they are more visible to the final user. + """ + + yield_context: dict[str, Any] + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.extend(yield_context=dict()) + + +class YieldExtension(Extension): + """`Jinja2 extension for the `yield` tag. + + If `yield` tag is used in a template, this extension sets the `yield_context` attribute to the + jinja environment. `yield_context` is a dictionary with the following keys: + - `single_var`: The name of the variable that will be yielded. + - `looped_var`: The variable that will be looped over. + + Note that this extension just sets the `yield_context` attribute but renders template + as usual. It is caller's responsibility to use the `yield_context` attribute in the + template to generate the desired output. + + Example: + template: "{% yield single_var from looped_var %}" + context: {"looped_var": [1, 2, 3], "single_var": "item"} + + then, + >>> from copier.jinja_ext import YieldEnvironment, YieldExtension + >>> env = YieldEnvironment(extensions=[YieldExtension]) + >>> template = env.from_string("{% yield single_var from looped_var %}{{ single_var }}{% endyield %}") + >>> template.render({"looped_var": [1, 2, 3]}) + '' + >>> env.yield_context + {'single_var': 'single_var', 'looped_var': [1, 2, 3]} + """ + + tags = {"yield"} + + environment: YieldEnvironment + + def __init__(self, environment: YieldEnvironment): + super().__init__(environment) + + def parse(self, parser: Parser) -> nodes.Node: + """Parse the `yield` tag.""" + lineno = next(parser.stream).lineno + + single_var: nodes.Name = parser.parse_assign_target(name_only=True) + parser.stream.expect("name:from") + looped_var = parser.parse_expression() + body = parser.parse_statements(("name:endyield",), drop_needle=True) + + return nodes.CallBlock( + self.call_method( + "_yield_support", + [looped_var, nodes.Const(single_var.name)], + ), + [], + [], + body, + lineno=lineno, + ) + + def _yield_support( + self, looped_var: Sequence[Any], single_var_name: str, caller: Callable[[], str] + ) -> str: + """Support function for the yield tag. + + Sets the yield context in the environment with the given + looped variable and single variable name, then calls the provided caller + function. If an UndefinedError is raised, it returns an empty string. + + """ + self.environment.yield_context = { + "single_var": single_var_name, + "looped_var": looped_var, + } + + try: + res = caller() + + # expression like `dict.attr` will always raise UndefinedError + # so we catch it here and return an empty string + except UndefinedError: + res = "" + + return res diff --git a/copier/main.py b/copier/main.py index bb514fba0..ebf4cdd74 100644 --- a/copier/main.py +++ b/copier/main.py @@ -28,7 +28,6 @@ from unicodedata import normalize from jinja2.loaders import FileSystemLoader -from jinja2.sandbox import SandboxedEnvironment from pathspec import PathSpec from plumbum import ProcessExecutionError, colors from plumbum.cli.terminal import ask @@ -44,6 +43,7 @@ UnsafeTemplateError, UserMessageError, ) +from .jinja_ext import YieldEnvironment, YieldExtension from .subproject import Subproject from .template import Task, Template from .tools import ( @@ -540,7 +540,7 @@ def all_exclusions(self) -> Sequence[str]: return self.template.exclude + tuple(self.exclude) @cached_property - def jinja_env(self) -> SandboxedEnvironment: + def jinja_env(self) -> YieldEnvironment: """Return a pre-configured Jinja environment. Respects template settings. @@ -549,14 +549,11 @@ def jinja_env(self) -> SandboxedEnvironment: loader = FileSystemLoader(paths) default_extensions = [ "jinja2_ansible_filters.AnsibleCoreFiltersExtension", + YieldExtension, ] extensions = default_extensions + list(self.template.jinja_extensions) - # We want to minimize the risk of hidden malware in the templates - # so we use the SandboxedEnvironment instead of the regular one. - # Of course we still have the post-copy tasks to worry about, but at least - # they are more visible to the final user. try: - env = SandboxedEnvironment( + env = YieldEnvironment( loader=loader, extensions=extensions, **self.template.envops ) except ModuleNotFoundError as error: @@ -606,19 +603,25 @@ def _render_template(self) -> None: for src in scantree(str(self.template_copy_root), follow_symlinks): src_abspath = Path(src.path) src_relpath = Path(src_abspath).relative_to(self.template.local_abspath) - dst_relpath = self._render_path( + dst_relpaths_ctxs = self._render_path( Path(src_abspath).relative_to(self.template_copy_root) ) - if dst_relpath is None or self.match_exclude(dst_relpath): - continue - if src.is_symlink() and self.template.preserve_symlinks: - self._render_symlink(src_relpath, dst_relpath) - elif src.is_dir(follow_symlinks=follow_symlinks): - self._render_folder(dst_relpath) - else: - self._render_file(src_relpath, dst_relpath) + for dst_relpath, ctx in dst_relpaths_ctxs: + if self.match_exclude(dst_relpath): + continue + if src.is_symlink() and self.template.preserve_symlinks: + self._render_symlink(src_relpath, dst_relpath) + elif src.is_dir(follow_symlinks=follow_symlinks): + self._render_folder(dst_relpath) + else: + self._render_file(src_relpath, dst_relpath, extra_context=ctx or {}) - def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None: + def _render_file( + self, + src_relpath: Path, + dst_relpath: Path, + extra_context: AnyByStrDict | None = None, + ) -> None: """Render one file. Args: @@ -628,6 +631,8 @@ def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None: dst_relpath: File to be created. It must be a path relative to the subproject root. + extra_context: + Additional variables to use for rendering the template. """ # TODO Get from main.render_file() assert not src_relpath.is_absolute() @@ -643,7 +648,9 @@ def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None: # suffix is empty, fallback to copy new_content = src_abspath.read_bytes() else: - new_content = tpl.render(**self._render_context()).encode() + new_content = tpl.render( + **self._render_context(), **(extra_context or {}) + ).encode() else: new_content = src_abspath.read_bytes() dst_abspath = self.subproject.local_abspath / dst_relpath @@ -715,8 +722,89 @@ def _render_folder(self, dst_relpath: Path) -> None: dst_abspath = self.subproject.local_abspath / dst_relpath dst_abspath.mkdir(parents=True, exist_ok=True) - def _render_path(self, relpath: Path) -> Path | None: - """Render one relative path. + def _adjust_rendered_part(self, rendered_part: str) -> str: + """Adjust the rendered part if necessary. + + If {{ _copier_conf.answers_file }} becomes the full path, + restore part to be just the end leaf. + + Args: + rendered_part: + The rendered part of the path to adjust. + + """ + if str(self.answers_relpath) == rendered_part: + return Path(rendered_part).name + return rendered_part + + def _render_parts( + self, + parts: tuple[str, ...], + rendered_parts: tuple[str, ...] | None = None, + extra_context: AnyByStrDict | None = None, + is_template: bool = False, + ) -> Iterable[tuple[Path, AnyByStrDict | None]]: + """Render a set of parts into path and context pairs. + + If a yield tag is found in a part, it will recursively yield multiple path and context pairs. + """ + if rendered_parts is None: + rendered_parts = tuple() + + if not parts: + rendered_path = Path(*rendered_parts) + + templated_sibling = ( + self.template.local_abspath + / f"{rendered_path}{self.template.templates_suffix}" + ) + if is_template or not templated_sibling.exists(): + yield rendered_path, extra_context + + return + + part = parts[0] + parts = parts[1:] + + if not extra_context: + extra_context = {} + + rendered_part = self._render_string(part, extra_context=extra_context) + + yield_context = self.jinja_env.yield_context.copy() + if yield_context: + single_var = yield_context["single_var"] + looped_var = yield_context["looped_var"] + + for value in looped_var: + new_context = {**extra_context, **{single_var: value}} + rendered_part = self._render_string(part, extra_context=new_context) + self.jinja_env.yield_context = {} + + rendered_part = self._adjust_rendered_part(rendered_part) + + # Skip if any part is rendered as an empty string + if not rendered_part: + continue + + yield from self._render_parts( + parts, rendered_parts + (rendered_part,), new_context, is_template + ) + + return + + # Skip if any part is rendered as an empty string + if not rendered_part: + return + + rendered_part = self._adjust_rendered_part(rendered_part) + + yield from self._render_parts( + parts, rendered_parts + (rendered_part,), extra_context, is_template + ) + + def _render_path(self, relpath: Path) -> Iterable[tuple[Path, AnyByStrDict | None]]: + """Render one relative path into multiple path and context pairs. Args: relpath: @@ -728,29 +816,11 @@ def _render_path(self, relpath: Path) -> Path | None: ) # With an empty suffix, the templated sibling always exists. if templated_sibling.exists() and self.template.templates_suffix: - return None + return if self.template.templates_suffix and is_template: relpath = relpath.with_suffix("") - rendered_parts = [] - for part in relpath.parts: - # Skip folder if any part is rendered as an empty string - part = self._render_string(part) - if not part: - return None - # {{ _copier_conf.answers_file }} becomes the full path; in that case, - # restore part to be just the end leaf - if str(self.answers_relpath) == part: - part = Path(part).name - rendered_parts.append(part) - result = Path(*rendered_parts) - if not is_template: - templated_sibling = ( - self.template.local_abspath - / f"{result}{self.template.templates_suffix}" - ) - if templated_sibling.exists(): - return None - return result + + yield from self._render_parts(relpath.parts, is_template=is_template) def _render_string( self, string: str, extra_context: AnyByStrDict | None = None diff --git a/tests/test_dynamic_file_structures.py b/tests/test_dynamic_file_structures.py new file mode 100644 index 000000000..69c2f3dde --- /dev/null +++ b/tests/test_dynamic_file_structures.py @@ -0,0 +1,165 @@ +import warnings + +import pytest + +import copier +from tests.helpers import build_file_tree + + +def test_folder_loop(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + src / "copier.yml": "", + src + / "folder_loop" + / "{% yield item from strings %}{{ item }}{% endyield %}" + / "{{ item }}.txt.jinja": "Hello {{ item }}", + } + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + copier.run_copy( + str(src), + dst, + data={ + "strings": ["a", "b", "c"], + }, + defaults=True, + overwrite=True, + ) + + expected_files = [dst / f"folder_loop/{i}/{i}.txt" for i in ["a", "b", "c"]] + + for f in expected_files: + assert f.exists() + assert f.read_text() == f"Hello {f.parent.name}" + + all_files = [p for p in dst.rglob("*") if p.is_file()] + unexpected_files = set(all_files) - set(expected_files) + + assert not unexpected_files, f"Unexpected files found: {unexpected_files}" + + +def test_nested_folder_loop(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + src / "copier.yml": "", + src + / "nested_folder_loop" + / "{% yield string_item from strings %}{{ string_item }}{% endyield %}" + / "{% yield integer_item from integers %}{{ integer_item }}{% endyield %}" + / "{{ string_item }}_{{ integer_item }}.txt.jinja": "Hello {{ string_item }} {{ integer_item }}", + } + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + copier.run_copy( + str(src), + dst, + data={ + "strings": ["a", "b"], + "integers": [1, 2, 3], + }, + defaults=True, + overwrite=True, + ) + + expected_files = [ + dst / f"nested_folder_loop/{s}/{i}/{s}_{i}.txt" + for s in ["a", "b"] + for i in [1, 2, 3] + ] + + for f in expected_files: + assert f.exists() + assert f.read_text() == f"Hello {f.parent.parent.name} {f.parent.name}" + + all_files = [p for p in dst.rglob("*") if p.is_file()] + unexpected_files = set(all_files) - set(expected_files) + + assert not unexpected_files, f"Unexpected files found: {unexpected_files}" + + +def test_file_loop(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + src / "copier.yml": "", + src + / "file_loop" + / "{% yield string_item from strings %}{{ string_item }}{% endyield %}.jinja": "Hello {{ string_item }}", + } + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + copier.run_copy( + str(src), + dst, + data={ + "strings": ["a.txt", "b.txt", "c.txt", ""], # if rendred as '.jinja', it will not be created + }, + defaults=True, + overwrite=True, + ) + + expected_files = [dst / f"file_loop/{i}.txt" for i in ["a", "b", "c"]] + for f in expected_files: + assert f.exists() + assert f.read_text() == f"Hello {f.stem}.txt" + + all_files = [p for p in dst.rglob("*") if p.is_file()] + unexpected_files = set(all_files) - set(expected_files) + + assert not unexpected_files, f"Unexpected files found: {unexpected_files}" + + +def test_folder_loop_dict_items(tmp_path_factory: pytest.TempPathFactory) -> None: + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + src / "copier.yml": "", + src + / "folder_loop_dict_items" + / "{% yield dict_item from dicts %}{{ dict_item.folder_name }}{% endyield %}" + / "{{ dict_item.file_name }}.txt.jinja": "Hello {{ '-'.join(dict_item.content) }}", + } + ) + + dicts = [ + {"folder_name": "folder_a", "file_name": "file_a", "content": ['folder_a', 'file_a']}, + {"folder_name": "folder_b", "file_name": "file_b", "content": ['folder_b', 'file_b']}, + {"folder_name": "folder_c", "file_name": "file_c", "content": ['folder_c', 'file_c']}, + ] + + with warnings.catch_warnings(): + warnings.simplefilter("error") + + copier.run_copy( + str(src), + dst, + data={ + "dicts": dicts + }, + defaults=True, + overwrite=True, + ) + + expected_files = [ + dst / f"folder_loop_dict_items/{d['folder_name']}/{d['file_name']}.txt" + for d in [ + {"folder_name": "folder_a", "file_name": "file_a"}, + {"folder_name": "folder_b", "file_name": "file_b"}, + {"folder_name": "folder_c", "file_name": "file_c"}, + ] + ] + + for f in expected_files: + assert f.exists() + assert f.read_text() == f"Hello {'-'.join([f.parts[-2], f.stem])}" + + all_files = [p for p in dst.rglob("*") if p.is_file()] + unexpected_files = set(all_files) - set(expected_files) + + assert not unexpected_files, f"Unexpected files found: {unexpected_files}"