From 54bc5a0fd6cba83fdaf0a6ef83b20d7d305b6359 Mon Sep 17 00:00:00 2001 From: truc0 <22969604+truc0@users.noreply.github.com> Date: Tue, 11 Jun 2024 23:05:40 +0800 Subject: [PATCH] fix: field_index is incorrect in RBAC with domains mode (#345) --- casbin/management_enforcer.py | 9 ++++++--- tests/test_management_api.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/casbin/management_enforcer.py b/casbin/management_enforcer.py index b374ce10..de46c05b 100644 --- a/casbin/management_enforcer.py +++ b/casbin/management_enforcer.py @@ -27,7 +27,8 @@ def get_all_subjects(self): def get_all_named_subjects(self, ptype): """gets the list of subjects that show up in the current named policy.""" - return self.model.get_values_for_field_in_policy("p", ptype, 0) + field_index = self.model.get_field_index(ptype, "sub") + return self.model.get_values_for_field_in_policy("p", ptype, field_index) def get_all_objects(self): """gets the list of objects that show up in the current policy.""" @@ -35,7 +36,8 @@ def get_all_objects(self): def get_all_named_objects(self, ptype): """gets the list of objects that show up in the current named policy.""" - return self.model.get_values_for_field_in_policy("p", ptype, 1) + field_index = self.model.get_field_index(ptype, "obj") + return self.model.get_values_for_field_in_policy("p", ptype, field_index) def get_all_actions(self): """gets the list of actions that show up in the current policy.""" @@ -43,7 +45,8 @@ def get_all_actions(self): def get_all_named_actions(self, ptype): """gets the list of actions that show up in the current named policy.""" - return self.model.get_values_for_field_in_policy("p", ptype, 2) + field_index = self.model.get_field_index(ptype, "act") + return self.model.get_values_for_field_in_policy("p", ptype, field_index) def get_all_roles(self): """gets the list of roles that show up in the current named policy.""" diff --git a/tests/test_management_api.py b/tests/test_management_api.py index 84f55225..e924daa9 100644 --- a/tests/test_management_api.py +++ b/tests/test_management_api.py @@ -38,6 +38,18 @@ def test_get_list(self): self.assertEqual(e.get_all_actions(), ["read", "write"]) self.assertEqual(e.get_all_roles(), ["data2_admin"]) + def test_get_list_with_domains(self): + e = self.get_enforcer( + get_examples("rbac_with_domains_model.conf"), + get_examples("rbac_with_domains_policy.csv"), + # True, + ) + + self.assertEqual(e.get_all_subjects(), ["admin"]) + self.assertEqual(e.get_all_objects(), ["data1", "data2"]) + self.assertEqual(e.get_all_actions(), ["read", "write"]) + self.assertEqual(e.get_all_roles(), ["admin"]) + def test_get_policy_api(self): e = self.get_enforcer( get_examples("rbac_model.conf"),