Skip to content

Commit

Permalink
Use device_map value when device is unspecified
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 29, 2023
1 parent 38b0b10 commit 34b4530
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
8 changes: 5 additions & 3 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def transformers(
----------
model_name
The name of the model as listed on Hugging Face's model page.
device_map
device
The device(s) on which the model should be loaded. This overrides
the value passed for `device_map` in `model_kwargs`.
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
Expand All @@ -181,7 +181,9 @@ def transformers(
"The `transformers` library needs to be installed in order to use `transformers` models."
)

model_kwargs["device_map"] = device
if device is not None:
model_kwargs["device_map"] = device

model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformersTokenizer(model_name, **tokenizer_kwargs)

Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def test_model():
assert isinstance(model.tokenizer, TransformersTokenizer)
assert model.device.type == "cpu"

model = transformers(TEST_MODEL, model_kwargs={"device_map": "cpu"})
assert isinstance(model.tokenizer, TransformersTokenizer)
assert model.device.type == "cpu"

model = transformers(TEST_MODEL, device="cpu", model_kwargs={"device_map": "cuda"})
assert isinstance(model.tokenizer, TransformersTokenizer)
assert model.device.type == "cpu"

input_ids = torch.tensor([[0, 1, 2]])
logits = model(input_ids, torch.ones_like(input_ids))
assert logits.type() == "torch.FloatTensor"
Expand Down

0 comments on commit 34b4530

Please sign in to comment.