diff --git a/hivemind_core/database.py b/hivemind_core/database.py index dcef3ef..b4495c9 100644 --- a/hivemind_core/database.py +++ b/hivemind_core/database.py @@ -67,6 +67,10 @@ def __post_init__(self): """ Initializes the allowed types for the Client instance if not provided. """ + if not isinstance(self.client_id, int): + raise ValueError("client_id should be an integer") + if not isinstance(self.is_admin, bool): + raise ValueError("is_admin should be a boolean") self.allowed_types = self.allowed_types or ["recognizer_loop:utterance", "recognizer_loop:record_begin", "recognizer_loop:record_end", @@ -185,7 +189,6 @@ def add_item(self, client: Client) -> bool: """ pass - @abc.abstractmethod def delete_item(self, client: Client) -> bool: """ Delete a client from the database. @@ -196,7 +199,9 @@ def delete_item(self, client: Client) -> bool: Returns: True if the deletion was successful, False otherwise. """ - pass + # leave the deleted entry in db, do not allow reuse of client_id ! + client = Client(client_id=client.client_id, api_key="revoked") + return self.update_item(client) def update_item(self, client: Client) -> bool: """ @@ -287,21 +292,6 @@ def add_item(self, client: Client) -> bool: self._db[client.client_id] = client.__dict__ return True - def delete_item(self, client: Client) -> bool: - """ - Delete a client from the JSON database. - - Args: - client: The client to be deleted. - - Returns: - True if the deletion was successful, False otherwise. - """ - if client.client_id in self._db: - self._db.pop(client.client_id) - return True - return False - def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]: """ Search for clients by a specific key-value pair in the JSON database. @@ -435,24 +425,6 @@ def add_item(self, client: Client) -> bool: LOG.error(f"Failed to add client to SQLite: {e}") return False - def delete_item(self, client: Client) -> bool: - """ - Delete a client from the SQLite database. - - Args: - client: The client to be deleted. - - Returns: - True if the deletion was successful, False otherwise. - """ - try: - with self.conn: - self.conn.execute("DELETE FROM clients WHERE client_id = ?", (client.client_id,)) - return True - except sqlite3.Error as e: - LOG.error(f"Failed to delete client from SQLite: {e}") - return False - def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]: """ Search for clients by a specific key-value pair in the SQLite database. @@ -549,31 +521,16 @@ def add_item(self, client: Client) -> bool: """ item_key = f"client:{client.client_id}" serialized_data: str = client.serialize() - try: # Store data in Redis self.redis.set(item_key, serialized_data) - return True - except Exception as e: - LOG.error(f"Failed to add client to Redis/RediSearch: {e}") - return False - def delete_item(self, client: Client) -> bool: - """ - Delete a client from Redis and RediSearch. - - Args: - client: The client to be deleted. - - Returns: - True if the deletion was successful, False otherwise. - """ - item_key = f"client:{client.client_id}" - try: - self.redis.delete(item_key) + # Maintain indices for common search fields + self.redis.sadd(f"client:index:name:{client.name}", client.client_id) + self.redis.sadd(f"client:index:api_key:{client.api_key}", client.client_id) return True except Exception as e: - LOG.error(f"Failed to delete client from Redis: {e}") + LOG.error(f"Failed to add client to Redis/RediSearch: {e}") return False def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[Client]: @@ -587,8 +544,18 @@ def search_by_value(self, key: str, val: Union[str, bool, int, float]) -> List[C Returns: A list of clients that match the search criteria. """ + # Use index if available + if key in ['name', 'api_key']: + client_ids = self.redis.smembers(f"client:index:{key}:{val}") + res = [cast2client(self.redis.get(f"client:{cid}")) + for cid in client_ids] + res = [c for c in res if c.api_key != "revoked"] + return res + res = [] for client_id in self.redis.scan_iter(f"client:*"): + if client_id.startswith("client:index:"): + continue client_data = self.redis.get(client_id) client = cast2client(client_data) if hasattr(client, key) and getattr(client, key) == val: @@ -602,7 +569,7 @@ def __len__(self) -> int: Returns: The number of clients in the database. """ - return len(self.redis.keys("client:*")) + return int(len(self.redis.keys("client:*")) / 3) # because of index entries for name/key fastsearch def __iter__(self) -> Iterable['Client']: """ @@ -612,7 +579,12 @@ def __iter__(self) -> Iterable['Client']: An iterator over the clients in the database. """ for client_id in self.redis.scan_iter(f"client:*"): - yield cast2client(self.redis.get(client_id)) + if client_id.startswith("client:index:"): + continue + try: + yield cast2client(self.redis.get(client_id)) + except Exception as e: + LOG.error(f"Failed to get client '{client_id}' : {e}") class ClientDatabase: diff --git a/test/unittests/test_db.py b/test/unittests/test_db.py index 12ac3e6..df95b41 100644 --- a/test/unittests/test_db.py +++ b/test/unittests/test_db.py @@ -183,5 +183,51 @@ def test_get_clients_by_name(self): self.assertEqual(clients[0].name, "Test Client") +class TestClientNegativeCases(unittest.TestCase): + + def test_missing_required_fields(self): + # Missing the "client_id" field, which is required by the Client dataclass + client_data = { + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + with self.assertRaises(TypeError): + Client(**client_data) + + def test_invalid_field_type_for_client_id(self): + # Providing a string instead of an integer for "client_id" + client_data = { + "client_id": "invalid_id", + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": False + } + with self.assertRaises(ValueError): + # If needed, adjust logic in your code to raise ValueError instead of TypeError + Client(**client_data) + + def test_invalid_field_type_for_is_admin(self): + # Providing a string instead of a boolean for "is_admin" + client_data = { + "client_id": 1, + "api_key": "test_api_key", + "name": "Test Client", + "description": "A test client", + "is_admin": "not_boolean" + } + with self.assertRaises(ValueError): + # If needed, adjust logic in your code to raise ValueError instead of TypeError + Client(**client_data) + + def test_deserialize_with_incorrect_json_structure(self): + # Passing an invalid JSON string missing required fields + invalid_json_str = '{"client_id": 1}' + with self.assertRaises(TypeError): + # Or another appropriate exception if your parsing logic differs + Client.deserialize(invalid_json_str) + if __name__ == '__main__': unittest.main()