-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom modules #217
base: main
Are you sure you want to change the base?
Custom modules #217
Conversation
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
def model_module(self): | ||
return self._model_module | ||
|
||
# overwrite early stopping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment is in the wrong place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep.
super().__init__() | ||
|
||
self._model_module = model | ||
self.model = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if model is not passed as parameter? it is none, correct? does it have any implications?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just modified it to guarantee that or a model_factory or a model is provided:
if model_factory:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
self.model_builder = self._build
elif model:
self.model_builder = self._bypass_build
else:
raise Exception("Or a model_factory or a torch.nn.Module object must be provided.")
Which will allow the selection of the most proper builder method.
|
||
def configure_models(self) -> None: | ||
self.model: Model = self.model_builder() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So are you sure this works?
Case 1:
You pass a modelfactory, then self.model_builder = self._build and the model is build from the factory
Case 2:
You don't pass a model factory but a model
self.model_builder = self._bypass_build and the one you get passed in from constructor is used
I think the code is correct but hard to understand, I also might wanna add some argument checking, e.g. if model factory is None a model need to be passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When no model_factory is passed, it used the object model, so the the method self._bypass_build just return this already instantiated object to be used during the process.
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
This is a prototype of how we could combine two model pipelines in terratorch. I believe it could be the easiest way to include third-party and custom model in our tests.