-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Linear Probing #38
base: main
Are you sure you want to change the base?
Linear Probing #38
Conversation
@@ -2,6 +2,7 @@ | |||
|
|||
from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining | |||
from mmlearn.tasks.ijepa import IJEPA | |||
from mmlearn.tasks.linear_probing import LinearClassifierModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from mmlearn.tasks.linear_probing import LinearClassifierModule | |
from mmlearn.tasks.linear_probing import LinearClassifier |
@@ -11,4 +12,5 @@ | |||
"IJEPA", | |||
"ZeroShotCrossModalRetrieval", | |||
"ZeroShotClassification", | |||
"LinearClassifierModule", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"LinearClassifierModule", | |
"LinearClassifier", |
|
||
|
||
@store(group="task", provider="mmlearn") | ||
class LinearClassifierModule(L.LightningModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class LinearClassifierModule(L.LightningModule): | |
class LinearClassifier(L.LightningModule): |
|
||
def __init__( | ||
self, | ||
# encoder: torch.nn.Module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this line.
self, | ||
# encoder: torch.nn.Module, | ||
encoder: nn.Module, | ||
model_checkpoint_path: Optional[str], # change name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_checkpoint_path: Optional[str], # change name | |
model_checkpoint_path: Optional[str], |
print("Encoder state dict loaded successfully") | ||
except Exception as e: | ||
print(f"Error loading state dict: {e}") | ||
return model["rgb"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please generalize this as well. That is, allow the user to specify which key, if any, to extract.
# Ignore specific keys | ||
if keys_to_ignore: | ||
state_dict = { | ||
k: v for k, v in state_dict.items() if k not in keys_to_ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is removing the keys from the state_dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the intention here to train a single linear classifier on all the datasets at once or a single linear classifier for each dataset?
task: multiclass | ||
num_classes: 4 | ||
num_output_features: 512 | ||
hidden_dims: [256] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add hidden layers? Are you following a certain recipe?
num_output_features: 512 | ||
hidden_dims: [256] | ||
modality: rgb | ||
encoder_checkpoint_path: /path/to/checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can make this a required parameter like so:
encoder_checkpoint_path: /path/to/checkpoint | |
encoder_checkpoint_path: ??? |
PR Type
[Feature]
Short Description
Linear probing task