Skip to content

Commit

Permalink
add earth2grid.healpix.ang2pix
Browse files Browse the repository at this point in the history
  • Loading branch information
nbren12 committed Sep 26, 2024
1 parent c45702a commit da7684f
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
15 changes: 15 additions & 0 deletions earth2grid/csrc/healpix_bare_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ torch::Tensor hpd2loc_wrapper(int nside, torch::Tensor input) {
return output;
}

torch::Tensor ang2ring_wrapper(int nside, torch::Tensor ang) {
auto ang_accessor = ang.accessor<double, 2>();
auto output_options = torch::TensorOptions().dtype(torch::kInt64);
auto output = torch::empty({ang.size(0)}, output_options);

auto output_accessor = output.accessor<int64_t, 1>();

for (int64_t i = 0; i < ang.size(0); ++i) {
t_ang angi {ang_accessor[i][0], ang_accessor[i][1]};
output_accessor[i] = ang2ring(nside, angi);
}
return output;
}

torch::Tensor hpc2loc_wrapper(torch::Tensor x, torch::Tensor y, torch::Tensor f) {
auto accessor_f = f.accessor<int64_t, 1>();
auto accessor_x = x.accessor<double, 1>();
Expand Down Expand Up @@ -223,6 +237,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ring2hpd", wrap_2hpd(ring2hpd), "hpd is f, y ,x");
m.def("hpd2loc", &hpd2loc_wrapper, "loc is in z, s, phi");
m.def("hpc2loc", &hpc2loc_wrapper, "hpc2loc(x, y, f) -> z, s, phi");
m.def("ang2ring", &ang2ring_wrapper, "ang2ring(nside, ang) -> pix");
m.def("corners", &corners, "");
m.def("get_interp_weights", &get_interp_weights, "");
};
3 changes: 2 additions & 1 deletion earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from earth2grid import healpix_bare
from earth2grid._regrid import Regridder
from earth2grid.healpix_bare import ang2pix

try:
import pyvista as pv
Expand All @@ -67,7 +68,7 @@
except ImportError:
cuhpx = None

__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder"]
__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d", "reorder", "ang2pix"]


def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
Expand Down
36 changes: 35 additions & 1 deletion earth2grid/healpix_bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import torch

from earth2grid import _healpix_bare
from earth2grid._healpix_bare import corners, hpc2loc, hpd2loc, nest2hpd, nest2ring, ring2hpd, ring2nest
from earth2grid._healpix_bare import ang2ring, corners, hpc2loc, hpd2loc, nest2hpd, nest2ring, ring2hpd, ring2nest

__all__ = [
"pix2ang",
"ang2pix",
"ring2nest",
"nest2ring",
"hpc2loc",
Expand All @@ -27,6 +28,12 @@


def pix2ang(nside, i, nest=False, lonlat=False):
"""
Returns:
theta, phi: (lon, lat) in degrees if lonlat=True else (colat, lon) in
radians
"""
if nest:
hpd = nest2hpd(nside, i)
else:
Expand All @@ -40,6 +47,29 @@ def pix2ang(nside, i, nest=False, lonlat=False):
return lat, lon


def ang2pix(nside, theta, phi, nest=False, lonlat=False):
"""Find the pixel containing a given angular coordinate
Args:
theta, phi: (lon, lat) in degrees if lonlat=True else (colat, lon) in
radians
"""
if lonlat:
lon = theta
lat = phi

theta = torch.deg2rad(90 - lat)
phi = torch.deg2rad(lon)

ang = torch.stack([theta, phi], -1)
pix = ang2ring(nside, ang.double())
if nest:
pix = ring2nest(nside, pix)

return pix


def _loc2ang(loc):
"""
static t_ang loc2ang(tloc loc)
Expand All @@ -51,6 +81,10 @@ def _loc2ang(loc):
return phi % (2 * torch.pi), torch.atan2(s, z)


def _ang2loc(lat, lon):
pass


def loc2vec(loc):
z = loc[..., 0]
s = loc[..., 1]
Expand Down
18 changes: 18 additions & 0 deletions tests/test_healpix_bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,21 @@ def test_get_interp_weights_vector_interp_y():
assert torch.all(pix == inpix[:, None])
expected_weights = torch.tensor([ay / 2, ay / 2, (1 - ay) / 2, (1 - ay) / 2]).double()[:, None]
assert torch.allclose(weights, expected_weights)


@pytest.mark.parametrize("lonlat", [True, False])
def test_ang2pix(lonlat):
if lonlat:
lon = torch.tensor([32.0])
lat = torch.tensor([45.0])
else:
lon = torch.tensor([1.0])
lat = torch.tensor([2.0])

n = 2**16

pix = earth2grid.healpix_bare.ang2pix(n, lon, lat, lonlat=lonlat)
lon_out, lat_out = earth2grid.healpix_bare.pix2ang(n, pix, lonlat=lonlat)

assert lon.item() == pytest.approx(lon_out.item(), rel=1e-4)
assert lat.item() == pytest.approx(lat_out.item(), rel=1e-4)

0 comments on commit da7684f

Please sign in to comment.