Skip to content

Commit

Permalink
Cache rotary encoder in LLaMA builder
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 21, 2023
1 parent 3f2dc1b commit a721021
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/fairseq2/models/llama/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class LLaMABuilder:
"""

config: LLaMAConfig
pos_encoder: Optional[RotaryEncoder]
device: Optional[Device]
dtype: Optional[DataType]

Expand All @@ -216,6 +217,7 @@ def __init__(
The data type of module parameters and buffers.
"""
self.config = config
self.pos_encoder = None
self.device = device
self.dtype = dtype

Expand Down Expand Up @@ -297,19 +299,20 @@ def build_attention(
"""Build a Transformer multi-head attention layer."""
sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)

pos_encoder = RotaryEncoder(
self.config.model_dim // num_heads,
self.config.max_seq_len,
device=self.device,
dtype=self.dtype,
)
if self.pos_encoder is None:
self.pos_encoder = RotaryEncoder(
self.config.model_dim // num_heads,
self.config.max_seq_len,
device=self.device,
dtype=self.dtype,
)

return StandardMultiheadAttention(
self.config.model_dim,
num_heads,
num_key_value_heads=num_key_value_heads,
sdpa=sdpa,
pos_encoder=pos_encoder,
pos_encoder=self.pos_encoder,
bias=False,
device=self.device,
dtype=self.dtype,
Expand Down

0 comments on commit a721021

Please sign in to comment.