Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typings to tex_mobject.py and numbers.py #4015

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions manim/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
import sys
import types
import warnings
from collections.abc import Iterable
from functools import partialmethod, reduce
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Literal
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -40,13 +39,15 @@
from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing import Callable, Literal

from typing_extensions import Self, TypeAlias

from manim.typing import (
FunctionOverride,
InternalPoint3D,
ManimFloat,
ManimInt,
MappingFunction,
PathFuncType,
PixelArray,
Expand Down Expand Up @@ -100,18 +101,18 @@ def __init__(
color: ParsableManimColor | list[ParsableManimColor] = WHITE,
name: str | None = None,
dim: int = 3,
target=None,
target: Mobject | None = None,
z_index: float = 0,
) -> None:
self.name = self.__class__.__name__ if name is None else name
self.dim = dim
self.target = target
self.z_index = z_index
self.name: str = self.__class__.__name__ if name is None else name
self.dim: int = dim
self.target: Mobject | None = target
self.z_index: float = z_index
Comment on lines +107 to +110
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of weird that mypy doesn't infer these? It should be able to 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be able to infer them. I've just seen a lot of attribute typehints around the source code, so I thought I should do the same here.

self.point_hash = None
self.submobjects = []
self.submobjects: Sequence[Mobject] = []
self.updaters: list[Updater] = []
self.updating_suspended = False
self.color = ManimColor.parse(color)
self.color: ManimColor | list[ManimColor] = ManimColor.parse(color)

self.reset_points()
self.generate_points()
Expand Down Expand Up @@ -2291,16 +2292,16 @@ def get_mobject_type_class() -> type[Mobject]:
"""Return the base class of this mobject type."""
return Mobject

def split(self) -> list[Self]:
def split(self) -> Sequence[Self]:
result = [self] if len(self.points) > 0 else []
return result + self.submobjects

def get_family(self, recurse: bool = True) -> list[Self]:
def get_family(self, recurse: bool = True) -> Sequence[Self]:
sub_families = [x.get_family() for x in self.submobjects]
all_mobjects = [self] + list(it.chain(*sub_families))
return remove_list_redundancies(all_mobjects)

def family_members_with_points(self) -> list[Self]:
def family_members_with_points(self) -> Sequence[Self]:
return [m for m in self.get_family() if m.get_num_points() > 0]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the usual recommendation to have type annotations for function inputs be as general as possible and have type annotations for function outputs be as specific as possible? It's similar to Postel's law: "be conservative in what you send, be liberal in what you accept". So, I think the return types here should be list[...].

Copy link
Contributor Author

@chopan050 chopan050 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's generally very true, but there's an issue with returning, say, a list[Mobject]: subclasses like VMobject cannot be typehinted to return a list[VMobject], because both lists are incompatible, even though VMobject is a subclass of Mobject. This does not happen when we use Sequence, which makes it the most specific type without that issue.

Therefore, the rule we apply here is:

  • If it's a method of a class meant to be subclassed, use Sequence.
  • Otherwise, use list. There are some plain functions, not methods, which are typehinted to return a list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that this logic applies here, because of the usage of Self. From a quick test it seems that we should be able to do it with list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to use list because of an error MyPy threw. Somewhere, I had a list of SingleStringMathTex and MyPy couldn't let me append a MathTex. I'll take a look again to see if there's another way of solving this, though.


def arrange(
Expand Down Expand Up @@ -2576,13 +2577,13 @@ def init_sizes(sizes, num, measures, name):

def sort(
self,
point_to_num_func: Callable[[Point3D], ManimInt] = lambda p: p[0],
submob_func: Callable[[Mobject], ManimInt] | None = None,
point_to_num_func: Callable[[Point3D], float] = lambda p: p[0],
submob_func: Callable[[Mobject], float] | None = None,
) -> Self:
"""Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``."""
if submob_func is None:

def submob_func(m: Mobject):
def submob_func(m: Mobject) -> float:
return point_to_num_func(m.get_center())

self.submobjects.sort(key=submob_func)
Expand Down
10 changes: 7 additions & 3 deletions manim/mobject/svg/svg_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
from pathlib import Path
from typing import TYPE_CHECKING
from xml.etree import ElementTree as ET

import numpy as np
Expand All @@ -21,6 +22,9 @@
from ..opengl.opengl_compatibility import ConvertToOpenGL
from ..types.vectorized_mobject import VMobject

if TYPE_CHECKING:
from manim.utils.color import ParsableManimColor

__all__ = ["SVGMobject", "VMobjectFromSVGPath"]


Expand Down Expand Up @@ -98,11 +102,11 @@ def __init__(
should_center: bool = True,
height: float | None = 2,
width: float | None = None,
color: str | None = None,
color: ParsableManimColor | None = None,
opacity: float | None = None,
fill_color: str | None = None,
fill_color: ParsableManimColor | None = None,
fill_opacity: float | None = None,
stroke_color: str | None = None,
stroke_color: ParsableManimColor | None = None,
stroke_opacity: float | None = None,
stroke_width: float | None = None,
svg_default: dict | None = None,
Expand Down
119 changes: 72 additions & 47 deletions manim/mobject/text/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,30 @@

__all__ = ["DecimalNumber", "Integer", "Variable"]

from collections.abc import Sequence
from typing import TYPE_CHECKING

import numpy as np

from manim import config
from manim._config import config
from manim.constants import *
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex, Tex
from manim.mobject.text.text_mobject import Text
from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
from manim.mobject.types.vectorized_mobject import VMobject
from manim.mobject.value_tracker import ValueTracker

string_to_mob_map = {}
if TYPE_CHECKING:
from typing import Any, Union

from typing_extensions import Self, TypeAlias

from manim.mobject.text.tex_mobject import Tex
Dismissed Show dismissed Hide dismissed
from manim.mobject.text.text_mobject import MarkupText, Text
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
from manim.typing import Vector3D

TextLike: TypeAlias = Union[SingleStringMathTex, MathTex, Tex, Text, MarkupText]


string_to_mob_map: dict[str, TextLike] = {}

__all__ = ["DecimalNumber", "Integer", "Variable"]

Expand Down Expand Up @@ -83,38 +94,38 @@ def construct(self):

def __init__(
self,
number: float = 0,
number: float | complex = 0,
num_decimal_places: int = 2,
mob_class: VMobject = MathTex,
mob_class: type[TextLike] = MathTex,
include_sign: bool = False,
group_with_commas: bool = True,
digit_buff_per_font_unit: float = 0.001,
show_ellipsis: bool = False,
unit: str | None = None, # Aligned to bottom unless it starts with "^"
unit_buff_per_font_unit: float = 0,
include_background_rectangle: bool = False,
edge_to_fix: Sequence[float] = LEFT,
edge_to_fix: Vector3D = LEFT,
font_size: float = DEFAULT_FONT_SIZE,
stroke_width: float = 0,
fill_opacity: float = 1.0,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(**kwargs, stroke_width=stroke_width)
self.number = number
self.num_decimal_places = num_decimal_places
self.include_sign = include_sign
self.mob_class = mob_class
self.group_with_commas = group_with_commas
self.digit_buff_per_font_unit = digit_buff_per_font_unit
self.show_ellipsis = show_ellipsis
self.unit = unit
self.unit_buff_per_font_unit = unit_buff_per_font_unit
self.include_background_rectangle = include_background_rectangle
self.edge_to_fix = edge_to_fix
self._font_size = font_size
self.fill_opacity = fill_opacity

self.initial_config = kwargs.copy()
self.number: float | complex = number
self.num_decimal_places: int = num_decimal_places
self.include_sign: bool = include_sign
self.mob_class: type[TextLike] = mob_class
self.group_with_commas: bool = group_with_commas
self.digit_buff_per_font_unit: float = digit_buff_per_font_unit
self.show_ellipsis: bool = show_ellipsis
self.unit: str | None = unit
self.unit_buff_per_font_unit: float = unit_buff_per_font_unit
self.include_background_rectangle: bool = include_background_rectangle
self.edge_to_fix: Vector3D = edge_to_fix
self._font_size: float = font_size
self.fill_opacity: float = fill_opacity
Comment on lines +114 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this feels a little bit suspicious - mypy should be able to figure out types from assignment?


self.initial_config: dict[str, Any] = kwargs.copy()
self.initial_config.update(
{
"num_decimal_places": num_decimal_places,
Expand All @@ -136,12 +147,12 @@ def __init__(
self.init_colors()

@property
def font_size(self):
def font_size(self) -> float:
"""The font size of the tex mobject."""
return self.height / self.initial_height * self._font_size

@font_size.setter
def font_size(self, font_val):
def font_size(self, font_val: float) -> None:
if font_val <= 0:
raise ValueError("font_size must be greater than 0.")
elif self.height > 0:
Expand All @@ -152,7 +163,7 @@ def font_size(self, font_val):
# font_size does not depend on current size.
self.scale(font_val / self.font_size)

def _set_submobjects_from_number(self, number):
def _set_submobjects_from_number(self, number: float | complex) -> None:
self.number = number
self.submobjects = []

Expand All @@ -161,8 +172,9 @@ def _set_submobjects_from_number(self, number):

# Add non-numerical bits
if self.show_ellipsis:
# TODO: Why MyPy 'cannot determine type of "color"'?
self.add(
self._string_to_mob("\\dots", SingleStringMathTex, color=self.color),
self._string_to_mob(r"\dots", SingleStringMathTex, color=self.color), # type: ignore [has-type]
)

self.arrange(
Expand Down Expand Up @@ -196,12 +208,12 @@ def _set_submobjects_from_number(self, number):
self.unit_sign.align_to(self, UP)

# track the initial height to enable scaling via font_size
self.initial_height = self.height
self.initial_height: float = self.height

if self.include_background_rectangle:
self.add_background_rectangle()

def _get_num_string(self, number):
def _get_num_string(self, number: float | complex) -> str:
if isinstance(number, complex):
formatter = self._get_complex_formatter()
else:
Expand All @@ -214,17 +226,21 @@ def _get_num_string(self, number):

return num_string

def _string_to_mob(self, string: str, mob_class: VMobject | None = None, **kwargs):
def _string_to_mob(
self, string: str, mob_class: type[TextLike] | None = None, **kwargs: Any
) -> TextLike:
chopan050 marked this conversation as resolved.
Show resolved Hide resolved
if mob_class is None:
mob_class = self.mob_class

_mob_class = self.mob_class if mob_class is None else mob_class
chopan050 marked this conversation as resolved.
Show resolved Hide resolved

if string not in string_to_mob_map:
string_to_mob_map[string] = mob_class(string, **kwargs)
string_to_mob_map[string] = _mob_class(string, **kwargs)
mob = string_to_mob_map[string].copy()
mob.font_size = self._font_size
return mob

def _get_formatter(self, **kwargs):
def _get_formatter(self, **kwargs: Any) -> str:
"""
Configuration is based first off instance attributes,
but overwritten by any kew word argument. Relevant
Expand Down Expand Up @@ -257,16 +273,16 @@ def _get_formatter(self, **kwargs):
],
)

def _get_complex_formatter(self):
def _get_complex_formatter(self, **kwargs: Any) -> str:
return "".join(
[
self._get_formatter(field_name="0.real"),
self._get_formatter(field_name="0.imag", include_sign=True),
self._get_formatter(field_name="0.real", **kwargs),
self._get_formatter(field_name="0.imag", **kwargs, include_sign=True),
"i",
],
)

def set_value(self, number: float):
def set_value(self, number: float | complex) -> Self:
"""Set the value of the :class:`~.DecimalNumber` to a new number.

Parameters
Expand Down Expand Up @@ -303,11 +319,12 @@ def set_value(self, number: float):
self.init_colors()
return self

def get_value(self):
def get_value(self) -> float | complex:
return self.number

def increment_value(self, delta_t=1):
def increment_value(self, delta_t: float | complex = 1.0) -> Self:
self.set_value(self.get_value() + delta_t)
return self


class Integer(DecimalNumber):
Expand All @@ -327,10 +344,15 @@ def construct(self):
self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4))
"""

def __init__(self, number=0, num_decimal_places=0, **kwargs):
def __init__(
self,
number: float | complex = 0,
chopan050 marked this conversation as resolved.
Show resolved Hide resolved
num_decimal_places: int = 0,
**kwargs: Any,
) -> None:
super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs)

def get_value(self):
def get_value(self) -> int:
return int(np.round(super().get_value()))


Expand Down Expand Up @@ -441,16 +463,19 @@ def __init__(
self,
var: float,
label: str | Tex | MathTex | Text | SingleStringMathTex,
var_type: DecimalNumber | Integer = DecimalNumber,
var_type: type[DecimalNumber | Integer] = DecimalNumber,
num_decimal_places: int = 2,
**kwargs,
):
self.label = MathTex(label) if isinstance(label, str) else label
**kwargs: Any,
) -> None:
self.label: Tex | MathTex | Text | SingleStringMathTex = (
MathTex(label) if isinstance(label, str) else label
)
equals = MathTex("=").next_to(self.label, RIGHT)
self.label.add(equals)

self.tracker = ValueTracker(var)
self.tracker: ValueTracker = ValueTracker(var)

self.value: DecimalNumber | Integer
if var_type == DecimalNumber:
self.value = DecimalNumber(
self.tracker.get_value(),
Expand Down
Loading
Loading