Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Probing #38

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Negiiiin
Copy link
Collaborator

PR Type

[Feature]

Short Description

Linear probing task

@@ -2,6 +2,7 @@

from mmlearn.tasks.contrastive_pretraining import ContrastivePretraining
from mmlearn.tasks.ijepa import IJEPA
from mmlearn.tasks.linear_probing import LinearClassifierModule
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from mmlearn.tasks.linear_probing import LinearClassifierModule
from mmlearn.tasks.linear_probing import LinearClassifier

@@ -11,4 +12,5 @@
"IJEPA",
"ZeroShotCrossModalRetrieval",
"ZeroShotClassification",
"LinearClassifierModule",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"LinearClassifierModule",
"LinearClassifier",



@store(group="task", provider="mmlearn")
class LinearClassifierModule(L.LightningModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class LinearClassifierModule(L.LightningModule):
class LinearClassifier(L.LightningModule):


def __init__(
self,
# encoder: torch.nn.Module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this line.

self,
# encoder: torch.nn.Module,
encoder: nn.Module,
model_checkpoint_path: Optional[str], # change name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_checkpoint_path: Optional[str], # change name
model_checkpoint_path: Optional[str],

print("Encoder state dict loaded successfully")
except Exception as e:
print(f"Error loading state dict: {e}")
return model["rgb"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please generalize this as well. That is, allow the user to specify which key, if any, to extract.

# Ignore specific keys
if keys_to_ignore:
state_dict = {
k: v for k, v in state_dict.items() if k not in keys_to_ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removing the keys from the state_dict

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the intention here to train a single linear classifier on all the datasets at once or a single linear classifier for each dataset?

task: multiclass
num_classes: 4
num_output_features: 512
hidden_dims: [256]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add hidden layers? Are you following a certain recipe?

num_output_features: 512
hidden_dims: [256]
modality: rgb
encoder_checkpoint_path: /path/to/checkpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make this a required parameter like so:

Suggested change
encoder_checkpoint_path: /path/to/checkpoint
encoder_checkpoint_path: ???

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants