diff --git a/src/fairseq2/models/llama/builder.py b/src/fairseq2/models/llama/builder.py index f29363b13..72de64215 100644 --- a/src/fairseq2/models/llama/builder.py +++ b/src/fairseq2/models/llama/builder.py @@ -197,6 +197,7 @@ class LLaMABuilder: """ config: LLaMAConfig + pos_encoder: Optional[RotaryEncoder] device: Optional[Device] dtype: Optional[DataType] @@ -216,6 +217,7 @@ def __init__( The data type of module parameters and buffers. """ self.config = config + self.pos_encoder = None self.device = device self.dtype = dtype @@ -297,19 +299,20 @@ def build_attention( """Build a Transformer multi-head attention layer.""" sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p) - pos_encoder = RotaryEncoder( - self.config.model_dim // num_heads, - self.config.max_seq_len, - device=self.device, - dtype=self.dtype, - ) + if self.pos_encoder is None: + self.pos_encoder = RotaryEncoder( + self.config.model_dim // num_heads, + self.config.max_seq_len, + device=self.device, + dtype=self.dtype, + ) return StandardMultiheadAttention( self.config.model_dim, num_heads, num_key_value_heads=num_key_value_heads, sdpa=sdpa, - pos_encoder=pos_encoder, + pos_encoder=self.pos_encoder, bias=False, device=self.device, dtype=self.dtype,