Skip to content

Commit

Permalink
Introduce RuntimeContext (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jan 3, 2025
1 parent b4b29e5 commit dc96fce
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/chatbots/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from typing_extensions import override

from fairseq2.chatbots.chatbot import Chatbot
from fairseq2.context import Registry
from fairseq2.data.text import TextTokenizer
from fairseq2.generation.generator import SequenceGenerator
from fairseq2.utils.registry import Registry


class ChatbotHandler(ABC):
Expand Down
103 changes: 103 additions & 0 deletions src/fairseq2/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Hashable, Iterable
from typing import Any, Generic, Mapping, TypeVar, final

from typing_extensions import override

from fairseq2.assets import AssetDownloadManager, StandardAssetStore
from fairseq2.error import AlreadyExistsError

T = TypeVar("T")


@final
class RuntimeContext:
_asset_store: StandardAssetStore
_asset_download_manager: AssetDownloadManager
_registries: Mapping[type, Registry[Any]]

def __init__(
self,
asset_store: StandardAssetStore,
asset_download_manager: AssetDownloadManager,
) -> None:
self._asset_store = asset_store
self._asset_download_manager = asset_download_manager

self._registries = defaultdict(Registry)

@property
def asset_store(self) -> StandardAssetStore:
return self._asset_store

@property
def asset_download_manager(self) -> AssetDownloadManager:
return self._asset_download_manager

def get_registry(self, kls: type[T]) -> Registry[T]:
return self._registries[kls]


T_co = TypeVar("T_co", covariant=True)


class Provider(ABC, Generic[T_co]):
@abstractmethod
def get(self, key: Hashable) -> T_co:
...

@abstractmethod
def get_all(self) -> Iterable[tuple[Hashable, T_co]]:
...


@final
class Registry(Provider[T]):
_entries: dict[Hashable, T]

def __init__(self) -> None:
self._entries = {}

@override
def get(self, key: Hashable) -> T:
try:
return self._entries[key]
except KeyError:
raise LookupError(f"The registry does not contain a '{key}' key.") from None

@override
def get_all(self) -> Iterable[tuple[Hashable, T]]:
return self._entries.items()

def register(self, key: Hashable, value: T) -> None:
if key in self._entries:
raise AlreadyExistsError(f"The registry already contains a '{key}' key.")

self._entries[key] = value


_default_context: RuntimeContext | None = None


def set_runtime_context(context: RuntimeContext) -> None:
global _default_context

_default_context = context


def get_runtime_context() -> RuntimeContext:
if _default_context is None:
raise RuntimeError(
"fairseq2 is not initialized. Make sure to call `fairseq2.setup_fairseq2()`."
)

return _default_context
3 changes: 2 additions & 1 deletion src/fairseq2/data/parquet/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -165,7 +166,7 @@ def load_one_fragment(

def apply_filter(
table: pa.Table,
filters: pa.dataset.Expression | None = None,
filters: list[Any] | pa.dataset.Expression | None = None,
drop_null: bool = True,
) -> pa.Table:
if drop_null:
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/data/text/tokenizers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from typing_extensions import override

from fairseq2.assets import AssetCard, AssetDownloadManager
from fairseq2.context import Registry
from fairseq2.data.text.tokenizers.tokenizer import TextTokenizer
from fairseq2.utils.registry import Registry


class TextTokenizerHandler(ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/factory_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def register(factory: Factory[ConfigT, P, R]) -> Factory[ConfigT, P, R]:
f"The first parameter of the decorated factory `{factory}` must be a dataclass."
)

self.register(name, factory, config_kls)
self.register(name, factory, config_kls) # type: ignore[arg-type]

return factory

Expand Down
43 changes: 0 additions & 43 deletions src/fairseq2/utils/registry.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/unit/data/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def test_init_works_when_input_buffer_is_shared_and_is_of_type_float(self) -> No
assert view.shape == (12,)
assert view.strides == (1,)

view = view.cast("f")
float_view = view.cast("f")

assert view.tolist() == pytest.approx([0.2, 0.4, 0.6])
assert float_view.tolist() == pytest.approx([0.2, 0.4, 0.6])

def test_init_works_when_copy_is_true(self) -> None:
arr = array("B", [0, 1, 2, 3])
Expand Down

0 comments on commit dc96fce

Please sign in to comment.