From 66abac168ba7f0044934f51d4b744c6148a6b6b9 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 26 Sep 2024 17:53:07 -0700 Subject: [PATCH 1/7] Removing eval --- torchgeo/models/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 6e06db82a71..f89c27b51ea 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -110,7 +110,7 @@ def get_weight(name: str) -> WeightsEnum: Returns: The requested weight enum. """ - return eval(name) + return _model_weights[name] def list_models() -> list[str]: From 109f94cdd3f881c4172c40cb9ecbc51cd78331d2 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 26 Sep 2024 18:20:43 -0700 Subject: [PATCH 2/7] Extend the model_weights dict with sub weights --- torchgeo/models/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index f89c27b51ea..dc8038a680c 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -67,6 +67,11 @@ 'vit_small_patch16_224': ViTSmall16_Weights, } +for name, weight_enum in _model_weights.items(): + if isinstance(name, str): + for sub_weight_enum in weight_enum: + _model_weights[str(sub_weight_enum)] = sub_weight_enum + def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: """Get an instantiated model from its name. From 4e268fceee72dd400af4087e1b413ca629df723f Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Thu, 26 Sep 2024 18:26:40 -0700 Subject: [PATCH 3/7] Just search through the sub weights enum --- torchgeo/models/api.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index dc8038a680c..4fb47f2aadc 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -67,11 +67,6 @@ 'vit_small_patch16_224': ViTSmall16_Weights, } -for name, weight_enum in _model_weights.items(): - if isinstance(name, str): - for sub_weight_enum in weight_enum: - _model_weights[str(sub_weight_enum)] = sub_weight_enum - def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: """Get an instantiated model from its name. @@ -114,8 +109,21 @@ def get_weight(name: str) -> WeightsEnum: Returns: The requested weight enum. + + Raises: + ValueError: if `name` doesn't point to a valid WeightsEnum """ - return _model_weights[name] + if name in _model_weights: + return _model_weights[name] + else: + sub_weights = {} + for name, weight_enum in _model_weights.items(): + if isinstance(name, str): + for sub_weight_enum in weight_enum: + if name == str(sub_weight_enum): + return sub_weight_enum + + raise ValueError(f"{name} isn't a valid WeightsEnum") def list_models() -> list[str]: From 63d346bb1850ecf1b8de8d3b4e487d4fa0d02f5b Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 27 Sep 2024 01:28:16 +0000 Subject: [PATCH 4/7] Ruff --- torchgeo/models/api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 4fb47f2aadc..c79bf09c52e 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -111,12 +111,11 @@ def get_weight(name: str) -> WeightsEnum: The requested weight enum. Raises: - ValueError: if `name` doesn't point to a valid WeightsEnum + ValueError: if `name` doesn't point to a valid WeightsEnum """ if name in _model_weights: return _model_weights[name] else: - sub_weights = {} for name, weight_enum in _model_weights.items(): if isinstance(name, str): for sub_weight_enum in weight_enum: From cf57e5a2580ae3bb69254aab7c145ace5fb93ef3 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 27 Sep 2024 17:03:29 +0000 Subject: [PATCH 5/7] Fix bug and mypy --- torchgeo/models/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index c79bf09c52e..a33ce5c2970 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -46,7 +46,7 @@ 'vit_small_patch16_224': vit_small_patch16_224, } -_model_weights = { +_model_weights: dict[str | Callable[..., nn.Module], WeightsEnum] = { dofa_base_patch16_224: DOFABase16_Weights, dofa_large_patch16_224: DOFALarge16_Weights, resnet18: ResNet18_Weights, @@ -116,8 +116,8 @@ def get_weight(name: str) -> WeightsEnum: if name in _model_weights: return _model_weights[name] else: - for name, weight_enum in _model_weights.items(): - if isinstance(name, str): + for weight_name, weight_enum in _model_weights.items(): + if isinstance(weight_name, str): for sub_weight_enum in weight_enum: if name == str(sub_weight_enum): return sub_weight_enum From 9c4d54a1465bcf65428073470cac52a022dc2760 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Fri, 27 Sep 2024 17:36:20 +0000 Subject: [PATCH 6/7] Test coverage --- tests/models/test_api.py | 5 +++++ torchgeo/models/api.py | 17 +++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/models/test_api.py b/tests/models/test_api.py index 6c0bc6790f1..c5a56d1808a 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -80,3 +80,8 @@ def test_get_weight(enum: WeightsEnum) -> None: def test_list_models() -> None: models = [builder.__name__ for builder in builders] assert set(models) == set(list_models()) + + +def test_invalid_model() -> None: + with pytest.raises(ValueError, match='bad_model is not a valid WeightsEnum'): + get_weight('bad_model') diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index a33ce5c2970..a420d397bac 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -113,16 +113,13 @@ def get_weight(name: str) -> WeightsEnum: Raises: ValueError: if `name` doesn't point to a valid WeightsEnum """ - if name in _model_weights: - return _model_weights[name] - else: - for weight_name, weight_enum in _model_weights.items(): - if isinstance(weight_name, str): - for sub_weight_enum in weight_enum: - if name == str(sub_weight_enum): - return sub_weight_enum - - raise ValueError(f"{name} isn't a valid WeightsEnum") + for weight_name, weight_enum in _model_weights.items(): + if isinstance(weight_name, str): + for sub_weight_enum in weight_enum: + if name == str(sub_weight_enum): + return sub_weight_enum + + raise ValueError(f'{name} is not a valid WeightsEnum') def list_models() -> list[str]: From f69b0f5cedf75a6c2826b0f893bbf888ff425991 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 28 Sep 2024 14:43:13 +0200 Subject: [PATCH 7/7] Formatting --- torchgeo/models/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index a420d397bac..b5b058726b2 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -111,7 +111,7 @@ def get_weight(name: str) -> WeightsEnum: The requested weight enum. Raises: - ValueError: if `name` doesn't point to a valid WeightsEnum + ValueError: If *name* is not a valid WeightsEnum. """ for weight_name, weight_enum in _model_weights.items(): if isinstance(weight_name, str):