-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add typings to tex_mobject.py
and numbers.py
#4015
base: main
Are you sure you want to change the base?
Changes from 4 commits
9b3bf7c
e96beca
00a3599
1b59bc1
6552eb5
139a325
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,9 @@ | |
import sys | ||
import types | ||
import warnings | ||
from collections.abc import Iterable | ||
from functools import partialmethod, reduce | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Callable, Literal | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
|
@@ -40,13 +39,15 @@ | |
from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterable, Sequence | ||
from typing import Callable, Literal | ||
|
||
from typing_extensions import Self, TypeAlias | ||
|
||
from manim.typing import ( | ||
FunctionOverride, | ||
InternalPoint3D, | ||
ManimFloat, | ||
ManimInt, | ||
MappingFunction, | ||
PathFuncType, | ||
PixelArray, | ||
|
@@ -100,18 +101,18 @@ def __init__( | |
color: ParsableManimColor | list[ParsableManimColor] = WHITE, | ||
name: str | None = None, | ||
dim: int = 3, | ||
target=None, | ||
target: Mobject | None = None, | ||
z_index: float = 0, | ||
) -> None: | ||
self.name = self.__class__.__name__ if name is None else name | ||
self.dim = dim | ||
self.target = target | ||
self.z_index = z_index | ||
self.name: str = self.__class__.__name__ if name is None else name | ||
self.dim: int = dim | ||
self.target: Mobject | None = target | ||
self.z_index: float = z_index | ||
self.point_hash = None | ||
self.submobjects = [] | ||
self.submobjects: Sequence[Mobject] = [] | ||
self.updaters: list[Updater] = [] | ||
self.updating_suspended = False | ||
self.color = ManimColor.parse(color) | ||
self.color: ManimColor | list[ManimColor] = ManimColor.parse(color) | ||
|
||
self.reset_points() | ||
self.generate_points() | ||
|
@@ -2291,16 +2292,16 @@ def get_mobject_type_class() -> type[Mobject]: | |
"""Return the base class of this mobject type.""" | ||
return Mobject | ||
|
||
def split(self) -> list[Self]: | ||
def split(self) -> Sequence[Self]: | ||
result = [self] if len(self.points) > 0 else [] | ||
return result + self.submobjects | ||
|
||
def get_family(self, recurse: bool = True) -> list[Self]: | ||
def get_family(self, recurse: bool = True) -> Sequence[Self]: | ||
sub_families = [x.get_family() for x in self.submobjects] | ||
all_mobjects = [self] + list(it.chain(*sub_families)) | ||
return remove_list_redundancies(all_mobjects) | ||
|
||
def family_members_with_points(self) -> list[Self]: | ||
def family_members_with_points(self) -> Sequence[Self]: | ||
return [m for m in self.get_family() if m.get_num_points() > 0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't the usual recommendation to have type annotations for function inputs be as general as possible and have type annotations for function outputs be as specific as possible? It's similar to Postel's law: "be conservative in what you send, be liberal in what you accept". So, I think the return types here should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's generally very true, but there's an issue with returning, say, a Therefore, the rule we apply here is:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure that this logic applies here, because of the usage of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to use |
||
|
||
def arrange( | ||
|
@@ -2576,13 +2577,13 @@ def init_sizes(sizes, num, measures, name): | |
|
||
def sort( | ||
self, | ||
point_to_num_func: Callable[[Point3D], ManimInt] = lambda p: p[0], | ||
submob_func: Callable[[Mobject], ManimInt] | None = None, | ||
point_to_num_func: Callable[[Point3D], float] = lambda p: p[0], | ||
submob_func: Callable[[Mobject], float] | None = None, | ||
) -> Self: | ||
"""Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``.""" | ||
if submob_func is None: | ||
|
||
def submob_func(m: Mobject): | ||
def submob_func(m: Mobject) -> float: | ||
return point_to_num_func(m.get_center()) | ||
|
||
self.submobjects.sort(key=submob_func) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,19 +4,30 @@ | |
|
||
__all__ = ["DecimalNumber", "Integer", "Variable"] | ||
|
||
from collections.abc import Sequence | ||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
from manim import config | ||
from manim._config import config | ||
from manim.constants import * | ||
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL | ||
from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex, Tex | ||
from manim.mobject.text.text_mobject import Text | ||
from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex | ||
|
||
from manim.mobject.types.vectorized_mobject import VMobject | ||
from manim.mobject.value_tracker import ValueTracker | ||
|
||
string_to_mob_map = {} | ||
if TYPE_CHECKING: | ||
from typing import Any, Union | ||
|
||
from typing_extensions import Self, TypeAlias | ||
|
||
from manim.mobject.text.tex_mobject import Tex | ||
|
||
from manim.mobject.text.text_mobject import MarkupText, Text | ||
|
||
from manim.typing import Vector3D | ||
|
||
TextLike: TypeAlias = Union[SingleStringMathTex, MathTex, Tex, Text, MarkupText] | ||
|
||
|
||
string_to_mob_map: dict[str, TextLike] = {} | ||
|
||
__all__ = ["DecimalNumber", "Integer", "Variable"] | ||
|
||
|
@@ -83,38 +94,38 @@ def construct(self): | |
|
||
def __init__( | ||
self, | ||
number: float = 0, | ||
number: float | complex = 0, | ||
num_decimal_places: int = 2, | ||
mob_class: VMobject = MathTex, | ||
mob_class: type[TextLike] = MathTex, | ||
include_sign: bool = False, | ||
group_with_commas: bool = True, | ||
digit_buff_per_font_unit: float = 0.001, | ||
show_ellipsis: bool = False, | ||
unit: str | None = None, # Aligned to bottom unless it starts with "^" | ||
unit_buff_per_font_unit: float = 0, | ||
include_background_rectangle: bool = False, | ||
edge_to_fix: Sequence[float] = LEFT, | ||
edge_to_fix: Vector3D = LEFT, | ||
font_size: float = DEFAULT_FONT_SIZE, | ||
stroke_width: float = 0, | ||
fill_opacity: float = 1.0, | ||
**kwargs, | ||
): | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs, stroke_width=stroke_width) | ||
self.number = number | ||
self.num_decimal_places = num_decimal_places | ||
self.include_sign = include_sign | ||
self.mob_class = mob_class | ||
self.group_with_commas = group_with_commas | ||
self.digit_buff_per_font_unit = digit_buff_per_font_unit | ||
self.show_ellipsis = show_ellipsis | ||
self.unit = unit | ||
self.unit_buff_per_font_unit = unit_buff_per_font_unit | ||
self.include_background_rectangle = include_background_rectangle | ||
self.edge_to_fix = edge_to_fix | ||
self._font_size = font_size | ||
self.fill_opacity = fill_opacity | ||
|
||
self.initial_config = kwargs.copy() | ||
self.number: float | complex = number | ||
self.num_decimal_places: int = num_decimal_places | ||
self.include_sign: bool = include_sign | ||
self.mob_class: type[TextLike] = mob_class | ||
self.group_with_commas: bool = group_with_commas | ||
self.digit_buff_per_font_unit: float = digit_buff_per_font_unit | ||
self.show_ellipsis: bool = show_ellipsis | ||
self.unit: str | None = unit | ||
self.unit_buff_per_font_unit: float = unit_buff_per_font_unit | ||
self.include_background_rectangle: bool = include_background_rectangle | ||
self.edge_to_fix: Vector3D = edge_to_fix | ||
self._font_size: float = font_size | ||
self.fill_opacity: float = fill_opacity | ||
Comment on lines
+114
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this feels a little bit suspicious - mypy should be able to figure out types from assignment? |
||
|
||
self.initial_config: dict[str, Any] = kwargs.copy() | ||
self.initial_config.update( | ||
{ | ||
"num_decimal_places": num_decimal_places, | ||
|
@@ -136,12 +147,12 @@ def __init__( | |
self.init_colors() | ||
|
||
@property | ||
def font_size(self): | ||
def font_size(self) -> float: | ||
"""The font size of the tex mobject.""" | ||
return self.height / self.initial_height * self._font_size | ||
|
||
@font_size.setter | ||
def font_size(self, font_val): | ||
def font_size(self, font_val: float) -> None: | ||
if font_val <= 0: | ||
raise ValueError("font_size must be greater than 0.") | ||
elif self.height > 0: | ||
|
@@ -152,7 +163,7 @@ def font_size(self, font_val): | |
# font_size does not depend on current size. | ||
self.scale(font_val / self.font_size) | ||
|
||
def _set_submobjects_from_number(self, number): | ||
def _set_submobjects_from_number(self, number: float | complex) -> None: | ||
self.number = number | ||
self.submobjects = [] | ||
|
||
|
@@ -161,8 +172,9 @@ def _set_submobjects_from_number(self, number): | |
|
||
# Add non-numerical bits | ||
if self.show_ellipsis: | ||
# TODO: Why MyPy 'cannot determine type of "color"'? | ||
self.add( | ||
self._string_to_mob("\\dots", SingleStringMathTex, color=self.color), | ||
self._string_to_mob(r"\dots", SingleStringMathTex, color=self.color), # type: ignore [has-type] | ||
) | ||
|
||
self.arrange( | ||
|
@@ -196,12 +208,12 @@ def _set_submobjects_from_number(self, number): | |
self.unit_sign.align_to(self, UP) | ||
|
||
# track the initial height to enable scaling via font_size | ||
self.initial_height = self.height | ||
self.initial_height: float = self.height | ||
|
||
if self.include_background_rectangle: | ||
self.add_background_rectangle() | ||
|
||
def _get_num_string(self, number): | ||
def _get_num_string(self, number: float | complex) -> str: | ||
if isinstance(number, complex): | ||
formatter = self._get_complex_formatter() | ||
else: | ||
|
@@ -214,17 +226,21 @@ def _get_num_string(self, number): | |
|
||
return num_string | ||
|
||
def _string_to_mob(self, string: str, mob_class: VMobject | None = None, **kwargs): | ||
def _string_to_mob( | ||
self, string: str, mob_class: type[TextLike] | None = None, **kwargs: Any | ||
) -> TextLike: | ||
chopan050 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if mob_class is None: | ||
mob_class = self.mob_class | ||
|
||
_mob_class = self.mob_class if mob_class is None else mob_class | ||
chopan050 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if string not in string_to_mob_map: | ||
string_to_mob_map[string] = mob_class(string, **kwargs) | ||
string_to_mob_map[string] = _mob_class(string, **kwargs) | ||
mob = string_to_mob_map[string].copy() | ||
mob.font_size = self._font_size | ||
return mob | ||
|
||
def _get_formatter(self, **kwargs): | ||
def _get_formatter(self, **kwargs: Any) -> str: | ||
""" | ||
Configuration is based first off instance attributes, | ||
but overwritten by any kew word argument. Relevant | ||
|
@@ -257,16 +273,16 @@ def _get_formatter(self, **kwargs): | |
], | ||
) | ||
|
||
def _get_complex_formatter(self): | ||
def _get_complex_formatter(self, **kwargs: Any) -> str: | ||
return "".join( | ||
[ | ||
self._get_formatter(field_name="0.real"), | ||
self._get_formatter(field_name="0.imag", include_sign=True), | ||
self._get_formatter(field_name="0.real", **kwargs), | ||
self._get_formatter(field_name="0.imag", **kwargs, include_sign=True), | ||
"i", | ||
], | ||
) | ||
|
||
def set_value(self, number: float): | ||
def set_value(self, number: float | complex) -> Self: | ||
"""Set the value of the :class:`~.DecimalNumber` to a new number. | ||
|
||
Parameters | ||
|
@@ -303,11 +319,12 @@ def set_value(self, number: float): | |
self.init_colors() | ||
return self | ||
|
||
def get_value(self): | ||
def get_value(self) -> float | complex: | ||
return self.number | ||
|
||
def increment_value(self, delta_t=1): | ||
def increment_value(self, delta_t: float | complex = 1.0) -> Self: | ||
self.set_value(self.get_value() + delta_t) | ||
return self | ||
|
||
|
||
class Integer(DecimalNumber): | ||
|
@@ -327,10 +344,15 @@ def construct(self): | |
self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4)) | ||
""" | ||
|
||
def __init__(self, number=0, num_decimal_places=0, **kwargs): | ||
def __init__( | ||
self, | ||
number: float | complex = 0, | ||
chopan050 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_decimal_places: int = 0, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs) | ||
|
||
def get_value(self): | ||
def get_value(self) -> int: | ||
return int(np.round(super().get_value())) | ||
|
||
|
||
|
@@ -441,16 +463,19 @@ def __init__( | |
self, | ||
var: float, | ||
label: str | Tex | MathTex | Text | SingleStringMathTex, | ||
var_type: DecimalNumber | Integer = DecimalNumber, | ||
var_type: type[DecimalNumber | Integer] = DecimalNumber, | ||
num_decimal_places: int = 2, | ||
**kwargs, | ||
): | ||
self.label = MathTex(label) if isinstance(label, str) else label | ||
**kwargs: Any, | ||
) -> None: | ||
self.label: Tex | MathTex | Text | SingleStringMathTex = ( | ||
MathTex(label) if isinstance(label, str) else label | ||
) | ||
equals = MathTex("=").next_to(self.label, RIGHT) | ||
self.label.add(equals) | ||
|
||
self.tracker = ValueTracker(var) | ||
self.tracker: ValueTracker = ValueTracker(var) | ||
|
||
self.value: DecimalNumber | Integer | ||
if var_type == DecimalNumber: | ||
self.value = DecimalNumber( | ||
self.tracker.get_value(), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's kind of weird that mypy doesn't infer these? It should be able to 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be able to infer them. I've just seen a lot of attribute typehints around the source code, so I thought I should do the same here.