Skip to content

Commit

Permalink
Removing eval in model weight API (#2323)
Browse files Browse the repository at this point in the history
* Removing eval

* Extend the model_weights dict with sub weights

* Just search through the sub weights enum

* Ruff

* Fix bug and mypy

* Test coverage

* Formatting

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
calebrob6 and adamjstewart authored Sep 28, 2024
1 parent da691c4 commit 1a98078
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
5 changes: 5 additions & 0 deletions tests/models/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,8 @@ def test_get_weight(enum: WeightsEnum) -> None:
def test_list_models() -> None:
models = [builder.__name__ for builder in builders]
assert set(models) == set(list_models())


def test_invalid_model() -> None:
with pytest.raises(ValueError, match='bad_model is not a valid WeightsEnum'):
get_weight('bad_model')
13 changes: 11 additions & 2 deletions torchgeo/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
'vit_small_patch16_224': vit_small_patch16_224,
}

_model_weights = {
_model_weights: dict[str | Callable[..., nn.Module], WeightsEnum] = {
dofa_base_patch16_224: DOFABase16_Weights,
dofa_large_patch16_224: DOFALarge16_Weights,
resnet18: ResNet18_Weights,
Expand Down Expand Up @@ -109,8 +109,17 @@ def get_weight(name: str) -> WeightsEnum:
Returns:
The requested weight enum.
Raises:
ValueError: If *name* is not a valid WeightsEnum.
"""
return eval(name)
for weight_name, weight_enum in _model_weights.items():
if isinstance(weight_name, str):
for sub_weight_enum in weight_enum:
if name == str(sub_weight_enum):
return sub_weight_enum

raise ValueError(f'{name} is not a valid WeightsEnum')


def list_models() -> list[str]:
Expand Down

0 comments on commit 1a98078

Please sign in to comment.