Skip to content

Commit

Permalink
Update catchstats.py
Browse files Browse the repository at this point in the history
  • Loading branch information
casadoj committed Mar 27, 2024
1 parent 5dc1009 commit 7e83765
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/lisfloodutilities/catchstats/catchstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
import sys
import time
import xarray as xr
from typing import Union, List, Optional
from typing import List, Union, Optional, Literal
from tqdm.auto import tqdm


def catchment_statistics(inputmaps: Union[str, Path],
mask: Union[str, Path],
statistic: List[str],
statistic: List[Literal['mean', 'sum', 'std', 'var', 'min', 'max', 'median']],
output: Union[str, Path],
pixarea: Optional[str] = None
pixarea: Optional[str] = None,
overwrite: bool = False
):
"""
Given a set of input maps and catchment masks, it computes catchment statistics.
Expand All @@ -42,6 +43,8 @@ def catchment_statistics(inputmaps: Union[str, Path],
directory where the resulting NetCDF files will be saved.
pixarea: optional or str
if provided, a NetCDF file with pixel area used to compute weighted statistics. It is specifically meant for geographic projection systems where the area of a pixel varies with latitude
overwrite: boolean
whether to overwrite or skip catchments whose output NetCDF file already exists. By default is False, so the catchment will be skipped
Returns:
--------
Expand Down Expand Up @@ -122,6 +125,11 @@ def catchment_statistics(inputmaps: Union[str, Path],
# compute statistics for each catchemnt
for ID in tqdm(maskpaths.keys(), desc='processing catchments'):

fileout = output / f'{ID:04}.nc'
if fileout.exists() & ~overwrite:
print(f'Output file {fileout} already exists. Moving forward to the next catchment')
continue

# create empty Dataset
coords.update({'id': [ID]})
ds_aoi = xr.Dataset({var: xr.DataArray(coords=coords, dims=coords.keys()) for var in variables})
Expand Down Expand Up @@ -157,7 +165,7 @@ def catchment_statistics(inputmaps: Union[str, Path],
ds_aoi[f'{var}_{stat}'].loc[{'id': ID}] = x

# export
ds_aoi.to_netcdf(output / f'{ID:04}.nc')
ds_aoi.to_netcdf(fileout)

end_time = time.perf_counter()
elapsed_time = end_time - start_time
Expand All @@ -176,15 +184,16 @@ def main(argv=sys.argv):
""",
prog=prog,
)
parser.add_argument("-i", "--input", required=True, help="Input directory with NetCDF files")
parser.add_argument("-m", "--mask", required=True, help="Input NetCDF file that represents the mask")
parser.add_argument("-i", "--input", required=True, help="Directory containint the input NetCDF files")
parser.add_argument("-m", "--mask", required=True, help="Directory containing the mask NetCDF files")
parser.add_argument("-s", "--statistic", nargs='+', required=True, help='List of statistics to be computed')
parser.add_argument("-o", "--output", required=True, help="Directory where the output NetCDF files will be saved")
parser.add_argument("-w", "--weight", required=False, default=None, help="NetCDF file of pixel area used to weight the statistics")
parser.add_argument("-a", "--area", required=False, default=None, help="NetCDF file of pixel area used to weight the statistics")
parser.add_argument("-W", "--overwrite", action="store_true", help="Overwrite existing output files")

args = parser.parse_args()

catchment_statistics(args.input, args.mask, args.statistic, args.output, args.weight)
catchment_statistics(args.input, args.mask, args.statistic, args.output, args.area, args.overwrite)



Expand Down

0 comments on commit 7e83765

Please sign in to comment.