Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

Commit

Permalink
improve redis search
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Dec 20, 2024
1 parent 6c4f534 commit 255a865
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 56 deletions.
84 changes: 28 additions & 56 deletions hivemind_core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __post_init__(self):
"""
Initializes the allowed types for the Client instance if not provided.
"""
if not isinstance(self.client_id, int):
raise ValueError("client_id should be an integer")
if not isinstance(self.is_admin, bool):
raise ValueError("is_admin should be a boolean")
self.allowed_types = self.allowed_types or ["recognizer_loop:utterance",
"recognizer_loop:record_begin",
"recognizer_loop:record_end",
Expand Down Expand Up @@ -185,7 +189,6 @@ def add_item(self, client: Client) -> bool:
"""
pass

@abc.abstractmethod
def delete_item(self, client: Client) -> bool:
"""
Delete a client from the database.
Expand All @@ -196,7 +199,9 @@ def delete_item(self, client: Client) -> bool:
Returns:
True if the deletion was successful, False otherwise.
"""
pass
# leave the deleted entry in db, do not allow reuse of client_id !
client = Client(client_id=client.client_id, api_key="revoked")
return self.update_item(client)

def update_item(self, client: Client) -> bool:
"""
Expand Down Expand Up @@ -287,21 +292,6 @@ def add_item(self, client: Client) -> bool:
self._db[client.client_id] = client.__dict__
return True

def delete_item(self, client: Client) -> bool:
"""
Delete a client from the JSON database.
Args:
client: The client to be deleted.
Returns:
True if the deletion was successful, False otherwise.
"""
if client.client_id in self._db:
self._db.pop(client.client_id)
return True
return False

def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]:
"""
Search for clients by a specific key-value pair in the JSON database.
Expand Down Expand Up @@ -435,24 +425,6 @@ def add_item(self, client: Client) -> bool:
LOG.error(f"Failed to add client to SQLite: {e}")
return False

def delete_item(self, client: Client) -> bool:
"""
Delete a client from the SQLite database.
Args:
client: The client to be deleted.
Returns:
True if the deletion was successful, False otherwise.
"""
try:
with self.conn:
self.conn.execute("DELETE FROM clients WHERE client_id = ?", (client.client_id,))
return True
except sqlite3.Error as e:
LOG.error(f"Failed to delete client from SQLite: {e}")
return False

def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]:
"""
Search for clients by a specific key-value pair in the SQLite database.
Expand Down Expand Up @@ -549,31 +521,16 @@ def add_item(self, client: Client) -> bool:
"""
item_key = f"client:{client.client_id}"
serialized_data: str = client.serialize()

try:
# Store data in Redis
self.redis.set(item_key, serialized_data)
return True
except Exception as e:
LOG.error(f"Failed to add client to Redis/RediSearch: {e}")
return False

def delete_item(self, client: Client) -> bool:
"""
Delete a client from Redis and RediSearch.
Args:
client: The client to be deleted.
Returns:
True if the deletion was successful, False otherwise.
"""
item_key = f"client:{client.client_id}"
try:
self.redis.delete(item_key)
# Maintain indices for common search fields
self.redis.sadd(f"client:index:name:{client.name}", client.client_id)
self.redis.sadd(f"client:index:api_key:{client.api_key}", client.client_id)
return True
except Exception as e:
LOG.error(f"Failed to delete client from Redis: {e}")
LOG.error(f"Failed to add client to Redis/RediSearch: {e}")
return False

def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]:
Expand All @@ -587,8 +544,18 @@ def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[C
Returns:
A list of clients that match the search criteria.
"""
# Use index if available
if key in ['name', 'api_key']:
client_ids = self.redis.smembers(f"client:index:{key}:{val}")
res = [cast2client(self.redis.get(f"client:{cid}"))
for cid in client_ids]
res = [c for c in res if c.api_key != "revoked"]
return res

res = []
for client_id in self.redis.scan_iter(f"client:*"):
if client_id.startswith("client:index:"):
continue
client_data = self.redis.get(client_id)
client = cast2client(client_data)
if hasattr(client, key) and getattr(client, key) == val:
Expand All @@ -602,7 +569,7 @@ def __len__(self) -> int:
Returns:
The number of clients in the database.
"""
return len(self.redis.keys("client:*"))
return int(len(self.redis.keys("client:*")) / 3) # because of index entries for name/key fastsearch

def __iter__(self) -> Iterable['Client']:
"""
Expand All @@ -612,7 +579,12 @@ def __iter__(self) -> Iterable['Client']:
An iterator over the clients in the database.
"""
for client_id in self.redis.scan_iter(f"client:*"):
yield cast2client(self.redis.get(client_id))
if client_id.startswith("client:index:"):
continue
try:
yield cast2client(self.redis.get(client_id))
except Exception as e:
LOG.error(f"Failed to get client '{client_id}' : {e}")


class ClientDatabase:
Expand Down
46 changes: 46 additions & 0 deletions test/unittests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,51 @@ def test_get_clients_by_name(self):
self.assertEqual(clients[0].name, "Test Client")


class TestClientNegativeCases(unittest.TestCase):

def test_missing_required_fields(self):
# Missing the "client_id" field, which is required by the Client dataclass
client_data = {
"api_key": "test_api_key",
"name": "Test Client",
"description": "A test client",
"is_admin": False
}
with self.assertRaises(TypeError):
Client(**client_data)

def test_invalid_field_type_for_client_id(self):
# Providing a string instead of an integer for "client_id"
client_data = {
"client_id": "invalid_id",
"api_key": "test_api_key",
"name": "Test Client",
"description": "A test client",
"is_admin": False
}
with self.assertRaises(ValueError):
# If needed, adjust logic in your code to raise ValueError instead of TypeError
Client(**client_data)

def test_invalid_field_type_for_is_admin(self):
# Providing a string instead of a boolean for "is_admin"
client_data = {
"client_id": 1,
"api_key": "test_api_key",
"name": "Test Client",
"description": "A test client",
"is_admin": "not_boolean"
}
with self.assertRaises(ValueError):
# If needed, adjust logic in your code to raise ValueError instead of TypeError
Client(**client_data)

def test_deserialize_with_incorrect_json_structure(self):
# Passing an invalid JSON string missing required fields
invalid_json_str = '{"client_id": 1}'
with self.assertRaises(TypeError):
# Or another appropriate exception if your parsing logic differs
Client.deserialize(invalid_json_str)

if __name__ == '__main__':
unittest.main()

0 comments on commit 255a865

Please sign in to comment.