@@ -10,6 +10,16 @@ def setUp(self):
1010 self .two_table_join_sql = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`
1111FROM customer c
1212JOIN orders o ON c.`id` = o.`customer_id`
13+ LIMIT 100'''
14+
15+ self .two_table_join_sql_with_schema = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`
16+ FROM someschema.customer c
17+ JOIN someschema.orders o ON c.`id` = o.`customer_id`
18+ LIMIT 100'''
19+
20+ self .two_table_join_sql_with_schema_output = '''SELECT c.`name`, o.`product`, o.`quantity`, o.`territory`
21+ FROM someschema__customer c
22+ JOIN someschema__orders o ON c.`id` = o.`customer_id`
1323LIMIT 100'''
1424
1525 self .expected_rls_enabled_sql = (
@@ -18,6 +28,12 @@ def setUp(self):
1828 "/* rls applied */ orders AS (SELECT * FROM orders WHERE territory = 'Asia')\n "
1929 f"{ self .two_table_join_sql } " )
2030
31+ self .expected_rls_enabled_sql_with_schema = (
32+ "WITH\n "
33+ "/* rls applied */ someschema__customer AS (SELECT * FROM someschema.customer WHERE created_by = 'admin'),\n "
34+ "/* rls applied */ someschema__orders AS (SELECT * FROM someschema.orders WHERE territory = 'Asia')\n "
35+ f"{ self .two_table_join_sql_with_schema_output } " )
36+
2137 self .base = MySQLDataSource ()
2238
2339 def test_row_level_security_control (self ):
@@ -34,6 +50,24 @@ def test_row_level_security_control(self):
3450
3551 self .assertEqual (self .expected_rls_enabled_sql , rls_modified_sql )
3652
53+ def test_row_level_security_control_with_schema (self ):
54+ test_yaml = '''tables:
55+ - table_name: someschema.customer
56+ columns:
57+ - column_name: created_by
58+ column_value: $login_user.username
59+ - table_name: someschema.orders
60+ columns:
61+ - column_name: territory
62+ column_value: Asia'''
63+ rls_modified_sql = self .base .row_level_security_control (self .two_table_join_sql_with_schema , test_yaml , LoginUser ('admin' ))
64+
65+ self .assertEqual (self .expected_rls_enabled_sql_with_schema , rls_modified_sql )
66+
67+ # 测试不带schema的表名的兼容性
68+ rls_modified_sql2 = self .base .row_level_security_control (self .two_table_join_sql , test_yaml , LoginUser ('admin' ))
69+ self .assertEqual (self .expected_rls_enabled_sql_with_schema , rls_modified_sql2 )
70+
3771 def test_cte_replace1 (self ):
3872 original_sql = """SELECT
3973 offer_id,
0 commit comments