Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #93 from MeteoSwiss/feature/txtnodes
Browse files Browse the repository at this point in the history
feature: AreaWeights split into SphericalAreaWeights and PlanarAreaWeights
  • Loading branch information
JPXKQX authored Dec 11, 2024
2 parents 67de774 + a40974e commit 3898f6f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Keep it human-readable, your future self will thank you!
- feat: Add `RemoveUnconnectedNodes` post processor to clean unconnected nodes in LAM. (#71)
- feat: Define node sets and edges based on an ICON icosahedral mesh (#53)
- feat: Support for multiple edge builders between two sets of nodes (#70)
- feat: Support for providing lon/lat coordinates from a text file (loaded with numpy loadtxt method) to build the graph `TextNodes` (#93)
- feat: Build 2D graphs with `Voronoi` in case `SphericalVoronoi` does not work well/is an overkill (LAM). Set `flat=true` in the nodes attributes to compute area weight using Voronoi with a qhull options preventing the empty region creation (#93)

# Changed

Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .builders.from_file import LimitedAreaNPZFileNodes
from .builders.from_file import NPZFileNodes
from .builders.from_file import TextNodes
from .builders.from_file import ZarrDatasetNodes
from .builders.from_healpix import HEALPixNodes
from .builders.from_healpix import LimitedAreaHEALPixNodes
Expand All @@ -35,4 +36,5 @@
"ICONMultimeshNodes",
"ICONCellGridNodes",
"ICONNodes",
"TextNodes",
]
80 changes: 64 additions & 16 deletions src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import numpy as np
import torch
from anemoi.datasets import open_dataset
from scipy.spatial import ConvexHull
from scipy.spatial import SphericalVoronoi
from scipy.spatial import Voronoi
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

Expand Down Expand Up @@ -101,6 +103,68 @@ def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
class AreaWeights(BaseNodeAttribute):
"""Implements the area of the nodes as the weights.
Attributes
----------
flat: bool
If True, the area is computed in 2D, otherwise in 3D.
**other: Any
Additional keyword arguments, see PlanarAreaWeights and SphericalAreaWeights
for details.
Methods
-------
compute(self, graph, nodes_name)
Compute the area attributes for each node.
"""

def __new__(cls, flat: bool = False, **kwargs):
logging.warning(
"Creating %s with flat=%s and kwargs=%s. In a future release, AreaWeights will be deprecated: please use directly PlanarAreaWeights or SphericalAreaWeights.",
cls.__name__,
flat,
kwargs,
)
if flat:
return PlanarAreaWeights(**kwargs)
return SphericalAreaWeights(**kwargs)


class PlanarAreaWeights(BaseNodeAttribute):
"""Implements the 2D area of the nodes as the weights.
Attributes
----------
norm : str
Normalisation of the weights.
Methods
-------
compute(self, graph, nodes_name)
Compute the area attributes for each node.
"""

def __init__(
self,
norm: str | None = None,
dtype: str = "float32",
) -> None:
super().__init__(norm, dtype)

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
points = np.stack([latitudes, longitudes], -1)
v = Voronoi(points, qhull_options="QJ Pp")
areas = []
for r in v.regions:
area = ConvexHull(v.vertices[r, :]).volume
areas.append(area)
result = np.asarray(areas)
return result


class SphericalAreaWeights(BaseNodeAttribute):
"""Implements the 3D area of the nodes as the weights.
Attributes
----------
norm : str
Expand Down Expand Up @@ -132,22 +196,6 @@ def __init__(
self.fill_value = fill_value

def get_raw_values(self, nodes: NodeStorage, **kwargs) -> np.ndarray:
"""Compute the area associated to each node.
It uses Voronoi diagrams to compute the area of each node.
Parameters
----------
nodes : NodeStorage
Nodes of the graph.
kwargs : dict
Additional keyword arguments.
Returns
-------
np.ndarray
Attributes.
"""
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
points = latlon_rad_to_cartesian((np.asarray(latitudes), np.asarray(longitudes)))
sv = SphericalVoronoi(points, self.radius, self.centre)
Expand Down
36 changes: 35 additions & 1 deletion src/anemoi/graphs/nodes/builders/from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ def get_coordinates(self) -> torch.Tensor:
return self.reshape_coords(dataset.latitudes, dataset.longitudes)


class TextNodes(BaseNodeBuilder):
"""Nodes from text file.
Attributes
----------
dataset : str | DictConfig
The path to txt file containing the coordinates of the nodes.
idx_lon : int
The index of the longitude in the dataset.
idx_lat : int
The index of the latitude in the dataset.
"""

def __init__(self, dataset, name: str, idx_lon: int = 0, idx_lat: int = 1) -> None:
LOGGER.info("Reading the dataset from %s.", dataset)
self.dataset = np.loadtxt(dataset)
self.idx_lon = idx_lon
self.idx_lat = idx_lat
super().__init__(name)

def get_coordinates(self) -> torch.Tensor:
"""Get the coordinates of the nodes.
Returns
-------
torch.Tensor of shape (num_nodes, 2)
A 2D tensor with the coordinates, in radians.
"""
return self.reshape_coords(self.dataset[self.idx_lat, :], self.dataset[self.idx_lon, :])


class NPZFileNodes(BaseNodeBuilder):
"""Nodes from NPZ defined grids.
Expand Down Expand Up @@ -146,7 +177,10 @@ def get_coordinates(self) -> np.ndarray:
)
area_mask = self.area_mask_builder.get_mask(coords)

LOGGER.info("Dropping %d nodes from the processor mesh.", len(area_mask) - area_mask.sum())
LOGGER.info(
"Dropping %d nodes from the processor mesh.",
len(area_mask) - area_mask.sum(),
)
coords = coords[area_mask]

return coords

0 comments on commit 3898f6f

Please sign in to comment.