11use std:: mem;
22
3+ use readyset_errors:: { unsupported, ReadySetError , ReadySetResult } ;
34use readyset_sql:: analysis:: visit_mut:: { self , VisitorMut } ;
45use readyset_sql:: ast:: { BinaryOperator , Expr , InValue , ItemPlaceholder , Literal , SelectStatement } ;
56
@@ -22,7 +23,7 @@ impl AutoParameterizeVisitor {
2223}
2324
2425impl < ' ast > VisitorMut < ' ast > for AutoParameterizeVisitor {
25- type Error = std :: convert :: Infallible ;
26+ type Error = ReadySetError ;
2627
2728 fn visit_literal ( & mut self , literal : & ' ast mut Literal ) -> Result < ( ) , Self :: Error > {
2829 if matches ! ( literal, Literal :: Placeholder ( _) ) {
@@ -55,11 +56,7 @@ impl<'ast> VisitorMut<'ast> for AutoParameterizeVisitor {
5556 if was_supported {
5657 match expression {
5758 Expr :: BinaryOp { lhs, op, rhs } => match ( lhs. as_mut ( ) , op, rhs. as_mut ( ) ) {
58- (
59- Expr :: Column ( _) ,
60- BinaryOperator :: Equal ,
61- Expr :: Literal ( Literal :: Placeholder ( _) ) ,
62- ) => { }
59+ ( Expr :: Column ( _) , BinaryOperator :: Equal , Expr :: Literal ( Literal :: Placeholder ( _) ) ) => { }
6360 ( Expr :: Row { .. } , BinaryOperator :: Equal , Expr :: Row { exprs, .. } ) => {
6461 for expr in exprs {
6562 if let Expr :: Literal ( lit) = expr {
@@ -71,8 +68,7 @@ impl<'ast> VisitorMut<'ast> for AutoParameterizeVisitor {
7168 }
7269 return Ok ( ( ) ) ;
7370 }
74- ( Expr :: Column ( _) , op, Expr :: Literal ( Literal :: Placeholder ( _) ) )
75- if op. is_ordering_comparison ( ) => { }
71+ ( Expr :: Column ( _) , op, Expr :: Literal ( Literal :: Placeholder ( _) ) ) if op. is_ordering_comparison ( ) => { }
7672 ( Expr :: Column ( _) , BinaryOperator :: Equal , Expr :: Literal ( lit) ) => {
7773 if self . autoparameterize_equals {
7874 self . replace_literal ( lit) ;
@@ -85,11 +81,7 @@ impl<'ast> VisitorMut<'ast> for AutoParameterizeVisitor {
8581 }
8682 return Ok ( ( ) ) ;
8783 }
88- (
89- Expr :: Literal ( _) ,
90- BinaryOperator :: Equal | BinaryOperator :: NotEqual ,
91- Expr :: Column ( _) ,
92- ) => {
84+ ( Expr :: Literal ( _) , BinaryOperator :: Equal | BinaryOperator :: NotEqual , Expr :: Column ( _) ) => {
9385 // for lit = col and lit != col, swap the equality first then revisit
9486 mem:: swap ( lhs, rhs) ;
9587 return self . visit_expr ( expression) ;
@@ -116,40 +108,123 @@ impl<'ast> VisitorMut<'ast> for AutoParameterizeVisitor {
116108 rhs : InValue :: List ( exprs) ,
117109 negated : false ,
118110 } => match lhs. as_ref ( ) {
111+ // Case 1: Single-column IN (a IN (1,2,3))
119112 Expr :: Column ( _)
120- if exprs. iter ( ) . all ( |e| {
121- matches ! (
122- e,
123- Expr :: Literal ( lit) if !matches!( lit, Literal :: Placeholder ( _) )
124- )
125- } ) =>
113+ if exprs
114+ . iter ( )
115+ . all ( |e| matches ! ( e, Expr :: Literal ( lit) if !matches!( lit, Literal :: Placeholder ( _) ) ) ) =>
126116 {
127117 if self . autoparameterize_equals {
128118 let exprs = mem:: replace (
129119 exprs,
130120 std:: iter:: repeat_n (
131- Expr :: Literal ( Literal :: Placeholder (
132- ItemPlaceholder :: QuestionMark ,
133- ) ) ,
121+ Expr :: Literal ( Literal :: Placeholder ( ItemPlaceholder :: QuestionMark ) ) ,
134122 exprs. len ( ) ,
135123 )
136124 . collect ( ) ,
137125 ) ;
138126 let num_exprs = exprs. len ( ) ;
139127 let start_index = self . param_index ;
140- self . out . extend ( exprs . into_iter ( ) . enumerate ( ) . filter_map (
141- move |( i, expr) | match expr {
128+ self . out
129+ . extend ( exprs . into_iter ( ) . enumerate ( ) . filter_map ( move |( i, expr) | match expr {
142130 Expr :: Literal ( lit) => Some ( ( i + start_index, lit) ) ,
143131 // unreachable since we checked everything in the list is a
144132 // literal above, but best
145133 // not to panic regardless
146134 _ => None ,
147- } ,
148- ) ) ;
135+ } ) ) ;
136+ self . param_index += num_exprs;
137+ }
138+ return Ok ( ( ) ) ;
139+ }
140+
141+ // Case 2: Tuple IN ((a, b) IN ((1,2), (3,4)))
142+ Expr :: Row { .. }
143+ if exprs. iter ( ) . all ( |e| {
144+ match e {
145+ Expr :: Row { exprs, .. } => exprs. iter ( ) . all (
146+ |e| matches ! ( e, Expr :: Literal ( lit) if !matches!( lit, Literal :: Placeholder ( _) ) ) ,
147+ ) ,
148+ // FIXME(sqlparser): This is a special case because nom parses `(a, b) IN ((1,2))` as
149+ // `(a, b) IN (1,2)` instead of `((a, b)) IN ((1,2))`.
150+ // This case should be removed once migration to sqlparser is
151+ // finalized.
152+ // To fix this, we readd the removed parens before proceeding.
153+ Expr :: Literal ( lit) if !matches ! ( lit, Literal :: Placeholder ( _) ) => true ,
154+ _ => false ,
155+ }
156+ } ) =>
157+ {
158+ if self . autoparameterize_equals {
159+ // FIXME(sqlparser): this handles the special case mentioned in the comment
160+ // just before this
161+ if !exprs. is_empty ( ) && matches ! ( exprs[ 0 ] , Expr :: Literal ( _) ) {
162+ let _ = mem:: replace (
163+ exprs,
164+ vec ! [ Expr :: Row {
165+ exprs: exprs. clone( ) ,
166+ explicit: false ,
167+ } ] ,
168+ ) ;
169+ } ;
170+
171+ let exprs = mem:: replace (
172+ exprs,
173+ exprs
174+ . iter ( )
175+ . map ( |e| -> Result < Expr , ReadySetError > {
176+ match e {
177+ Expr :: Row { exprs, .. } => Ok ( Expr :: Row {
178+ exprs : std:: iter:: repeat_n (
179+ Expr :: Literal ( Literal :: Placeholder ( ItemPlaceholder :: QuestionMark ) ) ,
180+ exprs. len ( ) ,
181+ )
182+ . collect ( ) ,
183+ explicit : false ,
184+ } ) ,
185+ // ideally, this should be fully checked by the guard
186+ // above, unfortunately, it's not because of the workaround
187+ // mentioned above
188+ _ => unsupported ! ( "Expected a ROW of placeholders" ) ,
189+ }
190+ } )
191+ . collect :: < ReadySetResult < Vec < _ > > > ( ) ?,
192+ ) ;
193+
194+ // same as the error above
195+ let num_exprs: usize = exprs
196+ . iter ( )
197+ . map ( |e| match e {
198+ Expr :: Row { exprs, .. } => Ok ( exprs. len ( ) ) ,
199+ _ => unsupported ! ( "Expected a ROW of placeholders" ) ,
200+ } )
201+ . collect :: < ReadySetResult < Vec < _ > > > ( ) ?
202+ . into_iter ( )
203+ . sum ( ) ;
204+
205+ let start_index = self . param_index ;
206+ let param_offset = 0 ;
207+
208+ self . out . extend (
209+ exprs
210+ . into_iter ( )
211+ . flat_map ( |e| match e {
212+ Expr :: Row { exprs, .. } => exprs,
213+ _ => unreachable ! ( ) , // checked above
214+ } )
215+ . enumerate ( )
216+ . map ( |( i, e) | match e {
217+ Expr :: Literal ( lit) => Ok ( ( start_index + param_offset + i, lit) ) ,
218+ _ => unsupported ! ( "Expected ROWs to only contain Literals" ) ,
219+ } )
220+ . collect :: < ReadySetResult < Vec < _ > > > ( ) ?,
221+ ) ;
222+
149223 self . param_index += num_exprs;
150224 }
151225 return Ok ( ( ) ) ;
152226 }
227+
153228 _ => self . in_supported_position = false ,
154229 } ,
155230 _ => self . in_supported_position = false ,
@@ -218,7 +293,6 @@ impl<'ast> VisitorMut<'ast> for AnalyzeLiteralsVisitor {
218293 }
219294 return Ok ( ( ) ) ;
220295 }
221- // We don't parametrize `(a,b) IN ((w,x),(y,z))`
222296 ( Expr :: Row { .. } , BinaryOperator :: Equal , Expr :: Row { exprs, .. } ) => {
223297 self . contains_equal = true ;
224298 for expr in exprs {
@@ -235,11 +309,7 @@ impl<'ast> VisitorMut<'ast> for AnalyzeLiteralsVisitor {
235309 }
236310 return Ok ( ( ) ) ;
237311 }
238- (
239- Expr :: Literal ( _) ,
240- BinaryOperator :: Equal | BinaryOperator :: NotEqual ,
241- Expr :: Column ( _) ,
242- ) => {
312+ ( Expr :: Literal ( _) , BinaryOperator :: Equal | BinaryOperator :: NotEqual , Expr :: Column ( _) ) => {
243313 // for lit = col and lit != col, swap the equality first then revisit
244314 mem:: swap ( lhs, rhs) ;
245315 return self . visit_expr ( expression) ;
@@ -276,14 +346,23 @@ impl<'ast> VisitorMut<'ast> for AnalyzeLiteralsVisitor {
276346 rhs : InValue :: List ( exprs) ,
277347 negated : false ,
278348 } if exprs. iter ( ) . all ( |e| {
279- matches ! (
280- e,
281- Expr :: Literal ( lit) if !matches!( lit, Literal :: Placeholder ( _) )
282- )
349+ match e {
350+ // Case 1: Single-column IN (a IN (1,2,3))
351+ Expr :: Literal ( lit) if !matches ! ( lit, Literal :: Placeholder ( _) ) => true ,
352+ // Case 2: Multi-column IN ((a,b) IN ((1,2), (3,4)))
353+ Expr :: Row { exprs, .. }
354+ if exprs. iter ( ) . all (
355+ |inner| matches ! ( inner, Expr :: Literal ( lit) if !matches!( lit, Literal :: Placeholder ( _) ) ) ,
356+ ) =>
357+ {
358+ true
359+ }
360+ _ => false ,
361+ }
283362 } ) && !self . has_aggregates =>
284363 {
285364 match lhs. as_ref ( ) {
286- Expr :: Column ( _) => {
365+ Expr :: Column ( _) | Expr :: Row { .. } => {
287366 self . contains_equal = true ;
288367 return Ok ( ( ) ) ;
289368 }
@@ -496,10 +575,10 @@ mod tests {
496575 #[ test]
497576 fn literal_in_subquery_where ( ) {
498577 test_auto_parameterize_mysql (
499- "SELECT id FROM users JOIN (SELECT id FROM users WHERE id = 1) s ON users.id = s.id WHERE id = 1" ,
500- "SELECT id FROM users JOIN (SELECT id FROM users WHERE id = 1) s ON users.id = s.id WHERE id = ?" ,
501- vec ! [ ( 0 , 1 . into( ) ) ] ,
502- )
578+ "SELECT id FROM users JOIN (SELECT id FROM users WHERE id = 1) s ON users.id = s.id WHERE id = 1" ,
579+ "SELECT id FROM users JOIN (SELECT id FROM users WHERE id = 1) s ON users.id = s.id WHERE id = ?" ,
580+ vec ! [ ( 0 , 1 . into( ) ) ] ,
581+ )
503582 }
504583
505584 #[ test]
@@ -511,6 +590,29 @@ mod tests {
511590 )
512591 }
513592
593+ #[ test]
594+ fn row_in_predicate ( ) {
595+ // FIXME(sqlparser): Read the FIXME above, the expected query gets parsed incorrectly
596+ // because of nom, but the actual query itself works as expected because of the hardcoded
597+ // check above
598+ // test_auto_parameterize_mysql(
599+ // "SELECT * FROM t WHERE (a, b) IN ((1, 10))",
600+ // "SELECT * FROM t WHERE (a, b) IN ((?, ?))",
601+ // vec![(0, 1.into()), (1, 10.into())],
602+ // );
603+
604+ test_auto_parameterize_mysql (
605+ "SELECT * FROM t WHERE (a, b) IN ((1, 'str'),(2, 'string'))" ,
606+ "SELECT * FROM t WHERE (a, b) IN ((?, ?), (?, ?))" ,
607+ vec ! [
608+ ( 0 , 1 . into( ) ) ,
609+ ( 1 , "str" . into( ) ) ,
610+ ( 2 , 2 . into( ) ) ,
611+ ( 3 , "string" . into( ) ) ,
612+ ] ,
613+ ) ;
614+ }
615+
514616 #[ test]
515617 fn literal_in_in_rhs ( ) {
516618 test_auto_parameterize_mysql (
0 commit comments