diff --git a/.github/manimdependency.json b/.github/manimdependency.json index 32e62e7ecf..93827c1502 100644 --- a/.github/manimdependency.json +++ b/.github/manimdependency.json @@ -4,7 +4,10 @@ "standalone", "preview", "doublestroke", - "ms", + "count1to", + "multitoc", + "prelim2e", + "ragged2e", "everysel", "setspace", "rsfs", @@ -29,7 +32,10 @@ "standalone", "preview", "doublestroke", - "ms", + "count1to", + "multitoc", + "prelim2e", + "ragged2e", "everysel", "setspace", "rsfs", diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 826e238d24..8977518d43 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,35 +12,22 @@ repos: - id: end-of-file-fixer - id: check-toml name: Validate Poetry - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - name: isort (python) - - id: isort - name: isort (cython) - types: [cython] - - id: isort - name: isort (pyi) - types: [pyi] - - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 - hooks: - - id: pyupgrade - name: Update code to new python versions - args: [--py39-plus] - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: - id: python-check-blanket-noqa name: Precision flake ignores - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.5 + rev: v0.4.10 hooks: + - id: ruff + name: ruff lint + types: [python] + args: [--exit-non-zero-on-fix] - id: ruff-format types: [python] - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + rev: 7.1.0 hooks: - id: flake8 additional_dependencies: diff --git a/conftest.py b/conftest.py index dacb730a29..683bd2bc05 100644 --- a/conftest.py +++ b/conftest.py @@ -5,12 +5,6 @@ from __future__ import annotations -try: - # https://github.com/moderngl/moderngl/issues/517 - import readline # required to prevent a segfault on Python 3.10 -except ModuleNotFoundError: # windows - pass - import cairo import moderngl diff --git a/docker/Dockerfile b/docker/Dockerfile index b41647e4e8..34d21fa0a0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,9 +22,9 @@ RUN wget -O /tmp/install-tl-unx.tar.gz http://mirror.ctan.org/systems/texlive/tl tar -xzf /tmp/install-tl-unx.tar.gz -C /tmp/install-tl --strip-components=1 && \ /tmp/install-tl/install-tl --profile=/tmp/texlive-profile.txt \ && tlmgr install \ - amsmath babel-english cbfonts-fd cm-super ctex doublestroke dvisvgm everysel \ + amsmath babel-english cbfonts-fd cm-super count1to ctex doublestroke dvisvgm everysel \ fontspec frcursive fundus-calligra gnu-freefont jknapltx latex-bin \ - mathastext microtype ms physics preview ragged2e relsize rsfs \ + mathastext microtype multitoc physics prelim2e preview ragged2e relsize rsfs \ setspace standalone tipa wasy wasysym xcolor xetex xkeyval # clone and build manim diff --git a/docs/source/contributing/docs.rst b/docs/source/contributing/docs.rst index a8bb61c535..aaac806a83 100644 --- a/docs/source/contributing/docs.rst +++ b/docs/source/contributing/docs.rst @@ -81,3 +81,4 @@ Index docs/examples docs/references docs/typings + docs/types diff --git a/docs/source/contributing/docs/types.rst b/docs/source/contributing/docs/types.rst new file mode 100644 index 0000000000..fb6f06732b --- /dev/null +++ b/docs/source/contributing/docs/types.rst @@ -0,0 +1,134 @@ +=================== +Choosing Type Hints +=================== +In order to provide the best user experience, +it's important that type hints are chosen correctly. +With the large variety of types provided by Manim, choosing +which one to use can be difficult. This guide aims to +aid you in the process of choosing the right type for the scenario. + + +The first step is figuring out which category your type hint fits into. + +Coordinates +----------- +Coordinates encompass two main categories: points, and vectors. + + +Points +~~~~~~ +The purpose of points is pretty straightforward: they represent a point +in space. For example: + +.. code-block:: python + + def status2D(coord: Point2D) -> None: + x, y = coord + print(f"Point at {x=},{y=}") + + + def status3D(coord: Point3D) -> None: + x, y, z = coord + print(f"Point at {x=},{y=},{z=}") + + + def get_statuses(coords: Point2D_Array | Point3D_Array) -> None: + for coord in coords: + if len(coord) == 2: + # it's a Point2D + status2D(coord) + else: + # it's a point3D + status3D(coord) + +It's important to realize that the status functions accepted both +tuples/lists of the correct length, and ``NDArray``'s of the correct shape. +If they only accepted ``NDArray``'s, we would use their ``Internal`` counterparts: +:class:`~.typing.InternalPoint2D`, :class:`~.typing.InternalPoint3D`, :class:`~.typing.InternalPoint2D_Array` and :class:`~.typing.InternalPoint3D_Array`. + +In general, the type aliases prefixed with ``Internal`` should never be used on +user-facing classes and functions, but should be reserved for internal behavior. + +Vectors +~~~~~~~ +Vectors share many similarities to points. However, they have a different +connotation. Vectors should be used to represent direction. For example, +consider this slightly contrived function: + +.. code-block:: python + + def shift_mobject(mob: Mobject, direction: Vector3D, scale_factor: float = 1) -> mob: + return mob.shift(direction * scale_factor) + +Here we see an important example of the difference. ``direction`` can not, and +should not, be typed as a :class:`~.typing.Point3D` because the function does not accept tuples/lists, +like ``direction=(0, 1, 0)``. You could type it as :class:`~.typing.InternalPoint3D` and +the type checker and linter would be happy; however, this makes the code harder +to understand. + +As a general rule, if a parameter is called ``direction`` or ``axis``, +it should be type hinted as some form of :class:`~.VectorND`. + +.. warning:: + + This is not always true. For example, as of Manim 0.18.0, the direction + parameter of the :class:`.Vector` Mobject should be ``Point2D | Point3D``, + as it can also accept ``tuple[float, float]`` and ``tuple[float, float, float]``. + +Colors +------ +The interface Manim provides for working with colors is :class:`.ManimColor`. +The main color types Manim supports are RGB, RGBA, and HSV. You will want +to add type hints to a function depending on which type it uses. If any color will work, +you will need something like: + +.. code-block:: python + + if TYPE_CHECKING: + from manim.utils.color import ParsableManimColor + + # type hint stuff with ParsableManimColor + + + +Béziers +------- +Manim internally represents a :class:`.Mobject` by a collection of points. In the case of :class:`.VMobject`, +the most commonly used subclass of :class:`.Mobject`, these points represent Bézier curves, +which are a way of representing a curve using a sequence of points. + +.. note:: + + To learn more about Béziers, take a look at https://pomax.github.io/bezierinfo/ + + +Manim supports two different renderers, which each have different representations of +Béziers: Cairo uses cubic Bézier curves, while OpenGL uses quadratic Bézier curves. + +Type hints like :class:`~.typing.BezierPoints` represent a single bezier curve, and :class:`~.typing.BezierPath` +represents multiple Bézier curves. A :class:`~.typing.Spline` is when the Bézier curves in a :class:`~.typing.BezierPath` +forms a single connected curve. Manim also provides more specific type aliases when working with +quadratic or cubic curves, and they are prefixed with their respective type (e.g. :class:`~.typing.CubicBezierPoints`, +is a :class:`~.typing.BezierPoints` consisting of exactly 4 points representing a cubic Bézier curve). + + +Functions +--------- +Throughout the codebase, many different types of functions are used. The most obvious example +is a rate function, which takes in a float and outputs a float (``Callable[[float], float]``). +Another example is for overriding animations. One will often need to map a :class:`.Mobject` +to an overridden :class:`.Animation`, and for that we have the :class:`~.typing.FunctionOverride` type hint. + +:class:`~.typing.PathFuncType` and :class:`~.typing.MappingFunction` are more niche, but are related to moving objects +along a path, or applying functions. If you need to use it, you'll know. + + +Images +------ +There are several representations of images in Manim. The most common is +the representation as a NumPy array of floats representing the pixels of an image. +This is especially common when it comes to the OpenGL renderer. + +This is the use case of the :class:`~.typing.Image` type hint. Sometimes, Manim may use ``PIL.Image``, +in which case one should use that type hint instead. +Of course, if a more specific type of image is needed, it can be annotated as such. diff --git a/docs/source/contributing/docs/typings.rst b/docs/source/contributing/docs/typings.rst index befd557a2b..7cc14068c8 100644 --- a/docs/source/contributing/docs/typings.rst +++ b/docs/source/contributing/docs/typings.rst @@ -1,6 +1,6 @@ -============== -Adding Typings -============== +================== +Typing Conventions +================== .. warning:: This section is still a work in progress. diff --git a/manim/animation/indication.py b/manim/animation/indication.py index 8e2c3996c3..89bf506ebe 100644 --- a/manim/animation/indication.py +++ b/manim/animation/indication.py @@ -63,7 +63,6 @@ def construct(self): from ..mobject.types.vectorized_mobject import VGroup, VMobject from ..utils.bezier import interpolate, inverse_interpolate from ..utils.color import GREY, YELLOW, ParsableManimColor -from ..utils.deprecation import deprecated from ..utils.rate_functions import smooth, there_and_back, wiggle from ..utils.space_ops import normalize diff --git a/manim/cli/checkhealth/checks.py b/manim/cli/checkhealth/checks.py index 9859ced29f..dfc5f231a4 100644 --- a/manim/cli/checkhealth/checks.py +++ b/manim/cli/checkhealth/checks.py @@ -5,11 +5,8 @@ import os import shutil -import subprocess from typing import Callable -from ..._config import config - __all__ = ["HEALTH_CHECKS"] HEALTH_CHECKS = [] diff --git a/manim/cli/default_group.py b/manim/cli/default_group.py index 06f1c0520a..f4c8c33dbb 100644 --- a/manim/cli/default_group.py +++ b/manim/cli/default_group.py @@ -10,6 +10,8 @@ of ``click.Group``. """ +from __future__ import annotations + import warnings import cloup diff --git a/manim/mobject/geometry/boolean_ops.py b/manim/mobject/geometry/boolean_ops.py index c17cf86cbb..ef5bd4fa58 100644 --- a/manim/mobject/geometry/boolean_ops.py +++ b/manim/mobject/geometry/boolean_ops.py @@ -142,7 +142,7 @@ def _convert_skia_path_to_vmobject(self, path: SkiaPath) -> VMobject: n1, n2 = self._convert_2d_to_3d_array(points) vmobject.add_quadratic_bezier_curve_to(n1, n2) else: - raise Exception("Unsupported: %s" % path_verb) + raise Exception(f"Unsupported: {path_verb}") return vmobject diff --git a/manim/mobject/geometry/line.py b/manim/mobject/geometry/line.py index 2989dcb44a..50d63c1190 100644 --- a/manim/mobject/geometry/line.py +++ b/manim/mobject/geometry/line.py @@ -41,8 +41,8 @@ class Line(TipableVMobject): def __init__( self, - start: Point3D = LEFT, - end: Point3D = RIGHT, + start: Point3D | Mobject = LEFT, + end: Point3D | Mobject = RIGHT, buff: float = 0, path_arc: float | None = None, **kwargs, @@ -63,16 +63,32 @@ def generate_points(self) -> None: def set_points_by_ends( self, - start: Point3D, - end: Point3D, + start: Point3D | Mobject, + end: Point3D | Mobject, buff: float = 0, path_arc: float = 0, ) -> None: + """Sets the points of the line based on its start and end points. + Unlike :meth:`put_start_and_end_on`, this method respects `self.buff` and + Mobject bounding boxes. + + Parameters + ---------- + start + The start point or Mobject of the line. + end + The end point or Mobject of the line. + buff + The empty space between the start and end of the line, by default 0. + path_arc + The angle of a circle spanned by this arc, by default 0 which is a straight line. + """ + self._set_start_and_end_attrs(start, end) if path_arc: arc = ArcBetweenPoints(self.start, self.end, angle=self.path_arc) self.set_points(arc.points) else: - self.set_points_as_corners([start, end]) + self.set_points_as_corners([self.start, self.end]) self._account_for_buff(buff) @@ -93,7 +109,9 @@ def _account_for_buff(self, buff: float) -> Self: self.pointwise_become_partial(self, buff_proportion, 1 - buff_proportion) return self - def _set_start_and_end_attrs(self, start: Point3D, end: Point3D) -> None: + def _set_start_and_end_attrs( + self, start: Point3D | Mobject, end: Point3D | Mobject + ) -> None: # If either start or end are Mobjects, this # gives their centers rough_start = self._pointify(start) diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 53d7ed464a..72a26d27db 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias + from manim.scene.scene import Scene from manim.typing import Point3D NxGraph: TypeAlias = nx.classes.graph.Graph | nx.classes.digraph.DiGraph @@ -477,7 +478,7 @@ def _determine_graph_layout( return cast(LayoutFunction, layout)( nx_graph, scale=layout_scale, **layout_config ) - except TypeError as e: + except TypeError: raise ValueError( f"The layout '{layout}' is neither a recognized layout, a layout function," "nor a vertex placement dictionary.", @@ -560,6 +561,7 @@ class GenericGraph(VMobject, metaclass=ConvertToOpenGL): all other configuration options for a vertex. edge_type The mobject class used for displaying edges in the scene. + Must be a subclass of :class:`~.Line` for default updaters to work. edge_config Either a dictionary containing keyword arguments to be passed to the class specified via ``edge_type``, or a dictionary whose @@ -1558,7 +1560,12 @@ def _populate_edge_dict( def update_edges(self, graph): for (u, v), edge in graph.edges.items(): # Undirected graph has a Line edge - edge.put_start_and_end_on(graph[u].get_center(), graph[v].get_center()) + edge.set_points_by_ends( + graph[u].get_center(), + graph[v].get_center(), + buff=self._edge_config.get("buff", 0), + path_arc=self._edge_config.get("path_arc", 0), + ) def __repr__(self: Graph) -> str: return f"Undirected graph on {len(self.vertices)} vertices and {len(self.edges)} edges" @@ -1767,10 +1774,15 @@ def update_edges(self, graph): deformed. """ for (u, v), edge in graph.edges.items(): - edge_type = type(edge) tip = edge.pop_tips()[0] - new_edge = edge_type(self[u], self[v], **self._edge_config[(u, v)]) - edge.become(new_edge) + # Passing the Mobject instead of the vertex makes the tip + # stop on the bounding box of the vertex. + edge.set_points_by_ends( + graph[u], + graph[v], + buff=self._edge_config.get("buff", 0), + path_arc=self._edge_config.get("path_arc", 0), + ) edge.add_tip(tip) def __repr__(self: DiGraph) -> str: diff --git a/manim/mobject/graphing/scale.py b/manim/mobject/graphing/scale.py index b301d1ff15..8ba8a7e63c 100644 --- a/manim/mobject/graphing/scale.py +++ b/manim/mobject/graphing/scale.py @@ -146,7 +146,9 @@ def inverse_function(self, value: float) -> float: """Inverse of ``function``. The value must be greater than 0""" if isinstance(value, np.ndarray): condition = value.any() <= 0 - func = lambda value, base: np.log(value) / np.log(base) + + def func(value, base): + return np.log(value) / np.log(base) else: condition = value <= 0 func = math.log @@ -180,7 +182,7 @@ def get_custom_labels( tex_labels = [ Integer( self.base, - unit="^{%s}" % (f"{self.inverse_function(i):.{unit_decimal_places}f}"), + unit="^{%s}" % (f"{self.inverse_function(i):.{unit_decimal_places}f}"), # noqa: UP031 **base_config, ) for i in val_range diff --git a/manim/mobject/opengl/opengl_mobject.py b/manim/mobject/opengl/opengl_mobject.py index 3bdbd7c0b5..c907c4c2e0 100644 --- a/manim/mobject/opengl/opengl_mobject.py +++ b/manim/mobject/opengl/opengl_mobject.py @@ -5,9 +5,11 @@ import itertools as it import random import sys -from collections.abc import Iterable, Sequence +import types +from collections.abc import Iterable, Iterator, Sequence from functools import partialmethod, wraps from math import ceil +from typing import TYPE_CHECKING, Any, Callable, TypeVar import moderngl import numpy as np @@ -44,10 +46,33 @@ rotation_matrix_transpose, ) +if TYPE_CHECKING: + import numpy.typing as npt + from typing_extensions import Self, TypeAlias -def affects_shader_info_id(func): + from manim.renderer.shader_wrapper import ShaderWrapper + from manim.typing import ( + ManimFloat, + MappingFunction, + MatrixMN, + PathFuncType, + Point3D, + Point3D_Array, + Vector3D, + ) + + TimeBasedUpdater: TypeAlias = Callable[["Mobject", float], object] + NonTimeBasedUpdater: TypeAlias = Callable[["Mobject"], object] + Updater: TypeAlias = NonTimeBasedUpdater | TimeBasedUpdater + + T = TypeVar("T") + + +def affects_shader_info_id( + func: Callable[[OpenGLMobject], OpenGLMobject], +) -> Callable[[OpenGLMobject], OpenGLMobject]: @wraps(func) - def wrapper(self): + def wrapper(self: OpenGLMobject) -> OpenGLMobject: for mob in self.get_family(): func(mob) mob.refresh_shader_wrapper_id() @@ -93,26 +118,26 @@ class OpenGLMobject: def __init__( self, - color=WHITE, - opacity=1, - dim=3, # TODO, get rid of this + color: ParsableManimColor | Iterable[ParsableManimColor] = WHITE, + opacity: float = 1, + dim: int = 3, # TODO, get rid of this # Lighting parameters # Positive gloss up to 1 makes it reflect the light. - gloss=0.0, + gloss: float = 0.0, # Positive shadow up to 1 makes a side opposite the light darker - shadow=0.0, + shadow: float = 0.0, # For shaders - render_primitive=moderngl.TRIANGLES, - texture_paths=None, - depth_test=False, + render_primitive: int = moderngl.TRIANGLES, + texture_paths: dict[str, str] | None = None, + depth_test: bool = False, # If true, the mobject will not get rotated according to camera position - is_fixed_in_frame=False, - is_fixed_orientation=False, + is_fixed_in_frame: bool = False, + is_fixed_orientation: bool = False, # Must match in attributes of vert shader # Event listener - listen_to_events=False, - model_matrix=None, - should_render=True, + listen_to_events: bool = False, + model_matrix: MatrixMN | None = None, + should_render: bool = True, name: str | None = None, **kwargs, ): @@ -199,7 +224,7 @@ def _assert_valid_submobjects(self, submobjects: Iterable[OpenGLMobject]) -> Sel return self._assert_valid_submobjects_internal(submobjects, OpenGLMobject) def _assert_valid_submobjects_internal( - self, submobjects: list[OpenGLMobject], mob_class: type[OpenGLMobject] + self, submobjects: Iterable[OpenGLMobject], mob_class: type[OpenGLMobject] ) -> Self: for i, submob in enumerate(submobjects): if not isinstance(submob, mob_class): @@ -224,14 +249,14 @@ def _assert_valid_submobjects_internal( return self @classmethod - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) cls._original__init__ = cls.__init__ - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ - def __repr__(self): + def __repr__(self) -> str: return str(self.name) def __sub__(self, other): @@ -247,7 +272,7 @@ def __iadd__(self, mobject): return NotImplemented @classmethod - def set_default(cls, **kwargs): + def set_default(cls, **kwargs) -> None: """Sets the default values of keyword arguments. If this method is called without any additional keyword @@ -294,14 +319,14 @@ def construct(self): else: cls.__init__ = cls._original__init__ - def init_data(self): + def init_data(self) -> None: """Initializes the ``points``, ``bounding_box`` and ``rgbas`` attributes and groups them into self.data. Subclasses can inherit and overwrite this method to extend `self.data`.""" self.points = np.zeros((0, 3)) self.bounding_box = np.zeros((3, 3)) self.rgbas = np.zeros((1, 4)) - def init_colors(self): + def init_colors(self) -> None: """Initializes the colors. Gets called upon creation""" @@ -315,7 +340,7 @@ def init_points(self): # Typically implemented in subclass, unless purposefully left blank pass - def set(self, **kwargs) -> OpenGLMobject: + def set(self, **kwargs) -> Self: """Sets attributes. Mainly to be used along with :attr:`animate` to @@ -349,18 +374,18 @@ def set(self, **kwargs) -> OpenGLMobject: return self - def set_data(self, data): + def set_data(self, data: dict[str, Any]) -> Self: for key in data: self.data[key] = data[key].copy() return self - def set_uniforms(self, uniforms): + def set_uniforms(self, uniforms: dict[str, Any]) -> Self: for key in uniforms: self.uniforms[key] = uniforms[key] # Copy? return self @property - def animate(self): + def animate(self) -> _AnimationBuilder | Self: """Used to animate the application of a method. .. warning:: @@ -448,7 +473,7 @@ def construct(self): return _AnimationBuilder(self) @property - def width(self): + def width(self) -> float: """The width of the mobject. Returns @@ -482,11 +507,11 @@ def construct(self): # Only these methods should directly affect points @width.setter - def width(self, value): + def width(self, value: float) -> None: self.rescale_to_fit(value, 0, stretch=False) @property - def height(self): + def height(self) -> float: """The height of the mobject. Returns @@ -519,11 +544,11 @@ def construct(self): return self.length_over_dim(1) @height.setter - def height(self, value): + def height(self, value: float) -> None: self.rescale_to_fit(value, 1, stretch=False) @property - def depth(self): + def depth(self) -> float: """The depth of the mobject. Returns @@ -540,7 +565,7 @@ def depth(self): return self.length_over_dim(2) @depth.setter - def depth(self, value): + def depth(self, value: float) -> None: self.rescale_to_fit(value, 2, stretch=False) def resize_points(self, new_length, resize_func=resize_array): @@ -549,7 +574,7 @@ def resize_points(self, new_length, resize_func=resize_array): self.refresh_bounding_box() return self - def set_points(self, points): + def set_points(self, points: Point3D_Array) -> Self: if len(points) == len(self.points): self.points[:] = points elif isinstance(points, np.ndarray): @@ -559,23 +584,26 @@ def set_points(self, points): self.refresh_bounding_box() return self - def apply_over_attr_arrays(self, func): + def apply_over_attr_arrays( + self, func: Callable[[npt.NDArray[T]], npt.NDArray[T]] + ) -> Self: + # TODO: OpenGLMobject.get_array_attrs() doesn't even exist! for attr in self.get_array_attrs(): setattr(self, attr, func(getattr(self, attr))) return self - def append_points(self, new_points): + def append_points(self, new_points: Point3D_Array) -> Self: self.points = np.vstack([self.points, new_points]) self.refresh_bounding_box() return self - def reverse_points(self): + def reverse_points(self) -> Self: for mob in self.get_family(): for key in mob.data: mob.data[key] = mob.data[key][::-1] return self - def get_midpoint(self) -> np.ndarray: + def get_midpoint(self) -> Point3D: """Get coordinates of the middle of the path that forms the :class:`~.OpenGLMobject`. Examples @@ -600,11 +628,11 @@ def construct(self): def apply_points_function( self, - func, - about_point=None, - about_edge=ORIGIN, - works_on_bounding_box=False, - ): + func: MappingFunction, + about_point: Point3D | None = None, + about_edge: Vector3D | None = ORIGIN, + works_on_bounding_box: bool = False, + ) -> Self: if about_point is None and about_edge is not None: about_point = self.get_bounding_box_point(about_edge) @@ -630,7 +658,7 @@ def apply_points_function( # Others related to points - def match_points(self, mobject): + def match_points(self, mobject: OpenGLMobject) -> Self: """Edit points, positions, and submobjects to be identical to another :class:`~.OpenGLMobject`, while keeping the style unchanged. @@ -648,29 +676,31 @@ def construct(self): self.wait(0.5) """ self.set_points(mobject.points) + return self - def clear_points(self): + def clear_points(self) -> Self: self.points = np.empty((0, 3)) + return self - def get_num_points(self): + def get_num_points(self) -> int: return len(self.points) - def get_all_points(self): + def get_all_points(self) -> Point3D_Array: if self.submobjects: return np.vstack([sm.points for sm in self.get_family()]) else: return self.points - def has_points(self): + def has_points(self) -> bool: return self.get_num_points() > 0 - def get_bounding_box(self): + def get_bounding_box(self) -> npt.NDArray[float]: if self.needs_new_bounding_box: self.bounding_box = self.compute_bounding_box() self.needs_new_bounding_box = False return self.bounding_box - def compute_bounding_box(self): + def compute_bounding_box(self) -> npt.NDArray[float]: all_points = np.vstack( [ self.points, @@ -690,7 +720,9 @@ def compute_bounding_box(self): mids = (mins + maxs) / 2 return np.array([mins, mids, maxs]) - def refresh_bounding_box(self, recurse_down=False, recurse_up=True): + def refresh_bounding_box( + self, recurse_down: bool = False, recurse_up: bool = True + ) -> Self: for mob in self.get_family(recurse_down): mob.needs_new_bounding_box = True if recurse_up: @@ -698,7 +730,7 @@ def refresh_bounding_box(self, recurse_down=False, recurse_up=True): parent.refresh_bounding_box() return self - def is_point_touching(self, point, buff=MED_SMALL_BUFF): + def is_point_touching(self, point: Point3D, buff: float = MED_SMALL_BUFF) -> bool: bb = self.get_bounding_box() mins = bb[0] - buff maxs = bb[2] + buff @@ -706,22 +738,22 @@ def is_point_touching(self, point, buff=MED_SMALL_BUFF): # Family matters - def __getitem__(self, value): + def __getitem__(self, value: int | slice) -> OpenGLMobject: if isinstance(value, slice): GroupClass = self.get_group_class() return GroupClass(*self.split().__getitem__(value)) return self.split().__getitem__(value) - def __iter__(self): + def __iter__(self) -> Iterator[OpenGLMobject]: return iter(self.split()) - def __len__(self): + def __len__(self) -> int: return len(self.split()) - def split(self): + def split(self) -> Sequence[OpenGLMobject]: return self.submobjects - def assemble_family(self): + def assemble_family(self) -> Self: sub_families = (sm.get_family() for sm in self.submobjects) self.family = [self, *uniq_chain(*sub_families)] self.refresh_has_updater_status() @@ -730,18 +762,16 @@ def assemble_family(self): parent.assemble_family() return self - def get_family(self, recurse=True): + def get_family(self, recurse: bool = True) -> Sequence[OpenGLMobject]: if recurse and hasattr(self, "family"): return self.family else: return [self] - def family_members_with_points(self): + def family_members_with_points(self) -> Sequence[OpenGLMobject]: return [m for m in self.get_family() if m.has_points()] - def add( - self, *mobjects: OpenGLMobject, update_parent: bool = False - ) -> OpenGLMobject: + def add(self, *mobjects: OpenGLMobject, update_parent: bool = False) -> Self: """Add mobjects as submobjects. The mobjects are added to :attr:`submobjects`. @@ -826,7 +856,9 @@ def add( self.assemble_family() return self - def insert(self, index: int, mobject: OpenGLMobject, update_parent: bool = False): + def insert( + self, index: int, mobject: OpenGLMobject, update_parent: bool = False + ) -> Self: """Inserts a mobject at a specific position into self.submobjects Effectively just calls ``self.submobjects.insert(index, mobject)``, @@ -858,9 +890,7 @@ def insert(self, index: int, mobject: OpenGLMobject, update_parent: bool = False self.assemble_family() return self - def remove( - self, *mobjects: OpenGLMobject, update_parent: bool = False - ) -> OpenGLMobject: + def remove(self, *mobjects: OpenGLMobject, update_parent: bool = False) -> Self: """Remove :attr:`submobjects`. The mobjects are removed from :attr:`submobjects`, if they exist. @@ -894,7 +924,7 @@ def remove( self.assemble_family() return self - def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: + def add_to_back(self, *mobjects: OpenGLMobject) -> Self: # NOTE: is the note true OpenGLMobjects? """Add all passed mobjects to the back of the submobjects. @@ -943,7 +973,7 @@ def add_to_back(self, *mobjects: OpenGLMobject) -> OpenGLMobject: self.submobjects = list_update(mobjects, self.submobjects) return self - def replace_submobject(self, index, new_submob): + def replace_submobject(self, index: int, new_submob: OpenGLMobject) -> Self: self._assert_valid_submobjects([new_submob]) old_submob = self.submobjects[index] if self in old_submob.parents: @@ -952,36 +982,11 @@ def replace_submobject(self, index, new_submob): self.assemble_family() return self - def invert(self, recursive=False): - """Inverts the list of :attr:`submobjects`. - - Parameters - ---------- - recursive - If ``True``, all submobject lists of this mobject's family are inverted. - - Examples - -------- - - .. manim:: InvertSumobjectsExample - - class InvertSumobjectsExample(Scene): - def construct(self): - s = VGroup(*[Dot().shift(i*0.1*RIGHT) for i in range(-20,20)]) - s2 = s.copy() - s2.invert() - s2.shift(DOWN) - self.play(Write(s), Write(s2)) - """ - if recursive: - for submob in self.submobjects: - submob.invert(recursive=True) - list.reverse(self.submobjects) - self.assemble_family() - # Submobject organization - def arrange(self, direction=RIGHT, center=True, **kwargs): + def arrange( + self, direction: Vector3D = RIGHT, center: bool = True, **kwargs + ) -> Self: """Sorts :class:`~.OpenGLMobject` next to each other on screen. Examples @@ -1010,14 +1015,14 @@ def arrange_in_grid( rows: int | None = None, cols: int | None = None, buff: float | tuple[float, float] = MED_SMALL_BUFF, - cell_alignment: np.ndarray = ORIGIN, + cell_alignment: Vector3D = ORIGIN, row_alignments: str | None = None, # "ucd" col_alignments: str | None = None, # "lcr" - row_heights: Iterable[float | None] | None = None, - col_widths: Iterable[float | None] | None = None, + row_heights: Sequence[float | None] | None = None, + col_widths: Sequence[float | None] | None = None, flow_order: str = "rd", **kwargs, - ) -> OpenGLMobject: + ) -> Self: """Arrange submobjects in a grid. Parameters @@ -1113,16 +1118,27 @@ def construct(self): start_pos = self.get_center() # get cols / rows values if given (implicitly) - def init_size(num, alignments, sizes): + def init_size( + num: int | None, + alignments: str | None, + sizes: Sequence[float | None] | None, + name: str, + ) -> int: if num is not None: return num if alignments is not None: return len(alignments) if sizes is not None: return len(sizes) + raise ValueError( + f"At least one of the following parameters: '{name}s', " + f"'{name}_alignments' or " + f"'{name}_{'widths' if name == 'col' else 'heights'}', " + "must not be None" + ) - cols = init_size(cols, col_alignments, col_widths) - rows = init_size(rows, row_alignments, row_heights) + cols = init_size(cols, col_alignments, col_widths, "col") + rows = init_size(rows, row_alignments, row_heights, "row") # calculate rows cols if rows is None and cols is None: @@ -1146,16 +1162,19 @@ def init_size(num, alignments, sizes): buff_x = buff_y = buff # Initialize alignments correctly - def init_alignments(alignments, num, mapping, name, dir): - if alignments is None: + def init_alignments( + str_alignments: str | None, + num: int, + mapping: dict[str, Vector3D], + name: str, + direction: Vector3D, + ) -> Sequence[Vector3D]: + if str_alignments is None: # Use cell_alignment as fallback - return [cell_alignment * dir] * num - if len(alignments) != num: + return [cell_alignment * direction] * num + if len(str_alignments) != num: raise ValueError(f"{name}_alignments has a mismatching size.") - alignments = list(alignments) - for i in range(num): - alignments[i] = mapping[alignments[i]] - return alignments + return [mapping[letter] for letter in str_alignments] row_alignments = init_alignments( row_alignments, @@ -1191,11 +1210,12 @@ def init_alignments(alignments, num, mapping, name, dir): # Reverse row_alignments and row_heights. Necessary since the # grid filling is handled bottom up for simplicity reasons. - def reverse(maybe_list): + def reverse(maybe_list: Sequence[Any] | None) -> Sequence[Any] | None: if maybe_list is not None: maybe_list = list(maybe_list) maybe_list.reverse() return maybe_list + return None row_alignments = reverse(row_alignments) row_heights = reverse(row_heights) @@ -1216,7 +1236,12 @@ def reverse(maybe_list): ] # Initialize row_heights / col_widths correctly using measurements as fallback - def init_sizes(sizes, num, measures, name): + def init_sizes( + sizes: Sequence[float | None] | None, + num: int, + measures: Sequence[float], + name: str, + ) -> Sequence[float]: if sizes is None: sizes = [None] * num if len(sizes) != num: @@ -1249,7 +1274,9 @@ def init_sizes(sizes, num, measures, name): self.move_to(start_pos) return self - def get_grid(self, n_rows, n_cols, height=None, **kwargs): + def get_grid( + self, n_rows: int, n_cols: int, height: float | None = None, **kwargs + ) -> OpenGLGroup: """ Returns a new mobject containing multiple copies of this one arranged in a grid @@ -1260,11 +1287,15 @@ def get_grid(self, n_rows, n_cols, height=None, **kwargs): grid.set_height(height) return grid - def duplicate(self, n: int): - """Returns an :class:`~.OpenGLVGroup` containing ``n`` copies of the mobject.""" + def duplicate(self, n: int) -> OpenGLGroup: + """Returns an :class:`~.OpenGLGroup` containing ``n`` copies of the mobject.""" return self.get_group_class()(*[self.copy() for _ in range(n)]) - def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): + def sort( + self, + point_to_num_func: Callable[[Point3D], float] = lambda p: p[0], + submob_func: Callable[[OpenGLMobject], Any] | None = None, + ) -> Self: """Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``.""" if submob_func is not None: self.submobjects.sort(key=submob_func) @@ -1272,7 +1303,7 @@ def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center())) return self - def shuffle(self, recurse=False): + def shuffle(self, recurse: bool = False) -> Self: """Shuffles the order of :attr:`submobjects` Examples @@ -1295,7 +1326,7 @@ def construct(self): self.assemble_family() return self - def invert(self, recursive=False): + def invert(self, recursive: bool = False) -> Self: """Inverts the list of :attr:`submobjects`. Parameters @@ -1319,11 +1350,12 @@ def construct(self): if recursive: for submob in self.submobjects: submob.invert(recursive=True) - list.reverse(self.submobjects) + self.submobjects.reverse() + # Is there supposed to be an assemble_family here? # Copying - def copy(self, shallow: bool = False): + def copy(self, shallow: bool = False) -> OpenGLMobject: """Create and return an identical copy of the :class:`OpenGLMobject` including all :attr:`submobjects`. @@ -1381,14 +1413,14 @@ def copy(self, shallow: bool = False): # setattr(copy_mobject, attr, value.copy()) return copy_mobject - def deepcopy(self): + def deepcopy(self) -> OpenGLMobject: parents = self.parents self.parents = [] result = copy.deepcopy(self) self.parents = parents return result - def generate_target(self, use_deepcopy: bool = False): + def generate_target(self, use_deepcopy: bool = False) -> OpenGLMobject: self.target = None # Prevent exponential explosion if use_deepcopy: self.target = self.deepcopy() @@ -1396,7 +1428,7 @@ def generate_target(self, use_deepcopy: bool = False): self.target = self.copy() return self.target - def save_state(self, use_deepcopy: bool = False): + def save_state(self, use_deepcopy: bool = False) -> Self: """Save the current state (position, color & size). Can be restored with :meth:`~.OpenGLMobject.restore`.""" if hasattr(self, "saved_state"): # Prevent exponential growth of data @@ -1407,7 +1439,7 @@ def save_state(self, use_deepcopy: bool = False): self.saved_state = self.copy() return self - def restore(self): + def restore(self) -> Self: """Restores the state that was previously saved with :meth:`~.OpenGLMobject.save_state`.""" if not hasattr(self, "saved_state") or self.save_state is None: raise Exception("Trying to restore without having saved") @@ -1416,13 +1448,13 @@ def restore(self): # Updating - def init_updaters(self): + def init_updaters(self) -> None: self.time_based_updaters = [] self.non_time_updaters = [] self.has_updaters = False self.updating_suspended = False - def update(self, dt=0, recurse=True): + def update(self, dt: float = 0, recurse: bool = True) -> Self: if not self.has_updaters or self.updating_suspended: return self for updater in self.time_based_updaters: @@ -1434,19 +1466,24 @@ def update(self, dt=0, recurse=True): submob.update(dt, recurse) return self - def get_time_based_updaters(self): + def get_time_based_updaters(self) -> Sequence[TimeBasedUpdater]: return self.time_based_updaters - def has_time_based_updater(self): + def has_time_based_updater(self) -> bool: return len(self.time_based_updaters) > 0 - def get_updaters(self): + def get_updaters(self) -> Sequence[Updater]: return self.time_based_updaters + self.non_time_updaters - def get_family_updaters(self): + def get_family_updaters(self) -> Sequence[Updater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) - def add_updater(self, update_function, index=None, call_updater=False): + def add_updater( + self, + update_function: Updater, + index: int | None = None, + call_updater: bool = False, + ) -> Self: if "dt" in inspect.signature(update_function).parameters: updater_list = self.time_based_updaters else: @@ -1462,14 +1499,14 @@ def add_updater(self, update_function, index=None, call_updater=False): self.update() return self - def remove_updater(self, update_function): + def remove_updater(self, update_function: Updater) -> Self: for updater_list in [self.time_based_updaters, self.non_time_updaters]: while update_function in updater_list: updater_list.remove(update_function) self.refresh_has_updater_status() return self - def clear_updaters(self, recurse=True): + def clear_updaters(self, recurse: bool = True) -> Self: self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() @@ -1478,20 +1515,20 @@ def clear_updaters(self, recurse=True): submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, mobject: OpenGLMobject) -> Self: self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse=True): + def suspend_updating(self, recurse: bool = True) -> Self: self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse=True, call_updater=True): + def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Self: self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -1502,13 +1539,13 @@ def resume_updating(self, recurse=True, call_updater=True): self.update(dt=0, recurse=recurse) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) return self # Transforming operations - def shift(self, vector): + def shift(self, vector: Vector3D) -> Self: self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -1522,7 +1559,7 @@ def scale( about_point: Sequence[float] | None = None, about_edge: Sequence[float] = ORIGIN, **kwargs, - ) -> OpenGLMobject: + ) -> Self: r"""Scale the size by a factor. Default behavior is to scale about the center of the mobject. @@ -1578,7 +1615,7 @@ def construct(self): ) return self - def stretch(self, factor, dim, **kwargs): + def stretch(self, factor: float, dim: int, **kwargs) -> Self: def func(points): points[:, dim] *= factor return points @@ -1586,16 +1623,16 @@ def func(points): self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle, axis=OUT): + def rotate_about_origin(self, angle: float, axis: Vector3D = OUT) -> Self: return self.rotate(angle, axis, about_point=ORIGIN) def rotate( self, - angle, - axis=OUT, + angle: float, + axis: Vector3D = OUT, about_point: Sequence[float] | None = None, **kwargs, - ): + ) -> Self: """Rotates the :class:`~.OpenGLMobject` about a certain point.""" rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( @@ -1605,7 +1642,7 @@ def rotate( ) return self - def flip(self, axis=UP, **kwargs): + def flip(self, axis: Vector3D = UP, **kwargs) -> Self: """Flips/Mirrors an mobject about its center. Examples @@ -1624,7 +1661,7 @@ def construct(self): """ return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function, **kwargs): + def apply_function(self, function: MappingFunction, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -1633,16 +1670,16 @@ def apply_function(self, function, **kwargs): ) return self - def apply_function_to_position(self, function): + def apply_function_to_position(self, function: MappingFunction) -> Self: self.move_to(function(self.get_center())) return self - def apply_function_to_submobject_positions(self, function): + def apply_function_to_submobject_positions(self, function: MappingFunction) -> Self: for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs): + def apply_matrix(self, matrix: MatrixMN, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -1654,7 +1691,9 @@ def apply_matrix(self, matrix, **kwargs): ) return self - def apply_complex_function(self, function, **kwargs): + def apply_complex_function( + self, function: Callable[[complex], complex], **kwargs + ) -> Self: """Applies a complex function to a :class:`OpenGLMobject`. The x and y coordinates correspond to the real and imaginary parts respectively. @@ -1688,7 +1727,7 @@ def R3_func(point): return self.apply_function(R3_func) - def hierarchical_model_matrix(self): + def hierarchical_model_matrix(self) -> MatrixMN: if self.parent is None: return self.model_matrix @@ -1699,7 +1738,12 @@ def hierarchical_model_matrix(self): current_object = current_object.parent return np.linalg.multi_dot(list(reversed(model_matrices))) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): + def wag( + self, + direction: Vector3D = RIGHT, + axis: Vector3D = DOWN, + wag_factor: float = 1.0, + ) -> Self: for mob in self.family_members_with_points(): alphas = np.dot(mob.points, np.transpose(axis)) alphas -= min(alphas) @@ -1716,12 +1760,16 @@ def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): # Positioning methods - def center(self): + def center(self) -> Self: """Moves the mobject to the center of the Scene.""" self.shift(-self.get_center()) return self - def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def align_on_border( + self, + direction: Vector3D, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> Self: """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -1737,22 +1785,30 @@ def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): self.shift(shift_val) return self - def to_corner(self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_corner( + self, + corner: Vector3D = LEFT + DOWN, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> Self: return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_edge( + self, + edge: Vector3D = LEFT, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> Self: return self.align_on_border(edge, buff) def next_to( self, - mobject_or_point, - direction=RIGHT, - buff=DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge=ORIGIN, - submobject_to_align=None, - index_of_submobject_to_align=None, - coor_mask=np.array([1, 1, 1]), - ): + mobject_or_point: OpenGLMobject | Point3D, + direction: Vector3D = RIGHT, + buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, + aligned_edge: Vector3D = ORIGIN, + submobject_to_align: OpenGLMobject | None = None, + index_of_submobject_to_align: int | None = None, + coor_mask: Point3D = np.array([1, 1, 1]), + ) -> Self: """Move this :class:`~.OpenGLMobject` next to another's :class:`~.OpenGLMobject` or coordinate. Examples @@ -1794,7 +1850,7 @@ def construct(self): self.shift((target_point - point_to_align + buff * direction) * coor_mask) return self - def shift_onto_screen(self, **kwargs): + def shift_onto_screen(self, **kwargs) -> Self: space_lengths = [config["frame_x_radius"], config["frame_y_radius"]] for vect in UP, DOWN, LEFT, RIGHT: dim = np.argmax(np.abs(vect)) @@ -1805,7 +1861,7 @@ def shift_onto_screen(self, **kwargs): self.to_edge(vect, **kwargs) return self - def is_off_screen(self): + def is_off_screen(self) -> bool: if self.get_left()[0] > config.frame_x_radius: return True if self.get_right()[0] < config.frame_x_radius: @@ -1816,10 +1872,12 @@ def is_off_screen(self): return True return False - def stretch_about_point(self, factor, dim, point): + def stretch_about_point(self, factor: float, dim: int, point: Point3D) -> Self: return self.stretch(factor, dim, about_point=point) - def rescale_to_fit(self, length, dim, stretch=False, **kwargs): + def rescale_to_fit( + self, length: float, dim: int, stretch: bool = False, **kwargs + ) -> Self: old_length = self.length_over_dim(dim) if old_length == 0: return self @@ -1829,7 +1887,7 @@ def rescale_to_fit(self, length, dim, stretch=False, **kwargs): self.scale(length / old_length, **kwargs) return self - def stretch_to_fit_width(self, width, **kwargs): + def stretch_to_fit_width(self, width: float, **kwargs) -> Self: """Stretches the :class:`~.OpenGLMobject` to fit a width, not keeping height/depth proportional. Returns @@ -1854,15 +1912,15 @@ def stretch_to_fit_width(self, width, **kwargs): """ return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def stretch_to_fit_height(self, height, **kwargs): + def stretch_to_fit_height(self, height: float, **kwargs) -> Self: """Stretches the :class:`~.OpenGLMobject` to fit a height, not keeping width/height proportional.""" return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def stretch_to_fit_depth(self, depth, **kwargs): + def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self: """Stretches the :class:`~.OpenGLMobject` to fit a depth, not keeping width/height proportional.""" return self.rescale_to_fit(depth, 1, stretch=True, **kwargs) - def set_width(self, width, stretch=False, **kwargs): + def set_width(self, width: float, stretch: bool = False, **kwargs) -> Self: """Scales the :class:`~.OpenGLMobject` to fit a width while keeping height/depth proportional. Returns @@ -1889,38 +1947,38 @@ def set_width(self, width, stretch=False, **kwargs): scale_to_fit_width = set_width - def set_height(self, height, stretch=False, **kwargs): + def set_height(self, height: float, stretch: bool = False, **kwargs) -> Self: """Scales the :class:`~.OpenGLMobject` to fit a height while keeping width/depth proportional.""" return self.rescale_to_fit(height, 1, stretch=stretch, **kwargs) scale_to_fit_height = set_height - def set_depth(self, depth, stretch=False, **kwargs): + def set_depth(self, depth: float, stretch: bool = False, **kwargs): """Scales the :class:`~.OpenGLMobject` to fit a depth while keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=stretch, **kwargs) scale_to_fit_depth = set_depth - def set_coord(self, value, dim, direction=ORIGIN): + def set_coord(self, value: float, dim: int, direction: Vector3D = ORIGIN) -> Self: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x, direction=ORIGIN): + def set_x(self, x: float, direction: Vector3D = ORIGIN) -> Self: """Set x value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(x, 0, direction) - def set_y(self, y, direction=ORIGIN): + def set_y(self, y: float, direction: Vector3D = ORIGIN) -> Self: """Set y value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(y, 1, direction) - def set_z(self, z, direction=ORIGIN): + def set_z(self, z: float, direction: Vector3D = ORIGIN) -> Self: """Set z value of the center of the :class:`~.OpenGLMobject` (``int`` or ``float``)""" return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor=1.5, **kwargs): + def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1.0 / factor) @@ -1928,10 +1986,10 @@ def space_out_submobjects(self, factor=1.5, **kwargs): def move_to( self, - point_or_mobject, - aligned_edge=ORIGIN, - coor_mask=np.array([1, 1, 1]), - ): + point_or_mobject: Point3D | OpenGLMobject, + aligned_edge: Vector3D = ORIGIN, + coor_mask: Point3D = np.array([1, 1, 1]), + ) -> Self: """Move center of the :class:`~.OpenGLMobject` to certain coordinate.""" if isinstance(point_or_mobject, OpenGLMobject): target = point_or_mobject.get_bounding_box_point(aligned_edge) @@ -1941,7 +1999,12 @@ def move_to( self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject, dim_to_match=0, stretch=False): + def replace( + self, + mobject: OpenGLMobject, + dim_to_match: int = 0, + stretch: bool = False, + ) -> Self: if not mobject.get_num_points() and not mobject.submobjects: self.scale(0) return self @@ -1963,13 +2026,13 @@ def surround( dim_to_match: int = 0, stretch: bool = False, buff: float = MED_SMALL_BUFF, - ): + ) -> Self: self.replace(mobject, dim_to_match, stretch) length = mobject.length_over_dim(dim_to_match) self.scale((length + buff) / length) return self - def put_start_and_end_on(self, start, end): + def put_start_and_end_on(self, start: Point3D, end: Point3D) -> Self: curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): @@ -1994,7 +2057,13 @@ def put_start_and_end_on(self, start, end): # Color functions - def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): + def set_rgba_array( + self, + color: ParsableManimColor | Iterable[ParsableManimColor] | None = None, + opacity: float | Iterable[float] | None = None, + name: str = "rgbas", + recurse: bool = True, + ) -> Self: if color is not None: rgbs = np.array([color_to_rgb(c) for c in listify(color)]) if opacity is not None: @@ -2024,7 +2093,12 @@ def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): mob.data[name] = rgbas.copy() return self - def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True): + def set_rgba_array_direct( + self, + rgbas: npt.NDArray[RGBA_Array_Float], + name: str = "rgbas", + recurse: bool = True, + ) -> Self: """Directly set rgba data from `rgbas` and optionally do the same recursively with submobjects. This can be used if the `rgbas` have already been generated with the correct shape and simply need to be set. @@ -2041,7 +2115,12 @@ def set_rgba_array_direct(self, rgbas: np.ndarray, name="rgbas", recurse=True): for mob in self.get_family(recurse): mob.data[name] = rgbas.copy() - def set_color(self, color: ParsableManimColor | None, opacity=None, recurse=True): + def set_color( + self, + color: ParsableManimColor | Iterable[ParsableManimColor] | None, + opacity: float | Iterable[float] | None = None, + recurse: bool = True, + ) -> Self: self.set_rgba_array(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array # in case they implement set_color differently @@ -2054,24 +2133,25 @@ def set_color(self, color: ParsableManimColor | None, opacity=None, recurse=True submob.set_color(color, recurse=True) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity( + self, opacity: float | Iterable[float] | None, recurse: bool = True + ) -> Self: self.set_rgba_array(color=None, opacity=opacity, recurse=False) if recurse: for submob in self.submobjects: submob.set_opacity(opacity, recurse=True) return self - def get_color(self): + def get_color(self) -> str: return rgb_to_hex(self.rgbas[0, :3]) - def get_opacity(self): + def get_opacity(self) -> float: return self.rgbas[0, 3] - def set_color_by_gradient(self, *colors): - self.set_submobject_colors_by_gradient(*colors) - return self + def set_color_by_gradient(self, *colors: ParsableManimColor) -> Self: + return self.set_submobject_colors_by_gradient(*colors) - def set_submobject_colors_by_gradient(self, *colors): + def set_submobject_colors_by_gradient(self, *colors: ParsableManimColor) -> Self: if len(colors) == 0: raise Exception("Need at least one color") elif len(colors) == 1: @@ -2085,21 +2165,21 @@ def set_submobject_colors_by_gradient(self, *colors): mob.set_color(color) return self - def fade(self, darkness=0.5, recurse=True): - self.set_opacity(1.0 - darkness, recurse=recurse) + def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self: + return self.set_opacity(1.0 - darkness, recurse=recurse) - def get_gloss(self): + def get_gloss(self) -> float: return self.gloss - def set_gloss(self, gloss, recurse=True): + def set_gloss(self, gloss: float, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.gloss = gloss return self - def get_shadow(self): + def get_shadow(self) -> float: return self.shadow - def set_shadow(self, shadow, recurse=True): + def set_shadow(self, shadow: float, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.shadow = shadow return self @@ -2107,8 +2187,11 @@ def set_shadow(self, shadow, recurse=True): # Background rectangle def add_background_rectangle( - self, color: ParsableManimColor | None = None, opacity: float = 0.75, **kwargs - ): + self, + color: ParsableManimColor | None = None, + opacity: float = 0.75, + **kwargs, + ) -> Self: # TODO, this does not behave well when the mobject has points, # since it gets displayed on top """Add a BackgroundRectangle as submobject. @@ -2146,39 +2229,39 @@ def add_background_rectangle( self.add_to_back(self.background_rectangle) return self - def add_background_rectangle_to_submobjects(self, **kwargs): + def add_background_rectangle_to_submobjects(self, **kwargs) -> Self: for submobject in self.submobjects: submobject.add_background_rectangle(**kwargs) return self - def add_background_rectangle_to_family_members_with_points(self, **kwargs): + def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Self: for mob in self.family_members_with_points(): mob.add_background_rectangle(**kwargs) return self # Getters - def get_bounding_box_point(self, direction): + def get_bounding_box_point(self, direction: Vector3D) -> Point3D: bb = self.get_bounding_box() indices = (np.sign(direction) + 1).astype(int) return np.array([bb[indices[i]][i] for i in range(3)]) - def get_edge_center(self, direction) -> np.ndarray: + def get_edge_center(self, direction: Vector3D) -> Point3D: """Get edge coordinates for certain direction.""" return self.get_bounding_box_point(direction) - def get_corner(self, direction) -> np.ndarray: + def get_corner(self, direction: Vector3D) -> Point3D: """Get corner coordinates for certain direction.""" return self.get_bounding_box_point(direction) - def get_center(self) -> np.ndarray: + def get_center(self) -> Point3D: """Get center coordinates.""" return self.get_bounding_box()[1] - def get_center_of_mass(self): + def get_center_of_mass(self) -> Point3D: return self.get_all_points().mean(0) - def get_boundary_point(self, direction): + def get_boundary_point(self, direction: Vector3D) -> Point3D: all_points = self.get_all_points() boundary_directions = all_points - self.get_center() norms = np.linalg.norm(boundary_directions, axis=1) @@ -2186,7 +2269,7 @@ def get_boundary_point(self, direction): index = np.argmax(np.dot(boundary_directions, np.array(direction).T)) return all_points[index] - def get_continuous_bounding_box_point(self, direction): + def get_continuous_bounding_box_point(self, direction: Vector3D) -> Point3D: dl, center, ur = self.get_bounding_box() corner_vect = ur - center return center + direction / np.max( @@ -2200,86 +2283,86 @@ def get_continuous_bounding_box_point(self, direction): ), ) - def get_top(self) -> np.ndarray: + def get_top(self) -> Point3D: """Get top coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(UP) - def get_bottom(self) -> np.ndarray: + def get_bottom(self) -> Point3D: """Get bottom coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(DOWN) - def get_right(self) -> np.ndarray: + def get_right(self) -> Point3D: """Get right coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(RIGHT) - def get_left(self) -> np.ndarray: + def get_left(self) -> Point3D: """Get left coordinates of a box bounding the :class:`~.OpenGLMobject`""" return self.get_edge_center(LEFT) - def get_zenith(self) -> np.ndarray: + def get_zenith(self) -> Point3D: """Get zenith coordinates of a box bounding a 3D :class:`~.OpenGLMobject`.""" return self.get_edge_center(OUT) - def get_nadir(self) -> np.ndarray: + def get_nadir(self) -> Point3D: """Get nadir (opposite the zenith) coordinates of a box bounding a 3D :class:`~.OpenGLMobject`.""" return self.get_edge_center(IN) - def length_over_dim(self, dim): + def length_over_dim(self, dim: int) -> float: bb = self.get_bounding_box() return abs((bb[2] - bb[0])[dim]) - def get_width(self): + def get_width(self) -> float: """Returns the width of the mobject.""" return self.length_over_dim(0) - def get_height(self): + def get_height(self) -> float: """Returns the height of the mobject.""" return self.length_over_dim(1) - def get_depth(self): + def get_depth(self) -> float: """Returns the depth of the mobject.""" return self.length_over_dim(2) - def get_coord(self, dim: int, direction=ORIGIN): + def get_coord(self, dim: int, direction: Vector3D = ORIGIN) -> ManimFloat: """Meant to generalize ``get_x``, ``get_y`` and ``get_z``""" return self.get_bounding_box_point(direction)[dim] - def get_x(self, direction=ORIGIN) -> np.float64: + def get_x(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns x coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(0, direction) - def get_y(self, direction=ORIGIN) -> np.float64: + def get_y(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns y coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(1, direction) - def get_z(self, direction=ORIGIN) -> np.float64: + def get_z(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns z coordinate of the center of the :class:`~.OpenGLMobject` as ``float``""" return self.get_coord(2, direction) - def get_start(self): + def get_start(self) -> Point3D: """Returns the point, where the stroke that surrounds the :class:`~.OpenGLMobject` starts.""" self.throw_error_if_no_points() return np.array(self.points[0]) - def get_end(self): + def get_end(self) -> Point3D: """Returns the point, where the stroke that surrounds the :class:`~.OpenGLMobject` ends.""" self.throw_error_if_no_points() return np.array(self.points[-1]) - def get_start_and_end(self): + def get_start_and_end(self) -> tuple[Point3D, Point3D]: """Returns starting and ending point of a stroke as a ``tuple``.""" return self.get_start(), self.get_end() - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> Point3D: points = self.points i, subalpha = integer_interpolate(0, len(points) - 1, alpha) return interpolate(points[i], points[i + 1], subalpha) - def pfp(self, alpha): + def pfp(self, alpha: float) -> Point3D: """Abbreviation for point_from_proportion""" return self.point_from_proportion(alpha) - def get_pieces(self, n_pieces): + def get_pieces(self, n_pieces: int) -> OpenGLMobject: template = self.copy() template.submobjects = [] alphas = np.linspace(0, 1, n_pieces + 1) @@ -2290,34 +2373,36 @@ def get_pieces(self, n_pieces): ) ) - def get_z_index_reference_point(self): + def get_z_index_reference_point(self) -> Point3D: # TODO, better place to define default z_index_group? z_index_group = getattr(self, "z_index_group", self) return z_index_group.get_center() # Match other mobject properties - def match_color(self, mobject: OpenGLMobject): + def match_color(self, mobject: OpenGLMobject) -> Self: """Match the color with the color of another :class:`~.OpenGLMobject`.""" return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject: OpenGLMobject, dim, **kwargs): + def match_dim_size(self, mobject: OpenGLMobject, dim: int, **kwargs) -> Self: """Match the specified dimension with the dimension of another :class:`~.OpenGLMobject`.""" return self.rescale_to_fit(mobject.length_over_dim(dim), dim, **kwargs) - def match_width(self, mobject: OpenGLMobject, **kwargs): + def match_width(self, mobject: OpenGLMobject, **kwargs) -> Self: """Match the width with the width of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject: OpenGLMobject, **kwargs): + def match_height(self, mobject: OpenGLMobject, **kwargs) -> Self: """Match the height with the height of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject: OpenGLMobject, **kwargs): + def match_depth(self, mobject: OpenGLMobject, **kwargs) -> Self: """Match the depth with the depth of another :class:`~.OpenGLMobject`.""" return self.match_dim_size(mobject, 2, **kwargs) - def match_coord(self, mobject: OpenGLMobject, dim, direction=ORIGIN): + def match_coord( + self, mobject: OpenGLMobject, dim: int, direction: Vector3D = ORIGIN + ) -> Self: """Match the coordinates with the coordinates of another :class:`~.OpenGLMobject`.""" return self.set_coord( mobject.get_coord(dim, direction), @@ -2325,23 +2410,23 @@ def match_coord(self, mobject: OpenGLMobject, dim, direction=ORIGIN): direction=direction, ) - def match_x(self, mobject, direction=ORIGIN): + def match_x(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: """Match x coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 0, direction) - def match_y(self, mobject, direction=ORIGIN): + def match_y(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: """Match y coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 1, direction) - def match_z(self, mobject, direction=ORIGIN): + def match_z(self, mobject: OpenGLMobject, direction: Vector3D = ORIGIN) -> Self: """Match z coord. to the x coord. of another :class:`~.OpenGLMobject`.""" return self.match_coord(mobject, 2, direction) def align_to( self, - mobject_or_point: OpenGLMobject | Sequence[float], - direction=ORIGIN, - ): + mobject_or_point: OpenGLMobject | Point3D, + direction: Vector3D = ORIGIN, + ) -> Self: """ Examples: mob1.align_to(mob2, UP) moves mob1 vertically so that its @@ -2361,21 +2446,22 @@ def align_to( self.set_coord(point[dim], dim, direction) return self - def get_group_class(self): + def get_group_class(self) -> type[OpenGLGroup]: return OpenGLGroup @staticmethod - def get_mobject_type_class(): + def get_mobject_type_class() -> type[OpenGLMobject]: """Return the base class of this mobject type.""" return OpenGLMobject # Alignment - def align_data_and_family(self, mobject): + def align_data_and_family(self, mobject: OpenGLMobject) -> Self: self.align_family(mobject) self.align_data(mobject) + return self - def align_data(self, mobject): + def align_data(self, mobject: OpenGLMobject) -> Self: # In case any data arrays get resized when aligned to shader data # self.refresh_shader_data() for mob1, mob2 in zip(self.get_family(), mobject.get_family()): @@ -2391,14 +2477,15 @@ def align_data(self, mobject): mob1.data[key] = resize_preserving_order(arr1, len(arr2)) elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) + return self - def align_points(self, mobject): + def align_points(self, mobject: OpenGLMobject) -> Self: max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject): + def align_family(self, mobject: OpenGLMobject) -> Self: mob1 = self mob2 = mobject n1 = len(mob1) @@ -2411,14 +2498,14 @@ def align_family(self, mobject): sm1.align_family(sm2) return self - def push_self_into_submobjects(self): + def push_self_into_submobjects(self) -> Self: copy = self.deepcopy() copy.submobjects = [] self.resize_points(0) self.add(copy) return self - def add_n_more_submobjects(self, n): + def add_n_more_submobjects(self, n: int) -> Self: if n == 0: return self @@ -2447,7 +2534,13 @@ def add_n_more_submobjects(self, n): # Interpolate - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path()): + def interpolate( + self, + mobject1: OpenGLMobject, + mobject2: OpenGLMobject, + alpha: float, + path_func: PathFuncType = straight_path(), + ) -> Self: """Turns this :class:`~.OpenGLMobject` into an interpolation between ``mobject1`` and ``mobject2``. @@ -2500,7 +2593,9 @@ def construct(self): ) return self - def pointwise_become_partial(self, mobject, a, b): + def pointwise_become_partial( + self, mobject: OpenGLMobject, a: float, b: float + ) -> None: """ Set points in such a way as to become only part of mobject. @@ -2517,7 +2612,7 @@ def become( match_depth: bool = False, match_center: bool = False, stretch: bool = False, - ): + ) -> Self: """Edit all data and submobjects to be identical to another :class:`~.OpenGLMobject` @@ -2577,7 +2672,7 @@ def construct(self): # Locking data - def lock_data(self, keys): + def lock_data(self, keys: Iterable[str]) -> None: """ To speed up some animations, particularly transformations, it can be handy to acknowledge which pieces of data @@ -2591,7 +2686,9 @@ def lock_data(self, keys): self.refresh_shader_data() self.locked_data_keys = set(keys) - def lock_matching_data(self, mobject1, mobject2): + def lock_matching_data( + self, mobject1: OpenGLMobject, mobject2: OpenGLMobject + ) -> Self: for sm, sm1, sm2 in zip( self.get_family(), mobject1.get_family(), @@ -2608,57 +2705,57 @@ def lock_matching_data(self, mobject1, mobject2): ) return self - def unlock_data(self): + def unlock_data(self) -> None: for mob in self.get_family(): mob.locked_data_keys = set() # Operations touching shader uniforms @affects_shader_info_id - def fix_in_frame(self): + def fix_in_frame(self) -> Self: self.is_fixed_in_frame = 1.0 return self @affects_shader_info_id - def fix_orientation(self): + def fix_orientation(self) -> Self: self.is_fixed_orientation = 1.0 self.fixed_orientation_center = tuple(self.get_center()) self.depth_test = True return self @affects_shader_info_id - def unfix_from_frame(self): + def unfix_from_frame(self) -> Self: self.is_fixed_in_frame = 0.0 return self @affects_shader_info_id - def unfix_orientation(self): + def unfix_orientation(self) -> Self: self.is_fixed_orientation = 0.0 self.fixed_orientation_center = (0, 0, 0) self.depth_test = False return self @affects_shader_info_id - def apply_depth_test(self): + def apply_depth_test(self) -> Self: self.depth_test = True return self @affects_shader_info_id - def deactivate_depth_test(self): + def deactivate_depth_test(self) -> Self: self.depth_test = False return self # Shader code manipulation - def replace_shader_code(self, old, new): + def replace_shader_code(self, old_code: str, new_code: str) -> Self: # TODO, will this work with VMobject structure, given # that it does not simpler return shader_wrappers of # family? for wrapper in self.get_shader_wrapper_list(): - wrapper.replace_code(old, new) + wrapper.replace_code(old_code, new_code) return self - def set_color_by_code(self, glsl_code): + def set_color_by_code(self, glsl_code: str) -> Self: """ Takes a snippet of code and inserts it into a context which has the following variables: @@ -2670,11 +2767,11 @@ def set_color_by_code(self, glsl_code): def set_color_by_xyz_func( self, - glsl_snippet, - min_value=-5.0, - max_value=5.0, - colormap="viridis", - ): + glsl_snippet: str, + min_value: float = -5.0, + max_value: float = 5.0, + colormap: str = "viridis", + ) -> Self: """ Pass in a glsl expression in terms of x, y and z which returns a float. @@ -2685,22 +2782,17 @@ def set_color_by_xyz_func( glsl_snippet = glsl_snippet.replace(char, "point." + char) rgb_list = get_colormap_list(colormap) self.set_color_by_code( - "color.rgb = float_to_color({}, {}, {}, {});".format( - glsl_snippet, - float(min_value), - float(max_value), - get_colormap_code(rgb_list), - ), + f"color.rgb = float_to_color({glsl_snippet}, {float(min_value)}, {float(max_value)}, {get_colormap_code(rgb_list)});", ) return self # For shader data - def refresh_shader_wrapper_id(self): + def refresh_shader_wrapper_id(self) -> Self: self.get_shader_wrapper().refresh_id() return self - def get_shader_wrapper(self): + def get_shader_wrapper(self) -> ShaderWrapper: from manim.renderer.shader_wrapper import ShaderWrapper # if hasattr(self, "__shader_wrapper"): @@ -2717,7 +2809,7 @@ def get_shader_wrapper(self): ) return self.shader_wrapper - def get_shader_wrapper_list(self): + def get_shader_wrapper_list(self) -> Sequence[ShaderWrapper]: shader_wrappers = it.chain( [self.get_shader_wrapper()], *(sm.get_shader_wrapper_list() for sm in self.submobjects), @@ -2734,7 +2826,7 @@ def get_shader_wrapper_list(self): result.append(shader_wrapper) return result - def check_data_alignment(self, array, data_key): + def check_data_alignment(self, array: npt.NDArray, data_key: str) -> Self: # Makes sure that self.data[key] can be broadcast into # the given array, meaning its length has to be either 1 # or the length of the array @@ -2746,45 +2838,50 @@ def check_data_alignment(self, array, data_key): ) return self - def get_resized_shader_data_array(self, length): + def get_resized_shader_data_array(self, length: float) -> npt.NDArray: # If possible, try to populate an existing array, rather # than recreating it each frame points = self.points shader_data = np.zeros(len(points), dtype=self.shader_dtype) return shader_data - def read_data_to_shader(self, shader_data, shader_data_key, data_key): + def read_data_to_shader( + self, + shader_data: npt.NDArray, # has structured data type, ex. ("point", np.float32, (3,)) + shader_data_key: str, + data_key: str, + ) -> None: if data_key in self.locked_data_keys: return self.check_data_alignment(shader_data, data_key) shader_data[shader_data_key] = self.data[data_key] - def get_shader_data(self): + def get_shader_data(self) -> npt.NDArray: shader_data = self.get_resized_shader_data_array(self.get_num_points()) self.read_data_to_shader(shader_data, "point", "points") return shader_data - def refresh_shader_data(self): + def refresh_shader_data(self) -> None: self.get_shader_data() - def get_shader_uniforms(self): + def get_shader_uniforms(self) -> dict[str, Any]: return self.uniforms - def get_shader_vert_indices(self): + def get_shader_vert_indices(self) -> Sequence[int]: return self.shader_indices @property - def submobjects(self): + def submobjects(self) -> Sequence[OpenGLMobject]: return self._submobjects if hasattr(self, "_submobjects") else [] @submobjects.setter - def submobjects(self, submobject_list): + def submobjects(self, submobject_list: Iterable[OpenGLMobject]) -> None: self.remove(*self.submobjects) self.add(*submobject_list) # Errors - def throw_error_if_no_points(self): + def throw_error_if_no_points(self) -> None: if not self.has_points(): message = ( "Cannot call OpenGLMobject.{} " + "for a OpenGLMobject with no points" @@ -2794,38 +2891,42 @@ def throw_error_if_no_points(self): class OpenGLGroup(OpenGLMobject): - def __init__(self, *mobjects, **kwargs): + def __init__(self, *mobjects: OpenGLMobject, **kwargs): super().__init__(**kwargs) self.add(*mobjects) class OpenGLPoint(OpenGLMobject): def __init__( - self, location=ORIGIN, artificial_width=1e-6, artificial_height=1e-6, **kwargs + self, + location: Point3D = ORIGIN, + artificial_width: float = 1e-6, + artificial_height: float = 1e-6, + **kwargs, ): self.artificial_width = artificial_width self.artificial_height = artificial_height super().__init__(**kwargs) self.set_location(location) - def get_width(self): + def get_width(self) -> float: return self.artificial_width - def get_height(self): + def get_height(self) -> float: return self.artificial_height - def get_location(self): + def get_location(self) -> Point3D: return self.points[0].copy() - def get_bounding_box_point(self, *args, **kwargs): + def get_bounding_box_point(self, *args, **kwargs) -> Point3D: return self.get_location() - def set_location(self, new_loc): + def set_location(self, new_loc: Point3D) -> None: self.set_points(np.array(new_loc, ndmin=2, dtype=float)) class _AnimationBuilder: - def __init__(self, mobject): + def __init__(self, mobject: OpenGLMobject): self.mobject = mobject self.mobject.generate_target() @@ -2837,7 +2938,7 @@ def __init__(self, mobject): self.cannot_pass_args = False self.anim_args = {} - def __call__(self, **kwargs): + def __call__(self, **kwargs) -> Self: if self.cannot_pass_args: raise ValueError( "Animation arguments must be passed before accessing methods and can only be passed once", @@ -2848,7 +2949,7 @@ def __call__(self, **kwargs): return self - def __getattr__(self, method_name): + def __getattr__(self, method_name: str) -> Callable[..., Self]: method = getattr(self.mobject.target, method_name) has_overridden_animation = hasattr(method, "_override_animate") @@ -2876,7 +2977,7 @@ def update_target(*method_args, **method_kwargs): return update_target - def build(self): + def build(self) -> _MethodAnimation: from manim.animation.transform import _MethodAnimation if self.overridden_animation: @@ -2890,7 +2991,7 @@ def build(self): return anim -def override_animate(method): +def override_animate(method: types.FunctionType) -> types.FunctionType: r"""Decorator for overriding method animations. This allows to specify a method (returning an :class:`~.Animation`) diff --git a/manim/mobject/opengl/opengl_surface.py b/manim/mobject/opengl/opengl_surface.py index bcf608909f..565b8c71cf 100644 --- a/manim/mobject/opengl/opengl_surface.py +++ b/manim/mobject/opengl/opengl_surface.py @@ -12,7 +12,6 @@ from manim.utils.bezier import integer_interpolate, interpolate from manim.utils.color import * from manim.utils.config_ops import _Data, _Uniforms -from manim.utils.deprecation import deprecated from manim.utils.images import change_to_rgba_array, get_full_raster_image_path from manim.utils.iterables import listify from manim.utils.space_ops import normalize_along_axis diff --git a/manim/mobject/opengl/opengl_vectorized_mobject.py b/manim/mobject/opengl/opengl_vectorized_mobject.py index b7b5d4ce54..8037760c4a 100644 --- a/manim/mobject/opengl/opengl_vectorized_mobject.py +++ b/manim/mobject/opengl/opengl_vectorized_mobject.py @@ -324,6 +324,7 @@ def match_style(self, vmobject, recurse=True): vmobject_style = vmobject.get_style() if config.renderer == RendererType.OPENGL: vmobject_style["stroke_width"] = vmobject_style["stroke_width"][0][0] + vmobject_style["fill_opacity"] = self.get_fill_opacity() self.set_style(**vmobject_style, recurse=False) if recurse: # Does its best to match up submobject lists, and @@ -405,7 +406,7 @@ def get_stroke_opacity(self): return self.get_stroke_opacities()[0] def get_color(self): - if self.has_stroke(): + if not self.has_fill(): return self.get_stroke_color() return self.get_fill_color() @@ -1223,8 +1224,8 @@ def get_nth_subpath(path_list, n): return path for n in range(n_subpaths): - sp1 = get_nth_subpath(subpaths1, n) - sp2 = get_nth_subpath(subpaths2, n) + sp1 = np.asarray(get_nth_subpath(subpaths1, n)) + sp2 = np.asarray(get_nth_subpath(subpaths2, n)) diff1 = max(0, (len(sp2) - len(sp1)) // nppc) diff2 = max(0, (len(sp1) - len(sp2)) // nppc) sp1 = self.insert_n_curves_to_point_list(diff1, sp1) diff --git a/manim/mobject/text/code_mobject.py b/manim/mobject/text/code_mobject.py index e85f3bf0ba..999ab3c90e 100644 --- a/manim/mobject/text/code_mobject.py +++ b/manim/mobject/text/code_mobject.py @@ -12,12 +12,11 @@ from pathlib import Path import numpy as np -from pygments import highlight +from pygments import highlight, styles from pygments.formatters.html import HtmlFormatter from pygments.lexers import get_lexer_by_name, guess_lexer_for_filename -from pygments.styles import get_all_styles -from manim import logger +# from pygments.styles import get_all_styles from manim.constants import * from manim.mobject.geometry.arc import Dot from manim.mobject.geometry.polygram import RoundedRectangle @@ -26,8 +25,6 @@ from manim.mobject.types.vectorized_mobject import VGroup from manim.utils.color import WHITE -__all__ = ["Code"] - class Code(VGroup): """A highlighted source code listing. @@ -64,7 +61,7 @@ class Code(VGroup): background_stroke_width=1, background_stroke_color=WHITE, insert_line_no=True, - style=Code.styles_list[15], + style="emacs", background="window", language="cpp", ) @@ -128,7 +125,9 @@ def construct(self): line_no_buff Defines the spacing between line numbers and displayed code. Defaults to 0.4. style - Defines the style type of displayed code. You can see possible names of styles in with :attr:`styles_list`. Defaults to ``"vim"``. + Defines the style type of displayed code. To see a list possible + names of styles call :meth:`get_styles_list`. + Defaults to ``"vim"``. language Specifies the programming language the given code was written in. If ``None`` (the default), the language will be automatically detected. For the list of @@ -157,7 +156,7 @@ def construct(self): # For more information about pygments.lexers visit https://pygments.org/docs/lexers/ # from pygments.lexers import get_all_lexers # all_lexers = get_all_lexers() - styles_list = list(get_all_styles()) + _styles_list_cache: list[str] | None = None # For more information about pygments.styles visit https://pygments.org/docs/styles/ def __init__( @@ -289,6 +288,20 @@ def __init__( ) self.move_to(np.array([0, 0, 0])) + @classmethod + def get_styles_list(cls): + """Get list of available code styles. + + Returns + ------- + list[str] + The list of available code styles to use for the ``styles`` + argument. + """ + if cls._styles_list_cache is None: + cls._styles_list_cache = list(styles.get_all_styles()) + return cls._styles_list_cache + def _ensure_valid_file(self): """Function to validate file.""" if self.file_name is None: diff --git a/manim/mobject/text/tex_mobject.py b/manim/mobject/text/tex_mobject.py index 449b21d385..c6810d8f65 100644 --- a/manim/mobject/text/tex_mobject.py +++ b/manim/mobject/text/tex_mobject.py @@ -175,8 +175,8 @@ def _modify_special_strings(self, tex): tex = self._remove_stray_braces(tex) for context in ["array"]: - begin_in = ("\\begin{%s}" % context) in tex - end_in = ("\\end{%s}" % context) in tex + begin_in = ("\\begin{%s}" % context) in tex # noqa: UP031 + end_in = ("\\end{%s}" % context) in tex # noqa: UP031 if begin_in ^ end_in: # Just turn this into a blank string, # which means caller should leave a diff --git a/manim/mobject/text/text_mobject.py b/manim/mobject/text/text_mobject.py index 39a1cdd171..6f7b4d9d9f 100644 --- a/manim/mobject/text/text_mobject.py +++ b/manim/mobject/text/text_mobject.py @@ -56,7 +56,6 @@ def construct(self): import copy import hashlib -import os import re from collections.abc import Iterable, Sequence from contextlib import contextmanager diff --git a/manim/mobject/three_d/three_dimensions.py b/manim/mobject/three_d/three_dimensions.py index 540a99bfe9..bcdde3e188 100644 --- a/manim/mobject/three_d/three_dimensions.py +++ b/manim/mobject/three_d/three_dimensions.py @@ -32,16 +32,10 @@ from manim.mobject.mobject import * from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.opengl.opengl_mobject import OpenGLMobject -from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.mobject.types.vectorized_mobject import VectorizedPoint, VGroup, VMobject from manim.utils.color import ( - BLUE, - BLUE_D, - BLUE_E, - LIGHT_GREY, - WHITE, ManimColor, ParsableManimColor, - interpolate_color, ) from manim.utils.iterables import tuplify from manim.utils.space_ops import normalize, perpendicular_bisector, z_to_vector @@ -622,17 +616,18 @@ def __init__( **kwargs, ) # used for rotations + self.new_height = height self._current_theta = 0 self._current_phi = 0 - + self.base_circle = Circle( + radius=base_radius, + color=self.fill_color, + fill_opacity=self.fill_opacity, + stroke_width=0, + ) + self.base_circle.shift(height * IN) + self._set_start_and_end_attributes(direction) if show_base: - self.base_circle = Circle( - radius=base_radius, - color=self.fill_color, - fill_opacity=self.fill_opacity, - stroke_width=0, - ) - self.base_circle.shift(height * IN) self.add(self.base_circle) self._rotate_to_direction() @@ -662,6 +657,12 @@ def func(self, u: float, v: float) -> np.ndarray: ], ) + def get_start(self) -> np.ndarray: + return self.start_point.get_center() + + def get_end(self) -> np.ndarray: + return self.end_point.get_center() + def _rotate_to_direction(self) -> None: x, y, z = self.direction @@ -716,6 +717,15 @@ def get_direction(self) -> np.ndarray: """ return self.direction + def _set_start_and_end_attributes(self, direction): + normalized_direction = direction * np.linalg.norm(direction) + + start = self.base_circle.get_center() + end = start + normalized_direction * self.new_height + self.start_point = VectorizedPoint(start) + self.end_point = VectorizedPoint(end) + self.add(self.start_point, self.end_point) + class Cylinder(Surface): """A cylinder, defined by its height, radius and direction, @@ -1156,14 +1166,20 @@ def __init__( self.end - height * self.direction, **kwargs, ) - self.cone = Cone( - direction=self.direction, base_radius=base_radius, height=height, **kwargs + direction=self.direction, + base_radius=base_radius, + height=height, + **kwargs, ) self.cone.shift(end) - self.add(self.cone) + self.end_point = VectorizedPoint(end) + self.add(self.end_point, self.cone) self.set_color(color) + def get_end(self) -> np.ndarray: + return self.end_point.get_center() + class Torus(Surface): """A torus. diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index cc938065d3..50ac24b7a9 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -31,7 +31,7 @@ from manim.utils.bezier import ( bezier, bezier_remap, - get_smooth_handle_points, + get_smooth_cubic_bezier_handle_points, integer_interpolate, interpolate, partial_bezier_points, @@ -911,7 +911,6 @@ def add_line_to(self, point: Point3D) -> Self: :class:`VMobject` ``self`` """ - nppcc = self.n_points_per_cubic_curve self.add_cubic_bezier_curve_to( *( interpolate(self.get_last_point(), point, t) @@ -1060,7 +1059,6 @@ def construct(self): vmob.set_points_as_corners(corners).scale(2) self.add(vmob) """ - nppcc = self.n_points_per_cubic_curve points = np.array(points) # This will set the handles aligned with the anchors. # Id est, a bezier curve will be the segment from the two anchors such that the handles belongs to this segment. @@ -1095,7 +1093,7 @@ def change_anchor_mode(self, mode: Literal["jagged", "smooth"]) -> Self: # The append is needed as the last element is not reached when slicing with numpy. anchors = np.append(subpath[::nppcc], subpath[-1:], 0) if mode == "smooth": - h1, h2 = get_smooth_handle_points(anchors) + h1, h2 = get_smooth_cubic_bezier_handle_points(anchors) else: # mode == "jagged" # The following will make the handles aligned with the anchors, thus making the bezier curve a segment a1 = anchors[:-1] @@ -2310,7 +2308,7 @@ def remove(self, key: Hashable) -> Self: my_dict.remove("square") """ if key not in self.submob_dict: - raise KeyError("The given key '%s' is not present in the VDict" % str(key)) + raise KeyError(f"The given key '{key!s}' is not present in the VDict") super().remove(self.submob_dict[key]) del self.submob_dict[key] return self diff --git a/manim/renderer/opengl_renderer.py b/manim/renderer/opengl_renderer.py index 5a4d692657..8347f8a49e 100644 --- a/manim/renderer/opengl_renderer.py +++ b/manim/renderer/opengl_renderer.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools as it -import sys import time from functools import cached_property from typing import Any @@ -569,7 +568,7 @@ def pixel_coords_to_space_coords(self, px, py, relative=False, top_left=False): if pixel_shape is None: return np.array([0, 0, 0]) pw, ph = pixel_shape - fw, fh = config["frame_width"], config["frame_height"] + fh = config["frame_height"] fc = self.camera.get_center() if relative: return 2 * np.array([px / pw, py / ph, 0]) diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 1f5faed4ec..3f80c91864 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -230,7 +230,7 @@ def render(self, preview: bool = False): self.construct() except EndSceneEarlyException: pass - except RerunSceneException as e: + except RerunSceneException: self.remove(*self.mobjects) self.renderer.clear_screen() self.renderer.num_plays = 0 diff --git a/manim/scene/vector_space_scene.py b/manim/scene/vector_space_scene.py index e9304115ef..d4b632cca8 100644 --- a/manim/scene/vector_space_scene.py +++ b/manim/scene/vector_space_scene.py @@ -297,7 +297,7 @@ def get_vector_label( """ if not isinstance(label, MathTex): if len(label) == 1: - label = "\\vec{\\textbf{%s}}" % label + label = "\\vec{\\textbf{%s}}" % label # noqa: UP031 label = MathTex(label) if color is None: color = vector.get_color() @@ -904,9 +904,8 @@ def add_transformable_label( if new_label: label_mob.target_text = new_label else: - label_mob.target_text = "{}({})".format( - transformation_name, - label_mob.get_tex_string(), + label_mob.target_text = ( + f"{transformation_name}({label_mob.get_tex_string()})" ) label_mob.vector = vector label_mob.kwargs = kwargs diff --git a/manim/utils/bezier.py b/manim/utils/bezier.py index 709f8b3bf4..c0bbcdf912 100644 --- a/manim/utils/bezier.py +++ b/manim/utils/bezier.py @@ -2,17 +2,6 @@ from __future__ import annotations -from manim.typing import ( - BezierPoints, - ColVector, - MatrixMN, - Point3D, - Point3D_Array, - PointDType, - QuadraticBezierPoints, - QuadraticBezierPoints_Array, -) - __all__ = [ "bezier", "partial_bezier_points", @@ -24,9 +13,7 @@ "mid", "inverse_interpolate", "match_interpolate", - "get_smooth_handle_points", "get_smooth_cubic_bezier_handle_points", - "diag_to_matrix", "is_closed", "proportions_along_bezier_curve_for_point", "point_lies_on_bezier", @@ -35,14 +22,27 @@ from collections.abc import Sequence from functools import reduce -from typing import Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, overload import numpy as np -import numpy.typing as npt -from scipy import linalg -from ..utils.simple_functions import choose -from ..utils.space_ops import cross2d, find_intersection +from manim.typing import PointDType +from manim.utils.simple_functions import choose +from manim.utils.space_ops import cross2d, find_intersection + +if TYPE_CHECKING: + import numpy.typing as npt + + from manim.typing import ( + BezierPoints, + BezierPoints_Array, + MatrixMN, + Point3D, + Point3D_Array, + ) + +# l is a commonly used name in linear algebra +# ruff: noqa: E741 def bezier( @@ -1118,156 +1118,496 @@ def match_interpolate( ) +# Figuring out which Bézier curves most smoothly connect a sequence of points def get_smooth_cubic_bezier_handle_points( - points: Point3D_Array, -) -> tuple[BezierPoints, BezierPoints]: - points = np.asarray(points) - num_handles = len(points) - 1 - dim = points.shape[1] - if num_handles < 1: - return np.zeros((0, dim)), np.zeros((0, dim)) - # Must solve 2*num_handles equations to get the handles. - # l and u are the number of lower an upper diagonal rows - # in the matrix to solve. - l, u = 2, 1 - # diag is a representation of the matrix in diagonal form - # See https://www.particleincell.com/2012/bezier-splines/ - # for how to arrive at these equations - diag: MatrixMN = np.zeros((l + u + 1, 2 * num_handles)) - diag[0, 1::2] = -1 - diag[0, 2::2] = 1 - diag[1, 0::2] = 2 - diag[1, 1::2] = 1 - diag[2, 1:-2:2] = -2 - diag[3, 0:-3:2] = 1 - # last - diag[2, -2] = -1 - diag[1, -1] = 2 - # This is the b as in Ax = b, where we are solving for x, - # and A is represented using diag. However, think of entries - # to x and b as being points in space, not numbers - b: Point3D_Array = np.zeros((2 * num_handles, dim)) - b[1::2] = 2 * points[1:] - b[0] = points[0] - b[-1] = points[-1] - - def solve_func(b: ColVector) -> ColVector | MatrixMN: - return linalg.solve_banded((l, u), diag, b) # type: ignore - - use_closed_solve_function = is_closed(points) - if use_closed_solve_function: - # Get equations to relate first and last points - matrix = diag_to_matrix((l, u), diag) - # last row handles second derivative - matrix[-1, [0, 1, -2, -1]] = [2, -1, 1, -2] - # first row handles first derivative - matrix[0, :] = np.zeros(matrix.shape[1]) - matrix[0, [0, -1]] = [1, 1] - b[0] = 2 * points[0] - b[-1] = np.zeros(dim) - - def closed_curve_solve_func(b: ColVector) -> ColVector | MatrixMN: - return linalg.solve(matrix, b) # type: ignore - - handle_pairs = np.zeros((2 * num_handles, dim)) - for i in range(dim): - if use_closed_solve_function: - handle_pairs[:, i] = closed_curve_solve_func(b[:, i]) - else: - handle_pairs[:, i] = solve_func(b[:, i]) - return handle_pairs[0::2], handle_pairs[1::2] - - -def get_smooth_handle_points( - points: BezierPoints, -) -> tuple[BezierPoints, BezierPoints]: - """Given some anchors (points), compute handles so the resulting bezier curve is smooth. + anchors: Point3D_Array, +) -> tuple[Point3D_Array, Point3D_Array]: + """Given an array of anchors for a cubic spline (array of connected cubic + Bézier curves), compute the 1st and 2nd handle for every curve, so that + the resulting spline is smooth. Parameters ---------- - points - Anchors. + anchors + Anchors of a cubic spline. Returns ------- - typing.Tuple[np.ndarray, np.ndarray] - Computed handles. + :class:`tuple` [:class:`~.Point3D_Array`, :class:`~.Point3D_Array`] + A tuple of two arrays: one containing the 1st handle for every curve in + the cubic spline, and the other containing the 2nd handles. """ - # NOTE points here are anchors. - points = np.asarray(points) - num_handles = len(points) - 1 - dim = points.shape[1] - if num_handles < 1: + anchors = np.asarray(anchors) + n_anchors = anchors.shape[0] + + # If there's a single anchor, there's no Bézier curve. + # Return empty arrays. + if n_anchors == 1: + dim = anchors.shape[1] return np.zeros((0, dim)), np.zeros((0, dim)) - # Must solve 2*num_handles equations to get the handles. - # l and u are the number of lower an upper diagonal rows - # in the matrix to solve. - l, u = 2, 1 - # diag is a representation of the matrix in diagonal form - # See https://www.particleincell.com/2012/bezier-splines/ - # for how to arrive at these equations - diag: MatrixMN = np.zeros((l + u + 1, 2 * num_handles)) - diag[0, 1::2] = -1 - diag[0, 2::2] = 1 - diag[1, 0::2] = 2 - diag[1, 1::2] = 1 - diag[2, 1:-2:2] = -2 - diag[3, 0:-3:2] = 1 - # last - diag[2, -2] = -1 - diag[1, -1] = 2 - # This is the b as in Ax = b, where we are solving for x, - # and A is represented using diag. However, think of entries - # to x and b as being points in space, not numbers - b = np.zeros((2 * num_handles, dim)) - b[1::2] = 2 * points[1:] - b[0] = points[0] - b[-1] = points[-1] - - def solve_func(b: ColVector) -> ColVector | MatrixMN: - return linalg.solve_banded((l, u), diag, b) # type: ignore - - use_closed_solve_function = is_closed(points) - if use_closed_solve_function: - # Get equations to relate first and last points - matrix = diag_to_matrix((l, u), diag) - # last row handles second derivative - matrix[-1, [0, 1, -2, -1]] = [2, -1, 1, -2] - # first row handles first derivative - matrix[0, :] = np.zeros(matrix.shape[1]) - matrix[0, [0, -1]] = [1, 1] - b[0] = 2 * points[0] - b[-1] = np.zeros(dim) - - def closed_curve_solve_func(b: ColVector) -> ColVector | MatrixMN: - return linalg.solve(matrix, b) # type: ignore - - handle_pairs = np.zeros((2 * num_handles, dim)) - for i in range(dim): - if use_closed_solve_function: - handle_pairs[:, i] = closed_curve_solve_func(b[:, i]) - else: - handle_pairs[:, i] = solve_func(b[:, i]) - return handle_pairs[0::2], handle_pairs[1::2] - - -def diag_to_matrix( - l_and_u: tuple[int, int], diag: npt.NDArray[Any] -) -> npt.NDArray[Any]: + + # If there are only two anchors (thus only one pair of handles), + # they can only be an interpolation of these two anchors with alphas + # 1/3 and 2/3, which will draw a straight line between the anchors. + if n_anchors == 2: + return interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]])) + + # Handle different cases depending on whether the points form a closed + # curve or not + curve_is_closed = is_closed(anchors) + if curve_is_closed: + return get_smooth_closed_cubic_bezier_handle_points(anchors) + else: + return get_smooth_open_cubic_bezier_handle_points(anchors) + + +CP_CLOSED_MEMO = np.array([1 / 3]) +UP_CLOSED_MEMO = np.array([1 / 3]) + + +def get_smooth_closed_cubic_bezier_handle_points( + anchors: Point3D_Array, +) -> tuple[Point3D_Array, Point3D_Array]: + r"""Special case of :func:`get_smooth_cubic_bezier_handle_points`, + when the ``anchors`` form a closed loop. + + .. note:: + A system of equations must be solved to get the first handles of + every Bézier curve (referred to as :math:`H_1`). + Then :math:`H_2` (the second handles) can be obtained separately. + + .. seealso:: + The equations were obtained from: + + * `Conditions on control points for continuous curvature. (2016). Jaco Stuifbergen. `_ + + In general, if there are :math:`N+1` anchors, there will be :math:`N` Bézier curves + and thus :math:`N` pairs of handles to find. We must solve the following + system of equations for the 1st handles (example for :math:`N = 5`): + + .. math:: + \begin{pmatrix} + 4 & 1 & 0 & 0 & 1 \\ + 1 & 4 & 1 & 0 & 0 \\ + 0 & 1 & 4 & 1 & 0 \\ + 0 & 0 & 1 & 4 & 1 \\ + 1 & 0 & 0 & 1 & 4 + \end{pmatrix} + \begin{pmatrix} + H_{1,0} \\ + H_{1,1} \\ + H_{1,2} \\ + H_{1,3} \\ + H_{1,4} + \end{pmatrix} + = + \begin{pmatrix} + 4A_0 + 2A_1 \\ + 4A_1 + 2A_2 \\ + 4A_2 + 2A_3 \\ + 4A_3 + 2A_4 \\ + 4A_4 + 2A_5 + \end{pmatrix} + + which will be expressed as :math:`RH_1 = D`. + + :math:`R` is almost a tridiagonal matrix, so we could use Thomas' algorithm. + + .. seealso:: + `Tridiagonal matrix algorithm. Wikipedia. `_ + + However, :math:`R` has ones at the opposite corners. A solution to this is + the first decomposition proposed in the link below, with :math:`\alpha = 1`: + + .. seealso:: + `Tridiagonal matrix algorithm # Variants. Wikipedia. `_ + + .. math:: + R + = + \begin{pmatrix} + 4 & 1 & 0 & 0 & 1 \\ + 1 & 4 & 1 & 0 & 0 \\ + 0 & 1 & 4 & 1 & 0 \\ + 0 & 0 & 1 & 4 & 1 \\ + 1 & 0 & 0 & 1 & 4 + \end{pmatrix} + &= + \begin{pmatrix} + 3 & 1 & 0 & 0 & 0 \\ + 1 & 4 & 1 & 0 & 0 \\ + 0 & 1 & 4 & 1 & 0 \\ + 0 & 0 & 1 & 4 & 1 \\ + 0 & 0 & 0 & 1 & 3 + \end{pmatrix} + + + \begin{pmatrix} + 1 & 0 & 0 & 0 & 1 \\ + 0 & 0 & 0 & 0 & 0 \\ + 0 & 0 & 0 & 0 & 0 \\ + 0 & 0 & 0 & 0 & 0 \\ + 1 & 0 & 0 & 0 & 1 + \end{pmatrix} + \\ + & + \\ + &= + \begin{pmatrix} + 3 & 1 & 0 & 0 & 0 \\ + 1 & 4 & 1 & 0 & 0 \\ + 0 & 1 & 4 & 1 & 0 \\ + 0 & 0 & 1 & 4 & 1 \\ + 0 & 0 & 0 & 1 & 3 + \end{pmatrix} + + + \begin{pmatrix} + 1 \\ + 0 \\ + 0 \\ + 0 \\ + 1 + \end{pmatrix} + \begin{pmatrix} + 1 & 0 & 0 & 0 & 1 + \end{pmatrix} + \\ + & + \\ + &= + T + uv^t + + We decompose :math:`R = T + uv^t`, where :math:`T` is a tridiagonal matrix, and + :math:`u, v` are :math:`N`-D vectors such that :math:`u_0 = u_{N-1} = v_0 = v_{N-1} = 1`, + and :math:`u_i = v_i = 0, \forall i \in \{1, ..., N-2\}`. + + Thus: + + .. math:: + RH_1 &= D \\ + \Rightarrow (T + uv^t)H_1 &= D + + If we find a vector :math:`q` such that :math:`Tq = u`: + + .. math:: + \Rightarrow (T + Tqv^t)H_1 &= D \\ + \Rightarrow T(I + qv^t)H_1 &= D \\ + \Rightarrow H_1 &= (I + qv^t)^{-1} T^{-1} D + + According to Sherman-Morrison's formula: + + .. seealso:: + `Sherman-Morrison's formula. Wikipedia. `_ + + .. math:: + (I + qv^t)^{-1} = I - \frac{1}{1 + v^tq} qv^t + + If we find :math:`Y = T^{-1} D`, or in other words, if we solve for + :math:`Y` in :math:`TY = D`: + + .. math:: + H_1 &= (I + qv^t)^{-1} T^{-1} D \\ + &= (I + qv^t)^{-1} Y \\ + &= (I - \frac{1}{1 + v^tq} qv^t) Y \\ + &= Y - \frac{1}{1 + v^tq} qv^tY + + Therefore, we must solve for :math:`q` and :math:`Y` in :math:`Tq = u` and :math:`TY = D`. + As :math:`T` is now tridiagonal, we shall use Thomas' algorithm. + + Define: + + * :math:`a = [a_0, \ a_1, \ ..., \ a_{N-2}]` as :math:`T`'s lower diagonal of :math:`N-1` elements, + such that :math:`a_0 = a_1 = ... = a_{N-2} = 1`, so this diagonal is filled with ones; + * :math:`b = [b_0, \ b_1, \ ..., \ b_{N-2}, \ b_{N-1}]` as :math:`T`'s main diagonal of :math:`N` elements, + such that :math:`b_0 = b_{N-1} = 3`, and :math:`b_1 = b_2 = ... = b_{N-2} = 4`; + * :math:`c = [c_0, \ c_1, \ ..., \ c_{N-2}]` as :math:`T`'s upper diagonal of :math:`N-1` elements, + such that :math:`c_0 = c_1 = ... = c_{N-2} = 1`: this diagonal is also filled with ones. + + If, according to Thomas' algorithm, we define: + + .. math:: + c'_0 &= \frac{c_0}{b_0} & \\ + c'_i &= \frac{c_i}{b_i - a_{i-1} c'_{i-1}}, & \quad \forall i \in \{1, ..., N-2\} \\ + & & \\ + u'_0 &= \frac{u_0}{b_0} & \\ + u'_i &= \frac{u_i - a_{i-1} u'_{i-1}}{b_i - a_{i-1} c'_{i-1}}, & \quad \forall i \in \{1, ..., N-1\} \\ + & & \\ + D'_0 &= \frac{1}{b_0} D_0 & \\ + D'_i &= \frac{1}{b_i - a_{i-1} c'_{i-1}} (D_i - a_{i-1} D'_{i-1}), & \quad \forall i \in \{1, ..., N-1\} + + Then: + + .. math:: + c'_0 &= \frac{1}{3} & \\ + c'_i &= \frac{1}{4 - c'_{i-1}}, & \quad \forall i \in \{1, ..., N-2\} \\ + & & \\ + u'_0 &= \frac{1}{3} & \\ + u'_i &= \frac{-u'_{i-1}}{4 - c'_{i-1}} = -c'_i u'_{i-1}, & \quad \forall i \in \{1, ..., N-2\} \\ + u'_{N-1} &= \frac{1 - u'_{N-2}}{3 - c'_{N-2}} & \\ + & & \\ + D'_0 &= \frac{1}{3} (4A_0 + 2A_1) & \\ + D'_i &= \frac{1}{4 - c'_{i-1}} (4A_i + 2A_{i+1} - D'_{i-1}) & \\ + &= c_i (4A_i + 2A_{i+1} - D'_{i-1}), & \quad \forall i \in \{1, ..., N-2\} \\ + D'_{N-1} &= \frac{1}{3 - c'_{N-2}} (4A_{N-1} + 2A_N - D'_{N-2}) & + + Finally, we can do Backward Substitution to find :math:`q` and :math:`Y`: + + .. math:: + q_{N-1} &= u'_{N-1} & \\ + q_i &= u'_{i} - c'_i q_{i+1}, & \quad \forall i \in \{0, ..., N-2\} \\ + & & \\ + Y_{N-1} &= D'_{N-1} & \\ + Y_i &= D'_i - c'_i Y_{i+1}, & \quad \forall i \in \{0, ..., N-2\} + + With those values, we can finally calculate :math:`H_1 = Y - \frac{1}{1 + v^tq} qv^tY`. + Given that :math:`v_0 = v_{N-1} = 1`, and :math:`v_1 = v_2 = ... = v_{N-2} = 0`, its dot products + with :math:`q` and :math:`Y` are respectively :math:`v^tq = q_0 + q_{N-1}` and + :math:`v^tY = Y_0 + Y_{N-1}`. Thus: + + .. math:: + H_1 = Y - \frac{1}{1 + q_0 + q_{N-1}} q(Y_0 + Y_{N-1}) + + Once we have :math:`H_1`, we can get :math:`H_2` (the array of second handles) as follows: + + .. math:: + H_{2, i} &= 2A_{i+1} - H_{1, i+1}, & \quad \forall i \in \{0, ..., N-2\} \\ + H_{2, N-1} &= 2A_0 - H_{1, 0} & + + Because the matrix :math:`R` always follows the same pattern (and thus :math:`T, u, v` as well), + we can define a memo list for :math:`c'` and :math:`u'` to avoid recalculation. We cannot + memoize :math:`D` and :math:`Y`, however, because they are always different matrices. We + cannot make a memo for :math:`q` either, but we can calculate it faster because :math:`u'` + can be memoized. + + Parameters + ---------- + anchors + Anchors of a closed cubic spline. + + Returns + ------- + :class:`tuple` [:class:`~.Point3D_Array`, :class:`~.Point3D_Array`] + A tuple of two arrays: one containing the 1st handle for every curve in + the closed cubic spline, and the other containing the 2nd handles. """ - Converts array whose rows represent diagonal - entries of a matrix into the matrix itself. - See scipy.linalg.solve_banded + global CP_CLOSED_MEMO + global UP_CLOSED_MEMO + + A = np.asarray(anchors) + N = A.shape[0] - 1 + dim = A.shape[1] + + # Calculate cp (c prime) and up (u prime) with help from + # CP_CLOSED_MEMO and UP_CLOSED_MEMO. + len_memo = CP_CLOSED_MEMO.size + if len_memo < N - 1: + cp = np.empty(N - 1) + up = np.empty(N - 1) + cp[:len_memo] = CP_CLOSED_MEMO + up[:len_memo] = UP_CLOSED_MEMO + # Forward Substitution 1 + # Calculate up (at the same time we calculate cp). + for i in range(len_memo, N - 1): + cp[i] = 1 / (4 - cp[i - 1]) + up[i] = -cp[i] * up[i - 1] + CP_CLOSED_MEMO = cp + UP_CLOSED_MEMO = up + else: + cp = CP_CLOSED_MEMO[: N - 1] + up = UP_CLOSED_MEMO[: N - 1] + + # The last element of u' is different + cp_last_division = 1 / (3 - cp[N - 2]) + up_last = cp_last_division * (1 - up[N - 2]) + + # Backward Substitution 1 + # Calculate q. + q = np.empty((N, dim)) + q[N - 1] = up_last + for i in range(N - 2, -1, -1): + q[i] = up[i] - cp[i] * q[i + 1] + + # Forward Substitution 2 + # Calculate Dp (D prime). + Dp = np.empty((N, dim)) + AUX = 4 * A[:N] + 2 * A[1:] # Vectorize the sum for efficiency. + Dp[0] = AUX[0] / 3 + for i in range(1, N - 1): + Dp[i] = cp[i] * (AUX[i] - Dp[i - 1]) + Dp[N - 1] = cp_last_division * (AUX[N - 1] - Dp[N - 2]) + + # Backward Substitution + # Calculate Y, which is defined as a view of Dp for efficiency + # and semantic convenience at the same time. + Y = Dp + # Y[N-1] = Dp[N-1] (redundant) + for i in range(N - 2, -1, -1): + Y[i] = Dp[i] - cp[i] * Y[i + 1] + + # Calculate H1. + H1 = Y - 1 / (1 + q[0] + q[N - 1]) * q * (Y[0] + Y[N - 1]) + + # Calculate H2. + H2 = np.empty((N, dim)) + H2[0 : N - 1] = 2 * A[1:N] - H1[1:N] + H2[N - 1] = 2 * A[N] - H1[0] + + return H1, H2 + + +CP_OPEN_MEMO = np.array([0.5]) + + +def get_smooth_open_cubic_bezier_handle_points( + anchors: Point3D_Array, +) -> tuple[Point3D_Array, Point3D_Array]: + r"""Special case of :func:`get_smooth_cubic_bezier_handle_points`, + when the ``anchors`` do not form a closed loop. + + .. note:: + A system of equations must be solved to get the first handles of + every Bèzier curve (referred to as :math:`H_1`). + Then :math:`H_2` (the second handles) can be obtained separately. + + .. seealso:: + The equations were obtained from: + + * `Smooth Bézier Spline Through Prescribed Points. (2012). Particle in Cell Consulting LLC. `_ + * `Conditions on control points for continuous curvature. (2016). Jaco Stuifbergen. `_ + + .. warning:: + The equations in the first webpage have some typos which were corrected in the comments. + + In general, if there are :math:`N+1` anchors, there will be :math:`N` Bézier curves + and thus :math:`N` pairs of handles to find. We must solve the following + system of equations for the 1st handles (example for :math:`N = 5`): + + .. math:: + \begin{pmatrix} + 2 & 1 & 0 & 0 & 0 \\ + 1 & 4 & 1 & 0 & 0 \\ + 0 & 1 & 4 & 1 & 0 \\ + 0 & 0 & 1 & 4 & 1 \\ + 0 & 0 & 0 & 2 & 7 + \end{pmatrix} + \begin{pmatrix} + H_{1,0} \\ + H_{1,1} \\ + H_{1,2} \\ + H_{1,3} \\ + H_{1,4} + \end{pmatrix} + = + \begin{pmatrix} + A_0 + 2A_1 \\ + 4A_1 + 2A_2 \\ + 4A_2 + 2A_3 \\ + 4A_3 + 2A_4 \\ + 8A_4 + A_5 + \end{pmatrix} + + which will be expressed as :math:`TH_1 = D`. + :math:`T` is a tridiagonal matrix, so the system can be solved in :math:`O(N)` + operations. Here we shall use Thomas' algorithm or the tridiagonal matrix + algorithm. + + .. seealso:: + `Tridiagonal matrix algorithm. Wikipedia. `_ + + Define: + + * :math:`a = [a_0, \ a_1, \ ..., \ a_{N-2}]` as :math:`T`'s lower diagonal of :math:`N-1` elements, + such that :math:`a_0 = a_1 = ... = a_{N-3} = 1`, and :math:`a_{N-2} = 2`; + * :math:`b = [b_0, \ b_1, \ ..., \ b_{N-2}, \ b_{N-1}]` as :math:`T`'s main diagonal of :math:`N` elements, + such that :math:`b_0 = 2`, :math:`b_1 = b_2 = ... = b_{N-2} = 4`, and :math:`b_{N-1} = 7`; + * :math:`c = [c_0, \ c_1, \ ..., \ c_{N-2}]` as :math:`T`'s upper diagonal of :math:`{N-1}` elements, + such that :math:`c_0 = c_1 = ... = c_{N-2} = 1`: this diagonal is filled with ones. + + If, according to Thomas' algorithm, we define: + + .. math:: + c'_0 &= \frac{c_0}{b_0} & \\ + c'_i &= \frac{c_i}{b_i - a_{i-1} c'_{i-1}}, & \quad \forall i \in \{1, ..., N-2\} \\ + & & \\ + D'_0 &= \frac{1}{b_0} D_0 & \\ + D'_i &= \frac{1}{b_i - a_{i-1} c'{i-1}} (D_i - a_{i-1} D'_{i-1}), & \quad \forall i \in \{1, ..., N-1\} + + Then: + + .. math:: + c'_0 &= 0.5 & \\ + c'_i &= \frac{1}{4 - c'_{i-1}}, & \quad \forall i \in \{1, ..., N-2\} \\ + & & \\ + D'_0 &= 0.5A_0 + A_1 & \\ + D'_i &= \frac{1}{4 - c'_{i-1}} (4A_i + 2A_{i+1} - D'_{i-1}) & \\ + &= c_i (4A_i + 2A_{i+1} - D'_{i-1}), & \quad \forall i \in \{1, ..., N-2\} \\ + D'_{N-1} &= \frac{1}{7 - 2c'_{N-2}} (8A_{N-1} + A_N - 2D'_{N-2}) & + + Finally, we can do Backward Substitution to find :math:`H_1`: + + .. math:: + H_{1, N-1} &= D'_{N-1} & \\ + H_{1, i} &= D'_i - c'_i H_{1, i+1}, & \quad \forall i \in \{0, ..., N-2\} + + Once we have :math:`H_1`, we can get :math:`H_2` (the array of second handles) as follows: + + .. math:: + H_{2, i} &= 2A_{i+1} - H_{1, i+1}, & \quad \forall i \in \{0, ..., N-2\} \\ + H_{2, N-1} &= 0.5A_N + 0.5H_{1, N-1} & + + As the matrix :math:`T` always follows the same pattern, we can define a memo list + for :math:`c'` to avoid recalculation. We cannot do the same for :math:`D`, however, + because it is always a different matrix. + + Parameters + ---------- + anchors + Anchors of an open cubic spline. + + Returns + ------- + :class:`tuple` [:class:`~.Point3D_Array`, :class:`~.Point3D_Array`] + A tuple of two arrays: one containing the 1st handle for every curve in + the open cubic spline, and the other containing the 2nd handles. """ - l, u = l_and_u - dim = diag.shape[1] - matrix = np.zeros((dim, dim)) - for i in range(l + u + 1): - np.fill_diagonal( - matrix[max(0, i - u) :, max(0, u - i) :], - diag[i, max(0, u - i) :], - ) - return matrix + global CP_OPEN_MEMO + + A = np.asarray(anchors) + N = A.shape[0] - 1 + dim = A.shape[1] + + # Calculate cp (c prime) with help from CP_OPEN_MEMO. + len_memo = CP_OPEN_MEMO.size + if len_memo < N - 1: + cp = np.empty(N - 1) + cp[:len_memo] = CP_OPEN_MEMO + for i in range(len_memo, N - 1): + cp[i] = 1 / (4 - cp[i - 1]) + CP_OPEN_MEMO = cp + else: + cp = CP_OPEN_MEMO[: N - 1] + + # Calculate Dp (D prime). + Dp = np.empty((N, dim)) + Dp[0] = 0.5 * A[0] + A[1] + AUX = 4 * A[1 : N - 1] + 2 * A[2:N] # Vectorize the sum for efficiency. + for i in range(1, N - 1): + Dp[i] = cp[i] * (AUX[i - 1] - Dp[i - 1]) + Dp[N - 1] = (1 / (7 - 2 * cp[N - 2])) * (8 * A[N - 1] + A[N] - 2 * Dp[N - 2]) + + # Backward Substitution. + # H1 (array of the first handles) is defined as a view of Dp for efficiency + # and semantic convenience at the same time. + H1 = Dp + # H1[N-1] = Dp[N-1] (redundant) + for i in range(N - 2, -1, -1): + H1[i] = Dp[i] - cp[i] * H1[i + 1] + + # Calculate H2. + H2 = np.empty((N, dim)) + H2[0 : N - 1] = 2 * A[1:N] - H1[1:N] + H2[N - 1] = 0.5 * (A[N] + H1[N - 1]) + + return H1, H2 # Given 4 control points for a cubic bezier curve (or arrays of such) diff --git a/manim/utils/color/AS2700.py b/manim/utils/color/AS2700.py index 83a3e6abd0..7b6dc6256c 100644 --- a/manim/utils/color/AS2700.py +++ b/manim/utils/color/AS2700.py @@ -24,6 +24,8 @@ """ +from __future__ import annotations + from .core import ManimColor B11_RICH_BLUE = ManimColor("#2B3770") diff --git a/manim/utils/color/BS381.py b/manim/utils/color/BS381.py index 50ae95b96c..60e8567a50 100644 --- a/manim/utils/color/BS381.py +++ b/manim/utils/color/BS381.py @@ -25,6 +25,8 @@ """ +from __future__ import annotations + from .core import ManimColor BS381_101 = ManimColor("#94BFAC") diff --git a/manim/utils/color/X11.py b/manim/utils/color/X11.py index 0379717eac..4338660200 100644 --- a/manim/utils/color/X11.py +++ b/manim/utils/color/X11.py @@ -23,6 +23,8 @@ .. automanimcolormodule:: manim.utils.color.X11 """ +from __future__ import annotations + from .core import ManimColor ALICEBLUE = ManimColor("#F0F8FF") diff --git a/manim/utils/color/XKCD.py b/manim/utils/color/XKCD.py index db9bccaed3..4c38af6862 100644 --- a/manim/utils/color/XKCD.py +++ b/manim/utils/color/XKCD.py @@ -24,6 +24,8 @@ """ +from __future__ import annotations + from .core import ManimColor ACIDGREEN = ManimColor("#8FFE09") diff --git a/manim/utils/color/__init__.py b/manim/utils/color/__init__.py index cfffc4edc9..362a31ac25 100644 --- a/manim/utils/color/__init__.py +++ b/manim/utils/color/__init__.py @@ -47,7 +47,7 @@ """ -from typing import Dict, List +from __future__ import annotations from . import AS2700, BS381, X11, XKCD from .core import * diff --git a/manim/utils/color/manim_colors.py b/manim/utils/color/manim_colors.py index 863cb5a99f..a3eacc83a4 100644 --- a/manim/utils/color/manim_colors.py +++ b/manim/utils/color/manim_colors.py @@ -121,7 +121,7 @@ def named_lines_group(length, colors, names, text_colors, align_to_block): """ -from typing import List +from __future__ import annotations from .core import ManimColor diff --git a/manim/utils/docbuild/autoaliasattr_directive.py b/manim/utils/docbuild/autoaliasattr_directive.py index 6dd645a9fd..ba42bd1ec4 100644 --- a/manim/utils/docbuild/autoaliasattr_directive.py +++ b/manim/utils/docbuild/autoaliasattr_directive.py @@ -12,12 +12,11 @@ if TYPE_CHECKING: from sphinx.application import Sphinx - from typing_extensions import TypeAlias __all__ = ["AliasAttrDocumenter"] -ALIAS_DOCS_DICT, DATA_DICT = parse_module_attributes() +ALIAS_DOCS_DICT, DATA_DICT, TYPEVAR_DICT = parse_module_attributes() ALIAS_LIST = [ alias_name for module_dict in ALIAS_DOCS_DICT.values() @@ -49,7 +48,9 @@ def smart_replace(base: str, alias: str, substitution: str) -> str: occurrences = [] len_alias = len(alias) len_base = len(base) - condition = lambda char: (not char.isalnum()) and char != "_" + + def condition(char: str) -> bool: + return not char.isalnum() and char != "_" start = 0 i = 0 @@ -99,10 +100,11 @@ class AliasAttrDocumenter(Directive): def run(self) -> list[nodes.Element]: module_name = self.arguments[0] - # Slice module_name[6:] to remove the "manim." prefix which is # not present in the keys of the DICTs - module_alias_dict = ALIAS_DOCS_DICT.get(module_name[6:], None) - module_attrs_list = DATA_DICT.get(module_name[6:], None) + module_name = module_name.removeprefix("manim.") + module_alias_dict = ALIAS_DOCS_DICT.get(module_name, None) + module_attrs_list = DATA_DICT.get(module_name, None) + module_typevars = TYPEVAR_DICT.get(module_name, None) content = nodes.container() @@ -160,6 +162,11 @@ def run(self) -> list[nodes.Element]: for A in ALIAS_LIST: alias_doc = alias_doc.replace(f"`{A}`", f":class:`~.{A}`") + # also hyperlink the TypeVars from that module + if module_typevars is not None: + for T in module_typevars: + alias_doc = alias_doc.replace(f"`{T}`", f":class:`{T}`") + # Add all the lines with 4 spaces behind, to consider all the # documentation as a paragraph INSIDE the `.. class::` block doc_lines = alias_doc.split("\n") @@ -171,6 +178,37 @@ def run(self) -> list[nodes.Element]: self.state.nested_parse(unparsed, 0, alias_container) category_alias_container += alias_container + # then add the module TypeVars section + if module_typevars is not None: + module_typevars_section = nodes.section(ids=[f"{module_name}.typevars"]) + content += module_typevars_section + + # Use a rubric (title-like), just like in `module.rst` + module_typevars_section += nodes.rubric(text="TypeVar's") + + # name: str + # definition: TypeVarDict = dict[str, str] + for name, definition in module_typevars.items(): + # Using the `.. class::` directive is CRUCIAL, since + # function/method parameters are always annotated via + # classes - therefore Sphinx expects a class + unparsed = ViewList( + [ + f".. class:: {name}", + "", + " .. parsed-literal::", + "", + f" {definition}", + "", + ] + ) + + # Parse the reST text into a fresh container + # https://www.sphinx-doc.org/en/master/extdev/markupapi.html#parsing-directive-content-as-rest + typevar_container = nodes.container() + self.state.nested_parse(unparsed, 0, typevar_container) + module_typevars_section += typevar_container + # Then, add the traditional "Module Attributes" section if module_attrs_list is not None: module_attrs_section = nodes.section(ids=[f"{module_name}.data"]) diff --git a/manim/utils/docbuild/autocolor_directive.py b/manim/utils/docbuild/autocolor_directive.py index 37c63efb29..574aaedf77 100644 --- a/manim/utils/docbuild/autocolor_directive.py +++ b/manim/utils/docbuild/autocolor_directive.py @@ -38,7 +38,7 @@ def run(self) -> list[nodes.Element]: return [ nodes.error( None, - nodes.paragraph(text="Failed to import module '%s'" % module_name), + nodes.paragraph(text=f"Failed to import module '{module_name}'"), ) ] diff --git a/manim/utils/docbuild/manim_directive.py b/manim/utils/docbuild/manim_directive.py index 111b40fa4d..cd277970e3 100644 --- a/manim/utils/docbuild/manim_directive.py +++ b/manim/utils/docbuild/manim_directive.py @@ -82,7 +82,6 @@ def construct(self): import csv import itertools as it -import os import re import shutil import sys @@ -356,7 +355,7 @@ def _write_rendering_stats(scene_name: str, run_time: str, file_name: str) -> No [ re.sub(r"^(reference\/)|(manim\.)", "", file_name), scene_name, - "%.3f" % run_time, + f"{run_time:.3f}", ], ) diff --git a/manim/utils/docbuild/module_parsing.py b/manim/utils/docbuild/module_parsing.py index 87bfa76c40..57ac9a56aa 100644 --- a/manim/utils/docbuild/module_parsing.py +++ b/manim/utils/docbuild/module_parsing.py @@ -26,6 +26,10 @@ classified by category in different `AliasCategoryDict` objects. """ +ModuleTypeVarDict: TypeAlias = dict[str, str] +"""Dictionary containing every :class:`TypeVar` defined in a module.""" + + AliasDocsDict: TypeAlias = dict[str, ModuleLevelAliasDict] """Dictionary which, for every module in Manim, contains documentation about their module-level attributes which are explicitly defined as @@ -39,17 +43,22 @@ explicitly defined as :class:`TypeAlias`. """ +TypeVarDict: TypeAlias = dict[str, ModuleTypeVarDict] +"""A dictionary mapping module names to dictionaries of :class:`TypeVar` objects.""" + ALIAS_DOCS_DICT: AliasDocsDict = {} DATA_DICT: DataDict = {} +TYPEVAR_DICT: TypeVarDict = {} MANIM_ROOT = Path(__file__).resolve().parent.parent.parent # In the following, we will use ``type(xyz) is xyz_type`` instead of # isinstance checks to make sure no subclasses of the type pass the # check +# ruff: noqa: E721 -def parse_module_attributes() -> tuple[AliasDocsDict, DataDict]: +def parse_module_attributes() -> tuple[AliasDocsDict, DataDict, TypeVarDict]: """Read all files, generate Abstract Syntax Trees from them, and extract useful information about the type aliases defined in the files: the category they belong to, their definition and their @@ -57,19 +66,24 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict]: Returns ------- - ALIAS_DOCS_DICT : `AliasDocsDict` + ALIAS_DOCS_DICT : :class:`AliasDocsDict` A dictionary containing the information from all the type - aliases in Manim. See `AliasDocsDict` for more information. + aliases in Manim. See :class:`AliasDocsDict` for more information. - DATA_DICT : `DataDict` + DATA_DICT : :class:`DataDict` A dictionary containing the names of all DOCUMENTED module-level attributes which are not a :class:`TypeAlias`. + + TYPEVAR_DICT : :class:`TypeVarDict` + A dictionary containing the definitions of :class:`TypeVar` objects, + organized by modules. """ global ALIAS_DOCS_DICT global DATA_DICT + global TYPEVAR_DICT - if ALIAS_DOCS_DICT or DATA_DICT: - return ALIAS_DOCS_DICT, DATA_DICT + if ALIAS_DOCS_DICT or DATA_DICT or TYPEVAR_DICT: + return ALIAS_DOCS_DICT, DATA_DICT, TYPEVAR_DICT for module_path in MANIM_ROOT.rglob("*.py"): module_name = module_path.resolve().relative_to(MANIM_ROOT) @@ -84,6 +98,9 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict]: category_dict: AliasCategoryDict | None = None alias_info: AliasInfo | None = None + # For storing TypeVars + module_typevars: ModuleTypeVarDict = {} + # For storing regular module attributes data_list: list[str] = [] data_name: str | None = None @@ -171,6 +188,19 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict]: alias_info = category_dict[alias_name] continue + # Check if it is a typing.TypeVar + elif ( + type(node) is ast.Assign + and type(node.targets[0]) is ast.Name + and type(node.value) is ast.Call + and type(node.value.func) is ast.Name + and node.value.func.id.endswith("TypeVar") + ): + module_typevars[node.targets[0].id] = ast.unparse( + node.value + ).replace("_", r"\_") + continue + # If here, the node is not a TypeAlias definition alias_info = None @@ -184,7 +214,9 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict]: else: target = None - if type(target) is ast.Name: + if type(target) is ast.Name and not ( + type(node) is ast.Assign and target.id not in module_typevars + ): data_name = target.id else: data_name = None @@ -193,5 +225,7 @@ def parse_module_attributes() -> tuple[AliasDocsDict, DataDict]: ALIAS_DOCS_DICT[module_name] = module_dict if len(data_list) > 0: DATA_DICT[module_name] = data_list + if module_typevars: + TYPEVAR_DICT[module_name] = module_typevars - return ALIAS_DOCS_DICT, DATA_DICT + return ALIAS_DOCS_DICT, DATA_DICT, TYPEVAR_DICT diff --git a/manim/utils/file_ops.py b/manim/utils/file_ops.py index 58c82464a1..7efcee02c5 100644 --- a/manim/utils/file_ops.py +++ b/manim/utils/file_ops.py @@ -32,7 +32,7 @@ from manim import __version__, config, logger -from .. import config, console +from .. import console def is_mp4_format() -> bool: diff --git a/manim/utils/ipython_magic.py b/manim/utils/ipython_magic.py index 7911da9cf3..601d4d6f8e 100644 --- a/manim/utils/ipython_magic.py +++ b/manim/utils/ipython_magic.py @@ -3,13 +3,12 @@ from __future__ import annotations import mimetypes -import os import shutil from datetime import datetime from pathlib import Path from typing import Any -from manim import Group, config, logger, tempconfig +from manim import config, logger, tempconfig from manim.__main__ import main from manim.renderer.shader import shader_program_cache diff --git a/manim/utils/module_ops.py b/manim/utils/module_ops.py index 2ff7291470..03f297030d 100644 --- a/manim/utils/module_ops.py +++ b/manim/utils/module_ops.py @@ -2,7 +2,6 @@ import importlib.util import inspect -import os import re import sys import types diff --git a/manim/utils/simple_functions.py b/manim/utils/simple_functions.py index 72a0f8ddbd..898c1d527b 100644 --- a/manim/utils/simple_functions.py +++ b/manim/utils/simple_functions.py @@ -10,9 +10,7 @@ ] -import inspect from functools import lru_cache -from types import MappingProxyType from typing import Callable import numpy as np diff --git a/manim/utils/space_ops.py b/manim/utils/space_ops.py index ec47d136eb..197621fce6 100644 --- a/manim/utils/space_ops.py +++ b/manim/utils/space_ops.py @@ -10,7 +10,7 @@ from mapbox_earcut import triangulate_float32 as earcut from scipy.spatial.transform import Rotation -from manim.constants import DOWN, OUT, PI, RIGHT, TAU, UP, RendererType +from manim.constants import DOWN, OUT, PI, RIGHT, TAU, UP from manim.utils.iterables import adjacent_pairs if TYPE_CHECKING: diff --git a/poetry.lock b/poetry.lock index 22d920cd9a..8480752e78 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2449,6 +2449,7 @@ description = "A cross platform helper library for ModernGL making window creati optional = false python-versions = ">=3.8" files = [ + {file = "moderngl-window-2.4.6.tar.gz", hash = "sha256:db9b4c27f35faa6f243b6d8cde6ada6da6e79541d62b8e536c0b20da29720c32"}, {file = "moderngl_window-2.4.6-py3-none-any.whl", hash = "sha256:cfa81c2b222536270a077e2901f5f7a18e317f7332026ae443662555ebf7a66d"}, ] @@ -4298,22 +4299,22 @@ files = [ [[package]] name = "tornado" -version = "6.4" +version = "6.4.1" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = true -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ - {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"}, - {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"}, - {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"}, - {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"}, - {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"}, - {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"}, - {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"}, + {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"}, + {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"}, + {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"}, + {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"}, + {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"}, + {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, ] [[package]] @@ -4448,13 +4449,13 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] diff --git a/pyproject.toml b/pyproject.toml index 7b05d004df..c38f80daef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,48 @@ extend-exclude = [ ] fix = true +[tool.ruff.lint] +select = [ + "E", + "F", + "I", + "UP", +] + +ignore = [ + # due to the import * used in manim + "F403", + "F405", + # as recommended by https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + "E111", + "E114", + "E117", + "E501", +] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ + # unused variable + "F841", + # from __future__ import annotations + "I002", +] + +"example_scenes/*" = [ + "I002", +] + +"__init__.py" = [ + "F401", + "F403", +] + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + [tool.ruff.format] docstring-code-format = true diff --git a/scripts/dev_changelog.py b/scripts/dev_changelog.py index d226b0195f..35e77dd798 100755 --- a/scripts/dev_changelog.py +++ b/scripts/dev_changelog.py @@ -301,7 +301,6 @@ def main(token, prior, tag, additional, outfile): for PR in pr_by_labels[label]: num = PR.number - url = PR.html_url title = PR.title label = PR.labels f.write(f"* :pr:`{num}`: {title}\n") diff --git a/scripts/extract_frames.py b/scripts/extract_frames.py index d3c7691895..a8900de5c0 100644 --- a/scripts/extract_frames.py +++ b/scripts/extract_frames.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pathlib import sys diff --git a/tests/interface/test_commands.py b/tests/interface/test_commands.py index d902d3b07b..c6223a78cb 100644 --- a/tests/interface/test_commands.py +++ b/tests/interface/test_commands.py @@ -8,7 +8,7 @@ from click.testing import CliRunner -from manim import __version__, capture, tempconfig +from manim import __version__, capture from manim.__main__ import main from manim.cli.checkhealth.checks import HEALTH_CHECKS diff --git a/tests/module/mobject/geometry/test_unit_geometry.py b/tests/module/mobject/geometry/test_unit_geometry.py index 1729752fff..45f4da279f 100644 --- a/tests/module/mobject/geometry/test_unit_geometry.py +++ b/tests/module/mobject/geometry/test_unit_geometry.py @@ -4,10 +4,10 @@ import numpy as np -logger = logging.getLogger(__name__) - from manim import BackgroundRectangle, Circle, Sector +logger = logging.getLogger(__name__) + def test_get_arc_center(): np.testing.assert_array_equal( diff --git a/tests/module/mobject/graphing/test_coordinate_system.py b/tests/module/mobject/graphing/test_coordinate_system.py index 4aa71f8968..727bb0ba68 100644 --- a/tests/module/mobject/graphing/test_coordinate_system.py +++ b/tests/module/mobject/graphing/test_coordinate_system.py @@ -5,9 +5,22 @@ import numpy as np import pytest -from manim import LEFT, ORIGIN, PI, UR, Axes, Circle, ComplexPlane +from manim import ( + LEFT, + ORIGIN, + PI, + UR, + Axes, + Circle, + ComplexPlane, + Dot, + NumberPlane, + PolarPlane, + ThreeDAxes, + config, + tempconfig, +) from manim import CoordinateSystem as CS -from manim import Dot, NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig def test_initial_config(): diff --git a/tests/module/mobject/text/test_texmobject.py b/tests/module/mobject/text/test_texmobject.py index 9f487be4e8..c8d4f51f84 100644 --- a/tests/module/mobject/text/test_texmobject.py +++ b/tests/module/mobject/text/test_texmobject.py @@ -6,8 +6,6 @@ import pytest from manim import MathTex, SingleStringMathTex, Tex, TexTemplate, config, tempconfig -from manim.mobject.types.vectorized_mobject import VMobject -from manim.utils.color import RED def test_MathTex(): diff --git a/tests/module/utils/test_bezier.py b/tests/module/utils/test_bezier.py index 7e1351c961..dca4be193e 100644 --- a/tests/module/utils/test_bezier.py +++ b/tests/module/utils/test_bezier.py @@ -5,8 +5,10 @@ from _split_matrices import SPLIT_MATRICES from _subdivision_matrices import SUBDIVISION_MATRICES +from manim.typing import ManimFloat from manim.utils.bezier import ( _get_subdivision_matrix, + get_smooth_cubic_bezier_handle_points, partial_bezier_points, split_bezier, subdivide_bezier, @@ -95,3 +97,73 @@ def test_subdivide_bezier() -> None: subdivide_bezier(points, n_divisions), subdivision_matrix @ points, ) + + +def test_get_smooth_cubic_bezier_handle_points() -> None: + """Test that :func:`.get_smooth_cubic_bezier_handle_points` returns the + correct handles, both for open and closed Bézier splines. + """ + open_curve_corners = np.array( + [ + [1, 1, 0], + [-1, 1, 1], + [-1, -1, 2], + [1, -1, 1], + ], + dtype=ManimFloat, + ) + h1, h2 = get_smooth_cubic_bezier_handle_points(open_curve_corners) + assert np.allclose( + h1, + np.array( + [ + [1 / 5, 11 / 9, 13 / 45], + [-7 / 5, 5 / 9, 64 / 45], + [-3 / 5, -13 / 9, 91 / 45], + ] + ), + ) + assert np.allclose( + h2, + np.array( + [ + [-3 / 5, 13 / 9, 26 / 45], + [-7 / 5, -5 / 9, 89 / 45], + [1 / 5, -11 / 9, 68 / 45], + ] + ), + ) + + closed_curve_corners = np.array( + [ + [1, 1, 0], + [-1, 1, 1], + [-1, -1, 2], + [1, -1, 1], + [1, 1, 0], + ], + dtype=ManimFloat, + ) + h1, h2 = get_smooth_cubic_bezier_handle_points(closed_curve_corners) + assert np.allclose( + h1, + np.array( + [ + [1 / 2, 3 / 2, 0], + [-3 / 2, 1 / 2, 3 / 2], + [-1 / 2, -3 / 2, 2], + [3 / 2, -1 / 2, 1 / 2], + ] + ), + ) + assert np.allclose( + h2, + np.array( + [ + [-1 / 2, 3 / 2, 1 / 2], + [-3 / 2, -1 / 2, 2], + [1 / 2, -3 / 2, 3 / 2], + [3 / 2, 1 / 2, 0], + ] + ), + ) diff --git a/tests/module/utils/test_space_ops.py b/tests/module/utils/test_space_ops.py index eef307b4b6..13584c797d 100644 --- a/tests/module/utils/test_space_ops.py +++ b/tests/module/utils/test_space_ops.py @@ -4,7 +4,7 @@ import pytest from manim.utils.space_ops import * -from manim.utils.space_ops import shoelace, shoelace_direction +from manim.utils.space_ops import shoelace def test_rotate_vector(): diff --git a/tests/opengl/test_coordinate_system_opengl.py b/tests/opengl/test_coordinate_system_opengl.py index 603f8eba54..ed596d7e9d 100644 --- a/tests/opengl/test_coordinate_system_opengl.py +++ b/tests/opengl/test_coordinate_system_opengl.py @@ -5,9 +5,21 @@ import numpy as np import pytest -from manim import LEFT, ORIGIN, PI, UR, Axes, Circle, ComplexPlane +from manim import ( + LEFT, + ORIGIN, + PI, + UR, + Axes, + Circle, + ComplexPlane, + NumberPlane, + PolarPlane, + ThreeDAxes, + config, + tempconfig, +) from manim import CoordinateSystem as CS -from manim import NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig def test_initial_config(using_opengl_renderer): diff --git a/tests/test_config.py b/tests/test_config.py index ed2f4b59c0..0c60d59b10 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import tempfile from pathlib import Path @@ -103,7 +102,6 @@ def test_custom_dirs(tmp_path): "frame_rate": 15, "pixel_height": 854, "pixel_width": 480, - "save_sections": True, "sections_dir": "{media_dir}/test_sections", "video_dir": "{media_dir}/test_video", "partial_movie_dir": "{media_dir}/test_partial_movie_dir", diff --git a/tests/test_graphical_units/test_axes.py b/tests/test_graphical_units/test_axes.py index 8e450511d2..2e06300e46 100644 --- a/tests/test_graphical_units/test_axes.py +++ b/tests/test_graphical_units/test_axes.py @@ -293,7 +293,10 @@ def test_get_z_axis_label(scene): @frames_comparison def test_polar_graph(scene): polar = PolarPlane() - r = lambda theta: 4 * np.sin(theta * 4) + + def r(theta): + return 4 * np.sin(theta * 4) + polar_graph = polar.plot_polar_graph(r) scene.add(polar, polar_graph) diff --git a/tests/test_graphical_units/test_tex_mobject.py b/tests/test_graphical_units/test_tex_mobject.py index d92162b379..b229e1cb0f 100644 --- a/tests/test_graphical_units/test_tex_mobject.py +++ b/tests/test_graphical_units/test_tex_mobject.py @@ -1,5 +1,3 @@ -import pytest - from manim import * from manim.utils.testing.frames_comparison import frames_comparison diff --git a/tests/test_graphical_units/test_text.py b/tests/test_graphical_units/test_text.py index 21649222be..ab5d4fbeb1 100644 --- a/tests/test_graphical_units/test_text.py +++ b/tests/test_graphical_units/test_text.py @@ -1,6 +1,4 @@ -import pytest - -from manim import RED, MarkupText, Text, VGroup, VMobject +from manim import RED, MarkupText, Text, VMobject __module_test__ = "text" diff --git a/tests/test_graphical_units/test_threed.py b/tests/test_graphical_units/test_threed.py index 022201f4c8..b6079e5e4c 100644 --- a/tests/test_graphical_units/test_threed.py +++ b/tests/test_graphical_units/test_threed.py @@ -33,6 +33,18 @@ def test_Cone(scene): scene.add(Cone(resolution=16)) +def test_Cone_get_start_and_get_end(): + cone = Cone().shift(RIGHT).rotate(PI / 4, about_point=ORIGIN, about_edge=OUT) + start = [0.70710678, 0.70710678, -1.0] + end = [0.70710678, 0.70710678, 0.0] + assert np.allclose( + cone.get_start(), start, atol=0.01 + ), "start points of Cone do not match" + assert np.allclose( + cone.get_end(), end, atol=0.01 + ), "end points of Cone do not match" + + @frames_comparison(base_scene=ThreeDScene) def test_Cylinder(scene): scene.add(Cylinder()) @@ -149,3 +161,14 @@ def param_surface(u, v): axes=axes, colorscale=[(RED, -0.4), (YELLOW, 0), (GREEN, 0.4)], axis=1 ) scene.add(axes, surface_plane) + + +def test_get_start_and_end_Arrow3d(): + start, end = ORIGIN, np.array([2, 1, 0]) + arrow = Arrow3D(start, end) + assert np.allclose( + arrow.get_start(), start, atol=0.01 + ), "start points of Arrow3D do not match" + assert np.allclose( + arrow.get_end(), end, atol=0.01 + ), "end points of Arrow3D do not match"