Skip to content

Commit

Permalink
Add and remove tips and tip configs with edges
Browse files Browse the repository at this point in the history
  • Loading branch information
tlcyr4 committed Jul 9, 2024
1 parent 8d70b0e commit 987f7cf
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 84 deletions.
138 changes: 54 additions & 84 deletions manim/mobject/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import itertools as it
from collections.abc import Hashable, Iterable
from copy import copy
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast

import networkx as nx
Expand Down Expand Up @@ -626,6 +626,7 @@ def __init__(

self.vertices = {v: vertex_type(**self._vertex_config[v]) for v in vertices}
self.vertices.update(vertex_mobjects)
self.add(*self.vertices.values())

self.change_layout(
layout=layout,
Expand All @@ -635,37 +636,16 @@ def __init__(
root_vertex=root_vertex,
)

# build edge_config
if edge_config is None:
edge_config = {}
default_tip_config = {}
default_edge_config = {}
if edge_config:
default_tip_config = edge_config.pop("tip_config", {})
default_edge_config = {
k: v
for k, v in edge_config.items()
if not isinstance(
k, tuple
) # everything that is not an edge is an option
}
self._edge_config = {}
self._tip_config = {}
for e in edges:
if e in edge_config:
self._tip_config[e] = edge_config[e].pop(
"tip_config", copy(default_tip_config)
)
self._edge_config[e] = edge_config[e]
else:
self._tip_config[e] = copy(default_tip_config)
self._edge_config[e] = copy(default_edge_config)

self.default_edge_config = default_edge_config
self._populate_edge_dict(edges, edge_type)
self.edges = {}
self._edge_config = {}
self.default_edge_config, _ = GenericGraph._split_out_child_configs(
edge_config, lambda k: isinstance(k, tuple)
)

self.add(*self.vertices.values())
self.add(*self.edges.values())
self.add_edges(*edges, edge_type=edge_type, edge_config=edge_config)

self.add_updater(self.update_edges)

Expand All @@ -674,11 +654,11 @@ def _empty_networkx_graph() -> nx.classes.graph.Graph:
"""Return an empty networkx graph for the given graph type."""
raise NotImplementedError("To be implemented in concrete subclasses")

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
"""Helper method for populating the edges of the graph."""
raise NotImplementedError("To be implemented in concrete subclasses")
@staticmethod
def _split_out_child_configs(config: dict, is_child_key) -> (dict, dict):
parent_config = {k: v for k, v in config.items() if not is_child_key(k)}
child_configs = {k: v for k, v in config.items() if is_child_key(k)}
return parent_config, child_configs

def __getitem__(self: Graph, v: Hashable) -> Mobject:
return self.vertices[v]
Expand Down Expand Up @@ -1028,28 +1008,20 @@ def _add_edge(
"""
if edge_config is None:
edge_config = self.default_edge_config.copy()
added_mobjects = []
for v in edge:
if v not in self.vertices:
added_mobjects.append(self._add_vertex(v))
u, v = edge
edge_config = {}

added_vertices = [self._add_vertex(v) for v in edge if v not in self.vertices]

u, v = edge
self._graph.add_edge(u, v)

base_edge_config = self.default_edge_config.copy()
base_edge_config.update(edge_config)
edge_config = base_edge_config
self._edge_config[(u, v)] = edge_config
self._edge_config[edge] = {**self.default_edge_config, **edge_config}
edge_mobject = self._create_edge_mobject(edge, edge_type)

edge_mobject = edge_type(
self[u].get_center(), self[v].get_center(), z_index=-1, **edge_config
)
self.edges[(u, v)] = edge_mobject

self.add(edge_mobject)
added_mobjects.append(edge_mobject)
return self.get_group_class()(*added_mobjects)

return self.get_group_class()(*added_vertices, edge_mobject)

def add_edges(
self,
Expand Down Expand Up @@ -1087,13 +1059,12 @@ def add_edges(
"""
if edge_config is None:
edge_config = {}
non_edge_settings = {k: v for (k, v) in edge_config.items() if k not in edges}
base_edge_config = self.default_edge_config.copy()
base_edge_config.update(non_edge_settings)
base_edge_config = {e: base_edge_config.copy() for e in edges}
for e in edges:
base_edge_config[e].update(edge_config.get(e, {}))
edge_config = base_edge_config
else:
edge_config = deepcopy(edge_config)

batch_default_config, custom_configs = GenericGraph._split_out_child_configs(
edge_config, lambda k: isinstance(k, tuple)
)

edge_vertices = set(it.chain(*edges))
new_vertices = [v for v in edge_vertices if v not in self.vertices]
Expand All @@ -1104,7 +1075,7 @@ def add_edges(
self._add_edge(
edge,
edge_type=edge_type,
edge_config=edge_config[edge],
edge_config={**batch_default_config, **custom_configs.get(edge, {})},
).submobjects
for edge in edges
),
Expand Down Expand Up @@ -1145,7 +1116,7 @@ def _remove_edge(self, edge: tuple[Hashable]):
edge_mobject = self.edges.pop(edge)

self._graph.remove_edge(*edge)
self._edge_config.pop(edge, None)
self._edge_config.pop(edge)

self.remove(edge_mobject)
return edge_mobject
Expand Down Expand Up @@ -1544,18 +1515,14 @@ def construct(self):
def _empty_networkx_graph() -> nx.Graph:
return nx.Graph()

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
self.edges = {
(u, v): edge_type(
self[u].get_center(),
self[v].get_center(),
z_index=-1,
**self._edge_config[(u, v)],
)
for (u, v) in edges
}
def _create_edge_mobject(self, edge, edge_type):
u, v = edge
return edge_type(
self[u].get_center(),
self[v].get_center(),
z_index=-1,
**self._edge_config[(u, v)],
)

def update_edges(self, graph):
for (u, v), edge in graph.edges.items():
Expand Down Expand Up @@ -1751,21 +1718,24 @@ def construct(self):
def _empty_networkx_graph() -> nx.DiGraph:
return nx.DiGraph()

def _populate_edge_dict(
self, edges: list[tuple[Hashable, Hashable]], edge_type: type[Mobject]
):
self.edges = {
(u, v): edge_type(
self[u],
self[v],
z_index=-1,
**self._edge_config[(u, v)],
)
for (u, v) in edges
}
@staticmethod
def _split_out_tip_configs(config: dict) -> (dict, dict):
edge_config, tip_config = GenericGraph._split_out_child_configs(
config, lambda k: k == "tip_config"
)
return edge_config, tip_config.get("tip_config", {})

for (u, v), edge in self.edges.items():
edge.add_tip(**self._tip_config[(u, v)])
def _create_edge_mobject(self, edge, edge_type):
edge_config, tip_config = DiGraph._split_out_tip_configs(self._edge_config[edge])
u, v = edge
edge_mobject = edge_type(
self[u],
self[v],
z_index=-1,
**edge_config,
)
edge_mobject.add_tip(**tip_config)
return edge_mobject

def update_edges(self, graph):
"""Updates the edges to stick at their corresponding vertices.
Expand Down
Binary file not shown.
40 changes: 40 additions & 0 deletions tests/test_graphical_units/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from manim import *
from manim.utils.testing.frames_comparison import frames_comparison

__module_test__ = "graph"


@frames_comparison
def test_digraph_add_edges(scene):
vertices = range(5)
edges = [
(0, 1),
(1, 2),
(3, 2),
(3, 4),
]

edge_config = {
"stroke_width": 2,
"tip_config": {
"tip_shape": ArrowSquareTip,
"tip_length": 0.15,
},
(3, 4): {
"color": RED,
"tip_config": {"tip_length": 0.25, "tip_width": 0.25}
},
}

g = DiGraph(
vertices,
[],
labels=True,
layout="circular",
).scale(1.4)

g.add_edges(*edges, edge_config=edge_config)

scene.add(g)

0 comments on commit 987f7cf

Please sign in to comment.