Skip to content

Commit 2065a79

Browse files
authored
feat: support BatchAdapter and UpdateAdapter interfaces (#6)
1 parent 328d63b commit 2065a79

File tree

3 files changed

+296
-2
lines changed

3 files changed

+296
-2
lines changed

casbin_async_redis_adapter/adapter.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def _delete_policy_lines(self, ptype, rule):
101101
await self.client.lrem(self.key, 0, json.dumps(line.dict()))
102102

103103
async def save_policy(self, model) -> bool:
104-
"""Implement add Interface for casbin. Save the policy in mongodb
104+
"""Implement add Interface for casbin. Save the policy in redis
105105
106106
Args:
107107
model (Class Model): Casbin Model which loads from .conf file usually.
@@ -131,6 +131,21 @@ async def add_policy(self, sec, ptype, rule):
131131
await self._save_policy_line(ptype, rule)
132132
return True
133133

134+
async def add_policies(self, sec, ptype, rules):
135+
"""AddPolicies adds policy rules to the storage.
136+
137+
Args:
138+
sec (str): Section name, 'g' or 'p'
139+
ptype (str): Policy type, 'g', 'g2', 'p', etc.
140+
rules: Casbin rules will be added
141+
142+
Returns:
143+
bool: True if succeed else False
144+
"""
145+
for rule in rules:
146+
await self.add_policy(sec, ptype, rule)
147+
return True
148+
134149
async def remove_policy(self, sec, ptype, rule):
135150
"""Remove policy rules in redis(rules duplicate will all be removed)
136151
@@ -145,6 +160,21 @@ async def remove_policy(self, sec, ptype, rule):
145160
await self._delete_policy_lines(ptype, rule)
146161
return True
147162

163+
async def remove_policies(self, sec, ptype, rules):
164+
"""RemovePolicies removes policy rules from the storage.
165+
166+
Args:
167+
sec (str): Section name, 'g' or 'p'
168+
ptype (str): Policy type, 'g', 'g2', 'p', etc.
169+
rules: Casbin rules will be removed
170+
171+
Returns:
172+
bool: True if succeed else False
173+
"""
174+
for rule in rules:
175+
await self.remove_policy(sec, ptype, rule)
176+
return True
177+
148178
async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
149179
"""Remove policy rules that match the filter from the storage.
150180
This is part of the Auto-Save feature.
@@ -183,3 +213,90 @@ async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
183213

184214
await self.client.lrem(self.key, 0, "__CASBIN_DELETED__")
185215
return True
216+
217+
async def update_policy(self, sec, ptype, old_rule, new_rule):
218+
"""
219+
update_policy updates a policy rule from storage.
220+
This is part of the Auto-Save feature.
221+
222+
Args:
223+
sec (str): Section name, 'g' or 'p'
224+
ptype (str): Policy type, 'g', 'g2', 'p', etc.
225+
old_rule: Casbin rule if it is exactly same as will be removed.
226+
new_rule: Casbin rule if it is exactly same as will be added.
227+
228+
Returns:
229+
bool: True if succeed else False
230+
"""
231+
old_rule_obj = CasbinRule(ptype=ptype)
232+
new_rule_obj = CasbinRule(ptype=ptype)
233+
for index, value in enumerate(old_rule):
234+
setattr(old_rule_obj, f"v{index}", value)
235+
for index, value in enumerate(new_rule):
236+
setattr(new_rule_obj, f"v{index}", value)
237+
238+
# Convert old_rule_obj and new_rule_obj to json
239+
old_rule_json = json.dumps(old_rule_obj.dict())
240+
new_rule_json = json.dumps(new_rule_obj.dict())
241+
242+
lua_script = """
243+
local old_rule_json = ARGV[1]
244+
local new_rule_json = ARGV[2]
245+
local rules = redis.call('lrange', KEYS[1], 0, -1)
246+
for i, rule_json in ipairs(rules) do
247+
local rule = cjson.decode(rule_json)
248+
if rule.ptype == ARGV[3] and rule_json == old_rule_json then
249+
redis.call('lset', KEYS[1], i-1, new_rule_json)
250+
return 1
251+
end
252+
end
253+
return 0
254+
"""
255+
256+
result = await self.client.eval(
257+
lua_script, 1, self.key, old_rule_json, new_rule_json, ptype
258+
)
259+
260+
return result == 1
261+
262+
async def update_policies(self, sec, ptype, old_rules, new_rules):
263+
"""
264+
UpdatePolicies updates some policy rules to storage, like db, redis.
265+
266+
Args:
267+
sec (str): Section name, 'g' or 'p'
268+
ptype (str): Policy type, 'g', 'g2', 'p', etc.
269+
old_rules: Casbin rule if it is exactly same as will be removed.
270+
new_rules: Casbin rule if it is exactly same as will be added.
271+
272+
Returns:
273+
bool: True if succeed else False
274+
"""
275+
for i in range(len(old_rules)):
276+
await self.update_policy(sec, ptype, old_rules[i], new_rules[i])
277+
return True
278+
279+
async def update_filtered_policies(
280+
self, sec, ptype, new_rules, field_index, *field_values
281+
):
282+
"""
283+
update_filtered_policies deletes old rules and adds new rules.
284+
285+
Args:
286+
sec (str): Section name, 'g' or 'p'
287+
ptype (str): Policy type, 'g', 'g2', 'p', etc.
288+
new_rules: Casbin rule if it is exactly same as will be added.
289+
field_index (int): The policy index at which the filed_values begins filtering. Its range is [0, 5]
290+
field_values(List[str]): A list of rules to filter policy which starts from
291+
292+
Returns:
293+
bool: True if succeed else False
294+
"""
295+
if not (0 <= field_index <= 5):
296+
return False
297+
if not (1 <= field_index + len(field_values) <= 6):
298+
return False
299+
300+
await self.remove_filtered_policy(sec, ptype, field_index, *field_values)
301+
await self.add_policies(sec, ptype, new_rules)
302+
return True

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
casbin>=1.34.0
2-
redis>=5.0.0
2+
redis>=5.0.0

tests/test_adapter.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,36 @@ async def test_add_policy(self):
100100
self.assertTrue(e.enforce("alice", "data2", "read"))
101101
self.assertTrue(e.enforce("alice", "data2", "write"))
102102

103+
async def test_add_policies(self):
104+
"""
105+
test add_policies
106+
"""
107+
e = await get_enforcer()
108+
adapter = e.get_adapter()
109+
self.assertTrue(e.enforce("alice", "data1", "read"))
110+
self.assertFalse(e.enforce("alice", "data1", "write"))
111+
self.assertFalse(e.enforce("bob", "data2", "read"))
112+
self.assertTrue(e.enforce("bob", "data2", "write"))
113+
self.assertTrue(e.enforce("alice", "data2", "read"))
114+
self.assertTrue(e.enforce("alice", "data2", "write"))
115+
116+
# test add_policies after insert 2 rules
117+
await adapter.add_policies(
118+
sec="p",
119+
ptype="p",
120+
rules=(("alice", "data1", "write"), ("bob", "data2", "read")),
121+
)
122+
123+
# reload policies from database
124+
await e.load_policy()
125+
126+
self.assertTrue(e.enforce("alice", "data1", "read"))
127+
self.assertTrue(e.enforce("alice", "data1", "write"))
128+
self.assertTrue(e.enforce("bob", "data2", "read"))
129+
self.assertTrue(e.enforce("bob", "data2", "write"))
130+
self.assertTrue(e.enforce("alice", "data2", "read"))
131+
self.assertTrue(e.enforce("alice", "data2", "write"))
132+
103133
async def test_remove_policy(self):
104134
"""
105135
test remove_policy
@@ -129,6 +159,38 @@ async def test_remove_policy(self):
129159
self.assertFalse(e.enforce("alice", "data2", "write"))
130160
self.assertTrue(result)
131161

162+
async def test_remove_policies(self):
163+
"""
164+
test remove_policies
165+
"""
166+
e = await get_enforcer()
167+
adapter = e.get_adapter()
168+
169+
self.assertFalse(e.enforce("alice", "data3", "write"))
170+
self.assertFalse(e.enforce("alice", "data3", "read"))
171+
172+
await adapter.add_policies(
173+
sec="p",
174+
ptype="p",
175+
rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
176+
)
177+
178+
await e.load_policy()
179+
self.assertTrue(e.enforce("alice", "data3", "write"))
180+
self.assertTrue(e.enforce("alice", "data3", "read"))
181+
182+
# test remove_policies after delete delete 2 rules
183+
result = await adapter.remove_policies(
184+
sec="p",
185+
ptype="p",
186+
rules=(("alice", "data3", "read"), ("alice", "data3", "write")),
187+
)
188+
189+
await e.load_policy()
190+
self.assertFalse(e.enforce("alice", "data3", "write"))
191+
self.assertFalse(e.enforce("alice", "data3", "read"))
192+
self.assertTrue(result)
193+
132194
async def test_remove_policy_no_remove_when_rule_is_incomplete(self):
133195
adapter = Adapter("localhost", 6379)
134196
e = casbin.AsyncEnforcer(get_fixture("rbac_with_resources_roles.conf"), adapter)
@@ -213,6 +275,121 @@ async def test_remove_filtered_policy(self):
213275
self.assertFalse(e.enforce("alice", "data2", "read"))
214276
self.assertFalse(e.enforce("alice", "data2", "write"))
215277

278+
async def test_update_policy(self):
279+
"""
280+
test update_policy
281+
"""
282+
e = await get_enforcer()
283+
adapter = e.get_adapter()
284+
self.assertTrue(e.enforce("alice", "data1", "read"))
285+
self.assertFalse(e.enforce("alice", "data1", "write"))
286+
self.assertFalse(e.enforce("bob", "data2", "read"))
287+
self.assertTrue(e.enforce("bob", "data2", "write"))
288+
self.assertTrue(e.enforce("alice", "data2", "read"))
289+
self.assertTrue(e.enforce("alice", "data2", "write"))
290+
291+
# test update_policy after update a rule
292+
result = await adapter.update_policy(
293+
sec="p",
294+
ptype="p",
295+
old_rule=("bob", "data2", "write"),
296+
new_rule=("bob", "data1", "write"),
297+
)
298+
299+
# reload policies from database
300+
await e.load_policy()
301+
302+
self.assertTrue(e.enforce("alice", "data1", "read"))
303+
self.assertFalse(e.enforce("alice", "data1", "write"))
304+
self.assertFalse(e.enforce("bob", "data2", "read"))
305+
self.assertFalse(e.enforce("bob", "data2", "write"))
306+
self.assertTrue(e.enforce("bob", "data1", "write"))
307+
self.assertTrue(e.enforce("alice", "data2", "read"))
308+
self.assertTrue(e.enforce("alice", "data2", "write"))
309+
self.assertTrue(result)
310+
311+
async def test_update_policies(self):
312+
"""
313+
test update_policies
314+
"""
315+
e = await get_enforcer()
316+
adapter = e.get_adapter()
317+
self.assertFalse(e.enforce("alice", "data3", "write"))
318+
self.assertFalse(e.enforce("alice", "data3", "read"))
319+
320+
await adapter.add_policies(
321+
sec="p",
322+
ptype="p",
323+
rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
324+
)
325+
326+
await e.load_policy()
327+
self.assertTrue(e.enforce("alice", "data3", "write"))
328+
self.assertTrue(e.enforce("alice", "data3", "read"))
329+
330+
# test update_policies after update 2 rules
331+
result = await adapter.update_policies(
332+
sec="p",
333+
ptype="p",
334+
old_rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
335+
new_rules=(("alice", "data4", "write"), ("alice", "data4", "read")),
336+
)
337+
338+
await e.load_policy()
339+
self.assertFalse(e.enforce("alice", "data3", "write"))
340+
self.assertFalse(e.enforce("alice", "data3", "read"))
341+
self.assertTrue(e.enforce("alice", "data4", "write"))
342+
self.assertTrue(e.enforce("alice", "data4", "read"))
343+
self.assertTrue(result)
344+
345+
async def test_update_filtered_policies(self):
346+
"""
347+
test update_filtered_policies
348+
"""
349+
e = await get_enforcer()
350+
adapter = e.get_adapter()
351+
self.assertFalse(e.enforce("alice", "data3", "write"))
352+
self.assertFalse(e.enforce("alice", "data3", "read"))
353+
354+
await adapter.add_policies(
355+
sec="p",
356+
ptype="p",
357+
rules=(("alice", "data3", "write"), ("alice", "data3", "read")),
358+
)
359+
360+
await e.load_policy()
361+
self.assertTrue(e.enforce("alice", "data3", "write"))
362+
self.assertTrue(e.enforce("alice", "data3", "read"))
363+
364+
# test update_filtered_policies
365+
result = await adapter.remove_filtered_policy(
366+
"g", "g", 6, "alice", "data2_admin"
367+
)
368+
await e.load_policy()
369+
self.assertFalse(result)
370+
371+
result = await adapter.remove_filtered_policy(
372+
"g", "g", 0, *[f"v{i}" for i in range(7)]
373+
)
374+
await e.load_policy()
375+
self.assertFalse(result)
376+
377+
result = await adapter.update_filtered_policies(
378+
"p",
379+
"p",
380+
(("alice", "data4", "write"), ("alice", "data4", "read")),
381+
0,
382+
"alice",
383+
"data3",
384+
)
385+
386+
await e.load_policy()
387+
self.assertFalse(e.enforce("alice", "data3", "write"))
388+
self.assertFalse(e.enforce("alice", "data3", "read"))
389+
self.assertTrue(e.enforce("alice", "data4", "write"))
390+
self.assertTrue(e.enforce("alice", "data4", "read"))
391+
self.assertTrue(result)
392+
216393
def test_str(self):
217394
"""
218395
test __str__ function

0 commit comments

Comments
 (0)