Skip to content

Commit

Permalink
reduced duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
pcrespov committed Dec 18, 2024
1 parent 82a5348 commit c94a135
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,25 @@ def _check_group_permissions(
async def _get_group_and_access_rights_or_raise(
conn: AsyncConnection,
*,
user_id: UserID,
caller_id: UserID,
group_id: GroupID,
permission: Literal["read", "write", "delete"] | None,
) -> Row:
result = await conn.execute(
sa.select(
*_GROUP_COLUMNS,
user_to_groups.c.access_rights,
)
.select_from(groups.join(user_to_groups, user_to_groups.c.gid == groups.c.gid))
.where((user_to_groups.c.uid == user_id) & (user_to_groups.c.gid == group_id))
.where((user_to_groups.c.uid == caller_id) & (user_to_groups.c.gid == group_id))
)
row = result.first()
if not row:
raise GroupNotFoundError(gid=group_id)

if permission:
_check_group_permissions(row, caller_id, group_id, permission)

return row


Expand Down Expand Up @@ -270,10 +275,8 @@ async def get_user_group(
"""
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
row = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
conn, caller_id=user_id, group_id=group_id, permission="read"
)
_check_group_permissions(row, user_id, group_id, "read")

group, access_rights = _to_group_info_tuple(row)
return group, access_rights

Expand All @@ -291,7 +294,10 @@ async def get_product_group_for_user(
"""
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
row = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=product_gid
conn,
caller_id=user_id,
group_id=product_gid,
permission=None,
)
group, access_rights = _to_group_info_tuple(row)
return group, access_rights
Expand All @@ -310,7 +316,9 @@ async def create_standard_group(

async with transaction_context(get_asyncpg_engine(app), connection) as conn:
user = await conn.scalar(
sa.select(users.c.primary_gid).where(users.c.id == user_id)
sa.select(
users.c.primary_gid,
).where(users.c.id == user_id)
)
if not user:
raise UserNotFoundError(user_id=user_id)
Expand Down Expand Up @@ -356,17 +364,17 @@ async def update_standard_group(

async with transaction_context(get_asyncpg_engine(app), connection) as conn:
row = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
conn, caller_id=user_id, group_id=group_id, permission="write"
)
assert row.gid == group_id # nosec
_check_group_permissions(row, user_id, group_id, "write")
# NOTE: update does not include access-rights
access_rights = AccessRightsDict(**row.access_rights) # type: ignore[typeddict-item]

result = await conn.stream(
# pylint: disable=no-value-for-parameter
groups.update()
.values(**values)
.where((groups.c.gid == row.gid) & (groups.c.type == GroupType.STANDARD))
.where((groups.c.gid == group_id) & (groups.c.type == GroupType.STANDARD))
.returning(*_GROUP_COLUMNS)
)
row = await result.fetchone()
Expand All @@ -384,15 +392,14 @@ async def delete_standard_group(
group_id: GroupID,
) -> None:
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
group = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
await _get_group_and_access_rights_or_raise(
conn, caller_id=user_id, group_id=group_id, permission="delete"
)
_check_group_permissions(group, user_id, group_id, "delete")

await conn.execute(
# pylint: disable=no-value-for-parameter
groups.delete().where(
(groups.c.gid == group.gid) & (groups.c.type == GroupType.STANDARD)
(groups.c.gid == group_id) & (groups.c.type == GroupType.STANDARD)
)
)

Expand All @@ -406,7 +413,7 @@ async def get_user_from_email(
app: web.Application,
connection: AsyncConnection | None = None,
*,
caller_user_id: UserID,
caller_id: UserID,
email: str,
) -> Row:
"""
Expand All @@ -418,7 +425,7 @@ async def get_user_from_email(
result = await conn.stream(
sa.select(users.c.id).where(
(users.c.email == email)
& is_public(users.c.privacy_hide_email, caller_id=caller_user_id)
& is_public(users.c.privacy_hide_email, caller_id=caller_id)
)
)
user = await result.fetchone()
Expand Down Expand Up @@ -463,11 +470,14 @@ def _group_user_cols(caller_id: UserID):


async def _get_user_in_group_or_raise(
conn: AsyncConnection, *, caller_user_id, group_id: GroupID, user_id: int
conn: AsyncConnection, *, caller_id: UserID, group_id: GroupID, user_id: UserID
) -> Row:
# now get the user
# NOTE: that the caller_id might be different that the target user_id
result = await conn.stream(
sa.select(*_group_user_cols(caller_user_id), user_to_groups.c.access_rights)
sa.select(
*_group_user_cols(caller_id),
user_to_groups.c.access_rights,
)
.select_from(
users.join(user_to_groups, users.c.id == user_to_groups.c.uid),
)
Expand All @@ -483,7 +493,7 @@ async def list_users_in_group(
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
caller_id: UserID,
group_id: GroupID,
) -> list[GroupMember]:
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
Expand All @@ -501,7 +511,7 @@ async def list_users_in_group(
.where(
(user_to_groups.c.gid == group_id)
& (
(user_to_groups.c.uid == user_id)
(user_to_groups.c.uid == caller_id)
| (
(groups.c.type == GroupType.PRIMARY)
& users.c.role.in_([r for r in UserRole if r > UserRole.GUEST])
Expand All @@ -518,14 +528,14 @@ async def list_users_in_group(
# Drop access-rights if primary group
if group_row.type == GroupType.PRIMARY:
query = sa.select(
*_group_user_cols(user_id),
*_group_user_cols(caller_id),
)
else:
_check_group_permissions(
group_row, caller_id=user_id, group_id=group_id, permission="read"
group_row, caller_id=caller_id, group_id=group_id, permission="read"
)
query = sa.select(
*_group_user_cols(user_id),
*_group_user_cols(caller_id),
user_to_groups.c.access_rights,
)

Expand All @@ -545,21 +555,20 @@ async def get_user_in_group(
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
caller_id: UserID,
group_id: GroupID,
the_user_id_in_group: int,
) -> GroupMember:
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
# first check if the group exists
group = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
await _get_group_and_access_rights_or_raise(
conn, caller_id=caller_id, group_id=group_id, permission="read"
)
_check_group_permissions(group, user_id, group_id, "read")

# get the user with its permissions
the_user = await _get_user_in_group_or_raise(
conn,
caller_user_id=user_id,
caller_id=caller_id,
group_id=group_id,
user_id=the_user_id_in_group,
)
Expand All @@ -570,7 +579,7 @@ async def update_user_in_group(
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
caller_id: UserID,
group_id: GroupID,
the_user_id_in_group: UserID,
access_rights: AccessRightsDict,
Expand All @@ -582,15 +591,14 @@ async def update_user_in_group(
async with transaction_context(get_asyncpg_engine(app), connection) as conn:

# first check if the group exists
group = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
await _get_group_and_access_rights_or_raise(
conn, caller_id=caller_id, group_id=group_id, permission="write"
)
_check_group_permissions(group, user_id, group_id, "write")

# now check the user exists
the_user = await _get_user_in_group_or_raise(
conn,
caller_user_id=user_id,
caller_id=caller_id,
group_id=group_id,
user_id=the_user_id_in_group,
)
Expand All @@ -617,21 +625,20 @@ async def delete_user_from_group(
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
caller_id: UserID,
group_id: GroupID,
the_user_id_in_group: UserID,
) -> None:
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
# first check if the group exists
group = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
await _get_group_and_access_rights_or_raise(
conn, caller_id=caller_id, group_id=group_id, permission="write"
)
_check_group_permissions(group, user_id, group_id, "write")

# check the user exists
await _get_user_in_group_or_raise(
conn,
caller_user_id=user_id,
caller_id=caller_id,
group_id=group_id,
user_id=the_user_id_in_group,
)
Expand Down Expand Up @@ -675,7 +682,7 @@ async def add_new_user_in_group(
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
caller_id: UserID,
group_id: GroupID,
# either user_id or user_name
new_user_id: UserID | None = None,
Expand All @@ -687,10 +694,9 @@ async def add_new_user_in_group(
"""
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
# first check if the group exists
group = await _get_group_and_access_rights_or_raise(
conn, user_id=user_id, group_id=group_id
await _get_group_and_access_rights_or_raise(
conn, caller_id=caller_id, group_id=group_id, permission="write"
)
_check_group_permissions(group, user_id, group_id, "write")

query = sa.select(users.c.id)
if new_user_id is not None:
Expand All @@ -715,20 +721,23 @@ async def add_new_user_in_group(
await conn.execute(
# pylint: disable=no-value-for-parameter
user_to_groups.insert().values(
uid=new_user_id, gid=group.gid, access_rights=user_access_rights
uid=new_user_id, gid=group_id, access_rights=user_access_rights
)
)
except UniqueViolation as exc:
raise UserAlreadyInGroupError(
uid=new_user_id,
gid=group_id,
user_id=user_id,
user_id=caller_id,
access_rights=access_rights,
) from exc


async def auto_add_user_to_groups(
app: web.Application, connection: AsyncConnection | None = None, *, user: dict
app: web.Application,
connection: AsyncConnection | None = None,
*,
user: dict,
) -> None:

user_id: UserID = user["id"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async def list_group_members(
app: web.Application, user_id: UserID, group_id: GroupID
) -> list[GroupMember]:
return await _groups_repository.list_users_in_group(
app, user_id=user_id, group_id=group_id
app, caller_id=user_id, group_id=group_id
)


Expand All @@ -171,7 +171,7 @@ async def get_group_member(

return await _groups_repository.get_user_in_group(
app,
user_id=user_id,
caller_id=user_id,
group_id=group_id,
the_user_id_in_group=the_user_id_in_group,
)
Expand All @@ -186,7 +186,7 @@ async def update_group_member(
) -> GroupMember:
return await _groups_repository.update_user_in_group(
app,
user_id=user_id,
caller_id=user_id,
group_id=group_id,
the_user_id_in_group=the_user_id_in_group,
access_rights=access_rights,
Expand All @@ -201,7 +201,7 @@ async def delete_group_member(
) -> None:
return await _groups_repository.delete_user_from_group(
app,
user_id=user_id,
caller_id=user_id,
group_id=group_id,
the_user_id_in_group=the_user_id_in_group,
)
Expand Down Expand Up @@ -261,13 +261,13 @@ async def add_user_in_group(

if new_by_user_email:
user = await _groups_repository.get_user_from_email(
app, email=new_by_user_email, caller_user_id=user_id
app, email=new_by_user_email, caller_id=user_id
)
new_by_user_id = user.id

return await _groups_repository.add_new_user_in_group(
app,
user_id=user_id,
caller_id=user_id,
group_id=group_id,
new_user_id=new_by_user_id,
new_user_name=new_by_user_name,
Expand Down

0 comments on commit c94a135

Please sign in to comment.