From 97f80d43604832d074ad76351b419f1bda181fbe Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <5533524+NickGeneva@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:24:04 -0700 Subject: [PATCH] Adding GEFS forecast data source (#112) * Adding GEFS data source * Adding tests for GEFS * Feedback * Fixes --- CHANGELOG.md | 3 +- docs/modules/datasources.rst | 1 + earth2studio/data/__init__.py | 1 + earth2studio/data/gefs.py | 402 +++++++++++++++++++++++++++++++ earth2studio/lexicon/__init__.py | 1 + earth2studio/lexicon/base.py | 4 +- earth2studio/lexicon/gefs.py | 239 ++++++++++++++++++ test/data/test_gefs.py | 202 ++++++++++++++++ 8 files changed, 850 insertions(+), 3 deletions(-) create mode 100644 earth2studio/data/gefs.py create mode 100644 earth2studio/lexicon/gefs.py create mode 100644 test/data/test_gefs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 65f4ef2a..a63386f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Forecast datasource API -- GFS Forecast datasource +- GFS forecast datasource +- GEFS forecast datasource ### Changed diff --git a/docs/modules/datasources.rst b/docs/modules/datasources.rst index 2f5013f6..110b8299 100644 --- a/docs/modules/datasources.rst +++ b/docs/modules/datasources.rst @@ -51,6 +51,7 @@ Typically used in intercomparison workflows. :template: datasource.rst data.GFS_FX + data.GEFS_FX Functions ~~~~~~~~~ diff --git a/earth2studio/data/__init__.py b/earth2studio/data/__init__.py index 6c368f89..cba068e2 100644 --- a/earth2studio/data/__init__.py +++ b/earth2studio/data/__init__.py @@ -17,6 +17,7 @@ from .arco import ARCO from .base import DataSource from .cds import CDS +from .gefs import GEFS_FX from .gfs import GFS, GFS_FX from .hrrr import HRRR from .ifs import IFS diff --git a/earth2studio/data/gefs.py b/earth2studio/data/gefs.py new file mode 100644 index 00000000..bbe80164 --- /dev/null +++ b/earth2studio/data/gefs.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import pathlib +import shutil +from datetime import datetime, timedelta + +import numpy as np +import s3fs +import xarray as xr +from fsspec.implementations.cached import WholeFileCacheFileSystem +from loguru import logger +from modulus.distributed.manager import DistributedManager +from s3fs.core import S3FileSystem +from tqdm import tqdm + +from earth2studio.data.utils import ( + datasource_cache_root, + prep_forecast_inputs, +) +from earth2studio.lexicon import GEFSLexicon +from earth2studio.utils.type import LeadTimeArray, TimeArray, VariableArray + +logger.remove() +logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) + + +class GEFS_FX: + """The Global Ensemble Forecast System (GEFS) forecast source is a 30 member + ensemble forecast provided on an equirectangular grid. GEFS is a weather forecast + model developed by National Centers for Environmental Prediction (NCEP). This data + source is on a 0.5 degree lat lon grid at 6-hour intervals spanning from + Sept 23rd 2020 to present date. Each forecast provides 3-hourly predictions up to + 10 days (240 hours) and 6 hourly predictions for another 6 days (384 hours). + + Parameters + ---------- + product : str, optional + GEFS product. Options are: control gec00 (control), gepNN (forecast member NN, + e.g. gep01, gep02,...), by default "gec00" + cache : bool, optional + Cache data source on local memory, by default True + verbose : bool, optional + Print download progress, by default True + + Warning + ------- + This is a remote data source and can potentially download a large amount of data + to your local machine for large requests. + + Note + ---- + This forecast source is an ensemble forecast of isobaric variables (secondary + variables). Additional surface variables are provided by GEFS at 0.25 degree + resolution but are not integrated. + + Note + ---- + Additional information on the data repository can be referenced here: + + - https://registry.opendata.aws/noaa-gefs/ + - https://www.ncei.noaa.gov/products/weather-climate-models/global-ensemble-forecast + - https://www.nco.ncep.noaa.gov/pmb/products/gens/ + """ + + GEFS_BUCKET_NAME = "noaa-gefs-pds" + MAX_BYTE_SIZE = 5000000 + + GEFS_LAT = np.linspace(90, -90, 361) + GEFS_LON = np.linspace(0, 359.5, 720) + + GEFS_PRODUCTS = ["gec00"] + [f"gep{i:02d}" for i in range(1, 31)] + + def __init__( + self, + product: str = "gec00", + cache: bool = True, + verbose: bool = True, + ): + + if product not in self.GEFS_PRODUCTS: + raise ValueError(f"Invalid GEFS product {product}") + + self._cache = cache + self._verbose = verbose + self._product = product + self.s3fs = s3fs.S3FileSystem( + anon=True, + default_block_size=2**20, + client_kwargs={}, + ) + + # if self._cache: + cache_options = { + "cache_storage": self.cache, + "expiry_time": 31622400, # 1 year + } + self.fs = WholeFileCacheFileSystem(fs=self.s3fs, **cache_options) + + def __call__( + self, + time: datetime | list[datetime] | TimeArray, + lead_time: timedelta | list[timedelta] | LeadTimeArray, + variable: str | list[str] | VariableArray, + ) -> xr.DataArray: + """Retrieve GEFS ensemble forecast data + + Parameters + ---------- + time : datetime | list[datetime] | TimeArray + Timestamps to return data for (UTC). + lead_time: timedelta | list[timedelta] | LeadTimeArray + Forecast lead times to fetch. + variable : str | list[str] | VariableArray + String, list of strings or array of strings that refer to variables to + return. Must be in the GEFS lexicon. + + Returns + ------- + xr.DataArray + GEFS weather data array + """ + time, lead_time, variable = prep_forecast_inputs(time, lead_time, variable) + + # Create cache dir if doesnt exist + pathlib.Path(self.cache).mkdir(parents=True, exist_ok=True) + + # Make sure input time is valid + self._validate_time(time) + self._validate_leadtime(lead_time) + + # Create data array of forecast data + xr_array = self.create_data_array(time, lead_time, variable) + + # Delete cache if needed + if not self._cache: + shutil.rmtree(self.cache) + + return xr_array + + def create_data_array( + self, time: list[datetime], lead_time: list[timedelta], variable: list[str] + ) -> xr.DataArray: + """Function that creates and populates an xarray data array with requested + GEFS data. + + Parameters + ---------- + time : list[datetime] + Time list to fetch + lead_time: list[timedelta] + Lead time list to getch + variable : list[str] + Variable list to fetch + + Returns + ------- + xr.DataArray + Xarray data array + """ + xr_array = xr.DataArray( + data=np.empty( + ( + len(time), + len(lead_time), + len(variable), + len(self.GEFS_LAT), + len(self.GEFS_LON), + ) + ), + dims=["time", "lead_time", "variable", "lat", "lon"], + coords={ + "time": time, + "lead_time": lead_time, + "variable": variable, + "lat": self.GEFS_LAT, + "lon": self.GEFS_LON, + }, + ) + + args = [ + (t, i, l, j, v, k) + for k, v in enumerate(variable) + for j, l in enumerate(lead_time) # noqa + for i, t in enumerate(time) + ] + + pbar = tqdm( + total=len(args), desc="Fetching GEFS data", disable=(not self._verbose) + ) + for (t, i, l, j, v, k) in args: # noqa + data = self.fetch_array(t, l, v) + xr_array[i, j, k] = data + pbar.update(1) + + return xr_array + + def fetch_array( + self, time: datetime, lead_time: timedelta, variable: str + ) -> np.ndarray: + """Fetches requested array from remote store + + Parameters + ---------- + time : datetime + Time to fetch + lead_time: timedelta + Lead time to fetch + variable : str + Variable to fetch + + Returns + ------- + np.ndarray + Data + """ + logger.debug( + f"Fetching GEFS data for variable: {variable} at {time.isoformat()} lead time {lead_time}" + ) + try: + gefs_name, modifier = GEFSLexicon[variable] + gefs_grib, gefs_var, gefs_level = gefs_name.split("::") + except KeyError as e: + logger.error(f"variable id {variable} not found in GEFS lexicon") + raise e + + # Set up s3 file paths + lead_hour = int(lead_time.total_seconds() // 3600) + file_name = f"gefs.{time.year}{time.month:0>2}{time.day:0>2}/{time.hour:0>2}" + index_file_name = os.path.join( + file_name, + f"atmos/{gefs_grib}p5/{self._product}.t{time.hour:0>2}z.{gefs_grib}.0p50.f{lead_hour:03d}.idx", + ) + s3_index_uri = os.path.join(self.GEFS_BUCKET_NAME, index_file_name) + s3_grib_uri = os.path.join(self.GEFS_BUCKET_NAME, index_file_name[:-4]) + + # Download the grib index file and parse + with self.fs.open(s3_index_uri) as file: + index_lines = [line.decode("utf-8").rstrip() for line in file] + + index_table = {} + # Note we actually drop the last variable here (Vertical Speed Shear) + for i, line in enumerate(index_lines[:-1]): + lsplit = line.split(":") + if len(lsplit) < 7: + continue + + nlsplit = index_lines[i + 1].split(":") + byte_length = int(nlsplit[1]) - int(lsplit[1]) + byte_offset = int(lsplit[1]) + key = f"{lsplit[3]}::{lsplit[4]}" + if byte_length > self.MAX_BYTE_SIZE: + raise ValueError( + f"Byte length, {byte_length}, of variable {key} larger than safe threshold of {self.MAX_BYTE_SIZE}" + ) + + index_table[key] = (byte_offset, byte_length) + + gefs_name = f"{gefs_var}::{gefs_level}" + if gefs_name not in index_table: + raise KeyError(f"Could not find variable {gefs_name} in index file") + byte_offset = index_table[gefs_name][0] + byte_length = index_table[gefs_name][1] + + # CFGrib requires file to be downloaded locally, no one like grib + # Because we need to read a byte range we need to use the s3fs store directly + # so manual caching it is. Literally NO ONE likes grib + sha = hashlib.sha256( + (s3_grib_uri + str(byte_offset) + str(byte_length)).encode() + ) + filename = sha.hexdigest() + cache_path = os.path.join(self.cache, filename) + + if not pathlib.Path(cache_path).is_file(): + grib_buffer = self.s3fs.read_block( + s3_grib_uri, offset=byte_offset, length=byte_length + ) + with open(cache_path, "wb") as file: + file.write(grib_buffer) + + da = xr.open_dataarray( + cache_path, + engine="cfgrib", + backend_kwargs={"indexpath": ""}, + ) + return modifier(da.values) + + @classmethod + def _validate_time(cls, times: list[datetime]) -> None: + """Verify if date time is valid for GEFS based on offline knowledge + + Parameters + ---------- + times : list[datetime] + list of date times to fetch data + """ + for time in times: + if not (time - datetime(1900, 1, 1)).total_seconds() % 21600 == 0: + raise ValueError( + f"Requested date time {time} needs to be 6 hour interval for GEFS" + ) + # To update search "gfs." at https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html + # They are slowly adding more data + if time < datetime(year=2020, month=9, day=23): + raise ValueError( + f"Requested date time {time} needs to be after Sept 23rd, 2020 for GEFS" + ) + + @classmethod + def _validate_leadtime(cls, lead_times: list[timedelta]) -> None: + """Verify if lead time is valid for GEFS based on offline knowledge + + Parameters + ---------- + lead_times : list[timedelta] + list of lead times to fetch data + """ + for delta in lead_times: + # To update search "gefs." at https://noaa-gefs-pds.s3.amazonaws.com/index.html + hours = int(delta.total_seconds() // 3600) + if hours > 384 or hours < 0: + raise ValueError( + f"Requested lead time {delta} can only be a max of 384 hours for GEFS" + ) + + # 3-hours supported for first 10 days + if delta.total_seconds() // 3600 <= 240: + if not delta.total_seconds() % 10800 == 0: + raise ValueError( + f"Requested lead time {delta} needs to be 3 hour interval for first 10 days in GEFS" + ) + # 6 hours for rest + else: + if not delta.total_seconds() % 21600 == 0: + raise ValueError( + f"Requested lead time {delta} needs to be 6 hour interval for last 6 days in GEFS" + ) + + @property + def cache(self) -> str: + """Return appropriate cache location.""" + cache_location = os.path.join(datasource_cache_root(), "gfs") + if not self._cache: + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + cache_location = os.path.join( + cache_location, f"tmp_{DistributedManager().rank}" + ) + return cache_location + + @classmethod + def available( + cls, + time: datetime | np.datetime64, + ) -> bool: + """Checks if given date time is avaliable in the GEFS object store + + Parameters + ---------- + time : datetime | np.datetime64 + Date time to access + + Returns + ------- + bool + If date time is avaiable + """ + if isinstance(time, np.datetime64): # np.datetime64 -> datetime + _unix = np.datetime64(0, "s") + _ds = np.timedelta64(1, "s") + time = datetime.utcfromtimestamp((time - _unix) / _ds) + + # Offline checks + try: + cls._validate_time([time]) + except ValueError: + return False + + fs = S3FileSystem(anon=True) + + # Object store directory for given time + # Should contain two keys: atmos and wave + file_name = f"gefs.{time.year}{time.month:0>2}{time.day:0>2}/{time.hour:0>2}/atmos/pgrb2bp5/" + s3_uri = f"s3://{cls.GEFS_BUCKET_NAME}/{file_name}" + exists = fs.exists(s3_uri) + + return exists diff --git a/earth2studio/lexicon/__init__.py b/earth2studio/lexicon/__init__.py index b0803d5f..fd5ec67a 100644 --- a/earth2studio/lexicon/__init__.py +++ b/earth2studio/lexicon/__init__.py @@ -16,6 +16,7 @@ from .arco import ARCOLexicon from .cds import CDSLexicon +from .gefs import GEFSLexicon from .gfs import GFSLexicon from .hrrr import HRRRLexicon from .ifs import IFSLexicon diff --git a/earth2studio/lexicon/base.py b/earth2studio/lexicon/base.py index 45ba486f..d159d87f 100644 --- a/earth2studio/lexicon/base.py +++ b/earth2studio/lexicon/base.py @@ -33,8 +33,8 @@ def __getitem__(cls, val: str) -> tuple[str, Callable]: "t2m": "temperature at 2m", "sp": "surface pressure", "msl": "mean sea level pressure", - "tcwv": "total column water vapor", - "tp": "total precipitation in meters", + "tcwv": "total column water vapor / precipitable water (kg m-2)", + "tp": "total precipitation (m)", "tpp": "total precipitation probability", "tpi": "total precipitation index", "tp06": "total precipitation accumulated over past 6 hours", diff --git a/earth2studio/lexicon/gefs.py b/earth2studio/lexicon/gefs.py new file mode 100644 index 00000000..18fe6021 --- /dev/null +++ b/earth2studio/lexicon/gefs.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + +import numpy as np + +from .base import LexiconType + + +class GEFSLexicon(metaclass=LexiconType): + """Global Ensemble Forecast System Lexicon, right now only support isobarric + GEFS specified :: + + Note + ---- + Additional resources: + - https://www.nco.ncep.noaa.gov/pmb/products/gens/gec00.t00z.pgrb2a.0p50.f003.shtml + - https://www.nco.ncep.noaa.gov/pmb/products/gens/gec00.t00z.pgrb2b.0p50.f003.shtml + """ + + VOCAB = { + "u10m": "pgrb2a::UGRD::10 m above ground", + "v10m": "pgrb2a::VGRD::10 m above ground", + "u100m": "pgrb2b::UGRD::100 m above ground", + "v100m": "pgrb2b::VGRD::100 m above ground", + "t2m": "pgrb2a::TMP::2 m above ground", + "t100m": "pgrb2b::TMP::100 m above ground", + "sp": "pgrb2a::PRES::surface", + "msl": "pgrb2b::PRES::mean sea level", + "tcwv": "pgrb2a::PWAT::entire atmosphere (considered as a single layer)", + "u1": "pgrb2b::UGRD::1 mb", + "u2": "pgrb2b::UGRD::2 mb", + "u3": "pgrb2b::UGRD::3 mb", + "u5": "pgrb2b::UGRD::5 mb", + "u7": "pgrb2b::UGRD::7 mb", + "u10": "pgrb2a::UGRD::10 mb", + "u20": "pgrb2b::UGRD::20 mb", + "u30": "pgrb2b::UGRD::30 mb", + "u50": "pgrb2a::UGRD::50 mb", + "u70": "pgrb2b::UGRD::70 mb", + "u100": "pgrb2a::UGRD::100 mb", + "u150": "pgrb2b::UGRD::150 mb", + "u200": "pgrb2a::UGRD::200 mb", + "u250": "pgrb2a::UGRD::250 mb", + "u300": "pgrb2a::UGRD::300 mb", + "u350": "pgrb2b::UGRD::350 mb", + "u400": "pgrb2a::UGRD::400 mb", + "u450": "pgrb2b::UGRD::450 mb", + "u500": "pgrb2a::UGRD::500 mb", + "u550": "pgrb2b::UGRD::550 mb", + "u600": "pgrb2b::UGRD::600 mb", + "u650": "pgrb2b::UGRD::650 mb", + "u700": "pgrb2a::UGRD::700 mb", + "u750": "pgrb2b::UGRD::750 mb", + "u800": "pgrb2b::UGRD::800 mb", + "u850": "pgrb2a::UGRD::850 mb", + "u900": "pgrb2b::UGRD::900 mb", + "u925": "pgrb2a::UGRD::925 mb", + "u950": "pgrb2b::UGRD::950 mb", + "u975": "pgrb2b::UGRD::975 mb", + "u1000": "pgrb2a::UGRD::1000 mb", + "v1": "pgrb2b::VGRD::1 mb", + "v2": "pgrb2b::VGRD::2 mb", + "v3": "pgrb2b::VGRD::3 mb", + "v5": "pgrb2b::VGRD::5 mb", + "v7": "pgrb2b::VGRD::7 mb", + "v10": "pgrb2a::VGRD::10 mb", + "v20": "pgrb2b::VGRD::20 mb", + "v30": "pgrb2b::VGRD::30 mb", + "v50": "pgrb2a::VGRD::50 mb", + "v70": "pgrb2b::VGRD::70 mb", + "v100": "pgrb2a::VGRD::100 mb", + "v150": "pgrb2b::VGRD::150 mb", + "v200": "pgrb2a::VGRD::200 mb", + "v250": "pgrb2a::VGRD::250 mb", + "v300": "pgrb2a::VGRD::300 mb", + "v350": "pgrb2b::VGRD::350 mb", + "v400": "pgrb2a::VGRD::400 mb", + "v450": "pgrb2b::VGRD::450 mb", + "v500": "pgrb2a::VGRD::500 mb", + "v550": "pgrb2b::VGRD::550 mb", + "v600": "pgrb2b::VGRD::600 mb", + "v650": "pgrb2b::VGRD::650 mb", + "v700": "pgrb2a::VGRD::700 mb", + "v750": "pgrb2b::VGRD::750 mb", + "v800": "pgrb2b::VGRD::800 mb", + "v850": "pgrb2a::VGRD::850 mb", + "v900": "pgrb2b::VGRD::900 mb", + "v925": "pgrb2a::VGRD::925 mb", + "v950": "pgrb2b::VGRD::950 mb", + "v975": "pgrb2b::VGRD::975 mb", + "v1000": "pgrb2a::VGRD::1000 mb", + "z1": "pgrb2b::HGT::1 mb", + "z2": "pgrb2b::HGT::2 mb", + "z3": "pgrb2b::HGT::3 mb", + "z5": "pgrb2b::HGT::5 mb", + "z7": "pgrb2b::HGT::7 mb", + "z10": "pgrb2a::HGT::10 mb", + "z20": "pgrb2b::HGT::20 mb", + "z30": "pgrb2b::HGT::30 mb", + "z50": "pgrb2a::HGT::50 mb", + "z70": "pgrb2b::HGT::70 mb", + "z100": "pgrb2a::HGT::100 mb", + "z150": "pgrb2b::HGT::150 mb", + "z200": "pgrb2a::HGT::200 mb", + "z250": "pgrb2a::HGT::250 mb", + "z300": "pgrb2a::HGT::300 mb", + "z350": "pgrb2b::HGT::350 mb", + "z400": "pgrb2b::HGT::400 mb", + "z450": "pgrb2b::HGT::450 mb", + "z500": "pgrb2a::HGT::500 mb", + "z550": "pgrb2b::HGT::550 mb", + "z600": "pgrb2b::HGT::600 mb", + "z650": "pgrb2b::HGT::650 mb", + "z700": "pgrb2a::HGT::700 mb", + "z750": "pgrb2b::HGT::750 mb", + "z800": "pgrb2b::HGT::800 mb", + "z850": "pgrb2a::HGT::850 mb", + "z900": "pgrb2b::HGT::900 mb", + "z925": "pgrb2a::HGT::925 mb", + "z950": "pgrb2b::HGT::950 mb", + "z975": "pgrb2b::HGT::975 mb", + "z1000": "pgrb2a::HGT::1000 mb", + "t1": "pgrb2b::TMP::1 mb", + "t2": "pgrb2b::TMP::2 mb", + "t3": "pgrb2b::TMP::3 mb", + "t5": "pgrb2b::TMP::5 mb", + "t7": "pgrb2b::TMP::7 mb", + "t10": "pgrb2a::TMP::10 mb", + "t20": "pgrb2b::TMP::20 mb", + "t30": "pgrb2b::TMP::30 mb", + "t50": "pgrb2a::TMP::50 mb", + "t70": "pgrb2b::TMP::70 mb", + "t100": "pgrb2a::TMP::100 mb", + "t150": "pgrb2b::TMP::150 mb", + "t200": "pgrb2a::TMP::200 mb", + "t250": "pgrb2a::TMP::250 mb", + "t300": "pgrb2b::TMP::300 mb", + "t350": "pgrb2b::TMP::350 mb", + "t400": "pgrb2b::TMP::400 mb", + "t450": "pgrb2b::TMP::450 mb", + "t500": "pgrb2a::TMP::500 mb", + "t550": "pgrb2b::TMP::550 mb", + "t600": "pgrb2b::TMP::600 mb", + "t650": "pgrb2b::TMP::650 mb", + "t700": "pgrb2a::TMP::700 mb", + "t750": "pgrb2b::TMP::750 mb", + "t800": "pgrb2b::TMP::800 mb", + "t850": "pgrb2a::TMP::850 mb", + "t900": "pgrb2b::TMP::900 mb", + "t925": "pgrb2a::TMP::925 mb", + "t950": "pgrb2b::TMP::950 mb", + "t975": "pgrb2b::TMP::975 mb", + "t1000": "pgrb2a::TMP::1000 mb", + "r10": "pgrb2a::RH::10 mb", + "r20": "pgrb2b::RH::20 mb", + "r30": "pgrb2b::RH::30 mb", + "r50": "pgrb2a::RH::50 mb", + "r70": "pgrb2b::RH::70 mb", + "r100": "pgrb2a::RH::100 mb", + "r150": "pgrb2b::RH::150 mb", + "r200": "pgrb2a::RH::200 mb", + "r250": "pgrb2a::RH::250 mb", + "r300": "pgrb2b::RH::300 mb", + "r350": "pgrb2b::RH::350 mb", + "r400": "pgrb2b::RH::400 mb", + "r450": "pgrb2b::RH::450 mb", + "r500": "pgrb2a::RH::500 mb", + "r550": "pgrb2b::RH::550 mb", + "r600": "pgrb2b::RH::600 mb", + "r650": "pgrb2b::RH::650 mb", + "r700": "pgrb2a::RH::700 mb", + "r750": "pgrb2b::RH::750 mb", + "r800": "pgrb2b::RH::800 mb", + "r850": "pgrb2a::RH::850 mb", + "r900": "pgrb2b::RH::900 mb", + "r925": "pgrb2a::RH::925 mb", + "r950": "pgrb2b::RH::950 mb", + "r975": "pgrb2b::RH::975 mb", + "r1000": "pgrb2a::RH::1000 mb", + "q10": "pgrb2b::SPFH::10 mb", + "q20": "pgrb2b::SPFH::20 mb", + "q30": "pgrb2b::SPFH::30 mb", + "q50": "pgrb2b::SPFH::50 mb", + "q70": "pgrb2b::SPFH::70 mb", + "q100": "pgrb2b::SPFH::100 mb", + "q150": "pgrb2b::SPFH::150 mb", + "q200": "pgrb2b::SPFH::200 mb", + "q250": "pgrb2b::SPFH::250 mb", + "q300": "pgrb2b::SPFH::300 mb", + "q350": "pgrb2b::SPFH::350 mb", + "q400": "pgrb2b::SPFH::400 mb", + "q450": "pgrb2b::SPFH::450 mb", + "q500": "pgrb2b::SPFH::500 mb", + "q550": "pgrb2b::SPFH::550 mb", + "q600": "pgrb2b::SPFH::600 mb", + "q650": "pgrb2b::SPFH::650 mb", + "q700": "pgrb2b::SPFH::700 mb", + "q750": "pgrb2b::SPFH::750 mb", + "q800": "pgrb2b::SPFH::800 mb", + "q850": "pgrb2b::SPFH::850 mb", + "q900": "pgrb2b::SPFH::900 mb", + "q925": "pgrb2b::SPFH::925 mb", + "q950": "pgrb2b::SPFH::950 mb", + "q975": "pgrb2b::SPFH::975 mb", + "q1000": "pgrb2b::SPFH::1000 mb", + } + + @classmethod + def get_item(cls, val: str) -> tuple[str, Callable]: + """Get item from GFS vocabulary.""" + gfs_key = cls.VOCAB[val] + if gfs_key.split("::")[0] == "HGT": + + def mod(x: np.array) -> np.array: + """Modify data value (if necessary).""" + return x * 9.81 + + else: + + def mod(x: np.array) -> np.array: + """Modify data value (if necessary).""" + return x + + return gfs_key, mod diff --git a/test/data/test_gefs.py b/test/data/test_gefs.py new file mode 100644 index 00000000..f9402d1e --- /dev/null +++ b/test/data/test_gefs.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import shutil +from datetime import datetime, timedelta + +import numpy as np +import pytest + +from earth2studio.data import GEFS_FX + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(30) +@pytest.mark.parametrize( + "time,lead_time,variable", + [ + ( + datetime(year=2020, month=11, day=1), + [timedelta(hours=3), timedelta(hours=384)], + "t2m", + ), + ( + [ + datetime(year=2021, month=8, day=8, hour=6), + datetime(year=2022, month=4, day=20, hour=12), + ], + timedelta(hours=0), + ["msl"], + ), + ( + datetime(year=2020, month=11, day=1), + [timedelta(hours=240)], + np.array(["tcwv", "u10m"]), + ), + ], +) +def test_gefs_fetch(time, lead_time, variable): + + ds = GEFS_FX(cache=False) + data = ds(time, lead_time, variable) + shape = data.shape + + if isinstance(variable, str): + variable = [variable] + + if isinstance(lead_time, timedelta): + lead_time = [lead_time] + + if isinstance(time, datetime): + time = [time] + + assert shape[0] == len(time) + assert shape[1] == len(lead_time) + assert shape[2] == len(variable) + assert shape[3] == 361 + assert shape[4] == 720 + assert not np.isnan(data.values).any() + assert GEFS_FX.available(time[0]) + assert np.array_equal(data.coords["variable"].values, np.array(variable)) + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(30) +@pytest.mark.parametrize("product", ["gec00", "gep01", "gep30"]) +def test_gefs_products(product): + time = datetime(year=2022, month=12, day=25) + lead_time = timedelta(hours=3) + variable = "u100m" + + ds = GEFS_FX(product, cache=False) + data = ds(time, lead_time, variable) + shape = data.shape + + assert shape[0] == 1 + assert shape[1] == 1 + assert shape[2] == 1 + assert shape[3] == 361 + assert shape[4] == 720 + assert not np.isnan(data.values).any() + assert np.array_equal(data.coords["variable"].values, np.array([variable])) + + +@pytest.mark.slow +@pytest.mark.xfail +@pytest.mark.timeout(30) +@pytest.mark.parametrize( + "time", + [ + np.array([np.datetime64("2024-01-01T00:00")]), + ], +) +@pytest.mark.parametrize("variable", [["t2m", "msl"]]) +@pytest.mark.parametrize("cache", [True, False]) +def test_gefs_cache(time, variable, cache): + + lead_time = np.array([np.timedelta64(3, "h")]) + + ds = GEFS_FX(cache=cache) + data = ds(time, lead_time, variable) + shape = data.shape + + assert shape[0] == 1 + assert shape[1] == 1 + assert shape[2] == 2 + assert shape[3] == 361 + assert shape[4] == 720 + assert not np.isnan(data.values).any() + assert GEFS_FX.available(time[0]) + # Cahce should be present + assert pathlib.Path(ds.cache).is_dir() == cache + + # Load from cach or refetch + data = ds(time, lead_time, variable[0]) + shape = data.shape + + assert shape[0] == 1 + assert shape[1] == 1 + assert shape[2] == 1 + assert shape[3] == 361 + assert shape[4] == 720 + assert not np.isnan(data.values).any() + assert GEFS_FX.available(time[0]) + + try: + shutil.rmtree(ds.cache) + except FileNotFoundError: + pass + + +@pytest.mark.timeout(15) +@pytest.mark.parametrize( + "time", + [ + datetime(year=2020, month=9, day=22), + datetime.now(), + ], +) +def test_gefs_available(time): + variable = ["mpl"] + lead_time = timedelta(hours=0) + assert not GEFS_FX.available(time) + with pytest.raises(ValueError): + ds = GEFS_FX(cache=False) + ds(time, lead_time, variable) + + +@pytest.mark.timeout(5) +@pytest.mark.parametrize( + "lead_time", + [ + timedelta(hours=-1), + [timedelta(hours=2), timedelta(hours=2, minutes=1)], + np.array([np.timedelta64(243, "h")]), + np.array([np.timedelta64(390, "h")]), + ], +) +def test_gefs_invalid_lead(lead_time): + time = datetime(year=2022, month=12, day=25) + variable = "t2m" + with pytest.raises(ValueError): + ds = GEFS_FX(cache=False) + ds(time, lead_time, variable) + + +@pytest.mark.timeout(5) +@pytest.mark.parametrize( + "variable", + ["aaa", "t1m"], +) +def test_gefs_invalid_variable(variable): + time = datetime(year=2022, month=12, day=25) + lead_time = timedelta(hours=0) + with pytest.raises(KeyError): + ds = GEFS_FX(cache=False) + ds(time, lead_time, variable) + + +@pytest.mark.timeout(5) +@pytest.mark.parametrize( + "product", + ["gec0", "gep31", "gep00"], +) +def test_gefs_invalid_product(product): + with pytest.raises(ValueError): + GEFS_FX(product, cache=False)