Skip to content

Commit

Permalink
Issue #16 blackify src/ and tests/
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Feb 1, 2024
1 parent ed40736 commit 53675b9
Show file tree
Hide file tree
Showing 21 changed files with 251 additions and 323 deletions.
8 changes: 2 additions & 6 deletions src/openeo_gfmap/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class BackendContext:
backend: Backend


def _create_connection(
url: str, *, env_var_suffix: str, connect_kwargs: Optional[dict] = None
):
def _create_connection(url: str, *, env_var_suffix: str, connect_kwargs: Optional[dict] = None):
"""
Generic helper to create an openEO connection
with support for multiple client credential configurations from environment variables
Expand Down Expand Up @@ -63,9 +61,7 @@ def _create_connection(

# Use a shorter max poll time by default to alleviate the default impression that the test seem to hang
# because of the OIDC device code poll loop.
max_poll_time = int(
os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30
)
max_poll_time = int(os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30)
connection.authenticate_oidc(max_poll_time=max_poll_time)
return connection

Expand Down
4 changes: 1 addition & 3 deletions src/openeo_gfmap/extractions/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def s2_l2a_fetch_default(
), "CRS not defined within GeoJSON collection."
spatial_extent = dict(spatial_extent)

cube = connection.load_collection(
collection_name, spatial_extent, temporal_extent, bands
)
cube = connection.load_collection(collection_name, spatial_extent, temporal_extent, bands)

# Apply if the collection is a GeoJSON Feature collection
if isinstance(spatial_extent, GeoJSON):
Expand Down
6 changes: 4 additions & 2 deletions src/openeo_gfmap/fetching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .s2 import build_sentinel2_l2a_extractor

__all__ = [
"build_sentinel2_l2a_extractor", "CollectionFetcher", "FetchType",
"build_sentinel1_grd_extractor"
"build_sentinel2_l2a_extractor",
"CollectionFetcher",
"FetchType",
"build_sentinel1_grd_extractor",
]
3 changes: 1 addition & 2 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def load_collection(
pre_mask = params.get("pre_mask", None)
if pre_mask is not None:
assert isinstance(pre_mask, openeo.DataCube), (
f"The 'pre_mask' parameter must be an openeo datacube, "
f"got {pre_mask}."
f"The 'pre_mask' parameter must be an openeo datacube, " f"got {pre_mask}."
)
cube = cube.mask(pre_mask.resample_cube_spatial(cube))

Expand Down
32 changes: 9 additions & 23 deletions src/openeo_gfmap/fetching/s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def s1_grd_fetch_default(
return s1_grd_fetch_default


def get_s1_grd_default_processor(
collection_name: str, fetch_type: FetchType
) -> Callable:
def get_s1_grd_default_processor(collection_name: str, fetch_type: FetchType) -> Callable:
"""Builds the preprocessing function from the collection name as it is stored
in the target backend.
"""
Expand All @@ -113,9 +111,7 @@ def s1_grd_default_processor(cube: openeo.DataCube, **params):
)

cube = resample_reproject(
cube,
params.get("target_resolution", 10.0),
params.get("target_crs", None)
cube, params.get("target_resolution", 10.0), params.get("target_crs", None)
)

cube = rename_bands(cube, BASE_SENTINEL1_GRD_MAPPING)
Expand All @@ -127,35 +123,25 @@ def s1_grd_default_processor(cube: openeo.DataCube, **params):

SENTINEL1_GRD_BACKEND_MAP = {
Backend.TERRASCOPE: {
"default": partial(
get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"
),
"preprocessor": partial(
get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"
)
"default": partial(get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"),
"preprocessor": partial(get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"),
},
Backend.CDSE: {
"default": partial(
get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"
),
"preprocessor": partial(
get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"
)
}
"default": partial(get_s1_grd_default_fetcher, collection_name="SENTINEL1_GRD"),
"preprocessor": partial(get_s1_grd_default_processor, collection_name="SENTINEL1_GRD"),
},
}


def build_sentinel1_grd_extractor(
backend_context: BackendContext, bands: list, fetch_type: FetchType, **params
) -> CollectionFetcher:
""" Creates a S1 GRD collection extractor for the given backend."""
"""Creates a S1 GRD collection extractor for the given backend."""
backend_functions = SENTINEL1_GRD_BACKEND_MAP.get(backend_context.backend)

fetcher, preprocessor = (
backend_functions["default"](fetch_type=fetch_type),
backend_functions["preprocessor"](fetch_type=fetch_type),
)

return CollectionFetcher(
backend_context, bands, fetcher, preprocessor, **params
)
return CollectionFetcher(backend_context, bands, fetcher, preprocessor, **params)
16 changes: 4 additions & 12 deletions src/openeo_gfmap/fetching/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def s2_l2a_fetch_default(
return s2_l2a_fetch_default


def get_s2_l2a_element84_fetcher(
collection_name: str, fetch_type: FetchType
) -> Callable:
def get_s2_l2a_element84_fetcher(collection_name: str, fetch_type: FetchType) -> Callable:
"""Fetches the collections from the Sentinel-2 Cloud-Optimized GeoTIFFs
bucket provided by Amazon and managed by Element84.
"""
Expand Down Expand Up @@ -157,9 +155,7 @@ def s2_l2a_element84_fetcher(
return s2_l2a_element84_fetcher


def get_s2_l2a_default_processor(
collection_name: str, fetch_type: FetchType
) -> Callable:
def get_s2_l2a_default_processor(collection_name: str, fetch_type: FetchType) -> Callable:
"""Builds the preprocessing function from the collection name as it stored
in the target backend.
"""
Expand Down Expand Up @@ -188,15 +184,11 @@ def s2_l2a_default_processor(cube: openeo.DataCube, **params):
SENTINEL2_L2A_BACKEND_MAP = {
Backend.TERRASCOPE: {
"fetch": partial(get_s2_l2a_default_fetcher, collection_name="SENTINEL2_L2A"),
"preprocessor": partial(
get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"
),
"preprocessor": partial(get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"),
},
Backend.CDSE: {
"fetch": partial(get_s2_l2a_default_fetcher, collection_name="SENTINEL2_L2A"),
"preprocessor": partial(
get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"
),
"preprocessor": partial(get_s2_l2a_default_processor, collection_name="SENTINEL2_L2A"),
},
}

Expand Down
2 changes: 1 addition & 1 deletion src/openeo_gfmap/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
"get_bap_score",
"get_bap_mask",
"bap_masking",
]
]
71 changes: 35 additions & 36 deletions src/openeo_gfmap/preprocessing/cloudmasking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
SCL_HARMONIZED_NAME: str = "S2-SCL"
BAPSCORE_HARMONIZED_NAME: str = "S2-BAPSCORE"


def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube:
"""Creates a mask from the SCL, dilates it and applies the mask to the optical
bands of the datacube. The other bands such as DEM, SAR and METEO will not
be affected by the mask.
"""
# Asserts if the SCL layer exists
assert SCL_HARMONIZED_NAME in cube.metadata.band_names, (
f"The SCL band ({SCL_HARMONIZED_NAME}) is not present in the datacube."
)
assert (
SCL_HARMONIZED_NAME in cube.metadata.band_names
), f"The SCL band ({SCL_HARMONIZED_NAME}) is not present in the datacube."

kernel1_size = params.get("kernel1_size", 17)
kernel2_size = params.get("kernel2_size", 3)
Expand All @@ -28,15 +29,11 @@ def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube:

# Only applies the filtering to the optical part of the cube
optical_cube = cube.filter_bands(
bands=list(
filter(lambda band: band.startswith("S2"), cube.metadata.band_names)
)
bands=list(filter(lambda band: band.startswith("S2"), cube.metadata.band_names))
)

nonoptical_cube = cube.filter_bands(
bands=list(
filter(lambda band: not band.startswith("S2"), cube.metadata.band_names)
)
bands=list(filter(lambda band: not band.startswith("S2"), cube.metadata.band_names))
)

optical_cube = optical_cube.process(
Expand All @@ -47,14 +44,15 @@ def mask_scl_dilation(cube: openeo.DataCube, **params: dict) -> openeo.DataCube:
kernel2_size=kernel2_size,
mask1_values=[2, 4, 5, 6, 7],
mask2_values=[3, 8, 9, 10, 11],
erosion_kernel_size=erosion_kernel_size
erosion_kernel_size=erosion_kernel_size,
)

if len(nonoptical_cube.metadata.band_names) == 0:
return optical_cube

return optical_cube.merge_cubes(nonoptical_cube)


def get_bap_score(cube: openeo.DataCube, **params: dict) -> openeo.DataCube:
"""Calculates the Best Available Pixel (BAP) score for the given datacube,
computed from the SCL layer.
Expand Down Expand Up @@ -113,23 +111,30 @@ def get_bap_score(cube: openeo.DataCube, **params: dict) -> openeo.DataCube:
kernel2_size=kernel2_size,
mask1_values=[2, 4, 5, 6, 7],
mask2_values=[3, 8, 9, 10, 11],
erosion_kernel_size=erosion_kernel_size
erosion_kernel_size=erosion_kernel_size,
)

# Replace NaN to 0 to avoid issues in the UDF
scl_cube = scl_cube.apply(lambda x: if_(is_nan(x), 0, x))

score = scl_cube.apply_neighborhood(
process=openeo.UDF.from_file(str(udf_path)),
size=[{'dimension': 'x', 'unit': 'px', 'value': 256}, {'dimension': 'y', 'unit': 'px', 'value': 256}],
overlap=[{'dimension': 'x', 'unit': 'px', 'value': 16}, {'dimension': 'y', 'unit': 'px', 'value': 16}],
size=[
{"dimension": "x", "unit": "px", "value": 256},
{"dimension": "y", "unit": "px", "value": 256},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 16},
{"dimension": "y", "unit": "px", "value": 16},
],
)

score = score.rename_labels('bands', [BAPSCORE_HARMONIZED_NAME])
score = score.rename_labels("bands", [BAPSCORE_HARMONIZED_NAME])

# Merge the score to the scl cube
return score


def get_bap_mask(cube: openeo.DataCube, period: Union[str, list], **params: dict):
"""Computes the bap score and masks the optical bands of the datacube using
the best scores for each pixel on a given time period. This method both
Expand All @@ -155,13 +160,14 @@ def get_bap_mask(cube: openeo.DataCube, period: Union[str, list], **params: dict
The datacube with the BAP mask applied.
"""
# Checks if the S2-SCL band is present in the datacube
assert SCL_HARMONIZED_NAME in cube.metadata.band_names, (
f"The {SCL_HARMONIZED_NAME} band is not present in the datacube."
)
assert (
SCL_HARMONIZED_NAME in cube.metadata.band_names
), f"The {SCL_HARMONIZED_NAME} band is not present in the datacube."

bap_score = get_bap_score(cube, **params)

if isinstance(period, str):

def max_score_selection(score):
max_score = score.max()
return score.array_apply(lambda x: x != max_score)
Expand All @@ -171,27 +177,26 @@ def max_score_selection(score):
size=[
{"dimension": "x", "unit": "px", "value": 1},
{"dimension": "y", "unit": "px", "value": 1},
{"dimension": "t", "value": period}
{"dimension": "t", "value": period},
],
overlap=[]
overlap=[],
)
elif isinstance(period, list):
udf_path = Path(__file__).parent / "udf_rank.py"
rank_mask = bap_score.apply_neighborhood(
process=openeo.UDF.from_file(
str(udf_path),
context={"intervals": period}
),
process=openeo.UDF.from_file(str(udf_path), context={"intervals": period}),
size=[
{'dimension': 'x', 'unit': 'px', 'value': 256},
{'dimension': 'y', 'unit': 'px', 'value': 256}
{"dimension": "x", "unit": "px", "value": 256},
{"dimension": "y", "unit": "px", "value": 256},
],
overlap=[],
)
else:
raise ValueError(f"'period' must be a string or a list of dates (in YYYY-mm-dd format), got {period}.")
raise ValueError(
f"'period' must be a string or a list of dates (in YYYY-mm-dd format), got {period}."
)

return rank_mask.rename_labels('bands', ['S2-BAPMASK'])
return rank_mask.rename_labels("bands", ["S2-BAPMASK"])


def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict):
Expand All @@ -213,22 +218,16 @@ def bap_masking(cube: openeo.DataCube, period: Union[str, list], **params: dict)
The datacube with the BAP mask applied.
"""
optical_cube = cube.filter_bands(
bands=list(
filter(lambda band: band.startswith("S2"), cube.metadata.band_names)
)
bands=list(filter(lambda band: band.startswith("S2"), cube.metadata.band_names))
)

nonoptical_cube = cube.filter_bands(
bands=list(
filter(lambda band: not band.startswith("S2"), cube.metadata.band_names)
)
bands=list(filter(lambda band: not band.startswith("S2"), cube.metadata.band_names))
)

rank_mask = get_bap_mask(optical_cube, period, **params)

optical_cube = optical_cube.mask(
rank_mask.resample_cube_spatial(cube)
)
optical_cube = optical_cube.mask(rank_mask.resample_cube_spatial(cube))

# Do not merge if bands are empty!
if len(nonoptical_cube.metadata.band_names) == 0:
Expand Down
1 change: 1 addition & 0 deletions src/openeo_gfmap/preprocessing/compositing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def median_compositing(cube: openeo.DataCube, period: Union[str, list]) -> opene
elif isinstance(period, list):
return cube.aggregate_temporal(intervals=period, reducer="median", dimension="t")


def mean_compositing(cube: openeo.DataCube, period: str) -> openeo.DataCube:
"""Perfrom mean compositing on the given datacube."""
if isinstance(period, str):
Expand Down
8 changes: 4 additions & 4 deletions src/openeo_gfmap/preprocessing/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import openeo


def linear_interpolation(cube: openeo.DataCube,) -> openeo.DataCube:
def linear_interpolation(
cube: openeo.DataCube,
) -> openeo.DataCube:
"""Perform linear interpolation on the given datacube."""
return cube.apply_dimension(
dimension="t", process="array_interpolate_linear"
)
return cube.apply_dimension(dimension="t", process="array_interpolate_linear")
8 changes: 2 additions & 6 deletions src/openeo_gfmap/preprocessing/udf_rank.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import numpy as np
import xarray as xr
from openeo.udf import XarrayDataCube
Expand All @@ -15,15 +14,14 @@ def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
"""
# First check if the period is defined in the context
intervals = context.get("intervals", None)
array = cube.get_array().transpose('t', 'bands', 'y', 'x')
array = cube.get_array().transpose("t", "bands", "y", "x")

bap_score = array.sel(bands="S2-BAPSCORE")

def select_maximum(score: xr.DataArray):
max_score = score.max(dim="t")
return score == max_score


if isinstance(intervals, str):
raise NotImplementedError(
"Period as string is not implemented yet, please provide a list of interval tuples."
Expand All @@ -32,9 +30,7 @@ def select_maximum(score: xr.DataArray):
# Convert YYYY-mm-dd to datetime64 objects
time_bins = [np.datetime64(interval[0]) for interval in intervals]

rank_mask = bap_score.groupby_bins('t', bins=time_bins).map(
select_maximum
)
rank_mask = bap_score.groupby_bins("t", bins=time_bins).map(select_maximum)
else:
raise ValueError("Period is not defined in the UDF. Cannot run it.")

Expand Down
Loading

0 comments on commit 53675b9

Please sign in to comment.