From e2f2392f6c5d79001a0f5cbbc7cad3c6f374b0c2 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Mon, 16 Sep 2024 12:43:28 -0700 Subject: [PATCH] respond to reviews --- earth2grid/__init__.py | 4 ++-- earth2grid/_regrid.py | 5 ++++- tests/test_regrid.py | 13 ++++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 2747948..6d52f27 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -15,7 +15,7 @@ import torch from earth2grid import base, healpix, latlon -from earth2grid._regrid import BilinearInterpolator, Identity, Regridder, S2NearestNeighborInterpolator +from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder __all__ = [ "base", @@ -23,7 +23,7 @@ "latlon", "get_regridder", "BilinearInterpolator", - "S2NearestNeighborInterpolator", + "KNNS2Interpolator", "Regridder", ] diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 8eb23fc..5e9e5e8 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -182,7 +182,7 @@ def forward(self, z: torch.Tensor): return interpolated -def S2NearestNeighborInterpolator( +def KNNS2Interpolator( src_lon: torch.Tensor, src_lat: torch.Tensor, dest_lon: torch.Tensor, @@ -202,6 +202,9 @@ def S2NearestNeighborInterpolator( k > 1. """ + if (src_lat.ndim != 1) or (src_lon.ndim != 1) or (dest_lat.ndim != 1) or (dest_lon.ndim != 1): + raise ValueError("All input coordinates must be 1 dimensional.") + src_lon = torch.deg2rad(src_lon.cpu()) src_lat = torch.deg2rad(src_lat.cpu()) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 6d3004a..73a6b9e 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -200,14 +200,15 @@ def test_out_of_bounds(): @pytest.mark.parametrize("k", [1, 2, 3]) def test_NearestNeighborInterpolator(k): n = 10000 + m = 887 torch.manual_seed(0) lon = torch.rand(n) * 360 lat = torch.rand(n) * 180 - 90 - lond = torch.rand(n) * 360 - latd = torch.rand(n) * 180 - 90 + lond = torch.rand(m) * 360 + latd = torch.rand(m) * 180 - 90 - interpolate = earth2grid.S2NearestNeighborInterpolator(lon, lat, lond, latd, k=k) + interpolate = earth2grid.KNNS2Interpolator(lon, lat, lond, latd, k=k) out = interpolate(torch.cos(torch.deg2rad(lon))) expected = torch.cos(torch.deg2rad(lond)) mae = torch.mean(torch.abs(out - expected)) @@ -215,3 +216,9 @@ def test_NearestNeighborInterpolator(k): # load-reload earth2grid.Regridder.from_state_dict(interpolate.state_dict()) + + # try batched interpolation + x = torch.cos(torch.deg2rad(lon)) + x = x.unsqueeze(0) + out = interpolate(x) + assert out.shape == (1, m)