Skip to content

Commit

Permalink
Merge pull request #150 from Open-EO/hv_unit_test_fetcher
Browse files Browse the repository at this point in the history
Hv unit test fetcher
  • Loading branch information
HansVRP authored Sep 11, 2024
2 parents 4e568b0 + 849bc55 commit 4602a46
Show file tree
Hide file tree
Showing 6 changed files with 627 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _load_collection_hybrid(
return cube


# TODO; deprecated?
def _load_collection(
connection: openeo.Connection,
bands: list,
Expand Down
1 change: 1 addition & 0 deletions src/openeo_gfmap/fetching/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def s2_l2a_fetch_default(
return s2_l2a_fetch_default


# TODO deprecated?
def _get_s2_l2a_element84_fetcher(
collection_name: str, fetch_type: FetchType
) -> Callable:
Expand Down
226 changes: 226 additions & 0 deletions tests/test_openeo_gfmap/test_commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from unittest.mock import MagicMock

import openeo
import pytest

from openeo_gfmap.fetching.commons import (
convert_band_names,
rename_bands,
resample_reproject,
)
from openeo_gfmap.fetching.s2 import (
BASE_SENTINEL2_L2A_MAPPING,
ELEMENT84_SENTINEL2_L2A_MAPPING,
)

from .utils import create_test_datacube

# band names


def test_convert_band_names_base_mapping():
"""Test conversion with BASE_SENTINEL2_L2A_MAPPING."""
desired_bands = ["S2-L2A-B01", "S2-L2A-B03", "S2-L2A-B12"]
result = convert_band_names(desired_bands, BASE_SENTINEL2_L2A_MAPPING)
assert result == ["B01", "B03", "B12"]


def test_convert_band_names_element84_mapping():
"""Test conversion with ELEMENT84_SENTINEL2_L2A_MAPPING."""
desired_bands = ["S2-L2A-B01", "S2-L2A-B08", "S2-L2A-B12"]
result = convert_band_names(desired_bands, ELEMENT84_SENTINEL2_L2A_MAPPING)
assert result == ["coastal", "nir", "swir22"]


def test_convert_band_names_mixed_case():
"""Test conversion with a mix of known and unknown bands in BASE_SENTINEL2_L2A_MAPPING."""
desired_bands = ["S2-L2A-B01", "S2-L2A-B99"] # S2-L2A-B99 does not exist
with pytest.raises(KeyError):
convert_band_names(desired_bands, BASE_SENTINEL2_L2A_MAPPING)


def test_convert_band_names_empty_base_mapping():
"""Test conversion with an empty desired_bands list in BASE_SENTINEL2_L2A_MAPPING."""
desired_bands = []
result = convert_band_names(desired_bands, BASE_SENTINEL2_L2A_MAPPING)
assert result == []


def test_convert_band_names_empty_element84_mapping():
"""Test conversion with an empty desired_bands list in ELEMENT84_SENTINEL2_L2A_MAPPING."""
desired_bands = []
result = convert_band_names(desired_bands, ELEMENT84_SENTINEL2_L2A_MAPPING)
assert result == []


def test_convert_band_names_with_nonexistent_band_element84():
"""Test conversion where a band is not in ELEMENT84_SENTINEL2_L2A_MAPPING."""
desired_bands = ["S2-L2A-B01", "S2-L2A-B99"] # S2-L2A-B99 does not exist
with pytest.raises(KeyError):
convert_band_names(desired_bands, ELEMENT84_SENTINEL2_L2A_MAPPING)


# resampling
def test_resample_reproject_valid_epsg():
"""Test resample_reproject with a valid EPSG code."""
# Create a mock DataCube object
mock_datacube = MagicMock(spec=openeo.DataCube)

# Mock the resample_spatial method to return a new mock DataCube
mock_resampled_datacube = MagicMock(spec=openeo.DataCube)
mock_datacube.resample_spatial.return_value = mock_resampled_datacube

# Call the resample_reproject function
resample_reproject(
mock_datacube, resolution=10.0, epsg_code="4326", method="bilinear"
)

# Ensure resample_spatial was called correctly
mock_datacube.resample_spatial.assert_called_once_with(
resolution=10.0, projection="4326", method="bilinear"
)


# invalid espg
def test_resample_reproject_invalid_epsg():
"""Test resample_reproject with an invalid EPSG code."""
# Create a mock DataCube object
mock_datacube = MagicMock(spec=openeo.DataCube)

# Attempt to call the resample_reproject function with an invalid EPSG code
with pytest.raises(ValueError, match="is not a valid EPSG code"):
resample_reproject(
mock_datacube, resolution=10.0, epsg_code="invalid_epsg", method="bilinear"
)

# Ensure resample_spatial was not called
mock_datacube.resample_spatial.assert_not_called()


# valid resolution
def test_resample_reproject_only_resolution():
"""Test resample_reproject with only resolution provided."""
# Create a mock DataCube object
mock_datacube = MagicMock(spec=openeo.DataCube)

# Mock the resample_spatial method to return a new mock DataCube
mock_resampled_datacube = MagicMock(spec=openeo.DataCube)
mock_datacube.resample_spatial.return_value = mock_resampled_datacube

# Call the resample_reproject function with only resolution provided
resample_reproject(mock_datacube, resolution=20.0)

# Ensure resample_spatial was called correctly with the resolution and default method
mock_datacube.resample_spatial.assert_called_once_with(
resolution=20.0, method="near"
)


# default espg
def test_resample_reproject_no_epsg():
"""Test resample_reproject with no EPSG code provided."""
# Create a mock DataCube object
mock_datacube = MagicMock(spec=openeo.DataCube)

# Mock the resample_spatial method to return a new mock DataCube
mock_resampled_datacube = MagicMock(spec=openeo.DataCube)
mock_datacube.resample_spatial.return_value = mock_resampled_datacube

# Call the resample_reproject function without specifying an EPSG code
resample_reproject(
mock_datacube, resolution=10.0, epsg_code=None, method="bilinear"
)

# Ensure resample_spatial was called correctly without the projection argument
mock_datacube.resample_spatial.assert_called_once_with(
resolution=10.0, method="bilinear"
)


# default resampling
def test_resample_reproject_default_method():
"""Test resample_reproject with a valid EPSG code and default resampling method."""
# Create a mock DataCube object
mock_datacube = MagicMock(spec=openeo.DataCube)

# Mock the resample_spatial method to return a new mock DataCube
mock_resampled_datacube = MagicMock(spec=openeo.DataCube)
mock_datacube.resample_spatial.return_value = mock_resampled_datacube

# Call the resample_reproject function with default method ("near")
resample_reproject(mock_datacube, resolution=10.0, epsg_code="4326")

# Ensure resample_spatial was called correctly with the default method
mock_datacube.resample_spatial.assert_called_once_with(
resolution=10.0, projection="4326", method="near"
)


def test_rename_bands_all_present():
"""Test rename_bands when all bands in the mapping are present in the datacube."""

datacube = create_test_datacube()

mapping = {"B01": "coastal", "B02": "blue", "B03": "green"}

# Call the rename_bands function
result = rename_bands(datacube, mapping)

# Extract the band names from the result metadata
result_band_names = [band.name for band in result.metadata._dimensions[0].bands]

# Check that only the available bands have been renamed
expected_renamed_bands = [
"coastal",
"blue",
"green",
] + datacube.metadata.band_names[
3:
] # Assuming B04-B12 remain unchanged

assert result_band_names == expected_renamed_bands


def test_rename_bands_some_missing():
"""Test rename_bands when some bands are not present in the datacube."""

# Use the fixture with specific bands
datacube = create_test_datacube(bands=["B01", "B02"]) # Only include B01 and B02

mapping = {
"B01": "coastal",
"B02": "blue",
"B03": "green", # B03 is not in the datacube
}

# Call the rename_bands function
result = rename_bands(datacube, mapping)

# Extract the band names from the result metadata
result_band_names = [band.name for band in result.metadata._dimensions[0].bands]

# Check that only the available bands have been renamed
expected_renamed_bands = ["coastal", "blue"]
assert result_band_names == expected_renamed_bands


def test_rename_bands_no_bands_present():
"""Test rename_bands when no bands in the mapping are present in the datacube."""

# Use the fixture with specific bands
datacube = create_test_datacube(bands=["B04", "B05"]) # Only include B04 and B05

mapping = {
"B01": "coastal",
"B02": "blue",
"B03": "green", # None of these bands are present in the datacube
}

# Call the rename_bands function
result = rename_bands(datacube, mapping)

# Extract the band names from the result metadata
result_band_names = [band.name for band in result.metadata._dimensions[0].bands]

# Check that only the available bands have been renamed
assert result_band_names == []
142 changes: 142 additions & 0 deletions tests/test_openeo_gfmap/test_unit_collection_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from unittest.mock import MagicMock

import pytest
import xarray as xr

from openeo_gfmap import BoundingBoxExtent, TemporalContext
from openeo_gfmap.backend import BackendContext
from openeo_gfmap.fetching import CollectionFetcher


@pytest.fixture
def mock_backend_context():
"""Fixture to create a mock backend context."""
return MagicMock(spec=BackendContext)


@pytest.fixture
def mock_connection():
"""Fixture to create a mock connection."""
return MagicMock()


@pytest.fixture
def mock_spatial_extent():
"""Fixture for spatial extent."""
return BoundingBoxExtent(
west=5.0515130512706845,
south=51.215806593713,
east=5.060320484557499,
north=51.22149744530769,
epsg=4326,
)


@pytest.fixture
def mock_temporal_context():
"""Fixture for temporal context."""
return TemporalContext(start_date="2023-04-01", end_date="2023-05-01")


@pytest.fixture
def mock_collection_fetch():
"""Fixture for the collection fetch function."""
return MagicMock(return_value=xr.DataArray([1, 2, 3], dims=["bands"]))


@pytest.fixture
def mock_collection_preprocessing():
"""Fixture for the collection preprocessing function."""
return MagicMock(return_value=xr.DataArray([1, 2, 3], dims=["bands"]))


def test_collection_fetcher(
mock_connection,
mock_spatial_extent,
mock_temporal_context,
mock_backend_context,
mock_collection_fetch,
mock_collection_preprocessing,
):
"""Test CollectionFetcher with basic data fetching."""

# Create the CollectionFetcher with the mock functions
fetcher = CollectionFetcher(
backend_context=mock_backend_context,
bands=["B01", "B02", "B03"],
collection_fetch=mock_collection_fetch, # Use the mock collection fetch function
collection_preprocessing=mock_collection_preprocessing, # Use the mock preprocessing function
)

# Call the method you're testing
result = fetcher.get_cube(
mock_connection, mock_spatial_extent, mock_temporal_context
)

# Assertions to check if everything works as expected
assert isinstance(
result, xr.DataArray
) # Check if the result is an xarray DataArray
assert result.dims == ("bands",) # Ensure the dimensions are as expected
assert result.values.tolist() == [
1,
2,
3,
] # Ensure the values match the expected output


def test_collection_fetcher_get_cube(
mock_connection,
mock_spatial_extent,
mock_temporal_context,
mock_backend_context,
mock_collection_fetch,
mock_collection_preprocessing,
):
"""Test that CollectionFetcher.get_cube is called correctly."""

bands = ["S2-L2A-B01", "S2-L2A-B02"]

# Create the CollectionFetcher with the mock functions
fetcher = CollectionFetcher(
backend_context=mock_backend_context,
bands=bands,
collection_fetch=mock_collection_fetch, # Use the mock collection fetch function
collection_preprocessing=mock_collection_preprocessing, # Use the mock preprocessing function
)

# Call the method you're testing
result = fetcher.get_cube(
mock_connection, mock_spatial_extent, mock_temporal_context
)

# Assert the fetch method was called with the correct arguments
mock_collection_fetch.assert_called_once_with(
mock_connection,
mock_spatial_extent,
mock_temporal_context,
bands,
**fetcher.params, # Check for additional parameters if any
)

# Assert the preprocessing method was called once
mock_collection_preprocessing.assert_called_once()

# Assert that the result is an instance of xarray.DataArray
assert isinstance(result, xr.DataArray)


def test_collection_fetcher_with_empty_bands(
mock_backend_context, mock_connection, mock_spatial_extent, mock_temporal_context
):
"""Test that CollectionFetcher raises an error with empty bands."""
bands = []
fetcher = CollectionFetcher(
backend_context=mock_backend_context,
bands=bands,
collection_fetch=MagicMock(), # No need to mock fetch here
collection_preprocessing=MagicMock(),
)

# with pytest.raises(ValueError, match="Bands cannot be empty"):
fetcher.get_cube(mock_connection, mock_spatial_extent, mock_temporal_context)
Loading

0 comments on commit 4602a46

Please sign in to comment.