Skip to content

Commit

Permalink
0.4.2 - swap sentence_transformers for onnx so build is smol
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Aug 10, 2023
1 parent 80baa1a commit 6cc0f19
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 19 deletions.
4 changes: 4 additions & 0 deletions agentmemory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
cluster,
)

from .check_model import check_model, infer_embeddings

__all__ = [
"create_memory",
"create_unique_memory",
Expand Down Expand Up @@ -70,4 +72,6 @@
"reset_epoch",
"set_epoch",
"cluster",
"check_model",
"infer_embeddings"
]
82 changes: 82 additions & 0 deletions agentmemory/check_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import requests
import tarfile
from pathlib import Path
from tqdm import tqdm

def _download(url: str, fname: Path, chunk_size: int = 1024) -> None:
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=str(fname),
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)

default_model_path = str(Path.home() / ".cache" / "onnx_models")

def check_model(model_name = "all-MiniLM-L6-v2", model_path = default_model_path) -> str:
DOWNLOAD_PATH = Path(model_path) / model_name
ARCHIVE_FILENAME = "onnx.tar.gz"
MODEL_DOWNLOAD_URL = f"https://chroma-onnx-models.s3.amazonaws.com/{model_name}/onnx.tar.gz"

# Check if model is not downloaded yet
if not os.path.exists(DOWNLOAD_PATH / ARCHIVE_FILENAME):
os.makedirs(DOWNLOAD_PATH, exist_ok=True)
_download(MODEL_DOWNLOAD_URL, DOWNLOAD_PATH / ARCHIVE_FILENAME)

with tarfile.open(DOWNLOAD_PATH / ARCHIVE_FILENAME, "r:gz") as tar:
tar.extractall(DOWNLOAD_PATH)

return str(DOWNLOAD_PATH / "onnx")

import importlib
import numpy as np
from tokenizers import Tokenizer
import onnxruntime
import numpy.typing as npt
from typing import List

def _normalize(v: npt.NDArray) -> npt.NDArray:
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis]

def infer_embeddings(documents: List[str], model_path: str, batch_size: int = 32) -> npt.NDArray:
# Load the tokenizer and model
tokenizer = Tokenizer.from_file(model_path + "/tokenizer.json")
tokenizer.enable_truncation(max_length=256)
tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256)
model = onnxruntime.InferenceSession(model_path + "/model.onnx")

all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
encoded = [tokenizer.encode(d) for d in batch]
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
onnx_input = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"token_type_ids": np.array(
[np.zeros(len(e), dtype=np.int64) for e in input_ids],
dtype=np.int64,
),
}
model_output = model.run(None, onnx_input)
last_hidden_state = model_output[0]
# Perform mean pooling with attention weighting
input_mask_expanded = np.broadcast_to(
np.expand_dims(attention_mask, -1), last_hidden_state.shape
)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(
input_mask_expanded.sum(1), a_min=1e-9, a_max=None
)
embeddings = _normalize(embeddings).astype(np.float32)
all_embeddings.append(embeddings)
return np.concatenate(all_embeddings)
6 changes: 1 addition & 5 deletions agentmemory/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import os
import json

import chromadb
import psycopg2
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

from agentmemory.postgres import PostgresClient
Expand Down Expand Up @@ -36,4 +32,4 @@ def get_client(client_type=None, *args, **kwargs):
else:
client = chromadb.PersistentClient(path=STORAGE_PATH, *args, **kwargs)

return client
return client
17 changes: 10 additions & 7 deletions agentmemory/postgres.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
from pathlib import Path
import psycopg2

from agentmemory.check_model import check_model, infer_embeddings


class PostgresCollection:
def __init__(self, category, client):
Expand Down Expand Up @@ -148,18 +151,17 @@ class PostgresCategory:
def __init__(self, name):
self.name = name

default_model_path = str(Path.home() / ".cache" / "onnx_models")

class PostgresClient:
def __init__(self, connection_string):
def __init__(self, connection_string, model_name = "all-MiniLM-L6-v2", model_path = default_model_path):
self.connection = psycopg2.connect(connection_string)
self.cur = self.connection.cursor()
from pgvector.psycopg2 import register_vector

register_vector(self.cur) # Register PGVector functions

from sentence_transformers import SentenceTransformer

self.model = SentenceTransformer("all-MiniLM-L6-v2")
full_model_path = check_model(model_name=model_name, model_path=model_path)
self.model_path = full_model_path

def _table_name(self, category):
return f"memory_{category}"
Expand Down Expand Up @@ -241,7 +243,8 @@ def insert_memory(self, category, document, metadata={}, embedding=None, id=None
return self.cur.fetchone()[0]

def create_embedding(self, document):
return self.model.encode(document, normalize_embeddings=True)
embeddings = infer_embeddings([document], model_path=self.model_path)
return embeddings[0]

def add(self, category, documents, metadatas, ids):
self.ensure_table_exists(category)
Expand Down
9 changes: 5 additions & 4 deletions agentmemory/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# from .helpers import *
from .helpers import *
from .main import *
# from .persistence import *
# from .events import *
# from .clustering import *
from .persistence import *
from .events import *
# from .clustering import *
from .check_model import *
40 changes: 40 additions & 0 deletions agentmemory/tests/check_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from pathlib import Path
import shutil
import tempfile
from agentmemory.check_model import check_model, infer_embeddings

def test_check_model():
model_name = "all-MiniLM-L6-v2"
temp_dir = tempfile.mkdtemp()
model_path = str(Path(temp_dir) / ".cache" / "onnx_models")

# Remove existing model if it exists
if os.path.exists(model_path):
shutil.rmtree(model_path)

result_path = check_model(model_name, model_path)

assert os.path.exists(result_path)
assert os.path.exists(os.path.join(result_path, "model.onnx"))
assert os.path.exists(os.path.join(result_path, "tokenizer.json"))

# Clean up by removing the temporary directory after the test
shutil.rmtree(temp_dir)

import numpy as np

def test_infer_embeddings():
# Define the path to the ONNX model, assuming you are using the check_model function
model_path = check_model()

# Test data
documents = ["This is a test sentence.", "Another test sentence."]

# Run the inference
embeddings = infer_embeddings(documents, model_path)

# Validate the result
assert isinstance(embeddings, np.ndarray), "Output must be a numpy array"
assert embeddings.shape[0] == len(documents), "Number of embeddings must match number of input documents"
assert embeddings.shape[1] > 0, "Embedding size must be greater than 0"
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
chromadb
agentlogger
psycopg2-binary
sentence_transformers
python-dotenv
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
"chromadb",
"agentlogger",
"psycopg2-binary",
"sentence_transformers",
"python-dotenv"
]

setup(
name='agentmemory',
version='0.4.1',
version='0.4.2',
description='Easy-to-use memory for agents, document search, knowledge graphing and more.',
long_description=long_description, # added this line
long_description_content_type="text/markdown", # and this line
Expand Down

0 comments on commit 6cc0f19

Please sign in to comment.