Skip to content

Commit

Permalink
Merge pull request #19 from simonbyrne/sb/lcc
Browse files Browse the repository at this point in the history
add HRRR CONUS grid
  • Loading branch information
nbren12 authored Sep 30, 2024
2 parents 1ad2a2e + 643cedc commit 6197c98
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 5 deletions.
5 changes: 4 additions & 1 deletion earth2grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
import torch

from earth2grid import base, healpix, latlon
from earth2grid import base, healpix, latlon, lcc
from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder

__all__ = [
"base",
"healpix",
"latlon",
"lcc",
"get_regridder",
"BilinearInterpolator",
"KNNS2Interpolator",
Expand All @@ -36,6 +37,8 @@ def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, lcc.LambertConformalConicGrid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(src, healpix.Grid):
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
elif isinstance(dest, healpix.Grid):
Expand Down
8 changes: 4 additions & 4 deletions earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def forward(self, z):
weight = self.weight.view(-1, p)

# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum')
output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode="sum")
output = output.T.view(*shape, -1)
return output.reshape(list(shape) + output_shape)

Expand Down Expand Up @@ -173,12 +173,12 @@ def forward(self, z: torch.Tensor):
*shape, y, x = z.shape
zrs = z.view(-1, y * x).T
# using embedding bag is 2x faster on cpu and 4x on gpu.
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode="sum")
interpolated = torch.full(
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
)
interpolated.masked_scatter_(self.mask.unsqueeze(-1), output)
interpolated = interpolated.T.view(*shape, self.mask.numel())
interpolated.masked_scatter_(self.mask.view(-1, 1), output)
interpolated = interpolated.T.view(*shape, *self.mask.shape)
return interpolated


Expand Down
190 changes: 190 additions & 0 deletions earth2grid/lcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch

from earth2grid import base
from earth2grid._regrid import BilinearInterpolator

try:
import pyvista as pv
except ImportError:
pv = None

__all__ = [
"LambertConformalConicProjection",
"LambertConformalConicGrid",
"HRRR_CONUS_PROJECTION",
"HRRR_CONUS_GRID",
]


class LambertConformalConicProjection:
def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float):
"""
Args:
lat0: latitude of origin (degrees)
lon0: longitude of origin (degrees)
lat1: first standard parallel (degrees)
lat2: second standard parallel (degrees)
radius: radius of sphere (m)
"""

self.lon0 = lon0
self.lat0 = lat0
self.lat1 = lat1
self.lat2 = lat2
self.radius = radius

c1 = np.cos(np.deg2rad(lat1))
c2 = np.cos(np.deg2rad(lat2))
t1 = np.tan(np.pi / 4 + np.deg2rad(lat1) / 2)
t2 = np.tan(np.pi / 4 + np.deg2rad(lat2) / 2)

if np.abs(lat1 - lat2) < 1e-8:
self.n = np.sin(np.deg2rad(lat1))
else:
self.n = np.log(c1 / c2) / np.log(t2 / t1)

self.RF = radius * c1 * np.power(t1, self.n) / self.n
self.rho0 = self._rho(lat0)

def _rho(self, lat):
return self.RF / np.power(np.tan(np.pi / 4 + np.deg2rad(lat) / 2), self.n)

def _theta(self, lon):
"""
Angle of deviation (in radians) of the projected grid from the regular grid,
for a given longitude (in degrees).
To convert to U and V on the projected grid to easterly / northerly components:
UN = cos(theta) * U + sin(theta) * V
VN = - sin(theta) * U + cos(theta) * V
"""
# center about reference longitude
delta_lon = lon - self.lon0
delta_lon = delta_lon - np.round(delta_lon / 360) * 360 # convert to [-180, 180]
return self.n * np.deg2rad(delta_lon)

def project(self, lat, lon):
"""
Compute the projected x,y from lat,lon.
"""
rho = self._rho(lat)
theta = self._theta(lon)

x = rho * np.sin(theta)
y = self.rho0 - rho * np.cos(theta)
return x, y

def inverse_project(self, x, y):
"""
Compute the lat,lon from the projected x,y.
"""
rho = np.hypot(x, self.rho0 - y)
theta = np.arctan2(x, self.rho0 - y)

lat = np.rad2deg(2 * np.arctan(np.power(self.RF / rho, 1 / self.n))) - 90
lon = self.lon0 + np.rad2deg(theta / self.n)
return lat, lon


# Projection used by HRRR CONUS (Continental US) data
# https://rapidrefresh.noaa.gov/hrrr/HRRR_conus.domain.txt
HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0)


class LambertConformalConicGrid(base.Grid):
# nothing here is specific to the projection, so could be shared by any projected rectilinear grid
def __init__(self, projection: LambertConformalConicProjection, x, y):
"""
Args:
projection: LambertConformalConicProjection object
x: range of x values
y: range of y values
"""
self.projection = projection

self.x = np.array(x)
self.y = np.array(y)

@property
def lat_lon(self):
mesh_x, mesh_y = np.meshgrid(self.x, self.y)
return self.projection.inverse_project(mesh_x, mesh_y)

@property
def lat(self):
return self.lat_lon[0]

@property
def lon(self):
return self.lat_lon[1]

@property
def shape(self):
return (len(self.y), len(self.x))

def __getitem__(self, idxs):
yidxs, xidxs = idxs
return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs])

def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
"""Get regridder to the specified lat and lon points"""

x, y = self.projection.project(lat, lon)

return BilinearInterpolator(
x_coords=torch.from_numpy(self.x),
y_coords=torch.from_numpy(self.y),
x_query=torch.from_numpy(x),
y_query=torch.from_numpy(y),
)

def visualize(self, data):
raise NotImplementedError()

def to_pyvista(self):
if pv is None:
raise ImportError("Need to install pyvista")

lat, lon = self.lat_lon
y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon))
x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon))
z = np.sin(np.deg2rad(lat))
grid = pv.StructuredGrid(x, y, z)
return grid


def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059):
# coordinates of point in top-left corner
lat0 = 21.138123
lon0 = 237.280472
# grid length (m)
scale = 3000.0
# coordinates on projected space
x0, y0 = HRRR_CONUS_PROJECTION.project(lat0, lon0)

x = [x0 + i * scale for i in range(ix0, ix0 + nx)]
y = [y0 + i * scale for i in range(iy0, iy0 + ny)]

return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y)


# Grid used by HRRR CONUS (Continental US) data
HRRR_CONUS_GRID = hrrr_conus_grid()
77 changes: 77 additions & 0 deletions tests/test_lcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# %%
import numpy as np
import pytest
import torch

from earth2grid.lcc import HRRR_CONUS_GRID


def test_grid_shape():
assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape
assert HRRR_CONUS_GRID.lon.shape == HRRR_CONUS_GRID.shape


lats = np.array(
[
[21.138123, 21.801926, 22.393631, 22.911015],
[23.636763, 24.328228, 24.944668, 25.48374],
[26.155672, 26.875362, 27.517046, 28.078257],
[28.69017, 29.438608, 30.106009, 30.68978],
]
)

lons = np.array(
[
[-122.71953, -120.03195, -117.304596, -114.54146],
[-123.491356, -120.72898, -117.92319, -115.07828],
[-124.310524, -121.469505, -118.58098, -115.649574],
[-125.181404, -122.25762, -119.28173, -116.25871],
]
)


def test_grid_vals():
assert HRRR_CONUS_GRID.lat[0:400:100, 0:400:100] == pytest.approx(lats)
assert HRRR_CONUS_GRID.lon[0:400:100, 0:400:100] == pytest.approx(lons)


def test_grid_slice():
slice_grid = HRRR_CONUS_GRID[0:400:100, 0:400:100]
assert slice_grid.lat == pytest.approx(lats)
assert slice_grid.lon == pytest.approx(lons)


def test_regrid_1d():
src = HRRR_CONUS_GRID
dest_lat = np.linspace(25.0, 33.0, 10)
dest_lon = np.linspace(-123, -98, 10)
regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon)
src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out_lat = regrid(src_lat)

assert torch.allclose(out_lat, torch.tensor(dest_lat))


def test_regrid_2d():
src = HRRR_CONUS_GRID
dest_lat, dest_lon = np.meshgrid(np.linspace(25.0, 33.0, 10), np.linspace(-123, -98, 12))
regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon)
src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape)
out_lat = regrid(src_lat)

assert torch.allclose(out_lat, torch.tensor(dest_lat))

0 comments on commit 6197c98

Please sign in to comment.