Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow NumPy-like multi-indexing and slicing of Mobjects #4036

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion manim/mobject/geometry/arc.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def get_tip(self) -> VMobject:
if len(tips) == 0:
raise Exception("tip not found")
else:
tip: VMobject = tips[0]
tip: VMobject = tips[0] # type: ignore [assignment]
return tip

def get_default_tip_length(self) -> float:
Expand Down
4 changes: 2 additions & 2 deletions manim/mobject/geometry/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,9 @@ def scale(self, factor: float, scale_tips: bool = False, **kwargs: Any) -> Self:
self._set_stroke_width_from_length()

if has_tip:
self.add_tip(tip=old_tips[0])
self.add_tip(tip=old_tips[0]) # type: ignore [arg-type]
if has_start_tip:
self.add_tip(tip=old_tips[1], at_start=True)
self.add_tip(tip=old_tips[1], at_start=True) # type: ignore [arg-type]
return self

def get_normal_vector(self) -> Vector3D:
Expand Down
128 changes: 124 additions & 4 deletions manim/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,70 @@ class Mobject:
getting and setting generic attributes with ``get_*``
and ``set_*`` methods. See :meth:`set` for more details.

Mobjects support NumPy-like indexing and slicing. If more than one item is
extracted, the results are collected inside a :class:`Group` or, if this is
a :class:`~.VMobject`, a :class:`~.VGroup`. For example, given this
structure of VMobjects:

.. code::

Mob0
├──VGroup [0]
│ ├──Mob1 [0, 0]
│ └──Mob2 [0, 1]
├──VGroup [1]
│ ├──Mob3 [1, 0]
│ │ ├──VGroup [1, 0, 0]
│ │ │ └──Mob7 [1, 0, 0, 0]
│ │ └──Mob8 [1, 0, 1]
│ │ └──Mob9 [1, 0, 1, 0]
│ ├──Mob4 [1, 1]
│ └──Mob5 [1, 2]
└──Mob6 [2]

this is possible:

.. code:: pycon

>>> mobs = [VMobject(name=f"Mob{i}") for i in range(10)]
>>> vgroups = [VGroup(*mobs[1:3]), VGroup(*mobs[3:6]), VGroup(mobs[7])]
>>>
>>> base_mob = mobs[0]
>>> base_mob.add(vgroups[0], vgroups[1], mobs[6])
>>> mobs[3].add(vgroups[2], mobs[8])
>>> mobs[8].add(mobs[9])
>>>
>>> # Basic indexing
>>> base_mob[2]
Mob6
>>> base_mob[0]
VGroup(Mob1, Mob2)
>>> base_mob[1]
VGroup(Mob3, Mob4, Mob5)
>>> base_mob[1:]
VGroup(VGroup of 3 submobjects, Mob6)
>>>
>>> # Multi-dimensional indexing
>>> base_mob[0, 0]
Mob1
>>> base_mob[1, 0, 0]
VGroup(Mob7)
>>> base_mob[1, 0, 0, 0]
Mob7
>>> base_mob[:2, 0]
VGroup(Mob1, Mob3)
>>> base_mob[:2, 1]
VGroup(Mob2, Mob4)
>>> base_mob[1, 0, :, 0]
VGroup(Mob7, Mob9)
>>>
>>> # Fancy indexing
>>> base_mob[[2, 0]]
VGroup(Mob6, VGroup of 2 submobjects)
>>> base_mob[[True, False, True]]
VGroup(VGroup of 2 submobjects, Mob6)


Attributes
----------
submobjects : List[:class:`Mobject`]
Expand Down Expand Up @@ -2270,12 +2334,68 @@ def align_to(

# Family matters

def __getitem__(self, value):
def __getitem__(self, value: int | slice | tuple | list | np.ndarray) -> Mobject:
"""See the Mobject docstring for more information. This magic method's
docstring is not rendered in the docs.
"""

def get_from_list(
mob_list: list[Mobject],
value: int | slice | tuple | list | np.ndarray,
) -> Mobject | list[Mobject]:
"""Internal function to extract items from a list, even if the
passed index is another sequence.
"""
# Basic indexing, 1 dimension
if isinstance(value, (int, slice)):
return mob_list[value]

# Basic indexing, N dimensions
if isinstance(value, tuple):
submob_or_submob_list = get_from_list(mob_list, value[0])
# Base case: only 1 value
if len(value) == 1:
return submob_or_submob_list
# Recursion: iterate over the rest of values
if isinstance(value[0], int):
submob = submob_or_submob_list
subgroup = submob[value[1:]]
return subgroup
submob_list = submob_or_submob_list
subgroups = [sm[value[1:]] for sm in submob_list]
return subgroups

# Simple fancy indexing
if isinstance(value, (list, np.ndarray)):
items: list[Mobject]
# With array of bools
if len(value) == len(mob_list) and all(
isinstance(index, bool) for index in value
):
items = []
for i, include_mob in enumerate(value):
if include_mob:
items.append(mob_list[i])
return items

if any(not isinstance(index, int) for index in value):
raise ValueError(
"The given array must contain either only bools or "
"only ints."
)

# With array of ints
items = [mob_list[index] for index in value]
return items

raise ValueError(f"Index type {value.__class__.__name__} is not supported.")

self_list = self.split()
if isinstance(value, slice):
mob_or_mobs = get_from_list(self_list, value)
if isinstance(mob_or_mobs, list):
GroupClass = self.get_group_class()
return GroupClass(*self_list.__getitem__(value))
return self_list.__getitem__(value)
return GroupClass(*mob_or_mobs)
return mob_or_mobs

def __iter__(self):
return iter(self.split())
Expand Down
76 changes: 75 additions & 1 deletion tests/module/mobject/mobject/test_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import numpy as np
import pytest

from manim import DL, UR, Circle, Mobject, Rectangle, Square, VGroup
from manim.constants import DL, UR
from manim.mobject.geometry.arc import Circle
from manim.mobject.geometry.polygram import Rectangle, Square
from manim.mobject.mobject import Mobject
from manim.mobject.types.vectorized_mobject import VGroup, VMobject


def test_mobject_add():
Expand Down Expand Up @@ -168,3 +172,73 @@ def test_mobject_dimensions_has_points_and_children():
assert inner_rect.width == 2
assert inner_rect.height == 1
assert inner_rect.depth == 0


def test_mobject_get_item():
mobs = [VMobject(name=f"Mob{i}") for i in range(10)]
vgroups = [VGroup(*mobs[1:3]), VGroup(*mobs[3:6]), VGroup(mobs[7])]

base_mob = mobs[0]
base_mob.add(vgroups[0], vgroups[1], mobs[6])
mobs[3].add(vgroups[2], mobs[8])
mobs[8].add(mobs[9])

"""
Structure:

Mob0
├──VGroup [0]
│ ├──Mob1 [0, 0]
│ └──Mob2 [0, 1]
├──VGroup [1]
│ ├──Mob3 [1, 0]
│ │ ├──VGroup [1, 0, 0]
│ │ │ └──Mob7 [1, 0, 0, 0]
│ │ └──Mob8 [1, 0, 1]
│ │ └──Mob9 [1, 0, 1, 0]
│ ├──Mob4 [1, 1]
│ └──Mob5 [1, 2]
└──Mob6 [2]
"""

# Basic indexing, 1 dimension
assert base_mob[0].__repr__() == "VGroup(Mob1, Mob2)"
assert base_mob[1].__repr__() == "VGroup(Mob3, Mob4, Mob5)"
assert base_mob[2].__repr__() == "Mob6"
assert base_mob[1:].__repr__() == "VGroup(VGroup of 3 submobjects, Mob6)"
assert (
base_mob[:2].__repr__()
== "VGroup(VGroup of 2 submobjects, VGroup of 3 submobjects)"
)
assert base_mob[::2].__repr__() == "VGroup(VGroup of 2 submobjects, Mob6)"

# Basic indexing, N dimensions
assert base_mob[0, 0].__repr__() == "Mob1"
assert base_mob[0, 1].__repr__() == "Mob2"
assert base_mob[1, 0].__repr__() == "Mob3"
assert base_mob[1, 0, 0].__repr__() == "VGroup(Mob7)"
assert base_mob[1, 0, 0, 0].__repr__() == "Mob7"
assert base_mob[1, 0, 1].__repr__() == "Mob8"
assert base_mob[1, 0, 1, 0].__repr__() == "Mob9"
assert base_mob[1, 1].__repr__() == "Mob4"
assert base_mob[1, 2].__repr__() == "Mob5"

assert base_mob[:2, 0].__repr__() == "VGroup(Mob1, Mob3)"
assert base_mob[:2, 1].__repr__() == "VGroup(Mob2, Mob4)"
assert (
base_mob[:2, ::2].__repr__()
== "VGroup(VGroup of 1 submobjects, VGroup of 2 submobjects)"
)
assert base_mob[1, 0, :, 0].__repr__() == "VGroup(Mob7, Mob9)"

# Fancy indexing
assert (
base_mob[[0, 1]].__repr__()
== "VGroup(VGroup of 2 submobjects, VGroup of 3 submobjects)"
)
assert base_mob[[2, 0]].__repr__() == "VGroup(Mob6, VGroup of 2 submobjects)"
assert (
base_mob[[True, False, True]].__repr__()
== "VGroup(VGroup of 2 submobjects, Mob6)"
)
assert base_mob[[0, 1], 0].__repr__() == "VGroup(Mob1, Mob3)"
Loading