Skip to content

Commit

Permalink
Reformat unit test file
Browse files Browse the repository at this point in the history
Fix check for existing username/user_id
Move user conflict check to internal method in base class
  • Loading branch information
NeonDaniel committed Oct 25, 2024
1 parent 1113ad1 commit 82ae3b7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
21 changes: 21 additions & 0 deletions neon_users_service/databases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC, abstractmethod

from neon_users_service.exceptions import UserExistsError, UserNotExistsError
from neon_users_service.models import User


Expand Down Expand Up @@ -48,6 +50,25 @@ def delete_user(self, user_id: str) -> User:
@return: User object removed from the database
"""

def _check_user_exists(self, user: User) -> bool:
"""
Check if a user already exists with the given `username` or `user_id`.
"""
try:
# If username is defined, raise an exception
if self.read_user_by_username(user.username):
return True
except UserNotExistsError:
pass
try:
# If user ID is defined, it was likely passed to the `User` object
# instead of allowing the Factory to generate a new one.
if self.read_user_by_id(user.user_id):
return True
except UserNotExistsError:
pass
return False

def shutdown(self):
"""
Perform any cleanup when a database is no longer being used
Expand Down
9 changes: 2 additions & 7 deletions neon_users_service/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@ def __init__(self, db_path: Optional[str] = None):
self.connection.commit()

def create_user(self, user: User) -> User:
try:
if self.read_user_by_id(user.user_id):
raise UserExistsError(user.user_id)
elif self.read_user_by_username(user.username):
raise UserExistsError(user.username)
except UserNotExistsError:
pass
if self._check_user_exists(user):
raise UserExistsError(user)

self.connection.execute(
f'''INSERT INTO users VALUES
Expand Down
24 changes: 12 additions & 12 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,30 @@ def test_neon_user_config(self):
NeonUserConfig(units={"time": 13})
with self.assertRaises(ValidationError):
NeonUserConfig(location={"latitude": "test"})

# Valid type casting
config = NeonUserConfig(location={"latitude": "47.6765382",
"longitude": "-122.2070775"})
"longitude": "-122.2070775"})
self.assertIsInstance(config.location.latitude, float)
self.assertIsInstance(config.location.longitude, float)

def test_user_model(self):
user_kwargs=dict(username="test",
password_hash="test",
tokens=[{"description": "test",
"client_id": "test_id",
"expiration_timestamp": 0,
"refresh_token": "",
"last_used_timestamp": 0}])
user_kwargs = dict(username="test",
password_hash="test",
tokens=[{"description": "test",
"client_id": "test_id",
"expiration_timestamp": 0,
"refresh_token": "",
"last_used_timestamp": 0}])
default_user = User(**user_kwargs)
self.assertIsInstance(default_user.tokens[0], TokenConfig)
with self.assertRaises(ValidationError):
User(username="test")

with self.assertRaises(ValidationError):
User(username="test", password_hash="test",
tokens=[{"description": "test"}])

duplicate_user = User(**user_kwargs)
self.assertNotEqual(default_user, duplicate_user)
self.assertEqual(default_user.tokens, duplicate_user.tokens)
self.assertEqual(default_user.tokens, duplicate_user.tokens)

0 comments on commit 82ae3b7

Please sign in to comment.