Skip to content

Commit fa22b44

Browse files
committed
feat: add parameter types to query statement
1 parent a006284 commit fa22b44

File tree

5 files changed

+85
-18
lines changed

5 files changed

+85
-18
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/query/src/datafusion/planner.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion::sql::planner::{ContextProvider, PlannerContext, SqlToRel};
2424
use datafusion_common::ScalarValue;
2525
use datafusion_expr::TableSource;
2626
use datatypes::arrow::datatypes::DataType;
27+
use datatypes::prelude::DataType as DataTypeTrait;
2728
use session::context::QueryContextRef;
2829
use snafu::ResultExt;
2930
use sql::statements::explain::Explain;
@@ -51,9 +52,16 @@ impl<'a, S: ContextProvider + Send + Sync> DfPlanner<'a, S> {
5152
pub fn query_to_plan(&self, query: Box<Query>) -> Result<LogicalPlan> {
5253
// todo(hl): original SQL should be provided as an argument
5354
let sql = query.inner.to_string();
55+
let mut context = PlannerContext::new_with_prepare_param_data_types(
56+
query
57+
.param_types()
58+
.iter()
59+
.map(|v| v.as_arrow_type())
60+
.collect(),
61+
);
5462
let result = self
5563
.sql_to_rel
56-
.query_to_plan(query.inner, &mut PlannerContext::default())
64+
.query_to_plan(query.inner, &mut context)
5765
.context(error::PlanSqlSnafu { sql })
5866
.map_err(BoxedError::new)
5967
.context(QueryPlanSnafu)?;

src/servers/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ num_cpus = "1.13"
3838
once_cell = "1.16"
3939
openmetrics-parser = "0.4"
4040
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "b44c9d1360da297b305abf33aecfa94888e1554c" }
41-
pgwire = "0.8"
41+
pgwire = "0.9"
4242
pin-project = "1.0"
4343
prost.workspace = true
4444
query = { path = "../query" }

src/servers/src/postgres/handler.rs

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
use std::ops::Deref;
1616
use std::sync::Arc;
17+
use std::time::Duration;
1718

1819
use async_trait::async_trait;
1920
use common_query::Output;
2021
use common_recordbatch::error::Result as RecordBatchResult;
2122
use common_recordbatch::RecordBatch;
23+
use common_time::timestamp::TimeUnit;
2224
use datatypes::prelude::{ConcreteDataType, Value};
2325
use datatypes::schema::{Schema, SchemaRef};
2426
use futures::{future, stream, Stream, StreamExt};
@@ -28,7 +30,7 @@ use pgwire::api::results::{
2830
binary_query_response, text_query_response, BinaryDataRowEncoder, FieldInfo, Response, Tag,
2931
TextDataRowEncoder,
3032
};
31-
use pgwire::api::stmt::QueryParser;
33+
use pgwire::api::stmt::{QueryParser, StoredStatement};
3234
use pgwire::api::store::MemPortalStore;
3335
use pgwire::api::{ClientInfo, Type};
3436
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
@@ -148,7 +150,7 @@ fn schema_to_pg(origin: &Schema) -> Result<Vec<FieldInfo>> {
148150
col.name.clone(),
149151
None,
150152
None,
151-
type_translate(&col.data_type)?,
153+
type_gt_to_pg(&col.data_type)?,
152154
))
153155
})
154156
.collect::<Result<Vec<FieldInfo>>>()
@@ -198,9 +200,19 @@ fn encode_binary_value(value: &Value, builder: &mut BinaryDataRowEncoder) -> PgW
198200
Value::Float64(v) => builder.append_field(&v.0),
199201
Value::String(v) => builder.append_field(&v.as_utf8()),
200202
Value::Binary(v) => builder.append_field(&v.deref()),
201-
Value::Date(v) => builder.append_field(&v.to_string()),
202-
Value::DateTime(v) => builder.append_field(&v.to_string()),
203-
Value::Timestamp(v) => builder.append_field(&v.to_iso8601_string()),
203+
Value::Date(v) => builder.append_field(&v.to_string()), // TOOD
204+
Value::DateTime(v) => builder.append_field(&v.to_string()), //TODO
205+
Value::Timestamp(v) => {
206+
// convert timestamp to SystemTime
207+
if let Some(ts) = v.convert_to(TimeUnit::Microsecond) {
208+
let sys_time = std::time::UNIX_EPOCH + Duration::from_micros(ts.value() as u64);
209+
builder.append_field(&sys_time)
210+
} else {
211+
Err(PgWireError::ApiError(Box::new(Error::Internal {
212+
err_msg: format!("Failed to conver timestamp to postgres type {v:?}",),
213+
})))
214+
}
215+
}
204216
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
205217
err_msg: format!(
206218
"cannot write value {:?} in postgres protocol: unimplemented",
@@ -210,7 +222,7 @@ fn encode_binary_value(value: &Value, builder: &mut BinaryDataRowEncoder) -> PgW
210222
}
211223
}
212224

213-
fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
225+
fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
214226
match origin {
215227
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
216228
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
@@ -232,13 +244,34 @@ fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
232244
}
233245
}
234246

247+
fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
248+
// Note that we only support a small amount of pg data types
249+
match origin {
250+
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
251+
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
252+
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
253+
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
254+
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
255+
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
256+
&Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
257+
common_time::timestamp::TimeUnit::Millisecond,
258+
)),
259+
&Type::DATE => Ok(ConcreteDataType::date_datatype()),
260+
&Type::TIME => Ok(ConcreteDataType::datetime_datatype()),
261+
_ => error::InternalSnafu {
262+
err_msg: format!("unimplemented datatype {origin:?}"),
263+
}
264+
.fail(),
265+
}
266+
}
267+
235268
#[derive(Default)]
236269
pub struct POCQueryParser;
237270

238271
impl QueryParser for POCQueryParser {
239272
type Statement = (Statement, String);
240273

241-
fn parse_sql(&self, sql: &str) -> PgWireResult<Self::Statement> {
274+
fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult<Self::Statement> {
242275
let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {})
243276
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
244277
if stmts.len() != 1 {
@@ -248,13 +281,22 @@ impl QueryParser for POCQueryParser {
248281
"invalid_prepared_statement_definition".to_owned(),
249282
))))
250283
} else {
251-
Ok((stmts.remove(0), sql.to_owned()))
284+
let mut stmt = stmts.remove(0);
285+
if let Statement::Query(qs) = &mut stmt {
286+
for t in types {
287+
let gt_type =
288+
type_pg_to_gt(t).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
289+
qs.param_types_mut().push(gt_type);
290+
}
291+
}
292+
293+
Ok((stmt, sql.to_owned()))
252294
}
253295
}
254296
}
255297

256298
fn parameter_to_string(portal: &Portal<(Statement, String)>, idx: usize) -> PgWireResult<String> {
257-
let param_type = portal.parameter_types().get(idx).unwrap();
299+
let param_type = portal.statement().parameter_types().get(idx).unwrap();
258300
match param_type {
259301
&Type::VARCHAR | &Type::TEXT => Ok(format!(
260302
"\"{}\"",
@@ -294,6 +336,9 @@ fn parameter_to_string(portal: &Portal<(Statement, String)>, idx: usize) -> PgWi
294336
//
295337
// - getting schema from
296338
// - setting parameters in
339+
//
340+
// Datafusion's LogicalPlan is a good candidate for SELECT. But we need to
341+
// confirm it's support for other SQL command like INSERT, UPDATE.
297342
#[async_trait]
298343
impl ExtendedQueryHandler for PostgresServerHandler {
299344
type Statement = (Statement, String);
@@ -317,15 +362,14 @@ impl ExtendedQueryHandler for PostgresServerHandler {
317362
where
318363
C: ClientInfo + Unpin + Send + Sync,
319364
{
320-
let (_, sql) = portal.statement();
365+
let (_, sql) = portal.statement().statement();
321366

322367
// manually replace variables in prepared statement
323368
let mut sql = sql.to_owned();
324369
for i in 0..portal.parameter_len() {
325370
sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
326371
}
327372

328-
dbg!(&sql);
329373
let output = self
330374
.query_handler
331375
.do_query(&sql, self.query_ctx.clone())
@@ -338,12 +382,12 @@ impl ExtendedQueryHandler for PostgresServerHandler {
338382
async fn do_describe<C>(
339383
&self,
340384
_client: &mut C,
341-
statement: &Self::Statement,
385+
statement: &StoredStatement<Self::Statement>,
342386
) -> PgWireResult<Vec<FieldInfo>>
343387
where
344388
C: ClientInfo + Unpin + Send + Sync,
345389
{
346-
let (stmt, _) = statement;
390+
let (stmt, _) = statement.statement();
347391
if let Some(schema) = self
348392
.query_handler
349393
.do_describe(stmt.clone(), self.query_ctx.clone())

src/sql/src/statements/query.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use datatypes::prelude::ConcreteDataType;
1516
use sqlparser::ast::Query as SpQuery;
1617

1718
use crate::error::Error;
@@ -20,14 +21,18 @@ use crate::error::Error;
2021
#[derive(Debug, Clone, PartialEq, Eq)]
2122
pub struct Query {
2223
pub inner: SpQuery,
24+
pub param_types: Vec<ConcreteDataType>,
2325
}
2426

2527
/// Automatically converts from sqlparser Query instance to SqlQuery.
2628
impl TryFrom<SpQuery> for Query {
2729
type Error = Error;
2830

2931
fn try_from(q: SpQuery) -> Result<Self, Self::Error> {
30-
Ok(Query { inner: q })
32+
Ok(Query {
33+
inner: q,
34+
param_types: vec![],
35+
})
3136
}
3237
}
3338

@@ -38,3 +43,13 @@ impl TryFrom<Query> for SpQuery {
3843
Ok(value.inner)
3944
}
4045
}
46+
47+
impl Query {
48+
pub fn param_types(&self) -> &Vec<ConcreteDataType> {
49+
&self.param_types
50+
}
51+
52+
pub fn param_types_mut(&mut self) -> &mut Vec<ConcreteDataType> {
53+
&mut self.param_types
54+
}
55+
}

0 commit comments

Comments
 (0)