Skip to content

Commit

Permalink
Merge pull request #130 from Open-EO/123-meteo-stac
Browse files Browse the repository at this point in the history
123 meteo stac
  • Loading branch information
GriffinBabe authored Jun 19, 2024
2 parents eaac70e + dd6c557 commit 5ceb21c
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 49 deletions.
36 changes: 17 additions & 19 deletions src/openeo_gfmap/features/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import openeo
import xarray as xr
from openeo.udf import XarrayDataCube
from openeo.udf.run_code import execute_local_udf
from openeo.udf.udf_data import UdfData
from pyproj import Transformer
from pyproj.crs import CRS
Expand All @@ -33,12 +32,6 @@ class FeatureExtractor(ABC):
"""

def __init__(self) -> None:
self.logger = None

def _initialize_logger(self) -> None:
"""
Initializes the PrestoFeatureExtractor object, starting a logger.
"""
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(self.__class__.__name__)

Expand Down Expand Up @@ -86,7 +79,6 @@ def _common_preparations(
executed. This method should be called by the `_execute` method of the
feature extractor.
"""
self._initialize_logger()
self._epsg = parameters.pop(EPSG_HARMONIZED_NAME)
self._parameters = parameters
return inarr
Expand Down Expand Up @@ -371,21 +363,27 @@ def apply_feature_extractor_local(
excepts for the cube parameter which expects a `xarray.DataArray` instead of
a `openeo.rest.datacube.DataCube` object.
"""
# Trying to get the local EPSG code
if EPSG_HARMONIZED_NAME not in parameters:
raise ValueError(
f"Please specify an EPSG code in the parameters with key: {EPSG_HARMONIZED_NAME} when "
f"running a Feature Extractor locally."
)

feature_extractor = feature_extractor_class()
feature_extractor._parameters = parameters
output_labels = feature_extractor.output_labels()
dependencies = feature_extractor.dependencies()

udf_code = _generate_udf_code(feature_extractor_class, dependencies)

udf = openeo.UDF(code=udf_code, context=parameters)
if len(dependencies) > 0:
feature_extractor.logger.warning(
"Running UDFs locally with pip dependencies is not supported yet, "
"dependencies will not be installed."
)

cube = XarrayDataCube(cube)

out_udf_data: UdfData = execute_local_udf(udf, cube, fmt="NetCDF")

output_cubes = out_udf_data.datacube_list

assert len(output_cubes) == 1, "UDF should have only a single output cube."

return output_cubes[0].get_array().assign_coords({"bands": output_labels})
return (
feature_extractor._execute(cube, parameters)
.get_array()
.assign_coords({"bands": output_labels})
)
62 changes: 49 additions & 13 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
"""
Common internal operations within collection extraction logic, such as reprojection.
"""

from typing import Optional, Union
from functools import partial
from typing import Dict, Optional, Sequence, Union

import openeo
from geojson import GeoJSON
from rasterio import CRS
from rasterio.errors import CRSError
from openeo.api.process import Parameter
from openeo.rest.connection import InputDate
from pyproj.crs import CRS
from pyproj.exceptions import CRSError

from openeo_gfmap.spatial import BoundingBoxExtent, SpatialContext
from openeo_gfmap.temporal import TemporalContext
Expand Down Expand Up @@ -76,19 +79,52 @@ def filter_condition(band_name, _):
)


def _load_collection_hybrid(
connection: openeo.Connection,
is_stac: bool,
collection_id_or_url: str,
bands: list,
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
properties: Optional[dict] = None,
):
"""Wrapper around the load_collection, or load_stac method of the openeo connection."""
if not is_stac:
return connection.load_collection(
collection_id=collection_id_or_url,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=properties,
)
cube = connection.load_stac(
url=collection_id_or_url,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=properties,
)
cube = cube.rename_labels(dimension="bands", target=bands)
return cube


def _load_collection(
connection: openeo.Connection,
bands: list,
collection_name: str,
spatial_extent: SpatialContext,
temporal_extent: Optional[TemporalContext],
fetch_type: FetchType,
is_stac: bool = False,
**params,
):
"""Loads a collection from the openeo backend, acting differently depending
on the fetch type.
"""
load_collection_parameters = params.get("load_collection", {})
load_collection_method = partial(
_load_collection_hybrid, is_stac=is_stac, collection_id_or_url=collection_name
)

if (
temporal_extent is not None
Expand All @@ -100,11 +136,11 @@ def _load_collection(
spatial_extent, BoundingBoxExtent
), "Please provide only a bounding box for tile based fetching."
spatial_extent = dict(spatial_extent)
cube = connection.load_collection(
collection_id=collection_name,
cube = load_collection_method(
connection=connection,
bands=bands,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=load_collection_parameters,
)
elif fetch_type == FetchType.POINT:
Expand All @@ -114,11 +150,11 @@ def _load_collection(
assert (
spatial_extent["type"] == "FeatureCollection"
), "Please provide a FeatureCollection type of GeoJSON"
cube = connection.load_collection(
collection_id=collection_name,
cube = load_collection_method(
connection=connection,
bands=bands,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=load_collection_parameters,
)
elif fetch_type == FetchType.POLYGON:
Expand All @@ -134,10 +170,10 @@ def _load_collection(
raise ValueError(
"Please provide a valid URL to a GeoParquet or GeoJSON file."
)
cube = connection.load_collection(
collection_id=collection_name,
temporal_extent=temporal_extent,
cube = load_collection_method(
connection=connection,
bands=bands,
temporal_extent=temporal_extent,
properties=load_collection_parameters,
)

Expand Down
126 changes: 126 additions & 0 deletions src/openeo_gfmap/fetching/meteo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Meteo data fetchers."""

from functools import partial

import openeo
from geojson import GeoJSON

from openeo_gfmap import (
Backend,
BackendContext,
FetchType,
SpatialContext,
TemporalContext,
)
from openeo_gfmap.fetching import CollectionFetcher
from openeo_gfmap.fetching.commons import convert_band_names
from openeo_gfmap.fetching.generic import (
_get_generic_fetcher,
_get_generic_processor,
_load_collection,
)

WEATHER_MAPPING_TERRASCOPE = {
"dewpoint-temperature": "AGERA5-DEWTEMP",
"precipitation-flux": "AGERA5-PRECIP",
"solar-radiation-flux": "AGERA5-SOLRAD",
"temperature-max": "AGERA5-TMAX",
"temperature-mean": "AGERA5-TMEAN",
"temperature-min": "AGERA5-TMIN",
"vapour-pressure": "AGERA5-VAPOUR",
"wind-speed": "AGERA5-WIND",
}

WEATHER_MAPPING_STAC = {
"dewpoint_temperature_mean": "AGERA5-DEWTEMP",
"total_precipitation": "AGERA5-PRECIP",
"solar_radiataion_flux": "AGERA5-SOLRAD",
"2m_temperature_max": "AGERA5-TMAX",
"2m_temperature_mean": "AGERA5-TMEAN",
"2m_temperature_min": "AGERA5-TMIN",
"vapour_pressure": "AGERA5-VAPOUR",
"wind_speed": "AGERA5-WIND",
}


def stac_fetcher(
connection: openeo.Connection,
spatial_extent: SpatialContext,
temporal_extent: TemporalContext,
bands: list,
fetch_type: FetchType,
**params,
) -> openeo.DataCube:
bands = convert_band_names(bands, WEATHER_MAPPING_STAC)

cube = _load_collection(
connection,
bands,
"https://stac.openeo.vito.be/collections/agera5_daily",
spatial_extent,
temporal_extent,
fetch_type,
is_stac=True,
**params,
)

if isinstance(spatial_extent, GeoJSON) and fetch_type == FetchType.POLYGON:
cube = cube.filter_spatial(spatial_extent)

return cube


METEO_BACKEND_MAP = {
Backend.TERRASCOPE: {
"fetch": partial(
_get_generic_fetcher,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_TERRASCOPE,
),
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_TERRASCOPE,
),
},
Backend.CDSE: {
"fetch": stac_fetcher,
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_STAC,
),
},
Backend.CDSE_STAGING: {
"fetch": stac_fetcher,
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_STAC,
),
},
Backend.FED: {
"fetch": stac_fetcher,
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_STAC,
),
},
}


def build_meteo_extractor(
backend_context: BackendContext,
bands: list,
fetch_type: FetchType,
**params,
) -> CollectionFetcher:
backend_functions = METEO_BACKEND_MAP.get(backend_context.backend)

fetcher, preprocessor = (
partial(backend_functions["fetch"], fetch_type=fetch_type),
backend_functions["preprocessor"](fetch_type=fetch_type),
)

return CollectionFetcher(backend_context, bands, fetcher, preprocessor, **params)
33 changes: 17 additions & 16 deletions src/openeo_gfmap/inference/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import xarray as xr
from openeo.udf import XarrayDataCube
from openeo.udf import inspect as udf_inspect
from openeo.udf.run_code import execute_local_udf
from openeo.udf.udf_data import UdfData

sys.path.insert(0, "onnx_deps")
Expand All @@ -32,9 +31,6 @@ class ModelInference(ABC):
"""

def __init__(self) -> None:
self.logger = None

def _initialize_logger(self) -> None:
"""
Initializes the PrestoFeatureExtractor object, starting a logger.
"""
Expand Down Expand Up @@ -115,7 +111,6 @@ def _common_preparations(
"""Common preparations for all inference models. This method will be
executed at the very beginning of the process.
"""
self._initialize_logger()
self._epsg = parameters.pop(EPSG_HARMONIZED_NAME)
self._parameters = parameters
return inarr
Expand Down Expand Up @@ -323,21 +318,27 @@ def apply_model_inference_local(
excepts for the cube parameter which expects a `xarray.DataArray` instead of
a `openeo.rest.datacube.DataCube` object.
"""
# Trying to get the local EPSG code
if EPSG_HARMONIZED_NAME not in parameters:
raise ValueError(
f"Please specify an EPSG code in the parameters with key: {EPSG_HARMONIZED_NAME} when "
f"running a Model Inference locally."
)

model_inference = model_inference_class()
model_inference._parameters = parameters
output_labels = model_inference.output_labels()
dependencies = model_inference.dependencies()

udf_code = _generate_udf_code(model_inference_class, dependencies)

udf = openeo.UDF(code=udf_code, context=parameters)
if len(dependencies) > 0:
model_inference.logger.warning(
"Running UDFs locally with pip dependencies is not supported yet, "
"dependencies will not be installed."
)

cube = XarrayDataCube(cube)

out_udf_data: UdfData = execute_local_udf(udf, cube, fmt="NetCDF")

output_cubes = out_udf_data.datacube_list

assert len(output_cubes) == 1, "UDF should have only a single output cube."

return output_cubes[0].get_array().assign_coords({"bands": output_labels})
return (
model_inference._execute(cube, parameters)
.get_array()
.assign_coords({"bands": output_labels})
)
Loading

0 comments on commit 5ceb21c

Please sign in to comment.