@@ -26,9 +26,7 @@ use common_query::prelude::Signature;
2626use common_query:: Output ;
2727use common_recordbatch:: error:: { ExternalSnafu , Result as RecordBatchResult } ;
2828use common_recordbatch:: { RecordBatch , RecordBatchStream , SendableRecordBatchStream } ;
29- use common_telemetry:: logging;
3029use datafusion_expr:: Volatility ;
31- use datatypes:: prelude:: ConcreteDataType ;
3230use datatypes:: schema:: { ColumnSchema , SchemaRef } ;
3331use datatypes:: vectors:: VectorRef ;
3432use futures:: Stream ;
@@ -38,22 +36,28 @@ use snafu::{ensure, ResultExt};
3836use sql:: statements:: statement:: Statement ;
3937
4038use crate :: engine:: { CompileContext , EvalContext , Script , ScriptEngine } ;
41- use crate :: python:: coprocessor:: { exec_parsed, parse, CoprocessorRef } ;
39+ use crate :: python:: coprocessor:: { exec_parsed, parse, AnnotationInfo , CoprocessorRef } ;
4240use crate :: python:: error:: { self , Result } ;
4341
4442const PY_ENGINE : & str = "python" ;
4543
46- pub struct PyUdf {
44+ #[ derive( Debug ) ]
45+ pub struct PyUDF {
4746 copr : CoprocessorRef ,
4847}
4948
50- impl std:: fmt:: Display for PyUdf {
49+ impl std:: fmt:: Display for PyUDF {
5150 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
52- write ! ( f, "{}" , & self . copr. name)
51+ write ! (
52+ f,
53+ "{}({})->" ,
54+ & self . copr. name,
55+ & self . copr. deco_args. arg_names. join( "," )
56+ )
5357 }
5458}
5559
56- impl PyUdf {
60+ impl PyUDF {
5761 fn from_copr ( copr : CoprocessorRef ) -> Arc < Self > {
5862 Arc :: new ( Self { copr } )
5963 }
@@ -79,7 +83,7 @@ impl PyUdf {
7983 }
8084}
8185
82- impl Function for PyUdf {
86+ impl Function for PyUDF {
8387 fn name ( & self ) -> & str {
8488 & self . copr . name
8589 }
@@ -88,29 +92,46 @@ impl Function for PyUdf {
8892 & self ,
8993 _input_types : & [ datatypes:: prelude:: ConcreteDataType ] ,
9094 ) -> common_query:: error:: Result < datatypes:: prelude:: ConcreteDataType > {
91- // TODO: use correct return annotation if exist
92- Ok ( self . copr . return_types [ 0 ]
93- . clone ( )
94- . map_or ( ConcreteDataType :: float64_datatype ( ) , |anno| {
95- anno. datatype
96- . unwrap_or ( ConcreteDataType :: float64_datatype ( ) )
97- } ) )
95+ // TODO(discord9): use correct return annotation if exist
96+ match self . copr . return_types . get ( 0 ) {
97+ Some ( Some ( AnnotationInfo {
98+ datatype : Some ( ty) , ..
99+ } ) ) => Ok ( ty. to_owned ( ) ) ,
100+ _ => PyUdfSnafu {
101+ msg : "Can't found return type for python UDF {self}" ,
102+ }
103+ . fail ( ) ,
104+ }
98105 }
99106
100107 fn signature ( & self ) -> common_query:: prelude:: Signature {
101- Signature :: uniform (
102- self . copr . arg_types . len ( ) ,
103- ConcreteDataType :: numerics ( ) ,
104- Volatility :: Immutable ,
105- )
108+ // try our best to get a type signature
109+ let mut arg_types = Vec :: with_capacity ( self . copr . arg_types . len ( ) ) ;
110+ let mut know_all_types = true ;
111+ for ty in self . copr . arg_types . iter ( ) {
112+ match ty {
113+ Some ( AnnotationInfo {
114+ datatype : Some ( ty) , ..
115+ } ) => arg_types. push ( ty. to_owned ( ) ) ,
116+ _ => {
117+ know_all_types = false ;
118+ break ;
119+ }
120+ }
121+ }
122+ if know_all_types {
123+ Signature :: variadic ( arg_types, Volatility :: Immutable )
124+ } else {
125+ Signature :: any ( self . copr . arg_types . len ( ) , Volatility :: Immutable )
126+ }
106127 }
107128
108129 fn eval (
109130 & self ,
110131 _func_ctx : common_function:: scalars:: function:: FunctionContext ,
111132 columns : & [ datatypes:: vectors:: VectorRef ] ,
112133 ) -> common_query:: error:: Result < datatypes:: vectors:: VectorRef > {
113- // FIXME: exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
134+ // FIXME(discord9) : exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
114135 let schema = self . fake_schema ( columns) ;
115136 let columns = columns. to_vec ( ) ;
116137 // TODO(discord9): remove unwrap
@@ -122,17 +143,13 @@ impl Function for PyUdf {
122143 . build ( )
123144 } ) ?;
124145 let len = res. columns ( ) . len ( ) ;
125- if len != 1 {
126- logging:: info!( "Python UDF should return exactly one column, found {len} column(s)" ) ;
127- if len == 0 {
128- return PyUdfSnafu {
129- msg : format ! (
130- "Python UDF should return exactly one column, found {len} column(s)"
131- ) ,
132- }
133- . fail ( ) ;
134- } // if more than one columns, just return first one
135- }
146+ if len == 0 {
147+ return PyUdfSnafu {
148+ msg : "Python UDF should return exactly one column, found zero column" . to_string ( ) ,
149+ }
150+ . fail ( ) ;
151+ } // if more than one columns, just return first one
152+
136153 // TODO(discord9): more error handling
137154 let res0 = res. column ( 0 ) ;
138155 Ok ( res0. to_owned ( ) )
@@ -148,9 +165,9 @@ impl PyScript {
148165 /// Register Current Script as UDF, register name is same as script name
149166 /// FIXME(discord9): possible inject attack?
150167 pub fn register_udf ( & self ) {
151- let udf = PyUdf :: from_copr ( self . copr . clone ( ) ) ;
152- PyUdf :: register_as_udf ( udf. clone ( ) ) ;
153- PyUdf :: register_to_query_engine ( udf, self . query_engine . to_owned ( ) ) ;
168+ let udf = PyUDF :: from_copr ( self . copr . clone ( ) ) ;
169+ PyUDF :: register_as_udf ( udf. clone ( ) ) ;
170+ PyUDF :: register_to_query_engine ( udf, self . query_engine . to_owned ( ) ) ;
154171 }
155172}
156173
0 commit comments