Skip to content

Commit

Permalink
Apply separate updaters to each edge in a Graph
Browse files Browse the repository at this point in the history
  • Loading branch information
tlcyr4 committed Jul 1, 2024
1 parent fc5e878 commit 8b5a789
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 128 deletions.
202 changes: 74 additions & 128 deletions manim/mobject/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,6 @@ def __init__(

nx_graph = self._empty_networkx_graph()
nx_graph.add_nodes_from(vertices)
nx_graph.add_edges_from(edges)
self._graph = nx_graph

if isinstance(labels, dict):
Expand Down Expand Up @@ -627,58 +626,46 @@ def __init__(
self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices}
self.vertices.update(vertex_mobjects)

if edge_config is None:
edge_config = {}

default_configs, per_edge_tip_configs = GenericGraph._separate_child_configs(
edge_config, lambda k: isinstance(k, tuple)
)
self.default_edge_config, self.default_tip_config = (
GenericGraph._separate_tip_configs(default_configs)
)

self.edges = {}
for edge in edges:
self._add_edge(edge, edge_type, per_edge_tip_configs.get(edge))

self.change_layout(
layout=layout,
layout_scale=layout_scale,
layout_config=layout_config,
partitions=partitions,
root_vertex=root_vertex,
)

# build edge_config
if edge_config is None:
edge_config = {}
default_tip_config = {}
default_edge_config = {}
if edge_config:
default_tip_config = edge_config.pop("tip_config", {})
default_edge_config = {
k: v
for k, v in edge_config.items()
if not isinstance(
k, tuple
) # everything that is not an edge is an option
}
self._edge_config = {}
self._tip_config = {}
for e in edges:
if e in edge_config:
self._tip_config[e] = edge_config[e].pop(
"tip_config", copy(default_tip_config)
)
self._edge_config[e] = edge_config[e]
else:
self._tip_config[e] = copy(default_tip_config)
self._edge_config[e] = copy(default_edge_config)

self.default_edge_config = default_edge_config
self._populate_edge_dict(edges, edge_type)

self.add(*self.vertices.values())
self.add(*self.edges.values())

self.add_updater(self.update_edges)

@staticmethod
def _empty_networkx_graph() -> nx.classes.graph.Graph:
"""Return an empty networkx graph for the given graph type."""
raise NotImplementedError("To be implemented in concrete subclasses")

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
"""Helper method for populating the edges of the graph."""
raise NotImplementedError("To be implemented in concrete subclasses")
@staticmethod
def _separate_tip_configs(config: dict) -> (dict, dict):
edge_config, tip_config_holder = GenericGraph._separate_child_configs(
config, lambda k: k == "tip_config"
)
return edge_config, tip_config_holder.get("tip_config", {})

@staticmethod
def _separate_child_configs(config: dict, is_child_key) -> (dict, dict):
default_config = {k: v for k, v in config.items() if not is_child_key(k)}
per_child_configs = {k: v for k, v in config.items() if is_child_key(k)}
return default_config, per_child_configs

def __getitem__(self: Graph, v: Hashable) -> Mobject:
return self.vertices[v]
Expand Down Expand Up @@ -955,8 +942,6 @@ def _remove_vertex(self, vertex):
self._vertex_config.pop(vertex)

edge_tuples = [e for e in self.edges if vertex in e]
for e in edge_tuples:
self._edge_config.pop(e)
to_remove = [self.edges.pop(e) for e in edge_tuples]
to_remove.append(self.vertices.pop(vertex))

Expand Down Expand Up @@ -1028,28 +1013,26 @@ def _add_edge(
"""
if edge_config is None:
edge_config = self.default_edge_config.copy()
added_mobjects = []
for v in edge:
if v not in self.vertices:
added_mobjects.append(self._add_vertex(v))
edge_config = {}
added_vertices = [self._add_vertex(v) for v in edge if v not in self.vertices]
u, v = edge

self._graph.add_edge(u, v)

base_edge_config = self.default_edge_config.copy()
base_edge_config.update(edge_config)
edge_config = base_edge_config
self._edge_config[(u, v)] = edge_config
edge_mobject = self._create_edge(self[u], self[v], edge_type, edge_config)
self.edges[(u, v)] = edge_mobject
self.add(edge_mobject)
return self.get_group_class()(*added_vertices, edge_mobject)

def _create_edge(self, u, v, edge_type, config):
edge_mobject = edge_type(
self[u].get_center(), self[v].get_center(), z_index=-1, **edge_config
u.get_center(),
v.get_center(),
z_index=-1,
**{**self.default_edge_config, **config},
)
self.edges[(u, v)] = edge_mobject

self.add(edge_mobject)
added_mobjects.append(edge_mobject)
return self.get_group_class()(*added_mobjects)
edge_mobject.add_updater(self._generate_edge_updater(u, v, config))
return edge_mobject

def add_edges(
self,
Expand Down Expand Up @@ -1087,29 +1070,14 @@ def add_edges(
"""
if edge_config is None:
edge_config = {}
non_edge_settings = {k: v for (k, v) in edge_config.items() if k not in edges}
base_edge_config = self.default_edge_config.copy()
base_edge_config.update(non_edge_settings)
base_edge_config = {e: base_edge_config.copy() for e in edges}
for e in edges:
base_edge_config[e].update(edge_config.get(e, {}))
edge_config = base_edge_config

edge_vertices = set(it.chain(*edges))
new_vertices = [v for v in edge_vertices if v not in self.vertices]
added_vertices = self.add_vertices(*new_vertices, **kwargs)

added_mobjects = sum(
(
self._add_edge(
edge,
edge_type=edge_type,
edge_config=edge_config[edge],
).submobjects
for edge in edges
),
added_vertices,
)

added_mobjects = [
mobject
for edge in edges
for mobject in self._add_edge(
edge, edge_type, edge_config.get(edge)
).submobjects
]
return self.get_group_class()(*added_mobjects)

@override_animate(add_edges)
Expand Down Expand Up @@ -1145,7 +1113,6 @@ def _remove_edge(self, edge: tuple[Hashable]):
edge_mobject = self.edges.pop(edge)

self._graph.remove_edge(*edge)
self._edge_config.pop(edge, None)

self.remove(edge_mobject)
return edge_mobject
Expand Down Expand Up @@ -1544,29 +1511,17 @@ def construct(self):
def _empty_networkx_graph() -> nx.Graph:
return nx.Graph()

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
self.edges = {
(u, v): edge_type(
self[u].get_center(),
self[v].get_center(),
z_index=-1,
**self._edge_config[(u, v)],
)
for (u, v) in edges
}

def update_edges(self, graph):
for (u, v), edge in graph.edges.items():
# Undirected graph has a Line edge
def _generate_edge_updater(self, u, v, config):
def edge_updater(edge):
edge.set_points_by_ends(
graph[u].get_center(),
graph[v].get_center(),
buff=self._edge_config.get("buff", 0),
path_arc=self._edge_config.get("path_arc", 0),
u.get_center(),
v.get_center(),
buff=config.get("buff", 0),
path_arc=config.get("path_arc", 0),
)

return edge_updater

def __repr__(self: Graph) -> str:
return f"Undirected graph on {len(self.vertices)} vertices and {len(self.edges)} edges"

Expand Down Expand Up @@ -1751,39 +1706,30 @@ def construct(self):
def _empty_networkx_graph() -> nx.DiGraph:
return nx.DiGraph()

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
self.edges = {
(u, v): edge_type(
self[u],
self[v],
z_index=-1,
**self._edge_config[(u, v)],
)
for (u, v) in edges
}

for (u, v), edge in self.edges.items():
edge.add_tip(**self._tip_config[(u, v)])

def update_edges(self, graph):
"""Updates the edges to stick at their corresponding vertices.
def _create_edge(self, u, v, edge_type, config):
edge_config, tip_config = GenericGraph._separate_tip_configs(config)
edge_mobject = edge_type(
u.get_center(),
v.get_center(),
z_index=-1,
**{**self.default_edge_config, **edge_config},
)
edge_mobject.add_updater(self._generate_edge_updater(u, v, edge_config))
edge_mobject.add_tip(**{**self.default_tip_config, **tip_config})
return edge_mobject

Arrow tips need to be repositioned since otherwise they can be
deformed.
"""
for (u, v), edge in graph.edges.items():
def _generate_edge_updater(self, u, v, config):
def edge_updater(edge):
tip = edge.pop_tips()[0]
# Passing the Mobject instead of the vertex makes the tip
# stop on the bounding box of the vertex.
edge.set_points_by_ends(
graph[u],
graph[v],
buff=self._edge_config.get("buff", 0),
path_arc=self._edge_config.get("path_arc", 0),
u.get_center(),
v.get_center(),
buff=config.get("buff", 0),
path_arc=config.get("path_arc", 0),
)
edge.add_tip(tip)

return edge_updater

def __repr__(self: DiGraph) -> str:
return f"Directed graph on {len(self.vertices)} vertices and {len(self.edges)} edges"
Binary file not shown.
Binary file not shown.
34 changes: 34 additions & 0 deletions tests/test_graphical_units/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from manim import *
from manim.utils.testing.frames_comparison import frames_comparison

__module_test__ = "graph"


@frames_comparison(last_frame=False)
def test_graph_concurrent_animations(scene):
vertices = [0, 1]
positions = {0: [-1, 0, 0], 1: [1, 0, 0]}
g = Graph(vertices, [], layout=positions)
scene.play(g[1].animate.move_to([1, 1, 0]), g.animate.add_edges((0, 1)))
scene.wait(0.1)


@frames_comparison(last_frame=False)
def test_digraph_add_edge(scene):
vertices = [0, 1]
positions = {0: [-1, 0, 0], 1: [1, 0, 0]}
g = DiGraph(
vertices,
[],
layout=positions,
edge_config={
"tip_config": {
"tip_shape": ArrowSquareTip,
"tip_length": 0.15,
}
},
)
scene.play(g.animate.add_edges((0, 1)))
scene.wait(0.1)

0 comments on commit 8b5a789

Please sign in to comment.