Skip to content

Commit

Permalink
Add a new "out" parameter to ModelLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Oct 2, 2023
1 parent 4c39cdd commit 622c32c
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 24 deletions.
1 change: 1 addition & 0 deletions src/fairseq2/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fairseq2.models.llama.builder import llama_archs as llama_archs
from fairseq2.models.llama.loader import LLaMALoader as LLaMALoader
from fairseq2.models.llama.loader import LLaMATokenizerLoader as LLaMATokenizerLoader
from fairseq2.models.llama.loader import load_llama_config as load_llama_config
from fairseq2.models.llama.loader import load_llama_model as load_llama_model
from fairseq2.models.llama.loader import load_llama_tokenizer as load_llama_tokenizer
from fairseq2.models.llama.tokenizer import LLaMATokenizer as LLaMATokenizer
1 change: 1 addition & 0 deletions src/fairseq2/models/nllb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fairseq2.models.nllb.builder import nllb_arch as nllb_arch
from fairseq2.models.nllb.builder import nllb_archs as nllb_archs
from fairseq2.models.nllb.loader import NllbLoader as NllbLoader
from fairseq2.models.nllb.loader import load_nllb_config as load_nllb_config
from fairseq2.models.nllb.loader import load_nllb_model as load_nllb_model
from fairseq2.models.nllb.loader import load_nllb_tokenizer as load_nllb_tokenizer
from fairseq2.models.nllb.tokenizer import NllbTokenizer as NllbTokenizer
3 changes: 3 additions & 0 deletions src/fairseq2/models/s2t_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from fairseq2.models.s2t_transformer.frontend import (
S2TTransformerFrontend as S2TTransformerFrontend,
)
from fairseq2.models.s2t_transformer.loader import (
load_s2t_transformer_config as load_s2t_transformer_config,
)
from fairseq2.models.s2t_transformer.loader import (
load_s2t_transformer_model as load_s2t_transformer_model,
)
Expand Down
49 changes: 29 additions & 20 deletions src/fairseq2/models/utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
from fairseq2.models.utils.checkpoint_loader import load_checkpoint
from fairseq2.nn.utils.module import reset_non_persistent_buffers
from fairseq2.nn.utils.module import infer_device, reset_non_persistent_buffers
from fairseq2.typing import DataType, Device
from fairseq2.utils.dataclass import update_dataclass

Expand Down Expand Up @@ -162,22 +162,25 @@ def __call__(
self,
model_name_or_card: Union[str, AssetCard],
*,
force: bool = False,
progress: bool = True,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
out: Optional[ModelT] = None,
force: bool = False,
progress: bool = True,
) -> ModelT:
"""
:param model_name_or_card:
The name or asset card of the model to load.
:param force:
If ``True``, downloads the checkpoint even if it is already in cache.
:param progress:
If ``True``, displays a progress bar to stderr.
:param device:
The device on which to load the model.
:param dtype:
The data type of the model parameters and buffers.
:param out:
The output model to load.
:param force:
If ``True``, downloads the checkpoint even if it is already in cache.
:param progress:
If ``True``, displays a progress bar to stderr.
:returns:
A model loaded from the checkpoint of ``model_name_or_card``.
Expand All @@ -204,22 +207,28 @@ def __call__(
converter=partial(self._convert_checkpoint, config=config),
)

try:
# Try to construct the model on the meta device.
model = self.model_factory(config, device=Device("meta"), dtype=dtype)
except NotImplementedError:
is_meta = False

logger.warning(
f"One or more operators in {card.name} constructor do not support meta device. Skipping lazy initialization."
)
if out is not None:
model = out

# If we are here, it means the model has at least one operator that
# does not support meta device. Do regular model initialization.
model = self.model_factory(config, device=device, dtype=dtype)
is_meta = infer_device(model).type == "meta"
else:
is_meta = True
try:
# Try to construct the model on the meta device.
model = self.model_factory(config, device=Device("meta"), dtype=dtype)

is_meta = True
except NotImplementedError:
is_meta = False

logger.warning(
f"One or more operators in {card.name} constructor do not support meta device. Skipping lazy initialization."
)

# If we are here, it means the model has at least one operator that
# does not support meta device. Do regular model initialization.
model = self.model_factory(config, device=device, dtype=dtype)

if is_meta:
# Move the model to the actual device without initializing. Its
# state will be overwritten by the checkpoint anyways.
model = model.to_empty(device=device or "cpu")
Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/models/w2vbert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fairseq2.models.w2vbert.builder import create_w2vbert_model as create_w2vbert_model
from fairseq2.models.w2vbert.builder import w2vbert_arch as w2vbert_arch
from fairseq2.models.w2vbert.builder import w2vbert_archs as w2vbert_archs
from fairseq2.models.w2vbert.loader import load_w2vbert_config as load_w2vbert_config
from fairseq2.models.w2vbert.loader import load_w2vbert_model as load_w2vbert_model
from fairseq2.models.w2vbert.model import W2VBertLoss as W2VBertLoss
from fairseq2.models.w2vbert.model import W2VBertModel as W2VBertModel
Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Wav2Vec2FeatureExtractor as Wav2Vec2FeatureExtractor,
)
from fairseq2.models.wav2vec2.frontend import Wav2Vec2Frontend as Wav2Vec2Frontend
from fairseq2.models.wav2vec2.loader import load_wav2vec2_config as load_wav2vec2_config
from fairseq2.models.wav2vec2.loader import load_wav2vec2_model as load_wav2vec2_model
from fairseq2.models.wav2vec2.masker import Wav2Vec2Masker as Wav2Vec2Masker
from fairseq2.models.wav2vec2.model import Wav2Vec2Loss as Wav2Vec2Loss
Expand Down
19 changes: 15 additions & 4 deletions tests/integration/models/test_s2t_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from fairseq2.generation import SequenceToTextGenerator
from fairseq2.models.s2t_transformer import (
S2TTransformerTokenizer,
create_s2t_transformer_model,
load_s2t_transformer_config,
load_s2t_transformer_model,
load_s2t_transformer_tokenizer,
)
from fairseq2.models.transformer import TransformerModel
from fairseq2.typing import Device
from tests.common import device

TEST_FBANK_PATH: Final = Path(__file__).parent.joinpath("fbank.pt")
Expand All @@ -40,8 +43,12 @@ def test_load_s2t_transformer_mustc_st_jt_m() -> None:


def test_load_s2t_conformer_covost_st_en_de() -> None:
model = load_s2t_transformer_model(
"s2t_conformer_covost_st_en_de", device=device, progress=False
config = load_s2t_transformer_config("s2t_conformer_covost_st_en_de")

model = create_s2t_transformer_model(config, device=Device("meta"))

load_s2t_transformer_model(
"s2t_conformer_covost_st_en_de", device=device, out=model, progress=False
)

tokenizer = load_s2t_transformer_tokenizer(
Expand All @@ -52,8 +59,12 @@ def test_load_s2t_conformer_covost_st_en_de() -> None:


def test_load_s2t_conformer_rel_pos_covost_st_en_de() -> None:
model = load_s2t_transformer_model(
"s2t_conformer_covost_st_en_de_rel_pos", device=device, progress=False
config = load_s2t_transformer_config("s2t_conformer_covost_st_en_de_rel_pos")

model = create_s2t_transformer_model(config, device=device)

load_s2t_transformer_model(
"s2t_conformer_covost_st_en_de_rel_pos", out=model, progress=False
)

tokenizer = load_s2t_transformer_tokenizer(
Expand Down

0 comments on commit 622c32c

Please sign in to comment.