diff --git a/ohsome_quality_api/api/api.py b/ohsome_quality_api/api/api.py index f19b5c346..2f5494a4d 100644 --- a/ohsome_quality_api/api/api.py +++ b/ohsome_quality_api/api/api.py @@ -275,7 +275,8 @@ async def post_attribute_completeness( if isinstance(parameters, AttributeCompletenessKeyRequest): for attribute in parameters.attribute_keys: validate_attribute_topic_combination( - attribute.value, parameters.topic_key.value + attribute, + parameters.topic, ) return await _post_indicator(request, "attribute-completeness", parameters) @@ -308,23 +309,12 @@ async def post_indicator( async def _post_indicator( - request: Request, key: str, parameters: IndicatorRequest + request: Request, + key: str, + parameters: IndicatorRequest, ) -> Any: - validate_indicator_topic_combination(key, parameters.topic_key.value) - attribute_keys = getattr(parameters, "attribute_keys", None) - if attribute_keys: - attribute_keys = [attribute.value for attribute in attribute_keys] - attribute_filter = getattr(parameters, "attribute_filter", None) - attribute_names = getattr(parameters, "attribute_names", None) - indicators = await oqt.create_indicator( - key=key, - bpolys=parameters.bpolys, - topic=get_topic_preset(parameters.topic_key.value), - include_figure=parameters.include_figure, - attribute_keys=attribute_keys, - attribute_filter=attribute_filter, - attribute_names=attribute_names, - ) + validate_indicator_topic_combination(key, parameters.topic) + indicators = await oqt.create_indicator(key=key, **dict(parameters)) if request.headers["accept"] == MEDIA_TYPE_JSON: return { @@ -345,10 +335,12 @@ async def _post_indicator( } else: detail = "Content-Type needs to be either {0} or {1}".format( - MEDIA_TYPE_JSON, MEDIA_TYPE_GEOJSON + MEDIA_TYPE_JSON, + MEDIA_TYPE_GEOJSON, ) raise HTTPException( - status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=detail + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + detail=detail, ) diff --git a/ohsome_quality_api/api/request_models.py b/ohsome_quality_api/api/request_models.py index 6126f7c03..c353038ca 100644 --- a/ohsome_quality_api/api/request_models.py +++ b/ohsome_quality_api/api/request_models.py @@ -2,11 +2,16 @@ import geojson from geojson_pydantic import Feature, FeatureCollection, MultiPolygon, Polygon -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, +) from ohsome_quality_api.attributes.definitions import AttributeEnum -from ohsome_quality_api.topics.definitions import TopicEnum -from ohsome_quality_api.topics.models import TopicData +from ohsome_quality_api.topics.definitions import TopicEnum, get_topic_preset +from ohsome_quality_api.topics.models import TopicData, TopicDefinition from ohsome_quality_api.utils.helper import snake_to_lower_camel @@ -49,7 +54,7 @@ class BaseBpolys(BaseConfig): @field_validator("bpolys") @classmethod - def transform(cls, value) -> geojson.FeatureCollection: + def transform_bpolys(cls, value): # NOTE: `geojson_pydantic` library is used only for validation and openAPI-spec # generation. To avoid refactoring all code the FeatureCollection object of # the `geojson` library is still used every else. @@ -57,13 +62,18 @@ def transform(cls, value) -> geojson.FeatureCollection: class IndicatorRequest(BaseBpolys): - topic_key: TopicEnum = Field( + topic: TopicEnum = Field( ..., title="Topic Key", alias="topic", ) include_figure: bool = True + @field_validator("topic") + @classmethod + def transform_topic(cls, value) -> TopicDefinition: + return get_topic_preset(value.value) + class AttributeCompletenessKeyRequest(IndicatorRequest): attribute_keys: List[AttributeEnum] = Field( @@ -72,6 +82,11 @@ class AttributeCompletenessKeyRequest(IndicatorRequest): alias="attributes", ) + @field_validator("attribute_keys") + @classmethod + def transform_attributes(cls, value) -> list[str]: + return [attribute.value for attribute in value] + class AttributeCompletenessFilterRequest(IndicatorRequest): attribute_filter: str = Field( diff --git a/ohsome_quality_api/oqt.py b/ohsome_quality_api/oqt.py index 56658bad6..88b62854b 100644 --- a/ohsome_quality_api/oqt.py +++ b/ohsome_quality_api/oqt.py @@ -1,7 +1,7 @@ """Controller for computing Indicators.""" import logging -from typing import Coroutine, List +from typing import Coroutine from geojson import Feature, FeatureCollection @@ -18,9 +18,8 @@ async def create_indicator( bpolys: FeatureCollection, topic: TopicData | TopicDefinition, include_figure: bool = True, - attribute_keys: List[str] | None = None, - attribute_filter: str | None = None, - attribute_names: List[str] | None = None, + *args, + **kwargs, ) -> list[Indicator]: """Create indicator(s) for features of a GeoJSON FeatureCollection. @@ -47,9 +46,8 @@ async def create_indicator( feature, topic, include_figure, - attribute_keys, - attribute_filter, - attribute_names, + *args, + **kwargs, ) ) return await gather_with_semaphore(tasks) @@ -60,9 +58,8 @@ async def _create_indicator( feature: Feature, topic: Topic, include_figure: bool = True, - attribute_keys: List[str] | None = None, - attribute_filter: str | None = None, - attribute_names: List[str] | None = None, + *args, + **kwargs, ) -> Indicator: """Create an indicator from scratch.""" @@ -71,19 +68,16 @@ async def _create_indicator( logging.info("Feature id: {0:4}".format(feature.get("id", "None"))) indicator_class = get_class_from_key(class_type="indicator", key=key) - if key == "attribute-completeness": - indicator = indicator_class( - topic, - feature, - attribute_keys, - attribute_filter, - attribute_names, - ) - else: - indicator = indicator_class(topic, feature) + indicator = indicator_class( + topic, + feature, + *args, + **kwargs, + ) logging.info("Run preprocessing") await indicator.preprocess() + logging.info("Run calculation") indicator.calculate() diff --git a/ohsome_quality_api/utils/validators.py b/ohsome_quality_api/utils/validators.py index 7aeaef9a0..3e41f7b4b 100644 --- a/ohsome_quality_api/utils/validators.py +++ b/ohsome_quality_api/utils/validators.py @@ -7,7 +7,7 @@ ) from ohsome_quality_api.config import get_config_value from ohsome_quality_api.indicators.definitions import get_valid_indicators -from ohsome_quality_api.topics.definitions import TopicEnum +from ohsome_quality_api.topics.models import BaseTopic from ohsome_quality_api.utils.exceptions import ( AttributeTopicCombinationError, GeoJSONError, @@ -20,19 +20,19 @@ from ohsome_quality_api.utils.helper_geo import calculate_area -def validate_attribute_topic_combination(attribute: AttributeEnum, topic: TopicEnum): +def validate_attribute_topic_combination(attribute: AttributeEnum, topic: BaseTopic): """As attributes are only meaningful for a certain topic, we need to check if the given combination is valid.""" - valid_attributes_for_topic = get_attributes()[topic] + valid_attributes_for_topic = get_attributes()[topic.key] valid_attribute_names = [attribute for attribute in valid_attributes_for_topic] if attribute not in valid_attributes_for_topic: raise AttributeTopicCombinationError(attribute, topic, valid_attribute_names) -def validate_indicator_topic_combination(indicator: str, topic: str): - if indicator not in get_valid_indicators(topic): +def validate_indicator_topic_combination(indicator: str, topic: BaseTopic): + if indicator not in get_valid_indicators(topic.key): raise IndicatorTopicCombinationError(indicator, topic) diff --git a/tests/unittests/test_validators.py b/tests/unittests/test_validators.py index 3ea3f9083..6a2a1be28 100644 --- a/tests/unittests/test_validators.py +++ b/tests/unittests/test_validators.py @@ -20,20 +20,17 @@ from tests.unittests.utils import get_geojson_fixture -def test_validate_attribute_topic_combination_with_valid_combination(): - validate_attribute_topic_combination("maxspeed", "roads") - - -def test_validate_attribute_topic_combination_with_invalid_topic(): - """As the method under test requires individually valid arguments - the arguments given lead to a KeyError.""" - with pytest.raises(KeyError): - validate_attribute_topic_combination("maxspeed", "xxxxx") +def test_validate_attribute_topic_combination_with_valid_combination( + topic_major_roads_length, +): + validate_attribute_topic_combination("maxspeed", topic_major_roads_length) -def test_validate_attribute_topic_combination_with_invalid_combination(): +def test_validate_attribute_topic_combination_with_invalid_combination( + topic_building_count, +): with pytest.raises(AttributeTopicCombinationError): - validate_attribute_topic_combination("maxspeed", "building-count") + validate_attribute_topic_combination("maxspeed", topic_building_count) def test_validate_geojson_feature_collection_single( @@ -77,13 +74,13 @@ def test_validate_geojson_unsupported_geometry_type( validate_geojson(feature_collection_unsupported_geometry_type) -def test_validate_indicator_topic_combination(): - validate_indicator_topic_combination("minimal", "minimal") +def test_validate_indicator_topic_combination(topic_minimal): + validate_indicator_topic_combination("minimal", topic_minimal) -def test_validate_indicator_topic_combination_invalid(): +def test_validate_indicator_topic_combination_invalid(topic_building_count): with pytest.raises(IndicatorTopicCombinationError): - validate_indicator_topic_combination("minimal", "building-count") + validate_indicator_topic_combination("minimal", topic_building_count) @mock.patch.dict(