Skip to content

Commit

Permalink
Add aliases and deprecation warnings for old API
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 30, 2023
1 parent ea4e418 commit bcfb5d7
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,32 @@
from outlines.models.transformers import TransformerTokenizer


def test_deprecation():
import outlines

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")

with pytest.warns(DeprecationWarning):
outlines.text.generate.continuation(model, max_tokens=10)

with pytest.warns(DeprecationWarning):
outlines.text.generate.choice(model, ["A", "B"], max_tokens=10)

with pytest.warns(DeprecationWarning):
outlines.text.generate.regex(model, "[0-9]", max_tokens=10)

with pytest.warns(DeprecationWarning):
outlines.text.generate.format(model, int, max_tokens=10)

with pytest.warns(DeprecationWarning):

def function(a: int):
pass

outlines.text.generate.json(model, function, max_tokens=10)


def test_transformers_integration_text():
rng = torch.Generator()
rng.manual_seed(10000) # Choosen so <EOS> is generated
Expand Down
Empty file added text/__init__.py
Empty file.
1 change: 1 addition & 0 deletions text/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .api import choice, continuation, format, json, regex
69 changes: 69 additions & 0 deletions text/generate/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import warnings
from typing import Callable, List, Optional, Union

import outlines
from outlines.generate.samplers import Sampler, multinomial


def json(
model,
schema_object: Union[str, object, Callable],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
warnings.warn(
"`outlines.text.generate.json` is deprecated, please use `outlines.generate.json` instead. "
"The old import path will be removed in Outlines v0.0.15.",
DeprecationWarning,
)
return outlines.generate.json(model, schema_object, max_tokens, sampler=sampler)


def regex(
model,
regex_str: str,
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
warnings.warn(
"`outlines.text.generate.regex` is deprecated, please use `outlines.generate.regex` instead. "
"The old import path will be removed in Outlines v0.0.15.",
DeprecationWarning,
)
return outlines.generate.regex(model, regex_str, max_tokens, sampler=sampler)


def format(
model, python_type, max_tokens: Optional[int] = None, sampler: Sampler = multinomial
):
warnings.warn(
"`outlines.text.generate.format` is deprecated, please use `outlines.generate.format` instead. "
"The old import path will be removed in Outlines v0.0.15.",
DeprecationWarning,
)
return outlines.generate.format(model, python_type, max_tokens, sampler=sampler)


def continuation(
model, max_tokens: Optional[int] = None, sampler: Sampler = multinomial
):
warnings.warn(
"`outlines.text.generate.continuation` is deprecated, please use `outlines.generate.text` instead. "
"The old import path will be removed in Outlines v0.0.15.",
DeprecationWarning,
)
return outlines.generate.text(model, max_tokens, sampler=sampler)


def choice(
model,
choices: List[str],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
warnings.warn(
"`outlines.text.generate.choice` is deprecated, please use `outlines.generate.choice` instead. "
"The old import path will be removed in Outlines v0.0.15.",
DeprecationWarning,
)
return outlines.generate.choice(model, choices, max_tokens, sampler)

0 comments on commit bcfb5d7

Please sign in to comment.