diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index bf1b83c2c8..d1cba58a63 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -34,6 +34,7 @@ * Added `snow spcs service execute-job` command, which supports creating and executing a job service in the current schema. * Added `snow app events` command to fetch logs and traces from local and customer app installations. * Added support for external access (api integrations and secrets) in Streamlit. +* Added support for `<% ... %>` syntax in SQL templating. * Support multiple Streamlit application in single snowflake.yml project definition file. ## Fixes and improvements diff --git a/src/snowflake/cli/_plugins/nativeapp/manager.py b/src/snowflake/cli/_plugins/nativeapp/manager.py index 46575240a5..0bc7ccf604 100644 --- a/src/snowflake/cli/_plugins/nativeapp/manager.py +++ b/src/snowflake/cli/_plugins/nativeapp/manager.py @@ -23,7 +23,7 @@ from functools import cached_property from pathlib import Path from textwrap import dedent -from typing import Any, Generator, List, NoReturn, Optional, TypedDict +from typing import Any, Callable, Dict, Generator, List, NoReturn, Optional, TypedDict import jinja2 from click import ClickException @@ -67,7 +67,6 @@ ) from snowflake.cli._plugins.stage.manager import StageManager from snowflake.cli._plugins.stage.utils import print_diff_to_console -from snowflake.cli.api.cli_global_context import get_cli_context from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.errno import ( DOES_NOT_EXIST_OR_CANNOT_BE_PERFORMED, @@ -84,9 +83,13 @@ identifier_for_url, unquote_identifier, ) +from snowflake.cli.api.rendering.jinja import ( + jinja_render_from_str, +) from snowflake.cli.api.rendering.sql_templates import ( - get_sql_cli_jinja_env, + snowflake_sql_jinja_render, ) +from snowflake.cli.api.secure_path import UNLIMITED, SecurePath from snowflake.cli.api.sql_execution import SqlExecutionMixin from snowflake.connector import DictCursor, ProgrammingError @@ -576,30 +579,36 @@ def create_app_package(self) -> None: ) ) - def _expand_script_templates( - self, env: jinja2.Environment, jinja_context: dict[str, Any], scripts: List[str] + def _render_script_templates( + self, + render_from_str: Callable[[str, Dict[str, Any]], str], + jinja_context: dict[str, Any], + scripts: List[str], ) -> List[str]: """ Input: - - env: Jinja2 environment + - render_from_str: function which renders a jinja template from a string and jinja context - jinja_context: a dictionary with the jinja context - - scripts: list of scripts that need to be expanded with Jinja + - scripts: list of script paths relative to the project root Returns: - - List of expanded scripts content. + - List of rendered scripts content Size of the return list is the same as the size of the input scripts list. """ scripts_contents = [] for relpath in scripts: + script_full_path = SecurePath(self.project_root) / relpath try: - template = env.get_template(relpath) - result = template.render(**jinja_context) + template_content = script_full_path.read_text( + file_size_limit_mb=UNLIMITED + ) + result = render_from_str(template_content, jinja_context) scripts_contents.append(result) - except jinja2.TemplateNotFound as e: - raise MissingScriptError(e.name) from e + except FileNotFoundError as e: + raise MissingScriptError(relpath) from e except jinja2.TemplateSyntaxError as e: - raise InvalidScriptError(e.name, e, e.lineno) from e + raise InvalidScriptError(relpath, e, e.lineno) from e except jinja2.UndefinedError as e: raise InvalidScriptError(relpath, e) from e @@ -617,14 +626,10 @@ def _apply_package_scripts(self) -> None: "WARNING: native_app.package.scripts is deprecated. Please migrate to using native_app.package.post_deploy." ) - env = jinja2.Environment( - loader=jinja2.loaders.FileSystemLoader(self.project_root), - keep_trailing_newline=True, - undefined=jinja2.StrictUndefined, - ) - - queued_queries = self._expand_script_templates( - env, dict(package_name=self.package_name), self.package_scripts + queued_queries = self._render_script_templates( + jinja_render_from_str, + dict(package_name=self.package_name), + self.package_scripts, ) # once we're sure all the templates expanded correctly, execute all of them @@ -678,11 +683,10 @@ def _execute_post_deploy_hooks( f"Unsupported {deployed_object_type} post-deploy hook type: {hook}" ) - env = get_sql_cli_jinja_env( - loader=jinja2.loaders.FileSystemLoader(self.project_root) - ) - scripts_content_list = self._expand_script_templates( - env, get_cli_context().template_context, sql_scripts_paths + scripts_content_list = self._render_script_templates( + snowflake_sql_jinja_render, + {}, + sql_scripts_paths, ) for index, sql_script_path in enumerate(sql_scripts_paths): diff --git a/src/snowflake/cli/api/rendering/jinja.py b/src/snowflake/cli/api/rendering/jinja.py index 299cb8ac8c..e65bf6ceac 100644 --- a/src/snowflake/cli/api/rendering/jinja.py +++ b/src/snowflake/cli/api/rendering/jinja.py @@ -17,7 +17,7 @@ from pathlib import Path from textwrap import dedent -from typing import Dict, Optional +from typing import Any, Dict, Optional import jinja2 from jinja2 import Environment, StrictUndefined, loaders @@ -82,8 +82,32 @@ def getitem(self, obj, argument): return self.undefined(obj=obj, name=argument) +def _get_jinja_env(loader: Optional[loaders.BaseLoader] = None) -> Environment: + return env_bootstrap( + IgnoreAttrEnvironment( + loader=loader or loaders.BaseLoader(), + keep_trailing_newline=True, + undefined=StrictUndefined, + ) + ) + + +def jinja_render_from_str(template_content: str, data: Dict[str, Any]) -> str: + """ + Renders a jinja template and outputs either the rendered contents as string or writes to a file. + + Args: + template_content (str): template contents + data (dict): A dictionary of jinja variables and their actual values + + Returns: + None if file path is provided, else returns the rendered string. + """ + return _get_jinja_env().from_string(template_content).render(data) + + def jinja_render_from_file( - template_path: Path, data: Dict, output_file_path: Optional[Path] = None + template_path: Path, data: Dict[str, Any], output_file_path: Optional[Path] = None ) -> Optional[str]: """ Renders a jinja template and outputs either the rendered contents as string or writes to a file. @@ -96,12 +120,8 @@ def jinja_render_from_file( Returns: None if file path is provided, else returns the rendered string. """ - env = env_bootstrap( - IgnoreAttrEnvironment( - loader=loaders.FileSystemLoader(template_path.parent), - keep_trailing_newline=True, - undefined=StrictUndefined, - ) + env = _get_jinja_env( + loader=loaders.FileSystemLoader(template_path.parent.as_posix()) ) loaded_template = env.get_template(template_path.name) rendered_result = loaded_template.render(**data) diff --git a/src/snowflake/cli/api/rendering/sql_templates.py b/src/snowflake/cli/api/rendering/sql_templates.py index b2eea68e75..f832417670 100644 --- a/src/snowflake/cli/api/rendering/sql_templates.py +++ b/src/snowflake/cli/api/rendering/sql_templates.py @@ -14,11 +14,13 @@ from __future__ import annotations -from typing import Dict, Optional +from typing import Dict from click import ClickException -from jinja2 import StrictUndefined, loaders +from jinja2 import Environment, StrictUndefined, loaders, meta from snowflake.cli.api.cli_global_context import get_cli_context +from snowflake.cli.api.console.console import cli_console +from snowflake.cli.api.exceptions import InvalidTemplate from snowflake.cli.api.rendering.jinja import ( CONTEXT_KEY, FUNCTION_KEY, @@ -26,26 +28,52 @@ env_bootstrap, ) -_SQL_TEMPLATE_START = "&{" -_SQL_TEMPLATE_END = "}" +_SQL_TEMPLATE_START = "<%" +_SQL_TEMPLATE_END = "%>" +_OLD_SQL_TEMPLATE_START = "&{" +_OLD_SQL_TEMPLATE_END = "}" RESERVED_KEYS = [CONTEXT_KEY, FUNCTION_KEY] -def get_sql_cli_jinja_env(*, loader: Optional[loaders.BaseLoader] = None): +def _get_sql_jinja_env(template_start: str, template_end: str) -> Environment: _random_block = "___very___unique___block___to___disable___logic___blocks___" return env_bootstrap( IgnoreAttrEnvironment( - loader=loader or loaders.BaseLoader(), - keep_trailing_newline=True, - variable_start_string=_SQL_TEMPLATE_START, - variable_end_string=_SQL_TEMPLATE_END, + variable_start_string=template_start, + variable_end_string=template_end, + loader=loaders.BaseLoader(), block_start_string=_random_block, block_end_string=_random_block, + keep_trailing_newline=True, undefined=StrictUndefined, ) ) +def _does_template_have_env_syntax(env: Environment, template_content: str) -> bool: + template = env.parse(template_content) + return bool(meta.find_undeclared_variables(template)) + + +def choose_sql_jinja_env_based_on_template_syntax(template_content: str) -> Environment: + old_syntax_env = _get_sql_jinja_env(_OLD_SQL_TEMPLATE_START, _OLD_SQL_TEMPLATE_END) + new_syntax_env = _get_sql_jinja_env(_SQL_TEMPLATE_START, _SQL_TEMPLATE_END) + has_old_syntax = _does_template_have_env_syntax(old_syntax_env, template_content) + has_new_syntax = _does_template_have_env_syntax(new_syntax_env, template_content) + if has_old_syntax and has_new_syntax: + raise InvalidTemplate( + f"The SQL query mixes {_OLD_SQL_TEMPLATE_START} ... {_OLD_SQL_TEMPLATE_END} syntax" + f" and {_SQL_TEMPLATE_START} ... {_SQL_TEMPLATE_END} syntax." + ) + if has_old_syntax: + cli_console.warning( + f"Warning: {_OLD_SQL_TEMPLATE_START} ... {_OLD_SQL_TEMPLATE_END} syntax is deprecated." + f" Use {_SQL_TEMPLATE_START} ... {_SQL_TEMPLATE_END} syntax instead." + ) + return old_syntax_env + return new_syntax_env + + def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: data = data or {} @@ -57,4 +85,5 @@ def snowflake_sql_jinja_render(content: str, data: Dict | None = None) -> str: context_data = get_cli_context().template_context context_data.update(data) - return get_sql_cli_jinja_env().from_string(content).render(**context_data) + env = choose_sql_jinja_env_based_on_template_syntax(content) + return env.from_string(content).render(context_data) diff --git a/tests/test_sql.py b/tests/test_sql.py index a1ba6cd26b..12169b6115 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from unittest import mock @@ -321,6 +320,7 @@ def test_use_command(mock_execute_query, _object): "select &{ aaa }.&{ bbb }", "select &aaa.&bbb", "select &aaa.&{ bbb }", + "select <% aaa %>.<% bbb %>", ], ) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") @@ -332,7 +332,29 @@ def test_rendering_of_sql(mock_execute_query, query, runner): ) -@pytest.mark.parametrize("query", ["select &{ foo }", "select &foo"]) +@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") +def test_old_template_syntax_causes_warning(mock_execute_query, runner): + result = runner.invoke(["sql", "-q", "select &{ aaa }", "-D", "aaa=foo"]) + assert result.exit_code == 0 + assert ( + "Warning: &{ ... } syntax is deprecated. Use <% ... %> syntax instead." + in result.output + ) + mock_execute_query.assert_called_once_with("select foo", cursor_class=VerboseCursor) + + +@mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") +def test_mixed_template_syntax_error(mock_execute_query, runner): + result = runner.invoke( + ["sql", "-q", "select <% aaa %>.&{ bbb }", "-D", "aaa=foo", "-D", "bbb=bar"] + ) + assert result.exit_code == 1 + assert "The SQL query mixes &{ ... } syntax and <% ... %> syntax." in result.output + + +@pytest.mark.parametrize( + "query", ["select &{ foo }", "select &foo", "select <% foo %>"] +) def test_execution_fails_if_unknown_variable(runner, query): result = runner.invoke(["sql", "-q", query, "-D", "bbb=1"]) assert "SQL template rendering error: 'foo' is undefined" in result.output @@ -356,12 +378,15 @@ def test_snowsql_compatibility(text, expected): assert transpile_snowsql_templates(text) == expected +@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")]) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") def test_uses_variables_from_snowflake_yml( - mock_execute_query, project_directory, runner + mock_execute_query, project_directory, runner, template_start, template_end ): with project_directory("sql_templating"): - result = runner.invoke(["sql", "-q", "select &{ ctx.env.sf_var }"]) + result = runner.invoke( + ["sql", "-q", f"select {template_start} ctx.env.sf_var {template_end}"] + ) assert result.exit_code == 0 mock_execute_query.assert_called_once_with( @@ -369,12 +394,19 @@ def test_uses_variables_from_snowflake_yml( ) +@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")]) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") def test_uses_variables_from_snowflake_local_yml( - mock_execute_query, project_directory, runner + mock_execute_query, project_directory, runner, template_start, template_end ): with project_directory("sql_templating"): - result = runner.invoke(["sql", "-q", "select &{ ctx.env.sf_var_override }"]) + result = runner.invoke( + [ + "sql", + "-q", + f"select {template_start} ctx.env.sf_var_override {template_end}", + ] + ) assert result.exit_code == 0 mock_execute_query.assert_called_once_with( @@ -382,16 +414,17 @@ def test_uses_variables_from_snowflake_local_yml( ) +@pytest.mark.parametrize("template_start,template_end", [("&{", "}"), ("<%", "%>")]) @mock.patch("snowflake.cli._plugins.sql.commands.SqlManager._execute_string") def test_uses_variables_from_cli_are_added_outside_context( - mock_execute_query, project_directory, runner + mock_execute_query, project_directory, runner, template_start, template_end ): with project_directory("sql_templating"): result = runner.invoke( [ "sql", "-q", - "select &{ ctx.env.sf_var } &{ other }", + f"select {template_start} ctx.env.sf_var {template_end} {template_start} other {template_end}", "-D", "other=other_value", ] diff --git a/tests_integration/test_sql_templating.py b/tests_integration/test_sql_templating.py index 7d69acdf1f..2d0de60a5d 100644 --- a/tests_integration/test_sql_templating.py +++ b/tests_integration/test_sql_templating.py @@ -18,7 +18,7 @@ @pytest.mark.integration def test_sql_env_value_from_cli_param(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'", "--env", "test=value_from_cli"] + ["sql", "-q", "select '<% ctx.env.test %>'", "--env", "test=value_from_cli"] ) assert result.exit_code == 0 @@ -28,7 +28,7 @@ def test_sql_env_value_from_cli_param(runner, snowflake_session): @pytest.mark.integration def test_sql_env_value_from_cli_param_that_is_blank(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'", "--env", "test="] + ["sql", "-q", "select '<% ctx.env.test %>'", "--env", "test="] ) assert result.exit_code == 0 @@ -38,7 +38,7 @@ def test_sql_env_value_from_cli_param_that_is_blank(runner, snowflake_session): @pytest.mark.integration def test_sql_undefined_env_causing_error(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'"] + ["sql", "-q", "select '<% ctx.env.test %>'"] ) assert result.exit_code == 1 @@ -48,7 +48,7 @@ def test_sql_undefined_env_causing_error(runner, snowflake_session): @pytest.mark.integration def test_sql_env_value_from_os_env(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'"], env={"test": "value_from_os_env"} + ["sql", "-q", "select '<% ctx.env.test %>'"], env={"test": "value_from_os_env"} ) assert result.exit_code == 0 @@ -58,7 +58,7 @@ def test_sql_env_value_from_os_env(runner, snowflake_session): @pytest.mark.integration def test_sql_env_value_from_cli_param_overriding_os_env(runner, snowflake_session): result = runner.invoke_with_connection_json( - ["sql", "-q", "select '&{ctx.env.test}'", "--env", "test=value_from_cli"], + ["sql", "-q", "select '<% ctx.env.test %>'", "--env", "test=value_from_cli"], env={"test": "value_from_os_env"}, ) @@ -72,7 +72,7 @@ def test_sql_env_value_from_cli_duplicate_arg(runner, snowflake_session): [ "sql", "-q", - "select '&{ctx.env.Test}'", + "select '<% ctx.env.Test %>'", "--env", "Test=firstArg", "--env", @@ -84,13 +84,16 @@ def test_sql_env_value_from_cli_duplicate_arg(runner, snowflake_session): assert result.json == [{"'SECONDARG'": "secondArg"}] +@pytest.mark.parametrize("t_start,t_end", [("&{", "}"), ("<%", "%>")]) @pytest.mark.integration -def test_sql_env_value_from_cli_multiple_args(runner, snowflake_session): +def test_sql_env_value_from_cli_multiple_args( + runner, snowflake_session, t_start, t_end +): result = runner.invoke_with_connection_json( [ "sql", "-q", - "select '&{ctx.env.Test1}-&{ctx.env.Test2}'", + f"select '{t_start}ctx.env.Test1{t_end}-{t_start}ctx.env.Test2{t_end}'", "--env", "Test1=test1", "--env",