Skip to content

Commit

Permalink
Fix issues related to torch device lookup for non-CUDA GPU devices, c…
Browse files Browse the repository at this point in the history
…loses #551
  • Loading branch information
davidmezzetti committed Sep 26, 2023
1 parent b1b07de commit 0012321
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/python/txtai/pipeline/hfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, path=None, quantize=False, gpu=False, batch=64):

# Get tensor device reference
self.deviceid = Models.deviceid(gpu)
self.device = Models.reference(self.deviceid)
self.device = Models.device(self.deviceid)

# Process batch size
self.batchsize = batch
Expand Down
7 changes: 4 additions & 3 deletions src/python/txtai/pipeline/hfpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwa
# Check if input model is a Pipeline or a HF pipeline
self.pipeline = model.pipeline if isinstance(model, HFPipeline) else model
else:
# Get device id
# Get device
deviceid = Models.deviceid(gpu) if "device_map" not in kwargs else None
device = Models.device(deviceid) if deviceid is not None else None

# Split into model args, pipeline args
modelargs, kwargs = self.parseargs(**kwargs)
Expand All @@ -50,9 +51,9 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwa
# Load model
model = Models.load(path[0], config, task)

self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=deviceid, model_kwargs=modelargs, **kwargs)
self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=device, model_kwargs=modelargs, **kwargs)
else:
self.pipeline = pipeline(task, model=path, device=deviceid, model_kwargs=modelargs, **kwargs)
self.pipeline = pipeline(task, model=path, device=device, model_kwargs=modelargs, **kwargs)

# Model quantization. Compresses model to int8 precision, improves runtime performance. Only supported on CPU.
if deviceid == -1 and quantize:
Expand Down
2 changes: 1 addition & 1 deletion src/python/txtai/vectors/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def load(self, path):
raise ImportError('sentence-transformers is not available - install "similarity" extra to enable')

# Build embeddings with sentence-transformers
return SentenceTransformer(path, device=Models.reference(deviceid))
return SentenceTransformer(path, device=Models.device(deviceid))

def encode(self, data):
# Encode data using vectors model
Expand Down

0 comments on commit 0012321

Please sign in to comment.