Skip to content

Commit

Permalink
Merge pull request #12 from NVlabs/nearest
Browse files Browse the repository at this point in the history
Nearest neighbor interpolator
  • Loading branch information
nbren12 authored Sep 16, 2024
2 parents fa2d3ac + 9cb3cd7 commit 6c6bba8
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 143 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Changelog

## latest

- `earth2grid.latlon.BilinearInterpolator` moved to `earth2grid.BilinearInterpolator`

## 2024.8.1

Expand Down
5 changes: 5 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ Regridding

.. autofunction:: earth2grid.get_regridder

.. autofunction:: earth2grid.KNNS2Interpolator

.. autofunction:: earth2grid.BilinearInterpolator

Other utilities
---------------

.. autofunction:: earth2grid.healpix.reorder
.. autofunction:: earth2grid.healpix.pad
30 changes: 28 additions & 2 deletions earth2grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from earth2grid import base, healpix, latlon
from earth2grid._regrid import get_regridder
from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder

__all__ = [
"base",
"healpix",
"latlon",
"get_regridder",
"BilinearInterpolator",
"KNNS2Interpolator",
"Regridder",
]


def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
"""Get a regridder from `src` to `dest`"""
if src == dest:
return Identity()
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, latlon.LatLonGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(dest, healpix.Grid):
return src.get_healpix_regridder(dest) # type: ignore

__all__ = ["base", "healpix", "latlon", "get_regridder"]
raise ValueError(src, dest, "not supported.")
200 changes: 182 additions & 18 deletions earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Sequence

import einops
import netCDF4 as nc
import torch
from scipy import spatial

from earth2grid.spatial import ang2vec, haversine_distance


class Regridder(torch.nn.Module):
"""Regridder to n points, with p nonzero maps weights
Forward:
(*, m) -> (*,) + shape
"""

def __init__(self, shape: Sequence[int], p: int):
super().__init__()
self.register_buffer("index", torch.empty(*shape, p, dtype=torch.long))
self.register_buffer("weight", torch.ones(*shape, p))

def forward(self, z):
*shape, x = z.shape
zrs = z.view(-1, x).T

*output_shape, p = self.index.shape
index = self.index.view(-1, p)
weight = self.weight.view(-1, p)

from earth2grid import base, healpix
from earth2grid.latlon import LatLonGrid
# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum')
output = output.T.view(*shape, -1)
return output.reshape(list(shape) + output_shape)

@staticmethod
def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder":
n, p = d["index"].shape
regridder = Regridder((n,), p)
regridder.load_state_dict(d)
return regridder


class TempestRegridder(torch.nn.Module):
Expand Down Expand Up @@ -48,22 +84,150 @@ def forward(self, x):
return y


class BilinearInterpolator(torch.nn.Module):
"""Bilinear interpolation for a non-uniform grid"""

def __init__(
self,
x_coords: torch.Tensor,
y_coords: torch.Tensor,
x_query: torch.Tensor,
y_query: torch.Tensor,
fill_value=math.nan,
) -> None:
"""
Args:
x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order.
y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order.
x_query (Tensor): X-coordinates for query points, shape [N].
y_query (Tensor): Y-coordinates for query points, shape [N].
"""
super().__init__()
self.fill_value = fill_value

# Ensure input coordinates are float for interpolation
x_coords, y_coords = x_coords.double(), y_coords.double()
x_query = x_query.double()
y_query = y_query.double()

if torch.any(x_coords[1:] < x_coords[:-1]):
raise ValueError("x_coords must be in non-decreasing order.")

if torch.any(y_coords[1:] < y_coords[:-1]):
raise ValueError("y_coords must be in non-decreasing order.")

# Find indices for the closest lower and upper bounds in x and y directions
x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1
x_u_idx = x_l_idx + 1
y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1
y_u_idx = y_l_idx + 1

# fill in nan outside mask
def isin(x, a, b):
return (x <= b) & (x >= a)

mask = (
isin(x_l_idx, 0, x_coords.size(0) - 2)
& isin(x_u_idx, 1, x_coords.size(0) - 1)
& isin(y_l_idx, 0, y_coords.size(0) - 2)
& isin(y_u_idx, 1, y_coords.size(0) - 1)
)
x_u_idx = x_u_idx[mask]
x_l_idx = x_l_idx[mask]
y_u_idx = y_u_idx[mask]
y_l_idx = y_l_idx[mask]
x_query = x_query[mask]
y_query = y_query[mask]

# Compute weights
x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx])
x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx])
y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx])
y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx])
weights = torch.stack(
[x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1
)

stride = x_coords.size(-1)
index = torch.stack(
[
x_l_idx + stride * y_l_idx,
x_u_idx + stride * y_l_idx,
x_l_idx + stride * y_u_idx,
x_u_idx + stride * y_u_idx,
],
dim=-1,
)
self.register_buffer("weights", weights)
self.register_buffer("mask", mask)
self.register_buffer("index", index)

def forward(self, z: torch.Tensor):
"""
Interpolate the field
Args:
z: shape [Y, X]
"""
*shape, y, x = z.shape
zrs = z.view(-1, y * x).T
# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
interpolated = torch.full(
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
)
interpolated.masked_scatter_(self.mask.unsqueeze(-1), output)
interpolated = interpolated.T.view(*shape, self.mask.numel())
return interpolated


def KNNS2Interpolator(
src_lon: torch.Tensor,
src_lat: torch.Tensor,
dest_lon: torch.Tensor,
dest_lat: torch.Tensor,
k: int = 1,
eps=1e-7,
) -> Regridder:
"""K-nearest neighbor interpolator with inverse distance weighting
Args:
src_lon: (m,) source longitude in degrees E
src_lat: (m,) source latitude in degrees N
dest_lon: (n,) output longitude in degrees E
dest_lat: (n,) output latitude in degrees N
k: number of neighbors, default: 1
eps: regularization factor for inverse distance weighting. Only used if
k > 1.
"""
if (src_lat.ndim != 1) or (src_lon.ndim != 1) or (dest_lat.ndim != 1) or (dest_lon.ndim != 1):
raise ValueError("All input coordinates must be 1 dimensional.")

src_lon = torch.deg2rad(src_lon.cpu())
src_lat = torch.deg2rad(src_lat.cpu())

dest_lon = torch.deg2rad(dest_lon.cpu())
dest_lat = torch.deg2rad(dest_lat.cpu())

vec = torch.stack(ang2vec(src_lon, src_lat), -1)

# havesign distance and euclidean are monotone for points on S2 so can use 3d lookups.
tree = spatial.KDTree(vec)
vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1)
_, neighbors = tree.query(vec, k=k)
regridder = Regridder(dest_lon.shape, k)
regridder.index.copy_(torch.as_tensor(neighbors).view(-1, k))
if k > 1:
d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors])
lam = 1 / (d + eps)
lam = lam / lam.sum(-1, keepdim=True)
regridder.weight.copy_(lam)

return regridder


class Identity(torch.nn.Module):
def forward(self, x):
return x


def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
"""Get a regridder from `src` to `dest`"""
if src == dest:
return Identity()
elif isinstance(src, LatLonGrid) and isinstance(dest, LatLonGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(dest, healpix.Grid):
return src.get_healpix_regridder(dest) # type: ignore

raise ValueError(src, dest, "not supported.")
32 changes: 11 additions & 21 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import torch

from earth2grid import healpix_bare
from earth2grid._regrid import Regridder

try:
import pyvista as pv
Expand Down Expand Up @@ -230,26 +231,6 @@ def _convert_xyindex(nside: int, src: XY, dest: XY, i):
return i


class ApplyWeights(torch.nn.Module):
def __init__(self, pix: torch.Tensor, weight: torch.Tensor):
super().__init__()

# the first dim is the 4 point stencil
n, *self.shape = pix.shape

pix = pix.view(n, -1).T
weight = weight.view(n, -1).T

self.register_buffer("index", pix)
self.register_buffer("weight", weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
*shape, npix = x.shape
x = x.view(-1, npix).T
interpolated = torch.nn.functional.embedding_bag(self.index, x, per_sample_weights=self.weight, mode="sum").T
return interpolated.view(shape + self.shape)


@dataclass
class Grid(base.Grid):
"""A Healpix Grid
Expand Down Expand Up @@ -345,7 +326,16 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat))
i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel())
i_me = self._nest2me(i_nest).reshape(i_ring.shape)
return ApplyWeights(i_me, weights)

# reshape to (*, p)
weights = weights.movedim(0, -1)
index = i_me.movedim(0, -1)

regridder = Regridder(weights.shape[:-1], p=weights.shape[-1])
regridder.to(weights)
regridder.index.copy_(index)
regridder.weight.copy_(weights)
return regridder

def approximate_grid_length_meters(self):
return approx_grid_length_meters(self._nside())
Expand Down
Loading

0 comments on commit 6c6bba8

Please sign in to comment.