Skip to content

Commit

Permalink
I-JEPA layers (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
vahid0001 authored Sep 18, 2024
1 parent ade1256 commit 4704ec4
Show file tree
Hide file tree
Showing 6 changed files with 1,216 additions and 1 deletion.
41 changes: 40 additions & 1 deletion mmlearn/datasets/processors/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import random
from typing import Any, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import torch
from hydra_zen import store
Expand Down Expand Up @@ -225,3 +225,42 @@ def __call__(self) -> torch.Tensor:
mask_count += delta

return mask


def apply_masks(
x: torch.Tensor, masks: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""
Apply masks to the input tensor by selecting the patches to keep based on the masks.
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, D), where B is the batch size, N is the number
of patches, and D is the feature dimension.
masks : Union[torch.Tensor, List[torch.Tensor]]
A list of tensors containing the indices of patches to keep for each sample.
Each mask tensor has shape (B, N), where B is the batch size and N is the number
of patches.
Returns
-------
torch.Tensor
The masked tensor where only the patches indicated by the masks are kept.
The output tensor has shape (B', N', D), where B' is the new batch size
(which may be different due to concatenation) and N' is the
reduced number of patches.
Notes
-----
- The masks should indicate which patches to keep (1 for keep, 0 for discard).
- The function uses `torch.gather` to select the patches specified by the masks.
"""
all_x = []
for m in masks:
# Expand the mask to match the feature dimension and gather the relevant patches
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x.append(torch.gather(x, dim=1, index=mask_keep))

# Concatenate along the batch dimension
return torch.cat(all_x, dim=0)
104 changes: 104 additions & 0 deletions mmlearn/datasets/processors/transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Custom transforms for datasets."""

import math
from typing import List, Union

import torch
from hydra_zen import store


Expand All @@ -28,3 +30,105 @@ def __call__(self, sentence: Union[str, List[str]]) -> Union[str, List[str]]:
sentence[i] = s[: self.trim_size]

return sentence


def _no_grad_trunc_normal_(
tensor: torch.Tensor, mean: float, std: float, a: float, b: float
) -> torch.Tensor:
"""
Apply truncated normal initialization to a tensor.
Parameters
----------
tensor : torch.Tensor
The tensor to be initialized.
mean : float
Mean of the normal distribution.
std : float
Standard deviation of the normal distribution.
a : float
Minimum value of the truncated distribution.
b : float
Maximum value of the truncated distribution.
Returns
-------
torch.Tensor
The initialized tensor.
"""

def norm_cdf(x: float) -> float:
"""Compute standard normal cumulative distribution function."""
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

with torch.no_grad():
lower_limit = norm_cdf((a - mean) / std)
upper_limit = norm_cdf((b - mean) / std)

tensor.uniform_(2 * lower_limit - 1, 2 * upper_limit - 1)
tensor.erfinv_()

tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)

return tensor


def trunc_normal_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""
Initialize a tensor using a truncated normal distribution.
Parameters
----------
tensor : torch.Tensor
The tensor to be initialized.
mean : float, default=0.
Mean of the normal distribution.
std : float, default=1.
Standard deviation of the normal distribution.
a : float, default=-2.
Minimum value of the truncated distribution.
b : float, default=2.
Maximum value of the truncated distribution.
Returns
-------
torch.Tensor
The initialized tensor.
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def repeat_interleave_batch(x: torch.Tensor, b: int, repeat: int) -> torch.Tensor:
"""
Repeat and interleave a tensor across the batch dimension.
Parameters
----------
x : torch.Tensor
Input tensor to be repeated.
b : int
Size of the batch to be repeated.
repeat : int
Number of times to repeat each batch.
Returns
-------
torch.Tensor
The repeated tensor with shape adjusted for the batch.
"""
n = len(x) // b
return torch.cat(
[
torch.cat([x[i * b : (i + 1) * b] for _ in range(repeat)], dim=0)
for i in range(n)
],
dim=0,
)
Loading

0 comments on commit 4704ec4

Please sign in to comment.