@@ -1747,23 +1747,38 @@ impl Expr {
1747
1747
pub fn infer_placeholder_types ( self , schema : & DFSchema ) -> Result < ( Expr , bool ) > {
1748
1748
let mut has_placeholder = false ;
1749
1749
self . transform ( |mut expr| {
1750
- // Default to assuming the arguments are the same type
1751
- if let Expr :: BinaryExpr ( BinaryExpr { left, op : _, right } ) = & mut expr {
1752
- rewrite_placeholder ( left. as_mut ( ) , right. as_ref ( ) , schema) ?;
1753
- rewrite_placeholder ( right. as_mut ( ) , left. as_ref ( ) , schema) ?;
1754
- } ;
1755
- if let Expr :: Between ( Between {
1756
- expr,
1757
- negated : _,
1758
- low,
1759
- high,
1760
- } ) = & mut expr
1761
- {
1762
- rewrite_placeholder ( low. as_mut ( ) , expr. as_ref ( ) , schema) ?;
1763
- rewrite_placeholder ( high. as_mut ( ) , expr. as_ref ( ) , schema) ?;
1764
- }
1765
- if let Expr :: Placeholder ( _) = & expr {
1766
- has_placeholder = true ;
1750
+ match & mut expr {
1751
+ // Default to assuming the arguments are the same type
1752
+ Expr :: BinaryExpr ( BinaryExpr { left, op : _, right } ) => {
1753
+ rewrite_placeholder ( left. as_mut ( ) , right. as_ref ( ) , schema) ?;
1754
+ rewrite_placeholder ( right. as_mut ( ) , left. as_ref ( ) , schema) ?;
1755
+ }
1756
+ Expr :: Between ( Between {
1757
+ expr,
1758
+ negated : _,
1759
+ low,
1760
+ high,
1761
+ } ) => {
1762
+ rewrite_placeholder ( low. as_mut ( ) , expr. as_ref ( ) , schema) ?;
1763
+ rewrite_placeholder ( high. as_mut ( ) , expr. as_ref ( ) , schema) ?;
1764
+ }
1765
+ Expr :: InList ( InList {
1766
+ expr,
1767
+ list,
1768
+ negated : _,
1769
+ } ) => {
1770
+ for item in list. iter_mut ( ) {
1771
+ rewrite_placeholder ( item, expr. as_ref ( ) , schema) ?;
1772
+ }
1773
+ }
1774
+ Expr :: Like ( Like { expr, pattern, .. } )
1775
+ | Expr :: SimilarTo ( Like { expr, pattern, .. } ) => {
1776
+ rewrite_placeholder ( pattern. as_mut ( ) , expr. as_ref ( ) , schema) ?;
1777
+ }
1778
+ Expr :: Placeholder ( _) => {
1779
+ has_placeholder = true ;
1780
+ }
1781
+ _ => { }
1767
1782
}
1768
1783
Ok ( Transformed :: yes ( expr) )
1769
1784
} )
@@ -3185,10 +3200,117 @@ mod test {
3185
3200
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue ,
3186
3201
ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl , Volatility ,
3187
3202
} ;
3203
+ use arrow:: datatypes:: { Field , Schema } ;
3188
3204
use sqlparser:: ast;
3189
3205
use sqlparser:: ast:: { Ident , IdentWithAlias } ;
3190
3206
use std:: any:: Any ;
3191
3207
3208
+ #[ test]
3209
+ fn infer_placeholder_in_clause ( ) {
3210
+ // SELECT * FROM employees WHERE department_id IN ($1, $2, $3);
3211
+ let column = col ( "department_id" ) ;
3212
+ let param_placeholders = vec ! [
3213
+ Expr :: Placeholder ( Placeholder {
3214
+ id: "$1" . to_string( ) ,
3215
+ data_type: None ,
3216
+ } ) ,
3217
+ Expr :: Placeholder ( Placeholder {
3218
+ id: "$2" . to_string( ) ,
3219
+ data_type: None ,
3220
+ } ) ,
3221
+ Expr :: Placeholder ( Placeholder {
3222
+ id: "$3" . to_string( ) ,
3223
+ data_type: None ,
3224
+ } ) ,
3225
+ ] ;
3226
+ let in_list = Expr :: InList ( InList {
3227
+ expr : Box :: new ( column) ,
3228
+ list : param_placeholders,
3229
+ negated : false ,
3230
+ } ) ;
3231
+
3232
+ let schema = Arc :: new ( Schema :: new ( vec ! [
3233
+ Field :: new( "name" , DataType :: Utf8 , true ) ,
3234
+ Field :: new( "department_id" , DataType :: Int32 , true ) ,
3235
+ ] ) ) ;
3236
+ let df_schema = DFSchema :: try_from ( schema) . unwrap ( ) ;
3237
+
3238
+ let ( inferred_expr, contains_placeholder) =
3239
+ in_list. infer_placeholder_types ( & df_schema) . unwrap ( ) ;
3240
+
3241
+ assert ! ( contains_placeholder) ;
3242
+
3243
+ match inferred_expr {
3244
+ Expr :: InList ( in_list) => {
3245
+ for expr in in_list. list {
3246
+ match expr {
3247
+ Expr :: Placeholder ( placeholder) => {
3248
+ assert_eq ! (
3249
+ placeholder. data_type,
3250
+ Some ( DataType :: Int32 ) ,
3251
+ "Placeholder {} should infer Int32" ,
3252
+ placeholder. id
3253
+ ) ;
3254
+ }
3255
+ _ => panic ! ( "Expected Placeholder expression" ) ,
3256
+ }
3257
+ }
3258
+ }
3259
+ _ => panic ! ( "Expected InList expression" ) ,
3260
+ }
3261
+ }
3262
+
3263
+ #[ test]
3264
+ fn infer_placeholder_like_and_similar_to ( ) {
3265
+ // name LIKE $1
3266
+ let schema =
3267
+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( "name" , DataType :: Utf8 , true ) ] ) ) ;
3268
+ let df_schema = DFSchema :: try_from ( schema) . unwrap ( ) ;
3269
+
3270
+ let like = Like {
3271
+ expr : Box :: new ( col ( "name" ) ) ,
3272
+ pattern : Box :: new ( Expr :: Placeholder ( Placeholder {
3273
+ id : "$1" . to_string ( ) ,
3274
+ data_type : None ,
3275
+ } ) ) ,
3276
+ negated : false ,
3277
+ case_insensitive : false ,
3278
+ escape_char : None ,
3279
+ } ;
3280
+
3281
+ let expr = Expr :: Like ( like. clone ( ) ) ;
3282
+
3283
+ let ( inferred_expr, _) = expr. infer_placeholder_types ( & df_schema) . unwrap ( ) ;
3284
+ match inferred_expr {
3285
+ Expr :: Like ( like) => match * like. pattern {
3286
+ Expr :: Placeholder ( placeholder) => {
3287
+ assert_eq ! ( placeholder. data_type, Some ( DataType :: Utf8 ) ) ;
3288
+ }
3289
+ _ => panic ! ( "Expected Placeholder" ) ,
3290
+ } ,
3291
+ _ => panic ! ( "Expected Like" ) ,
3292
+ }
3293
+
3294
+ // name SIMILAR TO $1
3295
+ let expr = Expr :: SimilarTo ( like) ;
3296
+
3297
+ let ( inferred_expr, _) = expr. infer_placeholder_types ( & df_schema) . unwrap ( ) ;
3298
+ match inferred_expr {
3299
+ Expr :: SimilarTo ( like) => match * like. pattern {
3300
+ Expr :: Placeholder ( placeholder) => {
3301
+ assert_eq ! (
3302
+ placeholder. data_type,
3303
+ Some ( DataType :: Utf8 ) ,
3304
+ "Placeholder {} should infer Utf8" ,
3305
+ placeholder. id
3306
+ ) ;
3307
+ }
3308
+ _ => panic ! ( "Expected Placeholder expression" ) ,
3309
+ } ,
3310
+ _ => panic ! ( "Expected SimilarTo expression" ) ,
3311
+ }
3312
+ }
3313
+
3192
3314
#[ test]
3193
3315
#[ allow( deprecated) ]
3194
3316
fn format_case_when ( ) -> Result < ( ) > {
0 commit comments