From da7684f4e2bfd704422a943c8c1c03c283ad31ae Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Thu, 26 Sep 2024 13:14:14 -0700 Subject: [PATCH] add earth2grid.healpix.ang2pix --- earth2grid/csrc/healpix_bare_wrapper.cpp | 15 ++++++++++ earth2grid/healpix.py | 3 +- earth2grid/healpix_bare.py | 36 +++++++++++++++++++++++- tests/test_healpix_bare.py | 18 ++++++++++++ 4 files changed, 70 insertions(+), 2 deletions(-) diff --git a/earth2grid/csrc/healpix_bare_wrapper.cpp b/earth2grid/csrc/healpix_bare_wrapper.cpp index 2cb0537..e05ff8d 100644 --- a/earth2grid/csrc/healpix_bare_wrapper.cpp +++ b/earth2grid/csrc/healpix_bare_wrapper.cpp @@ -93,6 +93,20 @@ torch::Tensor hpd2loc_wrapper(int nside, torch::Tensor input) { return output; } +torch::Tensor ang2ring_wrapper(int nside, torch::Tensor ang) { + auto ang_accessor = ang.accessor(); + auto output_options = torch::TensorOptions().dtype(torch::kInt64); + auto output = torch::empty({ang.size(0)}, output_options); + + auto output_accessor = output.accessor(); + + for (int64_t i = 0; i < ang.size(0); ++i) { + t_ang angi {ang_accessor[i][0], ang_accessor[i][1]}; + output_accessor[i] = ang2ring(nside, angi); + } + return output; +} + torch::Tensor hpc2loc_wrapper(torch::Tensor x, torch::Tensor y, torch::Tensor f) { auto accessor_f = f.accessor(); auto accessor_x = x.accessor(); @@ -223,6 +237,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ring2hpd", wrap_2hpd(ring2hpd), "hpd is f, y ,x"); m.def("hpd2loc", &hpd2loc_wrapper, "loc is in z, s, phi"); m.def("hpc2loc", &hpc2loc_wrapper, "hpc2loc(x, y, f) -> z, s, phi"); + m.def("ang2ring", &ang2ring_wrapper, "ang2ring(nside, ang) -> pix"); m.def("corners", &corners, ""); m.def("get_interp_weights", &get_interp_weights, ""); }; diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 543e823..4633ba7 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -44,6 +44,7 @@ from earth2grid import healpix_bare from earth2grid._regrid import Regridder +from earth2grid.healpix_bare import ang2pix try: import pyvista as pv @@ -67,7 +68,7 @@ except ImportError: cuhpx = None -__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder"] +__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder", "ang2pix"] def pad(x: torch.Tensor, padding: int) -> torch.Tensor: diff --git a/earth2grid/healpix_bare.py b/earth2grid/healpix_bare.py index 7ec3a3b..17588b8 100644 --- a/earth2grid/healpix_bare.py +++ b/earth2grid/healpix_bare.py @@ -15,10 +15,11 @@ import torch from earth2grid import _healpix_bare -from earth2grid._healpix_bare import corners, hpc2loc, hpd2loc, nest2hpd, nest2ring, ring2hpd, ring2nest +from earth2grid._healpix_bare import ang2ring, corners, hpc2loc, hpd2loc, nest2hpd, nest2ring, ring2hpd, ring2nest __all__ = [ "pix2ang", + "ang2pix", "ring2nest", "nest2ring", "hpc2loc", @@ -27,6 +28,12 @@ def pix2ang(nside, i, nest=False, lonlat=False): + """ + Returns: + theta, phi: (lon, lat) in degrees if lonlat=True else (colat, lon) in + radians + + """ if nest: hpd = nest2hpd(nside, i) else: @@ -40,6 +47,29 @@ def pix2ang(nside, i, nest=False, lonlat=False): return lat, lon +def ang2pix(nside, theta, phi, nest=False, lonlat=False): + """Find the pixel containing a given angular coordinate + + Args: + theta, phi: (lon, lat) in degrees if lonlat=True else (colat, lon) in + radians + + """ + if lonlat: + lon = theta + lat = phi + + theta = torch.deg2rad(90 - lat) + phi = torch.deg2rad(lon) + + ang = torch.stack([theta, phi], -1) + pix = ang2ring(nside, ang.double()) + if nest: + pix = ring2nest(nside, pix) + + return pix + + def _loc2ang(loc): """ static t_ang loc2ang(tloc loc) @@ -51,6 +81,10 @@ def _loc2ang(loc): return phi % (2 * torch.pi), torch.atan2(s, z) +def _ang2loc(lat, lon): + pass + + def loc2vec(loc): z = loc[..., 0] s = loc[..., 1] diff --git a/tests/test_healpix_bare.py b/tests/test_healpix_bare.py index c5b9f08..c581674 100644 --- a/tests/test_healpix_bare.py +++ b/tests/test_healpix_bare.py @@ -90,3 +90,21 @@ def test_get_interp_weights_vector_interp_y(): assert torch.all(pix == inpix[:, None]) expected_weights = torch.tensor([ay / 2, ay / 2, (1 - ay) / 2, (1 - ay) / 2]).double()[:, None] assert torch.allclose(weights, expected_weights) + + +@pytest.mark.parametrize("lonlat", [True, False]) +def test_ang2pix(lonlat): + if lonlat: + lon = torch.tensor([32.0]) + lat = torch.tensor([45.0]) + else: + lon = torch.tensor([1.0]) + lat = torch.tensor([2.0]) + + n = 2**16 + + pix = earth2grid.healpix_bare.ang2pix(n, lon, lat, lonlat=lonlat) + lon_out, lat_out = earth2grid.healpix_bare.pix2ang(n, pix, lonlat=lonlat) + + assert lon.item() == pytest.approx(lon_out.item(), rel=1e-4) + assert lat.item() == pytest.approx(lat_out.item(), rel=1e-4)