Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Open-EO/openeo-gfmap into s…
Browse files Browse the repository at this point in the history
…1-extraction-fixes-PR2

Conflicts:
	tests/test_openeo_gfmap/test_utils.py
  • Loading branch information
GriffinBabe committed Aug 26, 2024
2 parents 9f06d62 + 8d9a808 commit 11593f2
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/openeo_gfmap/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from openeo_gfmap.utils.build_df import load_json
from openeo_gfmap.utils.intervals import quintad_intervals
from openeo_gfmap.utils.netcdf import update_nc_attributes
from openeo_gfmap.utils.split_stac import split_collection_by_epsg
from openeo_gfmap.utils.tile_processing import (
array_bounds,
arrays_cosine_similarity,
Expand All @@ -19,5 +20,6 @@
"select_sar_bands",
"arrays_cosine_similarity",
"quintad_intervals",
"split_collection_by_epsg",
"update_nc_attributes",
]
10 changes: 9 additions & 1 deletion src/openeo_gfmap/utils/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def _query_cdse_catalogue(
temporal_extent: TemporalContext,
**additional_parameters: dict,
) -> dict:
"""
Queries the CDSE catalogue for a given collection, spatio-temporal context and additional
parameters.
Params
------
"""
minx, miny, maxx, maxy = bounds

# The date format should be YYYY-MM-DD
Expand Down Expand Up @@ -228,7 +236,7 @@ def s1_area_per_orbitstate_vvvh(
}


def select_S1_orbitstate_vvvh(
def select_s1_orbitstate_vvvh(
backend: BackendContext,
spatial_extent: SpatialContext,
temporal_extent: TemporalContext,
Expand Down
133 changes: 133 additions & 0 deletions src/openeo_gfmap/utils/split_stac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Utility function to split a STAC collection into multiple STAC collections based on CRS.
Requires the "proj:epsg" property to be present in all the STAC items.
"""

import os
from pathlib import Path
from typing import Union

import pystac


def _extract_epsg_from_stac_item(stac_item: pystac.Item) -> int:
"""
Extract the EPSG code from a STAC item.
Parameters:
stac_item (pystac.Item): The STAC item.
Returns:
int: The EPSG code.
Raises:
KeyError: If the "proj:epsg" property is missing from the STAC item.
"""

try:
epsg_code = stac_item.properties["proj:epsg"]
return epsg_code
except KeyError:
raise KeyError("The 'proj:epsg' property is missing from the STAC item.")


def _create_item_by_epsg_dict(collection: pystac.Collection) -> dict:
"""
Create a dictionary that groups items by their EPSG code.
Parameters:
collection (pystac.Collection): The STAC collection.
Returns:
dict: A dictionary that maps EPSG codes to lists of items.
"""
# Dictionary to store items grouped by their EPSG codes
items_by_epsg = {}

# Iterate through items and group them
for item in collection.get_items():
epsg = _extract_epsg_from_stac_item(item)
if epsg not in items_by_epsg:
items_by_epsg[epsg] = []
items_by_epsg[epsg].append(item)

return items_by_epsg


def _create_new_epsg_collection(
epsg: int, items: list, collection: pystac.Collection
) -> pystac.Collection:
"""
Create a new STAC collection with a given EPSG code.
Parameters:
epsg (int): The EPSG code.
items (list): The list of items.
collection (pystac.Collection): The original STAC collection.
Returns:
pystac.Collection: The new STAC collection.
"""
new_collection = collection.clone()
new_collection.id = f"{collection.id}_{epsg}"
new_collection.description = (
f"{collection.description} Containing only items with EPSG code {epsg}"
)
new_collection.clear_items()
for item in items:
new_collection.add_item(item)

new_collection.update_extent_from_items()

return new_collection


def _create_collection_by_epsg_dict(collection: pystac.Collection) -> dict:
"""
Create a dictionary that groups collections by their EPSG code.
Parameters:
collection (pystac.Collection): The STAC collection.
Returns:
dict: A dictionary that maps EPSG codes to STAC collections.
"""
items_by_epsg = _create_item_by_epsg_dict(collection)
collections_by_epsg = {}
for epsg, items in items_by_epsg.items():
new_collection = _create_new_epsg_collection(epsg, items, collection)
collections_by_epsg[epsg] = new_collection

return collections_by_epsg


def _write_collection_dict(collection_dict: dict, output_dir: Union[str, Path]):
"""
Write the collection dictionary to disk.
Parameters:
collection_dict (dict): The dictionary that maps EPSG codes to STAC collections.
output_dir (str): The output directory.
"""
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)

for epsg, collection in collection_dict.items():
collection.normalize_hrefs(os.path.join(output_dir, f"collection-{epsg}"))
collection.save()


def split_collection_by_epsg(path: Union[str, Path], output_dir: Union[str, Path]):
"""
Split a STAC collection into multiple STAC collections based on EPSG code.
Parameters:
path (str): The path to the STAC collection.
output_dir (str): The output directory.
"""
path = Path(path)
try:
collection = pystac.read_file(path)
except pystac.STACError:
print("Please provide a path to a valid STAC collection.")
collection_dict = _create_collection_by_epsg_dict(collection)
_write_collection_dict(collection_dict, output_dir)
101 changes: 99 additions & 2 deletions tests/test_openeo_gfmap/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from pathlib import Path

import pystac
import pytest
from netCDF4 import Dataset

from openeo_gfmap import Backend, BackendContext, BoundingBoxExtent, TemporalContext
from openeo_gfmap.utils import update_nc_attributes
from openeo_gfmap.utils.catalogue import (
s1_area_per_orbitstate_vvvh,
select_S1_orbitstate_vvvh,
select_s1_orbitstate_vvvh,
)

# Region of Paris, France
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_query_cdse_catalogue():
assert response["DESCENDING"]["full_overlap"] is True

# Testing the decision maker, it should return DESCENDING
decision = select_S1_orbitstate_vvvh(
decision = select_s1_orbitstate_vvvh(
backend=backend_context,
spatial_extent=SPATIAL_CONTEXT,
temporal_extent=TEMPORAL_CONTEXT,
Expand Down Expand Up @@ -79,3 +80,99 @@ def test_update_nc_attributes(temp_nc_file):
assert getattr(nc, attr_name) == attr_value
assert "existing_attribute" in nc.ncattrs()
assert nc.getncattr("existing_attribute") == "existing_value"


def test_split_collection_by_epsg(tmp_path):
collection = pystac.collection.Collection.from_dict(
{
"type": "Collection",
"id": "test-collection",
"stac_version": "1.0.0",
"description": "Test collection",
"links": [],
"title": "Test Collection",
"extent": {
"spatial": {"bbox": [[-180.0, -90.0, 180.0, 90.0]]},
"temporal": {
"interval": [["2020-01-01T00:00:00Z", "2020-01-10T00:00:00Z"]]
},
},
"license": "proprietary",
"summaries": {"eo:bands": [{"name": "B01"}, {"name": "B02"}]},
}
)
first_item = pystac.item.Item.from_dict(
{
"type": "Feature",
"stac_version": "1.0.0",
"id": "4326-item",
"properties": {
"datetime": "2020-05-22T00:00:00Z",
"eo:bands": [{"name": "SCL"}, {"name": "B08"}],
"proj:epsg": 4326,
},
"geometry": {
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
"type": "Polygon",
},
"links": [],
"assets": {},
"bbox": [0, 1, 0, 1],
"stac_extensions": [],
}
)
second_item = pystac.item.Item.from_dict(
{
"type": "Feature",
"stac_version": "1.0.0",
"id": "3857-item",
"properties": {
"datetime": "2020-05-22T00:00:00Z",
"eo:bands": [{"name": "SCL"}, {"name": "B08"}],
"proj:epsg": 3857,
},
"geometry": {
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
"type": "Polygon",
},
"links": [],
"assets": {},
"bbox": [0, 1, 0, 1],
"stac_extensions": [],
}
)
collection.add_items([first_item, second_item])
input_dir = str(tmp_path / "collection.json")
output_dir = str(tmp_path / "split_collections")

collection.normalize_and_save(input_dir)
split_collection_by_epsg(path=input_dir, output_dir=output_dir)

# Collection contains two different EPSG codes, so 2 collections should be created
assert len([p for p in Path(output_dir).iterdir() if p.is_dir()]) == 2

missing_epsg_item = pystac.item.Item.from_dict(
{
"type": "Feature",
"stac_version": "1.0.0",
"id": "3857-item",
"properties": {
"datetime": "2020-05-22T00:00:00Z",
"eo:bands": [{"name": "SCL"}, {"name": "B08"}],
},
"geometry": {
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
"type": "Polygon",
},
"links": [],
"assets": {},
"bbox": [0, 1, 0, 1],
"stac_extensions": [],
}
)

# Collection contains item without EPSG, so KeyError should be raised
with pytest.raises(KeyError):
collection.add_item(missing_epsg_item)
collection.normalize_and_save(input_dir)
split_collection_by_epsg(path=input_dir, output_dir=output_dir)

0 comments on commit 11593f2

Please sign in to comment.