Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom modules #217

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Custom modules #217

wants to merge 9 commits into from

Conversation

Joao-L-S-Almeida
Copy link
Member

@Joao-L-S-Almeida Joao-L-S-Almeida commented Nov 1, 2024

This is a prototype of how we could combine two model pipelines in terratorch. I believe it could be the easiest way to include third-party and custom model in our tests.

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
@Joao-L-S-Almeida Joao-L-S-Almeida marked this pull request as ready for review November 6, 2024 12:09
def model_module(self):
return self._model_module

# overwrite early stopping
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is in the wrong place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.

super().__init__()

self._model_module = model
self.model = model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if model is not passed as parameter? it is none, correct? does it have any implications?

Copy link
Member Author

@Joao-L-S-Almeida Joao-L-S-Almeida Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just modified it to guarantee that or a model_factory or a model is provided:

  if model_factory:  
         self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
         self.model_builder = self._build
 elif model:
     self.model_builder = self._bypass_build
 else:
     raise Exception("Or a model_factory or a torch.nn.Module object must be provided.")

Which will allow the selection of the most proper builder method.


def configure_models(self) -> None:
self.model: Model = self.model_builder()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you sure this works?

Case 1:
You pass a modelfactory, then self.model_builder = self._build and the model is build from the factory

Case 2:
You don't pass a model factory but a model
self.model_builder = self._bypass_build and the one you get passed in from constructor is used

I think the code is correct but hard to understand, I also might wanna add some argument checking, e.g. if model factory is None a model need to be passed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When no model_factory is passed, it used the object model, so the the method self._bypass_build just return this already instantiated object to be used during the process.

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants