Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

123 meteo stac #130

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions src/openeo_gfmap/features/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import openeo
import xarray as xr
from openeo.udf import XarrayDataCube
from openeo.udf.run_code import execute_local_udf
from openeo.udf.udf_data import UdfData
from pyproj import Transformer
from pyproj.crs import CRS
Expand All @@ -33,12 +32,6 @@ class FeatureExtractor(ABC):
"""

def __init__(self) -> None:
self.logger = None

def _initialize_logger(self) -> None:
"""
Initializes the PrestoFeatureExtractor object, starting a logger.
"""
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(self.__class__.__name__)

Expand Down Expand Up @@ -86,7 +79,6 @@ def _common_preparations(
executed. This method should be called by the `_execute` method of the
feature extractor.
"""
self._initialize_logger()
self._epsg = parameters.pop(EPSG_HARMONIZED_NAME)
self._parameters = parameters
return inarr
Expand Down Expand Up @@ -371,21 +363,27 @@ def apply_feature_extractor_local(
excepts for the cube parameter which expects a `xarray.DataArray` instead of
a `openeo.rest.datacube.DataCube` object.
"""
# Trying to get the local EPSG code
if EPSG_HARMONIZED_NAME not in parameters:
raise ValueError(
f"Please specify an EPSG code in the parameters with key: {EPSG_HARMONIZED_NAME} when "
f"running a Feature Extractor locally."
)

feature_extractor = feature_extractor_class()
feature_extractor._parameters = parameters
output_labels = feature_extractor.output_labels()
dependencies = feature_extractor.dependencies()

udf_code = _generate_udf_code(feature_extractor_class, dependencies)

udf = openeo.UDF(code=udf_code, context=parameters)
if len(dependencies) > 0:
feature_extractor.logger.warning(
"Running UDFs locally with pip dependencies is not supported yet, "
"dependencies will not be installed."
)

cube = XarrayDataCube(cube)

out_udf_data: UdfData = execute_local_udf(udf, cube, fmt="NetCDF")

output_cubes = out_udf_data.datacube_list

assert len(output_cubes) == 1, "UDF should have only a single output cube."

return output_cubes[0].get_array().assign_coords({"bands": output_labels})
return (
feature_extractor._execute(cube, parameters)
.get_array()
.assign_coords({"bands": output_labels})
)
62 changes: 49 additions & 13 deletions src/openeo_gfmap/fetching/commons.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
"""
Common internal operations within collection extraction logic, such as reprojection.
"""

from typing import Optional, Union
from functools import partial
from typing import Dict, Optional, Sequence, Union

import openeo
from geojson import GeoJSON
from rasterio import CRS
from rasterio.errors import CRSError
from openeo.api.process import Parameter
from openeo.rest.connection import InputDate
from pyproj.crs import CRS
from pyproj.exceptions import CRSError

from openeo_gfmap.spatial import BoundingBoxExtent, SpatialContext
from openeo_gfmap.temporal import TemporalContext
Expand Down Expand Up @@ -76,19 +79,52 @@ def filter_condition(band_name, _):
)


def _load_collection_hybrid(
connection: openeo.Connection,
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
is_stac: bool,
collection_id_or_url: str,
bands: list,
spatial_extent: Union[Dict[str, float], Parameter, None] = None,
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
properties: Optional[dict] = None,
):
"""Wrapper around the load_collection, or load_stac method of the openeo connection."""
if not is_stac:
return connection.load_collection(
collection_id=collection_id_or_url,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=properties,
)
cube = connection.load_stac(
url=collection_id_or_url,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=properties,
)
cube = cube.rename_labels(dimension="bands", target=bands)
return cube
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved


def _load_collection(
connection: openeo.Connection,
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
bands: list,
collection_name: str,
spatial_extent: SpatialContext,
temporal_extent: Optional[TemporalContext],
fetch_type: FetchType,
is_stac: bool = False,
**params,
):
"""Loads a collection from the openeo backend, acting differently depending
on the fetch type.
"""
load_collection_parameters = params.get("load_collection", {})
load_collection_method = partial(
_load_collection_hybrid, is_stac=is_stac, collection_id_or_url=collection_name
)

if (
temporal_extent is not None
Expand All @@ -100,11 +136,11 @@ def _load_collection(
spatial_extent, BoundingBoxExtent
), "Please provide only a bounding box for tile based fetching."
spatial_extent = dict(spatial_extent)
cube = connection.load_collection(
collection_id=collection_name,
cube = load_collection_method(
connection=connection,
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
bands=bands,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=load_collection_parameters,
)
elif fetch_type == FetchType.POINT:
Expand All @@ -114,11 +150,11 @@ def _load_collection(
assert (
spatial_extent["type"] == "FeatureCollection"
), "Please provide a FeatureCollection type of GeoJSON"
cube = connection.load_collection(
collection_id=collection_name,
cube = load_collection_method(
connection=connection,
bands=bands,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
bands=bands,
properties=load_collection_parameters,
)
elif fetch_type == FetchType.POLYGON:
Expand All @@ -134,10 +170,10 @@ def _load_collection(
raise ValueError(
"Please provide a valid URL to a GeoParquet or GeoJSON file."
)
cube = connection.load_collection(
collection_id=collection_name,
temporal_extent=temporal_extent,
cube = load_collection_method(
connection=connection,
bands=bands,
temporal_extent=temporal_extent,
properties=load_collection_parameters,
)

Expand Down
126 changes: 126 additions & 0 deletions src/openeo_gfmap/fetching/meteo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Meteo data fetchers."""

from functools import partial

import openeo
from geojson import GeoJSON

from openeo_gfmap import (
Backend,
BackendContext,
FetchType,
SpatialContext,
TemporalContext,
)
from openeo_gfmap.fetching import CollectionFetcher
from openeo_gfmap.fetching.commons import convert_band_names
from openeo_gfmap.fetching.generic import (
_get_generic_fetcher,
_get_generic_processor,
_load_collection,
)

WEATHER_MAPPING_TERRASCOPE = {
"dewpoint-temperature": "AGERA5-DEWTEMP",
"precipitation-flux": "AGERA5-PRECIP",
"solar-radiation-flux": "AGERA5-SOLRAD",
"temperature-max": "AGERA5-TMAX",
"temperature-mean": "AGERA5-TMEAN",
"temperature-min": "AGERA5-TMIN",
"vapour-pressure": "AGERA5-VAPOUR",
"wind-speed": "AGERA5-WIND",
}

WEATHER_MAPPING_STAC = {
"dewpoint_temperature_mean": "AGERA5-DEWTEMP",
"total_precipitation": "AGERA5-PRECIP",
"solar_radiataion_flux": "AGERA5-SOLRAD",
"2m_temperature_max": "AGERA5-TMAX",
"2m_temperature_mean": "AGERA5-TMEAN",
"2m_temperature_min": "AGERA5-TMIN",
"vapour_pressure": "AGERA5-VAPOUR",
"wind_speed": "AGERA5-WIND",
}


def stac_fetcher(
connection: openeo.Connection,
spatial_extent: SpatialContext,
temporal_extent: TemporalContext,
bands: list,
fetch_type: FetchType,
**params,
) -> openeo.DataCube:
bands = convert_band_names(bands, WEATHER_MAPPING_STAC)

cube = _load_collection(
connection,
bands,
"https://stac.openeo.vito.be/collections/agera5_daily",
spatial_extent,
temporal_extent,
fetch_type,
is_stac=True,
**params,
)

if isinstance(spatial_extent, GeoJSON) and fetch_type == FetchType.POLYGON:
cube = cube.filter_spatial(spatial_extent)

return cube


METEO_BACKEND_MAP = {
Backend.TERRASCOPE: {
"fetch": partial(
_get_generic_fetcher,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_TERRASCOPE,
),
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_TERRASCOPE,
),
},
Backend.CDSE: {
"fetch": stac_fetcher,
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_STAC,
),
},
Backend.CDSE_STAGING: {
"fetch": stac_fetcher,
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_STAC,
),
},
Backend.FED: {
"fetch": stac_fetcher,
"preprocessor": partial(
_get_generic_processor,
collection_name="AGERA5",
band_mapping=WEATHER_MAPPING_STAC,
),
},
}


def build_meteo_extractor(
backend_context: BackendContext,
bands: list,
fetch_type: FetchType,
**params,
) -> CollectionFetcher:
backend_functions = METEO_BACKEND_MAP.get(backend_context.backend)

fetcher, preprocessor = (
partial(backend_functions["fetch"], fetch_type=fetch_type),
backend_functions["preprocessor"](fetch_type=fetch_type),
)

return CollectionFetcher(backend_context, bands, fetcher, preprocessor, **params)
33 changes: 17 additions & 16 deletions src/openeo_gfmap/inference/model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import xarray as xr
from openeo.udf import XarrayDataCube
from openeo.udf import inspect as udf_inspect
from openeo.udf.run_code import execute_local_udf
from openeo.udf.udf_data import UdfData

sys.path.insert(0, "onnx_deps")
Expand All @@ -32,9 +31,6 @@ class ModelInference(ABC):
"""

def __init__(self) -> None:
self.logger = None

def _initialize_logger(self) -> None:
"""
Initializes the PrestoFeatureExtractor object, starting a logger.
"""
Expand Down Expand Up @@ -115,7 +111,6 @@ def _common_preparations(
"""Common preparations for all inference models. This method will be
executed at the very beginning of the process.
"""
self._initialize_logger()
self._epsg = parameters.pop(EPSG_HARMONIZED_NAME)
self._parameters = parameters
return inarr
Expand Down Expand Up @@ -323,21 +318,27 @@ def apply_model_inference_local(
excepts for the cube parameter which expects a `xarray.DataArray` instead of
a `openeo.rest.datacube.DataCube` object.
"""
# Trying to get the local EPSG code
if EPSG_HARMONIZED_NAME not in parameters:
raise ValueError(
f"Please specify an EPSG code in the parameters with key: {EPSG_HARMONIZED_NAME} when "
f"running a Model Inference locally."
)

model_inference = model_inference_class()
model_inference._parameters = parameters
output_labels = model_inference.output_labels()
dependencies = model_inference.dependencies()

udf_code = _generate_udf_code(model_inference_class, dependencies)

udf = openeo.UDF(code=udf_code, context=parameters)
if len(dependencies) > 0:
model_inference.logger.warning(
"Running UDFs locally with pip dependencies is not supported yet, "
"dependencies will not be installed."
)

cube = XarrayDataCube(cube)

out_udf_data: UdfData = execute_local_udf(udf, cube, fmt="NetCDF")

output_cubes = out_udf_data.datacube_list

assert len(output_cubes) == 1, "UDF should have only a single output cube."

return output_cubes[0].get_array().assign_coords({"bands": output_labels})
return (
model_inference._execute(cube, parameters)
.get_array()
.assign_coords({"bands": output_labels})
)
Loading
Loading