Skip to content

Commit 1022a6c

Browse files
authored
Merge pull request #319 from aws-samples/318-row-level-security-enhancement
improve the logic of replacing table name in the original SQL
2 parents 1cefe37 + feae4b9 commit 1022a6c

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

application/nlq/business/datasource/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import yaml
2+
import re
23
from abc import ABC, abstractmethod
34

45
from nlq.business.login_user import LoginUser
@@ -102,6 +103,18 @@ def replace_table_with_cte(sql, table_config: dict):
102103
sql_splits.append(sql)
103104

104105
for table_name, sub_query in table_config.items():
106+
if '.' in table_name:
107+
# 如果表名包含schema name(格式: schema.table), 则将.替换成__
108+
schema_name, table_name_alone = table_name.split('.')
109+
table_name_replaced = table_name.replace('.', '__')
110+
if table_name in sql_splits[1]:
111+
# 替换带schema的表名
112+
sql_splits[1] = re.sub(r'\b{}\b'.format(table_name), table_name_replaced, sql_splits[1])
113+
elif table_name_alone in sql_splits[1]:
114+
# 替换不带schema的表名
115+
sql_splits[1] = re.sub(r'\b{}\b'.format(table_name_alone), table_name_replaced, sql_splits[1])
116+
117+
table_name = table_name_replaced
105118
cte_sql += f"/* rls applied */ {table_name} AS {sub_query},\n"
106119
if origin_sql_has_cte:
107120
cte_sql = cte_sql[:-1]

application/tests/unit_tests/test_row_level_security.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ def setUp(self):
1010
self.two_table_join_sql = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`
1111
FROM customer c
1212
JOIN orders o ON c.`id` = o.`customer_id`
13+
LIMIT 100'''
14+
15+
self.two_table_join_sql_with_schema = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`
16+
FROM someschema.customer c
17+
JOIN someschema.orders o ON c.`id` = o.`customer_id`
18+
LIMIT 100'''
19+
20+
self.two_table_join_sql_with_schema_output = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`
21+
FROM someschema__customer c
22+
JOIN someschema__orders o ON c.`id` = o.`customer_id`
1323
LIMIT 100'''
1424

1525
self.expected_rls_enabled_sql = (
@@ -18,6 +28,12 @@ def setUp(self):
1828
"/* rls applied */ orders AS (SELECT * FROM orders WHERE territory = 'Asia')\n"
1929
f"{self.two_table_join_sql}")
2030

31+
self.expected_rls_enabled_sql_with_schema = (
32+
"WITH\n"
33+
"/* rls applied */ someschema__customer AS (SELECT * FROM someschema.customer WHERE created_by = 'admin'),\n"
34+
"/* rls applied */ someschema__orders AS (SELECT * FROM someschema.orders WHERE territory = 'Asia')\n"
35+
f"{self.two_table_join_sql_with_schema_output}")
36+
2137
self.base = MySQLDataSource()
2238

2339
def test_row_level_security_control(self):
@@ -34,6 +50,24 @@ def test_row_level_security_control(self):
3450

3551
self.assertEqual(self.expected_rls_enabled_sql, rls_modified_sql)
3652

53+
def test_row_level_security_control_with_schema(self):
54+
test_yaml = '''tables:
55+
- table_name: someschema.customer
56+
columns:
57+
- column_name: created_by
58+
column_value: $login_user.username
59+
- table_name: someschema.orders
60+
columns:
61+
- column_name: territory
62+
column_value: Asia'''
63+
rls_modified_sql = self.base.row_level_security_control(self.two_table_join_sql_with_schema, test_yaml, LoginUser('admin'))
64+
65+
self.assertEqual(self.expected_rls_enabled_sql_with_schema, rls_modified_sql)
66+
67+
# 测试不带schema的表名的兼容性
68+
rls_modified_sql2 = self.base.row_level_security_control(self.two_table_join_sql, test_yaml, LoginUser('admin'))
69+
self.assertEqual(self.expected_rls_enabled_sql_with_schema, rls_modified_sql2)
70+
3771
def test_cte_replace1(self):
3872
original_sql = """SELECT
3973
offer_id,

0 commit comments

Comments
 (0)