-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
wrapping.py
33 lines (24 loc) · 1.21 KB
/
wrapping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import functools
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
)
def get_size_policy(min_params=1e8):
num_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=min_params
)
return num_wrap_policy
def get_llama_wrapper():
"""we register our main layer class and use the fsdp transformer wrapping policy
ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
"""
# ==== use new transformer wrapper
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
)
return llama_auto_wrap_policy