Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft of section group feature #3990

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions manim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .mobject.value_tracker import *
from .mobject.vector_field import *
from .renderer.cairo_renderer import *
from .scene.groups import *
from .scene.moving_camera_scene import *
from .scene.scene import *
from .scene.scene_file_writer import *
Expand Down
139 changes: 139 additions & 0 deletions manim/scene/groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import types
from collections.abc import Callable
from typing import TYPE_CHECKING, ClassVar, Generic, ParamSpec, TypeVar, final, overload

from typing_extensions import Self, TypedDict, Unpack

if TYPE_CHECKING:
from .scene import Scene
Dismissed Show dismissed Hide dismissed

__all__ = ["group"]


P = ParamSpec("P")
T = TypeVar("T")


class SectionGroupData(TypedDict, total=False):
"""(Public) data for a :class:`.SectionGroup` in a :class:`.Scene`."""

skip: bool
order: int


# mark as final because _cls_instance_count doesn't
# work with inheritance
@final
class SectionGroup(Generic[P, T]):
"""A section in a :class:`.Scene`.

It holds data about each subsection, and keeps track of the order
of the sections via :attr:`~SectionGroup.order`.

.. warning::

:attr:`~SectionGroup.func` is effectively a function - it is not
bound to the scene, and thus must be called with the first argument
as an instance of :class:`.Scene`.
"""

_cls_instance_count: ClassVar[int] = 0
"""How many times the class has been instantiated.

This is also used for ordering sections, because of the order
decorators are called in a class.
"""

def __init__(
self, func: Callable[P, T], **kwargs: Unpack[SectionGroupData]
) -> None:
self.func = func

self.skip = kwargs.get("skip", False)

# update the order counter
self.order = self._cls_instance_count
self.__class__._cls_instance_count += 1
if "order" in kwargs:
self.order = kwargs["order"]

def __str__(self) -> str:
skip = self.skip
order = self.order
return f"{self.__class__.__name__}({order=}, {skip=})"

def __repr__(self) -> str:
# return a slightly more verbose repr
s = str(self).removesuffix(")")
func = self.func
return f"{s}, {func=})"

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.func(*args, **kwargs)

def bind(self, instance: Scene) -> Self:
"""Binds :attr:`func` to the scene instance, making :attr:`func` a method.

This allows the section to be called without the scene being passed explicitly.
"""
self.func = types.MethodType(self.func, instance)
return self

def __get__(self, instance: Scene, _owner: type[Scene]) -> Self:
"""Descriptor to bind the section to the scene instance.

This is called implicitly by python when methods are being bound.
"""
return self # HELPME use binding
return self.bind(instance)

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.


@overload
def group(
func: Callable[P, T],
**kwargs: Unpack[SectionGroupData],
) -> SectionGroup[P, T]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def group(
func: None = None,
**kwargs: Unpack[SectionGroupData],
) -> Callable[[Callable[P, T]], SectionGroup[P, T]]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def group(
func: Callable[P, T] | None = None, **kwargs: Unpack[SectionGroupData]
) -> SectionGroup[P, T] | Callable[[Callable[P, T]], SectionGroup[P, T]]:
r"""Decorator to create a SectionGroup in the scene.

Example
-------

.. code-block:: python

class MyScene(Scene):
SectionGroups_api = True

@SectionGroup
def first_SectionGroup(self):
pass

@SectionGroup(skip=True)
def second_SectionGroup(self):
pass

Parameters
----------
func : Callable
The subsection.
skip : bool, optional
Whether to skip the section, by default False
"""

def wrapper(func: Callable[P, T]) -> SectionGroup[P, T]:
return SectionGroup(func, **kwargs)

return wrapper(func) if func is not None else wrapper
62 changes: 61 additions & 1 deletion manim/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..utils.family_ops import restructure_list_to_exclude_certain_family_members
from ..utils.file_ops import open_media_file
from ..utils.iterables import list_difference_update, list_update
from .groups import SectionGroup

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'SectionGroup' may not be defined if module
manim.scene.groups
is imported before module
manim.scene.scene
, as the
definition
of SectionGroup occurs after the cyclic
import
of manim.scene.scene.

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -98,6 +99,11 @@

"""

groups_api = False
section_groups = []
""" Internal attributes to allow group decorator in the class.
TODO Document groups """

def __init__(
self,
renderer=None,
Expand All @@ -110,6 +116,7 @@
self.always_update_mobjects = always_update_mobjects
self.random_seed = random_seed
self.skip_animations = skip_animations
self.group_skip_animations = False # group animation are played by default

self.animations = None
self.stop_condition = None
Expand Down Expand Up @@ -154,6 +161,13 @@
random.seed(self.random_seed)
np.random.seed(self.random_seed)

self.section_groups = self.build_section_groups()
for group in self.section_groups:
if not isinstance(group, SectionGroup):
raise AttributeError(
f"The method {group} doesn't look like it is decorated with the @group decorator."
)

@property
def camera(self):
return self.renderer.camera
Expand Down Expand Up @@ -303,7 +317,18 @@
:meth:`Scene.tear_down`

"""
pass # To be implemented in subclasses
for (
group
) in self.section_groups: # this is empty if section groups are disabled
self.group_skip_animations = group.skip

self.next_section(
group.skip
) # create a default section at the start of each group
group(self) # launch the group # HELPME make a clean call

self.group_skip_animations = False
# To be implemented in subclasses if groups API is disabled

def next_section(
self,
Expand All @@ -315,8 +340,43 @@
``skip_animations`` skips the rendering of all animations in this section.
Refer to :doc:`the documentation</tutorials/output_and_config>` on how to use sections.
"""
# if group is disabled, all sections in it are also disabled
skip_animations = skip_animations or self.group_skip_animations
self.renderer.file_writer.next_section(name, section_type, skip_animations)

def build_section_groups(self) -> List[SectionGroup]:
"""Builds the group list depending on the API used (method list, enabled, disabled)."""
if self.section_groups:
# if a group list is provided we use it by default
def get_group_object(group):
if hasattr(self, group):
return getattr(self, group)
else:
raise AttributeError(
f"Couldn't find method {group} in class {__cls__}. Did you spell it correctly?"
)

return [get_group_object(group) for group in self.section_groups]
elif self.groups_api:
# groups api enabled, but no list provided so we have to look at the decorated groups in order
return self.find_groups()
else:
# groups api disabled
return []

def find_groups(self) -> list[SectionGroup]:
"""Find all groups in a :class:`.Scene` if groups api is turned on."""
groups: list[SectionGroup] = [
bound
for _, bound in inspect.getmembers(
self, predicate=lambda x: isinstance(x, SectionGroup)
)
]
# we can't care about the actual value of the order
# because that would break files with multiple scenes that have sections
groups.sort(key=lambda x: x.order)
return groups

def __str__(self):
return self.__class__.__name__

Expand Down
38 changes: 38 additions & 0 deletions tests/test_scene_rendering/simple_scenes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"InteractiveStaticScene",
"SceneWithSections",
"ElaborateSceneWithSections",
"SceneWithGroupAPI",
"SceneWithGroupList",
]


Expand Down Expand Up @@ -165,3 +167,39 @@ def construct(self):
self.next_section("fade out")
self.play(FadeOut(square))
self.wait()


class SceneWithGroupAPI(Scene):
groups_api = True

def __init__(self):
super().__init__()

self.square = Square()
self.circle = Circle()

@group
def transform(self):
self.play(TransformFromCopy(self.square, self.circle))

@group
def back_transform(self):
self.play(Transform(self.circle, self.square))


class SceneWithGroupList(Scene):
section_groups = ["transform", "back_transform"]

def __init__(self):
super().__init__()

self.square = Square()
self.circle = Circle()

@group
def back_transform(self):
self.play(Transform(self.circle, self.square))

@group
def transform(self):
self.play(TransformFromCopy(self.square, self.circle))
19 changes: 19 additions & 0 deletions tests/test_scene_rendering/test_sections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tests.assert_utils import assert_dir_exists, assert_dir_not_exists

from ..utils.video_tester import video_comparison
from .simple_scenes import SceneWithGroupAPI, SceneWithGroupList, SquareToCircle


@pytest.mark.slow
Expand Down Expand Up @@ -103,3 +104,21 @@ def test_skip_animations(tmp_path, manim_cfg_file, simple_scenes_path):
]
_, err, exit_code = capture(command)
assert exit_code == 0, err


def test_groups_api(tmp_path):
find_api_scene = SceneWithGroupAPI()
list_api_scene = SceneWithGroupList()

assert not SquareToCircle().section_groups
assert len(list_api_scene.section_groups) == len(find_api_scene.section_groups) == 2
assert (
list_api_scene.section_groups[0].func.__name__
== find_api_scene.section_groups[0].func.__name__
== "transform"
)
assert (
list_api_scene.section_groups[1].func.__name__
== find_api_scene.section_groups[1].func.__name__
== "back_transform"
)
Loading