Skip to content

Commit

Permalink
Extend the model_weights dict with sub weights
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 authored Sep 27, 2024
1 parent 66abac1 commit 109f94c
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions torchgeo/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
'vit_small_patch16_224': ViTSmall16_Weights,
}

for name, weight_enum in _model_weights.items():
if isinstance(name, str):
for sub_weight_enum in weight_enum:
_model_weights[str(sub_weight_enum)] = sub_weight_enum


def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module:
"""Get an instantiated model from its name.
Expand Down

0 comments on commit 109f94c

Please sign in to comment.