Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provide list for filter_by_keys #384

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cfgrib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ def get_values_in_order(message, shape):
class OnDiskArray:
index: abc.Index[T.Any, abc.Field]
shape: T.Tuple[int, ...]
field_id_index: T.Dict[
T.Tuple[T.Any, ...], T.List[T.Union[int, T.Tuple[int, int]]]
] = attr.attrib(repr=False)
field_id_index: T.Dict[T.Tuple[T.Any, ...], T.List[T.Union[int, T.Tuple[int, int]]]] = (
attr.attrib(repr=False)
)
missing_value: float
geo_ndim: int = attr.attrib(default=1, repr=False)
dtype = np.dtype("float32")
Expand Down
5 changes: 4 additions & 1 deletion cfgrib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,10 @@ def subindex(self, filter_by_keys={}, **query):
field_ids_index = []
for header_values, field_ids_values in self.field_ids_index:
for idx, val in raw_query:
if header_values[idx] != val:
# Ensure that the values to be tested is a list or tuple
if not isinstance(val, (list, tuple)):
val = [val]
if header_values[idx] not in val:
break
else:
field_ids_index.append((header_values, field_ids_values))
Expand Down
Binary file added tests/sample-data/era5-levels-members.nc
Binary file not shown.
20 changes: 20 additions & 0 deletions tests/test_30_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
TEST_DATA_SCALAR_TIME = os.path.join(SAMPLE_DATA_FOLDER, "era5-single-level-scalar-time.grib")
TEST_DATA_ALTERNATE_ROWS = os.path.join(SAMPLE_DATA_FOLDER, "alternate-scanning.grib")
TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib")
TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib")


def test_enforce_unique_attributes() -> None:
Expand Down Expand Up @@ -340,11 +341,30 @@ def test_open_fieldset_ignore_keys() -> None:
assert "GRIB_subCentre" not in res.attributes

def test_open_file() -> None:
res = dataset.open_file(TEST_DATA)

assert "t" in res.variables
assert "z" in res.variables


def test_open_file_filter_by_keys() -> None:
res = dataset.open_file(TEST_DATA, filter_by_keys={"shortName": "t"})

assert "t" in res.variables
assert "z" not in res.variables

res = dataset.open_file(TEST_DATA_MULTI_PARAMS)

assert "t" in res.variables
assert "z" in res.variables
assert "u" in res.variables

res = dataset.open_file(TEST_DATA_MULTI_PARAMS, filter_by_keys={"shortName": ["t", "z"]})

assert "t" in res.variables
assert "z" in res.variables
assert "u" not in res.variables


def test_alternating_rows() -> None:
res = dataset.open_file(TEST_DATA_ALTERNATE_ROWS)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_50_xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SAMPLE_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "sample-data")
TEST_DATA = os.path.join(SAMPLE_DATA_FOLDER, "regular_ll_sfc.grib")
TEST_DATA_MISSING_VALS = os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib")
TEST_DATA_MULTI_PARAMS = os.path.join(SAMPLE_DATA_FOLDER, "multi_param_on_multi_dims.grib")


def test_plugin() -> None:
Expand All @@ -29,6 +30,30 @@ def test_xr_open_dataset_file() -> None:
assert list(ds.data_vars) == ["skt"]


def test_xr_open_dataset_file_filter_by_keys() -> None:
ds = xr.open_dataset(TEST_DATA_MULTI_PARAMS, engine="cfgrib")

assert "t" in ds.data_vars
assert "z" in ds.data_vars
assert "u" in ds.data_vars

ds = xr.open_dataset(
TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": "t"}
)

assert "t" in ds.data_vars
assert "z" not in ds.data_vars
assert "u" not in ds.data_vars

ds = xr.open_dataset(
TEST_DATA_MULTI_PARAMS, engine="cfgrib", filter_by_keys={"shortName": ["t", "z"]}
)

assert "t" in ds.data_vars
assert "z" in ds.data_vars
assert "u" not in ds.data_vars


def test_xr_open_dataset_file_ignore_keys() -> None:
ds = xr.open_dataset(TEST_DATA, engine="cfgrib")
assert "GRIB_typeOfLevel" in ds["skt"].attrs
Expand Down
Loading