Skip to content

Commit

Permalink
fix: avoid C++ initializer list in PYX
Browse files Browse the repository at this point in the history
Initializer lists {a, b, c} are interpreted as sets in PYX, which led to
slow and incorrect assignment of array attributes.
  • Loading branch information
funkey committed Oct 8, 2024
1 parent 5c44223 commit 732a77f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 9 deletions.
46 changes: 39 additions & 7 deletions spatial_graph/graph/wrapper_template.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,13 @@ cdef class Graph:
self,
NodeType node,
$dtype.to_pyxtype(use_memory_view=True) $name):
self._graph.node_prop(node).${name} = $dtype.to_rvalue(name=$name)
%if $dtype.is_array
%for j in range($dtype.size)
self._graph.node_prop(node).${name}[$j] = ${name}[$j]
%end for
%else
self._graph.node_prop(node).${name} = $name
%end if

def set_nodes_data_${name}(
self,
Expand All @@ -453,16 +459,29 @@ cdef class Graph:
cdef Py_ssize_t i = 0

# all nodes requested
%set $rvalue=$dtype.to_rvalue(name=$name, array_index="i")
if nodes is None:
while node_it != node_end:
self._graph.node_prop(node_it).$name = $rvalue
%if $dtype.is_array
node_data = &self._graph.node_prop(node_it)
%for j in range($dtype.size)
node_data.${name}[$j] = ${name}[i, $j]
%end for
%else
self._graph.node_prop(node_it).$name = ${name}[i]
%end if
inc(node_it)
i += 1
else:
assert len(nodes) == len($name)
for i in range(len(nodes)):
self._graph.node_prop(nodes[i]).$name = $rvalue
%if dtype.is_array
node_data = &self._graph.node_prop(nodes[i])
%for j in range($dtype.size)
node_data.${name}[$j] = ${name}[i, $j]
%end for
%else
self._graph.node_prop(nodes[i]).$name = ${name}[i]
%end if
%end for

%for name, dtype in $edge_attr_dtypes.items()
Expand Down Expand Up @@ -541,7 +560,14 @@ cdef class Graph:
self,
NodeType u, NodeType v,
$dtype.to_pyxtype(use_memory_view=True) $name):
self._graph.edge_prop(u, v).${name} = $dtype.to_rvalue(name=$name)
%if dtype.is_array
edge_data = &self._graph.edge_prop(u, v)
%for j in range($dtype.size)
edge_data.${name}[$j] = ${name}[$j]
%end for
%else
self._graph.edge_prop(u, v).${name} = $name
%end if

def set_edges_data_${name}(
self,
Expand All @@ -554,9 +580,15 @@ cdef class Graph:
assert len(us) == len(vs)
num_edges = len(us)

%set $rvalue = $dtype.to_rvalue(name=$name, array_index="i")
for i in range(num_edges):
self._graph.edge_prop(us[i], vs[i]).$name = $rvalue
%if $dtype.is_array
edge_data = &self._graph.edge_prop(us[i], vs[i])
%for j in range($dtype.size)
edge_data.${name}[$j] = ${name}[i, $j]
%end for
%else
self._graph.edge_prop(us[i], vs[i]).$name = ${name}[i]
%end if
%end for

# modify graph
Expand Down
43 changes: 41 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_directed_edges():
def test_attribute_modification():
graph = sg.Graph(
"uint64",
{"attr1": "double", "attr2": "int"},
{"attr1": "double", "attr2": "int", "attr3": "float32[3]"},
{"attr1": "int[4]"},
directed=False,
)
Expand All @@ -119,12 +119,18 @@ def test_attribute_modification():
np.array([1, 2, 3, 4, 5], dtype="uint64"),
attr1=np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype="double"),
attr2=np.array([1, 2, 3, 4, 5], dtype="int"),
attr3=np.array([
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3],
[0.1, 0.2, 0.3]], dtype="float32")
)

graph.add_edges(
np.array([[1, 2], [3, 4], [5, 1]], dtype="uint64"),
attr1=np.array(
[[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7], [5, 6, 7, 8]],
[[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]],
dtype="int",
),
)
Expand All @@ -134,10 +140,12 @@ def test_attribute_modification():
for node, attrs in graph.nodes(data=True):
attrs.attr1 += 10.0
attrs.attr2 *= 2
attrs.attr3 *= np.float32(3.0)

for node, attrs in graph.nodes(data=True):
assert attrs.attr1 == (node / 10.0) + 10.0
assert attrs.attr2 == node * 2
np.testing.assert_array_almost_equal(attrs.attr3, [0.3, 0.6, 0.9])

# modify via attribute views (single item):

Expand All @@ -157,6 +165,37 @@ def test_attribute_modification():
assert graph.node_attrs[3].attr2 == 60
assert graph.node_attrs[4].attr2 == 80

graph.node_attrs[[2, 3, 4]].attr3 += np.array([
[1, 1, 1],
[2, 2, 2],
[3, 3, 3]], dtype="float32")
np.testing.assert_array_almost_equal(graph.node_attrs[2].attr3, [1.3, 1.6, 1.9])
np.testing.assert_array_almost_equal(graph.node_attrs[3].attr3, [2.3, 2.6, 2.9])
np.testing.assert_array_almost_equal(graph.node_attrs[4].attr3, [3.3, 3.6, 3.9])

# modify edge attribute
np.testing.assert_array_equal(
graph.edge_attrs[[[1, 2], [5, 1]]].attr1,
[
[1, 2, 3, 4],
[3, 4, 5, 6]
]
)
graph.edge_attrs[[[1, 2], [5, 1]]].attr1 = np.array(
[
[11, 22, 33, 44],
[30, 40, 50, 60]
],
dtype="int"
)
np.testing.assert_array_equal(
graph.edge_attrs[[[1, 2], [3, 4], [5, 1]]].attr1,
[
[11, 22, 33, 44],
[2, 3, 4, 5],
[30, 40, 50, 60]
]
)

def test_missing_nodes_edges():
graph = sg.Graph(
Expand Down

0 comments on commit 732a77f

Please sign in to comment.