From 0012321ee52a21d0a7f88d65b1e209006accae39 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:58:16 -0400 Subject: [PATCH] Fix issues related to torch device lookup for non-CUDA GPU devices, closes #551 --- src/python/txtai/pipeline/hfmodel.py | 2 +- src/python/txtai/pipeline/hfpipeline.py | 7 ++++--- src/python/txtai/vectors/transformers.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/python/txtai/pipeline/hfmodel.py b/src/python/txtai/pipeline/hfmodel.py index cebd485c2..04444e87d 100644 --- a/src/python/txtai/pipeline/hfmodel.py +++ b/src/python/txtai/pipeline/hfmodel.py @@ -31,7 +31,7 @@ def __init__(self, path=None, quantize=False, gpu=False, batch=64): # Get tensor device reference self.deviceid = Models.deviceid(gpu) - self.device = Models.reference(self.deviceid) + self.device = Models.device(self.deviceid) # Process batch size self.batchsize = batch diff --git a/src/python/txtai/pipeline/hfpipeline.py b/src/python/txtai/pipeline/hfpipeline.py index b40107baf..dde01d763 100644 --- a/src/python/txtai/pipeline/hfpipeline.py +++ b/src/python/txtai/pipeline/hfpipeline.py @@ -36,8 +36,9 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwa # Check if input model is a Pipeline or a HF pipeline self.pipeline = model.pipeline if isinstance(model, HFPipeline) else model else: - # Get device id + # Get device deviceid = Models.deviceid(gpu) if "device_map" not in kwargs else None + device = Models.device(deviceid) if deviceid is not None else None # Split into model args, pipeline args modelargs, kwargs = self.parseargs(**kwargs) @@ -50,9 +51,9 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwa # Load model model = Models.load(path[0], config, task) - self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=deviceid, model_kwargs=modelargs, **kwargs) + self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=device, model_kwargs=modelargs, **kwargs) else: - self.pipeline = pipeline(task, model=path, device=deviceid, model_kwargs=modelargs, **kwargs) + self.pipeline = pipeline(task, model=path, device=device, model_kwargs=modelargs, **kwargs) # Model quantization. Compresses model to int8 precision, improves runtime performance. Only supported on CPU. if deviceid == -1 and quantize: diff --git a/src/python/txtai/vectors/transformers.py b/src/python/txtai/vectors/transformers.py index 518cd2e48..70575df15 100644 --- a/src/python/txtai/vectors/transformers.py +++ b/src/python/txtai/vectors/transformers.py @@ -41,7 +41,7 @@ def load(self, path): raise ImportError('sentence-transformers is not available - install "similarity" extra to enable') # Build embeddings with sentence-transformers - return SentenceTransformer(path, device=Models.reference(deviceid)) + return SentenceTransformer(path, device=Models.device(deviceid)) def encode(self, data): # Encode data using vectors model