Skip to content

Commit 6979c65

Browse files
committed
implemented partitioning for Trino
1 parent aaab7de commit 6979c65

File tree

8 files changed

+119
-23
lines changed

8 files changed

+119
-23
lines changed

Cargo.lock

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

connectorx-python/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

connectorx-python/connectorx/tests/test_trino.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_trino_limit_large_with_partition(trino_url: str) -> None:
154154

155155

156156
def test_trino_with_partition_without_partition_range(trino_url: str) -> None:
157-
query = "SELECT * FROM test.test_table where test_float > 3 order by test_int"
157+
query = "SELECT * FROM test.test_table where test_float > 3"
158158
df = read_sql(
159159
trino_url,
160160
query,
@@ -170,6 +170,7 @@ def test_trino_with_partition_without_partition_range(trino_url: str) -> None:
170170
},
171171
)
172172
df.sort_values(by="test_int", inplace=True, ignore_index=True)
173+
173174
assert_frame_equal(df, expected, check_names=True)
174175

175176

@@ -210,7 +211,7 @@ def test_trino_selection_and_projection(trino_url: str) -> None:
210211

211212

212213
def test_trino_join(trino_url: str) -> None:
213-
query = "SELECT T.test_int, T.test_float, S.test_str FROM test_table T INNER JOIN test_table_extra S ON T.test_int = S.test_int order by T.test_int"
214+
query = "SELECT T.test_int, T.test_float, S.test_str FROM test.test_table T INNER JOIN test.test_table_extra S ON T.test_int = S.test_int order by T.test_int"
214215
df = read_sql(
215216
trino_url,
216217
query,
@@ -262,7 +263,7 @@ def test_trino_types_binary(trino_url: str) -> None:
262263
"test_real": pd.Series([123.456, 123.456, None], dtype="float64"),
263264
"test_double": pd.Series([123.4567890123, 123.4567890123, None], dtype="float64"),
264265
"test_decimal": pd.Series([1234567890.12, 1234567890.12, None], dtype="float64"),
265-
"test_date": pd.Series([None, "2023-01-01", "2023-01-01"], dtype="datetime64[ns]"),
266+
"test_date": pd.Series(["2023-01-01", "2023-01-01", None], dtype="datetime64[ns]"),
266267
"test_time": pd.Series(["12:00:00", "12:00:00", None], dtype="object"),
267268
"test_timestamp": pd.Series(["2023-01-01 12:00:00.123456", "2023-01-01 12:00:00.123456", None], dtype="datetime64[ns]"),
268269
"test_varchar": pd.Series(["Sample text", "Sample text", None], dtype="object"),
@@ -299,7 +300,7 @@ def test_empty_result_on_partition(trino_url: str) -> None:
299300

300301

301302
def test_empty_result_on_some_partition(trino_url: str) -> None:
302-
query = "SELECT * FROM test_table where test_int = 6"
303+
query = "SELECT * FROM test.test_table where test_int = 6"
303304
df = read_sql(trino_url, query, partition_on="test_int", partition_num=3)
304305
expected = pd.DataFrame(
305306
index=range(1),

connectorx-python/src/pandas/get_meta.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use super::{
22
destination::PandasDestination,
33
transports::{
44
BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport,
5-
PostgresPandasTransport, SqlitePandasTransport,
5+
PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport,
66
},
77
};
88
use crate::errors::ConnectorXPythonError;
@@ -18,6 +18,7 @@ use connectorx::{
1818
PostgresSource, SimpleProtocol,
1919
},
2020
sqlite::SQLiteSource,
21+
trino::TrinoSource,
2122
},
2223
sql::CXQuery,
2324
};
@@ -223,6 +224,17 @@ pub fn get_meta<'a>(py: Python<'a>, conn: &str, protocol: &str, query: String) -
223224
debug!("Running dispatcher");
224225
dispatcher.get_meta()?;
225226
}
227+
SourceType::Trino => {
228+
let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
229+
let source = TrinoSource::new(rt, &source_conn.conn[..])?;
230+
let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new(
231+
source,
232+
&mut destination,
233+
queries,
234+
None,
235+
);
236+
dispatcher.run()?;
237+
}
226238
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
227239
}
228240

connectorx/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ uuid = {version = "0.8", optional = true}
5858
j4rs = {version = "0.15", optional = true}
5959
datafusion = {version = "31", optional = true}
6060
prusto = {version = "0.5.1", optional = true}
61+
serde = {optional = true}
6162

6263
[lib]
6364
crate-type = ["cdylib", "rlib"]
@@ -98,7 +99,7 @@ src_postgres = [
9899
"postgres-openssl",
99100
]
100101
src_sqlite = ["rusqlite", "r2d2_sqlite", "fallible-streaming-iterator", "r2d2", "urlencoding"]
101-
src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits"]
102+
src_trino = ["prusto", "uuid", "urlencoding", "rust_decimal", "tokio", "num-traits", "serde"]
102103
federation = ["j4rs"]
103104
fed_exec = ["datafusion", "tokio"]
104105
integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"]

connectorx/src/partition.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::sources::mysql::{MySQLSourceError, MySQLTypeSystem};
1010
use crate::sources::oracle::{connect_oracle, OracleDialect};
1111
#[cfg(feature = "src_postgres")]
1212
use crate::sources::postgres::{rewrite_tls_args, PostgresTypeSystem};
13+
use crate::sources::trino::TrinoDialect;
1314
#[cfg(feature = "src_sqlite")]
1415
use crate::sql::get_partition_range_query_sep;
1516
use crate::sql::{get_partition_range_query, single_col_partition_query, CXQuery};
@@ -35,7 +36,7 @@ use sqlparser::dialect::PostgreSqlDialect;
3536
use sqlparser::dialect::SQLiteDialect;
3637
#[cfg(feature = "src_mssql")]
3738
use tiberius::Client;
38-
#[cfg(any(feature = "src_bigquery", feature = "src_mssql"))]
39+
#[cfg(any(feature = "src_bigquery", feature = "src_mssql", feature = "src_trino"))]
3940
use tokio::{net::TcpStream, runtime::Runtime};
4041
#[cfg(feature = "src_mssql")]
4142
use tokio_util::compat::TokioAsyncWriteCompatExt;
@@ -100,6 +101,8 @@ pub fn get_col_range(source_conn: &SourceConn, query: &str, col: &str) -> OutRes
100101
SourceType::Oracle => oracle_get_partition_range(&source_conn.conn, query, col),
101102
#[cfg(feature = "src_bigquery")]
102103
SourceType::BigQuery => bigquery_get_partition_range(&source_conn.conn, query, col),
104+
#[cfg(feature = "src_trino")]
105+
SourceType::Trino => trino_get_partition_range(&source_conn.conn, query, col),
103106
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
104107
}
105108
}
@@ -137,6 +140,10 @@ pub fn get_part_query(
137140
SourceType::BigQuery => {
138141
single_col_partition_query(query, col, lower, upper, &BigQueryDialect {})?
139142
}
143+
#[cfg(feature = "src_trino")]
144+
SourceType::Trino => {
145+
single_col_partition_query(query, col, lower, upper, &TrinoDialect {})?
146+
}
140147
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
141148
};
142149
CXQuery::Wrapped(query)
@@ -481,3 +488,52 @@ fn bigquery_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64
481488

482489
(min_v, max_v)
483490
}
491+
492+
#[cfg(feature = "src_trino")]
493+
#[throws(ConnectorXOutError)]
494+
fn trino_get_partition_range(conn: &Url, query: &str, col: &str) -> (i64, i64) {
495+
use prusto::{auth::Auth, ClientBuilder};
496+
497+
use crate::sources::trino::{TrinoDialect, TrinoPartitionQueryResult};
498+
499+
let rt = Runtime::new().expect("Failed to create runtime");
500+
501+
let username = match conn.username() {
502+
"" => "connectorx",
503+
username => username,
504+
};
505+
506+
let builder = ClientBuilder::new(username, conn.host().unwrap().to_owned())
507+
.port(conn.port().unwrap_or(8080))
508+
.ssl(prusto::ssl::Ssl { root_cert: None })
509+
.secure(conn.scheme() == "trino+https")
510+
.catalog(conn.path_segments().unwrap().last().unwrap_or("hive"));
511+
512+
let builder = match conn.password() {
513+
None => builder,
514+
Some(password) => builder.auth(Auth::Basic(username.to_owned(), Some(password.to_owned()))),
515+
};
516+
517+
let client = builder
518+
.build()
519+
.map_err(|e| anyhow!("Failed to build client: {}", e))?;
520+
521+
let range_query = get_partition_range_query(query, col, &TrinoDialect {})?;
522+
let query_result = rt.block_on(client.get_all::<TrinoPartitionQueryResult>(range_query));
523+
524+
let query_result = match query_result {
525+
Ok(query_result) => Ok(query_result.into_vec()),
526+
Err(e) => match e {
527+
prusto::error::Error::EmptyData => {
528+
Ok(vec![TrinoPartitionQueryResult { _col0: 0, _col1: 0 }])
529+
}
530+
_ => Err(anyhow!("Failed to get query result: {}", e)),
531+
},
532+
}?;
533+
534+
let result = query_result
535+
.first()
536+
.unwrap_or(&TrinoPartitionQueryResult { _col0: 0, _col1: 0 });
537+
538+
(result._col0, result._col1)
539+
}

connectorx/src/sources/trino/mod.rs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
44
use fehler::{throw, throws};
55
use prusto::{auth::Auth, Client, ClientBuilder, DataSet, Presto, Row};
66
use serde_json::Value;
7-
use sqlparser::dialect::GenericDialect;
7+
use sqlparser::dialect::{Dialect, GenericDialect};
88
use std::convert::TryFrom;
99
use tokio::runtime::Runtime;
1010

@@ -32,6 +32,26 @@ fn get_total_rows(rt: Arc<Runtime>, client: Arc<Client>, query: &CXQuery<String>
3232
.len()
3333
}
3434

35+
#[derive(Presto, Debug)]
36+
pub struct TrinoPartitionQueryResult {
37+
pub _col0: i64,
38+
pub _col1: i64,
39+
}
40+
41+
#[derive(Debug)]
42+
pub struct TrinoDialect {}
43+
44+
// implementation copy from AnsiDialect
45+
impl Dialect for TrinoDialect {
46+
fn is_identifier_start(&self, ch: char) -> bool {
47+
ch.is_ascii_lowercase() || ch.is_ascii_uppercase()
48+
}
49+
50+
fn is_identifier_part(&self, ch: char) -> bool {
51+
ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch.is_ascii_digit() || ch == '_'
52+
}
53+
}
54+
3555
pub struct TrinoSource {
3656
client: Arc<Client>,
3757
rt: Arc<Runtime>,
@@ -282,12 +302,12 @@ macro_rules! impl_produce_int {
282302
match value {
283303
Value::Number(x) => {
284304
if (x.is_i64()) {
285-
<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?
305+
<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?
286306
} else {
287-
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
307+
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
288308
}
289309
}
290-
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
310+
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
291311
}
292312
}
293313
}
@@ -304,12 +324,12 @@ macro_rules! impl_produce_int {
304324
Value::Null => None,
305325
Value::Number(x) => {
306326
if (x.is_i64()) {
307-
Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))?)
327+
Some(<$t>::try_from(x.as_i64().unwrap()).map_err(|_| anyhow!("Trino cannot parse i64 at position: ({}, {}) {:?}", ridx, cidx, value))?)
308328
} else {
309-
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
329+
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
310330
}
311331
}
312-
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
332+
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
313333
}
314334
}
315335
}
@@ -333,10 +353,11 @@ macro_rules! impl_produce_float {
333353
if (x.is_f64()) {
334354
x.as_f64().unwrap() as $t
335355
} else {
336-
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
356+
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
337357
}
338358
}
339-
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
359+
Value::String(x) => x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?,
360+
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
340361
}
341362
}
342363
}
@@ -355,10 +376,11 @@ macro_rules! impl_produce_float {
355376
if (x.is_f64()) {
356377
Some(x.as_f64().unwrap() as $t)
357378
} else {
358-
throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
379+
throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, x))
359380
}
360381
}
361-
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {})", ridx, cidx))
382+
Value::String(x) => Some(x.parse::<$t>().map_err(|_| anyhow!("Trino cannot parse String at position: ({}, {}) {:?}", ridx, cidx, value))?),
383+
_ => throw!(anyhow!("Trino cannot parse Number at position: ({}, {}) {:?}", ridx, cidx, value))
362384
}
363385
}
364386
}

connectorx/src/sources/trino/typesystem.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::errors::TrinoSourceError;
22
use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
33
use fehler::{throw, throws};
4-
use prusto::{PrestoFloat, PrestoInt, PrestoTy};
4+
use prusto::{Presto, PrestoFloat, PrestoInt, PrestoTy};
55
use std::convert::TryFrom;
66

77
// TODO: implement Tuple, Row, Array and Map as well as UUID
@@ -64,6 +64,7 @@ impl TryFrom<PrestoTy> for TrinoTypeSystem {
6464
PrestoTy::Map(_, _) => Varchar(true),
6565
PrestoTy::Decimal(_, _) => Double(true),
6666
PrestoTy::IpAddress => Varchar(true),
67+
PrestoTy::Uuid => Varchar(true),
6768
_ => throw!(TrinoSourceError::InferTypeFromNull),
6869
}
6970
}
@@ -97,6 +98,7 @@ impl TryFrom<(Option<&str>, PrestoTy)> for TrinoTypeSystem {
9798
"map" => Varchar(true),
9899
"decimal" => Double(true),
99100
"ipaddress" => Varchar(true),
101+
"uuid" => Varchar(true),
100102
_ => TrinoTypeSystem::try_from(ty)?,
101103
}
102104
}

0 commit comments

Comments
 (0)