Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dynamic file structures in loop using yield-tag #1855

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions copier/jinja_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Jinja2 extensions built for Copier."""

from __future__ import annotations

from typing import Any, Callable, Sequence

from jinja2 import nodes
from jinja2.exceptions import UndefinedError
from jinja2.ext import Extension
from jinja2.parser import Parser
from jinja2.sandbox import SandboxedEnvironment


class YieldEnvironment(SandboxedEnvironment):
"""Jinja2 environment with a `yield_context` attribute.

This is simple environment class that extends the SandboxedEnvironment
for use with the YieldExtension, mainly for avoiding type errors.

We use the SandboxedEnvironment because we want to minimize the risk of hidden malware
in the templates so we use the SandboxedEnvironment instead of the regular one.
Of course we still have the post-copy tasks to worry about, but at least
they are more visible to the final user.
"""

yield_context: dict[str, Any]

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.extend(yield_context=dict())


class YieldExtension(Extension):
"""`Jinja2 extension for the `yield` tag.

If `yield` tag is used in a template, this extension sets the `yield_context` attribute to the
jinja environment. `yield_context` is a dictionary with the following keys:
- `single_var`: The name of the variable that will be yielded.
- `looped_var`: The variable that will be looped over.

Note that this extension just sets the `yield_context` attribute but renders template
as usual. It is caller's responsibility to use the `yield_context` attribute in the
template to generate the desired output.

Example:
template: "{% yield single_var from looped_var %}"
context: {"looped_var": [1, 2, 3], "single_var": "item"}

then,
>>> from copier.jinja_ext import YieldEnvironment, YieldExtension
>>> env = YieldEnvironment(extensions=[YieldExtension])
>>> template = env.from_string("{% yield single_var from looped_var %}{{ single_var }}{% endyield %}")
>>> template.render({"looped_var": [1, 2, 3]})
''
>>> env.yield_context
{'single_var': 'single_var', 'looped_var': [1, 2, 3]}
"""

tags = {"yield"}

environment: YieldEnvironment

def __init__(self, environment: YieldEnvironment):
super().__init__(environment)

def parse(self, parser: Parser) -> nodes.Node:
"""Parse the `yield` tag."""
lineno = next(parser.stream).lineno

single_var: nodes.Name = parser.parse_assign_target(name_only=True)
parser.stream.expect("name:from")
looped_var = parser.parse_expression()
body = parser.parse_statements(("name:endyield",), drop_needle=True)

return nodes.CallBlock(
self.call_method(
"_yield_support",
[looped_var, nodes.Const(single_var.name)],
),
[],
[],
body,
lineno=lineno,
)

def _yield_support(
self, looped_var: Sequence[Any], single_var_name: str, caller: Callable[[], str]
) -> str:
"""Support function for the yield tag.

Sets the yield context in the environment with the given
looped variable and single variable name, then calls the provided caller
function. If an UndefinedError is raised, it returns an empty string.

"""
self.environment.yield_context = {
"single_var": single_var_name,
"looped_var": looped_var,
}

try:
res = caller()

# expression like `dict.attr` will always raise UndefinedError
# so we catch it here and return an empty string
except UndefinedError:
res = ""

return res
152 changes: 111 additions & 41 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from unicodedata import normalize

from jinja2.loaders import FileSystemLoader
from jinja2.sandbox import SandboxedEnvironment
from pathspec import PathSpec
from plumbum import ProcessExecutionError, colors
from plumbum.cli.terminal import ask
Expand All @@ -44,6 +43,7 @@
UnsafeTemplateError,
UserMessageError,
)
from .jinja_ext import YieldEnvironment, YieldExtension
from .subproject import Subproject
from .template import Task, Template
from .tools import (
Expand Down Expand Up @@ -540,7 +540,7 @@ def all_exclusions(self) -> Sequence[str]:
return self.template.exclude + tuple(self.exclude)

@cached_property
def jinja_env(self) -> SandboxedEnvironment:
def jinja_env(self) -> YieldEnvironment:
"""Return a pre-configured Jinja environment.

Respects template settings.
Expand All @@ -549,14 +549,11 @@ def jinja_env(self) -> SandboxedEnvironment:
loader = FileSystemLoader(paths)
default_extensions = [
"jinja2_ansible_filters.AnsibleCoreFiltersExtension",
YieldExtension,
]
extensions = default_extensions + list(self.template.jinja_extensions)
# We want to minimize the risk of hidden malware in the templates
# so we use the SandboxedEnvironment instead of the regular one.
# Of course we still have the post-copy tasks to worry about, but at least
# they are more visible to the final user.
try:
env = SandboxedEnvironment(
env = YieldEnvironment(
loader=loader, extensions=extensions, **self.template.envops
)
except ModuleNotFoundError as error:
Expand Down Expand Up @@ -606,19 +603,25 @@ def _render_template(self) -> None:
for src in scantree(str(self.template_copy_root), follow_symlinks):
src_abspath = Path(src.path)
src_relpath = Path(src_abspath).relative_to(self.template.local_abspath)
dst_relpath = self._render_path(
dst_relpaths_ctxs = self._render_path(
Path(src_abspath).relative_to(self.template_copy_root)
)
if dst_relpath is None or self.match_exclude(dst_relpath):
continue
if src.is_symlink() and self.template.preserve_symlinks:
self._render_symlink(src_relpath, dst_relpath)
elif src.is_dir(follow_symlinks=follow_symlinks):
self._render_folder(dst_relpath)
else:
self._render_file(src_relpath, dst_relpath)
for dst_relpath, ctx in dst_relpaths_ctxs:
if self.match_exclude(dst_relpath):
continue
if src.is_symlink() and self.template.preserve_symlinks:
self._render_symlink(src_relpath, dst_relpath)
elif src.is_dir(follow_symlinks=follow_symlinks):
self._render_folder(dst_relpath)
else:
self._render_file(src_relpath, dst_relpath, extra_context=ctx or {})

def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None:
def _render_file(
self,
src_relpath: Path,
dst_relpath: Path,
extra_context: AnyByStrDict | None = None,
) -> None:
"""Render one file.

Args:
Expand All @@ -628,6 +631,8 @@ def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None:
dst_relpath:
File to be created. It must be a path relative to the subproject
root.
extra_context:
Additional variables to use for rendering the template.
"""
# TODO Get from main.render_file()
assert not src_relpath.is_absolute()
Expand All @@ -643,7 +648,9 @@ def _render_file(self, src_relpath: Path, dst_relpath: Path) -> None:
# suffix is empty, fallback to copy
new_content = src_abspath.read_bytes()
else:
new_content = tpl.render(**self._render_context()).encode()
new_content = tpl.render(
**self._render_context(), **(extra_context or {})
).encode()
else:
new_content = src_abspath.read_bytes()
dst_abspath = self.subproject.local_abspath / dst_relpath
Expand Down Expand Up @@ -715,8 +722,89 @@ def _render_folder(self, dst_relpath: Path) -> None:
dst_abspath = self.subproject.local_abspath / dst_relpath
dst_abspath.mkdir(parents=True, exist_ok=True)

def _render_path(self, relpath: Path) -> Path | None:
"""Render one relative path.
def _adjust_rendered_part(self, rendered_part: str) -> str:
"""Adjust the rendered part if necessary.

If {{ _copier_conf.answers_file }} becomes the full path,
restore part to be just the end leaf.

Args:
rendered_part:
The rendered part of the path to adjust.

"""
if str(self.answers_relpath) == rendered_part:
return Path(rendered_part).name
return rendered_part

def _render_parts(
self,
parts: tuple[str, ...],
rendered_parts: tuple[str, ...] | None = None,
extra_context: AnyByStrDict | None = None,
is_template: bool = False,
) -> Iterable[tuple[Path, AnyByStrDict | None]]:
"""Render a set of parts into path and context pairs.

If a yield tag is found in a part, it will recursively yield multiple path and context pairs.
"""
if rendered_parts is None:
rendered_parts = tuple()

if not parts:
rendered_path = Path(*rendered_parts)

templated_sibling = (
self.template.local_abspath
/ f"{rendered_path}{self.template.templates_suffix}"
)
if is_template or not templated_sibling.exists():
yield rendered_path, extra_context

return

part = parts[0]
parts = parts[1:]

if not extra_context:
extra_context = {}

rendered_part = self._render_string(part, extra_context=extra_context)

yield_context = self.jinja_env.yield_context.copy()
if yield_context:
single_var = yield_context["single_var"]
looped_var = yield_context["looped_var"]

for value in looped_var:
new_context = {**extra_context, **{single_var: value}}
rendered_part = self._render_string(part, extra_context=new_context)
self.jinja_env.yield_context = {}

rendered_part = self._adjust_rendered_part(rendered_part)

# Skip if any part is rendered as an empty string
if not rendered_part:
continue

yield from self._render_parts(
parts, rendered_parts + (rendered_part,), new_context, is_template
)

return

# Skip if any part is rendered as an empty string
if not rendered_part:
return

rendered_part = self._adjust_rendered_part(rendered_part)

yield from self._render_parts(
parts, rendered_parts + (rendered_part,), extra_context, is_template
)

def _render_path(self, relpath: Path) -> Iterable[tuple[Path, AnyByStrDict | None]]:
"""Render one relative path into multiple path and context pairs.

Args:
relpath:
Expand All @@ -728,29 +816,11 @@ def _render_path(self, relpath: Path) -> Path | None:
)
# With an empty suffix, the templated sibling always exists.
if templated_sibling.exists() and self.template.templates_suffix:
return None
return
if self.template.templates_suffix and is_template:
relpath = relpath.with_suffix("")
rendered_parts = []
for part in relpath.parts:
# Skip folder if any part is rendered as an empty string
part = self._render_string(part)
if not part:
return None
# {{ _copier_conf.answers_file }} becomes the full path; in that case,
# restore part to be just the end leaf
if str(self.answers_relpath) == part:
part = Path(part).name
rendered_parts.append(part)
result = Path(*rendered_parts)
if not is_template:
templated_sibling = (
self.template.local_abspath
/ f"{result}{self.template.templates_suffix}"
)
if templated_sibling.exists():
return None
return result

yield from self._render_parts(relpath.parts, is_template=is_template)

def _render_string(
self, string: str, extra_context: AnyByStrDict | None = None
Expand Down
Loading
Loading