Skip to content

Commit

Permalink
xESMF weights
Browse files Browse the repository at this point in the history
  • Loading branch information
norlandrhagen committed Aug 20, 2024
1 parent b9dfc09 commit 164d36b
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions feedstock/pyramid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import apache_beam as beam
import xarray as xr
import numpy as np
import xarray as xr
import numpy as np
import xesmf as xe
from pangeo_forge_ndpyramid.transforms import StoreToPyramid
from pangeo_forge_recipes.transforms import OpenWithXarray, ConsolidateMetadata
from pangeo_forge_recipes.transforms import OpenWithXarray
from pangeo_forge_recipes.patterns import FileType, pattern_from_file_sequence

from dataclasses import dataclass
from leap_data_management_utils.data_management_transforms import (
Copy,
get_catalog_store_urls,
)

Expand Down Expand Up @@ -37,19 +35,9 @@ class GenerateWeights(beam.PTransform):
"""Custom PTransform to generate weights for xESMF regridding"""

def _generate_weights(self, ds: xr.Dataset) -> xr.Dataset:
import gcsfs

fs = gcsfs.GCSFileSystem()
weights_ds = xr.open_dataset(fs.open('gs://leap-scratch/norlandrhagen/enatl_weights_256.nc'))
except FileNotFoundError():


if weights_exist_bool:
weights_

ds = ds.rio.write_crs("EPSG:4326")
# grab sample of dataset for weights
nds = ds.isel(time=0)[['vosaline']]
nds = ds.isel(time=0)[["vosaline"]]

lat_min, lat_max = nds.nav_lat.min().values, nds.nav_lat.max().values
lon_min, lon_max = nds.nav_lon.min().values, nds.nav_lon.max().values
Expand All @@ -59,17 +47,22 @@ def _generate_weights(self, ds: xr.Dataset) -> xr.Dataset:
lon = np.linspace(lon_min, lon_max, 4096)

ds_out = xr.Dataset(
coords={
'lat': ('lat', lat),
'lon': ('lon', lon)
},
coords={"lat": ("lat", lat), "lon": ("lon", lon)},
data_vars={
'mask': (['lat', 'lon'], np.ones((len(lat), len(lon)), dtype=bool))
})

regridder = xe.Regridder(ds, ds_out, 'bilinear', weights='enatl_weights_4096.nc')

return ds
"mask": (["lat", "lon"], np.ones((len(lat), len(lon)), dtype=bool))
},
)
weights_local_filename = "enatl_weights_4096.nc"
regridder = xe.Regridder(ds, ds_out, "bilinear", weights=weights_local_filename)
weights_ds = xr.open_dataset("weights_local_filename")
weights_ds.to_netcdf(
f"gs://leap-scratch/data-library/feedstocks/eNATL_regridding/{weights_local_filename}"
)
weights_ds.to_zarr(
"gs://leap-scratch/data-library/feedstocks/eNATL_regridding/enatl_weights_4096.zarr"
)

return weights_ds

def expand(self, pcoll):
return pcoll | "subset" >> beam.MapTuple(lambda k, v: (k, self._subset(v)))
Expand All @@ -78,7 +71,7 @@ def expand(self, pcoll):
pyramid = (
beam.Create(pattern.items())
| OpenWithXarray(file_type=FileType("zarr"), xarray_open_kwargs={"chunks": {}})
# | Subset()
| GenerateWeights()
# | StoreToPyramid(
# store_name="eNATL60_BLBT02_pyramid.zarr",
# epsg_code="4326",
Expand Down

0 comments on commit 164d36b

Please sign in to comment.