From 8089efce3e66c1339e350bc99398c54838e68f35 Mon Sep 17 00:00:00 2001 From: Darius Couchard Date: Thu, 20 Jun 2024 15:53:54 +0200 Subject: [PATCH] Pushing unit-test fixes --- src/openeo_gfmap/inference/model_inference.py | 3 +++ tests/test_openeo_gfmap/test_feature_extractors.py | 4 +++- tests/test_openeo_gfmap/test_s1_fetchers.py | 2 +- tests/test_openeo_gfmap/test_s2_fetchers.py | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/openeo_gfmap/inference/model_inference.py b/src/openeo_gfmap/inference/model_inference.py index 226f8cf..05c1bea 100644 --- a/src/openeo_gfmap/inference/model_inference.py +++ b/src/openeo_gfmap/inference/model_inference.py @@ -169,6 +169,9 @@ class ONNXModelInference(ModelInference): """ + def dependencies(self) -> list: + return [] # Disable dependencies + def output_labels(self) -> list: return self._parameters["output_labels"] diff --git a/tests/test_openeo_gfmap/test_feature_extractors.py b/tests/test_openeo_gfmap/test_feature_extractors.py index 7dacb57..16cf657 100644 --- a/tests/test_openeo_gfmap/test_feature_extractors.py +++ b/tests/test_openeo_gfmap/test_feature_extractors.py @@ -227,7 +227,9 @@ def test_patch_feature_local(): bands=[band for band in inds.bands.to_numpy() if band != "crs"] ).transpose("bands", "t", "y", "x") - features = apply_feature_extractor_local(DummyPatchExtractor, inds, parameters={}) + features = apply_feature_extractor_local( + DummyPatchExtractor, inds, parameters={"GEO-EPSG": 32631} + ) features.to_netcdf(Path(__file__).parent / "results/patch_features_local.nc") diff --git a/tests/test_openeo_gfmap/test_s1_fetchers.py b/tests/test_openeo_gfmap/test_s1_fetchers.py index 364ef8d..c36c099 100644 --- a/tests/test_openeo_gfmap/test_s1_fetchers.py +++ b/tests/test_openeo_gfmap/test_s1_fetchers.py @@ -100,7 +100,7 @@ def compare_sentinel1_tiles(): tile_path = ( Path(__file__).parent / f"results/{backend.value}_sentinel1_grd.nc" ) - loaded_tiles.append(xr.open_dataset(tile_path, engine="h5netcdf")) + loaded_tiles.append(xr.open_dataset(tile_path)) # Compare the variable data type dtype = None diff --git a/tests/test_openeo_gfmap/test_s2_fetchers.py b/tests/test_openeo_gfmap/test_s2_fetchers.py index 0827374..e7318bb 100644 --- a/tests/test_openeo_gfmap/test_s2_fetchers.py +++ b/tests/test_openeo_gfmap/test_s2_fetchers.py @@ -129,7 +129,7 @@ def compare_sentinel2_tiles(): tile_path = ( Path(__file__).parent / f"results/{backend.value}_sentinel2_l2a.nc" ) - loaded_tiles.append(xr.open_dataset(tile_path, engine="h5netcdf")) + loaded_tiles.append(xr.open_dataset(tile_path)) # Compare the tile variable types all togheter dtype = None