diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 0b09649..bf3cb0c 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -46,11 +46,14 @@ def generate_graph(self) -> HeteroData: HeteroData: The generated graph. """ graph = HeteroData() - for name, nodes_cfg in self.config.nodes.items(): - graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {})) + + for nodes_cfg in self.config.nodes: + graph = instantiate(nodes_cfg.node_builder, name=nodes_cfg.name).update_graph( + graph, nodes_cfg.get("attributes", {}) + ) for edges_cfg in self.config.edges: - graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph( + graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph( graph, edges_cfg.get("attributes", {}) ) diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 6945867..9a8d6d8 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -1,7 +1,6 @@ import logging from abc import ABC from abc import abstractmethod -from dataclasses import dataclass from typing import Optional import numpy as np @@ -15,11 +14,11 @@ LOGGER = logging.getLogger(__name__) -@dataclass class BaseEdgeAttribute(ABC, NormalizerMixin): """Base class for edge attributes.""" - norm: Optional[str] = None + def __init__(self, norm: Optional[str] = None) -> None: + self.norm = norm @abstractmethod def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ... @@ -29,10 +28,13 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - return torch.tensor(values) + normed_values = self.normalize(values) + + return torch.tensor(normed_values, dtype=torch.float32) - def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor: + def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor: """Compute the edge attributes.""" + source_name, _, target_name = edges_name assert ( source_name in graph.node_types ), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." @@ -41,13 +43,11 @@ def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, ), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}." values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs) - normed_values = self.normalize(values) - return self.post_process(normed_values) + return self.post_process(values) -@dataclass class EdgeDirection(BaseEdgeAttribute): - """Compute directional features for edges. + """Edge direction feature. If using the rotated features, the direction of the edge is computed rotating the target nodes to the north pole. If not, it is computed @@ -69,8 +69,9 @@ class EdgeDirection(BaseEdgeAttribute): Compute directional attributes. """ - norm: Optional[str] = None - luse_rotated_features: bool = True + def __init__(self, norm: Optional[str] = None, luse_rotated_features: bool = True) -> None: + super().__init__(norm) + self.luse_rotated_features = luse_rotated_features def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: """Compute directional features for edges. @@ -96,7 +97,6 @@ def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) return edge_dirs -@dataclass class EdgeLength(BaseEdgeAttribute): """Edge length feature. @@ -115,8 +115,9 @@ class EdgeLength(BaseEdgeAttribute): Compute edge lengths attributes. """ - norm: str = "l1" - invert: bool = True + def __init__(self, norm: Optional[str] = None, invert: bool = False) -> None: + super().__init__(norm) + self.invert = invert def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray: """Compute haversine distance (in kilometers) between nodes connected by edges. diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 3c926d6..17ba4fc 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -18,34 +18,61 @@ class BaseEdgeBuilder(ABC): + """Base class for edge builders.""" def __init__(self, source_name: str, target_name: str): - super().__init__() self.source_name = source_name self.target_name = target_name + @property + def name(self) -> tuple[str, str, str]: + """Name of the edge subgraph.""" + return self.source_name, "to", self.target_name + @abstractmethod def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ... - def register_edges(self, graph: HeteroData, source_indices: np.ndarray, target_indices: np.ndarray) -> HeteroData: + def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: + """Prepare nodes information.""" + return graph[self.source_name], graph[self.target_name] + + def get_edge_index(self, graph: HeteroData) -> torch.Tensor: + """Get the edge indices of source and target nodes. + + Parameters + ---------- + graph : HeteroData + The graph. + + Returns + ------- + torch.Tensor of shape (2, num_edges) + The edge indices. + """ + source_nodes, target_nodes = self.prepare_node_data(graph) + + adjmat = self.get_adjacency_matrix(source_nodes, target_nodes) + + # Get source & target indices of the edges + edge_index = np.stack([adjmat.col, adjmat.row], axis=0) + + return torch.from_numpy(edge_index).to(torch.int32) + + def register_edges(self, graph: HeteroData) -> HeteroData: """Register edges in the graph. Parameters ---------- graph : HeteroData The graph to register the edges. - source_indices : np.ndarray of shape (N, ) - The indices of the source nodes. - target_indices : np.ndarray of shape (N, ) - The indices of the target nodes. Returns ------- HeteroData The graph with the registered edges. """ - edge_index = np.stack([source_indices, target_indices], axis=0).astype(np.int32) - graph[(self.source_name, "to", self.target_name)].edge_index = torch.from_numpy(edge_index) + graph[self.name].edge_index = self.get_edge_index(graph) + graph[self.name].edge_type = type(self).__name__ return graph def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: @@ -64,15 +91,9 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: The graph with the registered attributes. """ for attr_name, attr_config in config.items(): - graph[self.source_name, "to", self.target_name][attr_name] = instantiate(attr_config).compute( - graph, self.source_name, self.target_name - ) + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) return graph - def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]: - """Prepare nodes information.""" - return graph[self.source_name], graph[self.target_name] - def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: """Update the graph with the edges. @@ -88,11 +109,7 @@ def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None HeteroData The graph with the edges. """ - source_nodes, target_nodes = self.prepare_node_data(graph) - - adjmat = self.get_adjacency_matrix(source_nodes, target_nodes) - - graph = self.register_edges(graph, adjmat.col, adjmat.row) + graph = self.register_edges(graph) if attrs_config is None: return graph @@ -113,6 +130,17 @@ class KNNEdges(BaseEdgeBuilder): The name of the target nodes. num_nearest_neighbours : int Number of nearest neighbours. + + Methods + ------- + get_adjacency_matrix(source_nodes, target_nodes) + Compute the adjacency matrix for the KNN method. + register_edges(graph) + Register the edges in the graph. + register_attributes(graph, config) + Register attributes in the edges of the graph. + update_graph(graph, attrs_config) + Update the graph with the edges. """ def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int): @@ -162,6 +190,19 @@ class CutOffEdges(BaseEdgeBuilder): Factor to multiply the grid reference distance to get the cut-off radius. radius : float Cut-off radius. + + Methods + ------- + get_cutoff_radius(graph, mask_attr) + Compute the cut-off radius. + get_adjacency_matrix(source_nodes, target_nodes) + Get the adjacency matrix for the cut-off method. + register_edges(graph) + Register the edges in the graph. + register_attributes(graph, config) + Register attributes in the edges of the graph. + update_graph(graph, attrs_config) + Update the graph with the edges. """ def __init__(self, source_name: str, target_name: str, cutoff_factor: float): diff --git a/src/anemoi/graphs/nodes/attributes.py b/src/anemoi/graphs/nodes/attributes.py index b1942b7..911ce99 100644 --- a/src/anemoi/graphs/nodes/attributes.py +++ b/src/anemoi/graphs/nodes/attributes.py @@ -1,12 +1,12 @@ import logging from abc import ABC from abc import abstractmethod -from dataclasses import dataclass from typing import Optional import numpy as np import torch from scipy.spatial import SphericalVoronoi +from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian @@ -15,11 +15,11 @@ LOGGER = logging.getLogger(__name__) -@dataclass class BaseWeights(ABC, NormalizerMixin): """Base class for the weights of the nodes.""" - norm: Optional[str] = None + def __init__(self, norm: Optional[str] = None) -> None: + self.norm = norm @abstractmethod def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ... @@ -29,19 +29,28 @@ def post_process(self, values: np.ndarray) -> torch.Tensor: if values.ndim == 1: values = values[:, np.newaxis] - return torch.tensor(values) + norm_values = self.normalize(values) - def compute(self, nodes: NodeStorage, *args, **kwargs) -> torch.Tensor: + return torch.tensor(norm_values, dtype=torch.float32) + + def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor: """Get the node weights. + Parameters + ---------- + graph : HeteroData + Graph. + nodes_name : str + Name of the nodes. + Returns ------- torch.Tensor Weights associated to the nodes. """ + nodes = graph[nodes_name] weights = self.get_raw_values(nodes, *args, **kwargs) - norm_weights = self.normalize(weights) - return self.post_process(norm_weights) + return self.post_process(weights) class UniformWeights(BaseWeights): @@ -63,7 +72,6 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: return np.ones(nodes.num_nodes) -@dataclass class AreaWeights(BaseWeights): """Implements the area of the nodes as the weights. @@ -84,9 +92,12 @@ class AreaWeights(BaseWeights): Compute the area attributes for each node. """ - norm: Optional[str] = "unit-max" - radius: float = 1.0 - centre: np.ndarray = np.array([0, 0, 0]) + def __init__( + self, norm: Optional[str] = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0]) + ) -> None: + super().__init__(norm) + self.radius = radius + self.centre = centre def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray: """Compute the area associated to each node. diff --git a/src/anemoi/graphs/nodes/builder.py b/src/anemoi/graphs/nodes/builder.py index 11e99f6..6ff37a1 100644 --- a/src/anemoi/graphs/nodes/builder.py +++ b/src/anemoi/graphs/nodes/builder.py @@ -15,31 +15,33 @@ class BaseNodeBuilder(ABC): - """Base class for node builders.""" + """Base class for node builders. - def register_nodes(self, graph: HeteroData, name: str) -> None: + The node coordinates are stored in the `x` attribute of the nodes and they are stored in radians. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def register_nodes(self, graph: HeteroData) -> None: """Register nodes in the graph. Parameters ---------- graph : HeteroData The graph to register the nodes. - name : str - The name of the nodes. """ - graph[name].x = self.get_coordinates() - graph[name].node_type = type(self).__name__ + graph[self.name].x = self.get_coordinates() + graph[self.name].node_type = type(self).__name__ return graph - def register_attributes(self, graph: HeteroData, name: str, config: Optional[DotDict] = None) -> HeteroData: + def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData: """Register attributes in the nodes of the graph specified. Parameters ---------- graph : HeteroData The graph to register the attributes. - name : str - The name of the nodes. config : DotDict The configuration of the attributes. @@ -48,11 +50,8 @@ def register_attributes(self, graph: HeteroData, name: str, config: Optional[Dot HeteroData The graph with the registered attributes. """ - if config is None: - return graph - for attr_name, attr_config in config.items(): - graph[name][attr_name] = instantiate(attr_config).compute(graph[name]) + graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name) return graph @abstractmethod @@ -77,15 +76,13 @@ def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> torch coords = np.deg2rad(coords) return torch.tensor(coords, dtype=torch.float32) - def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDict] = None) -> HeteroData: + def update_graph(self, graph: HeteroData, attr_config: Optional[DotDict] = None) -> HeteroData: """Update the graph with new nodes. Parameters ---------- graph : HeteroData Input graph. - name : str - The name of the nodes. attr_config : DotDict The configuration of the attributes. @@ -94,12 +91,13 @@ def update_graph(self, graph: HeteroData, name: str, attr_config: Optional[DotDi HeteroData The graph with new nodes included. """ - graph = self.register_nodes(graph, name) + graph = self.register_nodes(graph) if attr_config is None: return graph - graph = self.register_attributes(graph, name, attr_config) + graph = self.register_attributes(graph, attr_config) + return graph @@ -110,11 +108,23 @@ class ZarrDatasetNodes(BaseNodeBuilder): ---------- ds : zarr.core.Array The dataset. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. """ - def __init__(self, dataset: DotDict) -> None: + def __init__(self, dataset: DotDict, name: str) -> None: LOGGER.info("Reading the dataset from %s.", dataset) self.ds = open_dataset(dataset) + super().__init__(name) def get_coordinates(self) -> torch.Tensor: """Get the coordinates of the nodes. @@ -138,9 +148,20 @@ class NPZFileNodes(BaseNodeBuilder): Path to the folder containing the grid definition files. grid_definition : dict[str, np.ndarray] The grid definition. + + Methods + ------- + get_coordinates() + Get the lat-lon coordinates of the nodes. + register_nodes(graph, name) + Register the nodes in the graph. + register_attributes(graph, name, config) + Register the attributes in the nodes of the graph specified. + update_graph(graph, name, attr_config) + Update the graph with new nodes and attributes. """ - def __init__(self, resolution: str, grid_definition_path: str) -> None: + def __init__(self, resolution: str, grid_definition_path: str, name: str) -> None: """Initialize the NPZFileNodes builder. The builder suppose the grids are stored in files with the name `grid-{resolution}.npz`. @@ -155,6 +176,7 @@ def __init__(self, resolution: str, grid_definition_path: str) -> None: self.resolution = resolution self.grid_definition_path = grid_definition_path self.grid_definition = np.load(Path(self.grid_definition_path) / f"grid-{self.resolution}.npz") + super().__init__(name) def get_coordinates(self) -> torch.Tensor: """Get the coordinates of the nodes. diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 9c261e5..c625417 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -16,7 +16,7 @@ def normalize(self, values: np.ndarray) -> np.ndarray: Parameters ---------- - values : np.ndarray + values : np.ndarray of shape (N, M) Values to normalize. Returns @@ -25,7 +25,7 @@ def normalize(self, values: np.ndarray) -> np.ndarray: Normalized values. """ if self.norm is None: - LOGGER.debug("Node weights are not normalized.") + LOGGER.debug(f"{self.__class__.__name__} values are not normalized.") return values if self.norm == "l1": return values / np.sum(values) @@ -36,9 +36,9 @@ def normalize(self, values: np.ndarray) -> np.ndarray: if self.norm == "unit-std": std = np.std(values) if std == 0: - LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} is 0. Cannot normalize.") + LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} values is 0. Normalization is skipped.") return values return values / std raise ValueError( - f"Weight normalization \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." + f"Attribute normalization \"{self.norm}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." ) diff --git a/tests/conftest.py b/tests/conftest.py index 290165c..1dc76de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,28 +57,26 @@ def graph_nodes_and_edges() -> HeteroData: def config_file(tmp_path) -> tuple[str, str]: """Mock grid_definition_path with files for 3 resolutions.""" cfg = { - "nodes": { - "test_nodes": { + "nodes": [ + { + "name": "test_nodes", "node_builder": { "_target_": "anemoi.graphs.nodes.NPZFileNodes", "grid_definition_path": str(tmp_path), "resolution": "o16", }, } - }, + ], "edges": [ { - "nodes": {"source_name": "test_nodes", "target_name": "test_nodes"}, + "source_name": "test_nodes", + "target_name": "test_nodes", "edge_builder": { "_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3, }, "attributes": { - "dist_norm": { - "_target_": "anemoi.graphs.edges.attributes.EdgeLength", - "norm": "l1", - "invert": True, - }, + "dist_norm": {"_target_": "anemoi.graphs.edges.attributes.EdgeLength"}, "edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"}, }, }, diff --git a/tests/edges/test_edge_attributes.py b/tests/edges/test_edge_attributes.py index b0bbede..40cba1c 100644 --- a/tests/edges/test_edge_attributes.py +++ b/tests/edges/test_edge_attributes.py @@ -10,7 +10,7 @@ def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features: bool): """Test EdgeDirection compute method.""" edge_attr_builder = EdgeDirection(norm=norm, luse_rotated_features=luse_rotated_features) - edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, "test_nodes", "test_nodes") + edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "test_nodes")) assert isinstance(edge_attr, torch.Tensor) @@ -18,12 +18,12 @@ def test_directional_features(graph_nodes_and_edges, norm, luse_rotated_features def test_edge_lengths(graph_nodes_and_edges, norm): """Test EdgeLength compute method.""" edge_attr_builder = EdgeLength(norm=norm) - edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, "test_nodes", "test_nodes") + edge_attr = edge_attr_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "test_nodes")) assert isinstance(edge_attr, torch.Tensor) @pytest.mark.parametrize("attribute_builder", [EdgeDirection(), EdgeLength()]) def test_fail_edge_features(attribute_builder, graph_nodes_and_edges): - """Test EdgeDirection compute method.""" + """Test edge attribute builder fails with unknown nodes.""" with pytest.raises(AssertionError): - attribute_builder.compute(graph_nodes_and_edges, "test_nodes", "unknown_nodes") + attribute_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "unknown_nodes")) diff --git a/tests/nodes/test_node_attributes.py b/tests/nodes/test_node_attributes.py index 3d7e5be..7347d88 100644 --- a/tests/nodes/test_node_attributes.py +++ b/tests/nodes/test_node_attributes.py @@ -8,9 +8,9 @@ @pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-std"]) def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): - """Test NPZNodes register correctly the weights.""" + """Test attribute builder for UniformWeights.""" node_attr_builder = UniformWeights(norm=norm) - weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) + weights = node_attr_builder.compute(graph_with_nodes, "test_nodes") assert weights is not None assert isinstance(weights, torch.Tensor) @@ -19,16 +19,16 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): @pytest.mark.parametrize("norm", ["l3", "invalide"]) def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): - """Test NPZNodes register correctly the weights.""" + """Test attribute builder for UniformWeights with invalid norm.""" with pytest.raises(ValueError): node_attr_builder = UniformWeights(norm=norm) - node_attr_builder.compute(graph_with_nodes["test_nodes"]) + node_attr_builder.compute(graph_with_nodes, "test_nodes") def test_area_weights(graph_with_nodes: HeteroData): - """Test NPZNodes register correctly the weights.""" + """Test attribute builder for AreaWeights.""" node_attr_builder = AreaWeights() - weights = node_attr_builder.compute(graph_with_nodes["test_nodes"]) + weights = node_attr_builder.compute(graph_with_nodes, "test_nodes") assert weights is not None assert isinstance(weights, torch.Tensor) @@ -37,6 +37,7 @@ def test_area_weights(graph_with_nodes: HeteroData): @pytest.mark.parametrize("radius", [-1.0, "hello", None]) def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): + """Test attribute builder for AreaWeights with invalid radius.""" with pytest.raises(ValueError): node_attr_builder = AreaWeights(radius=radius) - node_attr_builder.compute(graph_with_nodes["test_nodes"]) + node_attr_builder.compute(graph_with_nodes, "test_nodes") diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index fc4cf8c..95d09c0 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -9,34 +9,34 @@ @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) def test_init(mock_grids_path: tuple[str, int], resolution: str): - """Test NPZNodes initialization.""" + """Test NPZFileNodes initialization.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes(resolution, grid_definition_path, "test_nodes") assert isinstance(node_builder, NPZFileNodes) @pytest.mark.parametrize("resolution", ["o17", 13, "ajsnb", None]) def test_fail_init_wrong_resolution(mock_grids_path: tuple[str, int], resolution: str): - """Test NPZNodes initialization with invalid resolution.""" + """Test NPZFileNodes initialization with invalid resolution.""" grid_definition_path, _ = mock_grids_path with pytest.raises(FileNotFoundError): - NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + NPZFileNodes(resolution, grid_definition_path, "test_nodes") def test_fail_init_wrong_path(): - """Test NPZNodes initialization with invalid path.""" + """Test NPZFileNodes initialization with invalid path.""" with pytest.raises(FileNotFoundError): - NPZFileNodes("o16", "invalid_path") + NPZFileNodes("o16", "invalid_path", "test_nodes") @pytest.mark.parametrize("resolution", ["o16", "o48", "5km5"]) def test_register_nodes(mock_grids_path: str, resolution: str): - """Test NPZNodes register correctly the nodes.""" + """Test NPZFileNodes register correctly the nodes.""" graph = HeteroData() grid_definition_path, num_nodes = mock_grids_path - node_builder = NPZFileNodes(resolution, grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes(resolution, grid_definition_path, "test_nodes") - graph = node_builder.register_nodes(graph, "test_nodes") + graph = node_builder.register_nodes(graph) assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) @@ -46,12 +46,12 @@ def test_register_nodes(mock_grids_path: str, resolution: str): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(graph_with_nodes: HeteroData, mock_grids_path: tuple[str, int], attr_class): - """Test NPZNodes register correctly the weights.""" + """Test NPZFileNodes register correctly the weights.""" grid_definition_path, _ = mock_grids_path - node_builder = NPZFileNodes("o16", grid_definition_path=grid_definition_path) + node_builder = NPZFileNodes("o16", grid_definition_path, "test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} - graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + graph = node_builder.register_attributes(graph_with_nodes, config) assert graph["test_nodes"]["test_attr"] is not None assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index 0e91ece..e3a2687 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -9,26 +9,27 @@ def test_init(mocker, mock_zarr_dataset): - """Test ZarrNodes initialization.""" + """Test ZarrDatasetNodes initialization.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") + assert isinstance(node_builder, builder.BaseNodeBuilder) assert isinstance(node_builder, builder.ZarrDatasetNodes) def test_fail_init(): - """Test ZarrNodes initialization with invalid resolution.""" + """Test ZarrDatasetNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - builder.ZarrDatasetNodes("invalid_path.zarr") + builder.ZarrDatasetNodes("invalid_path.zarr", name="test_nodes") def test_register_nodes(mocker, mock_zarr_dataset): - """Test ZarrNodes register correctly the nodes.""" + """Test ZarrDatasetNodes register correctly the nodes.""" mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) - node_builder = builder.ZarrDatasetNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") graph = HeteroData() - graph = node_builder.register_nodes(graph, "test_nodes") + graph = node_builder.register_nodes(graph) assert graph["test_nodes"].x is not None assert isinstance(graph["test_nodes"].x, torch.Tensor) @@ -38,12 +39,12 @@ def test_register_nodes(mocker, mock_zarr_dataset): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_attributes(mocker, graph_with_nodes: HeteroData, attr_class): - """Test ZarrNodes register correctly the weights.""" + """Test ZarrDatasetNodes register correctly the weights.""" mocker.patch.object(builder, "open_dataset", return_value=None) - node_builder = builder.ZarrDatasetNodes("dataset.zarr") + node_builder = builder.ZarrDatasetNodes("dataset.zarr", name="test_nodes") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.attributes.{attr_class.__name__}"}} - graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) + graph = node_builder.register_attributes(graph_with_nodes, config) assert graph["test_nodes"]["test_attr"] is not None assert isinstance(graph["test_nodes"]["test_attr"], torch.Tensor) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 50731f7..ba2704f 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -30,7 +30,9 @@ def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]) for nodes in graph.node_stores: for node_attr in nodes.node_attrs(): assert isinstance(nodes[node_attr], torch.Tensor) + assert nodes[node_attr].dtype in [torch.int32, torch.float32] for edges in graph.edge_stores: for edge_attr in edges.edge_attrs(): assert isinstance(edges[edge_attr], torch.Tensor) + assert edges[edge_attr].dtype in [torch.int32, torch.float32]