Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.4x avg speedup on PMG struct to CHGNET graph conversion #40

Merged
merged 18 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 99 additions & 5 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from typing import TYPE_CHECKING, Literal

import numpy as np
import torch
from torch import Tensor, nn

Expand All @@ -14,6 +15,11 @@

datatype = torch.float32

try:
from chgnet.graph.cygraph import make_graph
except ImportError:
print("Error importing fast graph conversion (cygraph). Reverting to legacy.")


class CrystalGraphConverter(nn.Module):
"""Convert a pymatgen.core.Structure to a CrystalGraph
Expand Down Expand Up @@ -45,6 +51,7 @@ def forward(
graph_id=None,
mp_id=None,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
graph_converter: Literal["legacy", "fast"] = "fast",
) -> CrystalGraph:
"""Convert a structure, return a CrystalGraph.

Expand All @@ -55,8 +62,9 @@ def forward(
mp_id (str): Materials Project id of this structure
Default = None
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
with isolated atoms.
Default = 'error'
with isolated atoms. Default = 'error'
graph_converter ('legacy' | 'fast'): graph converter to use when converting.
default = 'fast'

Return:
CrystalGraph that is ready to use by CHGNet
Expand All @@ -74,9 +82,22 @@ def forward(
center_index, neighbor_index, image, distance = self.get_neighbors(structure)

# Make Graph
graph = Graph([Node(index=i) for i in range(n_atoms)])
for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance):
graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist)
if graph_converter == "fast":
try:
graph = self._create_graph_fast(
n_atoms, center_index, neighbor_index, image, distance
)
except Exception:
print("Failed to retrieve fast graph converter. Reverting to legacy.")
graph = self._create_graph_legacy(
n_atoms, center_index, neighbor_index, image, distance
)
elif graph_converter == "legacy":
graph = self._create_graph_legacy(
n_atoms, center_index, neighbor_index, image, distance
)
else:
raise ValueError(f"No graph_converter named {graph_converter}")

# Atom Graph
atom_graph, directed2undirected = graph.adjacency_list()
Expand Down Expand Up @@ -127,6 +148,79 @@ def forward(
bond_graph_cutoff=self.bond_graph_cutoff,
)

def _create_graph_legacy(
self,
n_atoms: int,
center_index: np.ndarray,
neighbor_index: np.ndarray,
image: np.ndarray,
distance: np.ndarray,
) -> Graph:
"""Given structure information, create a Graph structure to be used to
create Crystal_Graph.

Args:
n_atoms (int): the number of atoms in the structure
center_index (np.ndarray): np array of indices of center atoms.
Shape: (# of edges, )
neighbor_index (np.ndarray): np array of indices of neighbor atoms.
Shape: (# of edges, )
image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3)
distance (np.ndarray): np array of distances. Shape: (# of edges, )

Return:
Graph data structure used to create Crystal_Graph object
"""
graph = Graph([Node(index=i) for i in range(n_atoms)])
for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance):
graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist)

return graph

def _create_graph_fast(
self,
n_atoms: int,
center_index: np.ndarray,
neighbor_index: np.ndarray,
image: np.ndarray,
distance: np.ndarray,
) -> Graph:
"""Given structure information, create a Graph structure to be used to
create Crystal_Graph. NOTE: this is the fast version of _create_graph_legacy optimized
in c (~3x speedup).

Args:
n_atoms (int): the number of atoms in the structure
center_index (np.ndarray): np array of indices of center atoms.
Shape: (# of edges, )
neighbor_index (np.ndarray): np array of indices of neighbor atoms.
Shape: (# of edges, )
image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3)
distance (np.ndarray): np array of distances. Shape: (# of edges, )

Return:
Graph data structure used to create Crystal_Graph object
"""
center_index = np.ascontiguousarray(center_index)
neighbor_index = np.ascontiguousarray(neighbor_index)
image = np.ascontiguousarray(image, dtype=np.int_)
distance = np.ascontiguousarray(distance)

(
nodes,
directed_edges_list,
undirected_edges_list,
undirected_edges,
) = make_graph(
center_index, len(center_index), neighbor_index, image, distance, n_atoms
)
graph = Graph(nodes=nodes)
graph.directed_edges_list = directed_edges_list
graph.undirected_edges_list = undirected_edges_list
graph.undirected_edges = undirected_edges

return graph

def get_neighbors(
self, structure: Structure
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
Expand Down
169 changes: 169 additions & 0 deletions chgnet/graph/cygraph.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# cython: language_level=3
# cython: initializedcheck=False
# cython: nonecheck=False
# cython: boundscheck=False
# cython: wraparound=False
# cython: cdivision=False
# cython: profile=False
# distutils: language = c

import chgnet.graph.graph
from cython.operator import dereference
import numpy as np
from libc cimport time
from libc.stdio cimport printf
from libc.stdlib cimport free

cdef extern from 'fast_converter_libraries/create_graph.c':
ctypedef struct Node:
long index
LongToDirectedEdgeList* neighbors
long num_neighbors

ctypedef struct NodeIndexPair:
long center
long neighbor

ctypedef struct UndirectedEdge:
NodeIndexPair nodes
long index
long* directed_edge_indices
long num_directed_edges
double distance

ctypedef struct DirectedEdge:
NodeIndexPair nodes
long index
const long* image
long undirected_edge_index
double distance

ctypedef struct LongToDirectedEdgeList:
long key
DirectedEdge** directed_edges_list
int num_directed_edges_in_group

ctypedef struct StructToUndirectedEdgeList:
NodeIndexPair key
UndirectedEdge** undirected_edges_list
int num_undirected_edges_in_group


ctypedef struct ReturnElems2:
long num_nodes
long num_directed_edges
long num_undirected_edges

Node* nodes
UndirectedEdge** undirected_edges_list
DirectedEdge** directed_edges_list
StructToUndirectedEdgeList* undirected_edges

ReturnElems2* create_graph(
long* center_index,
long n_e,
long* neighbor_index,
long* image,
double* distance,
long num_atoms)


LongToDirectedEdgeList** get_neighbors(Node* node)

def make_graph(
const long[::1] center_index,
const long n_e,
const long[::1] neighbor_index,
const long[:, ::1] image,
const double[::1] distance,
const long num_atoms
):
cdef ReturnElems2* returned
returned = <ReturnElems2*> create_graph(<long*> &center_index[0], n_e, <long*> &neighbor_index[0], <long*> &image[0][0], <double*> &distance[0], num_atoms)

chg_DirectedEdge = chgnet.graph.graph.DirectedEdge
chg_Node = chgnet.graph.graph.Node
chg_UndirectedEdge = chgnet.graph.graph.UndirectedEdge

image_np = np.asarray(image)

cdef LongToDirectedEdgeList** node_neighbors
cdef Node this_node
cdef LongToDirectedEdgeList this_entry
py_nodes = []
cdef DirectedEdge* this_DE

# Handling nodes + directed edges
for i in range(returned[0].num_nodes):
this_node = dereference(returned).nodes[i]
this_py_node = chg_Node(index=i)
this_py_node.neighbors = {}

node_neighbors = get_neighbors(&this_node)

# Iterate through all neighbors and populate our py_node.neighbors dict
for j in range(this_node.num_neighbors):
this_entry = dereference(node_neighbors[j])
directed_edges = []

for k in range(this_entry.num_directed_edges_in_group):
this_DE = this_entry.directed_edges_list[k]
directed_edges.append(this_DE[0].index)

this_py_node.neighbors[this_entry.key] = directed_edges

py_nodes.append(this_py_node)

# Handling directed edges
py_directed_edges_list = []

for i in range(returned[0].num_directed_edges):
this_DE = returned[0].directed_edges_list[i]
py_DE = chg_DirectedEdge(nodes = [this_DE[0].nodes.center, this_DE[0].nodes.neighbor], index=this_DE[0].index, info = {"distance": this_DE[0].distance, "image": image_np[this_DE[0].index], "undirected_edge_index": this_DE[0].undirected_edge_index})

py_directed_edges_list.append(py_DE)


# Handling undirected edges
py_undirected_edges_list = []
cdef UndirectedEdge* UDE

for i in range(returned[0].num_undirected_edges):
UDE = returned[0].undirected_edges_list[i]
py_undirected_edge = chg_UndirectedEdge([UDE[0].nodes.center, UDE[0].nodes.neighbor], index = UDE[0].index, info = {"distance": UDE[0].distance, "directed_edge_index": []})

for j in range(UDE[0].num_directed_edges):
py_undirected_edge.info["directed_edge_index"].append(UDE[0].directed_edge_indices[j])

py_undirected_edges_list.append(py_undirected_edge)


# Create Undirected_Edges hashmap
py_undirected_edges = {}
for undirected_edge in py_undirected_edges_list:
this_set = frozenset(undirected_edge.nodes)
if this_set not in py_undirected_edges:
py_undirected_edges[this_set] = [undirected_edge]
else:
py_undirected_edges[this_set].append(undirected_edge)

# # Update the nodes list to have pointers to DirectedEdges instead of indices
for node_index in range(returned[0].num_nodes):
this_neighbors = py_nodes[node_index].neighbors
for this_neighbor_index in this_neighbors:
replacement = [py_directed_edges_list[edge_index] for edge_index in this_neighbors[this_neighbor_index]]
this_neighbors[this_neighbor_index] = replacement


# Free everything unneeded
for i in range(returned[0].num_directed_edges):
free(returned[0].directed_edges_list[i])

for i in range(returned[0].num_undirected_edges):
free(returned[0].undirected_edges_list[i])

free(returned[0].directed_edges_list)
free(returned[0].undirected_edges_list)
free(returned[0].nodes)

return py_nodes, py_directed_edges_list, py_undirected_edges_list, py_undirected_edges
Loading
Loading