Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue33 iceage #34

Merged
merged 12 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions geodataset/custom_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,58 @@ class CmemsMetIceChart(CustomDatasetRead):
class Dist2Coast(CustomDatasetRead):
pattern = re.compile(r'dist2coast_4deg.nc')
lonlat_names = 'lon', 'lat'
def get_lonlat_arrays(self):
return np.meshgrid(self['lon'][:], self['lat'][:])


class Etopo(CustomDatasetRead):
pattern = re.compile(r'ETOPO_Arctic_\d{1,2}arcmin.nc')

def get_lonlat_arrays(self):
return np.meshgrid(self['lon'][:], self['lat'][:])


class JaxaAmsr2IceConc(CustomDatasetRead):
pattern = re.compile(r'Arc_\d{8}_res3.125_pyres.nc')
lonlat_names = 'longitude', 'latitude'
grid_mapping = pyproj.CRS.from_epsg(3411), 'absent'


class NerscSarProducts(CustomDatasetRead):
class NERSCProductBase(CustomDatasetRead):
lonlat_names = 'absent', 'absent'
def get_lonlat_arrays(self):
x_grd, y_grd = np.meshgrid(self['x'][:], self['y'][:])

def get_lonlat_arrays(self, ij_range=(None,None,None,None), **kwargs):
"""
Return lon,lat as 2D arrays

Parameters
----------
ij_range : tuple(int)
- [i0, i1, j0, j1]
- pixel indices for subsetting
- return lon[i0:i1,j0:j1], lat[i0:i1,j0:j1]
instead of full arrays
dummy kwargs

Returns
-------
lon : numpy.ndarray
2D array with longitudes of pixel centers
lat : numpy.ndarray
2D array with latitudes of pixel centers
"""
i0, i1, j0, j1 = ij_range
x_grd, y_grd = np.meshgrid(self['x'][j0:j1], self['y'][i0:i1])
return self.projection(x_grd, y_grd, inverse=True)


class NerscDeformation(NerscSarProducts):

class NERSCDeformation(NERSCProductBase):
pattern = re.compile(r'arctic_2km_deformation_\d{8}T\d{6}.nc')


class NerscIceType(NerscSarProducts):
class NERSCIceType(NERSCProductBase):
pattern = re.compile(r'arctic_2km_icetype_\d{8}T\d{6}.nc')


class NERSCSeaIceAge(NERSCProductBase):
pattern = re.compile(r'arctic25km_sea_ice_age_v2p1_\d{8}.nc')


class OsisafDriftersNextsim(CustomDatasetRead):
pattern = re.compile(r'OSISAF_Drifters_.*.nc')
grid_mapping = pyproj.CRS.from_proj4(
Expand All @@ -88,22 +108,23 @@ class UniBremenAlbedoMPF(CustomDatasetRead):
grid_mapping = (pyproj.CRS.from_proj4(
'+proj=stere +lat_0=90 +lat_ts=70 +lon_0=-45 +x_0=0 +y_0=0 '
'+ellps=WGS84 +units=m +no_defs'), 'absent')
pattern = re.compile(r'mpd_\d{8}.nc|mpd_\d{8}_NR.nc') # after 2020, filenames have _NR suffix
pattern = re.compile(r'mpd1_\d{8}.nc')

@staticmethod
def get_xy_arrays(ij_range=None):
def get_xy_arrays(ij_range=(None,None,None,None), **kwargs):
"""
Grid info from
https://nsidc.org/data/polar-stereo/ps_grids.html
see table 6

Parameters:
-----------
ij_range : list(int)
ij_range : tuple(int)
- [i0, i1, j0, j1]
- pixel indices for subsetting
- return x[i0:i1+1,j0:j1+1], y[i0:i1+1,j0:j1+1]
- return x[i0:i1,j0:j1], y[i0:i1,j0:j1]
instead of full arrays
dummy kwargs

Returns:
--------
Expand All @@ -128,10 +149,8 @@ def get_xy_arrays(ij_range=None):
.5e3 * (qx[:-1] + qx[1:]),
.5e3 * (qy[:-1] + qy[1:]),
)
if ij_range is not None:
i0, i1, j0, j1 = ij_range
return px[i0:i1+1,j0:j1+1], py[i0:i1+1,j0:j1+1]
return px, py
i0, i1, j0, j1 = ij_range
return px[i0:i1,j0:j1], py[i0:i1,j0:j1]

def get_lonlat_arrays(self, **kwargs):
"""
Expand Down Expand Up @@ -160,4 +179,5 @@ def datetimes(self):
all the time values converted to datetime objects
"""
bname = os.path.basename(self.filepath())
return [dt.datetime.strptime(bname[4:12], '%Y%m%d')]
datestr = bname.split('_')[1][:8]
return [dt.datetime.strptime(datestr, '%Y%m%d') + dt.timedelta(hours=12)]
87 changes: 59 additions & 28 deletions geodataset/geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property

from netCDF4 import Dataset
import netcdftime
from netcdftime import num2date
import numpy as np
import pyproj
from pyproj.exceptions import CRSError
Expand Down Expand Up @@ -57,7 +57,8 @@ def convert_time_data(self, tdata):
atts = vars(self.variables[self.time_name])
cal = atts.get('calendar', 'standard')
units = atts['units']
datetimes = [netcdftime.num2date(t, units, calendar=cal) for t in tdata]
datetimes = [num2date(t, units, calendar=cal)
for t in tdata.flatten()]
return np.array(datetimes).reshape(tdata.shape)

@cached_property
Expand Down Expand Up @@ -426,14 +427,15 @@ def get_lonlat_arrays(self, ij_range=(None, None, None, None), **kwargs):
lat : numpy.ndarray
2D array with latitude
"""
if not self.is_lonlat_dim:
return [
self.get_variable_array(name, ij_range=ij_range)
for name in self.lonlat_names]
lon_name, lat_name = self.lonlat_names
lon = self[lon_name][ij_range[2]:ij_range[3]]
lat = self[lat_name][ij_range[0]:ij_range[1]]
return np.meshgrid(lon, lat)
lon = self.variables[lon_name]
lat = self.variables[lat_name]
i0, i1, j0, j1 = ij_range
slat = slice(i0, i1)
slon = slice(j0, j1)
if lon.ndim == 2:
return [a[slat, slon] for a in (lon, lat)]
return np.meshgrid(lon[slon], lat[slat])

def get_area_euclidean(self, mapping, **kwargs):
"""
Expand Down Expand Up @@ -524,8 +526,7 @@ def get_proj_info_kwargs(self):
)
return kwargs

def get_var_for_nextsim(self, var_name, nbo,
distance=5, on_elements=True, fill_value=np.nan, **kwargs):
def interp_to_points(self, var_name, lon, lat, distance=5, fill_value=np.nan, **kwargs):
""" Interpolate netCDF data onto mesh from NextsimBin object

Parameters
Expand Down Expand Up @@ -560,22 +561,13 @@ def get_var_for_nextsim(self, var_name, nbo,
if len(nc_v.shape) != 2:
raise ValueError('Can interpolate only 2D data from netCDF file')

# get elements coordinates in neXtSIM projection
nb_x = nbo.mesh_info.nodes_x
nb_y = nbo.mesh_info.nodes_y
t = nbo.mesh_info.indices
if on_elements:
nb_x, nb_y = [i[t].mean(axis=1) for i in [nb_x, nb_y]]

# transform nextsim coordinates to lon/lat
nb_x, nb_y = nbo.mesh_info.projection.pyproj(nb_x, nb_y, inverse=True)

# transform to common coordinate system if needed
if not self.is_lonlat_dim:
nc_x, nc_y = self.get_xy_dims_from_lonlat(nc_lon, nc_lat)
nb_x, nb_y = self.projection(nb_x, nb_y)
xout, yout = self.projection(lon, lat)
else:
nc_x, nc_y = nc_lon[0], nc_lat[:,0]
xout, yout = lon, lat

# fill nan gaps to avoid land contamination
nc_v = fill_nan_gaps(nc_v, distance)
Expand All @@ -584,12 +576,51 @@ def get_var_for_nextsim(self, var_name, nbo,
# make interpolator
rgi = RegularGridInterpolator((nc_y[::y_step], nc_x), nc_v[::y_step])
# interpolate only values within self bbox
gpi = ((nb_x > nc_x.min()) *
(nb_x < nc_x.max()) *
(nb_y > nc_y.min()) *
(nb_y < nc_y.max()))
v_pro = np.full_like(nb_x, fill_value, dtype=float)
v_pro[gpi] = rgi((nb_y[gpi], nb_x[gpi]))
gpi = ((xout > nc_x.min()) *
(xout < nc_x.max()) *
(yout > nc_y.min()) *
(yout < nc_y.max()))
v_pro = np.full_like(xout, fill_value, dtype=float)
v_pro[gpi] = rgi((yout[gpi], xout[gpi]))
# replace remaining NaN's (inside the domain, but not filled by fill_nan_gaps)
v_pro[np.isnan(v_pro)] = fill_value
return v_pro

def get_var_for_nextsim(self, var_name, nbo, on_elements=True, **kwargs):
""" Interpolate netCDF data onto mesh from NextsimBin object

Parameters
----------
var_name : str
name of variable
nbo : NextsimBin
nextsim bin object with mesh_info attribute
distance : int
extrapolation distance (in pixels) to avoid land contamintation
on_elements : bool
perform interpolation on elements or nodes?
fill_value : bool
value for filling out of bound regions
ij_range : list(int) or tuple(int)
for subsetting in space
eg [i0,i1,j0,j1] grabs lon[i0:i1,j0:j1], lat[i0:i1,j0:j1]
kwargs : dict
for GeoDatasetRead.get_variable_array and
GeoDatasetRead.get_lonlat_arrays

Returns
-------
v_pro : 1D nupy.array
values from netCDF interpolated on nextsim mesh
"""

# get elements coordinates in neXtSIM projection
nb_x = nbo.mesh_info.nodes_x
nb_y = nbo.mesh_info.nodes_y
if on_elements:
t = nbo.mesh_info.indices
nb_x, nb_y = [i[t].mean(axis=1) for i in [nb_x, nb_y]]

# transform nextsim coordinates to lon/lat
nb_x, nb_y = nbo.mesh_info.projection.pyproj(nb_x, nb_y, inverse=True)
return self.interp_to_points(var_name, nb_x, nb_y, **kwargs)
59 changes: 44 additions & 15 deletions geodataset/tests/test_custom_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import pyproj
from pyproj.exceptions import CRSError

from geodataset.custom_geodataset import UniBremenAlbedoMPF
from geodataset.custom_geodataset import UniBremenAlbedoMPF, NERSCProductBase

from geodataset.utils import InvalidDatasetError
from geodataset.tests.base_for_tests import BaseForTests


class UniBremenAlbedoMPFBaseTest(BaseForTests):
class UniBremenAlbedoMPFTest(BaseForTests):

def test_get_xy_arrays_1(self):
""" test get_xy_arrays with default options """
Expand All @@ -34,8 +34,8 @@ def test_get_xy_arrays_2(self):
""" test get_xy_arrays with ij_range passed """
x0, y0 = UniBremenAlbedoMPF.get_xy_arrays()
x, y = UniBremenAlbedoMPF.get_xy_arrays(ij_range=[3,10,6,21])
self.assertTrue(np.allclose(x0[3:11,6:22], x))
self.assertTrue(np.allclose(y0[3:11,6:22], y))
self.assertTrue(np.allclose(x0[3:10,6:21], x))
self.assertTrue(np.allclose(y0[3:10,6:21], y))

@patch.multiple(UniBremenAlbedoMPF,
__init__=MagicMock(return_value=None),
Expand All @@ -55,24 +55,53 @@ def test_get_lonlat_arrays(self):
__init__=MagicMock(return_value=None),
filepath=DEFAULT,
)
def test_datetimes_1(self, **kwargs):
""" test for older filename """
dto = dt.datetime(2017,5,1)
def test_datetimes(self, **kwargs):
dto = dt.datetime(2023,5,1,12)
kwargs['filepath'].return_value = dto.strftime('a/b/mpd_%Y%m%d.nc')
obj = UniBremenAlbedoMPF()
self.assertEqual(obj.datetimes, [dto])


@patch.multiple(UniBremenAlbedoMPF,
class NERSCProductBaseTest(BaseForTests):

@property
def x(self):
return np.linspace(0.,1.,6)

@property
def y(self):
return np.linspace(1.,2.,8)

@patch.multiple(NERSCProductBase,
__init__=MagicMock(return_value=None),
filepath=DEFAULT,
__getitem__=DEFAULT,
projection=DEFAULT,
)
def test_datetimes_2(self, **kwargs):
""" test for newer filename """
dto = dt.datetime(2021,5,1)
kwargs['filepath'].return_value = dto.strftime('a/b/mpd_%Y%m%d_NR.nc')
obj = UniBremenAlbedoMPF()
self.assertEqual(obj.datetimes, [dto])
def test_get_lonlat_arrays(self, __getitem__, projection):
""" test for older filename """
def mock_getitem(key):
if key == "x":
return self.x
return self.y

obj = NERSCProductBase()
__getitem__.side_effect = mock_getitem
projection.return_value = ('lon', 'lat')

i0 = 2
i1 = 5
j0 = 1
j1 = 6
x0, y0 = np.meshgrid(self.x[j0:j1], self.y[i0:i1])

lon, lat = obj.get_lonlat_arrays(ij_range=(i0, i1, j0, j1))
self.assertEqual(lon, 'lon')
self.assertEqual(lat, 'lat')
self.assertEqual(__getitem__.mock_calls, [call('x'), call('y')])
x, y = projection.mock_calls[0][1]
self.assertTrue(np.allclose(x, x0))
self.assertTrue(np.allclose(y, y0))
self.assertEqual(projection.mock_calls[0][2], dict(inverse=True))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion geodataset/tests/test_geodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUp(self):
self.moorings_var = 'sic'
# ECMWF forecast file - lon,lat are dims
self.ec2_file = os.path.join(os.environ['TEST_DATA_DIR'],
"ec2_start20220401.nc")
"ec2_start20240401.nc")


class GeoDatasetBaseTest(GeodatasetTestBase):
Expand Down
1 change: 0 additions & 1 deletion geodataset/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_open_netcdf(self):

def test_get_lonlat_arrays(self):
for nc_file in self.nc_files:
print(nc_file)
with self.subTest(nc_file=nc_file):
with open_netcdf(nc_file) as ds:
if not ds.is_lonlat_2d:
Expand Down
19 changes: 11 additions & 8 deletions geodataset/tools.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from geodataset.geodataset import GeoDatasetRead
from geodataset.utils import InvalidDatasetError
from geodataset.custom_geodataset import (
CmemsMetIceChart,
CmemsMetIceChart,
Dist2Coast,
Etopo,
JaxaAmsr2IceConc,
NerscDeformation,
NerscIceType,
JaxaAmsr2IceConc,
NERSCDeformation,
NERSCIceType,
NERSCSeaIceAge,
OsisafDriftersNextsim,
SmosIceThickness,
UniBremenAlbedoMPF,
)


custom_read_classes = [
CmemsMetIceChart,
CmemsMetIceChart,
Dist2Coast,
Etopo,
JaxaAmsr2IceConc,
NerscDeformation,
NerscIceType,
JaxaAmsr2IceConc,
NERSCDeformation,
NERSCIceType,
NERSCSeaIceAge,
OsisafDriftersNextsim,
SmosIceThickness,
UniBremenAlbedoMPF,
Expand Down
Loading