diff --git a/manim/mobject/graph.py b/manim/mobject/graph.py index 72a26d27db..b2dfb92761 100644 --- a/manim/mobject/graph.py +++ b/manim/mobject/graph.py @@ -9,7 +9,7 @@ import itertools as it from collections.abc import Hashable, Iterable -from copy import copy +from copy import copy, deepcopy from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import networkx as nx @@ -626,6 +626,7 @@ def __init__( self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices} self.vertices.update(vertex_mobjects) + self.add(*self.vertices.values()) self.change_layout( layout=layout, @@ -635,37 +636,16 @@ def __init__( root_vertex=root_vertex, ) - # build edge_config if edge_config is None: edge_config = {} - default_tip_config = {} - default_edge_config = {} - if edge_config: - default_tip_config = edge_config.pop("tip_config", {}) - default_edge_config = { - k: v - for k, v in edge_config.items() - if not isinstance( - k, tuple - ) # everything that is not an edge is an option - } - self._edge_config = {} - self._tip_config = {} - for e in edges: - if e in edge_config: - self._tip_config[e] = edge_config[e].pop( - "tip_config", copy(default_tip_config) - ) - self._edge_config[e] = edge_config[e] - else: - self._tip_config[e] = copy(default_tip_config) - self._edge_config[e] = copy(default_edge_config) - self.default_edge_config = default_edge_config - self._populate_edge_dict(edges, edge_type) + self.edges = {} + self._edge_config = {} + self.default_edge_config, _ = GenericGraph._split_out_child_configs( + edge_config, lambda k: isinstance(k, tuple) + ) - self.add(*self.vertices.values()) - self.add(*self.edges.values()) + self.add_edges(*edges, edge_type=edge_type, edge_config=edge_config) self.add_updater(self.update_edges) @@ -674,11 +654,11 @@ def _empty_networkx_graph() -> nx.classes.graph.Graph: """Return an empty networkx graph for the given graph type.""" raise NotImplementedError("To be implemented in concrete subclasses") - def _populate_edge_dict( - self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject] - ): - """Helper method for populating the edges of the graph.""" - raise NotImplementedError("To be implemented in concrete subclasses") + @staticmethod + def _split_out_child_configs(config: dict, is_child_key) -> (dict, dict): + parent_config = {k: v for k, v in config.items() if not is_child_key(k)} + child_configs = {k: v for k, v in config.items() if is_child_key(k)} + return parent_config, child_configs def __getitem__(self: Graph, v: Hashable) -> Mobject: return self.vertices[v] @@ -1028,28 +1008,20 @@ def _add_edge( """ if edge_config is None: - edge_config = self.default_edge_config.copy() - added_mobjects = [] - for v in edge: - if v not in self.vertices: - added_mobjects.append(self._add_vertex(v)) - u, v = edge + edge_config = {} + + added_vertices = [self._add_vertex(v) for v in edge if v not in self.vertices] + u, v = edge self._graph.add_edge(u, v) - base_edge_config = self.default_edge_config.copy() - base_edge_config.update(edge_config) - edge_config = base_edge_config - self._edge_config[(u, v)] = edge_config + self._edge_config[edge] = {**self.default_edge_config, **edge_config} + edge_mobject = self._create_edge_mobject(edge, edge_type) - edge_mobject = edge_type( - self[u].get_center(), self[v].get_center(), z_index=-1, **edge_config - ) self.edges[(u, v)] = edge_mobject - self.add(edge_mobject) - added_mobjects.append(edge_mobject) - return self.get_group_class()(*added_mobjects) + + return self.get_group_class()(*added_vertices, edge_mobject) def add_edges( self, @@ -1087,13 +1059,12 @@ def add_edges( """ if edge_config is None: edge_config = {} - non_edge_settings = {k: v for (k, v) in edge_config.items() if k not in edges} - base_edge_config = self.default_edge_config.copy() - base_edge_config.update(non_edge_settings) - base_edge_config = {e: base_edge_config.copy() for e in edges} - for e in edges: - base_edge_config[e].update(edge_config.get(e, {})) - edge_config = base_edge_config + else: + edge_config = deepcopy(edge_config) + + batch_default_config, custom_configs = GenericGraph._split_out_child_configs( + edge_config, lambda k: isinstance(k, tuple) + ) edge_vertices = set(it.chain(*edges)) new_vertices = [v for v in edge_vertices if v not in self.vertices] @@ -1104,7 +1075,7 @@ def add_edges( self._add_edge( edge, edge_type=edge_type, - edge_config=edge_config[edge], + edge_config={**batch_default_config, **custom_configs.get(edge, {})}, ).submobjects for edge in edges ), @@ -1145,7 +1116,7 @@ def _remove_edge(self, edge: tuple[Hashable]): edge_mobject = self.edges.pop(edge) self._graph.remove_edge(*edge) - self._edge_config.pop(edge, None) + self._edge_config.pop(edge) self.remove(edge_mobject) return edge_mobject @@ -1544,18 +1515,14 @@ def construct(self): def _empty_networkx_graph() -> nx.Graph: return nx.Graph() - def _populate_edge_dict( - self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject] - ): - self.edges = { - (u, v): edge_type( - self[u].get_center(), - self[v].get_center(), - z_index=-1, - **self._edge_config[(u, v)], - ) - for (u, v) in edges - } + def _create_edge_mobject(self, edge, edge_type): + u, v = edge + return edge_type( + self[u].get_center(), + self[v].get_center(), + z_index=-1, + **self._edge_config[(u, v)], + ) def update_edges(self, graph): for (u, v), edge in graph.edges.items(): @@ -1751,21 +1718,24 @@ def construct(self): def _empty_networkx_graph() -> nx.DiGraph: return nx.DiGraph() - def _populate_edge_dict( - self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject] - ): - self.edges = { - (u, v): edge_type( - self[u], - self[v], - z_index=-1, - **self._edge_config[(u, v)], - ) - for (u, v) in edges - } + @staticmethod + def _split_out_tip_configs(config: dict) -> (dict, dict): + edge_config, tip_config = GenericGraph._split_out_child_configs( + config, lambda k: k == "tip_config" + ) + return edge_config, tip_config.get("tip_config", {}) - for (u, v), edge in self.edges.items(): - edge.add_tip(**self._tip_config[(u, v)]) + def _create_edge_mobject(self, edge, edge_type): + edge_config, tip_config = DiGraph._split_out_tip_configs(self._edge_config[edge]) + u, v = edge + edge_mobject = edge_type( + self[u], + self[v], + z_index=-1, + **edge_config, + ) + edge_mobject.add_tip(**tip_config) + return edge_mobject def update_edges(self, graph): """Updates the edges to stick at their corresponding vertices. diff --git a/tests/test_graphical_units/control_data/graph/digraph_add_edges.npz b/tests/test_graphical_units/control_data/graph/digraph_add_edges.npz new file mode 100644 index 0000000000..583e741abe Binary files /dev/null and b/tests/test_graphical_units/control_data/graph/digraph_add_edges.npz differ diff --git a/tests/test_graphical_units/test_graph.py b/tests/test_graphical_units/test_graph.py new file mode 100644 index 0000000000..aab033f9b8 --- /dev/null +++ b/tests/test_graphical_units/test_graph.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from manim import * +from manim.utils.testing.frames_comparison import frames_comparison + +__module_test__ = "graph" + + +@frames_comparison +def test_digraph_add_edges(scene): + vertices = range(5) + edges = [ + (0, 1), + (1, 2), + (3, 2), + (3, 4), + ] + + edge_config = { + "stroke_width": 2, + "tip_config": { + "tip_shape": ArrowSquareTip, + "tip_length": 0.15, + }, + (3, 4): { + "color": RED, + "tip_config": {"tip_length": 0.25, "tip_width": 0.25} + }, + } + + g = DiGraph( + vertices, + [], + labels=True, + layout="circular", + ).scale(1.4) + + g.add_edges(*edges, edge_config=edge_config) + + scene.add(g)