Skip to content

Commit

Permalink
pixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Sep 13, 2024
1 parent 2e376ac commit 510f7ae
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/mistral_inference/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mistral_inference.lora import LoRALoaderMixin
from mistral_inference.model import ModelBase
from mistral_inference.rope import precompute_freqs_cis
from mistral_inference.transformer_utils import RMSNorm, TransformerBlock
from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
from mistral_inference.vision_encoder import VisionLanguageAdapter, VisionTransformer


Expand Down
File renamed without changes.
60 changes: 31 additions & 29 deletions src/mistral_inference/vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,7 @@

from mistral_inference.args import VisionEncoderArgs
from mistral_inference.rope import precompute_freqs_cis_2d
from mistral_inference.transformer_utils import RMSNorm, TransformerBlock


class Transformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(
TransformerBlock(
dim=args.hidden_size,
hidden_dim=args.intermediate_size,
n_heads=args.num_attention_heads,
n_kv_heads=args.num_attention_heads,
head_dim=args.hidden_size // args.num_attention_heads,
norm_eps=1e-5,
)
)

def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x
from mistral_inference.transformer_layers import RMSNorm, TransformerBlock


def position_meshgrid(
Expand Down Expand Up @@ -67,7 +40,7 @@ def __init__(self, args: VisionEncoderArgs):
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = Transformer(args)
self.transformer = VisionTransformerBlocks(args)

head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
Expand Down Expand Up @@ -142,3 +115,32 @@ def __init__(self, in_dim: int, out_dim: int):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]


class VisionTransformerBlocks(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(
TransformerBlock(
dim=args.hidden_size,
hidden_dim=args.intermediate_size,
n_heads=args.num_attention_heads,
n_kv_heads=args.num_attention_heads,
head_dim=args.hidden_size // args.num_attention_heads,
norm_eps=1e-5,
)
)

def forward(
self,
x: torch.Tensor,
mask: BlockDiagonalMask,
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x


0 comments on commit 510f7ae

Please sign in to comment.