Skip to content

Commit

Permalink
Merge pull request #117 from CIROH-UA/plot_forcings
Browse files Browse the repository at this point in the history
fix fp tests
  • Loading branch information
JordanLaserGit authored Sep 3, 2024
2 parents c09d1a4 + 21ed220 commit 30095b2
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 100 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/forcingprocessor.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Test Forcing Processor
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ./forcingprocessor
pip install pytest
- name: Test with pytest
run: |
python -m pytest -vv --deselect="forcingprocessor/tests/test_forcingprocessor.py::test_google_cloud_storage" --deselect="forcingprocessor/tests/test_forcingprocessor.py::test_gcs" --deselect="forcingprocessor/tests/test_forcingprocessor.py::test_gs" --deselect="forcingprocessor/tests/test_forcingprocessor.py::test_ciroh_zarr" --deselect="forcingprocessor/tests/test_forcingprocessor.py::test_nomads_post_processed" --deselect="forcingprocessor/tests/test_forcingprocessor.py::test_retro_ciroh_zarr"
python -m pytest -v -k test_google_cloud_storage
python -m pytest -v -k test_gs
python -m pytest -v -k test_gcs
40 changes: 0 additions & 40 deletions .github/workflows/python-app.yml

This file was deleted.

1 change: 1 addition & 0 deletions forcingprocessor/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install_requires =
cftime
gcsfs
geopandas
imageio
matplotlib
netCDF4
h5netcdf
Expand Down
104 changes: 63 additions & 41 deletions forcingprocessor/src/forcingprocessor/plot_forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
mpl.use('Agg')

def plot_ngen_forcings(
nwm_data : np.ndarray,
ngen_data : np.ndarray,
geopackage : str,
t_ax : list,
catchment_ids : list,
ngen_vars_plot : list,
output_dir : Path
nwm_data : np.ndarray,
ngen_data : np.ndarray,
geopackage : str,
t_ax : list,
catchment_ids : list,
ngen_vars_plot : list = ngen_variables,
output_dir : Path = './GIFs'
):
"""
Generates side-by-side gif of nwm and ngen forcing data
Expand All @@ -42,12 +42,12 @@ def plot_ngen_forcings(
print(f'creating gif for variables {nwm_variable} -> {ngen_variable}')
images = []
for j, jtime in enumerate(t_ax):
fig, axes = plt.subplots(1, 2, figsize=(8, 8), dpi=200)
nwm_var = nwm_data[j,var_idx,:,:]
_, axes = plt.subplots(1, 2, figsize=(8, 8), dpi=200)
nwm_data_jvar = nwm_data[j,var_idx,:,:]
if j==0:
cmin=np.min(nwm_var)
cmax=np.max(nwm_var)
im = axes[0].imshow(nwm_var, vmin=cmin, vmax=cmax)
cmin=np.min(nwm_data_jvar)
cmax=np.max(nwm_data_jvar)
im = axes[0].imshow(nwm_data_jvar, vmin=cmin, vmax=cmax)
axes[0].axis('off')
axes[0].set_title(f'NWM')
gdf[ngen_variable] = ngen_data[j, var_idx, :]
Expand Down Expand Up @@ -77,7 +77,9 @@ def plot_ngen_forcings(
os.remove(jpng)
imageio.mimsave(os.path.join(output_dir, f'{nwm_variable}_2_{ngen_variable}.gif') , images, loop=0, fps=2)

def nc_to_3darray(forcings_nc,requested_vars):
def nc_to_3darray(forcings_nc : os.PathLike,
requested_vars : list = ngen_variables
) -> np.ndarray:
'''
forcings_nc : path to ngen forcings netcdf
'''
Expand All @@ -94,30 +96,62 @@ def nc_to_3darray(forcings_nc,requested_vars):

return ngen_data, t_ax_dt, catchment_ids

def csvs_to_3darray(forcings_dir, requested_ngen_variables):
def csvs_to_3darray(forcings_dir : os.PathLike,
requested_vars : list = ngen_variables
) -> np.ndarray:
'''
forcings_dir : directory containing ngen forcings csvs
'''
catchment_ids = []
for (path, _, files) in os.walk(forcings_dir):
i = 0
for (_, _, files) in os.walk(forcings_dir):
for j, jfile in enumerate(files):
catchment_id = jfile.split('.')[0]
catchment_ids.append(catchment_id)
ngen_jdf = pd.read_csv(os.path.join(forcings_dir, jfile))
if j == 0:
t_ax = ngen_jdf['time']
ngen_jdf = ngen_jdf.drop(columns='time')
shp = ngen_jdf.shape
ngen_data = np.zeros((len(files),shp[0],shp[1]),dtype=np.float32)
else:
ngen_jdf = ngen_jdf.drop(columns='time')
ngen_data[j,:,:] = np.array(ngen_jdf)

ngen_vars = np.array([x for x in range(len(ngen_variables)) if ngen_variables[x] in requested_ngen_variables])
if jfile[-3:] == "csv":
catchment_id = jfile.split('.')[0]
catchment_ids.append(catchment_id)
ngen_jdf = pd.read_csv(os.path.join(forcings_dir, jfile))
if i == 0:
i += 1
t_ax = ngen_jdf['time']
ngen_jdf = ngen_jdf.drop(columns='time')
shp = ngen_jdf.shape
ngen_data = np.zeros((len(files),shp[0],shp[1]),dtype=np.float32)
else:
ngen_jdf = ngen_jdf.drop(columns='time')

ngen_data[j,:,:] = np.array(ngen_jdf)

ngen_vars = np.array([x for x in range(len(ngen_variables)) if ngen_variables[x] in requested_vars])
ngen_data = np.moveaxis(ngen_data[:,:,ngen_vars],[0,1,2],[2,0,1])

return ngen_data, t_ax, catchment_ids

def get_nwm_data_array(
nwm_folder : list,
geopackage : gpd.GeoDataFrame,
nwm_vars : list = nwm_variables
) -> np.ndarray:
"""
Inputs a folder of national water model files and nwm variable names to extract.
Outputs a windowed array of national water model data for the domain and forcing variables specified.
nwm_data : 4d array (time x nwm_forcing_variable x west_east x south_north)
"""
weights_json, _ = gpkgs2weightsjson([geopackage])
x_min, x_max, y_min, y_max = get_window(weights_json)

for path, _, files in os.walk(nwm_folder):
nwm_data = np.zeros((len(files),len(nwm_vars),y_max-y_min+1,x_max - x_min+1),dtype=np.float32)
for k, jfile in enumerate(sorted(files)):
jfile_path = os.path.join(path,jfile)
ds = xr.open_dataset(jfile_path)
nwm_var = np.zeros((len(nwm_vars),y_max-y_min+1,x_max - x_min+1),dtype=np.float32)
for j, jvar in enumerate(nwm_vars):
nwm_var[j,:,:] = np.flip(np.squeeze(ds[jvar].isel(x=slice(x_min, x_max + 1), y=slice(3840 - y_max, 3840 - y_min + 1))),0)
nwm_data[k,:,:,:] = nwm_var

return nwm_data

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ngen_forcings", help="Path to a folder containing ngen catchment forcings csvs or path to netcdf",default="")
Expand All @@ -127,21 +161,9 @@ def csvs_to_3darray(forcings_dir, requested_ngen_variables):
parser.add_argument("--output_dir", help="Path to write gifs to",default="./GIFs")
args = parser.parse_args()

weights_json, jcatchment_dict = gpkgs2weightsjson([args.geopackage])
x_min, x_max, y_min, y_max = get_window(weights_json)

requested_ngen_variables = args.ngen_variables.split(', ')
nwm_vars = np.array([nwm_variables[x] for x in range(len(ngen_variables)) if ngen_variables[x] in requested_ngen_variables])

for path, _, files in os.walk(args.nwm_folder):
nwm_data = np.zeros((len(files),len(nwm_vars),y_max-y_min+1,x_max - x_min+1),dtype=np.float32)
for k, jfile in enumerate(sorted(files)):
jfile_path = os.path.join(path,jfile)
ds = xr.open_dataset(jfile_path)
nwm_var = np.zeros((len(nwm_vars),y_max-y_min+1,x_max - x_min+1),dtype=np.float32)
for j, jvar in enumerate(nwm_vars):
nwm_var[j,:,:] = np.flip(np.squeeze(ds[jvar].isel(x=slice(x_min, x_max + 1), y=slice(3840 - y_max, 3840 - y_min + 1))),0)
nwm_data[k,:,:,:] = nwm_var
nwm_data = get_nwm_data_array(args.nwm_folder,args.geopackge, nwm_vars)

if args.ngen_forcings.endswith('.nc'):
ngen_data, t_ax, catchment_ids = nc_to_3darray(args.ngen_forcings, requested_ngen_variables)
Expand Down
4 changes: 4 additions & 0 deletions forcingprocessor/src/forcingprocessor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,10 @@ def prep_ngen_data(conf):
jnwm_files = nwm_forcing_files[start:end]
t0 = time.perf_counter()
if ii_verbose: print(f'Entering data extraction...\n',flush=True)
# data_array, t_ax, nwm_data = forcing_grid2catchment(jnwm_files, fs)
# data_array=data_array[0][None,:]
# t_ax = t_ax
# nwm_data=nwm_data[0][None,:]
data_array, t_ax, nwm_data = multiprocess_data_extract(jnwm_files,nprocs,weights_json,fs)
t_extract = time.perf_counter() - t0
complexity = (nfiles_tot * ncatchments) / 10000
Expand Down
35 changes: 16 additions & 19 deletions forcingprocessor/tests/test_forcingprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
pwd = Path.cwd()
pwd = pwd
data_dir = data_dir
if os.path.exists(data_dir):
os.system(f"rm -rf {data_dir}")
os.system(f"mkdir {data_dir}")
pwd = Path.cwd()
filenamelist = str((pwd/"filenamelist.txt").resolve())
Expand All @@ -34,7 +36,8 @@

"run" : {
"verbose" : False,
"collect_stats" : True
"collect_stats" : True,
"nprocs" : 1
}
}

Expand Down Expand Up @@ -72,9 +75,7 @@ def test_nomads_prod():
os.remove(parquet)

def test_nomads_post_processed():
print(f'test_nomads_post_processed() is BROKEN - https://github.com/CIROH-UA/nwmurl/issues/62')
assert False
return
assert False, f'test_nomads_post_processed() is BROKEN - https://github.com/CIROH-UA/nwmurl/issues/62'
nwmurl_conf['start_date'] = "202408240000"
nwmurl_conf['end_date'] = "202408241700"
nwmurl_conf["urlbaseinput"] = 2
Expand All @@ -92,43 +93,38 @@ def test_nwm_google_apis():
prep_ngen_data(conf)
parquet = (data_dir/"forcings/cat-2586011.parquet").resolve()
assert parquet.exists()
os.remove(parquet)
os.remove(parquet)

def test_google_cloud_storage():
print(f'hangs in pytest, but should work')
return
nwmurl_conf['start_date'] = "202407100100"
nwmurl_conf['end_date'] = "202407100100"
nwmurl_conf["urlbaseinput"] = 4
generate_nwmfiles(nwmurl_conf)
prep_ngen_data(conf)
parquet = (data_dir/"forcings/cat-2586011.parquet").resolve()
assert parquet.exists()
os.remove(parquet)
os.remove(parquet)

def test_gs_nwm():
print(f'hangs in pytest, but should work')
return
def test_gs():
nwmurl_conf['start_date'] = date + hourminute
nwmurl_conf['end_date'] = date + hourminute
nwmurl_conf["urlbaseinput"] = 5
generate_nwmfiles(nwmurl_conf)
prep_ngen_data(conf)
parquet = (data_dir/"forcings/cat-2586011.parquet").resolve()
assert parquet.exists()
os.remove(parquet)
os.remove(parquet)

def test_gcs_nwm():
print(f'hangs in pytest, but should work')
return
nwmurl_conf['start_date'] = date + hourminute
nwmurl_conf['end_date'] = date + hourminute
def test_gcs():
# assert False, f'hangs in pytest, but should work'
nwmurl_conf['start_date'] = "202407100100"
nwmurl_conf['end_date'] = "202407100100"
nwmurl_conf["urlbaseinput"] = 6
generate_nwmfiles(nwmurl_conf)
prep_ngen_data(conf)
parquet = (data_dir/"forcings/cat-2586011.parquet").resolve()
assert parquet.exists()
os.remove(parquet)
os.remove(parquet)

def test_noaa_nwm_pds():
nwmurl_conf['start_date'] = date + hourminute
Expand Down Expand Up @@ -200,14 +196,15 @@ def test_retro_3_0():

def test_plotting():
conf['forcing']['nwm_file'] = retro_filenamelist
conf['plot'] = {}
conf['plot']['nts'] = 1
conf['plot']['ngen_vars'] = [
"TMP_2maboveground"
]
nwmurl_conf_retro["urlbaseinput"] = 4
generate_nwmfiles(nwmurl_conf_retro)
prep_ngen_data(conf)
GIF = (data_dir/"forcings/metadata/T2D_2_TMP_2maboveground.gif").resolve()
GIF = (data_dir/"metadata/GIFs/T2D_2_TMP_2maboveground.gif").resolve()
assert GIF.exists()
os.remove(GIF)

Expand Down
Loading

0 comments on commit 30095b2

Please sign in to comment.