Skip to content

Commit

Permalink
Autodetect Model2Vec model paths, closes #822. Refactor vectors packa…
Browse files Browse the repository at this point in the history
…ge, closes #826.
  • Loading branch information
davidmezzetti committed Dec 1, 2024
1 parent 4b5164b commit 5fdbc90
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 70 deletions.
1 change: 1 addition & 0 deletions src/python/txtai/vectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .litellm import LiteLLM
from .llama import LlamaCpp
from .m2v import Model2Vec
from .sbert import STVectors
from .words import WordVectors
32 changes: 29 additions & 3 deletions src/python/txtai/vectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from ..pipeline import Tokenizer


class Vectors:
"""
Expand Down Expand Up @@ -147,7 +149,7 @@ def transform(self, document):
embeddings vector
"""

# Prepare input document for transformers model and build embeddings
# Prepare input document for vectors model and build embeddings
return self.batchtransform([document])[0]

def batchtransform(self, documents, category=None):
Expand All @@ -162,7 +164,7 @@ def batchtransform(self, documents, category=None):
embeddings vectors
"""

# Prepare input documents for transformers model
# Prepare input documents for vectors model
documents = [self.prepare(data, category) for _, data, _ in documents]

# Skip encoding data if it's already an array
Expand All @@ -183,7 +185,7 @@ def batch(self, documents, output):
(ids, dimensions) list of ids and number of dimensions in embeddings
"""

# Extract ids and prepare input documents for transformers model
# Extract ids and prepare input documents for vectors model
ids = [uid for uid, _, _ in documents]
documents = [self.prepare(data, "data") for _, data, _ in documents]
dimensions = None
Expand All @@ -208,6 +210,9 @@ def prepare(self, data, category=None):
data formatted for vector model
"""

# Prepares tokens for the model
data = self.tokens(data)

# Default instruction category
category = category if category else "query"

Expand All @@ -218,6 +223,27 @@ def prepare(self, data, category=None):

return data

def tokens(self, data):
"""
Prepare data as tokens model can accept.
Args:
data: input data
Returns:
tokens formatted for model
"""

# Optional string tokenization
if self.tokenize and isinstance(data, str):
data = Tokenizer.tokenize(data)

# Convert token list to string
if isinstance(data, list):
data = " ".join(data)

return data

def vectorize(self, data):
"""
Runs data vectorization, which consists of the following steps.
Expand Down
1 change: 1 addition & 0 deletions src/python/txtai/vectors/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def loadmodel(self, path):

def encode(self, data):
# Call external transform function, if available and data not already an array
# Batching is handed by the external transform function
if self.transform and data and not isinstance(data[0], np.ndarray):
data = self.transform(data)

Expand Down
9 changes: 8 additions & 1 deletion src/python/txtai/vectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .litellm import LiteLLM
from .llama import LlamaCpp
from .m2v import Model2Vec
from .sbert import STVectors
from .words import WordVectors


Expand Down Expand Up @@ -50,6 +51,10 @@ def create(config, scoring=None, models=None):
if method == "model2vec":
return Model2Vec(config, scoring, models)

# Sentence Transformers vectors
if method == "sentence-transformers":
return STVectors(config, scoring, models)

# Word vectors
if method == "words":
return WordVectors(config, scoring, models)
Expand All @@ -73,7 +78,7 @@ def method(config):
vector method
"""

# Determine vector method (external, litellm, llama.cpp, transformers or words)
# Determine vector method
method = config.get("method")
path = config.get("path")

Expand All @@ -84,6 +89,8 @@ def method(config):
method = "litellm"
elif LlamaCpp.ismodel(path):
method = "llama.cpp"
elif Model2Vec.ismodel(path):
method = "model2vec"
elif WordVectors.isdatabase(path):
method = "words"
else:
Expand Down
71 changes: 15 additions & 56 deletions src/python/txtai/vectors/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,14 @@
Hugging Face module
"""

# Conditional import
try:
from sentence_transformers import SentenceTransformer

SENTENCE_TRANSFORMERS = True
except ImportError:
SENTENCE_TRANSFORMERS = False
from ..models import Models, PoolingFactory

from .base import Vectors
from ..models import Models, PoolingFactory
from ..pipeline import Tokenizer


class HFVectors(Vectors):
"""
Builds vectors using the Hugging Face transformers library. Also supports the sentence-transformers library.
Builds vectors using the Hugging Face transformers library.
"""

@staticmethod
Expand All @@ -32,54 +24,21 @@ def ismethod(method):
True if this is a local transformers-based model, False otherwise
"""

return method in ("transformers", "sentence-transformers", "pooling", "clspooling", "meanpooling")
return method in ("transformers", "pooling", "clspooling", "meanpooling")

def loadmodel(self, path):
# Flag that determines if transformers or sentence-transformers should be used to build embeddings
method = self.config.get("method")
transformers = method != "sentence-transformers"

# Tensor device id
deviceid = Models.deviceid(self.config.get("gpu", True))

# Additional model arguments
modelargs = self.config.get("vectors", {})

# Build embeddings with transformers (default)
if transformers:
return PoolingFactory.create(
{
"method": method,
"path": path,
"device": deviceid,
"tokenizer": self.config.get("tokenizer"),
"maxlength": self.config.get("maxlength"),
"modelargs": modelargs,
}
)

# Otherwise, use sentence-transformers library
if not SENTENCE_TRANSFORMERS:
raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')

# Build embeddings with sentence-transformers
return SentenceTransformer(path, device=Models.device(deviceid), **modelargs)
# Build embeddings with transformers pooling
return PoolingFactory.create(
{
"method": self.config.get("method"),
"path": path,
"device": Models.deviceid(self.config.get("gpu", True)),
"tokenizer": self.config.get("tokenizer"),
"maxlength": self.config.get("maxlength"),
"modelargs": self.config.get("vectors", {}),
}
)

def encode(self, data):
# Get batch parameter name
param = "batch_size" if self.config.get("method") == "sentence-transformers" else "batch"

# Encode data using vectors model
return self.model.encode(data, **{param: self.encodebatch})

def prepare(self, data, category=None):
# Optional string tokenization
if self.tokenize and isinstance(data, str):
data = Tokenizer.tokenize(data)

# Convert token list to string
if isinstance(data, list):
data = " ".join(data)

# Add parent prepare logic
return super().prepare(data, category)
return self.model.encode(data, batch=self.encodebatch)
1 change: 1 addition & 0 deletions src/python/txtai/vectors/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def loadmodel(self, path):

def encode(self, data):
# Call external embeddings API using LiteLLM
# Batching is handled server-side
response = api.embedding(model=self.config.get("path"), input=data, **self.config.get("vectors", {}))

# Read response into a NumPy array
Expand Down
1 change: 1 addition & 0 deletions src/python/txtai/vectors/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def loadmodel(self, path):

def encode(self, data):
# Generate embeddings and return as a NumPy array
# llama-cpp-python has it's own batching built-in using n_batch parameter
return np.array(self.model.embed(data), dtype=np.float32)

def download(self, path):
Expand Down
32 changes: 31 additions & 1 deletion src/python/txtai/vectors/m2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Model2Vec module
"""

import json

from transformers.utils import cached_file

# Conditional import
try:
from model2vec import StaticModel
Expand All @@ -18,6 +22,32 @@ class Model2Vec(Vectors):
Builds vectors using Model2Vec.
"""

@staticmethod
def ismodel(path):
"""
Checks if path is a Model2Vec model.
Args:
path: input path
Returns:
True if this is a Model2Vec model, False otherwise
"""

try:
# Download file and parse JSON
path = cached_file(path_or_repo_id=path, filename="config.json")
if path:
with open(path, encoding="utf-8") as f:
config = json.load(f)
return config.get("model_type") == "model2vec"

# Ignore this error - invalid repo or directory
except OSError:
pass

return False

def __init__(self, config, scoring, models):
# Check before parent constructor since it calls loadmodel
if not MODEL2VEC:
Expand All @@ -29,4 +59,4 @@ def loadmodel(self, path):
return StaticModel.from_pretrained(path)

def encode(self, data):
return self.model.encode(data)
return self.model.encode(data, batch_size=self.encodebatch)
42 changes: 42 additions & 0 deletions src/python/txtai/vectors/sbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
SentenceTransformers module
"""

# Conditional import
try:
from sentence_transformers import SentenceTransformer

SENTENCE_TRANSFORMERS = True
except ImportError:
SENTENCE_TRANSFORMERS = False

from ..models import Models

from .base import Vectors


class STVectors(Vectors):
"""
Builds vectors using sentence-transformers (aka SBERT).
"""

def __init__(self, config, scoring, models):
# Check before parent constructor since it calls loadmodel
if not SENTENCE_TRANSFORMERS:
raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')

super().__init__(config, scoring, models)

def loadmodel(self, path):
# Tensor device id
deviceid = Models.deviceid(self.config.get("gpu", True))

# Additional model arguments
modelargs = self.config.get("vectors", {})

# Build embeddings with sentence-transformers
return SentenceTransformer(path, device=Models.device(deviceid), **modelargs)

def encode(self, data):
# Encode data using vectors model
return self.model.encode(data, batch_size=self.encodebatch)
4 changes: 4 additions & 0 deletions src/python/txtai/vectors/words.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def lookup(self, tokens):

return self.model.query(tokens)

def tokens(self, data):
# Skip tokenization rules
return data

@staticmethod
def isdatabase(path):
"""
Expand Down
8 changes: 0 additions & 8 deletions test/python/testvectors/testhuggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ def testIndex(self):
with open(stream, "rb") as queue:
self.assertEqual(np.load(queue).shape, (500, 768))

def testSentenceTransformers(self):
"""
Test creating a model with sentence transformers
"""

model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2"}, None)
self.assertEqual(model.transform((0, "This is a test", None)).shape, (384,))

def testText(self):
"""
Test transformers text conversion
Expand Down
2 changes: 1 addition & 1 deletion test/python/testvectors/testm2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setUpClass(cls):
Create Model2Vec instance.
"""

cls.model = VectorsFactory.create({"method": "model2vec", "path": "minishlab/M2V_base_output"}, None)
cls.model = VectorsFactory.create({"path": "minishlab/potion-base-8M"}, None)

def testIndex(self):
"""
Expand Down
Loading

0 comments on commit 5fdbc90

Please sign in to comment.