diff --git a/connectorx-cpp/src/lib.rs b/connectorx-cpp/src/lib.rs index 2169f139fe..0e3081d240 100644 --- a/connectorx-cpp/src/lib.rs +++ b/connectorx-cpp/src/lib.rs @@ -176,7 +176,7 @@ pub unsafe extern "C" fn connectorx_scan(conn: *const c_char, query: *const c_ch let conn_str = unsafe { CStr::from_ptr(conn) }.to_str().unwrap(); let query_str = unsafe { CStr::from_ptr(query) }.to_str().unwrap(); let source_conn = SourceConn::try_from(conn_str).unwrap(); - let record_batches = get_arrow(&source_conn, None, &[CXQuery::from(query_str)]) + let record_batches = get_arrow(&source_conn, None, &[CXQuery::from(query_str)], None) .unwrap() .arrow() .unwrap(); @@ -281,7 +281,7 @@ pub unsafe extern "C" fn connectorx_scan_iter( } let arrow_iter: Box = - new_record_batch_iter(&source_conn, None, query_vec.as_slice(), batch_size); + new_record_batch_iter(&source_conn, None, query_vec.as_slice(), batch_size, None); Box::into_raw(Box::new(arrow_iter)) } diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 737a1de57c..55e95d3ef4 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -131,6 +131,7 @@ def read_sql_pandas( partition_on: str | None = None, partition_range: tuple[int, int] | None = None, partition_num: int | None = None, + pre_execution_queries: list[str] | str | None = None, ) -> pd.DataFrame: """ Run the SQL query, download the data from database into a dataframe. @@ -160,6 +161,7 @@ def read_sql_pandas( partition_range=partition_range, partition_num=partition_num, index_col=index_col, + pre_execution_queries=pre_execution_queries, ) @@ -174,6 +176,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> pd.DataFrame: ... @@ -188,6 +191,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> pd.DataFrame: ... @@ -202,6 +206,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> pa.Table: ... @@ -216,6 +221,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> mpd.DataFrame: ... @@ -230,6 +236,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> dd.DataFrame: ... @@ -244,6 +251,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> pl.DataFrame: ... @@ -260,6 +268,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, strategy: str | None = None, + pre_execution_query: list[str] | str | None = None, ) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table: """ Run the SQL query, download the data from database into a dataframe. @@ -285,6 +294,9 @@ def read_sql( the index column to set; only applicable for return type "pandas", "modin", "dask". strategy strategy of rewriting the federated query for join pushdown + pre_execution_query + SQL query or list of SQL queries executed before main query; can be used to set runtime + configurations using SET statements; only applicable for source "Postgres" and "MySQL". Examples ======== @@ -358,6 +370,13 @@ def read_sql( raise ValueError("Partition on multiple queries is not supported.") else: raise ValueError("query must be either str or a list of str") + + if isinstance(pre_execution_query, list): + pre_execution_queries = [remove_ending_semicolon(subquery) for subquery in pre_execution_query] + elif isinstance(pre_execution_query, str): + pre_execution_queries = [remove_ending_semicolon(pre_execution_query)] + else: + pre_execution_queries = None conn, protocol = rewrite_conn(conn, protocol) @@ -370,6 +389,7 @@ def read_sql( queries=queries, protocol=protocol, partition_query=partition_query, + pre_execution_queries=pre_execution_queries, ) df = reconstruct_pandas(result) @@ -392,6 +412,7 @@ def read_sql( queries=queries, protocol=protocol, partition_query=partition_query, + pre_execution_queries=pre_execution_queries, ) df = reconstruct_arrow(result) if return_type in {"polars"}: diff --git a/connectorx-python/connectorx/connectorx.pyi b/connectorx-python/connectorx/connectorx.pyi index f116a97fa9..f63709c079 100644 --- a/connectorx-python/connectorx/connectorx.pyi +++ b/connectorx-python/connectorx/connectorx.pyi @@ -25,6 +25,7 @@ def read_sql( protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, + pre_execution_queries: list[str] | None, ) -> _DataframeInfos: ... @overload def read_sql( @@ -33,6 +34,7 @@ def read_sql( protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, + pre_execution_queries: list[str] | None, ) -> _ArrowInfos: ... def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ... def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ... diff --git a/connectorx-python/connectorx/tests/test_mysql.py b/connectorx-python/connectorx/tests/test_mysql.py index d76b56305e..ce3267de48 100644 --- a/connectorx-python/connectorx/tests/test_mysql.py +++ b/connectorx-python/connectorx/tests/test_mysql.py @@ -478,3 +478,52 @@ def test_mysql_cte(mysql_url: str) -> None: def test_connection_url(mysql_url: str) -> None: test_mysql_cte(ConnectionUrl(mysql_url)) + +def test_mysql_single_pre_execution_queries(mysql_url: str) -> None: + pre_execution_query = "SET SESSION max_execution_time = 2151" + query = "SELECT @@SESSION.max_execution_time AS max_execution_time" + df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query) + expected = pd.DataFrame( + index=range(1), + data={ + "max_execution_time": pd.Series([2151], dtype="float64") + }, + ) + assert_frame_equal(df, expected, check_names=True) + + +def test_mysql_multiple_pre_execution_queries(mysql_url: str) -> None: + pre_execution_query = [ + "SET SESSION max_execution_time = 2151", + "SET SESSION wait_timeout = 2252", + ] + query = "SELECT @@SESSION.max_execution_time AS max_execution_time, @@SESSION.wait_timeout AS wait_timeout" + df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query) + expected = pd.DataFrame( + index=range(1), + data={ + "max_execution_time": pd.Series([2151], dtype="float64"), + "wait_timeout": pd.Series([2252], dtype="float64") + }, + ) + assert_frame_equal(df, expected, check_names=True) + +def test_mysql_partitioned_pre_execution_queries(mysql_url: str) -> None: + pre_execution_query = [ + "SET SESSION max_execution_time = 2151", + "SET SESSION wait_timeout = 2252", + ] + query = [ + 'SELECT "max_execution_time" AS name, @@SESSION.max_execution_time AS setting', + 'SELECT "wait_timeout" AS name, @@SESSION.wait_timeout AS setting' + ] + df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query).sort_values(by=['name']) + expected = pd.DataFrame( + index=range(2), + data={ + "name": pd.Series(["max_execution_time", "wait_timeout"], dtype="str"), + "setting": pd.Series([2151, 2252], dtype="float64"), + }, + ).sort_values(by=['name']) + + assert_frame_equal(df, expected, check_like=False) diff --git a/connectorx-python/connectorx/tests/test_postgres.py b/connectorx-python/connectorx/tests/test_postgres.py index b18e133888..ae8801de87 100644 --- a/connectorx-python/connectorx/tests/test_postgres.py +++ b/connectorx-python/connectorx/tests/test_postgres.py @@ -1008,4 +1008,53 @@ def test_postgres_partition_with_orderby_limit_desc(postgres_url: str) -> None: }, ) df.sort_values(by="test_int", inplace=True, ignore_index=True) - assert_frame_equal(df, expected, check_names=True) \ No newline at end of file + assert_frame_equal(df, expected, check_names=True) + +def test_postgres_single_pre_execution_queries(postgres_url: str) -> None: + pre_execution_query = "SET SESSION statement_timeout = 2151" + query = "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) FROM pg_settings WHERE name = 'statement_timeout'" + df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query) + expected = pd.DataFrame( + index=range(1), + data={ + "name": pd.Series(["statement_timeout"], dtype="str"), + "setting": pd.Series([2151], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + +def test_postgres_multiple_pre_execution_queries(postgres_url: str) -> None: + pre_execution_query = [ + "SET SESSION statement_timeout = 2151", + "SET SESSION idle_in_transaction_session_timeout = 2252", + ] + query = "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) FROM pg_settings WHERE name IN ('statement_timeout', 'idle_in_transaction_session_timeout') ORDER BY name" + df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query) + expected = pd.DataFrame( + index=range(2), + data={ + "name": pd.Series(["idle_in_transaction_session_timeout", "statement_timeout"], dtype="str"), + "setting": pd.Series([2252, 2151], dtype="Int64"), + }, + ) + assert_frame_equal(df, expected, check_names=True) + +def test_postgres_partitioned_pre_execution_queries(postgres_url: str) -> None: + pre_execution_query = [ + "SET SESSION statement_timeout = 2151", + "SET SESSION idle_in_transaction_session_timeout = 2252", + ] + query = [ + "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'statement_timeout'", + "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'idle_in_transaction_session_timeout'" + ] + + df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query).sort_values(by=['name']) + expected = pd.DataFrame( + index=range(2), + data={ + "name": pd.Series(["statement_timeout", "idle_in_transaction_session_timeout"], dtype="str"), + "setting": pd.Series([2151, 2252], dtype="Int64"), + }, + ).sort_values(by=['name']) + assert_frame_equal(df, expected, check_names=True) diff --git a/connectorx-python/src/arrow.rs b/connectorx-python/src/arrow.rs index d8101be18a..f506f60dd8 100644 --- a/connectorx-python/src/arrow.rs +++ b/connectorx-python/src/arrow.rs @@ -15,10 +15,11 @@ pub fn write_arrow<'py>( source_conn: &SourceConn, origin_query: Option, queries: &[CXQuery], + pre_execution_queries: Option<&[String]>, ) -> Bound<'py, PyAny> { let ptrs = py.allow_threads( || -> Result<(Vec, Vec>), ConnectorXPythonError> { - let destination = get_arrow(source_conn, origin_query, queries)?; + let destination = get_arrow(source_conn, origin_query, queries, pre_execution_queries)?; let rbs = destination.arrow()?; Ok(to_ptrs(rbs)) }, diff --git a/connectorx-python/src/cx_read_sql.rs b/connectorx-python/src/cx_read_sql.rs index 2a1a457998..a95981716d 100644 --- a/connectorx-python/src/cx_read_sql.rs +++ b/connectorx-python/src/cx_read_sql.rs @@ -38,6 +38,7 @@ pub fn read_sql<'py>( protocol: Option<&str>, queries: Option>, partition_query: Option, + pre_execution_queries: Option>, ) -> PyResult> { let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?; let (queries, origin_query) = match (queries, partition_query) { @@ -62,12 +63,14 @@ pub fn read_sql<'py>( &source_conn, origin_query, &queries, + pre_execution_queries.as_deref(), )?), "arrow" => Ok(crate::arrow::write_arrow( py, &source_conn, origin_query, &queries, + pre_execution_queries.as_deref(), )?), _ => Err(PyValueError::new_err(format!( "return type should be 'pandas' or 'arrow', got '{}'", diff --git a/connectorx-python/src/lib.rs b/connectorx-python/src/lib.rs index a7c0d962ec..b4f573b35a 100644 --- a/connectorx-python/src/lib.rs +++ b/connectorx-python/src/lib.rs @@ -39,7 +39,7 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { } #[pyfunction] -#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None))] +#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))] pub fn read_sql<'py>( py: Python<'py>, conn: &str, @@ -47,8 +47,17 @@ pub fn read_sql<'py>( protocol: Option<&str>, queries: Option>, partition_query: Option, + pre_execution_queries: Option>, ) -> PyResult> { - cx_read_sql::read_sql(py, conn, return_type, protocol, queries, partition_query) + cx_read_sql::read_sql( + py, + conn, + return_type, + protocol, + queries, + partition_query, + pre_execution_queries, + ) } #[pyfunction] diff --git a/connectorx-python/src/pandas/dispatcher.rs b/connectorx-python/src/pandas/dispatcher.rs index 4d23a6e3e8..3c634b82a5 100644 --- a/connectorx-python/src/pandas/dispatcher.rs +++ b/connectorx-python/src/pandas/dispatcher.rs @@ -41,6 +41,10 @@ where } } + pub fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) { + self.src.set_pre_execution_queries(pre_execution_queries); + } + /// Start the data loading process. pub fn run(mut self, py: Python<'py>) -> Result, TP::Error> { debug!("Run dispatcher"); diff --git a/connectorx-python/src/pandas/mod.rs b/connectorx-python/src/pandas/mod.rs index 15c91f7783..e335f5d3d2 100644 --- a/connectorx-python/src/pandas/mod.rs +++ b/connectorx-python/src/pandas/mod.rs @@ -40,6 +40,7 @@ pub fn write_pandas<'a, 'py: 'a>( source_conn: &SourceConn, origin_query: Option, queries: &[CXQuery], + pre_execution_queries: Option<&[String]>, ) -> Bound<'py, PyAny> { let destination = PandasDestination::new(); let protocol = source_conn.proto.as_str(); @@ -55,23 +56,25 @@ pub fn write_pandas<'a, 'py: 'a>( tls_conn, queries.len(), )?; - let dispatcher = PandasDispatcher::< + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new( sb, destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("csv", None) => { let sb = PostgresSource::::new(config, NoTls, queries.len())?; - let dispatcher = PandasDispatcher::< + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new( sb, destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("binary", Some(tls_conn)) => { @@ -80,11 +83,12 @@ pub fn write_pandas<'a, 'py: 'a>( tls_conn, queries.len(), )?; - let dispatcher = + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new(sb, destination, queries, origin_query); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("binary", None) => { @@ -93,12 +97,13 @@ pub fn write_pandas<'a, 'py: 'a>( NoTls, queries.len(), )?; - let dispatcher = PandasDispatcher::< + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new( sb, destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("cursor", Some(tls_conn)) => { @@ -107,22 +112,24 @@ pub fn write_pandas<'a, 'py: 'a>( tls_conn, queries.len(), )?; - let dispatcher = + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new(sb, destination, queries, origin_query); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("cursor", None) => { let sb = PostgresSource::::new(config, NoTls, queries.len())?; - let dispatcher = PandasDispatcher::< + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new( sb, destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("simple", Some(tls_conn)) => { @@ -131,22 +138,24 @@ pub fn write_pandas<'a, 'py: 'a>( tls_conn, queries.len(), )?; - let dispatcher = + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new(sb, destination, queries, origin_query); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } ("simple", None) => { let sb = PostgresSource::::new(config, NoTls, queries.len())?; - let dispatcher = PandasDispatcher::< + let mut dispatcher = PandasDispatcher::< _, PostgresPandasTransport, >::new( sb, destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } _ => unimplemented!("{} protocol not supported", protocol), @@ -168,24 +177,26 @@ pub fn write_pandas<'a, 'py: 'a>( "binary" => { let source = MySQLSource::::new(&source_conn.conn[..], queries.len())?; - let dispatcher = + let mut dispatcher = PandasDispatcher::<_, MysqlPandasTransport>::new( source, destination, queries, origin_query, ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } "text" => { let source = MySQLSource::::new(&source_conn.conn[..], queries.len())?; - let dispatcher = PandasDispatcher::<_, MysqlPandasTransport>::new( + let mut dispatcher = PandasDispatcher::<_, MysqlPandasTransport>::new( source, destination, queries, origin_query, ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run(py)? } _ => unimplemented!("{} protocol not supported", protocol), diff --git a/connectorx/src/dispatcher.rs b/connectorx/src/dispatcher.rs index c8cf97cec5..e02dae783a 100644 --- a/connectorx/src/dispatcher.rs +++ b/connectorx/src/dispatcher.rs @@ -43,6 +43,10 @@ where } } + pub fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) { + self.src.set_pre_execution_queries(pre_execution_queries); + } + pub fn prepare( mut self, ) -> Result< diff --git a/connectorx/src/fed_dispatcher.rs b/connectorx/src/fed_dispatcher.rs index 43a184953a..fd7507ea7e 100644 --- a/connectorx/src/fed_dispatcher.rs +++ b/connectorx/src/fed_dispatcher.rs @@ -51,7 +51,7 @@ pub fn run( .as_ref() .unwrap(); - let destination = get_arrow(source_conn, None, queries.as_slice())?; + let destination = get_arrow(source_conn, None, queries.as_slice(), None)?; let rbs = destination.arrow()?; let provider = MemTable::try_new(rbs[0].schema(), vec![rbs])?; diff --git a/connectorx/src/get_arrow.rs b/connectorx/src/get_arrow.rs index a7efa44938..c665c9c7f1 100644 --- a/connectorx/src/get_arrow.rs +++ b/connectorx/src/get_arrow.rs @@ -25,6 +25,7 @@ pub fn get_arrow( source_conn: &SourceConn, origin_query: Option, queries: &[CXQuery], + pre_execution_queries: Option<&[String]>, ) -> ArrowDestination { let mut destination = ArrowDestination::new(); let protocol = source_conn.proto.as_str(); @@ -41,25 +42,27 @@ pub fn get_arrow( tls_conn, queries.len(), )?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, >::new( source, &mut destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("csv", None) => { let source = PostgresSource::::new(config, NoTls, queries.len())?; - let dispatcher = - Dispatcher::<_, _, PostgresArrowTransport>::new( - source, - &mut destination, - queries, - origin_query, - ); + let mut dispatcher = Dispatcher::< + _, + _, + PostgresArrowTransport, + >::new( + source, &mut destination, queries, origin_query + ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("binary", Some(tls_conn)) => { @@ -68,13 +71,14 @@ pub fn get_arrow( tls_conn, queries.len(), )?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, >::new( source, &mut destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("binary", None) => { @@ -83,13 +87,14 @@ pub fn get_arrow( NoTls, queries.len(), )?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, >::new( source, &mut destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("cursor", Some(tls_conn)) => { @@ -98,25 +103,27 @@ pub fn get_arrow( tls_conn, queries.len(), )?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, >::new( source, &mut destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("cursor", None) => { let source = PostgresSource::::new(config, NoTls, queries.len())?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, >::new( source, &mut destination, queries, origin_query ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("simple", Some(tls_conn)) => { @@ -125,7 +132,7 @@ pub fn get_arrow( tls_conn, queries.len(), )?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, @@ -133,12 +140,13 @@ pub fn get_arrow( sb, &mut destination, queries, origin_query ); debug!("Running dispatcher"); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } ("simple", None) => { let sb = PostgresSource::::new(config, NoTls, queries.len())?; - let dispatcher = Dispatcher::< + let mut dispatcher = Dispatcher::< _, _, PostgresArrowTransport, @@ -146,6 +154,7 @@ pub fn get_arrow( sb, &mut destination, queries, origin_query ); debug!("Running dispatcher"); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } _ => unimplemented!("{} protocol not supported", protocol), @@ -156,23 +165,26 @@ pub fn get_arrow( "binary" => { let source = MySQLSource::::new(&source_conn.conn[..], queries.len())?; - let dispatcher = Dispatcher::<_, _, MySQLArrowTransport>::new( - source, - &mut destination, - queries, - origin_query, - ); + let mut dispatcher = + Dispatcher::<_, _, MySQLArrowTransport>::new( + source, + &mut destination, + queries, + origin_query, + ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } "text" => { let source = MySQLSource::::new(&source_conn.conn[..], queries.len())?; - let dispatcher = Dispatcher::<_, _, MySQLArrowTransport>::new( + let mut dispatcher = Dispatcher::<_, _, MySQLArrowTransport>::new( source, &mut destination, queries, origin_query, ); + dispatcher.set_pre_execution_queries(pre_execution_queries); dispatcher.run()?; } _ => unimplemented!("{} protocol not supported", protocol), @@ -252,6 +264,7 @@ pub fn new_record_batch_iter( origin_query: Option, queries: &[CXQuery], batch_size: usize, + pre_execution_queries: Option<&[String]>, ) -> Box { let destination = ArrowStreamDestination::new_with_batch_size(batch_size); let protocol = source_conn.proto.as_str(); @@ -263,12 +276,15 @@ pub fn new_record_batch_iter( let (config, tls) = rewrite_tls_args(&source_conn.conn).unwrap(); match (protocol, tls) { ("csv", Some(tls_conn)) => { - let source = PostgresSource::::new( + let mut source = PostgresSource::::new( config, tls_conn, queries.len(), ) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::< _, @@ -278,9 +294,12 @@ pub fn new_record_batch_iter( return Box::new(batch_iter); } ("csv", None) => { - let source = + let mut source = PostgresSource::::new(config, NoTls, queries.len()) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::< _, PostgresArrowStreamTransport, @@ -291,12 +310,15 @@ pub fn new_record_batch_iter( return Box::new(batch_iter); } ("binary", Some(tls_conn)) => { - let source = PostgresSource::::new( + let mut source = PostgresSource::::new( config, tls_conn, queries.len(), ) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::< _, @@ -306,12 +328,15 @@ pub fn new_record_batch_iter( return Box::new(batch_iter); } ("binary", None) => { - let source = PostgresSource::::new( + let mut source = PostgresSource::::new( config, NoTls, queries.len(), ) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::< _, PostgresArrowStreamTransport, @@ -322,12 +347,15 @@ pub fn new_record_batch_iter( return Box::new(batch_iter); } ("cursor", Some(tls_conn)) => { - let source = PostgresSource::::new( + let mut source = PostgresSource::::new( config, tls_conn, queries.len(), ) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::< _, @@ -337,9 +365,12 @@ pub fn new_record_batch_iter( return Box::new(batch_iter); } ("cursor", None) => { - let source = + let mut source = PostgresSource::::new(config, NoTls, queries.len()) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::< _, PostgresArrowStreamTransport, @@ -355,9 +386,12 @@ pub fn new_record_batch_iter( #[cfg(feature = "src_mysql")] SourceType::MySQL => match protocol { "binary" => { - let source = + let mut source = MySQLSource::::new(&source_conn.conn[..], queries.len()) .unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::<_, MySQLArrowStreamTransport>::new( source, @@ -369,8 +403,11 @@ pub fn new_record_batch_iter( return Box::new(batch_iter); } "text" => { - let source = + let mut source = MySQLSource::::new(&source_conn.conn[..], queries.len()).unwrap(); + + source.set_pre_execution_queries(pre_execution_queries); + let batch_iter = ArrowBatchIter::<_, MySQLArrowStreamTransport>::new( source, destination, diff --git a/connectorx/src/sources/mod.rs b/connectorx/src/sources/mod.rs index dd86b524dc..f1f8a361c1 100644 --- a/connectorx/src/sources/mod.rs +++ b/connectorx/src/sources/mod.rs @@ -41,6 +41,10 @@ pub trait Source { fn set_origin_query(&mut self, query: Option); + fn set_pre_execution_queries(&mut self, _pre_execution_queries: Option<&[String]>) { + unimplemented!("pre_execution_queries is not implemented in this source type"); + } + fn fetch_metadata(&mut self) -> Result<(), Self::Error>; /// Get total number of rows if available fn result_rows(&mut self) -> Result, Self::Error>; diff --git a/connectorx/src/sources/mysql/mod.rs b/connectorx/src/sources/mysql/mod.rs index 0cf74d858f..bcf155d68d 100644 --- a/connectorx/src/sources/mysql/mod.rs +++ b/connectorx/src/sources/mysql/mod.rs @@ -50,6 +50,7 @@ pub struct MySQLSource

{ queries: Vec>, names: Vec, schema: Vec, + pre_execution_queries: Option>, _protocol: PhantomData

, } @@ -67,6 +68,7 @@ impl

MySQLSource

{ queries: vec![], names: vec![], schema: vec![], + pre_execution_queries: None, _protocol: PhantomData, } } @@ -98,6 +100,10 @@ where self.origin_query = query; } + fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) { + self.pre_execution_queries = pre_execution_queries.map(|s| s.to_vec()); + } + #[throws(MySQLSourceError)] fn fetch_metadata(&mut self) { assert!(!self.queries.is_empty()); @@ -219,7 +225,14 @@ where fn partition(self) -> Vec { let mut ret = vec![]; for query in self.queries { - let conn = self.pool.get()?; + let mut conn = self.pool.get()?; + + if let Some(pre_queries) = &self.pre_execution_queries { + for pre_query in pre_queries { + conn.query_drop(pre_query)?; + } + } + ret.push(MySQLSourcePartition::new(conn, &query, &self.schema)); } ret diff --git a/connectorx/src/sources/postgres/mod.rs b/connectorx/src/sources/postgres/mod.rs index 103e398fdd..df05962b36 100644 --- a/connectorx/src/sources/postgres/mod.rs +++ b/connectorx/src/sources/postgres/mod.rs @@ -92,6 +92,7 @@ where names: Vec, schema: Vec, pg_schema: Vec, + pre_execution_queries: Option>, _protocol: PhantomData

, } @@ -114,6 +115,7 @@ where names: vec![], schema: vec![], pg_schema: vec![], + pre_execution_queries: None, _protocol: PhantomData, } } @@ -149,6 +151,10 @@ where self.origin_query = query; } + fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) { + self.pre_execution_queries = pre_execution_queries.map(|s| s.to_vec()); + } + #[throws(PostgresSourceError)] fn fetch_metadata(&mut self) { assert!(!self.queries.is_empty()); @@ -199,7 +205,13 @@ where fn partition(self) -> Vec { let mut ret = vec![]; for query in self.queries { - let conn = self.pool.get()?; + let mut conn = self.pool.get()?; + + if let Some(pre_queries) = &self.pre_execution_queries { + for pre_query in pre_queries { + conn.query(pre_query, &[])?; + } + } ret.push(PostgresSourcePartition::::new( conn, diff --git a/connectorx/tests/test_mysql.rs b/connectorx/tests/test_mysql.rs index 4bc21bc284..07abc870e7 100644 --- a/connectorx/tests/test_mysql.rs +++ b/connectorx/tests/test_mysql.rs @@ -61,6 +61,106 @@ fn test_mysql_text() { verify_arrow_results(result); } +#[test] +fn test_mysql_pre_execution_queries() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("MYSQL_URL").unwrap(); + + let queries = [CXQuery::naked( + "SELECT @@SESSION.max_execution_time, @@SESSION.wait_timeout", + )]; + + let pre_execution_queries = [ + String::from("SET SESSION max_execution_time = 2151"), + String::from("SET SESSION wait_timeout = 2252"), + ]; + + let builder = MySQLSource::::new(&dburl, 2).unwrap(); + let mut destination = ArrowDestination::new(); + let mut dispatcher = Dispatcher::<_, _, MySQLArrowTransport>::new( + builder, + &mut destination, + &queries, + None, + ); + dispatcher.set_pre_execution_queries(Some(&pre_execution_queries)); + dispatcher.run().unwrap(); + + let result = destination.arrow().unwrap(); + + assert!(result.len() == 1); + + assert!(result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Float64Array::from(vec![2151.0]))); + + assert!(result[0] + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Float64Array::from(vec![2252.0]))); +} + +#[test] +fn test_mysql_partitioned_pre_execution_queries() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("MYSQL_URL").unwrap(); + + let queries = [ + CXQuery::naked( + "SELECT 'max_execution_time' AS name, @@SESSION.max_execution_time AS setting", + ), + CXQuery::naked("SELECT 'wait_timeout' AS name, @@SESSION.wait_timeout AS setting"), + ]; + + let pre_execution_queries = [ + String::from("SET SESSION max_execution_time = 2151"), + String::from("SET SESSION wait_timeout = 2252"), + ]; + + let builder = MySQLSource::::new(&dburl, 2).unwrap(); + let mut destination = ArrowDestination::new(); + let mut dispatcher = Dispatcher::<_, _, MySQLArrowTransport>::new( + builder, + &mut destination, + &queries, + None, + ); + dispatcher.set_pre_execution_queries(Some(&pre_execution_queries)); + dispatcher.run().unwrap(); + + let result = destination.arrow().unwrap(); + + assert!(result.len() == 2); + + let mut result_map = std::collections::HashMap::new(); + for record_batch in result { + let name = record_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .to_string(); + let setting = record_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + result_map.insert(name, setting); + } + + assert_eq!(result_map.get("max_execution_time"), Some(&2151.0)); + assert_eq!(result_map.get("wait_timeout"), Some(&2252.0)); +} + pub fn verify_arrow_results(result: Vec) { assert!(result.len() == 2); diff --git a/connectorx/tests/test_postgres.rs b/connectorx/tests/test_postgres.rs index a9db2fa728..dca520a45b 100644 --- a/connectorx/tests/test_postgres.rs +++ b/connectorx/tests/test_postgres.rs @@ -1211,3 +1211,105 @@ pub fn verify_arrow_type_results(result: Vec, protocol: &str) { None, ]))); } + +#[test] +fn test_postgres_pre_execution_queries() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("POSTGRES_URL").unwrap(); + + let queries = [ + CXQuery::naked("SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name IN ('statement_timeout', 'idle_in_transaction_session_timeout') ORDER BY name"), + ]; + + let pre_execution_queries = [ + String::from("SET SESSION statement_timeout = 2151"), + String::from("SET SESSION idle_in_transaction_session_timeout = 2252"), + ]; + + let url = Url::parse(dburl.as_str()).unwrap(); + let (config, _tls) = rewrite_tls_args(&url).unwrap(); + let builder = PostgresSource::::new(config, NoTls, 2).unwrap(); + let mut destination = ArrowDestination::new(); + let mut dispatcher = Dispatcher::<_, _, PostgresArrowTransport>::new( + builder, + &mut destination, + &queries, + None, + ); + + dispatcher.set_pre_execution_queries(Some(&pre_execution_queries)); + + dispatcher.run().expect("run dispatcher"); + + let result = destination.arrow().unwrap(); + + assert!(result.len() == 1); + + assert!(result[0] + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .eq(&Int32Array::from(vec![2252, 2151]))); +} + +#[test] +fn test_postgres_partitioned_pre_execution_queries() { + let _ = env_logger::builder().is_test(true).try_init(); + + let dburl = env::var("POSTGRES_URL").unwrap(); + + let queries = [ + "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'statement_timeout'", + "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'idle_in_transaction_session_timeout'" + ]; + + let pre_execution_queries = [ + String::from("SET SESSION statement_timeout = 2151"), + String::from("SET SESSION idle_in_transaction_session_timeout = 2252"), + ]; + + let url = Url::parse(dburl.as_str()).unwrap(); + let (config, _tls) = rewrite_tls_args(&url).unwrap(); + let builder = PostgresSource::::new(config, NoTls, 2).unwrap(); + let mut destination = ArrowDestination::new(); + let mut dispatcher = Dispatcher::<_, _, PostgresArrowTransport>::new( + builder, + &mut destination, + &queries, + None, + ); + + dispatcher.set_pre_execution_queries(Some(&pre_execution_queries)); + + dispatcher.run().expect("run dispatcher"); + + let result = destination.arrow().unwrap(); + + assert!(result.len() == 2); + + let mut result_map = std::collections::HashMap::new(); + for record_batch in result { + let name = record_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .to_string(); + let setting = record_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + result_map.insert(name, setting); + } + + assert_eq!(result_map.get("statement_timeout"), Some(&2151)); + assert_eq!( + result_map.get("idle_in_transaction_session_timeout"), + Some(&2252) + ); +} diff --git a/docs/api.md b/docs/api.md index d93e6f89fe..47c96638f5 100644 --- a/docs/api.md +++ b/docs/api.md @@ -16,7 +16,7 @@ connectorx.read_sql(conn: Union[str, Dict[str, str]], query: Union[List[str], st - `partition_range: Optional[Tuple[int, int]]`: The value range of the partition column. - `partition_num: Optional[int]`: The number of partitions to generate. - `index_col: Optional[str]`: The index column to set for the result dataframe. Only applicable when `return_type` is `pandas`, `modin` or `dask`. - +- `pre_execution_query: Optional[Union[str, List[str]]]`: SQL query or list of SQL queries executed before main query. Can be used to set runtime configurations using SET statements. Only applicable for source "Postgres" and "MySQL" ## Examples - Read a DataFrame from a SQL using a single thread