diff --git a/CHANGES.md b/CHANGES.md index 405b71a6c2f..11e491e1c65 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -21,6 +21,7 @@ - Fix crashes involving comments in parenthesised return types or `X | Y` style unions. (#4453) +- Fix skipping Jupyter cells with unknown `%%` magic (#4462) ### Preview style diff --git a/src/black/__init__.py b/src/black/__init__.py index 2640e17c003..a94f7fc29a0 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -53,12 +53,12 @@ ) from black.handle_ipynb_magics import ( PYTHON_CELL_MAGICS, - TRANSFORMED_MAGICS, jupyter_dependencies_are_installed, mask_cell, put_trailing_semicolon_back, remove_trailing_semicolon, unmask_cell, + validate_cell, ) from black.linegen import LN, LineGenerator, transform_line from black.lines import EmptyLineTracker, LinesBlock @@ -1084,32 +1084,6 @@ def format_file_contents( return dst_contents -def validate_cell(src: str, mode: Mode) -> None: - """Check that cell does not already contain TransformerManager transformations, - or non-Python cell magics, which might cause tokenizer_rt to break because of - indentations. - - If a cell contains ``!ls``, then it'll be transformed to - ``get_ipython().system('ls')``. However, if the cell originally contained - ``get_ipython().system('ls')``, then it would get transformed in the same way: - - >>> TransformerManager().transform_cell("get_ipython().system('ls')") - "get_ipython().system('ls')\n" - >>> TransformerManager().transform_cell("!ls") - "get_ipython().system('ls')\n" - - Due to the impossibility of safely roundtripping in such situations, cells - containing transformed magics will be ignored. - """ - if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): - raise NothingChanged - if ( - src[:2] == "%%" - and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics - ): - raise NothingChanged - - def format_cell(src: str, *, fast: bool, mode: Mode) -> str: """Format code in given cell of Jupyter notebook. diff --git a/src/black/handle_ipynb_magics.py b/src/black/handle_ipynb_magics.py index 4fe88f1f76f..792d22595aa 100644 --- a/src/black/handle_ipynb_magics.py +++ b/src/black/handle_ipynb_magics.py @@ -3,6 +3,7 @@ import ast import collections import dataclasses +import re import secrets import sys from functools import lru_cache @@ -14,6 +15,7 @@ else: from typing_extensions import TypeGuard +from black.mode import Mode from black.output import out from black.report import NothingChanged @@ -64,6 +66,34 @@ def jupyter_dependencies_are_installed(*, warn: bool) -> bool: return installed +def validate_cell(src: str, mode: Mode) -> None: + """Check that cell does not already contain TransformerManager transformations, + or non-Python cell magics, which might cause tokenizer_rt to break because of + indentations. + + If a cell contains ``!ls``, then it'll be transformed to + ``get_ipython().system('ls')``. However, if the cell originally contained + ``get_ipython().system('ls')``, then it would get transformed in the same way: + + >>> TransformerManager().transform_cell("get_ipython().system('ls')") + "get_ipython().system('ls')\n" + >>> TransformerManager().transform_cell("!ls") + "get_ipython().system('ls')\n" + + Due to the impossibility of safely roundtripping in such situations, cells + containing transformed magics will be ignored. + """ + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): + raise NothingChanged + + line = _get_code_start(src) + if line.startswith("%%") and ( + line.split(maxsplit=1)[0][2:] + not in PYTHON_CELL_MAGICS | mode.python_cell_magics + ): + raise NothingChanged + + def remove_trailing_semicolon(src: str) -> tuple[str, bool]: """Remove trailing semicolon from Jupyter notebook cell. @@ -276,6 +306,21 @@ def unmask_cell(src: str, replacements: list[Replacement]) -> str: return src +def _get_code_start(src: str) -> str: + """Provides the first line where the code starts. + + Iterates over lines of code until it finds the first line that doesn't + contain only empty spaces and comments. It removes any empty spaces at the + start of the line and returns it. If such line doesn't exist, it returns an + empty string. + """ + for match in re.finditer(".+", src): + line = match.group(0).lstrip() + if line and not line.startswith("#"): + return line + return "" + + def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]: """Check if attribute is IPython magic. diff --git a/tests/test_ipynb.py b/tests/test_ipynb.py index 59897190304..bdc2f27fcdb 100644 --- a/tests/test_ipynb.py +++ b/tests/test_ipynb.py @@ -208,6 +208,22 @@ def test_cell_magic_with_custom_python_magic( assert result == expected_output +@pytest.mark.parametrize( + "src", + ( + " %%custom_magic \nx=2", + "\n\n%%custom_magic\nx=2", + "# comment\n%%custom_magic\nx=2", + "\n \n # comment with %%time\n\t\n %%custom_magic # comment \nx=2", + ), +) +def test_cell_magic_with_custom_python_magic_after_spaces_and_comments_noop( + src: str, +) -> None: + with pytest.raises(NothingChanged): + format_cell(src, fast=True, mode=JUPYTER_MODE) + + def test_cell_magic_nested() -> None: src = "%%time\n%%time\n2+2" result = format_cell(src, fast=True, mode=JUPYTER_MODE)