Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update __init__.py #144

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions axiom/drs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dask.distributed import progress, wait
import numpy as np
from axiom.supervisor import Supervisor

from axiom.drs.processing.ccam import is_instantaneous

def consume(json_filepath):
"""Consume a json payload (for message passing)
Expand Down Expand Up @@ -205,6 +205,10 @@ def preprocess(ds, *args, **kwargs): return preprocessor(ds, **local_args)
**open_dataset_kwargs
)

# round lat/lon coords
# ds.coords['lon'] = ds.coords['lon'].round(decimals=domain.rounding)
# ds.coords['lat'] = ds.coords['lat'].round(decimals=domain.rounding)

# Subset temporally
if not adu.is_time_invariant(ds):
logger.info(f'Subsetting times to {start_year}')
Expand Down Expand Up @@ -286,6 +290,7 @@ def preprocess(ds, *args, **kwargs): return preprocessor(ds, **local_args)
raise Exception(f'Unable to parse domain {domain}.')

logger.debug('Domain: ' + domain.to_directive())
rounding = int(domain.rounding)

# Subset the geographical domain
logger.debug('Subsetting geographical domain.')
Expand Down Expand Up @@ -348,7 +353,7 @@ def preprocess(ds, *args, **kwargs): return preprocessor(ds, **local_args)
context['start_date'], context['end_date'] = adu.get_start_and_end_dates(year, output_frequency)

# Tracking info
context['creation_date'] = datetime.utcnow()
context['creation_date'] = datetime.utcnow().isoformat(timespec='seconds')+'Z'
context['uuid'] = uuid4()

# Interpolate context
Expand Down Expand Up @@ -395,6 +400,10 @@ def preprocess(ds, *args, **kwargs): return preprocessor(ds, **local_args)
# Apply a blanket variable encoding.
encoding[variable] = config.encoding['variables']

# add lat_bnds and lon_bnds encoding from drs.json to remove _FillValue
encoding['lat_bnds'] = config.encoding['lat_bnds']
encoding['lon_bnds'] = config.encoding['lon_bnds']

# Postprocess data if required
postprocessor = adu.load_postprocessor(postprocessor)

Expand All @@ -411,6 +420,19 @@ def postprocess(_ds, *args, **kwargs):
# Update the cell methods
if resampling_applied:
_ds = update_cell_methods(_ds, variable, dim='time', method='mean')
else:
if is_instantaneous(_ds, variable):
_ds[variable].attrs['cell_methods'] = f'area: mean time: point'
if adu.is_time_invariant(_ds): # if invariant, set to 'area: mean'
_ds[variable].attrs['cell_methods'] = f'area: mean'
if _ds[variable].attrs['cell_methods'] == f'time: maximum':
_ds[variable].attrs['cell_methods'] = f'area: mean time: maximum'
if _ds[variable].attrs['cell_methods'] == f'time: minimum':
_ds[variable].attrs['cell_methods'] = f'area: mean time: minimum'
if _ds[variable].attrs['cell_methods'] == f'time: sum':
_ds[variable].attrs['cell_methods'] = f'area: mean time: sum'
if _ds[variable].attrs['cell_methods'] == f'time: mean':
_ds[variable].attrs['cell_methods'] = f'area: time: mean'

# Get the full output filepath with string interpolation
logger.debug('Working out output paths')
Expand Down Expand Up @@ -467,6 +489,16 @@ def postprocess(_ds, *args, **kwargs):
# Get the output format from config
output_format = config.get('output_format', 'NETCDF4')

# round lat/lon coords and convert to double
_ds.coords['lon'] = _ds.coords['lon'].astype('float64')
_ds.coords['lat'] = _ds.coords['lat'].astype('float64')
_ds.coords['lon'] = _ds.coords['lon'].round(decimals=rounding)
_ds.coords['lat'] = _ds.coords['lat'].round(decimals=rounding)
_ds['lon_bnds'] = _ds['lon_bnds'].astype('float64')
_ds['lat_bnds'] = _ds['lat_bnds'].astype('float64')
_ds['lon_bnds'] = _ds['lon_bnds'].round(decimals=rounding)
_ds['lat_bnds'] = _ds['lat_bnds'].round(decimals=rounding)

logger.debug(f'Writing {output_filepath}')
write = _ds.to_netcdf(
output_filepath,
Expand Down Expand Up @@ -675,15 +707,19 @@ def update_cell_methods(ds, variable, dim='time', method='mean'):

# If there is no cell_methods attribute, add it now.
if 'cell_methods' not in da.attrs.keys():
da.attrs['cell_methods'] = f'{dim}: {method}'
da.attrs['cell_methods'] = f'area: {dim}: {method}'

# If the cell method was point, change it
elif da.attrs['cell_methods'] == f'{dim}: point':
da.attrs['cell_methods'] = f'{dim}: {method}'
da.attrs['cell_methods'] = f'area: {dim}: {method}'

# If another operation was already applied and doesn't match this, append
elif da.attrs['cell_methods'] != f'{dim}: {method}':
da.attrs['cell_methods'] = da.attrs['cell_methods'] + f' {dim}: {method}'
da.attrs['cell_methods'] = f'area: mean ' + da.attrs['cell_methods'] + f' {dim}: {method}'

# If cell method doesn't include area, add it
elif da.attrs['cell_methods'] == f'{dim}: {method}':
da.attrs['cell_methods'] = f'area: ' + da.attrs['cell_methods']

ds[variable] = da
return ds
Expand Down
14 changes: 9 additions & 5 deletions axiom/drs/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ class Domain:
lat_max (float): Maximum latitude.
lon_min (float): Minimum longitude.
lon_max (float): Maximum longitude.
rounding (integer): Decimal rounding.
"""

def __init__(self, name, dx, lat_min, lat_max, lon_min, lon_max):
def __init__(self, name, dx, lat_min, lat_max, lon_min, lon_max, rounding):
self.name = name
self.dx = dx
self.lat_min = lat_min
self.lat_max = lat_max
self.lon_min = lon_min
self.lon_max = lon_max
self.rounding = rounding


def from_dict(domain_dict):
Expand All @@ -44,7 +46,8 @@ def to_dict(self):
lat_min=self.lat_min,
lat_max=self.lat_max,
lon_min=self.lon_min,
lon_max=self.lon_max
lon_max=self.lon_max,
rounding=self.rounding
)


Expand All @@ -67,7 +70,8 @@ def from_directive(directive):
lat_min=lat_min,
lat_max=lat_max,
lon_min=lon_min,
lon_max=lon_max
lon_max=lon_max,
rounding=rounding
)


Expand All @@ -77,7 +81,7 @@ def to_directive(self):
Returns:
str : Directive.
"""
return f'{self.name},{self.dx},{self.lat_min},{self.lat_max},{self.lon_min},{self.lon_max}'
return f'{self.name},{self.dx},{self.lat_min},{self.lat_max},{self.lon_min},{self.lon_max},{self.rounding}'


def subset_xarray(self, ds, drop=True):
Expand Down Expand Up @@ -112,4 +116,4 @@ def from_config(key, config):
return Domain(
name=key,
**elements
)
)