From 11ec52ade5c05ce3cd15559b136c86e88692b223 Mon Sep 17 00:00:00 2001 From: Kevin Han Date: Mon, 26 Jun 2023 18:25:16 -0700 Subject: [PATCH 01/15] Wrote cygraph. --- chgnet/graph/converter.py | 89 +- chgnet/graph/cygraph.pyx | 169 +++ .../fast_converter_libraries/create_graph.c | 505 ++++++++ .../graph/fast_converter_libraries/uthash.h | 1140 +++++++++++++++++ pyproject.toml | 2 +- setup.py | 8 + tests/test_crystal_graph.py | 201 ++- 7 files changed, 2094 insertions(+), 20 deletions(-) create mode 100644 chgnet/graph/cygraph.pyx create mode 100644 chgnet/graph/fast_converter_libraries/create_graph.c create mode 100644 chgnet/graph/fast_converter_libraries/uthash.h create mode 100644 setup.py diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index e548574c..42e31c2d 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -5,6 +5,8 @@ import torch from torch import Tensor, nn +import numpy as np + from chgnet.graph.crystalgraph import CrystalGraph from chgnet.graph.graph import Graph, Node @@ -14,6 +16,10 @@ datatype = torch.float32 +try: + from chgnet.graph.cygraph import make_graph +except: + print("Error importing fast graph conversion (cygraph). Reverting to legacy.") class CrystalGraphConverter(nn.Module): """Convert a pymatgen.core.Structure to a CrystalGraph @@ -45,6 +51,7 @@ def forward( graph_id=None, mp_id=None, on_isolated_atoms: Literal["ignore", "warn", "error"] = "error", + graph_converter: Literal["legacy", "fast"] = "fast" ) -> CrystalGraph: """Convert a structure, return a CrystalGraph. @@ -55,8 +62,9 @@ def forward( mp_id (str): Materials Project id of this structure Default = None on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures - with isolated atoms. - Default = 'error' + with isolated atoms. Default = 'error' + graph_converter ('legacy' | 'fast'): graph converter to use when converting. + default = 'fast' Return: Crystal_Graph that is ready to use by CHGNet @@ -74,9 +82,16 @@ def forward( center_index, neighbor_index, image, distance = self.get_neighbors(structure) # Make Graph - graph = Graph([Node(index=i) for i in range(n_atoms)]) - for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance): - graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist) + if graph_converter == "fast": + try: + graph = self._create_graph_fast(n_atoms, center_index, neighbor_index, image, distance) + except: + print("Failed to retrieve fast graph converter. Reverting to legacy.") + graph = self._create_graph_legacy(n_atoms, center_index, neighbor_index, image, distance) + elif graph_converter == "legacy": + graph = self._create_graph_legacy(n_atoms, center_index, neighbor_index, image, distance) + else: + raise ValueError(f"No graph_converter named {graph_converter}") # Atom Graph atom_graph, directed2undirected = graph.adjacency_list() @@ -127,6 +142,70 @@ def forward( bond_graph_cutoff=self.bond_graph_cutoff, ) + def _create_graph_legacy( + self, + n_atoms: int, + center_index: np.ndarray, + neighbor_index: np.ndarray, + image: np.ndarray, + distance: np.ndarray + ) -> Graph: + """Given structure information, create a Graph structure to be used to + create Crystal_Graph + + Args: + n_atoms (int): the number of atoms in the structure + center_index (np.ndarray): np array of indices of center atoms. Shape: (# of edges, ) + neighbor_index (np.ndarray): np array of indices of neighbor atoms. Shape: (# of edges, ) + image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) + distance (np.ndarray): np array of distances. Shape: (# of edges, ) + + Return: + Graph data structure used to create Crystal_Graph object + """ + graph = Graph([Node(index=i) for i in range(n_atoms)]) + for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance): + graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist) + + return graph + + + def _create_graph_fast( + self, + n_atoms: int, + center_index: np.ndarray, + neighbor_index: np.ndarray, + image: np.ndarray, + distance: np.ndarray + ) -> Graph: + """Given structure information, create a Graph structure to be used to + create Crystal_Graph. NOTE: this is the fast version of _create_graph_legacy optimized + in c (~3x speedup). + + Args: + n_atoms (int): the number of atoms in the structure + center_index (np.ndarray): np array of indices of center atoms. Shape: (# of edges, ) + neighbor_index (np.ndarray): np array of indices of neighbor atoms. Shape: (# of edges, ) + image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) + distance (np.ndarray): np array of distances. Shape: (# of edges, ) + + Return: + Graph data structure used to create Crystal_Graph object + """ + + center_index = np.ascontiguousarray(center_index) + neighbor_index = np.ascontiguousarray(neighbor_index) + image = np.ascontiguousarray(image, dtype=np.int_) + distance = np.ascontiguousarray(distance) + + nodes, directed_edges_list, undirected_edges_list, undirected_edges = make_graph(center_index, len(center_index), neighbor_index, image, distance, n_atoms) + graph = Graph(nodes=nodes) + graph.directed_edges_list = directed_edges_list + graph.undirected_edges_list = undirected_edges_list + graph.undirected_edges = undirected_edges + + return graph + def get_neighbors( self, structure: Structure ) -> tuple[Tensor, Tensor, Tensor, Tensor]: diff --git a/chgnet/graph/cygraph.pyx b/chgnet/graph/cygraph.pyx new file mode 100644 index 00000000..b6d257f5 --- /dev/null +++ b/chgnet/graph/cygraph.pyx @@ -0,0 +1,169 @@ +# cython: language_level=3 +# cython: initializedcheck=False +# cython: nonecheck=False +# cython: boundscheck=False +# cython: wraparound=False +# cython: cdivision=False +# cython: profile=False +# distutils: language = c + +import chgnet.graph.graph +from cython.operator import dereference +import numpy as np +from libc cimport time +from libc.stdio cimport printf +from libc.stdlib cimport free + +cdef extern from 'fast_converter_libraries/create_graph.c': + ctypedef struct Node: + long index + LongToDirectedEdgeList* neighbors + long num_neighbors + + ctypedef struct NodeIndexPair: + long center + long neighbor + + ctypedef struct UndirectedEdge: + NodeIndexPair nodes + long index + long* directed_edge_indices + long num_directed_edges + double distance + + ctypedef struct DirectedEdge: + NodeIndexPair nodes + long index + const long* image + long undirected_edge_index + double distance + + ctypedef struct LongToDirectedEdgeList: + long key + DirectedEdge** directed_edges_list + int num_directed_edges_in_group + + ctypedef struct StructToUndirectedEdgeList: + NodeIndexPair key + UndirectedEdge** undirected_edges_list + int num_undirected_edges_in_group + + + ctypedef struct ReturnElems2: + long num_nodes + long num_directed_edges + long num_undirected_edges + + Node* nodes + UndirectedEdge** undirected_edges_list + DirectedEdge** directed_edges_list + StructToUndirectedEdgeList* undirected_edges + + ReturnElems2* create_graph( + long* center_index, + long n_e, + long* neighbor_index, + long* image, + double* distance, + long num_atoms) + + + LongToDirectedEdgeList** get_neighbors(Node* node) + +def make_graph( + const long[::1] center_index, + const long n_e, + const long[::1] neighbor_index, + const long[:, ::1] image, + const double[::1] distance, + const long num_atoms + ): + cdef ReturnElems2* returned + returned = create_graph( ¢er_index[0], n_e, &neighbor_index[0], &image[0][0], &distance[0], num_atoms) + + chg_DirectedEdge = chgnet.graph.graph.DirectedEdge + chg_Node = chgnet.graph.graph.Node + chg_UndirectedEdge = chgnet.graph.graph.UndirectedEdge + + image_np = np.asarray(image) + + cdef LongToDirectedEdgeList** node_neighbors + cdef Node this_node + cdef LongToDirectedEdgeList this_entry + py_nodes = [] + cdef DirectedEdge* this_DE + + # Handling nodes + directed edges + for i in range(returned[0].num_nodes): + this_node = dereference(returned).nodes[i] + this_py_node = chg_Node(index=i) + this_py_node.neighbors = {} + + node_neighbors = get_neighbors(&this_node) + + # Iterate through all neighbors and populate our py_node.neighbors dict + for j in range(this_node.num_neighbors): + this_entry = dereference(node_neighbors[j]) + directed_edges = [] + + for k in range(this_entry.num_directed_edges_in_group): + this_DE = this_entry.directed_edges_list[k] + directed_edges.append(this_DE[0].index) + + this_py_node.neighbors[this_entry.key] = directed_edges + + py_nodes.append(this_py_node) + + # Handling directed edges + py_directed_edges_list = [] + + for i in range(returned[0].num_directed_edges): + this_DE = returned[0].directed_edges_list[i] + py_DE = chg_DirectedEdge(nodes = [this_DE[0].nodes.center, this_DE[0].nodes.neighbor], index=this_DE[0].index, info = {"distance": this_DE[0].distance, "image": image_np[this_DE[0].index], "undirected_edge_index": this_DE[0].undirected_edge_index}) + + py_directed_edges_list.append(py_DE) + + + # Handling undirected edges + py_undirected_edges_list = [] + cdef UndirectedEdge* UDE + + for i in range(returned[0].num_undirected_edges): + UDE = returned[0].undirected_edges_list[i] + py_undirected_edge = chg_UndirectedEdge([UDE[0].nodes.center, UDE[0].nodes.neighbor], index = UDE[0].index, info = {"distance": UDE[0].distance, "directed_edge_index": []}) + + for j in range(UDE[0].num_directed_edges): + py_undirected_edge.info["directed_edge_index"].append(UDE[0].directed_edge_indices[j]) + + py_undirected_edges_list.append(py_undirected_edge) + + + # Create Undirected_Edges hashmap + py_undirected_edges = {} + for undirected_edge in py_undirected_edges_list: + this_set = frozenset(undirected_edge.nodes) + if this_set not in py_undirected_edges: + py_undirected_edges[this_set] = [undirected_edge] + else: + py_undirected_edges[this_set].append(undirected_edge) + + # # Update the nodes list to have pointers to DirectedEdges instead of indices + for node_index in range(returned[0].num_nodes): + this_neighbors = py_nodes[node_index].neighbors + for this_neighbor_index in this_neighbors: + replacement = [py_directed_edges_list[edge_index] for edge_index in this_neighbors[this_neighbor_index]] + this_neighbors[this_neighbor_index] = replacement + + + # Free everything unneeded + for i in range(returned[0].num_directed_edges): + free(returned[0].directed_edges_list[i]) + + for i in range(returned[0].num_undirected_edges): + free(returned[0].undirected_edges_list[i]) + + free(returned[0].directed_edges_list) + free(returned[0].undirected_edges_list) + free(returned[0].nodes) + + return py_nodes, py_directed_edges_list, py_undirected_edges_list, py_undirected_edges diff --git a/chgnet/graph/fast_converter_libraries/create_graph.c b/chgnet/graph/fast_converter_libraries/create_graph.c new file mode 100644 index 00000000..df7b1981 --- /dev/null +++ b/chgnet/graph/fast_converter_libraries/create_graph.c @@ -0,0 +1,505 @@ +#include "uthash.h" +#include + +typedef struct _UndirectedEdge UndirectedEdge; +typedef struct _DirectedEdge DirectedEdge; +typedef struct _Node Node; +typedef struct _NodeIndexPair NodeIndexPair; +typedef struct _LongToDirectedEdgeList LongToDirectedEdgeList; +typedef struct _ReturnElems ReturnElems; +typedef struct _ReturnElems2 ReturnElems2; + +// NOTE: This code was mainly written to replicate the original add_edges method +// in the graph class in chgnet.graph.graph such that anyone familiar with that code should be able to pick up this +// code pretty easily. + +long MEM_ERR = 100; + +typedef struct _Node { + long index; + LongToDirectedEdgeList* neighbors; // Assuming neighbors can only be directed edge. Key is dest_node, value is DirectedEdge struct + long num_neighbors; +} Node; + +typedef struct _NodeIndexPair { + long center; + long neighbor; +} NodeIndexPair; + +typedef struct _UndirectedEdge { + NodeIndexPair nodes; + long index; + long* directed_edge_indices; + long num_directed_edges; + double distance; +} UndirectedEdge; + +typedef struct _DirectedEdge { + NodeIndexPair nodes; + long index; + const long* image; // Only access the first 3, never edit + long undirected_edge_index; + double distance; +} DirectedEdge; + +typedef struct _StructToUndirectedEdgeList { + NodeIndexPair key; + UndirectedEdge** undirected_edges_list; + int num_undirected_edges_in_group; + UT_hash_handle hh; +} StructToUndirectedEdgeList; + +typedef struct _LongToDirectedEdgeList { + long key; + DirectedEdge** directed_edges_list; + int num_directed_edges_in_group; + UT_hash_handle hh; +} LongToDirectedEdgeList; + +typedef struct _ReturnElems { + long num_nodes; + long* node_index_unraveled; + long* node_neighbor_index_unraveled; + long* node_directed_edge_index_unraveled; + + long num_undirected_edges; + long* undirected_center_index_unraveled; + long* undirected_neighbor_index_unraveled; + long* undirected_index_unraveled; + long* undirected_directed_edge_indices_unraveled; + double* undirected_distances_unraveled; + + long num_directed_edges; + long* directed_undirected_edge_index_unraveled; +} ReturnElems; + + +typedef struct _ReturnElems2 { + long num_nodes; + long num_directed_edges; + long num_undirected_edges; + Node* nodes; + UndirectedEdge** undirected_edges_list; + DirectedEdge** directed_edges_list; + StructToUndirectedEdgeList* undirected_edges; +} ReturnElems2; + +bool find_in_undirected(NodeIndexPair* tmp, StructToUndirectedEdgeList** undirected_edges, StructToUndirectedEdgeList** found_entry); +void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, long undirected_index); +void create_new_undirected_edges_entry(StructToUndirectedEdgeList** undirected_edges, NodeIndexPair* tmp, UndirectedEdge* new_undirected_edge); +void append_to_undirected_edges_tmp(UndirectedEdge* undirected, StructToUndirectedEdgeList** undirected_edges, NodeIndexPair* tmp); +void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, long* num_undirected_edges); +void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, long* num_directed_edges); +void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* directed_edge); +void print_neighbors(Node* node); +void append_to_directed_edge_indices(UndirectedEdge* undirected_edge, long directed_edge_index); +bool is_reversed_directed_edge(DirectedEdge* directed_edge1, DirectedEdge* directed_edge2); +ReturnElems* get_raw_data(Node* nodes, long num_nodes, long num_undirected_edges, long num_directed_edges, UndirectedEdge** undirected_edges_list, DirectedEdge** directed_edges_list); + + +Node* create_nodes(long num_nodes) { + Node* Nodes = (Node*) malloc(sizeof(Node) * num_nodes); + + if (Nodes == NULL) { + return NULL; + } + + for (long i = 0; i < num_nodes; i++) { + Nodes[i].index = i; + Nodes[i].num_neighbors = 0; + + // Initialize the uthash + Nodes[i].neighbors = NULL; + } + + return Nodes; +} + +ReturnElems2* create_graph( + long* center_indices, + long num_edges, + long* neighbor_indices, + long* images, // contiguous memory (row-major) of image elements (total of n_e * 3 integers) + double* distances, + long num_atoms + ) { + // Initialize pertinent data structures --------------------- + Node* nodes = create_nodes(num_atoms); + + DirectedEdge** directed_edges_list = calloc(num_edges, sizeof(DirectedEdge)); + long num_directed_edges = 0; + + // There will never be more undirected edges than directed edges + UndirectedEdge** undirected_edges_list = calloc(num_edges, sizeof(UndirectedEdge)); + long num_undirected_edges = 0; + StructToUndirectedEdgeList* undirected_edges = NULL; + + // Pointer to beginning of list of UndirectedEdges corresponding to tmp of current iteration + StructToUndirectedEdgeList* corr_undirected_edges_item = NULL; + + // Pointer to NodeIndexPair storing tmp + NodeIndexPair* tmp = malloc(sizeof(NodeIndexPair)); + + // Flag for whether or not the value was found + bool found = false; + + // Flag used to show if we've already processed the current undirected edge + bool processed_edge = false; + + // Pointer used to store the previously added directed edge between two nodes + DirectedEdge* added_DE; + + // Add all edges to graph information + for (long i = 0; i < num_edges; i++) { + // Haven't processed the edge yet + processed_edge = false; + // Create the current directed edge ------------------- + DirectedEdge* this_directed_edge = calloc(1, sizeof(DirectedEdge)); + this_directed_edge->nodes.center = center_indices[i]; + this_directed_edge->nodes.neighbor = neighbor_indices[i]; + this_directed_edge->distance = distances[i]; + this_directed_edge->index = num_directed_edges; + this_directed_edge->image = images + (3 * i); + + // Load tmp + memset(tmp, 0, sizeof(NodeIndexPair)); + tmp->center = center_indices[i]; + tmp->neighbor = neighbor_indices[i]; + + // See if tmp is in undirected + corr_undirected_edges_item = NULL; + found = find_in_undirected(tmp, &undirected_edges, &corr_undirected_edges_item); + + if (!found) { + // Never seen this edge combination before + // printf("C: new edge combo: %lu and %lu. Dist: %.15lf. Img: [%ld, %ld, %ld]\n", tmp->center, tmp->neighbor, distances[i], *(this_directed_edge->image), *(this_directed_edge->image + 1), *(this_directed_edge->image + 2)); + + this_directed_edge->undirected_edge_index = num_undirected_edges; + + //TODO: be careful about double-freeing later. we're re-using a lot of memory space + + // Create new undirected edge + UndirectedEdge* this_undirected_edge = malloc(sizeof(UndirectedEdge)); + + directed_to_undirected(this_directed_edge, this_undirected_edge, num_undirected_edges); + + // Add this new edge information to various data structures + create_new_undirected_edges_entry(&undirected_edges, tmp, this_undirected_edge); + append_to_undirected_edges_list(undirected_edges_list, this_undirected_edge, &num_undirected_edges); + add_neighbors_to_node(&nodes[center_indices[i]], neighbor_indices[i], this_directed_edge); + append_to_directed_edges_list(directed_edges_list, this_directed_edge, &num_directed_edges); + } else { + // This pair of nodes has been added before. We have to check if it's the other directed edge (but pointed in + // the different direction) OR it's another totally different undirected edge that has different image and distance + + // if found is true, then corr_undirected_edges_item points to self.undirected_edges[tmp] + // iterate through all previously scanned undirected edges that have the same endpoints as this edge + // if there exists an undirected edge with the same inverted image as this_undirected_edge, then add this new directed edge + // and associate it with this undirected edge + for (int j = 0; j < corr_undirected_edges_item->num_undirected_edges_in_group; j++) { + // Grab the 0th directed edge associated with this undirected edge + added_DE = directed_edges_list[((corr_undirected_edges_item->undirected_edges_list)[j]->directed_edge_indices)[0]]; + + if (is_reversed_directed_edge(added_DE, this_directed_edge)) { + this_directed_edge->undirected_edge_index = added_DE->undirected_edge_index; + add_neighbors_to_node(&nodes[center_indices[i]], neighbor_indices[i], this_directed_edge); + append_to_directed_edges_list(directed_edges_list, this_directed_edge, &num_directed_edges); + append_to_directed_edge_indices((corr_undirected_edges_item->undirected_edges_list)[j], this_directed_edge->index); + processed_edge = true; + break; + } + } + // There wasn't a pre-existing undirected edge that corresponds to this directed edge + // Create a new undirected edge and process + if (!processed_edge) { + this_directed_edge->undirected_edge_index = num_undirected_edges; + // Create a new undirected edge + UndirectedEdge* this_undirected_edge = malloc(sizeof(UndirectedEdge)); + directed_to_undirected(this_directed_edge, this_undirected_edge, num_undirected_edges); + append_to_undirected_edges_tmp(this_undirected_edge, &undirected_edges, tmp); + append_to_undirected_edges_list(undirected_edges_list, this_undirected_edge, &num_undirected_edges); + add_neighbors_to_node(&nodes[center_indices[i]], neighbor_indices[i], this_directed_edge); + append_to_directed_edges_list(directed_edges_list, this_directed_edge, &num_directed_edges); + } + } + } + + + // ReturnElems* returned; + // returned = get_raw_data(nodes, num_atoms, num_undirected_edges, num_directed_edges, undirected_edges_list, directed_edges_list); + + // printf("From returned struct: %lu\n", returned->num_directed_edges); + // return returned; + + ReturnElems2* returned2 = malloc(sizeof(ReturnElems2)); + returned2->num_nodes = num_atoms; + returned2->num_undirected_edges = num_undirected_edges; + returned2->num_directed_edges = num_directed_edges; + + returned2->nodes = nodes; + returned2->directed_edges_list = directed_edges_list; + returned2->undirected_edges_list = undirected_edges_list; + returned2->undirected_edges = undirected_edges; + return returned2; +} + + +// Converts all data into forms that can be digested in cython and used to create a graph python object +ReturnElems* get_raw_data( + Node* nodes, + long num_nodes, + long num_undirected_edges, + long num_directed_edges, + UndirectedEdge** undirected_edges_list, + DirectedEdge** directed_edges_list + ) { + // NODES --------------------------- + // allocate memory to store node information + long* node_index_unraveled = malloc(sizeof(long) * num_directed_edges); + long* node_neighbor_index_unraveled = malloc(sizeof(long) * num_directed_edges); + long* node_directed_edge_index_unraveled = malloc(sizeof(long) * num_directed_edges); + long unravel_index = 0; + + LongToDirectedEdgeList *tmp, *neighbor; + + for (long node_i = 0; node_i < num_nodes; node_i++) { + HASH_ITER(hh, nodes[node_i].neighbors, neighbor, tmp) { + for (long edge_i = 0; edge_i < neighbor->num_directed_edges_in_group; edge_i++) { + node_index_unraveled[unravel_index] = nodes[node_i].index; + node_neighbor_index_unraveled[unravel_index] = neighbor->key; + node_directed_edge_index_unraveled[unravel_index] = neighbor->directed_edges_list[edge_i]->index; + unravel_index += 1; + } + } + } + + // Undirected edges -------------- + long* undirected_center_index_unraveled = malloc(sizeof(long) * num_directed_edges); + long* undirected_neighbor_index_unraveled = malloc(sizeof(long) * num_directed_edges); + long* undirected_index_unraveled = malloc(sizeof(long) * num_directed_edges); + long* undirected_directed_edge_indices_unraveled = malloc(sizeof(long) * num_directed_edges); + double* undirected_distances_unraveled = malloc(sizeof(double) * num_directed_edges); + unravel_index = 0; + + UndirectedEdge* curr_undirected; + + for (long undirected_i = 0; undirected_i < num_undirected_edges; undirected_i++) { + curr_undirected = undirected_edges_list[undirected_i]; + for (long directed_i = 0; directed_i < curr_undirected->num_directed_edges; directed_i++) { + undirected_center_index_unraveled[unravel_index] = curr_undirected->nodes.center; + undirected_neighbor_index_unraveled[unravel_index] = curr_undirected->nodes.neighbor; + undirected_index_unraveled[unravel_index] = curr_undirected->index; + undirected_directed_edge_indices_unraveled[unravel_index] = curr_undirected->directed_edge_indices[directed_i]; + undirected_distances_unraveled[unravel_index] = curr_undirected->distance; + + unravel_index += 1; + } + } + + // Directed edges --------------- + // center unraveled, neighbor unraveled, image unraveled, distance unraveled for directed edges are all + // just the inputs to the create graph function + long* directed_undirected_edge_index_unraveled = malloc(sizeof(long) * num_directed_edges); + for (long directed_i = 0; directed_i < num_directed_edges; directed_i++) { + directed_undirected_edge_index_unraveled[directed_i] = directed_edges_list[directed_i]->undirected_edge_index; + } + + ReturnElems* returned = malloc(sizeof(ReturnElems)); + returned->node_index_unraveled = node_index_unraveled; + returned->node_neighbor_index_unraveled = node_neighbor_index_unraveled; + returned->node_directed_edge_index_unraveled = node_directed_edge_index_unraveled; + + returned->undirected_center_index_unraveled = undirected_center_index_unraveled; + returned->undirected_neighbor_index_unraveled = undirected_neighbor_index_unraveled; + returned->undirected_index_unraveled = undirected_index_unraveled; + returned->undirected_directed_edge_indices_unraveled = undirected_directed_edge_indices_unraveled; + returned->undirected_distances_unraveled = undirected_distances_unraveled; + + returned->directed_undirected_edge_index_unraveled = directed_undirected_edge_index_unraveled; + + returned->num_nodes = num_nodes; + returned->num_directed_edges = num_directed_edges; + returned->num_undirected_edges = num_undirected_edges; + + return returned; +} + +// Returns a list of LongToDirectedEdgeList pointers which are entries for the neighbors of the inputted node +LongToDirectedEdgeList** get_neighbors(Node* node) { + long num_neighbors = HASH_COUNT(node->neighbors); + LongToDirectedEdgeList** entries = malloc(sizeof(LongToDirectedEdgeList*) * num_neighbors); + + LongToDirectedEdgeList* entry; + long counter = 0; + for (entry = node->neighbors; entry != NULL; entry = entry->hh.next) { + entries[counter] = entry; + counter += 1; + } + + return entries; +} + +void print_neighbors(Node* node) { + LongToDirectedEdgeList *tmp, *neighbor; + HASH_ITER(hh, node->neighbors, neighbor, tmp) { + printf("C:neighboring atom: %lu\n", neighbor->key); + } +} + +// Returns true if the two directed edges have images that are inverted +// NOTE: assumes that directed_edge1->center = directed_edge2->neighbor and directed_edge1->neighbor = directed_edge2->center +bool is_reversed_directed_edge(DirectedEdge* directed_edge1, DirectedEdge* directed_edge2) { + for (int i = 0; i < 3; i++) { + if (directed_edge1->image[i] != -1 * directed_edge2->image[i]) { + return false; + } + } + + // The two directed edges should have opposing center/neighbor nodes (i.e. center-neighbor for DE1 is [0, 1] and for DE2 is [1, 0]) + // We check for that condition here + if (directed_edge1->nodes.center != directed_edge2->nodes.neighbor) { + return false; + } + if (directed_edge1->nodes.neighbor != directed_edge2->nodes.center) { + return false; + } + return true; +} + +// If tmp or the reverse of tmp is found in undirected_edges, True is returned and the corresponding StructToUndirectedEdgeList pointer is placed +// into found_entry. Otherwise, False is returned. +// NOTE: does not edit the *tmp +// Assumes *tmp bits have already been 0'd at padding within a struct +bool find_in_undirected(NodeIndexPair* tmp, StructToUndirectedEdgeList** undirected_edges, StructToUndirectedEdgeList** found_entry) { + StructToUndirectedEdgeList* out_list; + // Check tmp + HASH_FIND(hh, *undirected_edges, tmp, sizeof(NodeIndexPair), out_list); + + if (out_list) { + *found_entry = out_list; + return true; + } + + // Check tmp_rev + NodeIndexPair tmp_rev; + tmp_rev.center = tmp->neighbor; + tmp_rev.neighbor = tmp->center; + + HASH_FIND(hh, *undirected_edges, &tmp_rev, sizeof(NodeIndexPair), out_list); + + if (out_list) { + *found_entry = out_list; + return true; + } + + return false; +} + + +// Creates new entry in undirected_edges and initializes necessary arrays +void create_new_undirected_edges_entry(StructToUndirectedEdgeList** undirected_edges, NodeIndexPair* tmp, UndirectedEdge* new_undirected_edge) { + StructToUndirectedEdgeList* new_entry = malloc(sizeof(StructToUndirectedEdgeList)); + memset(new_entry, 0, sizeof(StructToUndirectedEdgeList)); + + // Set up fields within the new entry in the hashmap + new_entry->key.center = tmp->center; + new_entry->key.neighbor = tmp->neighbor; + + new_entry->num_undirected_edges_in_group = 1; + new_entry->undirected_edges_list = malloc(sizeof(UndirectedEdge*)); + new_entry->undirected_edges_list[0] = new_undirected_edge; + + HASH_ADD(hh, *undirected_edges, key, sizeof(NodeIndexPair), new_entry); + +} + +// Appends undirected into the StructToUndirectedEdgeList entry that corresponds to tmp +// This function will first look up tmp +void append_to_undirected_edges_tmp(UndirectedEdge* undirected, StructToUndirectedEdgeList** undirected_edges, NodeIndexPair* tmp) { + + StructToUndirectedEdgeList* this_undirected_edges_item; + find_in_undirected(tmp, undirected_edges, &this_undirected_edges_item); + + long num_undirected_edges = this_undirected_edges_item->num_undirected_edges_in_group; + + // No need to worry about originally malloc'ing memory for this_undirected_edges_item->undirected_edges_list + // this is because, we first call create_new_undirected_edges_entry for all entires. This function already mallocs for us. + + // Realloc the space to fit a new pointer to an undirected edge + UndirectedEdge** new_list = realloc(this_undirected_edges_item->undirected_edges_list, sizeof(UndirectedEdge*) * (num_undirected_edges + 1)); + this_undirected_edges_item->undirected_edges_list = new_list; + + // Insert the undirected pointer into the newly allocated slot + this_undirected_edges_item->undirected_edges_list[num_undirected_edges] = undirected; + + // Increase the counter for # of undirected edges + this_undirected_edges_item->num_undirected_edges_in_group = num_undirected_edges + 1; +} + + +void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, long undirected_index) { + // Copy over image and distance + undirected->distance = directed->distance; + undirected->nodes = directed->nodes; + undirected->index = undirected_index; + + // Add a new directed_edge_index to the directed_edge_indices pointer. This should be the first + undirected->num_directed_edges = 1; + undirected->directed_edge_indices = malloc(sizeof(long)); + undirected->directed_edge_indices[0] = directed->index; +} + + +void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, long* num_undirected_edges) { + // No need to realloc for space since our original alloc should cover everything + + // Assign value to next available position + undirected_edges_list[*num_undirected_edges] = to_add; + *num_undirected_edges += 1; +} + +void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, long* num_directed_edges) { + // No need to realloc for space since our original alloc should cover everything + + // Assign value to next availabe position + directed_edges_list[*num_directed_edges] = to_add; + *num_directed_edges += 1; +} + +void append_to_directed_edge_indices(UndirectedEdge* undirected_edge, long directed_edge_index) { + // TODO: don't need to realloc if we always know that there will be 2 directed edges per undirected edge. Update this later for performance boosts. + // TODO: other random performance boost: don't pass longs into function parameters, pass long* instead + undirected_edge->directed_edge_indices = realloc(undirected_edge->directed_edge_indices, sizeof(long) * (undirected_edge->num_directed_edges + 1)); + undirected_edge->directed_edge_indices[undirected_edge->num_directed_edges] = directed_edge_index; + undirected_edge->num_directed_edges += 1; +} + +// If there already exists neighbor_index within the Node node, then adds directed_edge to the list of directed edges. +// If there doesn't already exist neighbor_index within the Node node, then create a new entry into the node's neighbors hashmap and add the entry +void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* directed_edge) { + LongToDirectedEdgeList* entry = NULL; + + // Search for the neighbor_index in our hashmap + HASH_FIND(hh, node->neighbors, &neighbor_index, sizeof(long), entry); + + if (entry) { + // We found something, update the list within this pointer + entry->directed_edges_list = realloc(entry->directed_edges_list, sizeof(DirectedEdge*) * (entry->num_directed_edges_in_group + 1)); + entry->directed_edges_list[entry->num_directed_edges_in_group] = directed_edge; + + entry->num_directed_edges_in_group += 1; + } else { + // allocate memory for entry + entry = malloc(sizeof(LongToDirectedEdgeList)); + + // The entry doesn't exist, initialize the entry and enter it into our hashmap + entry->directed_edges_list = malloc(sizeof(DirectedEdge*)); + entry->directed_edges_list[0] = directed_edge; + entry->key = neighbor_index; + + entry->num_directed_edges_in_group = 1; + HASH_ADD(hh, node->neighbors, key, sizeof(long), entry); + + node->num_neighbors += 1; + } +} diff --git a/chgnet/graph/fast_converter_libraries/uthash.h b/chgnet/graph/fast_converter_libraries/uthash.h new file mode 100644 index 00000000..070f7b44 --- /dev/null +++ b/chgnet/graph/fast_converter_libraries/uthash.h @@ -0,0 +1,1140 @@ +/* +Copyright (c) 2003-2022, Troy D. Hanson https://troydhanson.github.io/uthash/ +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER +OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef UTHASH_H +#define UTHASH_H + +#define UTHASH_VERSION 2.3.0 + +#include /* memcmp, memset, strlen */ +#include /* ptrdiff_t */ +#include /* exit */ + +#if defined(HASH_DEFINE_OWN_STDINT) && HASH_DEFINE_OWN_STDINT +/* This codepath is provided for backward compatibility, but I plan to remove it. */ +#warning "HASH_DEFINE_OWN_STDINT is deprecated; please use HASH_NO_STDINT instead" +typedef unsigned int uint32_t; +typedef unsigned char uint8_t; +#elif defined(HASH_NO_STDINT) && HASH_NO_STDINT +#else +#include /* uint8_t, uint32_t */ +#endif + +/* These macros use decltype or the earlier __typeof GNU extension. + As decltype is only available in newer compilers (VS2010 or gcc 4.3+ + when compiling c++ source) this code uses whatever method is needed + or, for VS2008 where neither is available, uses casting workarounds. */ +#if !defined(DECLTYPE) && !defined(NO_DECLTYPE) +#if defined(_MSC_VER) /* MS compiler */ +#if _MSC_VER >= 1600 && defined(__cplusplus) /* VS2010 or newer in C++ mode */ +#define DECLTYPE(x) (decltype(x)) +#else /* VS2008 or older (or VS2010 in C mode) */ +#define NO_DECLTYPE +#endif +#elif defined(__MCST__) /* Elbrus C Compiler */ +#define DECLTYPE(x) (__typeof(x)) +#elif defined(__BORLANDC__) || defined(__ICCARM__) || defined(__LCC__) || defined(__WATCOMC__) +#define NO_DECLTYPE +#else /* GNU, Sun and other compilers */ +#define DECLTYPE(x) (__typeof(x)) +#endif +#endif + +#ifdef NO_DECLTYPE +#define DECLTYPE(x) +#define DECLTYPE_ASSIGN(dst,src) \ +do { \ + char **_da_dst = (char**)(&(dst)); \ + *_da_dst = (char*)(src); \ +} while (0) +#else +#define DECLTYPE_ASSIGN(dst,src) \ +do { \ + (dst) = DECLTYPE(dst)(src); \ +} while (0) +#endif + +#ifndef uthash_malloc +#define uthash_malloc(sz) malloc(sz) /* malloc fcn */ +#endif +#ifndef uthash_free +#define uthash_free(ptr,sz) free(ptr) /* free fcn */ +#endif +#ifndef uthash_bzero +#define uthash_bzero(a,n) memset(a,'\0',n) +#endif +#ifndef uthash_strlen +#define uthash_strlen(s) strlen(s) +#endif + +#ifndef HASH_FUNCTION +#define HASH_FUNCTION(keyptr,keylen,hashv) HASH_JEN(keyptr, keylen, hashv) +#endif + +#ifndef HASH_KEYCMP +#define HASH_KEYCMP(a,b,n) memcmp(a,b,n) +#endif + +#ifndef uthash_noexpand_fyi +#define uthash_noexpand_fyi(tbl) /* can be defined to log noexpand */ +#endif +#ifndef uthash_expand_fyi +#define uthash_expand_fyi(tbl) /* can be defined to log expands */ +#endif + +#ifndef HASH_NONFATAL_OOM +#define HASH_NONFATAL_OOM 0 +#endif + +#if HASH_NONFATAL_OOM +/* malloc failures can be recovered from */ + +#ifndef uthash_nonfatal_oom +#define uthash_nonfatal_oom(obj) do {} while (0) /* non-fatal OOM error */ +#endif + +#define HASH_RECORD_OOM(oomed) do { (oomed) = 1; } while (0) +#define IF_HASH_NONFATAL_OOM(x) x + +#else +/* malloc failures result in lost memory, hash tables are unusable */ + +#ifndef uthash_fatal +#define uthash_fatal(msg) exit(-1) /* fatal OOM error */ +#endif + +#define HASH_RECORD_OOM(oomed) uthash_fatal("out of memory") +#define IF_HASH_NONFATAL_OOM(x) + +#endif + +/* initial number of buckets */ +#define HASH_INITIAL_NUM_BUCKETS 32U /* initial number of buckets */ +#define HASH_INITIAL_NUM_BUCKETS_LOG2 5U /* lg2 of initial number of buckets */ +#define HASH_BKT_CAPACITY_THRESH 10U /* expand when bucket count reaches */ + +/* calculate the element whose hash handle address is hhp */ +#define ELMT_FROM_HH(tbl,hhp) ((void*)(((char*)(hhp)) - ((tbl)->hho))) +/* calculate the hash handle from element address elp */ +#define HH_FROM_ELMT(tbl,elp) ((UT_hash_handle*)(void*)(((char*)(elp)) + ((tbl)->hho))) + +#define HASH_ROLLBACK_BKT(hh, head, itemptrhh) \ +do { \ + struct UT_hash_handle *_hd_hh_item = (itemptrhh); \ + unsigned _hd_bkt; \ + HASH_TO_BKT(_hd_hh_item->hashv, (head)->hh.tbl->num_buckets, _hd_bkt); \ + (head)->hh.tbl->buckets[_hd_bkt].count++; \ + _hd_hh_item->hh_next = NULL; \ + _hd_hh_item->hh_prev = NULL; \ +} while (0) + +#define HASH_VALUE(keyptr,keylen,hashv) \ +do { \ + HASH_FUNCTION(keyptr, keylen, hashv); \ +} while (0) + +#define HASH_FIND_BYHASHVALUE(hh,head,keyptr,keylen,hashval,out) \ +do { \ + (out) = NULL; \ + if (head) { \ + unsigned _hf_bkt; \ + HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _hf_bkt); \ + if (HASH_BLOOM_TEST((head)->hh.tbl, hashval) != 0) { \ + HASH_FIND_IN_BKT((head)->hh.tbl, hh, (head)->hh.tbl->buckets[ _hf_bkt ], keyptr, keylen, hashval, out); \ + } \ + } \ +} while (0) + +#define HASH_FIND(hh,head,keyptr,keylen,out) \ +do { \ + (out) = NULL; \ + if (head) { \ + unsigned _hf_hashv; \ + HASH_VALUE(keyptr, keylen, _hf_hashv); \ + HASH_FIND_BYHASHVALUE(hh, head, keyptr, keylen, _hf_hashv, out); \ + } \ +} while (0) + +#ifdef HASH_BLOOM +#define HASH_BLOOM_BITLEN (1UL << HASH_BLOOM) +#define HASH_BLOOM_BYTELEN (HASH_BLOOM_BITLEN/8UL) + (((HASH_BLOOM_BITLEN%8UL)!=0UL) ? 1UL : 0UL) +#define HASH_BLOOM_MAKE(tbl,oomed) \ +do { \ + (tbl)->bloom_nbits = HASH_BLOOM; \ + (tbl)->bloom_bv = (uint8_t*)uthash_malloc(HASH_BLOOM_BYTELEN); \ + if (!(tbl)->bloom_bv) { \ + HASH_RECORD_OOM(oomed); \ + } else { \ + uthash_bzero((tbl)->bloom_bv, HASH_BLOOM_BYTELEN); \ + (tbl)->bloom_sig = HASH_BLOOM_SIGNATURE; \ + } \ +} while (0) + +#define HASH_BLOOM_FREE(tbl) \ +do { \ + uthash_free((tbl)->bloom_bv, HASH_BLOOM_BYTELEN); \ +} while (0) + +#define HASH_BLOOM_BITSET(bv,idx) (bv[(idx)/8U] |= (1U << ((idx)%8U))) +#define HASH_BLOOM_BITTEST(bv,idx) (bv[(idx)/8U] & (1U << ((idx)%8U))) + +#define HASH_BLOOM_ADD(tbl,hashv) \ + HASH_BLOOM_BITSET((tbl)->bloom_bv, ((hashv) & (uint32_t)((1UL << (tbl)->bloom_nbits) - 1U))) + +#define HASH_BLOOM_TEST(tbl,hashv) \ + HASH_BLOOM_BITTEST((tbl)->bloom_bv, ((hashv) & (uint32_t)((1UL << (tbl)->bloom_nbits) - 1U))) + +#else +#define HASH_BLOOM_MAKE(tbl,oomed) +#define HASH_BLOOM_FREE(tbl) +#define HASH_BLOOM_ADD(tbl,hashv) +#define HASH_BLOOM_TEST(tbl,hashv) (1) +#define HASH_BLOOM_BYTELEN 0U +#endif + +#define HASH_MAKE_TABLE(hh,head,oomed) \ +do { \ + (head)->hh.tbl = (UT_hash_table*)uthash_malloc(sizeof(UT_hash_table)); \ + if (!(head)->hh.tbl) { \ + HASH_RECORD_OOM(oomed); \ + } else { \ + uthash_bzero((head)->hh.tbl, sizeof(UT_hash_table)); \ + (head)->hh.tbl->tail = &((head)->hh); \ + (head)->hh.tbl->num_buckets = HASH_INITIAL_NUM_BUCKETS; \ + (head)->hh.tbl->log2_num_buckets = HASH_INITIAL_NUM_BUCKETS_LOG2; \ + (head)->hh.tbl->hho = (char*)(&(head)->hh) - (char*)(head); \ + (head)->hh.tbl->buckets = (UT_hash_bucket*)uthash_malloc( \ + HASH_INITIAL_NUM_BUCKETS * sizeof(struct UT_hash_bucket)); \ + (head)->hh.tbl->signature = HASH_SIGNATURE; \ + if (!(head)->hh.tbl->buckets) { \ + HASH_RECORD_OOM(oomed); \ + uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ + } else { \ + uthash_bzero((head)->hh.tbl->buckets, \ + HASH_INITIAL_NUM_BUCKETS * sizeof(struct UT_hash_bucket)); \ + HASH_BLOOM_MAKE((head)->hh.tbl, oomed); \ + IF_HASH_NONFATAL_OOM( \ + if (oomed) { \ + uthash_free((head)->hh.tbl->buckets, \ + HASH_INITIAL_NUM_BUCKETS*sizeof(struct UT_hash_bucket)); \ + uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ + } \ + ) \ + } \ + } \ +} while (0) + +#define HASH_REPLACE_BYHASHVALUE_INORDER(hh,head,fieldname,keylen_in,hashval,add,replaced,cmpfcn) \ +do { \ + (replaced) = NULL; \ + HASH_FIND_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, replaced); \ + if (replaced) { \ + HASH_DELETE(hh, head, replaced); \ + } \ + HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, &((add)->fieldname), keylen_in, hashval, add, cmpfcn); \ +} while (0) + +#define HASH_REPLACE_BYHASHVALUE(hh,head,fieldname,keylen_in,hashval,add,replaced) \ +do { \ + (replaced) = NULL; \ + HASH_FIND_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, replaced); \ + if (replaced) { \ + HASH_DELETE(hh, head, replaced); \ + } \ + HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, add); \ +} while (0) + +#define HASH_REPLACE(hh,head,fieldname,keylen_in,add,replaced) \ +do { \ + unsigned _hr_hashv; \ + HASH_VALUE(&((add)->fieldname), keylen_in, _hr_hashv); \ + HASH_REPLACE_BYHASHVALUE(hh, head, fieldname, keylen_in, _hr_hashv, add, replaced); \ +} while (0) + +#define HASH_REPLACE_INORDER(hh,head,fieldname,keylen_in,add,replaced,cmpfcn) \ +do { \ + unsigned _hr_hashv; \ + HASH_VALUE(&((add)->fieldname), keylen_in, _hr_hashv); \ + HASH_REPLACE_BYHASHVALUE_INORDER(hh, head, fieldname, keylen_in, _hr_hashv, add, replaced, cmpfcn); \ +} while (0) + +#define HASH_APPEND_LIST(hh, head, add) \ +do { \ + (add)->hh.next = NULL; \ + (add)->hh.prev = ELMT_FROM_HH((head)->hh.tbl, (head)->hh.tbl->tail); \ + (head)->hh.tbl->tail->next = (add); \ + (head)->hh.tbl->tail = &((add)->hh); \ +} while (0) + +#define HASH_AKBI_INNER_LOOP(hh,head,add,cmpfcn) \ +do { \ + do { \ + if (cmpfcn(DECLTYPE(head)(_hs_iter), add) > 0) { \ + break; \ + } \ + } while ((_hs_iter = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->next)); \ +} while (0) + +#ifdef NO_DECLTYPE +#undef HASH_AKBI_INNER_LOOP +#define HASH_AKBI_INNER_LOOP(hh,head,add,cmpfcn) \ +do { \ + char *_hs_saved_head = (char*)(head); \ + do { \ + DECLTYPE_ASSIGN(head, _hs_iter); \ + if (cmpfcn(head, add) > 0) { \ + DECLTYPE_ASSIGN(head, _hs_saved_head); \ + break; \ + } \ + DECLTYPE_ASSIGN(head, _hs_saved_head); \ + } while ((_hs_iter = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->next)); \ +} while (0) +#endif + +#if HASH_NONFATAL_OOM + +#define HASH_ADD_TO_TABLE(hh,head,keyptr,keylen_in,hashval,add,oomed) \ +do { \ + if (!(oomed)) { \ + unsigned _ha_bkt; \ + (head)->hh.tbl->num_items++; \ + HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _ha_bkt); \ + HASH_ADD_TO_BKT((head)->hh.tbl->buckets[_ha_bkt], hh, &(add)->hh, oomed); \ + if (oomed) { \ + HASH_ROLLBACK_BKT(hh, head, &(add)->hh); \ + HASH_DELETE_HH(hh, head, &(add)->hh); \ + (add)->hh.tbl = NULL; \ + uthash_nonfatal_oom(add); \ + } else { \ + HASH_BLOOM_ADD((head)->hh.tbl, hashval); \ + HASH_EMIT_KEY(hh, head, keyptr, keylen_in); \ + } \ + } else { \ + (add)->hh.tbl = NULL; \ + uthash_nonfatal_oom(add); \ + } \ +} while (0) + +#else + +#define HASH_ADD_TO_TABLE(hh,head,keyptr,keylen_in,hashval,add,oomed) \ +do { \ + unsigned _ha_bkt; \ + (head)->hh.tbl->num_items++; \ + HASH_TO_BKT(hashval, (head)->hh.tbl->num_buckets, _ha_bkt); \ + HASH_ADD_TO_BKT((head)->hh.tbl->buckets[_ha_bkt], hh, &(add)->hh, oomed); \ + HASH_BLOOM_ADD((head)->hh.tbl, hashval); \ + HASH_EMIT_KEY(hh, head, keyptr, keylen_in); \ +} while (0) + +#endif + + +#define HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh,head,keyptr,keylen_in,hashval,add,cmpfcn) \ +do { \ + IF_HASH_NONFATAL_OOM( int _ha_oomed = 0; ) \ + (add)->hh.hashv = (hashval); \ + (add)->hh.key = (char*) (keyptr); \ + (add)->hh.keylen = (unsigned) (keylen_in); \ + if (!(head)) { \ + (add)->hh.next = NULL; \ + (add)->hh.prev = NULL; \ + HASH_MAKE_TABLE(hh, add, _ha_oomed); \ + IF_HASH_NONFATAL_OOM( if (!_ha_oomed) { ) \ + (head) = (add); \ + IF_HASH_NONFATAL_OOM( } ) \ + } else { \ + void *_hs_iter = (head); \ + (add)->hh.tbl = (head)->hh.tbl; \ + HASH_AKBI_INNER_LOOP(hh, head, add, cmpfcn); \ + if (_hs_iter) { \ + (add)->hh.next = _hs_iter; \ + if (((add)->hh.prev = HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->prev)) { \ + HH_FROM_ELMT((head)->hh.tbl, (add)->hh.prev)->next = (add); \ + } else { \ + (head) = (add); \ + } \ + HH_FROM_ELMT((head)->hh.tbl, _hs_iter)->prev = (add); \ + } else { \ + HASH_APPEND_LIST(hh, head, add); \ + } \ + } \ + HASH_ADD_TO_TABLE(hh, head, keyptr, keylen_in, hashval, add, _ha_oomed); \ + HASH_FSCK(hh, head, "HASH_ADD_KEYPTR_BYHASHVALUE_INORDER"); \ +} while (0) + +#define HASH_ADD_KEYPTR_INORDER(hh,head,keyptr,keylen_in,add,cmpfcn) \ +do { \ + unsigned _hs_hashv; \ + HASH_VALUE(keyptr, keylen_in, _hs_hashv); \ + HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, keyptr, keylen_in, _hs_hashv, add, cmpfcn); \ +} while (0) + +#define HASH_ADD_BYHASHVALUE_INORDER(hh,head,fieldname,keylen_in,hashval,add,cmpfcn) \ + HASH_ADD_KEYPTR_BYHASHVALUE_INORDER(hh, head, &((add)->fieldname), keylen_in, hashval, add, cmpfcn) + +#define HASH_ADD_INORDER(hh,head,fieldname,keylen_in,add,cmpfcn) \ + HASH_ADD_KEYPTR_INORDER(hh, head, &((add)->fieldname), keylen_in, add, cmpfcn) + +#define HASH_ADD_KEYPTR_BYHASHVALUE(hh,head,keyptr,keylen_in,hashval,add) \ +do { \ + IF_HASH_NONFATAL_OOM( int _ha_oomed = 0; ) \ + (add)->hh.hashv = (hashval); \ + (add)->hh.key = (const void*) (keyptr); \ + (add)->hh.keylen = (unsigned) (keylen_in); \ + if (!(head)) { \ + (add)->hh.next = NULL; \ + (add)->hh.prev = NULL; \ + HASH_MAKE_TABLE(hh, add, _ha_oomed); \ + IF_HASH_NONFATAL_OOM( if (!_ha_oomed) { ) \ + (head) = (add); \ + IF_HASH_NONFATAL_OOM( } ) \ + } else { \ + (add)->hh.tbl = (head)->hh.tbl; \ + HASH_APPEND_LIST(hh, head, add); \ + } \ + HASH_ADD_TO_TABLE(hh, head, keyptr, keylen_in, hashval, add, _ha_oomed); \ + HASH_FSCK(hh, head, "HASH_ADD_KEYPTR_BYHASHVALUE"); \ +} while (0) + +#define HASH_ADD_KEYPTR(hh,head,keyptr,keylen_in,add) \ +do { \ + unsigned _ha_hashv; \ + HASH_VALUE(keyptr, keylen_in, _ha_hashv); \ + HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, keyptr, keylen_in, _ha_hashv, add); \ +} while (0) + +#define HASH_ADD_BYHASHVALUE(hh,head,fieldname,keylen_in,hashval,add) \ + HASH_ADD_KEYPTR_BYHASHVALUE(hh, head, &((add)->fieldname), keylen_in, hashval, add) + +#define HASH_ADD(hh,head,fieldname,keylen_in,add) \ + HASH_ADD_KEYPTR(hh, head, &((add)->fieldname), keylen_in, add) + +#define HASH_TO_BKT(hashv,num_bkts,bkt) \ +do { \ + bkt = ((hashv) & ((num_bkts) - 1U)); \ +} while (0) + +/* delete "delptr" from the hash table. + * "the usual" patch-up process for the app-order doubly-linked-list. + * The use of _hd_hh_del below deserves special explanation. + * These used to be expressed using (delptr) but that led to a bug + * if someone used the same symbol for the head and deletee, like + * HASH_DELETE(hh,users,users); + * We want that to work, but by changing the head (users) below + * we were forfeiting our ability to further refer to the deletee (users) + * in the patch-up process. Solution: use scratch space to + * copy the deletee pointer, then the latter references are via that + * scratch pointer rather than through the repointed (users) symbol. + */ +#define HASH_DELETE(hh,head,delptr) \ + HASH_DELETE_HH(hh, head, &(delptr)->hh) + +#define HASH_DELETE_HH(hh,head,delptrhh) \ +do { \ + const struct UT_hash_handle *_hd_hh_del = (delptrhh); \ + if ((_hd_hh_del->prev == NULL) && (_hd_hh_del->next == NULL)) { \ + HASH_BLOOM_FREE((head)->hh.tbl); \ + uthash_free((head)->hh.tbl->buckets, \ + (head)->hh.tbl->num_buckets * sizeof(struct UT_hash_bucket)); \ + uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ + (head) = NULL; \ + } else { \ + unsigned _hd_bkt; \ + if (_hd_hh_del == (head)->hh.tbl->tail) { \ + (head)->hh.tbl->tail = HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->prev); \ + } \ + if (_hd_hh_del->prev != NULL) { \ + HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->prev)->next = _hd_hh_del->next; \ + } else { \ + DECLTYPE_ASSIGN(head, _hd_hh_del->next); \ + } \ + if (_hd_hh_del->next != NULL) { \ + HH_FROM_ELMT((head)->hh.tbl, _hd_hh_del->next)->prev = _hd_hh_del->prev; \ + } \ + HASH_TO_BKT(_hd_hh_del->hashv, (head)->hh.tbl->num_buckets, _hd_bkt); \ + HASH_DEL_IN_BKT((head)->hh.tbl->buckets[_hd_bkt], _hd_hh_del); \ + (head)->hh.tbl->num_items--; \ + } \ + HASH_FSCK(hh, head, "HASH_DELETE_HH"); \ +} while (0) + +/* convenience forms of HASH_FIND/HASH_ADD/HASH_DEL */ +#define HASH_FIND_STR(head,findstr,out) \ +do { \ + unsigned _uthash_hfstr_keylen = (unsigned)uthash_strlen(findstr); \ + HASH_FIND(hh, head, findstr, _uthash_hfstr_keylen, out); \ +} while (0) +#define HASH_ADD_STR(head,strfield,add) \ +do { \ + unsigned _uthash_hastr_keylen = (unsigned)uthash_strlen((add)->strfield); \ + HASH_ADD(hh, head, strfield[0], _uthash_hastr_keylen, add); \ +} while (0) +#define HASH_REPLACE_STR(head,strfield,add,replaced) \ +do { \ + unsigned _uthash_hrstr_keylen = (unsigned)uthash_strlen((add)->strfield); \ + HASH_REPLACE(hh, head, strfield[0], _uthash_hrstr_keylen, add, replaced); \ +} while (0) +#define HASH_FIND_INT(head,findint,out) \ + HASH_FIND(hh,head,findint,sizeof(int),out) +#define HASH_ADD_INT(head,intfield,add) \ + HASH_ADD(hh,head,intfield,sizeof(int),add) +#define HASH_REPLACE_INT(head,intfield,add,replaced) \ + HASH_REPLACE(hh,head,intfield,sizeof(int),add,replaced) +#define HASH_FIND_PTR(head,findptr,out) \ + HASH_FIND(hh,head,findptr,sizeof(void *),out) +#define HASH_ADD_PTR(head,ptrfield,add) \ + HASH_ADD(hh,head,ptrfield,sizeof(void *),add) +#define HASH_REPLACE_PTR(head,ptrfield,add,replaced) \ + HASH_REPLACE(hh,head,ptrfield,sizeof(void *),add,replaced) +#define HASH_DEL(head,delptr) \ + HASH_DELETE(hh,head,delptr) + +/* HASH_FSCK checks hash integrity on every add/delete when HASH_DEBUG is defined. + * This is for uthash developer only; it compiles away if HASH_DEBUG isn't defined. + */ +#ifdef HASH_DEBUG +#include /* fprintf, stderr */ +#define HASH_OOPS(...) do { fprintf(stderr, __VA_ARGS__); exit(-1); } while (0) +#define HASH_FSCK(hh,head,where) \ +do { \ + struct UT_hash_handle *_thh; \ + if (head) { \ + unsigned _bkt_i; \ + unsigned _count = 0; \ + char *_prev; \ + for (_bkt_i = 0; _bkt_i < (head)->hh.tbl->num_buckets; ++_bkt_i) { \ + unsigned _bkt_count = 0; \ + _thh = (head)->hh.tbl->buckets[_bkt_i].hh_head; \ + _prev = NULL; \ + while (_thh) { \ + if (_prev != (char*)(_thh->hh_prev)) { \ + HASH_OOPS("%s: invalid hh_prev %p, actual %p\n", \ + (where), (void*)_thh->hh_prev, (void*)_prev); \ + } \ + _bkt_count++; \ + _prev = (char*)(_thh); \ + _thh = _thh->hh_next; \ + } \ + _count += _bkt_count; \ + if ((head)->hh.tbl->buckets[_bkt_i].count != _bkt_count) { \ + HASH_OOPS("%s: invalid bucket count %u, actual %u\n", \ + (where), (head)->hh.tbl->buckets[_bkt_i].count, _bkt_count); \ + } \ + } \ + if (_count != (head)->hh.tbl->num_items) { \ + HASH_OOPS("%s: invalid hh item count %u, actual %u\n", \ + (where), (head)->hh.tbl->num_items, _count); \ + } \ + _count = 0; \ + _prev = NULL; \ + _thh = &(head)->hh; \ + while (_thh) { \ + _count++; \ + if (_prev != (char*)_thh->prev) { \ + HASH_OOPS("%s: invalid prev %p, actual %p\n", \ + (where), (void*)_thh->prev, (void*)_prev); \ + } \ + _prev = (char*)ELMT_FROM_HH((head)->hh.tbl, _thh); \ + _thh = (_thh->next ? HH_FROM_ELMT((head)->hh.tbl, _thh->next) : NULL); \ + } \ + if (_count != (head)->hh.tbl->num_items) { \ + HASH_OOPS("%s: invalid app item count %u, actual %u\n", \ + (where), (head)->hh.tbl->num_items, _count); \ + } \ + } \ +} while (0) +#else +#define HASH_FSCK(hh,head,where) +#endif + +/* When compiled with -DHASH_EMIT_KEYS, length-prefixed keys are emitted to + * the descriptor to which this macro is defined for tuning the hash function. + * The app can #include to get the prototype for write(2). */ +#ifdef HASH_EMIT_KEYS +#define HASH_EMIT_KEY(hh,head,keyptr,fieldlen) \ +do { \ + unsigned _klen = fieldlen; \ + write(HASH_EMIT_KEYS, &_klen, sizeof(_klen)); \ + write(HASH_EMIT_KEYS, keyptr, (unsigned long)fieldlen); \ +} while (0) +#else +#define HASH_EMIT_KEY(hh,head,keyptr,fieldlen) +#endif + +/* The Bernstein hash function, used in Perl prior to v5.6. Note (x<<5+x)=x*33. */ +#define HASH_BER(key,keylen,hashv) \ +do { \ + unsigned _hb_keylen = (unsigned)keylen; \ + const unsigned char *_hb_key = (const unsigned char*)(key); \ + (hashv) = 0; \ + while (_hb_keylen-- != 0U) { \ + (hashv) = (((hashv) << 5) + (hashv)) + *_hb_key++; \ + } \ +} while (0) + + +/* SAX/FNV/OAT/JEN hash functions are macro variants of those listed at + * http://eternallyconfuzzled.com/tuts/algorithms/jsw_tut_hashing.aspx + * (archive link: https://archive.is/Ivcan ) + */ +#define HASH_SAX(key,keylen,hashv) \ +do { \ + unsigned _sx_i; \ + const unsigned char *_hs_key = (const unsigned char*)(key); \ + hashv = 0; \ + for (_sx_i=0; _sx_i < keylen; _sx_i++) { \ + hashv ^= (hashv << 5) + (hashv >> 2) + _hs_key[_sx_i]; \ + } \ +} while (0) +/* FNV-1a variation */ +#define HASH_FNV(key,keylen,hashv) \ +do { \ + unsigned _fn_i; \ + const unsigned char *_hf_key = (const unsigned char*)(key); \ + (hashv) = 2166136261U; \ + for (_fn_i=0; _fn_i < keylen; _fn_i++) { \ + hashv = hashv ^ _hf_key[_fn_i]; \ + hashv = hashv * 16777619U; \ + } \ +} while (0) + +#define HASH_OAT(key,keylen,hashv) \ +do { \ + unsigned _ho_i; \ + const unsigned char *_ho_key=(const unsigned char*)(key); \ + hashv = 0; \ + for(_ho_i=0; _ho_i < keylen; _ho_i++) { \ + hashv += _ho_key[_ho_i]; \ + hashv += (hashv << 10); \ + hashv ^= (hashv >> 6); \ + } \ + hashv += (hashv << 3); \ + hashv ^= (hashv >> 11); \ + hashv += (hashv << 15); \ +} while (0) + +#define HASH_JEN_MIX(a,b,c) \ +do { \ + a -= b; a -= c; a ^= ( c >> 13 ); \ + b -= c; b -= a; b ^= ( a << 8 ); \ + c -= a; c -= b; c ^= ( b >> 13 ); \ + a -= b; a -= c; a ^= ( c >> 12 ); \ + b -= c; b -= a; b ^= ( a << 16 ); \ + c -= a; c -= b; c ^= ( b >> 5 ); \ + a -= b; a -= c; a ^= ( c >> 3 ); \ + b -= c; b -= a; b ^= ( a << 10 ); \ + c -= a; c -= b; c ^= ( b >> 15 ); \ +} while (0) + +#define HASH_JEN(key,keylen,hashv) \ +do { \ + unsigned _hj_i,_hj_j,_hj_k; \ + unsigned const char *_hj_key=(unsigned const char*)(key); \ + hashv = 0xfeedbeefu; \ + _hj_i = _hj_j = 0x9e3779b9u; \ + _hj_k = (unsigned)(keylen); \ + while (_hj_k >= 12U) { \ + _hj_i += (_hj_key[0] + ( (unsigned)_hj_key[1] << 8 ) \ + + ( (unsigned)_hj_key[2] << 16 ) \ + + ( (unsigned)_hj_key[3] << 24 ) ); \ + _hj_j += (_hj_key[4] + ( (unsigned)_hj_key[5] << 8 ) \ + + ( (unsigned)_hj_key[6] << 16 ) \ + + ( (unsigned)_hj_key[7] << 24 ) ); \ + hashv += (_hj_key[8] + ( (unsigned)_hj_key[9] << 8 ) \ + + ( (unsigned)_hj_key[10] << 16 ) \ + + ( (unsigned)_hj_key[11] << 24 ) ); \ + \ + HASH_JEN_MIX(_hj_i, _hj_j, hashv); \ + \ + _hj_key += 12; \ + _hj_k -= 12U; \ + } \ + hashv += (unsigned)(keylen); \ + switch ( _hj_k ) { \ + case 11: hashv += ( (unsigned)_hj_key[10] << 24 ); /* FALLTHROUGH */ \ + case 10: hashv += ( (unsigned)_hj_key[9] << 16 ); /* FALLTHROUGH */ \ + case 9: hashv += ( (unsigned)_hj_key[8] << 8 ); /* FALLTHROUGH */ \ + case 8: _hj_j += ( (unsigned)_hj_key[7] << 24 ); /* FALLTHROUGH */ \ + case 7: _hj_j += ( (unsigned)_hj_key[6] << 16 ); /* FALLTHROUGH */ \ + case 6: _hj_j += ( (unsigned)_hj_key[5] << 8 ); /* FALLTHROUGH */ \ + case 5: _hj_j += _hj_key[4]; /* FALLTHROUGH */ \ + case 4: _hj_i += ( (unsigned)_hj_key[3] << 24 ); /* FALLTHROUGH */ \ + case 3: _hj_i += ( (unsigned)_hj_key[2] << 16 ); /* FALLTHROUGH */ \ + case 2: _hj_i += ( (unsigned)_hj_key[1] << 8 ); /* FALLTHROUGH */ \ + case 1: _hj_i += _hj_key[0]; /* FALLTHROUGH */ \ + default: ; \ + } \ + HASH_JEN_MIX(_hj_i, _hj_j, hashv); \ +} while (0) + +/* The Paul Hsieh hash function */ +#undef get16bits +#if (defined(__GNUC__) && defined(__i386__)) || defined(__WATCOMC__) \ + || defined(_MSC_VER) || defined (__BORLANDC__) || defined (__TURBOC__) +#define get16bits(d) (*((const uint16_t *) (d))) +#endif + +#if !defined (get16bits) +#define get16bits(d) ((((uint32_t)(((const uint8_t *)(d))[1])) << 8) \ + +(uint32_t)(((const uint8_t *)(d))[0]) ) +#endif +#define HASH_SFH(key,keylen,hashv) \ +do { \ + unsigned const char *_sfh_key=(unsigned const char*)(key); \ + uint32_t _sfh_tmp, _sfh_len = (uint32_t)keylen; \ + \ + unsigned _sfh_rem = _sfh_len & 3U; \ + _sfh_len >>= 2; \ + hashv = 0xcafebabeu; \ + \ + /* Main loop */ \ + for (;_sfh_len > 0U; _sfh_len--) { \ + hashv += get16bits (_sfh_key); \ + _sfh_tmp = ((uint32_t)(get16bits (_sfh_key+2)) << 11) ^ hashv; \ + hashv = (hashv << 16) ^ _sfh_tmp; \ + _sfh_key += 2U*sizeof (uint16_t); \ + hashv += hashv >> 11; \ + } \ + \ + /* Handle end cases */ \ + switch (_sfh_rem) { \ + case 3: hashv += get16bits (_sfh_key); \ + hashv ^= hashv << 16; \ + hashv ^= (uint32_t)(_sfh_key[sizeof (uint16_t)]) << 18; \ + hashv += hashv >> 11; \ + break; \ + case 2: hashv += get16bits (_sfh_key); \ + hashv ^= hashv << 11; \ + hashv += hashv >> 17; \ + break; \ + case 1: hashv += *_sfh_key; \ + hashv ^= hashv << 10; \ + hashv += hashv >> 1; \ + break; \ + default: ; \ + } \ + \ + /* Force "avalanching" of final 127 bits */ \ + hashv ^= hashv << 3; \ + hashv += hashv >> 5; \ + hashv ^= hashv << 4; \ + hashv += hashv >> 17; \ + hashv ^= hashv << 25; \ + hashv += hashv >> 6; \ +} while (0) + +/* iterate over items in a known bucket to find desired item */ +#define HASH_FIND_IN_BKT(tbl,hh,head,keyptr,keylen_in,hashval,out) \ +do { \ + if ((head).hh_head != NULL) { \ + DECLTYPE_ASSIGN(out, ELMT_FROM_HH(tbl, (head).hh_head)); \ + } else { \ + (out) = NULL; \ + } \ + while ((out) != NULL) { \ + if ((out)->hh.hashv == (hashval) && (out)->hh.keylen == (keylen_in)) { \ + if (HASH_KEYCMP((out)->hh.key, keyptr, keylen_in) == 0) { \ + break; \ + } \ + } \ + if ((out)->hh.hh_next != NULL) { \ + DECLTYPE_ASSIGN(out, ELMT_FROM_HH(tbl, (out)->hh.hh_next)); \ + } else { \ + (out) = NULL; \ + } \ + } \ +} while (0) + +/* add an item to a bucket */ +#define HASH_ADD_TO_BKT(head,hh,addhh,oomed) \ +do { \ + UT_hash_bucket *_ha_head = &(head); \ + _ha_head->count++; \ + (addhh)->hh_next = _ha_head->hh_head; \ + (addhh)->hh_prev = NULL; \ + if (_ha_head->hh_head != NULL) { \ + _ha_head->hh_head->hh_prev = (addhh); \ + } \ + _ha_head->hh_head = (addhh); \ + if ((_ha_head->count >= ((_ha_head->expand_mult + 1U) * HASH_BKT_CAPACITY_THRESH)) \ + && !(addhh)->tbl->noexpand) { \ + HASH_EXPAND_BUCKETS(addhh,(addhh)->tbl, oomed); \ + IF_HASH_NONFATAL_OOM( \ + if (oomed) { \ + HASH_DEL_IN_BKT(head,addhh); \ + } \ + ) \ + } \ +} while (0) + +/* remove an item from a given bucket */ +#define HASH_DEL_IN_BKT(head,delhh) \ +do { \ + UT_hash_bucket *_hd_head = &(head); \ + _hd_head->count--; \ + if (_hd_head->hh_head == (delhh)) { \ + _hd_head->hh_head = (delhh)->hh_next; \ + } \ + if ((delhh)->hh_prev) { \ + (delhh)->hh_prev->hh_next = (delhh)->hh_next; \ + } \ + if ((delhh)->hh_next) { \ + (delhh)->hh_next->hh_prev = (delhh)->hh_prev; \ + } \ +} while (0) + +/* Bucket expansion has the effect of doubling the number of buckets + * and redistributing the items into the new buckets. Ideally the + * items will distribute more or less evenly into the new buckets + * (the extent to which this is true is a measure of the quality of + * the hash function as it applies to the key domain). + * + * With the items distributed into more buckets, the chain length + * (item count) in each bucket is reduced. Thus by expanding buckets + * the hash keeps a bound on the chain length. This bounded chain + * length is the essence of how a hash provides constant time lookup. + * + * The calculation of tbl->ideal_chain_maxlen below deserves some + * explanation. First, keep in mind that we're calculating the ideal + * maximum chain length based on the *new* (doubled) bucket count. + * In fractions this is just n/b (n=number of items,b=new num buckets). + * Since the ideal chain length is an integer, we want to calculate + * ceil(n/b). We don't depend on floating point arithmetic in this + * hash, so to calculate ceil(n/b) with integers we could write + * + * ceil(n/b) = (n/b) + ((n%b)?1:0) + * + * and in fact a previous version of this hash did just that. + * But now we have improved things a bit by recognizing that b is + * always a power of two. We keep its base 2 log handy (call it lb), + * so now we can write this with a bit shift and logical AND: + * + * ceil(n/b) = (n>>lb) + ( (n & (b-1)) ? 1:0) + * + */ +#define HASH_EXPAND_BUCKETS(hh,tbl,oomed) \ +do { \ + unsigned _he_bkt; \ + unsigned _he_bkt_i; \ + struct UT_hash_handle *_he_thh, *_he_hh_nxt; \ + UT_hash_bucket *_he_new_buckets, *_he_newbkt; \ + _he_new_buckets = (UT_hash_bucket*)uthash_malloc( \ + sizeof(struct UT_hash_bucket) * (tbl)->num_buckets * 2U); \ + if (!_he_new_buckets) { \ + HASH_RECORD_OOM(oomed); \ + } else { \ + uthash_bzero(_he_new_buckets, \ + sizeof(struct UT_hash_bucket) * (tbl)->num_buckets * 2U); \ + (tbl)->ideal_chain_maxlen = \ + ((tbl)->num_items >> ((tbl)->log2_num_buckets+1U)) + \ + ((((tbl)->num_items & (((tbl)->num_buckets*2U)-1U)) != 0U) ? 1U : 0U); \ + (tbl)->nonideal_items = 0; \ + for (_he_bkt_i = 0; _he_bkt_i < (tbl)->num_buckets; _he_bkt_i++) { \ + _he_thh = (tbl)->buckets[ _he_bkt_i ].hh_head; \ + while (_he_thh != NULL) { \ + _he_hh_nxt = _he_thh->hh_next; \ + HASH_TO_BKT(_he_thh->hashv, (tbl)->num_buckets * 2U, _he_bkt); \ + _he_newbkt = &(_he_new_buckets[_he_bkt]); \ + if (++(_he_newbkt->count) > (tbl)->ideal_chain_maxlen) { \ + (tbl)->nonideal_items++; \ + if (_he_newbkt->count > _he_newbkt->expand_mult * (tbl)->ideal_chain_maxlen) { \ + _he_newbkt->expand_mult++; \ + } \ + } \ + _he_thh->hh_prev = NULL; \ + _he_thh->hh_next = _he_newbkt->hh_head; \ + if (_he_newbkt->hh_head != NULL) { \ + _he_newbkt->hh_head->hh_prev = _he_thh; \ + } \ + _he_newbkt->hh_head = _he_thh; \ + _he_thh = _he_hh_nxt; \ + } \ + } \ + uthash_free((tbl)->buckets, (tbl)->num_buckets * sizeof(struct UT_hash_bucket)); \ + (tbl)->num_buckets *= 2U; \ + (tbl)->log2_num_buckets++; \ + (tbl)->buckets = _he_new_buckets; \ + (tbl)->ineff_expands = ((tbl)->nonideal_items > ((tbl)->num_items >> 1)) ? \ + ((tbl)->ineff_expands+1U) : 0U; \ + if ((tbl)->ineff_expands > 1U) { \ + (tbl)->noexpand = 1; \ + uthash_noexpand_fyi(tbl); \ + } \ + uthash_expand_fyi(tbl); \ + } \ +} while (0) + + +/* This is an adaptation of Simon Tatham's O(n log(n)) mergesort */ +/* Note that HASH_SORT assumes the hash handle name to be hh. + * HASH_SRT was added to allow the hash handle name to be passed in. */ +#define HASH_SORT(head,cmpfcn) HASH_SRT(hh,head,cmpfcn) +#define HASH_SRT(hh,head,cmpfcn) \ +do { \ + unsigned _hs_i; \ + unsigned _hs_looping,_hs_nmerges,_hs_insize,_hs_psize,_hs_qsize; \ + struct UT_hash_handle *_hs_p, *_hs_q, *_hs_e, *_hs_list, *_hs_tail; \ + if (head != NULL) { \ + _hs_insize = 1; \ + _hs_looping = 1; \ + _hs_list = &((head)->hh); \ + while (_hs_looping != 0U) { \ + _hs_p = _hs_list; \ + _hs_list = NULL; \ + _hs_tail = NULL; \ + _hs_nmerges = 0; \ + while (_hs_p != NULL) { \ + _hs_nmerges++; \ + _hs_q = _hs_p; \ + _hs_psize = 0; \ + for (_hs_i = 0; _hs_i < _hs_insize; ++_hs_i) { \ + _hs_psize++; \ + _hs_q = ((_hs_q->next != NULL) ? \ + HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \ + if (_hs_q == NULL) { \ + break; \ + } \ + } \ + _hs_qsize = _hs_insize; \ + while ((_hs_psize != 0U) || ((_hs_qsize != 0U) && (_hs_q != NULL))) { \ + if (_hs_psize == 0U) { \ + _hs_e = _hs_q; \ + _hs_q = ((_hs_q->next != NULL) ? \ + HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \ + _hs_qsize--; \ + } else if ((_hs_qsize == 0U) || (_hs_q == NULL)) { \ + _hs_e = _hs_p; \ + if (_hs_p != NULL) { \ + _hs_p = ((_hs_p->next != NULL) ? \ + HH_FROM_ELMT((head)->hh.tbl, _hs_p->next) : NULL); \ + } \ + _hs_psize--; \ + } else if ((cmpfcn( \ + DECLTYPE(head)(ELMT_FROM_HH((head)->hh.tbl, _hs_p)), \ + DECLTYPE(head)(ELMT_FROM_HH((head)->hh.tbl, _hs_q)) \ + )) <= 0) { \ + _hs_e = _hs_p; \ + if (_hs_p != NULL) { \ + _hs_p = ((_hs_p->next != NULL) ? \ + HH_FROM_ELMT((head)->hh.tbl, _hs_p->next) : NULL); \ + } \ + _hs_psize--; \ + } else { \ + _hs_e = _hs_q; \ + _hs_q = ((_hs_q->next != NULL) ? \ + HH_FROM_ELMT((head)->hh.tbl, _hs_q->next) : NULL); \ + _hs_qsize--; \ + } \ + if ( _hs_tail != NULL ) { \ + _hs_tail->next = ((_hs_e != NULL) ? \ + ELMT_FROM_HH((head)->hh.tbl, _hs_e) : NULL); \ + } else { \ + _hs_list = _hs_e; \ + } \ + if (_hs_e != NULL) { \ + _hs_e->prev = ((_hs_tail != NULL) ? \ + ELMT_FROM_HH((head)->hh.tbl, _hs_tail) : NULL); \ + } \ + _hs_tail = _hs_e; \ + } \ + _hs_p = _hs_q; \ + } \ + if (_hs_tail != NULL) { \ + _hs_tail->next = NULL; \ + } \ + if (_hs_nmerges <= 1U) { \ + _hs_looping = 0; \ + (head)->hh.tbl->tail = _hs_tail; \ + DECLTYPE_ASSIGN(head, ELMT_FROM_HH((head)->hh.tbl, _hs_list)); \ + } \ + _hs_insize *= 2U; \ + } \ + HASH_FSCK(hh, head, "HASH_SRT"); \ + } \ +} while (0) + +/* This function selects items from one hash into another hash. + * The end result is that the selected items have dual presence + * in both hashes. There is no copy of the items made; rather + * they are added into the new hash through a secondary hash + * hash handle that must be present in the structure. */ +#define HASH_SELECT(hh_dst, dst, hh_src, src, cond) \ +do { \ + unsigned _src_bkt, _dst_bkt; \ + void *_last_elt = NULL, *_elt; \ + UT_hash_handle *_src_hh, *_dst_hh, *_last_elt_hh=NULL; \ + ptrdiff_t _dst_hho = ((char*)(&(dst)->hh_dst) - (char*)(dst)); \ + if ((src) != NULL) { \ + for (_src_bkt=0; _src_bkt < (src)->hh_src.tbl->num_buckets; _src_bkt++) { \ + for (_src_hh = (src)->hh_src.tbl->buckets[_src_bkt].hh_head; \ + _src_hh != NULL; \ + _src_hh = _src_hh->hh_next) { \ + _elt = ELMT_FROM_HH((src)->hh_src.tbl, _src_hh); \ + if (cond(_elt)) { \ + IF_HASH_NONFATAL_OOM( int _hs_oomed = 0; ) \ + _dst_hh = (UT_hash_handle*)(void*)(((char*)_elt) + _dst_hho); \ + _dst_hh->key = _src_hh->key; \ + _dst_hh->keylen = _src_hh->keylen; \ + _dst_hh->hashv = _src_hh->hashv; \ + _dst_hh->prev = _last_elt; \ + _dst_hh->next = NULL; \ + if (_last_elt_hh != NULL) { \ + _last_elt_hh->next = _elt; \ + } \ + if ((dst) == NULL) { \ + DECLTYPE_ASSIGN(dst, _elt); \ + HASH_MAKE_TABLE(hh_dst, dst, _hs_oomed); \ + IF_HASH_NONFATAL_OOM( \ + if (_hs_oomed) { \ + uthash_nonfatal_oom(_elt); \ + (dst) = NULL; \ + continue; \ + } \ + ) \ + } else { \ + _dst_hh->tbl = (dst)->hh_dst.tbl; \ + } \ + HASH_TO_BKT(_dst_hh->hashv, _dst_hh->tbl->num_buckets, _dst_bkt); \ + HASH_ADD_TO_BKT(_dst_hh->tbl->buckets[_dst_bkt], hh_dst, _dst_hh, _hs_oomed); \ + (dst)->hh_dst.tbl->num_items++; \ + IF_HASH_NONFATAL_OOM( \ + if (_hs_oomed) { \ + HASH_ROLLBACK_BKT(hh_dst, dst, _dst_hh); \ + HASH_DELETE_HH(hh_dst, dst, _dst_hh); \ + _dst_hh->tbl = NULL; \ + uthash_nonfatal_oom(_elt); \ + continue; \ + } \ + ) \ + HASH_BLOOM_ADD(_dst_hh->tbl, _dst_hh->hashv); \ + _last_elt = _elt; \ + _last_elt_hh = _dst_hh; \ + } \ + } \ + } \ + } \ + HASH_FSCK(hh_dst, dst, "HASH_SELECT"); \ +} while (0) + +#define HASH_CLEAR(hh,head) \ +do { \ + if ((head) != NULL) { \ + HASH_BLOOM_FREE((head)->hh.tbl); \ + uthash_free((head)->hh.tbl->buckets, \ + (head)->hh.tbl->num_buckets*sizeof(struct UT_hash_bucket)); \ + uthash_free((head)->hh.tbl, sizeof(UT_hash_table)); \ + (head) = NULL; \ + } \ +} while (0) + +#define HASH_OVERHEAD(hh,head) \ + (((head) != NULL) ? ( \ + (size_t)(((head)->hh.tbl->num_items * sizeof(UT_hash_handle)) + \ + ((head)->hh.tbl->num_buckets * sizeof(UT_hash_bucket)) + \ + sizeof(UT_hash_table) + \ + (HASH_BLOOM_BYTELEN))) : 0U) + +#ifdef NO_DECLTYPE +#define HASH_ITER(hh,head,el,tmp) \ +for(((el)=(head)), ((*(char**)(&(tmp)))=(char*)((head!=NULL)?(head)->hh.next:NULL)); \ + (el) != NULL; ((el)=(tmp)), ((*(char**)(&(tmp)))=(char*)((tmp!=NULL)?(tmp)->hh.next:NULL))) +#else +#define HASH_ITER(hh,head,el,tmp) \ +for(((el)=(head)), ((tmp)=DECLTYPE(el)((head!=NULL)?(head)->hh.next:NULL)); \ + (el) != NULL; ((el)=(tmp)), ((tmp)=DECLTYPE(el)((tmp!=NULL)?(tmp)->hh.next:NULL))) +#endif + +/* obtain a count of items in the hash */ +#define HASH_COUNT(head) HASH_CNT(hh,head) +#define HASH_CNT(hh,head) ((head != NULL)?((head)->hh.tbl->num_items):0U) + +typedef struct UT_hash_bucket { + struct UT_hash_handle *hh_head; + unsigned count; + + /* expand_mult is normally set to 0. In this situation, the max chain length + * threshold is enforced at its default value, HASH_BKT_CAPACITY_THRESH. (If + * the bucket's chain exceeds this length, bucket expansion is triggered). + * However, setting expand_mult to a non-zero value delays bucket expansion + * (that would be triggered by additions to this particular bucket) + * until its chain length reaches a *multiple* of HASH_BKT_CAPACITY_THRESH. + * (The multiplier is simply expand_mult+1). The whole idea of this + * multiplier is to reduce bucket expansions, since they are expensive, in + * situations where we know that a particular bucket tends to be overused. + * It is better to let its chain length grow to a longer yet-still-bounded + * value, than to do an O(n) bucket expansion too often. + */ + unsigned expand_mult; + +} UT_hash_bucket; + +/* random signature used only to find hash tables in external analysis */ +#define HASH_SIGNATURE 0xa0111fe1u +#define HASH_BLOOM_SIGNATURE 0xb12220f2u + +typedef struct UT_hash_table { + UT_hash_bucket *buckets; + unsigned num_buckets, log2_num_buckets; + unsigned num_items; + struct UT_hash_handle *tail; /* tail hh in app order, for fast append */ + ptrdiff_t hho; /* hash handle offset (byte pos of hash handle in element */ + + /* in an ideal situation (all buckets used equally), no bucket would have + * more than ceil(#items/#buckets) items. that's the ideal chain length. */ + unsigned ideal_chain_maxlen; + + /* nonideal_items is the number of items in the hash whose chain position + * exceeds the ideal chain maxlen. these items pay the penalty for an uneven + * hash distribution; reaching them in a chain traversal takes >ideal steps */ + unsigned nonideal_items; + + /* ineffective expands occur when a bucket doubling was performed, but + * afterward, more than half the items in the hash had nonideal chain + * positions. If this happens on two consecutive expansions we inhibit any + * further expansion, as it's not helping; this happens when the hash + * function isn't a good fit for the key domain. When expansion is inhibited + * the hash will still work, albeit no longer in constant time. */ + unsigned ineff_expands, noexpand; + + uint32_t signature; /* used only to find hash tables in external analysis */ +#ifdef HASH_BLOOM + uint32_t bloom_sig; /* used only to test bloom exists in external analysis */ + uint8_t *bloom_bv; + uint8_t bloom_nbits; +#endif + +} UT_hash_table; + +typedef struct UT_hash_handle { + struct UT_hash_table *tbl; + void *prev; /* prev element in app order */ + void *next; /* next element in app order */ + struct UT_hash_handle *hh_prev; /* previous hh in bucket order */ + struct UT_hash_handle *hh_next; /* next hh in bucket order */ + const void *key; /* ptr to enclosing struct's key */ + unsigned keylen; /* enclosing struct's key len */ + unsigned hashv; /* result of hash-fcn(key) */ +} UT_hash_handle; + +#endif /* UTHASH_H */ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 83ea8fd7..d9d68855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=65.0"] +requires = ["setuptools>=65.0", "Cython", "wheel"] build-backend = "setuptools.build_meta" [project] diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..6def2c5d --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +from setuptools import Extension, setup +from Cython.Build import cythonize + +extensions = cythonize([Extension("chgnet.graph.cygraph", ["chgnet/graph/cygraph.pyx"])]) + +setup( + ext_modules = extensions +) \ No newline at end of file diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index d0423842..28a24a66 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -5,13 +5,16 @@ from chgnet import ROOT from chgnet.graph import CrystalGraphConverter +from time import perf_counter structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif") converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3) -def test_crystal_graph(): - graph = converter(structure) +def test_crystal_graph_legacy(): + start=perf_counter() + graph = converter(structure, graph_converter="legacy") + print("Legacy test_crystal_graph time:", perf_counter() - start) assert graph.composition == "Li2 Mn2 O4" assert graph.atomic_number.tolist() == [3, 3, 25, 25, 8, 8, 8, 8] @@ -32,10 +35,38 @@ def test_crystal_graph(): assert list(graph.undirected2directed.shape) == [192] assert list(graph.directed2undirected.shape) == [384] +def test_crystal_graph_fast(): + start=perf_counter() + graph = converter(structure, graph_converter="fast") + print("Fast test_crystal_graph time:", perf_counter() - start) -def test_crystal_graph_different_cutoff(): + + assert graph.composition == "Li2 Mn2 O4" + assert graph.atomic_number.tolist() == [3, 3, 25, 25, 8, 8, 8, 8] + assert list(graph.atom_frac_coord.shape) == [8, 3] + assert list(graph.atom_graph.shape) == [384, 2] + assert (graph.atom_graph[:, 0] == 0).sum().item() == 48 + assert (graph.atom_graph[:, 1] == 0).sum().item() == 48 + assert (graph.atom_graph[:, 0] == 4).sum().item() == 48 + assert (graph.atom_graph[:, 0] == 7).sum().item() == 48 + + assert list(graph.bond_graph.shape) == [744, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 72 + assert (graph.bond_graph[:, 1] == 100).sum().item() == 16 + assert (graph.bond_graph[:, 3] == 100).sum().item() == 16 + assert (graph.bond_graph[:, 2] == 348).sum().item() == 8 + assert (graph.bond_graph[:, 4] == 121).sum().item() == 8 + assert list(graph.lattice.shape) == [3, 3] + assert list(graph.undirected2directed.shape) == [192] + assert list(graph.directed2undirected.shape) == [384] + + +def test_crystal_graph_different_cutoff_legacy(): converter = CrystalGraphConverter(atom_graph_cutoff=5.5, bond_graph_cutoff=3.5) - graph = converter(structure) + + start = perf_counter() + graph = converter(structure, graph_converter="legacy") + print("Legacy test_crystal_graph_different_cutoff time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [624, 2] @@ -53,12 +84,40 @@ def test_crystal_graph_different_cutoff(): assert list(graph.undirected2directed.shape) == [312] assert list(graph.directed2undirected.shape) == [624] +def test_crystal_graph_different_cutoff_fast(): + converter = CrystalGraphConverter(atom_graph_cutoff=5.5, bond_graph_cutoff=3.5) + + start=perf_counter() + graph = converter(structure, graph_converter="fast") + print("Fast test_crystal_graph_different_cutoff time:", perf_counter() - start) -def test_crystal_graph_perturb(): + + assert list(graph.atom_frac_coord.shape) == [8, 3] + assert list(graph.atom_graph.shape) == [624, 2] + assert (graph.atom_graph[:, 0] == 5).sum().item() == 78 + assert (graph.atom_graph[:, 1] == 5).sum().item() == 78 + assert (graph.atom_graph[:, 1] == 7).sum().item() == 78 + + assert list(graph.bond_graph.shape) == [2448, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 306 + assert (graph.bond_graph[:, 1] == 100).sum().item() == 0 + assert (graph.bond_graph[:, 3] == 100).sum().item() == 0 + assert (graph.bond_graph[:, 2] == 250).sum().item() == 17 + assert (graph.bond_graph[:, 4] == 50).sum().item() == 17 + assert list(graph.lattice.shape) == [3, 3] + assert list(graph.undirected2directed.shape) == [312] + assert list(graph.directed2undirected.shape) == [624] + + +def test_crystal_graph_perturb_legacy(): np.random.seed(0) structure_perturbed = structure.copy() structure_perturbed.perturb(distance=0.1) - graph = converter(structure_perturbed) + + start=perf_counter() + graph = converter(structure_perturbed, graph_converter="legacy") + print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) + assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [410, 2] @@ -77,11 +136,40 @@ def test_crystal_graph_perturb(): assert list(graph.undirected2directed.shape) == [205] assert list(graph.directed2undirected.shape) == [410] +def test_crystal_graph_perturb_fast(): + np.random.seed(0) + structure_perturbed = structure.copy() + structure_perturbed.perturb(distance=0.1) + + start=perf_counter() + graph = converter(structure_perturbed, graph_converter="fast") + print("Fast test_crystal_graph_perturb time:", perf_counter() - start) -def test_crystal_graph_isotropic_strained(): + assert list(graph.atom_frac_coord.shape) == [8, 3] + assert list(graph.atom_graph.shape) == [410, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 + + assert list(graph.bond_graph.shape) == [688, 5] + print(graph.bond_graph[120, :]) + assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 + assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 + assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 + assert list(graph.lattice.shape) == [3, 3] + assert list(graph.undirected2directed.shape) == [205] + assert list(graph.directed2undirected.shape) == [410] + + +def test_crystal_graph_isotropic_strained_legacy(): structure_strained = structure.copy() structure_strained.apply_strain([0.1, 0.1, 0.1]) - graph = converter(structure_strained) + + start=perf_counter() + graph = converter(structure_strained, graph_converter="legacy") + print("Legacy test_crystal_graph_isotropic_strained time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [264, 2] @@ -94,11 +182,54 @@ def test_crystal_graph_isotropic_strained(): assert list(graph.undirected2directed.shape) == [132] assert list(graph.directed2undirected.shape) == [264] +def test_crystal_graph_isotropic_strained_fast(): + structure_strained = structure.copy() + structure_strained.apply_strain([0.1, 0.1, 0.1]) + + start=perf_counter() + graph = converter(structure_strained, graph_converter="fast") + print("Fast test_crystal_graph_isotropic_strained time:", perf_counter() - start) + + assert list(graph.atom_frac_coord.shape) == [8, 3] + assert list(graph.atom_graph.shape) == [264, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 34 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 34 + assert (graph.atom_graph[:, 0] == 7).sum().item() == 32 -def test_crystal_graph_anisotropic_strained(): + assert list(graph.bond_graph.shape) == [288, 5] + assert list(graph.lattice.shape) == [3, 3] + assert list(graph.undirected2directed.shape) == [132] + assert list(graph.directed2undirected.shape) == [264] + + + +def test_crystal_graph_anisotropic_strained_legacy(): structure_strained = structure.copy() structure_strained.apply_strain([0.2, -0.3, 0.5]) - graph = converter(structure_strained) + + start=perf_counter() + graph = converter(structure_strained, graph_converter="legacy") + print("Legacy test_crystal_graph_anisotropic_strained time:", perf_counter() - start) + + assert list(graph.atom_frac_coord.shape) == [8, 3] + assert list(graph.atom_graph.shape) == [336, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 42 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 42 + assert (graph.atom_graph[:, 0] == 7).sum().item() == 42 + + assert list(graph.bond_graph.shape) == [256, 5] + assert list(graph.lattice.shape) == [3, 3] + assert list(graph.undirected2directed.shape) == [168] + assert list(graph.directed2undirected.shape) == [336] + +def test_crystal_graph_anisotropic_strained_fast(): + structure_strained = structure.copy() + structure_strained.apply_strain([0.2, -0.3, 0.5]) + + start=perf_counter() + graph = converter(structure_strained, graph_converter="fast") + print("Fast test_crystal_graph_anisotropic_strained time:", perf_counter() - start) + assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [336, 2] @@ -112,10 +243,14 @@ def test_crystal_graph_anisotropic_strained(): assert list(graph.directed2undirected.shape) == [336] -def test_crystal_graph_supercell(): +def test_crystal_graph_supercell_legacy(): structure_supercell = structure.copy() structure_supercell.make_supercell([2, 3, 4]) - graph = converter(structure_supercell) + + start=perf_counter() + graph = converter(structure_supercell, graph_converter="legacy") + print("Legacy test_crystal_graph_supercell time:", perf_counter() - start) + assert graph.composition == "Li48 Mn48 O96" assert list(graph.atom_frac_coord.shape) == [192, 3] @@ -134,14 +269,52 @@ def test_crystal_graph_supercell(): assert list(graph.undirected2directed.shape) == [4608] assert list(graph.directed2undirected.shape) == [9216] +def test_crystal_graph_supercell_fast(): + structure_supercell = structure.copy() + structure_supercell.make_supercell([2, 3, 4]) + + start=perf_counter() + graph = converter(structure_supercell, graph_converter="fast") + print("Fast test_crystal_graph_supercell time:", perf_counter() - start) + + assert graph.composition == "Li48 Mn48 O96" + assert list(graph.atom_frac_coord.shape) == [192, 3] + assert list(graph.atom_graph.shape) == [9216, 2] + assert (graph.atom_graph[:, 0] == 30).sum().item() == 48 + assert (graph.atom_graph[:, 1] == 30).sum().item() == 48 + assert (graph.atom_graph[:, 0] == 70).sum().item() == 48 + + assert list(graph.bond_graph.shape) == [17856, 5] + assert (graph.bond_graph[:, 0] == 100).sum().item() == 72 + assert (graph.bond_graph[:, 1] == 623).sum().item() == 16 + assert (graph.bond_graph[:, 3] == 623).sum().item() == 16 + assert (graph.bond_graph[:, 2] == 2938).sum().item() == 8 + assert (graph.bond_graph[:, 4] == 121).sum().item() == 8 + assert list(graph.lattice.shape) == [3, 3] + assert list(graph.undirected2directed.shape) == [4608] + assert list(graph.directed2undirected.shape) == [9216] + + +def test_crystal_graph_stability_legacy(): + for _i in range(20): + np.random.seed(0) + structure_perturbed = structure.copy() + structure_perturbed.make_supercell([2, 2, 2]) + structure_perturbed.perturb(distance=0.5) + graph = converter(structure_perturbed, graph_converter="legacy") + + assert ( + graph.directed2undirected.shape[0] == 2 * graph.undirected2directed.shape[0] + ) + assert graph.atom_graph.shape[0] == graph.directed2undirected.shape[0] -def test_crystal_graph_stability(): +def test_crystal_graph_stability_fast(): for _i in range(20): np.random.seed(0) structure_perturbed = structure.copy() structure_perturbed.make_supercell([2, 2, 2]) structure_perturbed.perturb(distance=0.5) - graph = converter(structure_perturbed) + graph = converter(structure_perturbed, graph_converter="fast") assert ( graph.directed2undirected.shape[0] == 2 * graph.undirected2directed.shape[0] From b35d86dd95dfce8b17b9f8d4741c1db4df5a52e8 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 27 Jun 2023 16:27:39 -0700 Subject: [PATCH 02/15] auto fixes --- chgnet/graph/converter.py | 61 +++++++++++-------- chgnet/graph/cygraph.pyx | 8 +-- .../fast_converter_libraries/create_graph.c | 36 +++++------ .../graph/fast_converter_libraries/uthash.h | 2 +- setup.py | 12 ++-- tests/test_crystal_graph.py | 44 +++++++------ 6 files changed, 90 insertions(+), 73 deletions(-) diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 25a5a31b..6a443930 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -3,10 +3,9 @@ import sys from typing import TYPE_CHECKING, Literal +import numpy as np import torch from torch import Tensor, nn -import numpy as np - from chgnet.graph.crystalgraph import CrystalGraph from chgnet.graph.graph import Graph, Node @@ -18,9 +17,10 @@ try: from chgnet.graph.cygraph import make_graph -except: +except ImportError: print("Error importing fast graph conversion (cygraph). Reverting to legacy.") + class CrystalGraphConverter(nn.Module): """Convert a pymatgen.core.Structure to a CrystalGraph The CrystalGraph dataclass stores essential field to make sure that @@ -51,7 +51,7 @@ def forward( graph_id=None, mp_id=None, on_isolated_atoms: Literal["ignore", "warn", "error"] = "error", - graph_converter: Literal["legacy", "fast"] = "fast" + graph_converter: Literal["legacy", "fast"] = "fast", ) -> CrystalGraph: """Convert a structure, return a CrystalGraph. @@ -84,12 +84,18 @@ def forward( # Make Graph if graph_converter == "fast": try: - graph = self._create_graph_fast(n_atoms, center_index, neighbor_index, image, distance) - except: + graph = self._create_graph_fast( + n_atoms, center_index, neighbor_index, image, distance + ) + except Exception: print("Failed to retrieve fast graph converter. Reverting to legacy.") - graph = self._create_graph_legacy(n_atoms, center_index, neighbor_index, image, distance) + graph = self._create_graph_legacy( + n_atoms, center_index, neighbor_index, image, distance + ) elif graph_converter == "legacy": - graph = self._create_graph_legacy(n_atoms, center_index, neighbor_index, image, distance) + graph = self._create_graph_legacy( + n_atoms, center_index, neighbor_index, image, distance + ) else: raise ValueError(f"No graph_converter named {graph_converter}") @@ -144,17 +150,17 @@ def forward( def _create_graph_legacy( self, - n_atoms: int, - center_index: np.ndarray, - neighbor_index: np.ndarray, + n_atoms: int, + center_index: np.ndarray, + neighbor_index: np.ndarray, image: np.ndarray, - distance: np.ndarray + distance: np.ndarray, ) -> Graph: - """Given structure information, create a Graph structure to be used to - create Crystal_Graph + """Given structure information, create a Graph structure to be used to + create Crystal_Graph. Args: - n_atoms (int): the number of atoms in the structure + n_atoms (int): the number of atoms in the structure center_index (np.ndarray): np array of indices of center atoms. Shape: (# of edges, ) neighbor_index (np.ndarray): np array of indices of neighbor atoms. Shape: (# of edges, ) image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) @@ -166,24 +172,23 @@ def _create_graph_legacy( graph = Graph([Node(index=i) for i in range(n_atoms)]) for ii, jj, img, dist in zip(center_index, neighbor_index, image, distance): graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist) - - return graph + return graph def _create_graph_fast( self, - n_atoms: int, - center_index: np.ndarray, - neighbor_index: np.ndarray, + n_atoms: int, + center_index: np.ndarray, + neighbor_index: np.ndarray, image: np.ndarray, - distance: np.ndarray + distance: np.ndarray, ) -> Graph: - """Given structure information, create a Graph structure to be used to + """Given structure information, create a Graph structure to be used to create Crystal_Graph. NOTE: this is the fast version of _create_graph_legacy optimized in c (~3x speedup). Args: - n_atoms (int): the number of atoms in the structure + n_atoms (int): the number of atoms in the structure center_index (np.ndarray): np array of indices of center atoms. Shape: (# of edges, ) neighbor_index (np.ndarray): np array of indices of neighbor atoms. Shape: (# of edges, ) image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) @@ -192,13 +197,19 @@ def _create_graph_fast( Return: Graph data structure used to create Crystal_Graph object """ - center_index = np.ascontiguousarray(center_index) neighbor_index = np.ascontiguousarray(neighbor_index) image = np.ascontiguousarray(image, dtype=np.int_) distance = np.ascontiguousarray(distance) - nodes, directed_edges_list, undirected_edges_list, undirected_edges = make_graph(center_index, len(center_index), neighbor_index, image, distance, n_atoms) + ( + nodes, + directed_edges_list, + undirected_edges_list, + undirected_edges, + ) = make_graph( + center_index, len(center_index), neighbor_index, image, distance, n_atoms + ) graph = Graph(nodes=nodes) graph.directed_edges_list = directed_edges_list graph.undirected_edges_list = undirected_edges_list diff --git a/chgnet/graph/cygraph.pyx b/chgnet/graph/cygraph.pyx index b6d257f5..75d021e1 100644 --- a/chgnet/graph/cygraph.pyx +++ b/chgnet/graph/cygraph.pyx @@ -71,9 +71,9 @@ cdef extern from 'fast_converter_libraries/create_graph.c': LongToDirectedEdgeList** get_neighbors(Node* node) def make_graph( - const long[::1] center_index, + const long[::1] center_index, const long n_e, - const long[::1] neighbor_index, + const long[::1] neighbor_index, const long[:, ::1] image, const double[::1] distance, const long num_atoms @@ -109,9 +109,9 @@ def make_graph( for k in range(this_entry.num_directed_edges_in_group): this_DE = this_entry.directed_edges_list[k] directed_edges.append(this_DE[0].index) - + this_py_node.neighbors[this_entry.key] = directed_edges - + py_nodes.append(this_py_node) # Handling directed edges diff --git a/chgnet/graph/fast_converter_libraries/create_graph.c b/chgnet/graph/fast_converter_libraries/create_graph.c index df7b1981..53e925c0 100644 --- a/chgnet/graph/fast_converter_libraries/create_graph.c +++ b/chgnet/graph/fast_converter_libraries/create_graph.c @@ -9,7 +9,7 @@ typedef struct _LongToDirectedEdgeList LongToDirectedEdgeList; typedef struct _ReturnElems ReturnElems; typedef struct _ReturnElems2 ReturnElems2; -// NOTE: This code was mainly written to replicate the original add_edges method +// NOTE: This code was mainly written to replicate the original add_edges method // in the graph class in chgnet.graph.graph such that anyone familiar with that code should be able to pick up this // code pretty easily. @@ -61,7 +61,7 @@ typedef struct _ReturnElems { long* node_index_unraveled; long* node_neighbor_index_unraveled; long* node_directed_edge_index_unraveled; - + long num_undirected_edges; long* undirected_center_index_unraveled; long* undirected_neighbor_index_unraveled; @@ -70,7 +70,7 @@ typedef struct _ReturnElems { double* undirected_distances_unraveled; long num_directed_edges; - long* directed_undirected_edge_index_unraveled; + long* directed_undirected_edge_index_unraveled; } ReturnElems; @@ -135,7 +135,7 @@ ReturnElems2* create_graph( StructToUndirectedEdgeList* undirected_edges = NULL; // Pointer to beginning of list of UndirectedEdges corresponding to tmp of current iteration - StructToUndirectedEdgeList* corr_undirected_edges_item = NULL; + StructToUndirectedEdgeList* corr_undirected_edges_item = NULL; // Pointer to NodeIndexPair storing tmp NodeIndexPair* tmp = malloc(sizeof(NodeIndexPair)); @@ -144,7 +144,7 @@ ReturnElems2* create_graph( bool found = false; // Flag used to show if we've already processed the current undirected edge - bool processed_edge = false; + bool processed_edge = false; // Pointer used to store the previously added directed edge between two nodes DirectedEdge* added_DE; @@ -176,7 +176,7 @@ ReturnElems2* create_graph( this_directed_edge->undirected_edge_index = num_undirected_edges; - //TODO: be careful about double-freeing later. we're re-using a lot of memory space + //TODO: be careful about double-freeing later. we're re-using a lot of memory space // Create new undirected edge UndirectedEdge* this_undirected_edge = malloc(sizeof(UndirectedEdge)); @@ -187,9 +187,9 @@ ReturnElems2* create_graph( create_new_undirected_edges_entry(&undirected_edges, tmp, this_undirected_edge); append_to_undirected_edges_list(undirected_edges_list, this_undirected_edge, &num_undirected_edges); add_neighbors_to_node(&nodes[center_indices[i]], neighbor_indices[i], this_directed_edge); - append_to_directed_edges_list(directed_edges_list, this_directed_edge, &num_directed_edges); + append_to_directed_edges_list(directed_edges_list, this_directed_edge, &num_directed_edges); } else { - // This pair of nodes has been added before. We have to check if it's the other directed edge (but pointed in + // This pair of nodes has been added before. We have to check if it's the other directed edge (but pointed in // the different direction) OR it's another totally different undirected edge that has different image and distance // if found is true, then corr_undirected_edges_item points to self.undirected_edges[tmp] @@ -224,7 +224,7 @@ ReturnElems2* create_graph( } } - + // ReturnElems* returned; // returned = get_raw_data(nodes, num_atoms, num_undirected_edges, num_directed_edges, undirected_edges_list, directed_edges_list); @@ -235,7 +235,7 @@ ReturnElems2* create_graph( returned2->num_nodes = num_atoms; returned2->num_undirected_edges = num_undirected_edges; returned2->num_directed_edges = num_directed_edges; - + returned2->nodes = nodes; returned2->directed_edges_list = directed_edges_list; returned2->undirected_edges_list = undirected_edges_list; @@ -246,8 +246,8 @@ ReturnElems2* create_graph( // Converts all data into forms that can be digested in cython and used to create a graph python object ReturnElems* get_raw_data( - Node* nodes, - long num_nodes, + Node* nodes, + long num_nodes, long num_undirected_edges, long num_directed_edges, UndirectedEdge** undirected_edges_list, @@ -294,10 +294,10 @@ ReturnElems* get_raw_data( unravel_index += 1; } - } + } // Directed edges --------------- - // center unraveled, neighbor unraveled, image unraveled, distance unraveled for directed edges are all + // center unraveled, neighbor unraveled, image unraveled, distance unraveled for directed edges are all // just the inputs to the create graph function long* directed_undirected_edge_index_unraveled = malloc(sizeof(long) * num_directed_edges); for (long directed_i = 0; directed_i < num_directed_edges; directed_i++) { @@ -367,7 +367,7 @@ bool is_reversed_directed_edge(DirectedEdge* directed_edge1, DirectedEdge* direc } // If tmp or the reverse of tmp is found in undirected_edges, True is returned and the corresponding StructToUndirectedEdgeList pointer is placed -// into found_entry. Otherwise, False is returned. +// into found_entry. Otherwise, False is returned. // NOTE: does not edit the *tmp // Assumes *tmp bits have already been 0'd at padding within a struct bool find_in_undirected(NodeIndexPair* tmp, StructToUndirectedEdgeList** undirected_edges, StructToUndirectedEdgeList** found_entry) { @@ -408,7 +408,7 @@ void create_new_undirected_edges_entry(StructToUndirectedEdgeList** undirected_e new_entry->num_undirected_edges_in_group = 1; new_entry->undirected_edges_list = malloc(sizeof(UndirectedEdge*)); new_entry->undirected_edges_list[0] = new_undirected_edge; - + HASH_ADD(hh, *undirected_edges, key, sizeof(NodeIndexPair), new_entry); } @@ -452,7 +452,7 @@ void directed_to_undirected(DirectedEdge* directed, UndirectedEdge* undirected, void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, UndirectedEdge* to_add, long* num_undirected_edges) { // No need to realloc for space since our original alloc should cover everything - + // Assign value to next available position undirected_edges_list[*num_undirected_edges] = to_add; *num_undirected_edges += 1; @@ -496,7 +496,7 @@ void add_neighbors_to_node(Node* node, long neighbor_index, DirectedEdge* direct entry->directed_edges_list = malloc(sizeof(DirectedEdge*)); entry->directed_edges_list[0] = directed_edge; entry->key = neighbor_index; - + entry->num_directed_edges_in_group = 1; HASH_ADD(hh, node->neighbors, key, sizeof(long), entry); diff --git a/chgnet/graph/fast_converter_libraries/uthash.h b/chgnet/graph/fast_converter_libraries/uthash.h index 070f7b44..68693bf3 100644 --- a/chgnet/graph/fast_converter_libraries/uthash.h +++ b/chgnet/graph/fast_converter_libraries/uthash.h @@ -1137,4 +1137,4 @@ typedef struct UT_hash_handle { unsigned hashv; /* result of hash-fcn(key) */ } UT_hash_handle; -#endif /* UTHASH_H */ \ No newline at end of file +#endif /* UTHASH_H */ diff --git a/setup.py b/setup.py index 6def2c5d..49d04efa 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,10 @@ -from setuptools import Extension, setup +from __future__ import annotations + from Cython.Build import cythonize +from setuptools import Extension, setup -extensions = cythonize([Extension("chgnet.graph.cygraph", ["chgnet/graph/cygraph.pyx"])]) +extensions = cythonize( + [Extension("chgnet.graph.cygraph", ["chgnet/graph/cygraph.pyx"])] +) -setup( - ext_modules = extensions -) \ No newline at end of file +setup(ext_modules=extensions) diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index 28a24a66..fd2a1d4f 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -1,18 +1,19 @@ from __future__ import annotations +from time import perf_counter + import numpy as np from pymatgen.core import Structure from chgnet import ROOT from chgnet.graph import CrystalGraphConverter -from time import perf_counter structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif") converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3) def test_crystal_graph_legacy(): - start=perf_counter() + start = perf_counter() graph = converter(structure, graph_converter="legacy") print("Legacy test_crystal_graph time:", perf_counter() - start) @@ -35,12 +36,12 @@ def test_crystal_graph_legacy(): assert list(graph.undirected2directed.shape) == [192] assert list(graph.directed2undirected.shape) == [384] + def test_crystal_graph_fast(): - start=perf_counter() + start = perf_counter() graph = converter(structure, graph_converter="fast") print("Fast test_crystal_graph time:", perf_counter() - start) - assert graph.composition == "Li2 Mn2 O4" assert graph.atomic_number.tolist() == [3, 3, 25, 25, 8, 8, 8, 8] assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -84,14 +85,14 @@ def test_crystal_graph_different_cutoff_legacy(): assert list(graph.undirected2directed.shape) == [312] assert list(graph.directed2undirected.shape) == [624] + def test_crystal_graph_different_cutoff_fast(): converter = CrystalGraphConverter(atom_graph_cutoff=5.5, bond_graph_cutoff=3.5) - start=perf_counter() + start = perf_counter() graph = converter(structure, graph_converter="fast") print("Fast test_crystal_graph_different_cutoff time:", perf_counter() - start) - assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [624, 2] assert (graph.atom_graph[:, 0] == 5).sum().item() == 78 @@ -114,11 +115,10 @@ def test_crystal_graph_perturb_legacy(): structure_perturbed = structure.copy() structure_perturbed.perturb(distance=0.1) - start=perf_counter() + start = perf_counter() graph = converter(structure_perturbed, graph_converter="legacy") print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) - assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [410, 2] assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 @@ -136,12 +136,13 @@ def test_crystal_graph_perturb_legacy(): assert list(graph.undirected2directed.shape) == [205] assert list(graph.directed2undirected.shape) == [410] + def test_crystal_graph_perturb_fast(): np.random.seed(0) structure_perturbed = structure.copy() structure_perturbed.perturb(distance=0.1) - start=perf_counter() + start = perf_counter() graph = converter(structure_perturbed, graph_converter="fast") print("Fast test_crystal_graph_perturb time:", perf_counter() - start) @@ -167,7 +168,7 @@ def test_crystal_graph_isotropic_strained_legacy(): structure_strained = structure.copy() structure_strained.apply_strain([0.1, 0.1, 0.1]) - start=perf_counter() + start = perf_counter() graph = converter(structure_strained, graph_converter="legacy") print("Legacy test_crystal_graph_isotropic_strained time:", perf_counter() - start) @@ -182,11 +183,12 @@ def test_crystal_graph_isotropic_strained_legacy(): assert list(graph.undirected2directed.shape) == [132] assert list(graph.directed2undirected.shape) == [264] + def test_crystal_graph_isotropic_strained_fast(): structure_strained = structure.copy() structure_strained.apply_strain([0.1, 0.1, 0.1]) - start=perf_counter() + start = perf_counter() graph = converter(structure_strained, graph_converter="fast") print("Fast test_crystal_graph_isotropic_strained time:", perf_counter() - start) @@ -202,14 +204,15 @@ def test_crystal_graph_isotropic_strained_fast(): assert list(graph.directed2undirected.shape) == [264] - def test_crystal_graph_anisotropic_strained_legacy(): structure_strained = structure.copy() structure_strained.apply_strain([0.2, -0.3, 0.5]) - start=perf_counter() + start = perf_counter() graph = converter(structure_strained, graph_converter="legacy") - print("Legacy test_crystal_graph_anisotropic_strained time:", perf_counter() - start) + print( + "Legacy test_crystal_graph_anisotropic_strained time:", perf_counter() - start + ) assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [336, 2] @@ -222,15 +225,15 @@ def test_crystal_graph_anisotropic_strained_legacy(): assert list(graph.undirected2directed.shape) == [168] assert list(graph.directed2undirected.shape) == [336] + def test_crystal_graph_anisotropic_strained_fast(): structure_strained = structure.copy() structure_strained.apply_strain([0.2, -0.3, 0.5]) - start=perf_counter() + start = perf_counter() graph = converter(structure_strained, graph_converter="fast") print("Fast test_crystal_graph_anisotropic_strained time:", perf_counter() - start) - assert list(graph.atom_frac_coord.shape) == [8, 3] assert list(graph.atom_graph.shape) == [336, 2] assert (graph.atom_graph[:, 0] == 3).sum().item() == 42 @@ -246,12 +249,11 @@ def test_crystal_graph_anisotropic_strained_fast(): def test_crystal_graph_supercell_legacy(): structure_supercell = structure.copy() structure_supercell.make_supercell([2, 3, 4]) - - start=perf_counter() + + start = perf_counter() graph = converter(structure_supercell, graph_converter="legacy") print("Legacy test_crystal_graph_supercell time:", perf_counter() - start) - assert graph.composition == "Li48 Mn48 O96" assert list(graph.atom_frac_coord.shape) == [192, 3] assert list(graph.atom_graph.shape) == [9216, 2] @@ -269,11 +271,12 @@ def test_crystal_graph_supercell_legacy(): assert list(graph.undirected2directed.shape) == [4608] assert list(graph.directed2undirected.shape) == [9216] + def test_crystal_graph_supercell_fast(): structure_supercell = structure.copy() structure_supercell.make_supercell([2, 3, 4]) - start=perf_counter() + start = perf_counter() graph = converter(structure_supercell, graph_converter="fast") print("Fast test_crystal_graph_supercell time:", perf_counter() - start) @@ -308,6 +311,7 @@ def test_crystal_graph_stability_legacy(): ) assert graph.atom_graph.shape[0] == graph.directed2undirected.shape[0] + def test_crystal_graph_stability_fast(): for _i in range(20): np.random.seed(0) From e86fb0794bdc329c0803379f8444f0d9f64a6904 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Tue, 27 Jun 2023 16:28:57 -0700 Subject: [PATCH 03/15] fix typos --- chgnet/graph/converter.py | 12 ++++++++---- chgnet/graph/fast_converter_libraries/create_graph.c | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 6a443930..49b176eb 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -161,8 +161,10 @@ def _create_graph_legacy( Args: n_atoms (int): the number of atoms in the structure - center_index (np.ndarray): np array of indices of center atoms. Shape: (# of edges, ) - neighbor_index (np.ndarray): np array of indices of neighbor atoms. Shape: (# of edges, ) + center_index (np.ndarray): np array of indices of center atoms. + Shape: (# of edges, ) + neighbor_index (np.ndarray): np array of indices of neighbor atoms. + Shape: (# of edges, ) image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) distance (np.ndarray): np array of distances. Shape: (# of edges, ) @@ -189,8 +191,10 @@ def _create_graph_fast( Args: n_atoms (int): the number of atoms in the structure - center_index (np.ndarray): np array of indices of center atoms. Shape: (# of edges, ) - neighbor_index (np.ndarray): np array of indices of neighbor atoms. Shape: (# of edges, ) + center_index (np.ndarray): np array of indices of center atoms. + Shape: (# of edges, ) + neighbor_index (np.ndarray): np array of indices of neighbor atoms. + Shape: (# of edges, ) image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) distance (np.ndarray): np array of distances. Shape: (# of edges, ) diff --git a/chgnet/graph/fast_converter_libraries/create_graph.c b/chgnet/graph/fast_converter_libraries/create_graph.c index 53e925c0..df150c4f 100644 --- a/chgnet/graph/fast_converter_libraries/create_graph.c +++ b/chgnet/graph/fast_converter_libraries/create_graph.c @@ -423,7 +423,7 @@ void append_to_undirected_edges_tmp(UndirectedEdge* undirected, StructToUndirect long num_undirected_edges = this_undirected_edges_item->num_undirected_edges_in_group; // No need to worry about originally malloc'ing memory for this_undirected_edges_item->undirected_edges_list - // this is because, we first call create_new_undirected_edges_entry for all entires. This function already mallocs for us. + // this is because, we first call create_new_undirected_edges_entry for all entries. This function already mallocs for us. // Realloc the space to fit a new pointer to an undirected edge UndirectedEdge** new_list = realloc(this_undirected_edges_item->undirected_edges_list, sizeof(UndirectedEdge*) * (num_undirected_edges + 1)); @@ -461,7 +461,7 @@ void append_to_undirected_edges_list(UndirectedEdge** undirected_edges_list, Und void append_to_directed_edges_list(DirectedEdge** directed_edges_list, DirectedEdge* to_add, long* num_directed_edges) { // No need to realloc for space since our original alloc should cover everything - // Assign value to next availabe position + // Assign value to next available position directed_edges_list[*num_directed_edges] = to_add; *num_directed_edges += 1; } From 4fcb6b055b834b51947d40126820eef6af945602 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Tue, 27 Jun 2023 18:21:31 -0700 Subject: [PATCH 04/15] 1. added test cases for legacy and fast converter 2. rewrite arg names nad docstrings 3. added arg in model initialization to select converter algorithm --- .gitignore | 1 + chgnet/graph/converter.py | 104 ++++++++++++++++++++++-------------- chgnet/graph/cygraph.pyx | 16 +++--- chgnet/model/model.py | 30 ++++++++--- pyproject.toml | 1 + tests/test_converter.py | 15 +++++- tests/test_crystal_graph.py | 59 +++++++++++++------- tests/test_md.py | 47 +++++++++++++--- tests/test_model.py | 14 +++++ tests/test_relaxation.py | 34 ++++++++++-- 10 files changed, 237 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index 5adaf13b..aea32c2b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ *.py[cod] # C extensions +cygraph.c *.so # Distribution / packaging diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 49b176eb..056dab70 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -15,10 +15,6 @@ datatype = torch.float32 -try: - from chgnet.graph.cygraph import make_graph -except ImportError: - print("Error importing fast graph conversion (cygraph). Reverting to legacy.") class CrystalGraphConverter(nn.Module): @@ -28,15 +24,29 @@ class CrystalGraphConverter(nn.Module): """ def __init__( - self, atom_graph_cutoff: float = 5, bond_graph_cutoff: float = 3 + self, + atom_graph_cutoff: float = 5, + bond_graph_cutoff: float = 3, + algorithm: Literal["legacy", "fast"] = "fast", + verbose: bool = False, ) -> None: """Initialize the Crystal Graph Converter. Args: atom_graph_cutoff (float): cutoff radius to search for neighboring atom in - atom_graph. Default = 5 - bond_graph_cutoff (float): bond length threshold to include bond in bond_graph + atom_graph. + Default = 5 + bond_graph_cutoff (float): bond length threshold to include bond in + bond_graph Default = 3 + algorithm ('legacy' | 'fast'): algorithm to use for converting graphs. + 'legacy': python implementation of graph creation + 'fast': C implementation of graph creation, this is faster, + but will need the cygraph.c file correctly compiled from pip install + Default = 'fast' + verbose (bool): whether to print the CrystalGraphConverter + initialization message + Default = False """ super().__init__() self.atom_graph_cutoff = atom_graph_cutoff @@ -45,13 +55,33 @@ def __init__( else: self.bond_graph_cutoff = bond_graph_cutoff + # Set graph conversion algorithm + if algorithm == "fast": + try: + from chgnet.graph.cygraph import make_graph + + self._make_graph = make_graph + self.create_graph = self._create_graph_fast + self.algorithm = "fast" + except ImportError: + self.create_graph = self._create_graph_legacy + self.algorithm = "legacy" + elif algorithm == "legacy": + self.create_graph = self._create_graph_legacy + self.algorithm = "legacy" + + if verbose: + print( + f"CrystalGraphConverter-{self.algorithm} initialized with " + f"atom_cutoff={atom_graph_cutoff}, bond_cutoff={bond_graph_cutoff}" + ) + def forward( self, structure: Structure, graph_id=None, mp_id=None, on_isolated_atoms: Literal["ignore", "warn", "error"] = "error", - graph_converter: Literal["legacy", "fast"] = "fast", ) -> CrystalGraph: """Convert a structure, return a CrystalGraph. @@ -62,9 +92,8 @@ def forward( mp_id (str): Materials Project id of this structure Default = None on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures - with isolated atoms. Default = 'error' - graph_converter ('legacy' | 'fast'): graph converter to use when converting. - default = 'fast' + with isolated atoms. + Default = 'error' Return: CrystalGraph that is ready to use by CHGNet @@ -82,22 +111,9 @@ def forward( center_index, neighbor_index, image, distance = self.get_neighbors(structure) # Make Graph - if graph_converter == "fast": - try: - graph = self._create_graph_fast( - n_atoms, center_index, neighbor_index, image, distance - ) - except Exception: - print("Failed to retrieve fast graph converter. Reverting to legacy.") - graph = self._create_graph_legacy( - n_atoms, center_index, neighbor_index, image, distance - ) - elif graph_converter == "legacy": - graph = self._create_graph_legacy( - n_atoms, center_index, neighbor_index, image, distance - ) - else: - raise ValueError(f"No graph_converter named {graph_converter}") + graph = self.create_graph( + n_atoms, center_index, neighbor_index, image, distance + ) # Atom Graph atom_graph, directed2undirected = graph.adjacency_list() @@ -148,8 +164,8 @@ def forward( bond_graph_cutoff=self.bond_graph_cutoff, ) + @staticmethod def _create_graph_legacy( - self, n_atoms: int, center_index: np.ndarray, neighbor_index: np.ndarray, @@ -157,16 +173,18 @@ def _create_graph_legacy( distance: np.ndarray, ) -> Graph: """Given structure information, create a Graph structure to be used to - create Crystal_Graph. + create Crystal_Graph using pure python implementation. Args: n_atoms (int): the number of atoms in the structure center_index (np.ndarray): np array of indices of center atoms. - Shape: (# of edges, ) + [num_undirected_bonds] neighbor_index (np.ndarray): np array of indices of neighbor atoms. - Shape: (# of edges, ) - image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) - distance (np.ndarray): np array of distances. Shape: (# of edges, ) + [num_undirected_bonds] + image (np.ndarray): np array of images for each edge. + [num_undirected_bonds, 3] + distance (np.ndarray): np array of distances. + [num_undirected_bonds] Return: Graph data structure used to create Crystal_Graph object @@ -186,17 +204,21 @@ def _create_graph_fast( distance: np.ndarray, ) -> Graph: """Given structure information, create a Graph structure to be used to - create Crystal_Graph. NOTE: this is the fast version of _create_graph_legacy optimized - in c (~3x speedup). + create Crystal_Graph using C implementation. + + NOTE: this is the fast version of _create_graph_legacy optimized + in c (~3x speedup). Args: n_atoms (int): the number of atoms in the structure center_index (np.ndarray): np array of indices of center atoms. - Shape: (# of edges, ) + [num_undirected_bonds] neighbor_index (np.ndarray): np array of indices of neighbor atoms. - Shape: (# of edges, ) - image (np.ndarray): np array of images for each edge. Shape: (# of edges, 3) - distance (np.ndarray): np array of distances. Shape: (# of edges, ) + [num_undirected_bonds] + image (np.ndarray): np array of images for each edge. + [num_undirected_bonds, 3] + distance (np.ndarray): np array of distances. + [num_undirected_bonds] Return: Graph data structure used to create Crystal_Graph object @@ -211,9 +233,10 @@ def _create_graph_fast( directed_edges_list, undirected_edges_list, undirected_edges, - ) = make_graph( + ) = self._make_graph( center_index, len(center_index), neighbor_index, image, distance, n_atoms ) + graph = Graph(nodes=nodes) graph.directed_edges_list = directed_edges_list graph.undirected_edges_list = undirected_edges_list @@ -242,6 +265,7 @@ def as_dict(self) -> dict[str, float]: return { "atom_graph_cutoff": self.atom_graph_cutoff, "bond_graph_cutoff": self.bond_graph_cutoff, + "algorithm": self.algorithm, } @classmethod diff --git a/chgnet/graph/cygraph.pyx b/chgnet/graph/cygraph.pyx index 75d021e1..9b9d8263 100644 --- a/chgnet/graph/cygraph.pyx +++ b/chgnet/graph/cygraph.pyx @@ -3,16 +3,16 @@ # cython: nonecheck=False # cython: boundscheck=False # cython: wraparound=False -# cython: cdivision=False +# cython: cdivision=True # cython: profile=False # distutils: language = c -import chgnet.graph.graph -from cython.operator import dereference import numpy as np -from libc cimport time -from libc.stdio cimport printf -from libc.stdlib cimport free +from libc.stdlib cimport + +free + +import chgnet.graph.graph cdef extern from 'fast_converter_libraries/create_graph.c': ctypedef struct Node: @@ -95,7 +95,7 @@ def make_graph( # Handling nodes + directed edges for i in range(returned[0].num_nodes): - this_node = dereference(returned).nodes[i] + this_node = returned[0].nodes[i] this_py_node = chg_Node(index=i) this_py_node.neighbors = {} @@ -103,7 +103,7 @@ def make_graph( # Iterate through all neighbors and populate our py_node.neighbors dict for j in range(this_node.num_neighbors): - this_entry = dereference(node_neighbors[j]) + this_entry = node_neighbors[j][0] directed_edges = [] for k in range(this_entry.num_directed_edges_in_group): diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 3bdda512..de47b9f7 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -52,8 +52,9 @@ def __init__( mlp_first: bool = True, is_intensive: bool = True, non_linearity: Literal["silu", "relu", "tanh", "gelu"] = "silu", - atom_graph_cutoff: int = 5, - bond_graph_cutoff: int = 3, + atom_graph_cutoff: float = 5, + bond_graph_cutoff: float = 3, + graph_converter_algorithm: Literal["legacy", "fast"] = "fast", cutoff_coeff: int = 5, learnable_rbf: bool = True, **kwargs, @@ -121,6 +122,12 @@ def __init__( bond_graph_cutoff (float): cutoff radius (A) in creating bond_graph, this need to be consistent with value in training dataloader Default = 3 + graph_converter_algorithm ('legacy' | 'fast'): algorithm to use + for converting pymatgen.core.Structure to CrystalGraph. + 'legacy': python implementation of graph creation + 'fast': C implementation of graph creation, this is faster, + but will need the cygraph.c file correctly compiled from pip install + default = 'fast' cutoff_coeff (float): cutoff strength used in graph smooth cutoff function. the smaller this coeff is, the smoother the basis is Default = 5 @@ -159,7 +166,10 @@ def __init__( # Define Crystal Graph Converter self.graph_converter = CrystalGraphConverter( - atom_graph_cutoff=atom_graph_cutoff, bond_graph_cutoff=bond_graph_cutoff + atom_graph_cutoff=atom_graph_cutoff, + bond_graph_cutoff=bond_graph_cutoff, + algorithm=graph_converter_algorithm, + verbose=kwargs.pop("converter_verbose", False), ) # Define embedding layers @@ -512,9 +522,17 @@ def predict_structure( s: stress of structure [3 * batch_size, 3] in GPa m: magnetic moments of sites [num_batch_atoms, 3] in Bohr magneton mu_B """ - if not isinstance(structure, (Structure, Sequence)): - raise ValueError( - f"structure should be a Structure or list of structures, got {type(structure)}" + assert ( + self.graph_converter is not None + ), "self.algorithm needs to be initialized first!" + if type(structure) == Structure: + graph = self.graph_converter(structure) + return self.predict_graph( + graph, + task=task, + return_atom_feas=return_atom_feas, + return_crystal_feas=return_crystal_feas, + batch_size=batch_size, ) if self.graph_converter is None: raise ValueError("graph_converter cannot be None!") diff --git a/pyproject.toml b/pyproject.toml index d9d68855..9b4523f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numpy>=1.21.6", "pymatgen>=2022.4.19", "torch>=1.11.0", + "cython>=0.29.26" ] classifiers = [ "Intended Audience :: Science/Research", diff --git a/tests/test_converter.py b/tests/test_converter.py index 713d92df..66a5c0c6 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( "atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)] ) -def test_crystal_graph_converter_init(atom_graph_cutoff, bond_graph_cutoff): +def test_crystal_graph_converter_cutoff(atom_graph_cutoff, bond_graph_cutoff): converter = CrystalGraphConverter( atom_graph_cutoff=atom_graph_cutoff, bond_graph_cutoff=bond_graph_cutoff ) @@ -23,6 +23,14 @@ def test_crystal_graph_converter_init(atom_graph_cutoff, bond_graph_cutoff): assert converter.bond_graph_cutoff == bond_graph_cutoff or atom_graph_cutoff +@pytest.mark.parametrize("algorithm", ["legacy", "fast"]) +def test_crystal_graph_converter_algorithm(algorithm): + converter = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm=algorithm + ) + assert converter.algorithm == algorithm + + @pytest.mark.parametrize("on_isolated_atoms", ["ignore", "warn", "error"]) def test_crystal_graph_converter_forward( on_isolated_atoms, capsys: CaptureFixture[str] @@ -98,4 +106,7 @@ def test_crystal_graph_converter_get_neighbors(): def test_crystal_graph_converter_as_dict_round_trip(): expected = {"atom_graph_cutoff": 5, "bond_graph_cutoff": 3} converter = CrystalGraphConverter(**expected) - assert CrystalGraphConverter.from_dict(converter.as_dict()).as_dict() == expected + converter2 = CrystalGraphConverter.from_dict(converter.as_dict()) + assert converter.atom_graph_cutoff == converter2.atom_graph_cutoff + assert converter.bond_graph_cutoff == converter2.bond_graph_cutoff + assert converter.algorithm == converter2.algorithm diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index fd2a1d4f..3f01b885 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -10,11 +10,18 @@ structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif") converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3) +converter_legacy = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="legacy", verbose=True +) +converter_fast = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="fast", verbose=True +) def test_crystal_graph_legacy(): + assert converter_legacy.algorithm == "legacy" start = perf_counter() - graph = converter(structure, graph_converter="legacy") + graph = converter_legacy(structure) print("Legacy test_crystal_graph time:", perf_counter() - start) assert graph.composition == "Li2 Mn2 O4" @@ -38,9 +45,10 @@ def test_crystal_graph_legacy(): def test_crystal_graph_fast(): + assert converter_fast.algorithm == "fast" start = perf_counter() - graph = converter(structure, graph_converter="fast") - print("Fast test_crystal_graph time:", perf_counter() - start) + graph = converter_fast(structure) + print("Fasttest_crystal_graph time:", perf_counter() - start) assert graph.composition == "Li2 Mn2 O4" assert graph.atomic_number.tolist() == [3, 3, 25, 25, 8, 8, 8, 8] @@ -63,10 +71,13 @@ def test_crystal_graph_fast(): def test_crystal_graph_different_cutoff_legacy(): - converter = CrystalGraphConverter(atom_graph_cutoff=5.5, bond_graph_cutoff=3.5) + converter_legacy_2 = CrystalGraphConverter( + atom_graph_cutoff=5.5, bond_graph_cutoff=3.5, algorithm="legacy" + ) + assert converter_legacy_2.algorithm == "legacy" start = perf_counter() - graph = converter(structure, graph_converter="legacy") + graph = converter_legacy_2(structure) print("Legacy test_crystal_graph_different_cutoff time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -87,10 +98,13 @@ def test_crystal_graph_different_cutoff_legacy(): def test_crystal_graph_different_cutoff_fast(): - converter = CrystalGraphConverter(atom_graph_cutoff=5.5, bond_graph_cutoff=3.5) + converter_fast_2 = CrystalGraphConverter( + atom_graph_cutoff=5.5, bond_graph_cutoff=3.5, algorithm="fast" + ) + assert converter_fast_2.algorithm == "fast" start = perf_counter() - graph = converter(structure, graph_converter="fast") + graph = converter_fast_2(structure) print("Fast test_crystal_graph_different_cutoff time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -116,7 +130,7 @@ def test_crystal_graph_perturb_legacy(): structure_perturbed.perturb(distance=0.1) start = perf_counter() - graph = converter(structure_perturbed, graph_converter="legacy") + graph = converter_legacy(structure_perturbed) print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -126,7 +140,6 @@ def test_crystal_graph_perturb_legacy(): assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 assert list(graph.bond_graph.shape) == [688, 5] - print(graph.bond_graph[120, :]) assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 @@ -143,7 +156,7 @@ def test_crystal_graph_perturb_fast(): structure_perturbed.perturb(distance=0.1) start = perf_counter() - graph = converter(structure_perturbed, graph_converter="fast") + graph = converter_fast(structure_perturbed) print("Fast test_crystal_graph_perturb time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -153,7 +166,6 @@ def test_crystal_graph_perturb_fast(): assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 assert list(graph.bond_graph.shape) == [688, 5] - print(graph.bond_graph[120, :]) assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 @@ -169,7 +181,7 @@ def test_crystal_graph_isotropic_strained_legacy(): structure_strained.apply_strain([0.1, 0.1, 0.1]) start = perf_counter() - graph = converter(structure_strained, graph_converter="legacy") + graph = converter_legacy(structure_strained) print("Legacy test_crystal_graph_isotropic_strained time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -189,7 +201,7 @@ def test_crystal_graph_isotropic_strained_fast(): structure_strained.apply_strain([0.1, 0.1, 0.1]) start = perf_counter() - graph = converter(structure_strained, graph_converter="fast") + graph = converter_fast(structure_strained) print("Fast test_crystal_graph_isotropic_strained time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -209,7 +221,7 @@ def test_crystal_graph_anisotropic_strained_legacy(): structure_strained.apply_strain([0.2, -0.3, 0.5]) start = perf_counter() - graph = converter(structure_strained, graph_converter="legacy") + graph = converter_legacy(structure_strained) print( "Legacy test_crystal_graph_anisotropic_strained time:", perf_counter() - start ) @@ -231,7 +243,7 @@ def test_crystal_graph_anisotropic_strained_fast(): structure_strained.apply_strain([0.2, -0.3, 0.5]) start = perf_counter() - graph = converter(structure_strained, graph_converter="fast") + graph = converter_fast(structure_strained) print("Fast test_crystal_graph_anisotropic_strained time:", perf_counter() - start) assert list(graph.atom_frac_coord.shape) == [8, 3] @@ -251,7 +263,7 @@ def test_crystal_graph_supercell_legacy(): structure_supercell.make_supercell([2, 3, 4]) start = perf_counter() - graph = converter(structure_supercell, graph_converter="legacy") + graph = converter_legacy(structure_supercell) print("Legacy test_crystal_graph_supercell time:", perf_counter() - start) assert graph.composition == "Li48 Mn48 O96" @@ -277,7 +289,7 @@ def test_crystal_graph_supercell_fast(): structure_supercell.make_supercell([2, 3, 4]) start = perf_counter() - graph = converter(structure_supercell, graph_converter="fast") + graph = converter_fast(structure_supercell) print("Fast test_crystal_graph_supercell time:", perf_counter() - start) assert graph.composition == "Li48 Mn48 O96" @@ -299,28 +311,37 @@ def test_crystal_graph_supercell_fast(): def test_crystal_graph_stability_legacy(): + total_time = 0 for _i in range(20): np.random.seed(0) structure_perturbed = structure.copy() structure_perturbed.make_supercell([2, 2, 2]) structure_perturbed.perturb(distance=0.5) - graph = converter(structure_perturbed, graph_converter="legacy") + start = perf_counter() + graph = converter_legacy(structure_perturbed) + total_time += perf_counter() - start assert ( graph.directed2undirected.shape[0] == 2 * graph.undirected2directed.shape[0] ) assert graph.atom_graph.shape[0] == graph.directed2undirected.shape[0] + print("Legacy test_crystal_graph_stability time:", total_time) + def test_crystal_graph_stability_fast(): + total_time = 0 for _i in range(20): np.random.seed(0) structure_perturbed = structure.copy() structure_perturbed.make_supercell([2, 2, 2]) structure_perturbed.perturb(distance=0.5) - graph = converter(structure_perturbed, graph_converter="fast") + start = perf_counter() + graph = converter_fast(structure_perturbed) + total_time += perf_counter() - start assert ( graph.directed2undirected.shape[0] == 2 * graph.undirected2directed.shape[0] ) assert graph.atom_graph.shape[0] == graph.directed2undirected.shape[0] + print("Fast test_crystal_graph_stability time:", total_time) diff --git a/tests/test_md.py b/tests/test_md.py index 168b112e..3afdced1 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -10,6 +10,7 @@ from pytest import MonkeyPatch, approx from chgnet import ROOT +from chgnet.graph import CrystalGraphConverter from chgnet.model import StructOptimizer from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics from chgnet.model.model import CHGNet @@ -25,17 +26,13 @@ def test_eos(): eos = EquationOfState() eos.fit(atoms=structure) - print(eos.get_bulk_mudulus()) - print(eos.get_bulk_mudulus(unit="GPa")) - print(eos.get_compressibility()) - print(eos.get_compressibility(unit="GPa^-1")) assert eos.get_bulk_mudulus() == approx(0.6621170816, rel=1e-5) assert eos.get_bulk_mudulus(unit="GPa") == approx(106.08285172, rel=1e-5) assert eos.get_compressibility() == approx(1.510306904, rel=1e-5) assert eos.get_compressibility(unit="GPa^-1") == approx(0.009426594, rel=1e-5) -def test_md_nvt(tmp_path: Path, monkeypatch: MonkeyPatch): +def test_md_nvt_legacy_converter(tmp_path: Path, monkeypatch: MonkeyPatch): # cd into the temporary directory monkeypatch.chdir(tmp_path) @@ -65,6 +62,44 @@ def test_md_nvt(tmp_path: Path, monkeypatch: MonkeyPatch): ) +def test_md_nvt_fast_converter(tmp_path: Path, monkeypatch: MonkeyPatch): + # cd into the temporary directory + monkeypatch.chdir(tmp_path) + + chgnet_fast = CHGNet.load() + converter_fast = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="fast" + ) + assert converter_fast.algorithm == "fast" + + chgnet_fast.graph_converter = converter_fast + + md = MolecularDynamics( + atoms=structure, + model=chgnet_fast, + ensemble="nvt", + temperature=1000, # in k + timestep=2, # in fs + trajectory="md_out.traj", + logfile="md_out.log", + loginterval=100, + use_device="cpu", + ) + md.run(10) + + assert isinstance(md.atoms, Atoms) + assert isinstance(md.atoms.calc, CHGNetCalculator) + assert isinstance(md.dyn, NVTBerendsen) + assert os.path.isfile("md_out.traj") + assert os.path.isfile("md_out.log") + with open("md_out.log") as log_file: + logs = log_file.read() + assert logs == ( + "Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n" + "0.0000 -58.9727 -58.9727 0.0000 0.0\n" + ) + + def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch): # cd into the temporary directory monkeypatch.chdir(tmp_path) @@ -75,6 +110,7 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch ensemble="npt", temperature=1000, # in k timestep=2, # in fs + compressibility_au=1.5103069, trajectory="md_out.traj", logfile="md_out.log", loginterval=100, @@ -85,7 +121,6 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch assert isinstance(md.atoms.calc, CHGNetCalculator) assert isinstance(md.dyn, Inhomogeneous_NPTBerendsen) assert md.dyn.pressure == approx(6.324209e-07, rel=1e-5) - assert md.dyn.compressibility == approx(1.5103069, rel=1e-5) assert os.path.isfile("md_out.traj") assert os.path.isfile("md_out.log") with open("md_out.log") as log_file: diff --git a/tests/test_model.py b/tests/test_model.py index 2d70379a..f6ae10f0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -37,6 +37,7 @@ def test_model( num_angular=num_angular, n_conv=n_conv, composition_model=composition_model, + converter_verbose="True", ) out = model([graph]) assert list(out) == ["atoms_per_graph", "e"] @@ -51,6 +52,19 @@ def test_predict_structure() -> None: assert sorted(out) == ["e", "f", "m", "s"] assert out["e"] == pytest.approx(-7.37159, abs=1e-4) assert len(out["f"]) == len(structure) + force = np.array( + [ + [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], + [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], + [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], + [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], + [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], + [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], + [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], + [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + ] + ) + assert out["f"] == pytest.approx(force, abs=1e-4) assert len(out["m"]) == len(structure) assert out["m"] == pytest.approx( [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538], diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index 4f51407d..55140b7f 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -5,15 +5,15 @@ from pytest import approx, mark, param from chgnet import ROOT -from chgnet.model import StructOptimizer +from chgnet.graph import CrystalGraphConverter +from chgnet.model import CHGNet, StructOptimizer relaxer = StructOptimizer() structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif") -def test_relaxation(): +def test_relaxation_legacy(): result = relaxer.relax(structure, verbose=True) - assert list(result) == ["final_structure", "trajectory"] traj = result["trajectory"] @@ -31,7 +31,35 @@ def test_relaxation(): # make sure final structure is more relaxed than initial one assert traj.energies[0] > traj.energies[-1] + assert traj.energies[-1] == approx(-58.972927) + + +def test_relaxation_fast_converter(): + chgnet = CHGNet.load() + converter_fast = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="fast" + ) + assert converter_fast.algorithm == "fast" + + chgnet.graph_converter = converter_fast + result = relaxer.relax(structure, verbose=True) + assert list(result) == ["final_structure", "trajectory"] + + traj = result["trajectory"] + # make sure trajectory has expected attributes + assert list(traj.__dict__) == [ + "atoms", + "energies", + "forces", + "stresses", + "magmoms", + "atom_positions", + "cells", + ] + assert len(traj) == 4 + # make sure final structure is more relaxed than initial one + assert traj.energies[0] > traj.energies[-1] assert traj.energies[-1] == approx(-58.972927) From 1573a5f6feac0fe92e33143247f34a3fbc26b99d Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:10:03 -0700 Subject: [PATCH 05/15] fixed typo --- chgnet/graph/cygraph.pyx | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/chgnet/graph/cygraph.pyx b/chgnet/graph/cygraph.pyx index 9b9d8263..d24ec104 100644 --- a/chgnet/graph/cygraph.pyx +++ b/chgnet/graph/cygraph.pyx @@ -8,9 +8,7 @@ # distutils: language = c import numpy as np -from libc.stdlib cimport - -free +from libc.stdlib cimport free import chgnet.graph.graph From f4f7619b89c9cdcbf4841210e08128e31dbb6fa4 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:53:48 -0700 Subject: [PATCH 06/15] fixed linting and test_relaxation.py --- chgnet/graph/converter.py | 1 - chgnet/model/model.py | 2 +- tests/test_crystal_graph.py | 1 - tests/test_relaxation.py | 12 ++++++++++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 056dab70..0e690fca 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -16,7 +16,6 @@ datatype = torch.float32 - class CrystalGraphConverter(nn.Module): """Convert a pymatgen.core.Structure to a CrystalGraph The CrystalGraph dataclass stores essential field to make sure that diff --git a/chgnet/model/model.py b/chgnet/model/model.py index de47b9f7..d837d058 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -569,7 +569,7 @@ def predict_graph( only available if self.mlp_first is False Default = False batch_size (int): batch_size for predict structures. - Default = 100. + Default = 100 Returns: prediction (dict): containing the fields: diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index 3f01b885..d6a29059 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -328,7 +328,6 @@ def test_crystal_graph_stability_legacy(): print("Legacy test_crystal_graph_stability time:", total_time) - def test_crystal_graph_stability_fast(): total_time = 0 for _i in range(20): diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index 55140b7f..dd3e1226 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -8,11 +8,18 @@ from chgnet.graph import CrystalGraphConverter from chgnet.model import CHGNet, StructOptimizer -relaxer = StructOptimizer() structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif") -def test_relaxation_legacy(): +def test_relaxation_legacy_converter(): + chgnet = CHGNet.load() + converter_legacy = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="legacy" + ) + assert converter_legacy.algorithm == "legacy" + + chgnet.graph_converter = converter_legacy + relaxer = StructOptimizer(model=chgnet) result = relaxer.relax(structure, verbose=True) assert list(result) == ["final_structure", "trajectory"] @@ -42,6 +49,7 @@ def test_relaxation_fast_converter(): assert converter_fast.algorithm == "fast" chgnet.graph_converter = converter_fast + relaxer = StructOptimizer(model=chgnet) result = relaxer.relax(structure, verbose=True) assert list(result) == ["final_structure", "trajectory"] From fa06afc228013735123cbcc932fa70c6e106a8f1 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 14:11:57 -0700 Subject: [PATCH 07/15] fix tests --- chgnet/model/model.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/chgnet/model/model.py b/chgnet/model/model.py index d837d058..4c4bf5b6 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -522,32 +522,19 @@ def predict_structure( s: stress of structure [3 * batch_size, 3] in GPa m: magnetic moments of sites [num_batch_atoms, 3] in Bohr magneton mu_B """ - assert ( - self.graph_converter is not None - ), "self.algorithm needs to be initialized first!" - if type(structure) == Structure: - graph = self.graph_converter(structure) - return self.predict_graph( - graph, - task=task, - return_atom_feas=return_atom_feas, - return_crystal_feas=return_crystal_feas, - batch_size=batch_size, - ) if self.graph_converter is None: raise ValueError("graph_converter cannot be None!") structures = [structure] if isinstance(structure, Structure) else structure graphs = [self.graph_converter(struct) for struct in structures] - predictions = self.predict_graph( + return self.predict_graph( graphs, task=task, return_atom_feas=return_atom_feas, return_crystal_feas=return_crystal_feas, batch_size=batch_size, ) - return predictions[0] if len(structures) == 1 else predictions def predict_graph( self, From b89fd48d2fe34b8f619694dbf50ae150a619e3c3 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 14:17:11 -0700 Subject: [PATCH 08/15] use pytest.mark.parametrize to keep relaxation test DRY --- tests/test_relaxation.py | 44 ++++++++-------------------------------- 1 file changed, 9 insertions(+), 35 deletions(-) diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index dd3e1226..2e5a7b14 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Literal + +import pytest import torch from pymatgen.core import Structure from pytest import approx, mark, param @@ -11,44 +14,15 @@ structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif") -def test_relaxation_legacy_converter(): - chgnet = CHGNet.load() - converter_legacy = CrystalGraphConverter( - atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="legacy" - ) - assert converter_legacy.algorithm == "legacy" - - chgnet.graph_converter = converter_legacy - relaxer = StructOptimizer(model=chgnet) - result = relaxer.relax(structure, verbose=True) - assert list(result) == ["final_structure", "trajectory"] - - traj = result["trajectory"] - # make sure trajectory has expected attributes - assert list(traj.__dict__) == [ - "atoms", - "energies", - "forces", - "stresses", - "magmoms", - "atom_positions", - "cells", - ] - assert len(traj) == 4 - - # make sure final structure is more relaxed than initial one - assert traj.energies[0] > traj.energies[-1] - assert traj.energies[-1] == approx(-58.972927) - - -def test_relaxation_fast_converter(): +@pytest.mark.parametrize("algorithm", ["legacy", "fast"]) +def test_relaxation(algorithm: Literal["legacy", "fast"]): chgnet = CHGNet.load() - converter_fast = CrystalGraphConverter( - atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm="fast" + converter = CrystalGraphConverter( + atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm=algorithm ) - assert converter_fast.algorithm == "fast" + assert converter.algorithm == algorithm - chgnet.graph_converter = converter_fast + chgnet.graph_converter = converter relaxer = StructOptimizer(model=chgnet) result = relaxer.relax(structure, verbose=True) assert list(result) == ["final_structure", "trajectory"] From 79b45b2d935bdb8702f5721bd0275a7cb5e04040 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 14:28:54 -0700 Subject: [PATCH 09/15] add CrystalGraphConverter.__repr__ used to refactor CrystalGraphConverter(verbose=True) --- chgnet/graph/converter.py | 31 ++++++++++++++++--------------- chgnet/model/model.py | 2 +- tests/test_relaxation.py | 4 ++-- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 0e690fca..e492ae29 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -33,26 +33,22 @@ def __init__( Args: atom_graph_cutoff (float): cutoff radius to search for neighboring atom in - atom_graph. - Default = 5 + atom_graph. Default = 5. bond_graph_cutoff (float): bond length threshold to include bond in - bond_graph - Default = 3 + bond_graph. Default = 3. algorithm ('legacy' | 'fast'): algorithm to use for converting graphs. 'legacy': python implementation of graph creation 'fast': C implementation of graph creation, this is faster, but will need the cygraph.c file correctly compiled from pip install Default = 'fast' verbose (bool): whether to print the CrystalGraphConverter - initialization message - Default = False + initialization message. Default = False. """ super().__init__() self.atom_graph_cutoff = atom_graph_cutoff - if bond_graph_cutoff is None: - self.bond_graph_cutoff = atom_graph_cutoff - else: - self.bond_graph_cutoff = bond_graph_cutoff + self.bond_graph_cutoff = ( + atom_graph_cutoff if bond_graph_cutoff is None else bond_graph_cutoff + ) # Set graph conversion algorithm if algorithm == "fast": @@ -62,7 +58,7 @@ def __init__( self._make_graph = make_graph self.create_graph = self._create_graph_fast self.algorithm = "fast" - except ImportError: + except (ImportError, AttributeError): self.create_graph = self._create_graph_legacy self.algorithm = "legacy" elif algorithm == "legacy": @@ -70,10 +66,15 @@ def __init__( self.algorithm = "legacy" if verbose: - print( - f"CrystalGraphConverter-{self.algorithm} initialized with " - f"atom_cutoff={atom_graph_cutoff}, bond_cutoff={bond_graph_cutoff}" - ) + print(self) + + def __repr__(self) -> str: + """String representation of the CrystalGraphConverter.""" + atom_graph_cutoff = self.atom_graph_cutoff + bond_graph_cutoff = self.bond_graph_cutoff + algorithm = self.algorithm + cls_name = type(self).__name__ + return f"{cls_name}({algorithm=}, {atom_graph_cutoff=}, {bond_graph_cutoff=})" def forward( self, diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 4c4bf5b6..7e734415 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -169,7 +169,7 @@ def __init__( atom_graph_cutoff=atom_graph_cutoff, bond_graph_cutoff=bond_graph_cutoff, algorithm=graph_converter_algorithm, - verbose=kwargs.pop("converter_verbose", False), + verbose=kwargs.pop("converter_verbose"), ) # Define embedding layers diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index 2e5a7b14..aca01509 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -29,7 +29,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]): traj = result["trajectory"] # make sure trajectory has expected attributes - assert list(traj.__dict__) == [ + assert {*traj.__dict__} == { "atoms", "energies", "forces", @@ -37,7 +37,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]): "magmoms", "atom_positions", "cells", - ] + } assert len(traj) == 4 # make sure final structure is more relaxed than initial one From e16107dc35a98a444535f1c6573e342dd19ea499 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 14:41:03 -0700 Subject: [PATCH 10/15] check for converter_verbose message in stdout in test_model() --- chgnet/model/model.py | 2 +- tests/test_model.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 7e734415..4c4bf5b6 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -169,7 +169,7 @@ def __init__( atom_graph_cutoff=atom_graph_cutoff, bond_graph_cutoff=bond_graph_cutoff, algorithm=graph_converter_algorithm, - verbose=kwargs.pop("converter_verbose"), + verbose=kwargs.pop("converter_verbose", False), ) # Define embedding layers diff --git a/tests/test_model.py b/tests/test_model.py index f6ae10f0..bbdf11f2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,7 +28,9 @@ def test_model( num_angular: int, n_conv: int, composition_model: str, + capsys: pytest.CaptureFixture[str], ) -> None: + converter_verbose = False model = CHGNet( atom_fea_dim=atom_fea_dim, bond_fea_dim=bond_fea_dim, @@ -37,13 +39,21 @@ def test_model( num_angular=num_angular, n_conv=n_conv, composition_model=composition_model, - converter_verbose="True", + converter_verbose=converter_verbose, ) out = model([graph]) assert list(out) == ["atoms_per_graph", "e"] assert out["atoms_per_graph"].shape == (1,) assert out["e"] < 0 + stdout, stderr = capsys.readouterr() + if converter_verbose: + assert repr(model.graph_converter) in stdout + else: + assert "CHGNet initialized with" in stdout + + assert stderr == "" + def test_predict_structure() -> None: model = CHGNet.load() From a7b534bcec9f3a28fdbc8dfed4d5f794912c468c Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 14:46:53 -0700 Subject: [PATCH 11/15] fix pyproject version number --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9b4523f1..bb7f0fef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["setuptools>=65.0", "Cython", "wheel"] +requires = ["Cython", "setuptools>=65.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "chgnet" -version = "0.1.04" +version = "0.1.4" description = "Pretrained Universal Neural Network Potential for Charge-informed Atomistic Modeling" authors = [{ name = "Bowen Deng", email = "bowendeng@berkeley.edu" }] requires-python = ">=3.8" @@ -12,10 +12,10 @@ readme = "README.md" license = { text = "Modified BSD" } dependencies = [ "ase>=3.22.0", + "cython>=0.29.26", "numpy>=1.21.6", "pymatgen>=2022.4.19", "torch>=1.11.0", - "cython>=0.29.26" ] classifiers = [ "Intended Audience :: Science/Research", From 2339b00d7a5535dda0725e2c329bc92d5d9833b8 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 14:52:10 -0700 Subject: [PATCH 12/15] make sure we check out repo where PR originates (rather than our own) --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 123711ba..226a6fbe 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,6 +30,8 @@ jobs: steps: - name: Check out repo uses: actions/checkout@v3 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} - name: Set up Python uses: actions/setup-python@v4 From c189890eb0147cc791737522e99c3d83a8634253 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Wed, 28 Jun 2023 14:57:54 -0700 Subject: [PATCH 13/15] more test cases in test_model --- tests/test_model.py | 127 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 120 insertions(+), 7 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index bbdf11f2..4ff01206 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -55,13 +55,15 @@ def test_model( assert stderr == "" +model = CHGNet.load() + + def test_predict_structure() -> None: - model = CHGNet.load() out = model.predict_structure(structure) assert sorted(out) == ["e", "f", "m", "s"] assert out["e"] == pytest.approx(-7.37159, abs=1e-4) - assert len(out["f"]) == len(structure) + force = np.array( [ [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], @@ -75,12 +77,79 @@ def test_predict_structure() -> None: ] ) assert out["f"] == pytest.approx(force, abs=1e-4) - assert len(out["m"]) == len(structure) - assert out["m"] == pytest.approx( - [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538], - abs=1e-4, + + stress = np.array( + [ + [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], + [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], + [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], + ] + ) + assert out["s"] == pytest.approx(stress, abs=1e-4) + + magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] + assert out["m"] == pytest.approx(magmom, abs=1e-4) + + +def test_predict_structure_rotated() -> None: + from pymatgen.transformations.standard_transformations import RotationTransformation + + rotation_transformation = RotationTransformation(axis=[0, 0, 1], angle=30) + rotated_structure = rotation_transformation.apply_transformation(structure) + out = model.predict_structure(rotated_structure) + + assert sorted(out) == ["e", "f", "m", "s"] + assert out["e"] == pytest.approx(-7.37159, abs=1e-4) + + # Define a rotation matrix for rotation about Z-axis by 90 degrees + theta = np.radians(30) # Convert angle to radians + c, s = np.cos(theta), np.sin(theta) + + rotation_matrix = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) + + force = np.array( + [ + [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], + [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], + [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], + [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], + [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], + [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], + [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], + [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + ] + ) + rotated_force = force @ rotation_matrix + assert out["f"] == pytest.approx(rotated_force, abs=1e-4) + + magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] + assert out["m"] == pytest.approx(magmom, abs=1e-4) + + +def test_predict_structure_supercell() -> None: + supercell = structure.copy() + supercell.make_supercell([2, 2, 1]) + out = model.predict_structure(supercell) + + assert sorted(out) == ["e", "f", "m", "s"] + assert out["e"] == pytest.approx(-7.37159, abs=1e-4) + + force = np.array( + [ + [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], + [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], + [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], + [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], + [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], + [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], + [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], + [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + ] ) - assert len(out["s"]) == 3 + for index, f in enumerate(force): + for cell_number in range(4): + assert out["f"][index * 4 + cell_number] == pytest.approx(f, abs=1e-4) + stress = np.array( [ [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], @@ -89,3 +158,47 @@ def test_predict_structure() -> None: ] ) assert out["s"] == pytest.approx(stress, abs=1e-4) + + magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] + for index, m in enumerate(magmom): + for cell_number in range(4): + assert out["m"][index * 4 + cell_number] == pytest.approx(m, abs=1e-4) + + +def test_predict_batched_structures() -> None: + out = model.predict_structure([structure, structure, structure]) + assert out[0]["e"] == pytest.approx(-7.37159, abs=1e-4) + assert out[1]["e"] == pytest.approx(-7.37159, abs=1e-4) + assert out[2]["e"] == pytest.approx(-7.37159, abs=1e-4) + + force = np.array( + [ + [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], + [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], + [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], + [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], + [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], + [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], + [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], + [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + ] + ) + assert out[0]["f"] == pytest.approx(force, abs=1e-4) + assert out[1]["f"] == pytest.approx(force, abs=1e-4) + assert out[2]["f"] == pytest.approx(force, abs=1e-4) + + stress = np.array( + [ + [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], + [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], + [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], + ] + ) + assert out[0]["s"] == pytest.approx(stress, abs=1e-4) + assert out[1]["s"] == pytest.approx(stress, abs=1e-4) + assert out[2]["s"] == pytest.approx(stress, abs=1e-4) + + magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] + assert out[0]["m"] == pytest.approx(magmom, abs=1e-4) + assert out[1]["m"] == pytest.approx(magmom, abs=1e-4) + assert out[2]["m"] == pytest.approx(magmom, abs=1e-4) From 7785e07f9ed9d1d9b8ebe6e380fd798e5c8c9c7e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 15:07:58 -0700 Subject: [PATCH 14/15] reduce repetition in test_predict_batched_structures() --- .github/workflows/test.yml | 2 -- tests/test_model.py | 59 ++++++++++++++++---------------------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 226a6fbe..123711ba 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,8 +30,6 @@ jobs: steps: - name: Check out repo uses: actions/checkout@v3 - with: - repository: ${{ github.event.pull_request.head.repo.full_name }} - name: Set up Python uses: actions/setup-python@v4 diff --git a/tests/test_model.py b/tests/test_model.py index 4ff01206..7b8b6ca2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -166,39 +166,30 @@ def test_predict_structure_supercell() -> None: def test_predict_batched_structures() -> None: - out = model.predict_structure([structure, structure, structure]) - assert out[0]["e"] == pytest.approx(-7.37159, abs=1e-4) - assert out[1]["e"] == pytest.approx(-7.37159, abs=1e-4) - assert out[2]["e"] == pytest.approx(-7.37159, abs=1e-4) - - force = np.array( - [ - [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], - [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], - [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], - [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], - [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], - [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], - [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], - [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], - ] - ) - assert out[0]["f"] == pytest.approx(force, abs=1e-4) - assert out[1]["f"] == pytest.approx(force, abs=1e-4) - assert out[2]["f"] == pytest.approx(force, abs=1e-4) - - stress = np.array( - [ - [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], - [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], - [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], - ] - ) - assert out[0]["s"] == pytest.approx(stress, abs=1e-4) - assert out[1]["s"] == pytest.approx(stress, abs=1e-4) - assert out[2]["s"] == pytest.approx(stress, abs=1e-4) + structs = [structure, structure, structure] + out = model.predict_structure(structs) + assert len(out) == len(structs) + + assert all(preds["e"] == pytest.approx(-7.37159, abs=1e-4) for preds in out) + + force = [ + [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], + [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], + [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], + [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], + [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], + [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], + [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], + [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + ] + assert all(np.allclose(preds["f"], force, atol=1e-4) for preds in out) + + stress = [ + [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], + [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], + [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], + ] + assert all(np.allclose(preds["s"], stress, atol=1e-4) for preds in out) magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] - assert out[0]["m"] == pytest.approx(magmom, abs=1e-4) - assert out[1]["m"] == pytest.approx(magmom, abs=1e-4) - assert out[2]["m"] == pytest.approx(magmom, abs=1e-4) + assert all(preds["m"] == pytest.approx(magmom, abs=1e-4) for preds in out) From ade7e4d199d793152d999f030d3f9b8fafbe611c Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 28 Jun 2023 15:34:15 -0700 Subject: [PATCH 15/15] tweak test_predict_structure_supercell() using np.allclose --- tests/test_model.py | 54 +++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 7b8b6ca2..1f0f507c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -134,35 +134,31 @@ def test_predict_structure_supercell() -> None: assert sorted(out) == ["e", "f", "m", "s"] assert out["e"] == pytest.approx(-7.37159, abs=1e-4) - force = np.array( - [ - [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], - [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], - [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], - [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], - [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], - [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], - [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], - [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], - ] - ) - for index, f in enumerate(force): - for cell_number in range(4): - assert out["f"][index * 4 + cell_number] == pytest.approx(f, abs=1e-4) + forces = [ + [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], + [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], + [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], + [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], + [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], + [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], + [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], + [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + ] + for idx, force in enumerate(forces): + for cell_idx in range(4): + assert np.allclose(out["f"][idx * 4 + cell_idx], force, atol=1e-4) - stress = np.array( - [ - [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], - [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], - [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], - ] - ) - assert out["s"] == pytest.approx(stress, abs=1e-4) + stress = [ + [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], + [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], + [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], + ] + assert np.allclose(out["s"], stress, atol=1e-4) - magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] - for index, m in enumerate(magmom): - for cell_number in range(4): - assert out["m"][index * 4 + cell_number] == pytest.approx(m, abs=1e-4) + magmoms = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] + for idx, magmom in enumerate(magmoms): + for cell_idx in range(4): + assert np.allclose(out["m"][idx * 4 + cell_idx], magmom, atol=1e-4) def test_predict_batched_structures() -> None: @@ -172,7 +168,7 @@ def test_predict_batched_structures() -> None: assert all(preds["e"] == pytest.approx(-7.37159, abs=1e-4) for preds in out) - force = [ + forces = [ [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], @@ -182,7 +178,7 @@ def test_predict_batched_structures() -> None: [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], ] - assert all(np.allclose(preds["f"], force, atol=1e-4) for preds in out) + assert all(np.allclose(preds["f"], forces, atol=1e-4) for preds in out) stress = [ [3.3677614e-01, -1.9665707e-07, -5.6416429e-06],