Skip to content

Commit

Permalink
Pass the right contexts to add_member() and edit_member()
Browse files Browse the repository at this point in the history
Problem
-------

`group_membership_api_factory()` is the context factory for all views
that use the `"api.group_member"` route:

* `remove_member()` (`request_method="DELETE"`)
* `add_member()` (`request_method="POST"`)
* `edit_member()` (`request_method="PATCH"`)
* Add in #9197 also `get_member()`
  (`request_method="GET"`)

`group_membership_api_factory()` returns a `GroupMembershipContext`
object but:

* `GroupMembershipContext` isn't the right context for `edit_member()`:
  it lacks the `context.new_roles` attribute for the new roles that the
  request wants to change `context.membership.roles` to.

* `GroupMembershipContext` isn't the right context for `add_member()`:
  it has an inappropriate `context.membership` attribute (when adding a
  new membership there shouldn't be an existing membership in the
  context) and it lacks the `context.new_roles` attribute for the roles
  that the request wants to create a membership with.

The context for the `edit_member()` view should be an
`EditGroupMembershipContext` object, and for `add_member()` it should be
an `AddGroupMembershipContext`.

As a result the `edit_member()` view has to create its own context
object to pass to `request.has_permission()`:

    def edit_member(context: GroupMembershipContext, request):
        appstruct = EditGroupMembershipAPISchema().validate(json_payload(request))
        new_roles = appstruct["roles"]

        if not request.has_permission(
            Permission.Group.MEMBER_EDIT,
            EditGroupMembershipContext(
                context.group, context.user, context.membership, new_roles
            ),
        ):
            raise HTTPNotFound()

When a future PR enables users (not just auth clients) to call the
add-membership API the `add_member()` view will have to do something
similar: constructing its own `AddGroupMembershipContext` object and
passing it to `request.has_permission()`.

This means there are two different context objects in play for the
`edit_member()` and `add_member()` views: the `context` that is passed
to the view is a `GroupMembershipContext`, but the `context` that is
passed to `request.has_permission()` is an `EditGroupMembershipContext`
or `AddGroupMembershipContext` constructed by the view itself.

Solution
-------

This commit changes `group_membership_api_factory()` to return a
`GroupMembershipContext` for `GET` and `DELETE` requests but return an
`EditGroupMembershipContext` for `PATCH` requests and an
`AddGroupMembershipContext` for `POST`s.

It's not possible for `group_membership_api_factory()` to set the
context's `new_roles` attribute: the value for `new_roles` isn't
available until later in the request processing cycle after the view has
parsed and validated the request's JSON body. So the factory returns
`context` objects with `context.new_roles=None` and the `edit_member()`
view has to set `new_roles` before calling `has_permission()`:

    appstruct = EditGroupMembershipAPISchema().validate(json_payload(request))
    context.new_roles = appstruct["roles"]

    if not request.has_permission(Permission.Group.MEMBER_EDIT, context):
        raise HTTPNotFound()

In future the `add_member()` view will have to do the same. So this is
still a little weird, but I think it's better than having two different
context objects for a single request.
  • Loading branch information
seanh committed Jan 7, 2025
1 parent 026e974 commit 544af1f
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 32 deletions.
4 changes: 4 additions & 0 deletions h/security/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def get_authenticated_users_membership():
def group_member_edit(
identity, context: EditGroupMembershipContext
): # pylint:disable=too-many-return-statements,too-complex
assert (
context.new_roles is not None
), "new_roles must be set before checking permissions"

old_roles = context.membership.roles
new_roles = context.new_roles

Expand Down
2 changes: 2 additions & 0 deletions h/traversal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from h.traversal.annotation import AnnotationContext, AnnotationRoot
from h.traversal.group import GroupContext, GroupRequiredRoot, GroupRoot
from h.traversal.group_membership import (
AddGroupMembershipContext,
EditGroupMembershipContext,
GroupMembershipContext,
group_membership_api_factory,
Expand All @@ -82,6 +83,7 @@
"UserByIDRoot",
"UserRoot",
"GroupContext",
"AddGroupMembershipContext",
"EditGroupMembershipContext",
"GroupMembershipContext",
"group_membership_api_factory",
Expand Down
29 changes: 23 additions & 6 deletions h/traversal/group_membership.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@
class GroupMembershipContext:
group: Group
user: User
membership: GroupMembership | None
membership: GroupMembership


@dataclass
class AddGroupMembershipContext:
group: Group
user: User
new_roles: list[GroupMembershipRoles] | None


@dataclass
class EditGroupMembershipContext:
group: Group
user: User
membership: GroupMembership
new_roles: list[GroupMembershipRoles]
new_roles: list[GroupMembershipRoles] | None


def _get_user(request, userid) -> User | None:
Expand Down Expand Up @@ -46,21 +53,31 @@ def _get_membership(request, group, user) -> GroupMembership | None:
return group_members_service.get_membership(group, user)


def group_membership_api_factory(request) -> GroupMembershipContext:
def group_membership_api_factory(
request,
) -> GroupMembershipContext | AddGroupMembershipContext | EditGroupMembershipContext:
userid = request.matchdict["userid"]
pubid = request.matchdict["pubid"]

user = _get_user(request, userid)
group = _get_group(request, pubid)
membership = _get_membership(request, group, user)

if not user:
raise HTTPNotFound(f"User not found: {userid}")

if not group:
raise HTTPNotFound(f"Group not found: {pubid}")

if not membership and request.method != "POST":
if request.method == "POST":
return AddGroupMembershipContext(group, user, new_roles=None)

membership = _get_membership(request, group, user)

if not membership:
raise HTTPNotFound(f"Membership not found: ({pubid}, {userid})")

return GroupMembershipContext(group=group, user=user, membership=membership)
if request.method in ("GET", "DELETE"):
return GroupMembershipContext(group=group, user=user, membership=membership)

assert request.method == "PATCH"
return EditGroupMembershipContext(group, user, membership, new_roles=None)
28 changes: 14 additions & 14 deletions h/views/api/group_members.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from h.schemas.util import validate_query_params
from h.security import Permission
from h.services.group_members import ConflictError
from h.traversal import EditGroupMembershipContext, GroupContext, GroupMembershipContext
from h.traversal import (
AddGroupMembershipContext,
EditGroupMembershipContext,
GroupContext,
GroupMembershipContext,
)
from h.views.api.config import api_config
from h.views.api.helpers.json_payload import json_payload

Expand Down Expand Up @@ -108,7 +113,7 @@ def remove_member(context: GroupMembershipContext, request):
description="Add a user to a group",
permission=Permission.Group.MEMBER_ADD,
)
def add_member(context: GroupMembershipContext, request):
def add_member(context: AddGroupMembershipContext, request):
if context.user.authority != context.group.authority:
raise HTTPNotFound()

Expand Down Expand Up @@ -139,21 +144,16 @@ def add_member(context: GroupMembershipContext, request):
link_name="group.member.edit",
description="Change a user's role in a group",
)
def edit_member(context: GroupMembershipContext, request):
def edit_member(context: EditGroupMembershipContext, request):
appstruct = EditGroupMembershipAPISchema().validate(json_payload(request))
new_roles = appstruct["roles"]

if not request.has_permission(
Permission.Group.MEMBER_EDIT,
EditGroupMembershipContext(
context.group, context.user, context.membership, new_roles
),
):
context.new_roles = appstruct["roles"]

if not request.has_permission(Permission.Group.MEMBER_EDIT, context):
raise HTTPNotFound()

if context.membership.roles != new_roles:
if context.membership.roles != context.new_roles:
old_roles = context.membership.roles
context.membership.roles = new_roles
context.membership.roles = context.new_roles
log.info(
"Changed group membership roles: %r (previous roles were: %r)",
context.membership,
Expand All @@ -166,6 +166,6 @@ def edit_member(context: GroupMembershipContext, request):
# Otherwise permissions checks will be based on the old roles.
for membership in request.identity.user.memberships:
if membership.group.id == context.group.id:
membership.roles = new_roles
membership.roles = context.new_roles

return GroupMembershipJSONPresenter(request, context.membership).asdict()
14 changes: 14 additions & 0 deletions tests/unit/h/security/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,20 @@ def test_changing_own_role(

assert predicates.group_member_edit(identity, context) == expected_result

def test_it_crashes_if_new_roles_is_not_set(self, identity):
context = EditGroupMembershipContext(
group=sentinel.group,
user=sentinel.user,
membership=sentinel.membership,
new_roles=None,
)

with pytest.raises(
AssertionError,
match="^new_roles must be set before checking permissions$",
):
predicates.group_member_edit(identity, context)

@pytest.fixture
def authenticated_user(self, db_session, authenticated_user, factories):
# Make the authenticated user a member of a *different* group,
Expand Down
77 changes: 69 additions & 8 deletions tests/unit/h/traversal/group_membership_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,26 @@

from h.exceptions import InvalidUserId
from h.traversal.group_membership import (
AddGroupMembershipContext,
EditGroupMembershipContext,
GroupMembershipContext,
group_membership_api_factory,
)


@pytest.mark.usefixtures("group_service", "user_service", "group_members_service")
class TestGroupMembershipAPIFactory:
def test_it(
self, group_service, user_service, group_members_service, pyramid_request
@pytest.mark.parametrize("request_method", ["GET", "DELETE"])
def test_get_delete(
self,
group_service,
user_service,
group_members_service,
pyramid_request,
request_method,
):
pyramid_request.method = request_method

context = group_membership_api_factory(pyramid_request)

group_service.fetch.assert_called_once_with(sentinel.pubid)
Expand All @@ -28,25 +38,70 @@ def test_it(
assert context.user == user_service.fetch.return_value
assert context.membership == group_members_service.get_membership.return_value

def test_when_no_matching_group(self, group_service, pyramid_request):
def test_post(
self, group_service, user_service, group_members_service, pyramid_request
):
pyramid_request.method = "POST"

context = group_membership_api_factory(pyramid_request)

group_service.fetch.assert_called_once_with(sentinel.pubid)
user_service.fetch.assert_called_once_with(sentinel.userid)
group_members_service.get_membership.assert_not_called()
assert isinstance(context, AddGroupMembershipContext)
assert context.group == group_service.fetch.return_value
assert context.user == user_service.fetch.return_value
assert context.new_roles is None

def test_patch(
self, group_service, user_service, group_members_service, pyramid_request
):
pyramid_request.method = "PATCH"

context = group_membership_api_factory(pyramid_request)

group_service.fetch.assert_called_once_with(sentinel.pubid)
user_service.fetch.assert_called_once_with(sentinel.userid)
group_members_service.get_membership.assert_called_once_with(
group_service.fetch.return_value, user_service.fetch.return_value
)
assert isinstance(context, EditGroupMembershipContext)
assert context.group == group_service.fetch.return_value
assert context.user == user_service.fetch.return_value
assert context.membership == group_members_service.get_membership.return_value
assert context.new_roles is None

@pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"])
def test_when_no_matching_group(
self, group_service, pyramid_request, request_method
):
pyramid_request.method = request_method
group_service.fetch.return_value = None

with pytest.raises(HTTPNotFound, match="Group not found: sentinel.pubid"):
group_membership_api_factory(pyramid_request)

def test_when_no_matching_user(self, user_service, pyramid_request):
@pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"])
def test_when_no_matching_user(self, user_service, pyramid_request, request_method):
pyramid_request.method = request_method
user_service.fetch.return_value = None

with pytest.raises(HTTPNotFound, match="User not found: sentinel.userid"):
group_membership_api_factory(pyramid_request)

def test_when_invalid_userid(self, user_service, pyramid_request):
@pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"])
def test_when_invalid_userid(self, user_service, pyramid_request, request_method):
pyramid_request.method = request_method
user_service.fetch.side_effect = InvalidUserId(sentinel.userid)

with pytest.raises(HTTPNotFound, match="User not found: sentinel.userid"):
group_membership_api_factory(pyramid_request)

def test_when_no_matching_membership(self, group_members_service, pyramid_request):
@pytest.mark.parametrize("request_method", ["GET", "PATCH", "DELETE"])
def test_when_no_matching_membership(
self, group_members_service, pyramid_request, request_method
):
pyramid_request.method = request_method
group_members_service.get_membership.return_value = None

with pytest.raises(
Expand All @@ -55,15 +110,21 @@ def test_when_no_matching_membership(self, group_members_service, pyramid_reques
):
group_membership_api_factory(pyramid_request)

def test_me_alias(self, pyramid_config, pyramid_request, user_service):
@pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"])
def test_me_alias(
self, pyramid_config, pyramid_request, user_service, request_method
):
pyramid_request.method = request_method
pyramid_config.testing_securitypolicy(userid=sentinel.userid)
pyramid_request.matchdict["userid"] = "me"

group_membership_api_factory(pyramid_request)

user_service.fetch.assert_called_once_with(sentinel.userid)

def test_me_alias_when_not_authenticated(self, pyramid_request):
@pytest.mark.parametrize("request_method", ["GET", "POST", "PATCH", "DELETE"])
def test_me_alias_when_not_authenticated(self, pyramid_request, request_method):
pyramid_request.method = request_method
pyramid_request.matchdict["userid"] = "me"

with pytest.raises(HTTPNotFound, match="User not found: me"):
Expand Down
20 changes: 16 additions & 4 deletions tests/unit/h/views/api/group_members_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
from h import presenters
from h.models import GroupMembership
from h.schemas.base import ValidationError
from h.security import Permission
from h.security.identity import Identity, LongLivedGroup, LongLivedMembership
from h.services.group_members import ConflictError
from h.traversal import GroupContext, GroupMembershipContext
from h.traversal import (
AddGroupMembershipContext,
EditGroupMembershipContext,
GroupContext,
GroupMembershipContext,
)
from h.views.api.exceptions import PayloadError


Expand Down Expand Up @@ -234,7 +240,7 @@ def test_it_with_authority_mismatch(self, pyramid_request, context):
def context(self, factories):
group = factories.Group.build()
user = factories.User.build(authority=group.authority)
return GroupMembershipContext(group=group, user=user, membership=None)
return AddGroupMembershipContext(group=group, user=user, new_roles=None)

@pytest.fixture
def pyramid_request(self, pyramid_request):
Expand All @@ -258,12 +264,17 @@ def test_it(
EditGroupMembershipAPISchema,
GroupMembershipJSONPresenter,
caplog,
mocker,
):
has_permission = mocker.spy(pyramid_request, "has_permission")

response = views.edit_member(context, pyramid_request)

EditGroupMembershipAPISchema.return_value.validate.assert_called_once_with(
sentinel.json_body
)
assert context.new_roles == sentinel.new_roles
has_permission.assert_called_once_with(Permission.Group.MEMBER_EDIT, context)
assert context.membership.roles == sentinel.new_roles
GroupMembershipJSONPresenter.assert_called_once_with(
pyramid_request, context.membership
Expand Down Expand Up @@ -342,8 +353,9 @@ def context(self, factories):
group = factories.Group.build()
user = factories.User.build(authority=group.authority)
membership = GroupMembership(group=group, user=user, roles=sentinel.old_roles)

return GroupMembershipContext(group=group, user=user, membership=membership)
return EditGroupMembershipContext(
group=group, user=user, membership=membership, new_roles=None
)

@pytest.fixture
def pyramid_request(self, pyramid_request):
Expand Down

0 comments on commit 544af1f

Please sign in to comment.