From cfb71dca1b0a61b1f726cdc994d7d5a493a1deab Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 3 Jan 2025 22:38:55 +0000 Subject: [PATCH] Revise library setup --- src/fairseq2/assets/__init__.py | 10 ++- src/fairseq2/assets/card.py | 24 +++--- src/fairseq2/assets/error.py | 7 +- src/fairseq2/assets/store.py | 46 ++--------- src/fairseq2/chatbots/__init__.py | 8 +- src/fairseq2/chatbots/handler.py | 28 +++++++ src/fairseq2/chatbots/llama.py | 14 ++-- src/fairseq2/chatbots/mistral.py | 12 +-- src/fairseq2/chatbots/register.py | 29 ------- src/fairseq2/chatbots/registry.py | 43 ---------- src/fairseq2/chatbots/static.py | 28 +++++++ src/fairseq2/data/text/__init__.py | 36 ++++----- .../tokenizers/{registry.py => handler.py} | 38 ++++----- src/fairseq2/data/text/tokenizers/llama.py | 4 +- src/fairseq2/data/text/tokenizers/nllb.py | 2 +- src/fairseq2/data/text/tokenizers/register.py | 63 --------------- .../data/text/tokenizers/s2t_transformer.py | 2 +- src/fairseq2/data/text/tokenizers/static.py | 31 ++++---- src/fairseq2/datasets/loader.py | 6 +- src/fairseq2/extensions.py | 34 +++----- src/fairseq2/models/config_loader.py | 4 +- src/fairseq2/models/loader.py | 17 ++-- src/fairseq2/recipes/lm/chatbot.py | 10 +-- src/fairseq2/recipes/utils/asset.py | 4 +- src/fairseq2/setup.py | 74 ------------------ src/fairseq2/setup/__init__.py | 16 ++++ src/fairseq2/setup/assets.py | 68 ++++++++++++++++ src/fairseq2/setup/chatbots.py | 21 +++++ src/fairseq2/setup/root.py | 58 ++++++++++++++ src/fairseq2/setup/text_tokenizers.py | 78 +++++++++++++++++++ tests/integration/models/test_llama.py | 8 +- 31 files changed, 435 insertions(+), 388 deletions(-) create mode 100644 src/fairseq2/chatbots/handler.py delete mode 100644 src/fairseq2/chatbots/register.py delete mode 100644 src/fairseq2/chatbots/registry.py create mode 100644 src/fairseq2/chatbots/static.py rename src/fairseq2/data/text/tokenizers/{registry.py => handler.py} (65%) delete mode 100644 src/fairseq2/data/text/tokenizers/register.py delete mode 100644 src/fairseq2/setup.py create mode 100644 src/fairseq2/setup/__init__.py create mode 100644 src/fairseq2/setup/assets.py create mode 100644 src/fairseq2/setup/chatbots.py create mode 100644 src/fairseq2/setup/root.py create mode 100644 src/fairseq2/setup/text_tokenizers.py diff --git a/src/fairseq2/assets/__init__.py b/src/fairseq2/assets/__init__.py index 63542e37c..b13739375 100644 --- a/src/fairseq2/assets/__init__.py +++ b/src/fairseq2/assets/__init__.py @@ -14,9 +14,6 @@ from fairseq2.assets.download_manager import ( InProcAssetDownloadManager as InProcAssetDownloadManager, ) -from fairseq2.assets.download_manager import ( - default_asset_download_manager as default_asset_download_manager, -) from fairseq2.assets.error import AssetCardError as AssetCardError from fairseq2.assets.error import ( AssetCardFieldNotFoundError as AssetCardFieldNotFoundError, @@ -43,9 +40,14 @@ from fairseq2.assets.metadata_provider import ( PackageAssetMetadataProvider as PackageAssetMetadataProvider, ) +from fairseq2.assets.metadata_provider import PackageFileLister as PackageFileLister +from fairseq2.assets.metadata_provider import ( + WheelPackageFileLister as WheelPackageFileLister, +) from fairseq2.assets.metadata_provider import load_metadata_file as load_metadata_file from fairseq2.assets.store import AssetStore as AssetStore from fairseq2.assets.store import EnvironmentResolver as EnvironmentResolver from fairseq2.assets.store import StandardAssetStore as StandardAssetStore from fairseq2.assets.store import default_asset_store as default_asset_store -from fairseq2.assets.store import setup_asset_store as setup_asset_store +from fairseq2.assets.store import get_asset_dir as get_asset_dir +from fairseq2.assets.store import get_user_asset_dir as get_user_asset_dir diff --git a/src/fairseq2/assets/card.py b/src/fairseq2/assets/card.py index 434ced16f..388c8b8a2 100644 --- a/src/fairseq2/assets/card.py +++ b/src/fairseq2/assets/card.py @@ -15,7 +15,11 @@ from fairseq2.assets.error import AssetCardError, AssetCardFieldNotFoundError from fairseq2.error import InternalError -from fairseq2.utils.structured import StructureError, structure, unstructure +from fairseq2.utils.structured import ( + StructureError, + default_value_converter, + unstructure, +) @final @@ -75,7 +79,7 @@ def _get_field_value(self, leaf_card: AssetCard, path: list[str]) -> object: pathname = ".".join(path) raise AssetCardFieldNotFoundError( - f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." + leaf_card.name, f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." # fmt: skip ) try: @@ -92,7 +96,7 @@ def _get_field_value(self, leaf_card: AssetCard, path: list[str]) -> object: pathname = ".".join(path) raise AssetCardFieldNotFoundError( - f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." + leaf_card.name, f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." # fmt: skip ) return metadata @@ -118,7 +122,7 @@ def _set_field_value(self, path: list[str], value: object) -> None: pathname = ".".join(path) raise AssetCardError( - f"The '{self._name}' asset card cannot have a field named '{pathname}' due to path conflict at '{conflict_pathname}'." + self._name, f"The '{self._name}' asset card cannot have a field named '{pathname}' due to path conflict at '{conflict_pathname}'." # fmt: skip ) metadata = value_ @@ -211,12 +215,12 @@ def as_(self, type_: object, *, allow_empty: bool = False) -> Any: unstructured_value = self._card._get_field_value(self._card, self._path) try: - value = structure(unstructured_value, type_) + value = default_value_converter.structure(unstructured_value, type_) except StructureError as ex: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the '{pathname}' field of the '{self._card.name}' asset card cannot be parsed as `{type_}`. See the nested exception for details." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card cannot be parsed as `{type_}`. See the nested exception for details." # fmt: skip ) from ex if value is None: @@ -226,7 +230,7 @@ def as_(self, type_: object, *, allow_empty: bool = False) -> Any: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the '{pathname}' field of the '{self._card.name}' asset card is empty." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is empty." # fmt: skip ) return value @@ -252,7 +256,7 @@ def as_one_of(self, valid_values: Set[str]) -> str: s = ", ".join(values) raise AssetCardError( - f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be one of the following values, but is '{value}' instead: {s}" + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be one of the following values, but is '{value}' instead: {s}" # fmt: skip ) return value @@ -274,7 +278,7 @@ def as_uri(self) -> str: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a URI or an absolute pathname, but is '{value}' instead." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a URI or an absolute pathname, but is '{value}' instead." # fmt: skip ) from None def as_filename(self) -> str: @@ -285,7 +289,7 @@ def as_filename(self) -> str: pathname = ".".join(self._path) raise AssetCardError( - f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a filename, but is '{value}' instead." + self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a filename, but is '{value}' instead." # fmt: skip ) return value diff --git a/src/fairseq2/assets/error.py b/src/fairseq2/assets/error.py index 9e1a6ba0f..ecebbd96a 100644 --- a/src/fairseq2/assets/error.py +++ b/src/fairseq2/assets/error.py @@ -14,7 +14,12 @@ class AssetError(Exception): class AssetCardError(AssetError): - pass + name: str + + def __init__(self, name: str, message: str) -> None: + super().__init__(message) + + self.name = name class AssetCardNotFoundError(AssetCardError): diff --git a/src/fairseq2/assets/store.py b/src/fairseq2/assets/store.py index e0b4d3d79..225d58edb 100644 --- a/src/fairseq2/assets/store.py +++ b/src/fairseq2/assets/store.py @@ -18,14 +18,11 @@ from fairseq2.assets.metadata_provider import ( AssetMetadataNotFoundError, AssetMetadataProvider, - FileAssetMetadataProvider, PackageAssetMetadataProvider, WheelPackageFileLister, ) from fairseq2.error import ContractError -from fairseq2.extensions import run_extensions from fairseq2.utils.env import get_path_from_env -from fairseq2.utils.file import StandardFileSystem from fairseq2.utils.yaml import load_yaml AssetScope: TypeAlias = Literal["all", "global", "user"] @@ -118,7 +115,7 @@ def _do_retrieve_card( metadata = self._get_metadata(f"{name}@", scope) except AssetMetadataNotFoundError: raise AssetCardNotFoundError( - f"An asset card with name '{name}' is not found." + name, f"An asset card with name '{name}' is not found." ) from None # If we have environment-specific metadata, merge it with `metadata`. @@ -157,7 +154,7 @@ def contract_error( base_card = self._do_retrieve_card(base_name, envs, scope) except AssetCardNotFoundError: raise AssetCardError( - f"A transitive base asset card with name '{name}' is not found." + name, f"A transitive base asset card with name '{base_name}' is not found." # fmt: skip ) from None base_path = metadata.get("__base_path__") @@ -217,20 +214,6 @@ def clear_cache(self) -> None: for provider in self.user_metadata_providers: provider.clear_cache() - def add_file_metadata_provider(self, path: Path, user: bool = False) -> None: - """Add a new :class:`FileAssetMetadataProvider` pointing to ``path``. - - :param path: The directory under which asset metadata is stored. - :param user: If ``True``, adds the metadata provider to the user scope. - """ - file_system = StandardFileSystem() - - provider = FileAssetMetadataProvider(path, file_system, load_yaml) - - providers = self.user_metadata_providers if user else self.metadata_providers - - providers.append(provider) - def add_package_metadata_provider(self, package_name: str) -> None: """Add a new :class:`PackageAssetMetadataProvider` for ``package_name``. @@ -254,30 +237,17 @@ def __call__(self) -> str | None: default_asset_store = StandardAssetStore() -def setup_asset_store(store: StandardAssetStore) -> None: - store.add_package_metadata_provider("fairseq2.assets.cards") - - # /etc/fairseq2/assets - _add_etc_dir_metadata_provider(store) - - # ~/.config/fairseq2/assets - _add_home_config_dir_metadata_provider(store) - - # Extensions - run_extensions("setup_fairseq2_asset_store", store) - - -def _add_etc_dir_metadata_provider(store: StandardAssetStore) -> None: +def get_asset_dir() -> Path | None: asset_dir = get_path_from_env("FAIRSEQ2_ASSET_DIR") if asset_dir is None: asset_dir = Path("/etc/fairseq2/assets").resolve() if not asset_dir.exists(): - return + return None - store.add_file_metadata_provider(asset_dir) + return asset_dir -def _add_home_config_dir_metadata_provider(store: StandardAssetStore) -> None: +def get_user_asset_dir() -> Path | None: asset_dir = get_path_from_env("FAIRSEQ2_USER_ASSET_DIR") if asset_dir is None: asset_dir = get_path_from_env("XDG_CONFIG_HOME") @@ -286,6 +256,6 @@ def _add_home_config_dir_metadata_provider(store: StandardAssetStore) -> None: asset_dir = asset_dir.joinpath("fairseq2/assets").resolve() if not asset_dir.exists(): - return + return None - store.add_file_metadata_provider(asset_dir, user=True) + return asset_dir diff --git a/src/fairseq2/chatbots/__init__.py b/src/fairseq2/chatbots/__init__.py index 9a9b67fcc..b44133ed2 100644 --- a/src/fairseq2/chatbots/__init__.py +++ b/src/fairseq2/chatbots/__init__.py @@ -10,8 +10,6 @@ from fairseq2.chatbots.chatbot import Chatbot as Chatbot from fairseq2.chatbots.chatbot import ChatDialog as ChatDialog from fairseq2.chatbots.chatbot import ChatMessage as ChatMessage -from fairseq2.chatbots.register import register_chatbots as register_chatbots -from fairseq2.chatbots.registry import ChatbotFactory as ChatbotFactory -from fairseq2.chatbots.registry import ChatbotHandler as ChatbotHandler -from fairseq2.chatbots.registry import ChatbotRegistry as ChatbotRegistry -from fairseq2.chatbots.registry import StandardChatbotHandler as StandardChatbotHandler +from fairseq2.chatbots.handler import ChatbotHandler as ChatbotHandler +from fairseq2.chatbots.handler import ChatbotNotFoundError as ChatbotNotFoundError +from fairseq2.chatbots.static import create_chatbot as create_chatbot diff --git a/src/fairseq2/chatbots/handler.py b/src/fairseq2/chatbots/handler.py new file mode 100644 index 000000000..6d425de99 --- /dev/null +++ b/src/fairseq2/chatbots/handler.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod + +from fairseq2.chatbots.chatbot import Chatbot +from fairseq2.data.text import TextTokenizer +from fairseq2.generation.generator import SequenceGenerator + + +class ChatbotHandler(ABC): + @abstractmethod + def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: + ... + + +class ChatbotNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known chatbot.") + + self.name = name diff --git a/src/fairseq2/chatbots/llama.py b/src/fairseq2/chatbots/llama.py index b2a142b53..7b8452612 100644 --- a/src/fairseq2/chatbots/llama.py +++ b/src/fairseq2/chatbots/llama.py @@ -13,6 +13,7 @@ from typing_extensions import override from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog, ChatMessage +from fairseq2.chatbots.handler import ChatbotHandler from fairseq2.data.text import LLaMA3Tokenizer, TextTokenEncoder, TextTokenizer from fairseq2.generation import SequenceGenerator from fairseq2.nn.utils.module import infer_device @@ -206,10 +207,11 @@ def supports_system_prompt(self) -> bool: return True -def make_llama_chatbot( - generator: SequenceGenerator, tokenizer: TextTokenizer -) -> Chatbot: - if isinstance(tokenizer, LLaMA3Tokenizer): - return LLaMA3Chatbot(generator, tokenizer) +@final +class LLaMAChatbotHandler(ChatbotHandler): + @override + def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: + if isinstance(tokenizer, LLaMA3Tokenizer): + return LLaMA3Chatbot(generator, tokenizer) - return LLaMAChatbot(generator, tokenizer) + return LLaMAChatbot(generator, tokenizer) diff --git a/src/fairseq2/chatbots/mistral.py b/src/fairseq2/chatbots/mistral.py index 3e1273cc1..ba86f8d30 100644 --- a/src/fairseq2/chatbots/mistral.py +++ b/src/fairseq2/chatbots/mistral.py @@ -12,7 +12,8 @@ from torch import Tensor from typing_extensions import override -from fairseq2.chatbots.chatbot import AbstractChatbot, ChatDialog +from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog +from fairseq2.chatbots.handler import ChatbotHandler from fairseq2.data.text import TextTokenEncoder, TextTokenizer from fairseq2.generation import SequenceGenerator from fairseq2.nn.utils.module import infer_device @@ -91,7 +92,8 @@ def supports_system_prompt(self) -> bool: return False -def make_mistral_chatbot( - generator: SequenceGenerator, tokenizer: TextTokenizer -) -> MistralChatbot: - return MistralChatbot(generator, tokenizer) +@final +class MistralChatbotHandler(ChatbotHandler): + @override + def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: + return MistralChatbot(generator, tokenizer) diff --git a/src/fairseq2/chatbots/register.py b/src/fairseq2/chatbots/register.py deleted file mode 100644 index c9ad8a6bc..000000000 --- a/src/fairseq2/chatbots/register.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from fairseq2.chatbots.llama import make_llama_chatbot -from fairseq2.chatbots.mistral import make_mistral_chatbot -from fairseq2.chatbots.registry import ChatbotRegistry, StandardChatbotHandler -from fairseq2.extensions import run_extensions -from fairseq2.models.llama import LLAMA_FAMILY -from fairseq2.models.mistral import MISTRAL_FAMILY - - -def register_chatbots(registry: ChatbotRegistry) -> None: - # LLaMA - handler = StandardChatbotHandler(factory=make_llama_chatbot) - - registry.register(LLAMA_FAMILY, handler) - - # Mistral - handler = StandardChatbotHandler(factory=make_mistral_chatbot) - - registry.register(MISTRAL_FAMILY, handler) - - # Extensions - run_extensions("register_fairseq2_chatbots", registry) diff --git a/src/fairseq2/chatbots/registry.py b/src/fairseq2/chatbots/registry.py deleted file mode 100644 index f5a9e09ed..000000000 --- a/src/fairseq2/chatbots/registry.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Protocol, TypeAlias, final - -from typing_extensions import override - -from fairseq2.chatbots.chatbot import Chatbot -from fairseq2.context import Registry -from fairseq2.data.text import TextTokenizer -from fairseq2.generation.generator import SequenceGenerator - - -class ChatbotHandler(ABC): - @abstractmethod - def make(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: - ... - - -ChatbotRegistry: TypeAlias = Registry[ChatbotHandler] - - -class ChatbotFactory(Protocol): - def __call__( - self, generator: SequenceGenerator, tokenizer: TextTokenizer - ) -> Chatbot: - ... - - -@final -class StandardChatbotHandler(ChatbotHandler): - def __init__(self, *, factory: ChatbotFactory) -> None: - self._factory = factory - - @override - def make(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot: - return self._factory(generator, tokenizer) diff --git a/src/fairseq2/chatbots/static.py b/src/fairseq2/chatbots/static.py new file mode 100644 index 000000000..e9db21970 --- /dev/null +++ b/src/fairseq2/chatbots/static.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.chatbots.chatbot import Chatbot +from fairseq2.chatbots.handler import ChatbotHandler, ChatbotNotFoundError +from fairseq2.context import get_runtime_context +from fairseq2.data.text import TextTokenizer +from fairseq2.generation.generator import SequenceGenerator + + +def create_chatbot( + name: str, generator: SequenceGenerator, tokenizer: TextTokenizer +) -> Chatbot: + context = get_runtime_context() + + registry = context.get_registry(ChatbotHandler) + + try: + handler = registry.get(name) + except LookupError: + raise ChatbotNotFoundError(name) from None + + return handler.create(generator, tokenizer) diff --git a/src/fairseq2/data/text/__init__.py b/src/fairseq2/data/text/__init__.py index 0754d2893..815d24a16 100644 --- a/src/fairseq2/data/text/__init__.py +++ b/src/fairseq2/data/text/__init__.py @@ -14,6 +14,21 @@ from fairseq2.data.text.tokenizers.char_tokenizer import ( CHAR_TOKENIZER_FAMILY as CHAR_TOKENIZER_FAMILY, ) +from fairseq2.data.text.tokenizers.handler import ( + StandardTextTokenizerHandler as StandardTextTokenizerHandler, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerHandler as TextTokenizerHandler, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerLoader as TextTokenizerLoader, +) +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerNotFoundError as TextTokenizerNotFoundError, +) +from fairseq2.data.text.tokenizers.handler import ( + get_text_tokenizer_family as get_text_tokenizer_family, +) from fairseq2.data.text.tokenizers.llama import ( LLAMA_TOKENIZER_FAMILY as LLAMA_TOKENIZER_FAMILY, ) @@ -28,24 +43,6 @@ from fairseq2.data.text.tokenizers.ref import ( resolve_text_tokenizer_reference as resolve_text_tokenizer_reference, ) -from fairseq2.data.text.tokenizers.register import ( - register_text_tokenizers as register_text_tokenizers, -) -from fairseq2.data.text.tokenizers.registry import ( - StandardTextTokenizerHandler as StandardTextTokenizerHandler, -) -from fairseq2.data.text.tokenizers.registry import ( - TextTokenizerHandler as TextTokenizerHandler, -) -from fairseq2.data.text.tokenizers.registry import ( - TextTokenizerLoader as TextTokenizerLoader, -) -from fairseq2.data.text.tokenizers.registry import ( - TextTokenizerRegistry as TextTokenizerRegistry, -) -from fairseq2.data.text.tokenizers.registry import ( - get_text_tokenizer_family as get_text_tokenizer_family, -) from fairseq2.data.text.tokenizers.s2t_transformer import ( S2T_TRANSFORMER_TOKENIZER_FAMILY as S2T_TRANSFORMER_TOKENIZER_FAMILY, ) @@ -79,9 +76,6 @@ from fairseq2.data.text.tokenizers.sentencepiece import ( vocab_info_from_sentencepiece as vocab_info_from_sentencepiece, ) -from fairseq2.data.text.tokenizers.static import ( - default_text_tokenizer_registry as default_text_tokenizer_registry, -) from fairseq2.data.text.tokenizers.static import ( load_text_tokenizer as load_text_tokenizer, ) diff --git a/src/fairseq2/data/text/tokenizers/registry.py b/src/fairseq2/data/text/tokenizers/handler.py similarity index 65% rename from src/fairseq2/data/text/tokenizers/registry.py rename to src/fairseq2/data/text/tokenizers/handler.py index 427864e66..d864f238d 100644 --- a/src/fairseq2/data/text/tokenizers/registry.py +++ b/src/fairseq2/data/text/tokenizers/handler.py @@ -8,28 +8,27 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Protocol, TypeAlias, final +from typing import Protocol, final from typing_extensions import override from fairseq2.assets import AssetCard, AssetDownloadManager -from fairseq2.context import Registry from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer class TextTokenizerHandler(ABC): @abstractmethod - def load( - self, - card: AssetCard, - asset_download_manager: AssetDownloadManager, - *, - force: bool = False, - ) -> TextTokenizer: + def load(self, card: AssetCard, *, force: bool = False) -> TextTokenizer: ... -TextTokenizerRegistry: TypeAlias = Registry[TextTokenizerHandler] +class TextTokenizerNotFoundError(LookupError): + name: str + + def __init__(self, name: str) -> None: + super().__init__(f"'{name}' is not a known text tokenizer.") + + self.name = name class TextTokenizerLoader(Protocol): @@ -40,21 +39,22 @@ def __call__(self, path: Path, card: AssetCard) -> TextTokenizer: @final class StandardTextTokenizerHandler(TextTokenizerHandler): _loader: TextTokenizerLoader + _asset_download_manager: AssetDownloadManager - def __init__(self, *, loader: TextTokenizerLoader) -> None: + def __init__( + self, + *, + loader: TextTokenizerLoader, + asset_download_manager: AssetDownloadManager, + ) -> None: self._loader = loader + self._asset_download_manager = asset_download_manager @override - def load( - self, - card: AssetCard, - asset_download_manager: AssetDownloadManager, - *, - force: bool = False, - ) -> TextTokenizer: + def load(self, card: AssetCard, *, force: bool = False) -> TextTokenizer: tokenizer_uri = card.field("tokenizer").as_uri() - path = asset_download_manager.download_tokenizer( + path = self._asset_download_manager.download_tokenizer( tokenizer_uri, card.name, force=force ) diff --git a/src/fairseq2/data/text/tokenizers/llama.py b/src/fairseq2/data/text/tokenizers/llama.py index 71e8950bd..cb5330c99 100644 --- a/src/fairseq2/data/text/tokenizers/llama.py +++ b/src/fairseq2/data/text/tokenizers/llama.py @@ -121,12 +121,12 @@ def load_llama_tokenizer(path: Path, card: AssetCard) -> TextTokenizer: return LLaMA3Tokenizer(path, instruct=eos_idx == eot_idx) except ValueError as ex: raise AssetCardError( - f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip ) from ex else: try: return BasicSentencePieceTokenizer(path) except ValueError as ex: raise AssetCardError( - f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip ) from ex diff --git a/src/fairseq2/data/text/tokenizers/nllb.py b/src/fairseq2/data/text/tokenizers/nllb.py index 1e85f9f38..0d3f6abee 100644 --- a/src/fairseq2/data/text/tokenizers/nllb.py +++ b/src/fairseq2/data/text/tokenizers/nllb.py @@ -136,5 +136,5 @@ def load_nllb_tokenizer(path: Path, card: AssetCard) -> NllbTokenizer: return NllbTokenizer(path, langs, default_lang) except ValueError as ex: raise AssetCardError( - f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip ) from ex diff --git a/src/fairseq2/data/text/tokenizers/register.py b/src/fairseq2/data/text/tokenizers/register.py deleted file mode 100644 index c588acd41..000000000 --- a/src/fairseq2/data/text/tokenizers/register.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from fairseq2.data.text.tokenizers.char_tokenizer import ( - CHAR_TOKENIZER_FAMILY, - load_char_tokenizer, -) -from fairseq2.data.text.tokenizers.llama import ( - LLAMA_TOKENIZER_FAMILY, - load_llama_tokenizer, -) -from fairseq2.data.text.tokenizers.mistral import ( - MISTRAL_TOKENIZER_FAMILY, - load_mistral_tokenizer, -) -from fairseq2.data.text.tokenizers.nllb import ( - NLLB_TOKENIZER_FAMILY, - load_nllb_tokenizer, -) -from fairseq2.data.text.tokenizers.registry import ( - StandardTextTokenizerHandler, - TextTokenizerRegistry, -) -from fairseq2.data.text.tokenizers.s2t_transformer import ( - S2T_TRANSFORMER_TOKENIZER_FAMILY, - load_s2t_transformer_tokenizer, -) -from fairseq2.extensions import run_extensions - - -def register_text_tokenizers(registry: TextTokenizerRegistry) -> None: - # Char Tokenizer - handler = StandardTextTokenizerHandler(loader=load_char_tokenizer) - - registry.register(CHAR_TOKENIZER_FAMILY, handler) - - # LLaMA - handler = StandardTextTokenizerHandler(loader=load_llama_tokenizer) - - registry.register(LLAMA_TOKENIZER_FAMILY, handler) - - # Mistral - handler = StandardTextTokenizerHandler(loader=load_mistral_tokenizer) - - registry.register(MISTRAL_TOKENIZER_FAMILY, handler) - - # NLLB - handler = StandardTextTokenizerHandler(loader=load_nllb_tokenizer) - - registry.register(NLLB_TOKENIZER_FAMILY, handler) - - # S2T Transformer - handler = StandardTextTokenizerHandler(loader=load_s2t_transformer_tokenizer) - - registry.register(S2T_TRANSFORMER_TOKENIZER_FAMILY, handler) - - # Extensions - run_extensions("register_fairseq2_text_tokenizers", registry) diff --git a/src/fairseq2/data/text/tokenizers/s2t_transformer.py b/src/fairseq2/data/text/tokenizers/s2t_transformer.py index 0c8b17835..aaf9165f8 100644 --- a/src/fairseq2/data/text/tokenizers/s2t_transformer.py +++ b/src/fairseq2/data/text/tokenizers/s2t_transformer.py @@ -126,5 +126,5 @@ def load_s2t_transformer_tokenizer( ) except ValueError as ex: raise AssetCardError( - f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." + card.name, f"The '{card.name}' asset card does not have a valid text tokenizer configuration. See the nested exception for details." # fmt: skip ) from ex diff --git a/src/fairseq2/data/text/tokenizers/static.py b/src/fairseq2/data/text/tokenizers/static.py index cd05ba3b2..d157993fe 100644 --- a/src/fairseq2/data/text/tokenizers/static.py +++ b/src/fairseq2/data/text/tokenizers/static.py @@ -6,33 +6,36 @@ from __future__ import annotations -from fairseq2.assets import ( - AssetCard, - default_asset_download_manager, - default_asset_store, -) -from fairseq2.data.text.tokenizers.ref import resolve_text_tokenizer_reference -from fairseq2.data.text.tokenizers.registry import ( - TextTokenizerRegistry, +from fairseq2.assets import AssetCard +from fairseq2.context import get_runtime_context +from fairseq2.data.text.tokenizers.handler import ( + TextTokenizerHandler, + TextTokenizerNotFoundError, get_text_tokenizer_family, ) +from fairseq2.data.text.tokenizers.ref import resolve_text_tokenizer_reference from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer -default_text_tokenizer_registry = TextTokenizerRegistry() - def load_text_tokenizer( name_or_card: str | AssetCard, *, force: bool = False ) -> TextTokenizer: + context = get_runtime_context() + if isinstance(name_or_card, AssetCard): card = name_or_card else: - card = default_asset_store.retrieve_card(name_or_card) + card = context.asset_store.retrieve_card(name_or_card) - card = resolve_text_tokenizer_reference(default_asset_store, card) + card = resolve_text_tokenizer_reference(context.asset_store, card) family = get_text_tokenizer_family(card) - handler = default_text_tokenizer_registry.get(family) + registry = context.get_registry(TextTokenizerHandler) + + try: + handler = registry.get(family) + except LookupError: + raise TextTokenizerNotFoundError(card.name) from None - return handler.load(card, default_asset_download_manager, force=force) + return handler.load(card, force=force) diff --git a/src/fairseq2/datasets/loader.py b/src/fairseq2/datasets/loader.py index 458cf1497..378389314 100644 --- a/src/fairseq2/datasets/loader.py +++ b/src/fairseq2/datasets/loader.py @@ -16,7 +16,7 @@ AssetDownloadManager, AssetError, AssetStore, - default_asset_download_manager, + InProcAssetDownloadManager, default_asset_store, ) @@ -66,7 +66,7 @@ def __init__( be used. """ self._asset_store = asset_store or default_asset_store - self._download_manager = download_manager or default_asset_download_manager + self._download_manager = download_manager or InProcAssetDownloadManager() @final def __call__( @@ -89,7 +89,7 @@ def __call__( ) except ValueError as ex: raise AssetCardError( - f"The value of the field 'data' of the asset card '{card.name}' must be a URI. See nested exception for details." + card.name, f"The value of the field 'data' of the asset card '{card.name}' must be a URI. See nested exception for details." # fmt: skip ) from ex try: diff --git a/src/fairseq2/extensions.py b/src/fairseq2/extensions.py index 759f5301e..9b1683d3b 100644 --- a/src/fairseq2/extensions.py +++ b/src/fairseq2/extensions.py @@ -14,41 +14,31 @@ from fairseq2.logging import log -def run_extensions(name: str, *args: Any, **kwargs: Any) -> None: +def run_extensions(extension_name: str, *args: Any, **kwargs: Any) -> None: should_trace = "FAIRSEQ2_EXTENSION_TRACE" in os.environ - for entry_point in entry_points(group="fairseq2.extension"): + for entry_point in entry_points(group=extension_name): try: - extension_module = entry_point.load() - except Exception as ex: + extension = entry_point.load() + + extension(*args, **kwargs) + except TypeError: if should_trace: raise ExtensionError( - entry_point.value, f"The `{entry_point.value}` extension module has failed to load. See the nested exception for details." # fmt: skip - ) from ex - - log.warning("The `{}` extension module has failed to load. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip - - continue + entry_point.value, f"The '{entry_point.value}' entry point is not a valid extension function." # fmt: skip + ) from None - try: - extension = getattr(extension_module, name) - except AttributeError: - continue - - try: - extension(*args, **kwargs) + log.warning("The '{}' entry point is not a valid extension function. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip except Exception as ex: if should_trace: raise ExtensionError( - entry_point.value, f"The `{entry_point.value}.{name}` extension function has failed. See the nested exception for details." # fmt: skip + entry_point.value, f"The '{entry_point.value}' extension function has failed. See the nested exception for details." # fmt: skip ) from ex - log.warning("The `{}.{}` extension function has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value, name) # fmt: skip - - continue + log.warning("The '{}' extension function has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip if should_trace: - log.info("The `{}.{}` extension function run successfully.", entry_point.value, name) # fmt: skip + log.info("The `{}` extension function run successfully.", entry_point.value) # fmt: skip class ExtensionError(Exception): diff --git a/src/fairseq2/models/config_loader.py b/src/fairseq2/models/config_loader.py index b314a28c6..24d0ee05f 100644 --- a/src/fairseq2/models/config_loader.py +++ b/src/fairseq2/models/config_loader.py @@ -92,7 +92,7 @@ def __call__( model_family = get_model_family(card) if model_family != self._family: raise AssetCardError( - f"The value of the field 'model_family' of the asset card '{card.name}' must be '{self._family}', but is '{model_family}' instead." + card.name, f"The value of the field 'model_family' of the asset card '{card.name}' must be '{self._family}', but is '{model_family}' instead." # fmt: skip ) config_kls = self._config_kls @@ -189,5 +189,5 @@ def get_model_family(card: AssetCard) -> str: return cast(str, card.field("model_type").as_(str)) except AssetCardFieldNotFoundError: raise AssetCardFieldNotFoundError( - f"The asset card '{card.name}' must have a field named 'model_family." + card.name, f"The asset card '{card.name}' must have a field named 'model_family." # fmt: skip ) from None diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index f4938cdbe..11b172781 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -18,7 +18,7 @@ AssetDownloadManager, AssetError, AssetStore, - default_asset_download_manager, + InProcAssetDownloadManager, default_asset_store, ) from fairseq2.gang import Gang @@ -82,7 +82,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT_co: """ :param model_name_or_card: @@ -99,9 +98,6 @@ def __call__( cache. :param progress: If ``True``, displays a progress bar to stderr. - :param strict_state_dict: - If ``True``, checkpoint' parameters and layers must be identical to - the model state dict) :returns: A model loaded from the checkpoint of ``model_name_or_card``. @@ -186,7 +182,7 @@ def __init__( that do not support PyTorch's ``reset_parameters()`` convention. """ self._asset_store = asset_store or default_asset_store - self._download_manager = download_manager or default_asset_download_manager + self._download_manager = download_manager or InProcAssetDownloadManager() self._tensor_loader = tensor_loader or load_tensors self._config_loader = config_loader self._factory = factory @@ -205,7 +201,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card @@ -229,7 +224,7 @@ def __call__( num_shards = card.field("num_shards").get_as_(int, default=1) if num_shards < 1: raise AssetCardError( - f"The value of the field 'num_shards' of the asset card '{card.name}' must be greater than or equal to 1, but is {num_shards} instead." + card.name, f"The value of the field 'num_shards' of the asset card '{card.name}' must be greater than or equal to 1, but is {num_shards} instead." # fmt: skip ) if num_shards > 1: @@ -290,7 +285,7 @@ def __call__( ) except ValueError as ex: raise AssetCardError( - f"The value of the field 'checkpoint' of the asset card '{card.name}' must be URI. See nested exception for details." + card.name, f"The value of the field 'checkpoint' of the asset card '{card.name}' must be URI. See nested exception for details." # fmt: skip ) from ex try: @@ -360,7 +355,7 @@ def __call__( consume_prefix_in_state_dict_if_present(state_dict, prefix="module.") try: - load_state_dict(model, state_dict, strict=strict_state_dict) + load_state_dict(model, state_dict) except (KeyError, ValueError) as ex: raise AssetError( f"{card.name} cannot be loaded. See nested exception for details." @@ -401,7 +396,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card @@ -425,7 +419,6 @@ def __call__( dtype=dtype, force=force, progress=progress, - strict_state_dict=strict_state_dict, ) def register(self, family: str, loader: ModelLoader[ModelT]) -> None: diff --git a/src/fairseq2/recipes/lm/chatbot.py b/src/fairseq2/recipes/lm/chatbot.py index d62f24a71..4aa97e386 100644 --- a/src/fairseq2/recipes/lm/chatbot.py +++ b/src/fairseq2/recipes/lm/chatbot.py @@ -16,7 +16,7 @@ from torch import Tensor from typing_extensions import override -from fairseq2.chatbots import Chatbot, ChatbotRegistry, ChatMessage, register_chatbots +from fairseq2.chatbots import Chatbot, ChatMessage, create_chatbot from fairseq2.data.text import TextTokenDecoder, TextTokenizer, load_text_tokenizer from fairseq2.error import InternalError from fairseq2.gang import Gang, is_torchrun @@ -166,19 +166,13 @@ def run(self, args: Namespace) -> int: sys.exit(1) - registry = ChatbotRegistry() - - register_chatbots(registry) - try: - handler = registry.get(model.family) + chatbot = create_chatbot(model.family, generator, tokenizer) except LookupError: log.exception("The chatbot cannot be created.") sys.exit(1) - chatbot = handler.make(generator, tokenizer) - rng_bag = RngBag.from_device_defaults(CPU, root_gang.device) # Set the seed for sequence generation. diff --git a/src/fairseq2/recipes/utils/asset.py b/src/fairseq2/recipes/utils/asset.py index 6bc66cd1b..c9ba43a28 100644 --- a/src/fairseq2/recipes/utils/asset.py +++ b/src/fairseq2/recipes/utils/asset.py @@ -34,7 +34,7 @@ def retrieve_asset_card(name_or_card: AssetReference) -> AssetCard: if isinstance(name_or_card, Path): if name_or_card.is_dir(): raise AssetNotFoundError( - f"An asset metadata file cannot be found at {name_or_card}." # fmt: skip + name_or_card.name, f"An asset metadata file cannot be found at {name_or_card}." # fmt: skip ) return _card_from_file(name_or_card) @@ -55,7 +55,7 @@ def retrieve_asset_card(name_or_card: AssetReference) -> AssetCard: if (file.suffix == ".yaml" or file.suffix == ".yml") and file.exists(): return _card_from_file(file) - raise AssetNotFoundError(f"An asset with the name '{name}' cannot be found.") + raise AssetNotFoundError(name, f"An asset with the name '{name}' cannot be found.") def _card_from_file(file: Path) -> AssetCard: diff --git a/src/fairseq2/setup.py b/src/fairseq2/setup.py deleted file mode 100644 index 2f7a1d813..000000000 --- a/src/fairseq2/setup.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import os - -from importlib_metadata import entry_points - -from fairseq2.assets import default_asset_store, setup_asset_store -from fairseq2.data.text import default_text_tokenizer_registry, register_text_tokenizers -from fairseq2.extensions import ExtensionError -from fairseq2.logging import log - -_setup_complete = False - - -def setup_fairseq2() -> None: - """ - Sets up fairseq2. - - As part of the initialization, this function also registers external - objects with via setuptools' `entry-point`__ mechanism. See - :doc:`/basics/runtime_extensions` for more information. - - .. important:: - - This function must be called before using any of the fairseq2 APIs. - - .. __: https://setuptools.pypa.io/en/latest/userguide/entry_point.html - """ - global _setup_complete - - if _setup_complete: - return - - # Mark as complete early on to avoid recursive calls. - _setup_complete = True - - setup_asset_store(default_asset_store) - - register_text_tokenizers(default_text_tokenizer_registry) - - _setup_legacy_extensions() - - -def _setup_legacy_extensions() -> None: - should_trace = "FAIRSEQ2_EXTENSION_TRACE" in os.environ - - for entry_point in entry_points(group="fairseq2"): - try: - extension = entry_point.load() - - extension() - except TypeError: - if should_trace: - raise ExtensionError( - entry_point.value, f"The '{entry_point.value}' entry point is not a valid extension function." # fmt: skip - ) from None - - log.warning("The '{}' entry point is not a valid extension function. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip - except Exception as ex: - if should_trace: - raise ExtensionError( - entry_point.value, f"The '{entry_point.value}' extension function has failed. See the nested exception for details." # fmt: skip - ) from ex - - log.warning("The '{}' extension function has failed. Set `FAIRSEQ2_EXTENSION_TRACE` environment variable to print the stack trace.", entry_point.value) # fmt: skip - - if should_trace: - log.info("The `{}` extension function run successfully.", entry_point.value) # fmt: skip diff --git a/src/fairseq2/setup/__init__.py b/src/fairseq2/setup/__init__.py new file mode 100644 index 000000000..28fd44ced --- /dev/null +++ b/src/fairseq2/setup/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.setup.assets import ( + register_package_metadata_provider as register_package_metadata_provider, +) +from fairseq2.setup.root import setup_fairseq2 as setup_fairseq2 +from fairseq2.setup.root import setup_runtime_context as setup_runtime_context +from fairseq2.setup.text_tokenizers import ( + register_text_tokenizer as register_text_tokenizer, +) diff --git a/src/fairseq2/setup/assets.py b/src/fairseq2/setup/assets.py new file mode 100644 index 000000000..86cf9dc1b --- /dev/null +++ b/src/fairseq2/setup/assets.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.assets import ( + FileAssetMetadataProvider, + PackageAssetMetadataProvider, + StandardAssetStore, + WheelPackageFileLister, + get_asset_dir, + get_user_asset_dir, +) +from fairseq2.context import RuntimeContext +from fairseq2.utils.file import StandardFileSystem +from fairseq2.utils.yaml import load_yaml + + +def _register_assets(context: RuntimeContext) -> None: + asset_store = context.asset_store + + # Package Metadata + register_package_metadata_provider(asset_store, "fairseq2.assets.cards") + + # /etc/fairseq2/assets + _register_asset_dir(asset_store) + + # ~/.config/fairseq2/assets + _register_user_asset_dir(asset_store) + + +def _register_asset_dir(asset_store: StandardAssetStore) -> None: + config_dir = get_asset_dir() + if config_dir is None: + return + + file_system = StandardFileSystem() + + provider = FileAssetMetadataProvider(config_dir, file_system, load_yaml) + + asset_store.metadata_providers.append(provider) + + +def _register_user_asset_dir(asset_store: StandardAssetStore) -> None: + config_dir = get_user_asset_dir() + if config_dir is None: + return + + file_system = StandardFileSystem() + + provider = FileAssetMetadataProvider(config_dir, file_system, load_yaml) + + asset_store.user_metadata_providers.append(provider) + + +def register_package_metadata_provider( + asset_store: StandardAssetStore, package_name: str +) -> None: + package_file_lister = WheelPackageFileLister() + + provider = PackageAssetMetadataProvider( + package_name, package_file_lister, load_yaml + ) + + asset_store.metadata_providers.append(provider) diff --git a/src/fairseq2/setup/chatbots.py b/src/fairseq2/setup/chatbots.py new file mode 100644 index 000000000..fb7f64efb --- /dev/null +++ b/src/fairseq2/setup/chatbots.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.chatbots import ChatbotHandler +from fairseq2.chatbots.llama import LLaMAChatbotHandler +from fairseq2.chatbots.mistral import MistralChatbotHandler +from fairseq2.context import RuntimeContext +from fairseq2.models.llama import LLAMA_FAMILY +from fairseq2.models.mistral import MISTRAL_FAMILY + + +def _register_chatbots(context: RuntimeContext) -> None: + registry = context.get_registry(ChatbotHandler) + + registry.register(LLAMA_FAMILY, LLaMAChatbotHandler()) + registry.register(MISTRAL_FAMILY, MistralChatbotHandler()) diff --git a/src/fairseq2/setup/root.py b/src/fairseq2/setup/root.py new file mode 100644 index 000000000..248eda7b1 --- /dev/null +++ b/src/fairseq2/setup/root.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.assets import InProcAssetDownloadManager, default_asset_store +from fairseq2.context import RuntimeContext, set_runtime_context +from fairseq2.extensions import run_extensions +from fairseq2.setup.assets import _register_assets +from fairseq2.setup.chatbots import _register_chatbots +from fairseq2.setup.text_tokenizers import _register_text_tokenizers + +_setup_called: bool = False + + +def setup_fairseq2() -> None: + """ + Sets up fairseq2. + + As part of the initialization, this function also registers extensions + with via setuptools' `entry-point`__ mechanism. See + :doc:`/basics/runtime_extensions` for more information. + + .. important:: + + This function must be called before using any of the fairseq2 APIs. + + .. __: https://setuptools.pypa.io/en/latest/userguide/entry_point.html + """ + global _setup_called + + if _setup_called: + return + + _setup_called = True # Mark as called to avoid recursive calls. + + context = setup_runtime_context() + + set_runtime_context(context) + + run_extensions("fairseq2") # compat + + +def setup_runtime_context() -> RuntimeContext: + asset_download_manager = InProcAssetDownloadManager() + + context = RuntimeContext(default_asset_store, asset_download_manager) + + _register_assets(context) + _register_chatbots(context) + _register_text_tokenizers(context) + + run_extensions("fairseq2.extension", context) + + return context diff --git a/src/fairseq2/setup/text_tokenizers.py b/src/fairseq2/setup/text_tokenizers.py new file mode 100644 index 000000000..969b63b2b --- /dev/null +++ b/src/fairseq2/setup/text_tokenizers.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.context import RuntimeContext +from fairseq2.data.text import ( + StandardTextTokenizerHandler, + TextTokenizerHandler, + TextTokenizerLoader, +) +from fairseq2.data.text.tokenizers.char_tokenizer import ( + CHAR_TOKENIZER_FAMILY, + load_char_tokenizer, +) +from fairseq2.data.text.tokenizers.llama import ( + LLAMA_TOKENIZER_FAMILY, + load_llama_tokenizer, +) +from fairseq2.data.text.tokenizers.mistral import ( + MISTRAL_TOKENIZER_FAMILY, + load_mistral_tokenizer, +) +from fairseq2.data.text.tokenizers.nllb import ( + NLLB_TOKENIZER_FAMILY, + load_nllb_tokenizer, +) +from fairseq2.data.text.tokenizers.s2t_transformer import ( + S2T_TRANSFORMER_TOKENIZER_FAMILY, + load_s2t_transformer_tokenizer, +) + + +def _register_text_tokenizers(context: RuntimeContext) -> None: + register_text_tokenizer( + context, + CHAR_TOKENIZER_FAMILY, + loader=load_char_tokenizer, + ) + + register_text_tokenizer( + context, + LLAMA_TOKENIZER_FAMILY, + loader=load_llama_tokenizer, + ) + + register_text_tokenizer( + context, + MISTRAL_TOKENIZER_FAMILY, + loader=load_mistral_tokenizer, + ) + + register_text_tokenizer( + context, + NLLB_TOKENIZER_FAMILY, + loader=load_nllb_tokenizer, + ) + + register_text_tokenizer( + context, + S2T_TRANSFORMER_TOKENIZER_FAMILY, + loader=load_s2t_transformer_tokenizer, + ) + + +def register_text_tokenizer( + context: RuntimeContext, family: str, *, loader: TextTokenizerLoader +) -> None: + handler = StandardTextTokenizerHandler( + loader=loader, asset_download_manager=context.asset_download_manager + ) + + registry = context.get_registry(TextTokenizerHandler) + + registry.register(family, handler) diff --git a/tests/integration/models/test_llama.py b/tests/integration/models/test_llama.py index adf5e1f7a..633d44589 100644 --- a/tests/integration/models/test_llama.py +++ b/tests/integration/models/test_llama.py @@ -9,7 +9,7 @@ import pytest -from fairseq2.assets import default_asset_download_manager, default_asset_store +from fairseq2.context import get_runtime_context from fairseq2.models.llama import create_llama_model, llama_archs from fairseq2.models.llama.integ import convert_to_reference_checkpoint from fairseq2.models.llama.loader import convert_llama_checkpoint @@ -22,11 +22,13 @@ "FAIR_ENV_CLUSTER" not in os.environ, reason="checkpoints only on faircluster" ) def test_convert_to_reference_checkpoint() -> None: + context = get_runtime_context() + model_config = llama_archs.get("llama2_7b") - card = default_asset_store.retrieve_card("llama2_7b") + card = context.asset_store.retrieve_card("llama2_7b") - path = default_asset_download_manager.download_checkpoint( + path = context.asset_download_manager.download_checkpoint( card.field("checkpoint").as_uri(), model_name="llama2_7b", progress=False )