From 606b2ea133f2d0a787c2c2a756b6fac9a6060ec3 Mon Sep 17 00:00:00 2001 From: josix Date: Fri, 25 Oct 2024 23:57:18 +0800 Subject: [PATCH 1/2] refactor(utils/decorators): rewrite remove task decorator to use ast --- airflow/utils/decorators.py | 67 +++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index e299999423e5..e449d238f73c 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -17,55 +17,50 @@ # under the License. from __future__ import annotations +import ast import sys -from collections import deque from typing import Callable, TypeVar T = TypeVar("T", bound=Callable) +class _TaskDecoratorRemover(ast.NodeTransformer): + def __init__(self, task_decorator_name): + self.decorators_to_remove = { + "setup", + "teardown", + "task.skip_if", + "task.run_if", + task_decorator_name, + } + + def visit_FunctionDef(self, node): + node.decorator_list = [ + decorator for decorator in node.decorator_list if not self._is_task_decorator(decorator) + ] + return self.generic_visit(node) + + def _is_task_decorator(self, decorator): + if isinstance(decorator, ast.Name): + return decorator.id in self.decorators_to_remove + elif isinstance(decorator, ast.Attribute): + return f"{decorator.value.id}.{decorator.attr}" in self.decorators_to_remove + elif isinstance(decorator, ast.Call): + return self._is_task_decorator(decorator.func) + return False + + def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: """ Remove @task or similar decorators as well as @setup and @teardown. :param python_source: The python source code :param task_decorator_name: the decorator name - - TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse """ - - def _remove_task_decorator(py_source, decorator_name): - # if no line starts with @decorator_name, we can early exit - for line in py_source.split("\n"): - if line.startswith(decorator_name): - break - else: - return python_source - split = python_source.split(decorator_name, 1) - before_decorator, after_decorator = split[0], split[1] - if after_decorator[0] == "(": - after_decorator = _balance_parens(after_decorator) - if after_decorator[0] == "\n": - after_decorator = after_decorator[1:] - return before_decorator + after_decorator - - decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] - for decorator in decorators: - python_source = _remove_task_decorator(python_source, decorator) - return python_source - - -def _balance_parens(after_decorator): - num_paren = 1 - after_decorator = deque(after_decorator) - after_decorator.popleft() - while num_paren: - current = after_decorator.popleft() - if current == "(": - num_paren = num_paren + 1 - elif current == ")": - num_paren = num_paren - 1 - return "".join(after_decorator) + tree = ast.parse(python_source) + remover = _TaskDecoratorRemover(task_decorator_name.strip("@")) + mutated_tree = remover.visit(tree) + return ast.unparse(mutated_tree) class _autostacklevel_warn: From 28d8f2cba9500af5a9677af3f28953738bf1bdc3 Mon Sep 17 00:00:00 2001 From: josix Date: Sat, 26 Oct 2024 04:54:09 +0800 Subject: [PATCH 2/2] fixup! refactor(utils/decorators): rewrite remove task decorator to use ast --- tests/utils/test_decorators.py | 14 ++++++------- ...preexisting_python_virtualenv_decorator.py | 16 +++++++------- tests/utils/test_python_virtualenv.py | 21 +++++++------------ 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index 19d3ec31d031..a0a8ea263427 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -49,7 +49,7 @@ def test_task_decorator_using_source(decorator: TaskDecorator): def f(): return ["some_task"] - assert parse_python_source(f, "decorator") == 'def f():\n return ["some_task"]\n' + assert parse_python_source(f, "decorator") == "def f():\n return ['some_task']" @pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"]) @@ -59,7 +59,7 @@ def test_skip_if(decorator: TaskDecorator): def f(): return "hello world" - assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n' + assert parse_python_source(f, "decorator") == "def f():\n return 'hello world'" @pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"]) @@ -69,7 +69,7 @@ def test_run_if(decorator: TaskDecorator): def f(): return "hello world" - assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n' + assert parse_python_source(f, "decorator") == "def f():\n return 'hello world'" def test_skip_if_and_run_if(): @@ -79,7 +79,7 @@ def test_skip_if_and_run_if(): def f(): return "hello world" - assert parse_python_source(f) == 'def f():\n return "hello world"\n' + assert parse_python_source(f) == "def f():\n return 'hello world'" def test_run_if_and_skip_if(): @@ -89,7 +89,7 @@ def test_run_if_and_skip_if(): def f(): return "hello world" - assert parse_python_source(f) == 'def f():\n return "hello world"\n' + assert parse_python_source(f) == "def f():\n return 'hello world'" def test_skip_if_allow_decorator(): @@ -102,7 +102,7 @@ def non_task_decorator(func): def f(): return "hello world" - assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n' + assert parse_python_source(f) == "@non_task_decorator\ndef f():\n return 'hello world'" def test_run_if_allow_decorator(): @@ -115,7 +115,7 @@ def non_task_decorator(func): def f(): return "hello world" - assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n' + assert parse_python_source(f) == "@non_task_decorator\ndef f():\n return 'hello world'" def parse_python_source(task: Task, custom_operator_name: str | None = None) -> str: diff --git a/tests/utils/test_preexisting_python_virtualenv_decorator.py b/tests/utils/test_preexisting_python_virtualenv_decorator.py index 2e97469958fd..1a8aa0fa8229 100644 --- a/tests/utils/test_preexisting_python_virtualenv_decorator.py +++ b/tests/utils/test_preexisting_python_virtualenv_decorator.py @@ -22,20 +22,20 @@ class TestExternalPythonDecorator: def test_remove_task_decorator(self): - py_source = "@task.external_python(use_dill=True)\ndef f():\nimport funcsigs" + py_source = "@task.external_python(use_dill=True)\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n ...\nimport funcsigs" def test_remove_decorator_no_parens(self): - py_source = "@task.external_python\ndef f():\nimport funcsigs" + py_source = "@task.external_python\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n ...\nimport funcsigs" def test_remove_decorator_nested(self): - py_source = "@foo\n@task.external_python\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.external_python\n@bar\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n ...\nimport funcsigs" - py_source = "@foo\n@task.external_python()\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.external_python()\n@bar\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n ...\nimport funcsigs" diff --git a/tests/utils/test_python_virtualenv.py b/tests/utils/test_python_virtualenv.py index 38cda4854baf..4d683b7ddd1f 100644 --- a/tests/utils/test_python_virtualenv.py +++ b/tests/utils/test_python_virtualenv.py @@ -116,25 +116,20 @@ def test_should_create_virtualenv_with_extra_packages(self, mock_execute_in_subp mock_execute_in_subprocess.assert_called_with(["/VENV/bin/pip", "install", "apache-beam[gcp]"]) def test_remove_task_decorator(self): - py_source = "@task.virtualenv(use_dill=True)\ndef f():\nimport funcsigs" + py_source = "@task.virtualenv(use_dill=True)\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\nimport funcsigs" + assert res == "def f():\n ...\nimport funcsigs" def test_remove_decorator_no_parens(self): - py_source = "@task.virtualenv\ndef f():\nimport funcsigs" + py_source = "@task.virtualenv\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\nimport funcsigs" - - def test_remove_decorator_including_comment(self): - py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs" - res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\n# @task.virtualenv\nimport funcsigs" + assert res == "def f():\n ...\nimport funcsigs" def test_remove_decorator_nested(self): - py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.virtualenv\n@bar\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n ...\nimport funcsigs" - py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs" + py_source = "@foo\n@task.virtualenv()\n@bar\ndef f(): ...\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "@foo\n@bar\ndef f():\nimport funcsigs" + assert res == "@foo\n@bar\ndef f():\n ...\nimport funcsigs"