Skip to content

Commit a36cf53

Browse files
authored
feat: support sync & async update_policy / update_policies (#7)
1 parent 1868e6d commit a36cf53

4 files changed

Lines changed: 182 additions & 0 deletions

File tree

casbin_pymongo_adapter/adapter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,35 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
175175
query["ptype"] = ptype
176176
results = self._collection.delete_many(query)
177177
return results.deleted_count > 0
178+
179+
def update_policy(self, sec, ptype, old_rule, new_rule):
180+
"""Update the old_rule with the new_rule in the database (storage).
181+
182+
Args:
183+
sec (str): section type
184+
ptype (str): policy type
185+
old_rule (list[str]): the old rule that needs to be modified
186+
new_rule (list[str]): the new rule to replace the old rule
187+
"""
188+
filter_query = {}
189+
for index, value in enumerate(old_rule):
190+
filter_query[f"v{index}"] = value
191+
192+
self._collection.find_one_and_update(
193+
filter_query,
194+
{"$set": {f"v{index}": value for index, value in enumerate(new_rule)}},
195+
)
196+
197+
return None
198+
199+
def update_policies(self, sec, ptype, old_rules, new_rules):
200+
"""Update the old_rule with the new_rule in the database (storage).
201+
202+
Args:
203+
sec (str): section type
204+
ptype (str): policy type
205+
old_rules (list[list[str]]): the old rules that needs to be modified
206+
new_rules (list[list[str]]): the new rules to replace the old rule
207+
"""
208+
for old_rule, new_rule in zip(old_rules, new_rules):
209+
self.update_policy(sec, ptype, old_rule, new_rule)

casbin_pymongo_adapter/asynchronous/adapter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,35 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
169169
query["ptype"] = ptype
170170
results = await self._collection.delete_many(query)
171171
return results.deleted_count > 0
172+
173+
async def update_policy(self, sec, ptype, old_rule, new_rule):
174+
"""Update the old_rule with the new_rule in the database (storage).
175+
176+
Args:
177+
sec (str): section type
178+
ptype (str): policy type
179+
old_rule (list[str]): the old rule that needs to be modified
180+
new_rule (list[str]): the new rule to replace the old rule
181+
"""
182+
filter_query = {}
183+
for index, value in enumerate(old_rule):
184+
filter_query[f"v{index}"] = value
185+
186+
await self._collection.find_one_and_update(
187+
filter_query,
188+
{"$set": {f"v{index}": value for index, value in enumerate(new_rule)}},
189+
)
190+
191+
return None
192+
193+
async def update_policies(self, sec, ptype, old_rules, new_rules):
194+
"""Update the old_rule with the new_rule in the database (storage).
195+
196+
Args:
197+
sec (str): section type
198+
ptype (str): policy type
199+
old_rules (list[list[str]]): the old rules that needs to be modified
200+
new_rules (list[list[str]]): the new rules to replace the old rule
201+
"""
202+
for old_rule, new_rule in zip(old_rules, new_rules):
203+
await self.update_policy(sec, ptype, old_rule, new_rule)

tests/asynchronous/test_adapter.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,62 @@ async def test_filtered_policy_with_raw_query(self):
338338
self.assertFalse(e.enforce("bob", "data1", "write"))
339339
self.assertFalse(e.enforce("bob", "data2", "read"))
340340
self.assertTrue(e.enforce("bob", "data2", "write"))
341+
342+
async def test_update_policy(self):
343+
e = await get_enforcer()
344+
example_p = ["mike", "cookie", "eat"]
345+
346+
self.assertTrue(e.enforce("alice", "data1", "read"))
347+
await e.update_policy(["alice", "data1", "read"], ["alice", "data1", "no_read"])
348+
self.assertFalse(e.enforce("alice", "data1", "read"))
349+
350+
self.assertFalse(e.enforce("bob", "data1", "read"))
351+
await e.add_policy(example_p)
352+
await e.update_policy(example_p, ["bob", "data1", "read"])
353+
self.assertTrue(e.enforce("bob", "data1", "read"))
354+
355+
self.assertFalse(e.enforce("bob", "data1", "write"))
356+
await e.update_policy(["bob", "data1", "read"], ["bob", "data1", "write"])
357+
self.assertTrue(e.enforce("bob", "data1", "write"))
358+
359+
self.assertTrue(e.enforce("bob", "data2", "write"))
360+
await e.update_policy(["bob", "data2", "write"], ["bob", "data2", "read"])
361+
self.assertFalse(e.enforce("bob", "data2", "write"))
362+
363+
self.assertTrue(e.enforce("bob", "data2", "read"))
364+
await e.update_policy(["bob", "data2", "read"], ["carl", "data2", "write"])
365+
self.assertFalse(e.enforce("bob", "data2", "write"))
366+
367+
self.assertTrue(e.enforce("carl", "data2", "write"))
368+
await e.update_policy(["carl", "data2", "write"], ["carl", "data2", "no_write"])
369+
self.assertFalse(e.enforce("bob", "data2", "write"))
370+
371+
async def test_update_policies(self):
372+
e = await get_enforcer()
373+
374+
old_rule_0 = ["alice", "data1", "read"]
375+
old_rule_1 = ["bob", "data2", "write"]
376+
old_rule_2 = ["data2_admin", "data2", "read"]
377+
old_rule_3 = ["data2_admin", "data2", "write"]
378+
379+
new_rule_0 = ["alice", "data_test", "read"]
380+
new_rule_1 = ["bob", "data_test", "write"]
381+
new_rule_2 = ["data2_admin", "data_test", "read"]
382+
new_rule_3 = ["data2_admin", "data_test", "write"]
383+
384+
old_rules = [old_rule_0, old_rule_1, old_rule_2, old_rule_3]
385+
new_rules = [new_rule_0, new_rule_1, new_rule_2, new_rule_3]
386+
387+
await e.update_policies(old_rules, new_rules)
388+
389+
self.assertFalse(e.enforce("alice", "data1", "read"))
390+
self.assertTrue(e.enforce("alice", "data_test", "read"))
391+
392+
self.assertFalse(e.enforce("bob", "data2", "write"))
393+
self.assertTrue(e.enforce("bob", "data_test", "write"))
394+
395+
self.assertFalse(e.enforce("data2_admin", "data2", "read"))
396+
self.assertTrue(e.enforce("data2_admin", "data_test", "read"))
397+
398+
self.assertFalse(e.enforce("data2_admin", "data2", "write"))
399+
self.assertTrue(e.enforce("data2_admin", "data_test", "write"))

tests/test_adapter.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,65 @@ async def test_filtered_policy_with_raw_query(self):
343343
self.assertFalse(e.enforce("bob", "data2", "read"))
344344
self.assertTrue(e.enforce("bob", "data2", "write"))
345345

346+
async def test_update_policy(self):
347+
e = get_enforcer()
348+
example_p = ["mike", "cookie", "eat"]
349+
350+
self.assertTrue(e.enforce("alice", "data1", "read"))
351+
e.update_policy(["alice", "data1", "read"], ["alice", "data1", "no_read"])
352+
self.assertFalse(e.enforce("alice", "data1", "read"))
353+
354+
self.assertFalse(e.enforce("bob", "data1", "read"))
355+
e.add_policy(example_p)
356+
e.update_policy(example_p, ["bob", "data1", "read"])
357+
self.assertTrue(e.enforce("bob", "data1", "read"))
358+
359+
self.assertFalse(e.enforce("bob", "data1", "write"))
360+
e.update_policy(["bob", "data1", "read"], ["bob", "data1", "write"])
361+
self.assertTrue(e.enforce("bob", "data1", "write"))
362+
363+
self.assertTrue(e.enforce("bob", "data2", "write"))
364+
e.update_policy(["bob", "data2", "write"], ["bob", "data2", "read"])
365+
self.assertFalse(e.enforce("bob", "data2", "write"))
366+
367+
self.assertTrue(e.enforce("bob", "data2", "read"))
368+
e.update_policy(["bob", "data2", "read"], ["carl", "data2", "write"])
369+
self.assertFalse(e.enforce("bob", "data2", "write"))
370+
371+
self.assertTrue(e.enforce("carl", "data2", "write"))
372+
e.update_policy(["carl", "data2", "write"], ["carl", "data2", "no_write"])
373+
self.assertFalse(e.enforce("bob", "data2", "write"))
374+
375+
async def test_update_policies(self):
376+
e = get_enforcer()
377+
378+
old_rule_0 = ["alice", "data1", "read"]
379+
old_rule_1 = ["bob", "data2", "write"]
380+
old_rule_2 = ["data2_admin", "data2", "read"]
381+
old_rule_3 = ["data2_admin", "data2", "write"]
382+
383+
new_rule_0 = ["alice", "data_test", "read"]
384+
new_rule_1 = ["bob", "data_test", "write"]
385+
new_rule_2 = ["data2_admin", "data_test", "read"]
386+
new_rule_3 = ["data2_admin", "data_test", "write"]
387+
388+
old_rules = [old_rule_0, old_rule_1, old_rule_2, old_rule_3]
389+
new_rules = [new_rule_0, new_rule_1, new_rule_2, new_rule_3]
390+
391+
e.update_policies(old_rules, new_rules)
392+
393+
self.assertFalse(e.enforce("alice", "data1", "read"))
394+
self.assertTrue(e.enforce("alice", "data_test", "read"))
395+
396+
self.assertFalse(e.enforce("bob", "data2", "write"))
397+
self.assertTrue(e.enforce("bob", "data_test", "write"))
398+
399+
self.assertFalse(e.enforce("data2_admin", "data2", "read"))
400+
self.assertTrue(e.enforce("data2_admin", "data_test", "read"))
401+
402+
self.assertFalse(e.enforce("data2_admin", "data2", "write"))
403+
self.assertTrue(e.enforce("data2_admin", "data_test", "write"))
404+
346405
def test_str(self):
347406
"""
348407
test __str__ function

0 commit comments

Comments
 (0)