Skip to content

Commit

Permalink
Vector encoding using multiple-GPUs, closes #541
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 22, 2024
1 parent 34ffc97 commit 65a8316
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 15 deletions.
2 changes: 2 additions & 0 deletions docs/embeddings/configuration/vectors.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ gpu: boolean|int|string|device

Set the target device. Supports true/false, device id, device string and torch device instance. This is automatically derived if omitted.

The `sentence-transformers` method supports encoding with multiple GPUs. This can be enabled by setting the gpu parameter to `all`.

## batch
```yaml
batch: int
Expand Down
9 changes: 8 additions & 1 deletion src/python/txtai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def close(self):
"""

self.config, self.archive = None, None
self.reducer, self.query, self.model, self.models = None, None, None, None
self.reducer, self.query = None, None
self.ids = None

# Close ANN
Expand Down Expand Up @@ -698,6 +698,13 @@ def close(self):
self.indexes.close()
self.indexes = None

# Close vectors model
if self.model:
self.model.close()
self.model = None

self.models = None

def info(self):
"""
Prints the current embeddings index configuration.
Expand Down
11 changes: 11 additions & 0 deletions src/python/txtai/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ def reference(deviceid):
)
)

@staticmethod
def acceleratorcount():
"""
Gets the number of accelerator devices available.
Returns:
number of accelerators available
"""

return max(torch.cuda.device_count(), int(Models.hasaccelerator()))

@staticmethod
def hasaccelerator():
"""
Expand Down
7 changes: 7 additions & 0 deletions src/python/txtai/vectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def index(self, documents, batchsize=500):

return (ids, dimensions, batches, stream)

def close(self):
"""
Closes this vectors instance.
"""

self.model = None

def transform(self, document):
"""
Transforms document into an embeddings vector.
Expand Down
2 changes: 1 addition & 1 deletion src/python/txtai/vectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create(config, scoring=None, models=None):

# Sentence Transformers vectors
if method == "sentence-transformers":
return STVectors(config, scoring, models)
return STVectors(config, scoring, models) if config and config.get("path") else None

# Word vectors
if method == "words":
Expand Down
39 changes: 36 additions & 3 deletions src/python/txtai/vectors/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,51 @@ def __init__(self, config, scoring, models):
if not SENTENCE_TRANSFORMERS:
raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')

# Pool parameter created here since loadmodel is called from parent constructor
self.pool = None

super().__init__(config, scoring, models)

def loadmodel(self, path):
# Get target device
gpu, pool = self.config.get("gpu", True), False

# Default mode uses a single GPU. Setting to all spawns a process per GPU.
if isinstance(gpu, str) and gpu == "all":
# Get number of accelerator devices available
devices = Models.acceleratorcount()

# Enable multiprocessing pooling only when multiple devices are available
gpu, pool = devices <= 1, devices > 1

# Tensor device id
deviceid = Models.deviceid(self.config.get("gpu", True))
deviceid = Models.deviceid(gpu)

# Additional model arguments
modelargs = self.config.get("vectors", {})

# Build embeddings with sentence-transformers
return SentenceTransformer(path, device=Models.device(deviceid), **modelargs)
model = SentenceTransformer(path, device=Models.device(deviceid), **modelargs)

# Start process pool for multiple GPUs
if pool:
self.pool = model.start_multi_process_pool()

# Return model
return model

def encode(self, data):
# Encode data using vectors model
# Multiprocess encoding
if self.pool:
return self.model.encode_multi_process(data, self.pool, batch_size=self.encodebatch)

# Standard encoding
return self.model.encode(data, batch_size=self.encodebatch)

def close(self):
# Close pool before model is closed in parent method
if self.pool:
self.model.stop_multi_process_pool(self.pool)
self.pool = None

super().close()
2 changes: 1 addition & 1 deletion test/python/testvectors/testlitellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def tearDownClass(cls):

def testIndex(self):
"""
Test indexing with LiteLLM vectors.
Test indexing with LiteLLM vectors
"""

# LiteLLM vectors instance
Expand Down
2 changes: 1 addition & 1 deletion test/python/testvectors/testllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUpClass(cls):

def testIndex(self):
"""
Test indexing with LlamaCpp vectors.
Test indexing with LlamaCpp vectors
"""

ids, dimension, batches, stream = self.model.index([(0, "test", None)])
Expand Down
2 changes: 1 addition & 1 deletion test/python/testvectors/testm2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUpClass(cls):

def testIndex(self):
"""
Test indexing with Model2Vec vectors.
Test indexing with Model2Vec vectors
"""

ids, dimension, batches, stream = self.model.index([(0, "test", None)])
Expand Down
33 changes: 26 additions & 7 deletions test/python/testvectors/testsbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import unittest

from unittest.mock import patch

import numpy as np

from txtai.vectors import VectorsFactory
Expand All @@ -15,20 +17,34 @@ class TestSTVectors(unittest.TestCase):
STVectors tests
"""

@classmethod
def setUpClass(cls):
def testIndex(self):
"""
Create STVectors instance.
Test indexing with sentence-transformers vectors
"""

cls.model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2"}, None)
model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2"}, None)
ids, dimension, batches, stream = model.index([(0, "test", None)])

def testIndex(self):
self.assertEqual(len(ids), 1)
self.assertEqual(dimension, 384)
self.assertEqual(batches, 1)
self.assertIsNotNone(os.path.exists(stream))

# Test shape of serialized embeddings
with open(stream, "rb") as queue:
self.assertEqual(np.load(queue).shape, (1, 384))

@patch("torch.cuda.device_count")
def testMultiGPU(self, count):
"""
Test indexing with sentence-transformers vectors.
Test multiple gpu encoding
"""

ids, dimension, batches, stream = self.model.index([(0, "test", None)])
# Mock accelerator count
count.return_value = 2

model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2", "gpu": "all"}, None)
ids, dimension, batches, stream = model.index([(0, "test", None)])

self.assertEqual(len(ids), 1)
self.assertEqual(dimension, 384)
Expand All @@ -38,3 +54,6 @@ def testIndex(self):
# Test shape of serialized embeddings
with open(stream, "rb") as queue:
self.assertEqual(np.load(queue).shape, (1, 384))

# Close the multiprocessing pool
model.close()

0 comments on commit 65a8316

Please sign in to comment.