Skip to content

Commit

Permalink
Revise library setup
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Jan 3, 2025
1 parent dc96fce commit cfb71dc
Show file tree
Hide file tree
Showing 31 changed files with 435 additions and 388 deletions.
10 changes: 6 additions & 4 deletions src/fairseq2/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from fairseq2.assets.download_manager import (
InProcAssetDownloadManager as InProcAssetDownloadManager,
)
from fairseq2.assets.download_manager import (
default_asset_download_manager as default_asset_download_manager,
)
from fairseq2.assets.error import AssetCardError as AssetCardError
from fairseq2.assets.error import (
AssetCardFieldNotFoundError as AssetCardFieldNotFoundError,
Expand All @@ -43,9 +40,14 @@
from fairseq2.assets.metadata_provider import (
PackageAssetMetadataProvider as PackageAssetMetadataProvider,
)
from fairseq2.assets.metadata_provider import PackageFileLister as PackageFileLister
from fairseq2.assets.metadata_provider import (
WheelPackageFileLister as WheelPackageFileLister,
)
from fairseq2.assets.metadata_provider import load_metadata_file as load_metadata_file
from fairseq2.assets.store import AssetStore as AssetStore
from fairseq2.assets.store import EnvironmentResolver as EnvironmentResolver
from fairseq2.assets.store import StandardAssetStore as StandardAssetStore
from fairseq2.assets.store import default_asset_store as default_asset_store
from fairseq2.assets.store import setup_asset_store as setup_asset_store
from fairseq2.assets.store import get_asset_dir as get_asset_dir
from fairseq2.assets.store import get_user_asset_dir as get_user_asset_dir
24 changes: 14 additions & 10 deletions src/fairseq2/assets/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

from fairseq2.assets.error import AssetCardError, AssetCardFieldNotFoundError
from fairseq2.error import InternalError
from fairseq2.utils.structured import StructureError, structure, unstructure
from fairseq2.utils.structured import (
StructureError,
default_value_converter,
unstructure,
)


@final
Expand Down Expand Up @@ -75,7 +79,7 @@ def _get_field_value(self, leaf_card: AssetCard, path: list[str]) -> object:
pathname = ".".join(path)

raise AssetCardFieldNotFoundError(
f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'."
leaf_card.name, f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." # fmt: skip
)

try:
Expand All @@ -92,7 +96,7 @@ def _get_field_value(self, leaf_card: AssetCard, path: list[str]) -> object:
pathname = ".".join(path)

raise AssetCardFieldNotFoundError(
f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'."
leaf_card.name, f"The '{leaf_card.name}' asset card does not have a field named '{pathname}'." # fmt: skip
)

return metadata
Expand All @@ -118,7 +122,7 @@ def _set_field_value(self, path: list[str], value: object) -> None:
pathname = ".".join(path)

raise AssetCardError(
f"The '{self._name}' asset card cannot have a field named '{pathname}' due to path conflict at '{conflict_pathname}'."
self._name, f"The '{self._name}' asset card cannot have a field named '{pathname}' due to path conflict at '{conflict_pathname}'." # fmt: skip
)

metadata = value_
Expand Down Expand Up @@ -211,12 +215,12 @@ def as_(self, type_: object, *, allow_empty: bool = False) -> Any:
unstructured_value = self._card._get_field_value(self._card, self._path)

try:
value = structure(unstructured_value, type_)
value = default_value_converter.structure(unstructured_value, type_)
except StructureError as ex:
pathname = ".".join(self._path)

raise AssetCardError(
f"The value of the '{pathname}' field of the '{self._card.name}' asset card cannot be parsed as `{type_}`. See the nested exception for details."
self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card cannot be parsed as `{type_}`. See the nested exception for details." # fmt: skip
) from ex

if value is None:
Expand All @@ -226,7 +230,7 @@ def as_(self, type_: object, *, allow_empty: bool = False) -> Any:
pathname = ".".join(self._path)

raise AssetCardError(
f"The value of the '{pathname}' field of the '{self._card.name}' asset card is empty."
self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is empty." # fmt: skip
)

return value
Expand All @@ -252,7 +256,7 @@ def as_one_of(self, valid_values: Set[str]) -> str:
s = ", ".join(values)

raise AssetCardError(
f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be one of the following values, but is '{value}' instead: {s}"
self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be one of the following values, but is '{value}' instead: {s}" # fmt: skip
)

return value
Expand All @@ -274,7 +278,7 @@ def as_uri(self) -> str:
pathname = ".".join(self._path)

raise AssetCardError(
f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a URI or an absolute pathname, but is '{value}' instead."
self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a URI or an absolute pathname, but is '{value}' instead." # fmt: skip
) from None

def as_filename(self) -> str:
Expand All @@ -285,7 +289,7 @@ def as_filename(self) -> str:
pathname = ".".join(self._path)

raise AssetCardError(
f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a filename, but is '{value}' instead."
self._card.name, f"The value of the '{pathname}' field of the '{self._card.name}' asset card is expected to be a filename, but is '{value}' instead." # fmt: skip
)

return value
Expand Down
7 changes: 6 additions & 1 deletion src/fairseq2/assets/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ class AssetError(Exception):


class AssetCardError(AssetError):
pass
name: str

def __init__(self, name: str, message: str) -> None:
super().__init__(message)

self.name = name


class AssetCardNotFoundError(AssetCardError):
Expand Down
46 changes: 8 additions & 38 deletions src/fairseq2/assets/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
from fairseq2.assets.metadata_provider import (
AssetMetadataNotFoundError,
AssetMetadataProvider,
FileAssetMetadataProvider,
PackageAssetMetadataProvider,
WheelPackageFileLister,
)
from fairseq2.error import ContractError
from fairseq2.extensions import run_extensions
from fairseq2.utils.env import get_path_from_env
from fairseq2.utils.file import StandardFileSystem
from fairseq2.utils.yaml import load_yaml

AssetScope: TypeAlias = Literal["all", "global", "user"]
Expand Down Expand Up @@ -118,7 +115,7 @@ def _do_retrieve_card(
metadata = self._get_metadata(f"{name}@", scope)
except AssetMetadataNotFoundError:
raise AssetCardNotFoundError(
f"An asset card with name '{name}' is not found."
name, f"An asset card with name '{name}' is not found."
) from None

# If we have environment-specific metadata, merge it with `metadata`.
Expand Down Expand Up @@ -157,7 +154,7 @@ def contract_error(
base_card = self._do_retrieve_card(base_name, envs, scope)
except AssetCardNotFoundError:
raise AssetCardError(
f"A transitive base asset card with name '{name}' is not found."
name, f"A transitive base asset card with name '{base_name}' is not found." # fmt: skip
) from None

base_path = metadata.get("__base_path__")
Expand Down Expand Up @@ -217,20 +214,6 @@ def clear_cache(self) -> None:
for provider in self.user_metadata_providers:
provider.clear_cache()

def add_file_metadata_provider(self, path: Path, user: bool = False) -> None:
"""Add a new :class:`FileAssetMetadataProvider` pointing to ``path``.
:param path: The directory under which asset metadata is stored.
:param user: If ``True``, adds the metadata provider to the user scope.
"""
file_system = StandardFileSystem()

provider = FileAssetMetadataProvider(path, file_system, load_yaml)

providers = self.user_metadata_providers if user else self.metadata_providers

providers.append(provider)

def add_package_metadata_provider(self, package_name: str) -> None:
"""Add a new :class:`PackageAssetMetadataProvider` for ``package_name``.
Expand All @@ -254,30 +237,17 @@ def __call__(self) -> str | None:
default_asset_store = StandardAssetStore()


def setup_asset_store(store: StandardAssetStore) -> None:
store.add_package_metadata_provider("fairseq2.assets.cards")

# /etc/fairseq2/assets
_add_etc_dir_metadata_provider(store)

# ~/.config/fairseq2/assets
_add_home_config_dir_metadata_provider(store)

# Extensions
run_extensions("setup_fairseq2_asset_store", store)


def _add_etc_dir_metadata_provider(store: StandardAssetStore) -> None:
def get_asset_dir() -> Path | None:
asset_dir = get_path_from_env("FAIRSEQ2_ASSET_DIR")
if asset_dir is None:
asset_dir = Path("/etc/fairseq2/assets").resolve()
if not asset_dir.exists():
return
return None

store.add_file_metadata_provider(asset_dir)
return asset_dir


def _add_home_config_dir_metadata_provider(store: StandardAssetStore) -> None:
def get_user_asset_dir() -> Path | None:
asset_dir = get_path_from_env("FAIRSEQ2_USER_ASSET_DIR")
if asset_dir is None:
asset_dir = get_path_from_env("XDG_CONFIG_HOME")
Expand All @@ -286,6 +256,6 @@ def _add_home_config_dir_metadata_provider(store: StandardAssetStore) -> None:

asset_dir = asset_dir.joinpath("fairseq2/assets").resolve()
if not asset_dir.exists():
return
return None

store.add_file_metadata_provider(asset_dir, user=True)
return asset_dir
8 changes: 3 additions & 5 deletions src/fairseq2/chatbots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from fairseq2.chatbots.chatbot import Chatbot as Chatbot
from fairseq2.chatbots.chatbot import ChatDialog as ChatDialog
from fairseq2.chatbots.chatbot import ChatMessage as ChatMessage
from fairseq2.chatbots.register import register_chatbots as register_chatbots
from fairseq2.chatbots.registry import ChatbotFactory as ChatbotFactory
from fairseq2.chatbots.registry import ChatbotHandler as ChatbotHandler
from fairseq2.chatbots.registry import ChatbotRegistry as ChatbotRegistry
from fairseq2.chatbots.registry import StandardChatbotHandler as StandardChatbotHandler
from fairseq2.chatbots.handler import ChatbotHandler as ChatbotHandler
from fairseq2.chatbots.handler import ChatbotNotFoundError as ChatbotNotFoundError
from fairseq2.chatbots.static import create_chatbot as create_chatbot
28 changes: 28 additions & 0 deletions src/fairseq2/chatbots/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod

from fairseq2.chatbots.chatbot import Chatbot
from fairseq2.data.text import TextTokenizer
from fairseq2.generation.generator import SequenceGenerator


class ChatbotHandler(ABC):
@abstractmethod
def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot:
...


class ChatbotNotFoundError(LookupError):
name: str

def __init__(self, name: str) -> None:
super().__init__(f"'{name}' is not a known chatbot.")

self.name = name
14 changes: 8 additions & 6 deletions src/fairseq2/chatbots/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing_extensions import override

from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog, ChatMessage
from fairseq2.chatbots.handler import ChatbotHandler
from fairseq2.data.text import LLaMA3Tokenizer, TextTokenEncoder, TextTokenizer
from fairseq2.generation import SequenceGenerator
from fairseq2.nn.utils.module import infer_device
Expand Down Expand Up @@ -206,10 +207,11 @@ def supports_system_prompt(self) -> bool:
return True


def make_llama_chatbot(
generator: SequenceGenerator, tokenizer: TextTokenizer
) -> Chatbot:
if isinstance(tokenizer, LLaMA3Tokenizer):
return LLaMA3Chatbot(generator, tokenizer)
@final
class LLaMAChatbotHandler(ChatbotHandler):
@override
def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot:
if isinstance(tokenizer, LLaMA3Tokenizer):
return LLaMA3Chatbot(generator, tokenizer)

return LLaMAChatbot(generator, tokenizer)
return LLaMAChatbot(generator, tokenizer)
12 changes: 7 additions & 5 deletions src/fairseq2/chatbots/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from torch import Tensor
from typing_extensions import override

from fairseq2.chatbots.chatbot import AbstractChatbot, ChatDialog
from fairseq2.chatbots.chatbot import AbstractChatbot, Chatbot, ChatDialog
from fairseq2.chatbots.handler import ChatbotHandler
from fairseq2.data.text import TextTokenEncoder, TextTokenizer
from fairseq2.generation import SequenceGenerator
from fairseq2.nn.utils.module import infer_device
Expand Down Expand Up @@ -91,7 +92,8 @@ def supports_system_prompt(self) -> bool:
return False


def make_mistral_chatbot(
generator: SequenceGenerator, tokenizer: TextTokenizer
) -> MistralChatbot:
return MistralChatbot(generator, tokenizer)
@final
class MistralChatbotHandler(ChatbotHandler):
@override
def create(self, generator: SequenceGenerator, tokenizer: TextTokenizer) -> Chatbot:
return MistralChatbot(generator, tokenizer)
29 changes: 0 additions & 29 deletions src/fairseq2/chatbots/register.py

This file was deleted.

43 changes: 0 additions & 43 deletions src/fairseq2/chatbots/registry.py

This file was deleted.

Loading

0 comments on commit cfb71dc

Please sign in to comment.