From 56187a6461ca9475f2954dacf9e7b09b7d00a26d Mon Sep 17 00:00:00 2001 From: JasonGrace2282 Date: Mon, 22 Jul 2024 07:03:38 -0400 Subject: [PATCH 1/6] Improve matmul for PolarPlane --- manim/mobject/graphing/coordinate_systems.py | 18 ++++++++++++++--- .../graphing/test_coordinate_system.py | 20 +++++++++++++++++++ .../test_coordinate_systems.py | 16 ++++++++++++++- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index fa07c7fd53..f7f3456f7d 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -56,6 +56,8 @@ from manim.utils.space_ops import angle_of_vector if TYPE_CHECKING: + import numpy.typing as npt + from manim.mobject.mobject import Mobject from manim.typing import ManimFloat, Point2D, Point3D, Vector3D @@ -1836,13 +1838,20 @@ def construct(self): return T_label_group - def __matmul__(self, coord: Point3D | Mobject): + _matmul_method = "coords_to_point" + _rmatmul_method = "point_to_coords" + + def __matmul__(self, coord: Sequence[float] | Mobject | npt.NDArray[np.float64]): if isinstance(coord, Mobject): coord = coord.get_center() - return self.coords_to_point(*coord) + method = getattr(self, self._matmul_method) + assert callable(method) + return method(*coord) def __rmatmul__(self, point: Point3D): - return self.point_to_coords(point) + method = getattr(self, self._rmatmul_method) + assert callable(method) + return method(point) class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): @@ -2984,6 +2993,9 @@ def construct(self): self.add(polarplane_pi) """ + _matmul_method = "polar_to_point" + _rmatmul_method = "point_to_polar" + def __init__( self, radius_max: float = config["frame_y_radius"], diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index 470d9d0074..6202cf1b11 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -14,6 +14,7 @@ Circle, ComplexPlane, Dot, + NumberLine, NumberPlane, PolarPlane, ThreeDAxes, @@ -192,3 +193,22 @@ def test_input_to_graph_point(): # test the line_graph implementation position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4) np.testing.assert_array_equal(position, (2.6928, 1.2876, 0)) + + +def test_matmul_operations(): + ax = Axes() + assert (ax @ (1, 2) == ax.coords_to_point(1, 2)).all() + # should work with mobjects too, using their center + mob = Dot().move_to((1, 2, 0)) + assert (ax @ mob == ax.coords_to_point(1, 2)).all() + + # other coordinate systems like PolarPlane should override __matmul__ indirectly + polar = PolarPlane() + # radius, azimuthal angle + assert (polar @ (1, 2) == polar.polar_to_point(1, 2)).all() + + # Numberline doesn't inherit from CoordinateSystem, but it should still work + n = NumberLine() + assert (n @ 3 == n.number_to_point(3)).all() + mob = Dot().move_to(n @ 3) + assert mob @ n == n.point_to_number(mob.get_center()) diff --git a/tests/test_graphical_units/test_coordinate_systems.py b/tests/test_graphical_units/test_coordinate_systems.py index 7d6dad67af..2d9b5cd947 100644 --- a/tests/test_graphical_units/test_coordinate_systems.py +++ b/tests/test_graphical_units/test_coordinate_systems.py @@ -1,6 +1,20 @@ from __future__ import annotations -from manim import * +from manim import ( + BLUE, + GREEN, + ORANGE, + RED, + UL, + YELLOW, + Axes, + LogBase, + NumberPlane, + ThreeDAxes, + ThreeDScene, + VGroup, + np, +) from manim.utils.testing.frames_comparison import frames_comparison __module_test__ = "coordinate_system" From 827d9f5d5265c6f6012794c1f2993656fa6f0b2b Mon Sep 17 00:00:00 2001 From: JasonGrace2282 Date: Mon, 22 Jul 2024 07:19:23 -0400 Subject: [PATCH 2/6] Fixes for ComplexPlane --- manim/mobject/graphing/coordinate_systems.py | 53 ++++++++++++++----- .../graphing/test_coordinate_system.py | 6 ++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index f7f3456f7d..9dc3e19ddf 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload import numpy as np -from typing_extensions import Self +from typing_extensions import Self, TypedDict from manim import config from manim.constants import * @@ -56,14 +56,25 @@ from manim.utils.space_ops import angle_of_vector if TYPE_CHECKING: - import numpy.typing as npt - from manim.mobject.mobject import Mobject from manim.typing import ManimFloat, Point2D, Point3D, Vector3D LineType = TypeVar("LineType", bound=Line) +class _MatmulConfig(TypedDict): + """A dictionary for configuring the __matmul__/__rmatmul__ operation. + + Parameters + ---------- + method: The method to call + unpack: whether to unpack the parameter given to __matmul__/__rmatmul__ + """ + + method: str + unpack: bool + + class CoordinateSystem: r"""Abstract base class for Axes and NumberPlane. @@ -1838,20 +1849,29 @@ def construct(self): return T_label_group - _matmul_method = "coords_to_point" - _rmatmul_method = "point_to_coords" + _matmul_config: _MatmulConfig = { + "method": "coords_to_point", + "unpack": True, + } + _rmatmul_config: _MatmulConfig = {"method": "point_to_coords", "unpack": False} - def __matmul__(self, coord: Sequence[float] | Mobject | npt.NDArray[np.float64]): + def __matmul__(self, coord): if isinstance(coord, Mobject): coord = coord.get_center() - method = getattr(self, self._matmul_method) + method = getattr(self, self._matmul_config["method"]) assert callable(method) - return method(*coord) + return ( + method(*coord) if self._matmul_config.get("unpack", True) else method(coord) + ) - def __rmatmul__(self, point: Point3D): - method = getattr(self, self._rmatmul_method) + def __rmatmul__(self, point): + method = getattr(self, self._rmatmul_config["method"]) assert callable(method) - return method(point) + return ( + method(*point) + if self._rmatmul_config.get("unpack", False) + else method(point) + ) class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): @@ -2993,8 +3013,11 @@ def construct(self): self.add(polarplane_pi) """ - _matmul_method = "polar_to_point" - _rmatmul_method = "point_to_polar" + _matmul_config = { + "method": "polar_to_point", + "unpack": True, + } + _rmatmul_config = {"method": "point_to_polar", "unpack": False} def __init__( self, @@ -3366,6 +3389,10 @@ def construct(self): """ + _matmul_config = {"method": "number_to_point", "unpack": False} + + _rmatmul_config = {"method": "point_to_number", "unpack": False} + def __init__(self, **kwargs: Any) -> None: super().__init__( **kwargs, diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index 6202cf1b11..112f3f15ba 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -202,11 +202,13 @@ def test_matmul_operations(): mob = Dot().move_to((1, 2, 0)) assert (ax @ mob == ax.coords_to_point(1, 2)).all() - # other coordinate systems like PolarPlane should override __matmul__ indirectly + # other coordinate systems like PolarPlane and ComplexPlane should override __matmul__ indirectly polar = PolarPlane() - # radius, azimuthal angle assert (polar @ (1, 2) == polar.polar_to_point(1, 2)).all() + complx = ComplexPlane() + assert (complx @ (1 + 2j) == complx.number_to_point(1 + 2j)).all() + # Numberline doesn't inherit from CoordinateSystem, but it should still work n = NumberLine() assert (n @ 3 == n.number_to_point(3)).all() From f952d523cd76f23029cb0697625c86c4599fcb3d Mon Sep 17 00:00:00 2001 From: JasonGrace2282 Date: Mon, 22 Jul 2024 09:42:50 -0400 Subject: [PATCH 3/6] Add test for __rmatmul__ and add warning --- manim/mobject/graphing/coordinate_systems.py | 7 +++++ .../graphing/test_coordinate_system.py | 26 +++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index 9dc3e19ddf..fb67530900 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -1865,6 +1865,13 @@ def __matmul__(self, coord): ) def __rmatmul__(self, point): + """Perform a point-to-coords action for a coordinate scene. + + .. warning:: + + This will not work with NumPy arrays or other objects that + implement ``__matmul__``. + """ method = getattr(self, self._rmatmul_config["method"]) assert callable(method) return ( diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index 112f3f15ba..cbdef5a19d 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -212,5 +212,27 @@ def test_matmul_operations(): # Numberline doesn't inherit from CoordinateSystem, but it should still work n = NumberLine() assert (n @ 3 == n.number_to_point(3)).all() - mob = Dot().move_to(n @ 3) - assert mob @ n == n.point_to_number(mob.get_center()) + + +def test_rmatmul_operations(): + point = (1, 2, 0) + + ax = Axes() + assert (point @ ax == ax.point_to_coords(point)).all() + + polar = PolarPlane() + assert point @ polar == polar.point_to_polar(point) + + complx = ComplexPlane() + assert point @ complx == complx.point_to_number(point) + + n = NumberLine() + point = n @ 4 + + assert ( + tuple(point) @ n # ndarray overrides __matmul__ + == n.point_to_number(point) + ).all() + + mob = Dot().move_to(point) + assert (mob @ n == n.point_to_number(mob.get_center())).all() From ad78681e4dbda12eb6923b9e24ac9db2471d7807 Mon Sep 17 00:00:00 2001 From: JasonGrace2282 Date: Thu, 25 Jul 2024 15:11:25 -0400 Subject: [PATCH 4/6] Rewrite using numpy.testing.assert_equal --- .../graphing/test_coordinate_system.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index cbdef5a19d..06ecfd6409 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -3,6 +3,7 @@ import math import numpy as np +import numpy.testing as nt import pytest from manim import ( @@ -197,42 +198,42 @@ def test_input_to_graph_point(): def test_matmul_operations(): ax = Axes() - assert (ax @ (1, 2) == ax.coords_to_point(1, 2)).all() + nt.assert_equal(ax @ (1, 2), ax.coords_to_point(1, 2)) # should work with mobjects too, using their center mob = Dot().move_to((1, 2, 0)) - assert (ax @ mob == ax.coords_to_point(1, 2)).all() + nt.assert_equal(ax @ mob, ax.coords_to_point(1, 2)) # other coordinate systems like PolarPlane and ComplexPlane should override __matmul__ indirectly polar = PolarPlane() - assert (polar @ (1, 2) == polar.polar_to_point(1, 2)).all() + nt.assert_equal(polar @ (1, 2), polar.polar_to_point(1, 2)) complx = ComplexPlane() - assert (complx @ (1 + 2j) == complx.number_to_point(1 + 2j)).all() + nt.assert_equal(complx @ (1 + 2j), complx.number_to_point(1 + 2j)) # Numberline doesn't inherit from CoordinateSystem, but it should still work n = NumberLine() - assert (n @ 3 == n.number_to_point(3)).all() + nt.assert_equal(n @ 3, n.number_to_point(3)) def test_rmatmul_operations(): point = (1, 2, 0) ax = Axes() - assert (point @ ax == ax.point_to_coords(point)).all() + nt.assert_equal(point @ ax, ax.point_to_coords(point)) polar = PolarPlane() assert point @ polar == polar.point_to_polar(point) complx = ComplexPlane() - assert point @ complx == complx.point_to_number(point) + nt.assert_equal(point @ complx, complx.point_to_number(point)) n = NumberLine() point = n @ 4 - assert ( - tuple(point) @ n # ndarray overrides __matmul__ - == n.point_to_number(point) - ).all() + nt.assert_equal( + tuple(point) @ n, # ndarray overrides __matmul__ + n.point_to_number(point), + ) mob = Dot().move_to(point) - assert (mob @ n == n.point_to_number(mob.get_center())).all() + nt.assert_equal(mob @ n, n.point_to_number(mob.get_center())) From 1c68e522c7bb0aa947109cfde445cae990180548 Mon Sep 17 00:00:00 2001 From: JasonGrace2282 Date: Mon, 29 Jul 2024 12:43:16 -0400 Subject: [PATCH 5/6] Use typing_extensions.NotRequired to reduce boilerplate --- manim/mobject/graphing/coordinate_systems.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index fb67530900..96b3e0c92c 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload import numpy as np -from typing_extensions import Self, TypedDict +from typing_extensions import NotRequired, Self, TypedDict from manim import config from manim.constants import * @@ -72,7 +72,7 @@ class _MatmulConfig(TypedDict): """ method: str - unpack: bool + unpack: NotRequired[bool] class CoordinateSystem: @@ -3022,9 +3022,8 @@ def construct(self): _matmul_config = { "method": "polar_to_point", - "unpack": True, } - _rmatmul_config = {"method": "point_to_polar", "unpack": False} + _rmatmul_config = {"method": "point_to_polar"} def __init__( self, @@ -3398,7 +3397,7 @@ def construct(self): _matmul_config = {"method": "number_to_point", "unpack": False} - _rmatmul_config = {"method": "point_to_number", "unpack": False} + _rmatmul_config = {"method": "point_to_number"} def __init__(self, **kwargs: Any) -> None: super().__init__( From ccdd84cfb885db64e42b07a06dd6b20fbc7b7d51 Mon Sep 17 00:00:00 2001 From: JasonGrace2282 Date: Sun, 27 Oct 2024 17:36:30 -0400 Subject: [PATCH 6/6] remove weird implementation --- manim/mobject/geometry/shape_matchers.py | 2 +- manim/mobject/graphing/coordinate_systems.py | 62 ++++++-------------- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/manim/mobject/geometry/shape_matchers.py b/manim/mobject/geometry/shape_matchers.py index b546dfb4f3..e72cf2775d 100644 --- a/manim/mobject/geometry/shape_matchers.py +++ b/manim/mobject/geometry/shape_matchers.py @@ -132,7 +132,7 @@ def set_style(self, fill_opacity: float, **kwargs: Any) -> Self: # type: ignore def get_fill_color(self) -> ManimColor: # The type of the color property is set to Any using the property decorator # vectorized_mobject.py#L571 - temp_color: ManimColor = self.color + temp_color: ManimColor = self.color # type: ignore[has-type] return temp_color diff --git a/manim/mobject/graphing/coordinate_systems.py b/manim/mobject/graphing/coordinate_systems.py index 96b3e0c92c..7d2ab7dbde 100644 --- a/manim/mobject/graphing/coordinate_systems.py +++ b/manim/mobject/graphing/coordinate_systems.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload import numpy as np -from typing_extensions import NotRequired, Self, TypedDict +from typing_extensions import Self from manim import config from manim.constants import * @@ -62,19 +62,6 @@ LineType = TypeVar("LineType", bound=Line) -class _MatmulConfig(TypedDict): - """A dictionary for configuring the __matmul__/__rmatmul__ operation. - - Parameters - ---------- - method: The method to call - unpack: whether to unpack the parameter given to __matmul__/__rmatmul__ - """ - - method: str - unpack: NotRequired[bool] - - class CoordinateSystem: r"""Abstract base class for Axes and NumberPlane. @@ -160,7 +147,7 @@ def __init__( self.y_length = y_length self.num_sampled_graph_points_per_tick = 10 - def coords_to_point(self, *coords: ManimFloat): + def coords_to_point(self, *coords: float): raise NotImplementedError() def point_to_coords(self, point: Point3D): @@ -1849,22 +1836,12 @@ def construct(self): return T_label_group - _matmul_config: _MatmulConfig = { - "method": "coords_to_point", - "unpack": True, - } - _rmatmul_config: _MatmulConfig = {"method": "point_to_coords", "unpack": False} - - def __matmul__(self, coord): + def __matmul__(self, coord: Iterable[float] | Mobject): if isinstance(coord, Mobject): coord = coord.get_center() - method = getattr(self, self._matmul_config["method"]) - assert callable(method) - return ( - method(*coord) if self._matmul_config.get("unpack", True) else method(coord) - ) + return self.coords_to_point(*coord) - def __rmatmul__(self, point): + def __rmatmul__(self, point: Point3D): """Perform a point-to-coords action for a coordinate scene. .. warning:: @@ -1872,13 +1849,7 @@ def __rmatmul__(self, point): This will not work with NumPy arrays or other objects that implement ``__matmul__``. """ - method = getattr(self, self._rmatmul_config["method"]) - assert callable(method) - return ( - method(*point) - if self._rmatmul_config.get("unpack", False) - else method(point) - ) + return self.point_to_coords(point) class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL): @@ -3020,11 +2991,6 @@ def construct(self): self.add(polarplane_pi) """ - _matmul_config = { - "method": "polar_to_point", - } - _rmatmul_config = {"method": "point_to_polar"} - def __init__( self, radius_max: float = config["frame_y_radius"], @@ -3368,6 +3334,12 @@ def get_radian_label(self, number, font_size: float = 24, **kwargs: Any) -> Math return MathTex(string, font_size=font_size, **kwargs) + def __matmul__(self, coord: Point2D): + return self.polar_to_point(*coord) + + def __rmatmul__(self, point: Point2D): + return self.point_to_polar(point) + class ComplexPlane(NumberPlane): """A :class:`~.NumberPlane` specialized for use with complex numbers. @@ -3395,10 +3367,6 @@ def construct(self): """ - _matmul_config = {"method": "number_to_point", "unpack": False} - - _rmatmul_config = {"method": "point_to_number"} - def __init__(self, **kwargs: Any) -> None: super().__init__( **kwargs, @@ -3444,6 +3412,12 @@ def p2n(self, point: Point3D) -> complex: """Abbreviation for :meth:`point_to_number`.""" return self.point_to_number(point) + def __matmul__(self, coord: float | complex): + return self.number_to_point(coord) + + def __rmatmul__(self, point: Point3D): + return self.point_to_number(point) + def _get_default_coordinate_values(self) -> list[float | complex]: """Generate a list containing the numerical values of the plane's labels.