diff --git a/tests/models/test_api.py b/tests/models/test_api.py index 6c0bc6790f1..c5a56d1808a 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -80,3 +80,8 @@ def test_get_weight(enum: WeightsEnum) -> None: def test_list_models() -> None: models = [builder.__name__ for builder in builders] assert set(models) == set(list_models()) + + +def test_invalid_model() -> None: + with pytest.raises(ValueError, match='bad_model is not a valid WeightsEnum'): + get_weight('bad_model') diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 6e06db82a71..b5b058726b2 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -46,7 +46,7 @@ 'vit_small_patch16_224': vit_small_patch16_224, } -_model_weights = { +_model_weights: dict[str | Callable[..., nn.Module], WeightsEnum] = { dofa_base_patch16_224: DOFABase16_Weights, dofa_large_patch16_224: DOFALarge16_Weights, resnet18: ResNet18_Weights, @@ -109,8 +109,17 @@ def get_weight(name: str) -> WeightsEnum: Returns: The requested weight enum. + + Raises: + ValueError: If *name* is not a valid WeightsEnum. """ - return eval(name) + for weight_name, weight_enum in _model_weights.items(): + if isinstance(weight_name, str): + for sub_weight_enum in weight_enum: + if name == str(sub_weight_enum): + return sub_weight_enum + + raise ValueError(f'{name} is not a valid WeightsEnum') def list_models() -> list[str]: