Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing eval in model weight API #2323

Merged
merged 7 commits into from
Sep 28, 2024
Merged

Removing eval in model weight API #2323

merged 7 commits into from
Sep 28, 2024

Conversation

calebrob6
Copy link
Member

@calebrob6 calebrob6 commented Sep 27, 2024

Removing use of eval.

@nilsleh -- do you remember what the use case for eval was? If not, I say we replace all instances of get_weight with get_model_weights and remove get_weight. Nevermind, I understand it now.

get_weight(...) allows a user to pass a string "ResNet18_Weights.LANDSAT_ETM_SR_MOCO" and get back the object ResNet18_Weights.LANDSAT_ETM_SR_MOCO WeightEnum. Previously we did this with eval("ResNet18_Weights.LANDSAT_ETM_SR_MOCO") which is bad. This proposed fix simply iterates through all available WeightEnums to look for matches and has the benefit of raising an error that makes sense if there is no match (vs. the eval(...) which would do whatever).

@github-actions github-actions bot added the models Models and pretrained weights label Sep 27, 2024
@adamjstewart
Copy link
Collaborator

Most of this API was copied from torchvision, let me check what their code looks like these days.

@adamjstewart
Copy link
Collaborator

I wonder if we can use torchvision's @register_model decorator and remove this entire file. Timm also has a way to register models, although I don't know how compatible it is. But it would let us easily use our custom models in our trainers.

@github-actions github-actions bot added the testing Continuous integration testing label Sep 27, 2024
@calebrob6
Copy link
Member Author

I looked at the torchvision version (https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py#L108) and it seems more hacky than what I've implemented here. I updated the PR description with more details.

adamjstewart
adamjstewart previously approved these changes Sep 28, 2024
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More lines of code, but also more secure. This is also better from a type checking perspective, as eval can return any type. We can think about how to delete this file and properly register our models in a separate PR.

torchgeo/models/api.py Outdated Show resolved Hide resolved
@adamjstewart adamjstewart enabled auto-merge (squash) September 28, 2024 12:43
@adamjstewart adamjstewart merged commit 1a98078 into main Sep 28, 2024
19 checks passed
@adamjstewart adamjstewart deleted the bugfix branch September 28, 2024 12:58
adamjstewart added a commit that referenced this pull request Oct 10, 2024
* Removing eval

* Extend the model_weights dict with sub weights

* Just search through the sub weights enum

* Ruff

* Fix bug and mypy

* Test coverage

* Formatting

---------

Co-authored-by: Adam J. Stewart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Models and pretrained weights testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants