Skip to content

Commit 2eab28c

Browse files
committed
chore: PR advices
1 parent 68ffc67 commit 2eab28c

File tree

3 files changed

+54
-36
lines changed

3 files changed

+54
-36
lines changed

src/common/query/src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub enum Error {
3131
PyUdf {
3232
// TODO(discord9): find a way that prevent circle depend(query<-script<-query) and can use script's error type
3333
msg: String,
34+
backtrace: Backtrace,
3435
},
3536

3637
#[snafu(display(

src/script/src/python/engine.rs

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ use common_query::prelude::Signature;
2626
use common_query::Output;
2727
use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult};
2828
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
29-
use common_telemetry::logging;
3029
use datafusion_expr::Volatility;
31-
use datatypes::prelude::ConcreteDataType;
3230
use datatypes::schema::{ColumnSchema, SchemaRef};
3331
use datatypes::vectors::VectorRef;
3432
use futures::Stream;
@@ -38,22 +36,28 @@ use snafu::{ensure, ResultExt};
3836
use sql::statements::statement::Statement;
3937

4038
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
41-
use crate::python::coprocessor::{exec_parsed, parse, CoprocessorRef};
39+
use crate::python::coprocessor::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};
4240
use crate::python::error::{self, Result};
4341

4442
const PY_ENGINE: &str = "python";
4543

46-
pub struct PyUdf {
44+
#[derive(Debug)]
45+
pub struct PyUDF {
4746
copr: CoprocessorRef,
4847
}
4948

50-
impl std::fmt::Display for PyUdf {
49+
impl std::fmt::Display for PyUDF {
5150
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52-
write!(f, "{}", &self.copr.name)
51+
write!(
52+
f,
53+
"{}({})->",
54+
&self.copr.name,
55+
&self.copr.deco_args.arg_names.join(",")
56+
)
5357
}
5458
}
5559

56-
impl PyUdf {
60+
impl PyUDF {
5761
fn from_copr(copr: CoprocessorRef) -> Arc<Self> {
5862
Arc::new(Self { copr })
5963
}
@@ -79,7 +83,7 @@ impl PyUdf {
7983
}
8084
}
8185

82-
impl Function for PyUdf {
86+
impl Function for PyUDF {
8387
fn name(&self) -> &str {
8488
&self.copr.name
8589
}
@@ -88,29 +92,46 @@ impl Function for PyUdf {
8892
&self,
8993
_input_types: &[datatypes::prelude::ConcreteDataType],
9094
) -> common_query::error::Result<datatypes::prelude::ConcreteDataType> {
91-
// TODO: use correct return annotation if exist
92-
Ok(self.copr.return_types[0]
93-
.clone()
94-
.map_or(ConcreteDataType::float64_datatype(), |anno| {
95-
anno.datatype
96-
.unwrap_or(ConcreteDataType::float64_datatype())
97-
}))
95+
// TODO(discord9): use correct return annotation if exist
96+
match self.copr.return_types.get(0) {
97+
Some(Some(AnnotationInfo {
98+
datatype: Some(ty), ..
99+
})) => Ok(ty.to_owned()),
100+
_ => PyUdfSnafu {
101+
msg: "Can't found return type for python UDF {self}",
102+
}
103+
.fail(),
104+
}
98105
}
99106

100107
fn signature(&self) -> common_query::prelude::Signature {
101-
Signature::uniform(
102-
self.copr.arg_types.len(),
103-
ConcreteDataType::numerics(),
104-
Volatility::Immutable,
105-
)
108+
// try our best to get a type signature
109+
let mut arg_types = Vec::with_capacity(self.copr.arg_types.len());
110+
let mut know_all_types = true;
111+
for ty in self.copr.arg_types.iter() {
112+
match ty {
113+
Some(AnnotationInfo {
114+
datatype: Some(ty), ..
115+
}) => arg_types.push(ty.to_owned()),
116+
_ => {
117+
know_all_types = false;
118+
break;
119+
}
120+
}
121+
}
122+
if know_all_types {
123+
Signature::variadic(arg_types, Volatility::Immutable)
124+
} else {
125+
Signature::any(self.copr.arg_types.len(), Volatility::Immutable)
126+
}
106127
}
107128

108129
fn eval(
109130
&self,
110131
_func_ctx: common_function::scalars::function::FunctionContext,
111132
columns: &[datatypes::vectors::VectorRef],
112133
) -> common_query::error::Result<datatypes::vectors::VectorRef> {
113-
// FIXME: exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
134+
// FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
114135
let schema = self.fake_schema(columns);
115136
let columns = columns.to_vec();
116137
// TODO(discord9): remove unwrap
@@ -122,17 +143,13 @@ impl Function for PyUdf {
122143
.build()
123144
})?;
124145
let len = res.columns().len();
125-
if len != 1 {
126-
logging::info!("Python UDF should return exactly one column, found {len} column(s)");
127-
if len == 0 {
128-
return PyUdfSnafu {
129-
msg: format!(
130-
"Python UDF should return exactly one column, found {len} column(s)"
131-
),
132-
}
133-
.fail();
134-
} // if more than one columns, just return first one
135-
}
146+
if len == 0 {
147+
return PyUdfSnafu {
148+
msg: "Python UDF should return exactly one column, found zero column".to_string(),
149+
}
150+
.fail();
151+
} // if more than one columns, just return first one
152+
136153
// TODO(discord9): more error handling
137154
let res0 = res.column(0);
138155
Ok(res0.to_owned())
@@ -148,9 +165,9 @@ impl PyScript {
148165
/// Register Current Script as UDF, register name is same as script name
149166
/// FIXME(discord9): possible inject attack?
150167
pub fn register_udf(&self) {
151-
let udf = PyUdf::from_copr(self.copr.clone());
152-
PyUdf::register_as_udf(udf.clone());
153-
PyUdf::register_to_query_engine(udf, self.query_engine.to_owned());
168+
let udf = PyUDF::from_copr(self.copr.clone());
169+
PyUDF::register_as_udf(udf.clone());
170+
PyUDF::register_to_query_engine(udf, self.query_engine.to_owned());
154171
}
155172
}
156173

src/servers/tests/py_script/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use table::test_util::MemTable;
2222
use crate::create_testing_instance;
2323

2424
#[tokio::test]
25-
async fn test_insert_udf_and_query() -> Result<()> {
25+
async fn test_insert_py_udf_and_query() -> Result<()> {
2626
let query_ctx = Arc::new(QueryContext::new());
2727
let table = MemTable::default_numbers_table();
2828

0 commit comments

Comments
 (0)