Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

Commit

Permalink
feat/redis_db
Browse files Browse the repository at this point in the history
allow several database backends
  • Loading branch information
JarbasAl committed Jun 30, 2024
1 parent 0a1b723 commit 052c70b
Showing 1 changed file with 155 additions and 103 deletions.
258 changes: 155 additions & 103 deletions hivemind_core/database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import json
from functools import wraps
from typing import List, Dict, Union, Any, Optional, Iterable
Expand Down Expand Up @@ -50,20 +51,20 @@ def call_function(*args, **kwargs):

class Client:
def __init__(
self,
client_id: int,
api_key: str,
name: str = "",
description: str = "",
is_admin: bool = False,
last_seen: float = -1,
blacklist: Optional[Dict[str, List[str]]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None,
can_broadcast: bool = True,
can_escalate: bool = True,
can_propagate: bool = True,
self,
client_id: int,
api_key: str,
name: str = "",
description: str = "",
is_admin: bool = False,
last_seen: float = -1,
blacklist: Optional[Dict[str, List[str]]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None,
can_broadcast: bool = True,
can_escalate: bool = True,
can_propagate: bool = True,
):
self.client_id = client_id
self.description = description
Expand All @@ -74,10 +75,10 @@ def __init__(
self.crypto_key = crypto_key
self.password = password
self.blacklist = blacklist or {"messages": [], "skills": [], "intents": []}
self.allowed_types = allowed_types or ["recognizer_loop:utterance",
"recognizer_loop:record_begin",
"recognizer_loop:record_end",
"recognizer_loop:audio_output_start",
self.allowed_types = allowed_types or ["recognizer_loop:utterance",
"recognizer_loop:record_begin",
"recognizer_loop:record_end",
"recognizer_loop:audio_output_start",
"recognizer_loop:audio_output_end",
"ovos.common_play.SEI.get.response"]
if "recognizer_loop:utterance" not in self.allowed_types:
Expand Down Expand Up @@ -106,110 +107,161 @@ def __repr__(self) -> str:
return str(self.__dict__)


class ClientDatabase(JsonDatabaseXDG):
class AbstractDB:
@abc.abstractmethod
def get_item_id(self, client: Client) -> str:
pass

@abc.abstractmethod
def add_item(self, client: Client):
pass

@abc.abstractmethod
def update_item(self, item_id: str, client: Client):
pass

def delete_item(self, client: Client):
item_id = self.get_item_id(client)
self.update_item(item_id, Client(-1, api_key="revoked"))

@abc.abstractmethod
def search_by_value(self, key: str, val: str):
pass

@abc.abstractmethod
def __len__(self):
return 0

@abc.abstractmethod
def commit(self):
pass


class JsonDB(AbstractDB):
def __init__(self):
super().__init__("clients", subfolder="hivemind")
self._db = JsonDatabaseXDG(name="clients", subfolder="hivemind")

def update_timestamp(self, key: str, timestamp: float) -> bool:
user = self.get_client_by_api_key(key)
if user is None:
return False
item_id = self.get_item_id(user)
user["last_seen"] = timestamp
self.update_item(item_id, user)
return True
def get_item_id(self, client: Client) -> str:
client = client.__dict__
return self._db.get_item_id(client)

def add_item(self, client: Client):
client = client.__dict__
self._db.add_item(client)

def update_item(self, item_id: str, client: Client):
client = client.__dict__
self._db.update_item(item_id, client)

@cast_to_client_obj()
def search_by_value(self, key: str, val: str) -> List[Client]:
return self._db.search_by_value(key, val)

def __len__(self):
return len(self._db)

def commit(self):
self.commit()


class RedisDB(AbstractDB):
def __init__(self):
try:
import redis
from redis.commands.json.path import Path
from redis.commands.search.query import Query
except ImportError:
LOG.error("'pip install redis[hiredis]'")
raise
self._Path = Path
self._Query = Query
# TODO - host/port from config
self.r = redis.Redis(host="localhost", port=6379)
self.rs = self.r.ft("idx:clients")

def get_item_id(self, client: Client) -> str:
pass # TODO

def add_item(self, client: Client):
client_id = len(self) + 1
self.r.json().set(f"client:{client_id}",
self._Path.root_path(),
client.__dict__)

def update_item(self, item_id: str, client: Client):
self.r.json().set(f"client:{item_id}", self._Path.root_path(), client)

@cast_to_client_obj()
def search_by_value(self, key: str, val: str) -> List[Client]:
search = self.rs.search(self._Query(f"@{key}:{val}"))
return [json.loads(doc.json) for doc in search.docs]

@cast_to_client_obj()
def get_all_clients(
self, sort_by: str = "id", asc: bool = True
) -> Optional[List[Client]]:
clients: List = []
search = self.rs.search(self._Query("@id:[0 +inf]").sort_by(sort_by, asc))
for client in search.docs:
clients.append(json.loads(client.json))
return clients

def __len__(self):
return len(self.get_all_clients())

def commit(self):
pass


class ClientDatabase:
valid_backends = ["json", "redis"]

def __init__(self, backend="json"):
if backend not in self.valid_backends:
raise NotImplementedError(f"{backend} not supported, choose one of {self.valid_backends}")

if backend == "json":
self.db = JsonDB()
else:
self.db = RedisDB()

def get_item_id(self, client: Client):
return self.db.get_item_id(client)

def delete_client(self, key: str) -> bool:
user = self.get_client_by_api_key(key)
if user:
item_id = self.get_item_id(user)
self.update_item(item_id, Client(-1, api_key="revoked"))
self.db.delete_item(user)
return True
return False

def change_key(self, old_key: str, new_key: str) -> bool:
user = self.get_client_by_api_key(old_key)
if user is None:
return False
item_id = self.get_item_id(user)
user["api_key"] = new_key
self.update_item(item_id, user)
return True

def change_crypto_key(self, api_key: str, new_key: str) -> bool:
user = self.get_client_by_api_key(api_key)
if user is None:
return False
item_id = self.get_item_id(user)
user["crypto_key"] = new_key
self.update_item(item_id, user)
return True

def get_crypto_key(self, api_key: str) -> Optional[str]:
user = self.get_client_by_api_key(api_key)
if user is None:
return None
return user["crypto_key"]

def get_password(self, api_key: str) -> Optional[str]:
user = self.get_client_by_api_key(api_key)
if user is None:
return None
return user["password"]

def change_name(self, new_name: str, key: str) -> bool:
user = self.get_client_by_api_key(key)
if user is None:
return False
item_id = self.get_item_id(user)
user["name"] = new_name
self.update_item(item_id, user)
return True

def change_blacklist(self, blacklist: Union[str, Dict[str, Any]], key: str) -> bool:
if isinstance(blacklist, dict):
blacklist = json.dumps(blacklist)
user = self.get_client_by_api_key(key)
if user is None:
return False
item_id = self.get_item_id(user)
user["blacklist"] = blacklist
self.update_item(item_id, user)
return True

def get_blacklist_by_api_key(self, api_key: str):
search = self.search_by_value("api_key", api_key)
if len(search):
return search[0]["blacklist"]
return None

@cast_to_client_obj()
def get_client_by_api_key(self, api_key: str) -> Optional[Client]:
search = self.search_by_value("api_key", api_key)
search = self.db.search_by_value("api_key", api_key)
if len(search):
return search[0]
return None

@cast_to_client_obj()
def get_clients_by_name(self, name: str) -> List[Client]:
return self.search_by_value("name", name)
return self.db.search_by_value("name", name)

@cast_to_client_obj()
def add_client(
self,
name: str,
key: str = "",
admin: bool = False,
blacklist: Optional[Dict[str, Any]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None,
self,
name: str,
key: str = "",
admin: bool = False,
blacklist: Optional[Dict[str, Any]] = None,
allowed_types: Optional[List[str]] = None,
crypto_key: Optional[str] = None,
password: Optional[str] = None,
) -> Client:
user = self.get_client_by_api_key(key)
item_id = self.get_item_id(user)
if crypto_key is not None:
crypto_key = crypto_key[:16]
if item_id >= 0:
user = self.get_client_by_api_key(key)
if user:
item_id = self.db.get_item_id(user)
if name:
user["name"] = name
if blacklist:
Expand All @@ -222,7 +274,7 @@ def add_client(
user["crypto_key"] = crypto_key
if password:
user["password"] = password
self.update_item(item_id, user)
self.db.update_item(item_id, user)
else:
user = Client(
api_key=key,
Expand All @@ -234,11 +286,11 @@ def add_client(
password=password,
allowed_types=allowed_types,
)
self.add_item(user)
self.db.add_item(user)
return user

def total_clients(self) -> int:
return len(self)
return len(self.db)

def __enter__(self):
"""Context handler"""
Expand All @@ -247,6 +299,6 @@ def __enter__(self):
def __exit__(self, _type, value, traceback):
"""Commits changes and Closes the session"""
try:
self.commit()
self.db.commit()
except Exception as e:
LOG.error(e)

0 comments on commit 052c70b

Please sign in to comment.