From cd7631847411c5733dee679d8a1efe3835a6bc35 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Parent Date: Sat, 23 Sep 2023 10:07:32 -0400 Subject: [PATCH 1/3] declare metadata early, and recall past metadata, to avoid ensure --- agentmemory/check_model.py | 2 +- agentmemory/chroma_client.py | 67 +++++++++++++++++++++++++++++++++++- agentmemory/client.py | 2 +- agentmemory/postgres.py | 66 ++++++++++++++++++++++------------- 4 files changed, 110 insertions(+), 27 deletions(-) diff --git a/agentmemory/check_model.py b/agentmemory/check_model.py index befb4fc..a507689 100644 --- a/agentmemory/check_model.py +++ b/agentmemory/check_model.py @@ -18,7 +18,7 @@ def _download(url: str, fname: Path, chunk_size: int = 1024) -> None: size = file.write(data) bar.update(size) -default_model_path = str(Path.home() / ".cache" / "onnx_models") +default_model_path = Path.home() / ".cache" / "onnx_models" def check_model(model_name = "all-MiniLM-L6-v2", model_path = default_model_path) -> str: DOWNLOAD_PATH = Path(model_path) / model_name diff --git a/agentmemory/chroma_client.py b/agentmemory/chroma_client.py index 722c201..635ba12 100644 --- a/agentmemory/chroma_client.py +++ b/agentmemory/chroma_client.py @@ -2,7 +2,72 @@ import chromadb +from .client import CollectionMemory, AgentMemory + +class ChromaCollectionMemory(CollectionMemory): + def __init__(self, collection, metadata=None) -> None: + self.collection = collection + + def count(self): + return self.collection.count() + + def add(self, ids, documents=None, metadatas=None, embeddings=None): + return self.collection.add(ids, documents, metadatas, embeddings) + + def get( + self, + ids=None, + where=None, + limit=None, + offset=None, + where_document=None, + include=["metadatas", "documents"], + ): + return self.collection.get(ids, where, limit, offset, where_document, include) + + def peek(self, limit=10): + return self.collection.peek(limit) + + def query( + self, + query_embeddings=None, + query_texts=None, + n_results=10, + where=None, + where_document=None, + include=["metadatas", "documents", "distances"], + ): + return self.collection.query(query_embeddings, query_texts, n_results, where, where_document, include) + + def update(self, ids, documents=None, metadatas=None, embeddings=None): + return self.collection.update(ids, embeddings, metadatas, documents) + + def upsert(self, ids, documents=None, metadatas=None, embeddings=None): + return self.collection.upsert(ids, embeddings, metadatas, documents) + + def delete(self, ids=None, where=None, where_document=None): + return self.collection.delete(ids, where, where_document) + + +class ChromaMemory(AgentMemory): + def __init__(self, path) -> None: + self.chroma = chromadb.PersistentClient(path=path) + + def get_or_create_collection(self, category, metadata=None) -> CollectionMemory: + memory = self.chroma.get_or_create_collection(category) + return ChromaCollectionMemory(memory, metadata) + + def get_collection(self, category) -> CollectionMemory: + memory = self.chroma.get_collection(category) + return ChromaCollectionMemory(memory) + + def delete_collection(self, category): + self.chroma.delete_collection(category) + + def list_collections(self): + return self.chroma.list_collections() + def create_client(): STORAGE_PATH = os.environ.get("STORAGE_PATH", "./memory") - return chromadb.PersistentClient(path=STORAGE_PATH) + return ChromaMemory(path=STORAGE_PATH) diff --git a/agentmemory/client.py b/agentmemory/client.py index 28dd0ef..4274891 100644 --- a/agentmemory/client.py +++ b/agentmemory/client.py @@ -105,7 +105,7 @@ class AgentCollection(): class AgentMemory(ABC): @abstractmethod - def get_or_create_collection(self, category) -> CollectionMemory: + def get_or_create_collection(self, category, metadata=None) -> CollectionMemory: raise NotImplementedError() @abstractmethod diff --git a/agentmemory/postgres.py b/agentmemory/postgres.py index 5cae418..c260916 100644 --- a/agentmemory/postgres.py +++ b/agentmemory/postgres.py @@ -1,13 +1,16 @@ +from __future__ import annotations from pathlib import Path import os +from functools import reduce import psycopg2 from .client import AgentMemory, CollectionMemory, AgentCollection -from agentmemory.check_model import check_model, infer_embeddings - +from .check_model import check_model, infer_embeddings +import agentlogger def parse_metadata(where): + where = where or {} metadata = {} for key, value in where.items(): if key[0] != "$": @@ -57,12 +60,21 @@ def get_sql_operator(operator): class PostgresCollection(CollectionMemory): - def __init__(self, category, client): + def __init__(self, category, client: PostgresClient, metadata=None): self.category = category self.client = client + self.metadata = metadata or {} + client.ensure_table_exists(category) + if metadata: + client._ensure_metadata_columns_exist(category, metadata) + + def _validate_metadata(self, metadata): + if new_columns := set(metadata) - set(self.metadata): + agentlogger.log(f"Undeclared metadata {', '.join(new_columns)} for collection {self.category}") + self.client._ensure_metadata_columns_exist(self.category, metadata) + self.metadata.update(metadata) def count(self): - self.client.ensure_table_exists(self.category) table_name = self.client._table_name(self.category) query = f"SELECT COUNT(*) FROM {table_name}" @@ -90,7 +102,6 @@ def get( include=["metadatas", "documents"], ): # TODO: Mirrors Chroma API, but could be optimized a lot - category = self.category table_name = self.client._table_name(category) conditions = [] @@ -117,8 +128,7 @@ def get( else: conditions.append(f"{key}=%s") params.append(str(value)) - - self.client._ensure_metadata_columns_exist(category, parse_metadata(where)) + self._validate_metadata(parse_metadata(where)) if ids: if not all(isinstance(i, str) or isinstance(i, int) for i in ids): @@ -196,7 +206,6 @@ def query( ) def update(self, ids, documents=None, metadatas=None, embeddings=None): - self.client.ensure_table_exists(self.category) # if embeddings is not None if embeddings is None: if documents is None: @@ -277,6 +286,7 @@ def __init__( full_model_path = check_model(model_name=model_name, model_path=model_path) self.model_path = full_model_path self.embedding_width = embedding_width + self.collections = {} def _table_name(self, category): return f"memory_{category}" @@ -324,20 +334,31 @@ def list_collections(self): if row[0].startswith("memory_") ] - def get_collection(self, category): - return PostgresCollection(category, self) + def get_collection(self, category, metadata=None): + if collection := self.collections.get(category): + if metadata: + collection._validate_metadata(metadata) + else: + # Should we check for table existence here? + collection = PostgresCollection(category, self, metadata) + return collection def delete_collection(self, category): table_name = self._table_name(category) self.cur.execute(f"DROP TABLE IF EXISTS {table_name}") self.connection.commit() + self.collections.pop(category, None) - def get_or_create_collection(self, category): - return PostgresCollection(category, self) + def get_or_create_collection(self, category, metadata=None): + if collection := self.collections.get(category): + if metadata: + collection._validate_metadata(metadata) + else: + collection = PostgresCollection(category, self, metadata) + return collection def insert_memory(self, category, document, metadata={}, embedding=None, id=None): - self.ensure_table_exists(category) - self._ensure_metadata_columns_exist(category, parse_metadata(metadata)) + collection = self.get_or_create_collection(category, metadata) table_name = self._table_name(category) if embedding is None: @@ -345,7 +366,7 @@ def insert_memory(self, category, document, metadata={}, embedding=None, id=None # if the id is None, get the length of the table by counting the number of rows in the category if id is None: - id = self.get_or_create_collection(category).count() + id = collection.count() # Extracting the keys and values from metadata to insert them into respective columns columns = ["id", "document", "embedding"] + list(metadata.keys()) @@ -365,12 +386,12 @@ def create_embedding(self, document): return embeddings[0] def add(self, category, documents, metadatas, ids): - self.ensure_table_exists(category) + metadatas = list(metadatas) # in case it's a consumable iterable + meta_keys = reduce(set.union, (set(parse_metadata(m)) for m in metadatas), set()) + collection = self.get_or_create_collection(category, meta_keys) table_name = self._table_name(category) with self.connection.cursor() as cur: for document, metadata, id_ in zip(documents, metadatas, ids): - self._ensure_metadata_columns_exist(category, parse_metadata(metadata)) - columns = ["id", "document", "embedding"] + list(metadata.keys()) placeholders = ["%s"] * len(columns) embedding = self.create_embedding(document) @@ -385,7 +406,7 @@ def add(self, category, documents, metadatas, ids): def query( self, category, query_texts, n_results=5, where=None, where_document=None ): - self.ensure_table_exists(category) + collection = self.get_or_create_collection(category, parse_metadata(where)) table_name = self._table_name(category) conditions = [] params = [] @@ -457,16 +478,13 @@ def query( return results def update(self, category, id_, document=None, metadata=None, embedding=None): - self.ensure_table_exists(category) + collection = self.get_or_create_collection(category, parse_metadata(metadata)) table_name = self._table_name(category) with self.connection.cursor() as cur: if document: if embedding is None: embedding = self.create_embedding(document) if metadata: - self._ensure_metadata_columns_exist( - category, parse_metadata(metadata) - ) columns = ["document=%s", "embedding=%s"] + [ f"{key}=%s" for key in metadata.keys() ] @@ -482,7 +500,7 @@ def update(self, category, id_, document=None, metadata=None, embedding=None): """ cur.execute(query, tuple(values) + (id_,)) elif metadata: - self._ensure_metadata_columns_exist(category, parse_metadata(metadata)) + collection._validate_metadata(metadata) columns = [f"{key}=%s" for key in metadata.keys()] values = list(metadata.values()) query = f""" From cbba54d7630c070bfd28716ae0fc907605ea7c77 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Parent Date: Sat, 23 Sep 2023 11:54:20 -0400 Subject: [PATCH 2/3] refactor: extract common code --- agentmemory/postgres.py | 135 ++++++++++++---------------------------- 1 file changed, 41 insertions(+), 94 deletions(-) diff --git a/agentmemory/postgres.py b/agentmemory/postgres.py index c260916..dbc88d7 100644 --- a/agentmemory/postgres.py +++ b/agentmemory/postgres.py @@ -58,6 +58,42 @@ def get_sql_operator(operator): else: raise ValueError(f"Operator {operator} not supported") +def parse_conditions(where=None, where_document=None, ids=None): + conditions = [] + params = [] + if where_document is not None: + if where_document.get("$contains", None) is not None: + where_document = where_document["$contains"] + conditions.append("document LIKE %s") + params.append(f"%{where_document}%") + + if where: + for key, value in where.items(): + if key == "$and": + new_conditions, new_params = handle_and_condition(value) + conditions.extend(new_conditions) + params.extend(new_params) + elif key == "$or": + or_condition, new_params = handle_or_condition(value) + conditions.append(or_condition) + params.extend(new_params) + elif key == "$contains": + conditions.append(f"document LIKE %s") + params.append(f"%{value}%") + else: + conditions.append(f"{key}=%s") + params.append(str(value)) + + if ids: + if not all(isinstance(i, str) or isinstance(i, int) for i in ids): + raise Exception( + "ids must be a list of integers or strings representing integers" + ) + ids = [int(i) for i in ids] + conditions.append("id=ANY(%s::int[])") # Added explicit type casting + params.append(ids) + + return conditions, params class PostgresCollection(CollectionMemory): def __init__(self, category, client: PostgresClient, metadata=None): @@ -104,40 +140,8 @@ def get( # TODO: Mirrors Chroma API, but could be optimized a lot category = self.category table_name = self.client._table_name(category) - conditions = [] - params = [] - if where_document is not None: - if where_document.get("$contains", None) is not None: - where_document = where_document["$contains"] - conditions.append("document LIKE %s") - params.append(f"%{where_document}%") - - if where: - for key, value in where.items(): - if key == "$and": - new_conditions, new_params = handle_and_condition(value) - conditions.extend(new_conditions) - params.extend(new_params) - elif key == "$or": - or_condition, new_params = handle_or_condition(value) - conditions.append(or_condition) - params.extend(new_params) - elif key == "$contains": - conditions.append(f"document LIKE %s") - params.append(f"%{value}%") - else: - conditions.append(f"{key}=%s") - params.append(str(value)) - self._validate_metadata(parse_metadata(where)) - - if ids: - if not all(isinstance(i, str) or isinstance(i, int) for i in ids): - raise Exception( - "ids must be a list of integers or strings representing integers" - ) - ids = [int(i) for i in ids] - conditions.append("id=ANY(%s)") - params.append(ids) + self._validate_metadata(parse_metadata(where)) + conditions, params = parse_conditions(where, where_document, ids) if limit is None: limit = 100 # or another default value @@ -223,40 +227,7 @@ def upsert(self, ids, documents=None, metadatas=None, embeddings=None): def delete(self, ids=None, where=None, where_document=None): table_name = self.client._table_name(self.category) - conditions = [] - params = [] - - if where_document is not None: - if where_document.get("$contains", None) is not None: - where_document = where_document["$contains"] - conditions.append("document LIKE %s") - params.append(f"%{where_document}%") - - if ids: - if not all(isinstance(i, str) or isinstance(i, int) for i in ids): - raise Exception( - "ids must be a list of integers or strings representing integers" - ) - ids = [int(i) for i in ids] - conditions.append("id=ANY(%s::int[])") # Added explicit type casting - params.append(ids) - - if where: - for key, value in where.items(): - if key == "$and": - new_conditions, new_params = handle_and_condition(value) - conditions.extend(new_conditions) - params.extend(new_params) - elif key == "$or": - or_condition, new_params = handle_or_condition(value) - conditions.append(or_condition) - params.extend(new_params) - elif key == "$contains": - conditions.append(f"document LIKE %s") - params.append(f"%{value}%") - else: - conditions.append(f"{key}=%s") - params.append(str(value)) + conditions, params = parse_conditions(where, where_document, ids) if conditions: query = f"DELETE FROM {table_name} WHERE " + " AND ".join(conditions) @@ -408,32 +379,8 @@ def query( ): collection = self.get_or_create_collection(category, parse_metadata(where)) table_name = self._table_name(category) - conditions = [] - params = [] - - # Check if where_document is given - if where_document: - if where_document.get("$contains", None) is not None: - where_document = where_document["$contains"] - conditions.append("document LIKE %s") - params.append(f"%{where_document}%") - - if where: - for key, value in where.items(): - if key == "$and": - new_conditions, new_params = handle_and_condition(value) - conditions.extend(new_conditions) - params.extend(new_params) - elif key == "$or": - or_condition, new_params = handle_or_condition(value) - conditions.append(or_condition) - params.extend(new_params) - elif key == "$contains": - conditions.append(f"document LIKE %s") - params.append(f"%{value}%") - else: - conditions.append(f"{key}=%s") - params.append(str(value)) + collection._validate_metadata(parse_metadata(where)) + conditions, params = parse_conditions(where, where_document) where_clause = " WHERE " + " AND ".join(conditions) if conditions else "" From 793a1ab0efd3fa77a60c0d56bb97365f0848a5a4 Mon Sep 17 00:00:00 2001 From: Marc-Antoine Parent Date: Sat, 23 Sep 2023 16:30:27 -0400 Subject: [PATCH 3/3] let postgres set the ID --- agentmemory/chroma_client.py | 6 ++++++ agentmemory/main.py | 8 +------- agentmemory/postgres.py | 26 +++++++++++--------------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/agentmemory/chroma_client.py b/agentmemory/chroma_client.py index 635ba12..5409f62 100644 --- a/agentmemory/chroma_client.py +++ b/agentmemory/chroma_client.py @@ -43,6 +43,12 @@ def update(self, ids, documents=None, metadatas=None, embeddings=None): return self.collection.update(ids, embeddings, metadatas, documents) def upsert(self, ids, documents=None, metadatas=None, embeddings=None): + # if no id is provided, generate one based on count of documents in collection + if any(id is None for id in ids): + origin = self.count() + # pad the id with zeros to make it 16 digits long + ids = [str(id_).zfill(16) for id_ in range(origin, origin+len(documents))] + return self.collection.upsert(ids, embeddings, metadatas, documents) def delete(self, ids=None, where=None, where_document=None): diff --git a/agentmemory/main.py b/agentmemory/main.py index 0217c43..0ca60e9 100644 --- a/agentmemory/main.py +++ b/agentmemory/main.py @@ -37,12 +37,6 @@ def create_memory(category, text, metadata={}, embedding=None, id=None): metadata["created_at"] = datetime.datetime.now().timestamp() metadata["updated_at"] = datetime.datetime.now().timestamp() - # if no id is provided, generate one based on count of documents in collection - if id is None: - id = str(memories.count()) - # pad the id with zeros to make it 16 digits long - id = id.zfill(16) - # for each field in metadata... # if the field is a boolean, convert it to a string for key, value in metadata.items(): @@ -52,7 +46,7 @@ def create_memory(category, text, metadata={}, embedding=None, id=None): # insert the document into the collection memories.upsert( - ids=[str(id)], + ids=[id], documents=[text], metadatas=[metadata], embeddings=[embedding] if embedding is not None else None, diff --git a/agentmemory/postgres.py b/agentmemory/postgres.py index dbc88d7..2b944b6 100644 --- a/agentmemory/postgres.py +++ b/agentmemory/postgres.py @@ -2,6 +2,7 @@ from pathlib import Path import os from functools import reduce +from itertools import repeat import psycopg2 @@ -118,15 +119,14 @@ def count(self): return self.client.cur.fetchone()[0] - def add(self, ids, documents=None, metadatas=None, embeddings=None): - if embeddings is None: - for id_, document, metadata in zip(ids, documents, metadatas): - self.client.insert_memory(self.category, document, metadata) - else: - for id_, document, metadata, emb in zip( - ids, documents, metadatas, embeddings - ): - self.client.insert_memory(self.category, document, metadata, emb) + def add(self, ids=None, documents=None, metadatas=None, embeddings=None): + # dropping ids, using database serial + embeddings = embeddings or repeat(None) + metadatas = metadatas or repeat({}) + for document, metadata, emb in zip( + documents, metadatas, embeddings + ): + self.client.insert_memory(self.category, document, metadata, emb) def get( self, @@ -335,14 +335,10 @@ def insert_memory(self, category, document, metadata={}, embedding=None, id=None if embedding is None: embedding = self.create_embedding(document) - # if the id is None, get the length of the table by counting the number of rows in the category - if id is None: - id = collection.count() - # Extracting the keys and values from metadata to insert them into respective columns - columns = ["id", "document", "embedding"] + list(metadata.keys()) + columns = ["document", "embedding"] + list(metadata.keys()) placeholders = ["%s"] * len(columns) - values = [id, document, embedding] + list(metadata.values()) + values = [document, embedding] + list(metadata.values()) query = f""" INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join(placeholders)})