Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions connectorx-cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ pub unsafe extern "C" fn connectorx_scan(conn: *const c_char, query: *const c_ch
let conn_str = unsafe { CStr::from_ptr(conn) }.to_str().unwrap();
let query_str = unsafe { CStr::from_ptr(query) }.to_str().unwrap();
let source_conn = SourceConn::try_from(conn_str).unwrap();
let record_batches = get_arrow(&source_conn, None, &[CXQuery::from(query_str)])
let record_batches = get_arrow(&source_conn, None, &[CXQuery::from(query_str)], None)
.unwrap()
.arrow()
.unwrap();
Expand Down Expand Up @@ -281,7 +281,7 @@ pub unsafe extern "C" fn connectorx_scan_iter(
}

let arrow_iter: Box<dyn RecordBatchIterator> =
new_record_batch_iter(&source_conn, None, query_vec.as_slice(), batch_size);
new_record_batch_iter(&source_conn, None, query_vec.as_slice(), batch_size, None);

Box::into_raw(Box::new(arrow_iter))
}
Expand Down
21 changes: 21 additions & 0 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def read_sql_pandas(
partition_on: str | None = None,
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
pre_execution_queries: list[str] | str | None = None,
) -> pd.DataFrame:
"""
Run the SQL query, download the data from database into a dataframe.
Expand Down Expand Up @@ -160,6 +161,7 @@ def read_sql_pandas(
partition_range=partition_range,
partition_num=partition_num,
index_col=index_col,
pre_execution_queries=pre_execution_queries,
)


Expand All @@ -174,6 +176,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pd.DataFrame: ...


Expand All @@ -188,6 +191,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pd.DataFrame: ...


Expand All @@ -202,6 +206,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pa.Table: ...


Expand All @@ -216,6 +221,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> mpd.DataFrame: ...


Expand All @@ -230,6 +236,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> dd.DataFrame: ...


Expand All @@ -244,6 +251,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pl.DataFrame: ...


Expand All @@ -260,6 +268,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
strategy: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
"""
Run the SQL query, download the data from database into a dataframe.
Expand All @@ -285,6 +294,9 @@ def read_sql(
the index column to set; only applicable for return type "pandas", "modin", "dask".
strategy
strategy of rewriting the federated query for join pushdown
pre_execution_query
SQL query or list of SQL queries executed before main query; can be used to set runtime
configurations using SET statements; only applicable for source "Postgres" and "MySQL".

Examples
========
Expand Down Expand Up @@ -358,6 +370,13 @@ def read_sql(
raise ValueError("Partition on multiple queries is not supported.")
else:
raise ValueError("query must be either str or a list of str")

if isinstance(pre_execution_query, list):
pre_execution_queries = [remove_ending_semicolon(subquery) for subquery in pre_execution_query]
elif isinstance(pre_execution_query, str):
pre_execution_queries = [remove_ending_semicolon(pre_execution_query)]
else:
pre_execution_queries = None

conn, protocol = rewrite_conn(conn, protocol)

Expand All @@ -370,6 +389,7 @@ def read_sql(
queries=queries,
protocol=protocol,
partition_query=partition_query,
pre_execution_queries=pre_execution_queries,
)
df = reconstruct_pandas(result)

Expand All @@ -392,6 +412,7 @@ def read_sql(
queries=queries,
protocol=protocol,
partition_query=partition_query,
pre_execution_queries=pre_execution_queries,
)
df = reconstruct_arrow(result)
if return_type in {"polars"}:
Expand Down
2 changes: 2 additions & 0 deletions connectorx-python/connectorx/connectorx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def read_sql(
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
pre_execution_queries: list[str] | None,
) -> _DataframeInfos: ...
@overload
def read_sql(
Expand All @@ -33,6 +34,7 @@ def read_sql(
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
pre_execution_queries: list[str] | None,
) -> _ArrowInfos: ...
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...
Expand Down
49 changes: 49 additions & 0 deletions connectorx-python/connectorx/tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,52 @@ def test_mysql_cte(mysql_url: str) -> None:

def test_connection_url(mysql_url: str) -> None:
test_mysql_cte(ConnectionUrl(mysql_url))

def test_mysql_single_pre_execution_queries(mysql_url: str) -> None:
pre_execution_query = "SET SESSION max_execution_time = 2151"
query = "SELECT @@SESSION.max_execution_time AS max_execution_time"
df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query)
expected = pd.DataFrame(
index=range(1),
data={
"max_execution_time": pd.Series([2151], dtype="float64")
},
)
assert_frame_equal(df, expected, check_names=True)


def test_mysql_multiple_pre_execution_queries(mysql_url: str) -> None:
pre_execution_query = [
"SET SESSION max_execution_time = 2151",
"SET SESSION wait_timeout = 2252",
]
query = "SELECT @@SESSION.max_execution_time AS max_execution_time, @@SESSION.wait_timeout AS wait_timeout"
df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query)
expected = pd.DataFrame(
index=range(1),
data={
"max_execution_time": pd.Series([2151], dtype="float64"),
"wait_timeout": pd.Series([2252], dtype="float64")
},
)
assert_frame_equal(df, expected, check_names=True)

def test_mysql_partitioned_pre_execution_queries(mysql_url: str) -> None:
pre_execution_query = [
"SET SESSION max_execution_time = 2151",
"SET SESSION wait_timeout = 2252",
]
query = [
'SELECT "max_execution_time" AS name, @@SESSION.max_execution_time AS setting',
'SELECT "wait_timeout" AS name, @@SESSION.wait_timeout AS setting'
]
df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query).sort_values(by=['name'])
expected = pd.DataFrame(
index=range(2),
data={
"name": pd.Series(["max_execution_time", "wait_timeout"], dtype="str"),
"setting": pd.Series([2151, 2252], dtype="float64"),
},
).sort_values(by=['name'])

assert_frame_equal(df, expected, check_like=False)
51 changes: 50 additions & 1 deletion connectorx-python/connectorx/tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,4 +1008,53 @@ def test_postgres_partition_with_orderby_limit_desc(postgres_url: str) -> None:
},
)
df.sort_values(by="test_int", inplace=True, ignore_index=True)
assert_frame_equal(df, expected, check_names=True)
assert_frame_equal(df, expected, check_names=True)

def test_postgres_single_pre_execution_queries(postgres_url: str) -> None:
pre_execution_query = "SET SESSION statement_timeout = 2151"
query = "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) FROM pg_settings WHERE name = 'statement_timeout'"
df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query)
expected = pd.DataFrame(
index=range(1),
data={
"name": pd.Series(["statement_timeout"], dtype="str"),
"setting": pd.Series([2151], dtype="Int64"),
},
)
assert_frame_equal(df, expected, check_names=True)

def test_postgres_multiple_pre_execution_queries(postgres_url: str) -> None:
pre_execution_query = [
"SET SESSION statement_timeout = 2151",
"SET SESSION idle_in_transaction_session_timeout = 2252",
]
query = "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) FROM pg_settings WHERE name IN ('statement_timeout', 'idle_in_transaction_session_timeout') ORDER BY name"
df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query)
expected = pd.DataFrame(
index=range(2),
data={
"name": pd.Series(["idle_in_transaction_session_timeout", "statement_timeout"], dtype="str"),
"setting": pd.Series([2252, 2151], dtype="Int64"),
},
)
assert_frame_equal(df, expected, check_names=True)

def test_postgres_partitioned_pre_execution_queries(postgres_url: str) -> None:
pre_execution_query = [
"SET SESSION statement_timeout = 2151",
"SET SESSION idle_in_transaction_session_timeout = 2252",
]
query = [
"SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'statement_timeout'",
"SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'idle_in_transaction_session_timeout'"
]

df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query).sort_values(by=['name'])
expected = pd.DataFrame(
index=range(2),
data={
"name": pd.Series(["statement_timeout", "idle_in_transaction_session_timeout"], dtype="str"),
"setting": pd.Series([2151, 2252], dtype="Int64"),
},
).sort_values(by=['name'])
assert_frame_equal(df, expected, check_names=True)
3 changes: 2 additions & 1 deletion connectorx-python/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ pub fn write_arrow<'py>(
source_conn: &SourceConn,
origin_query: Option<String>,
queries: &[CXQuery<String>],
pre_execution_queries: Option<&[String]>,
) -> Bound<'py, PyAny> {
let ptrs = py.allow_threads(
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
let destination = get_arrow(source_conn, origin_query, queries)?;
let destination = get_arrow(source_conn, origin_query, queries, pre_execution_queries)?;
let rbs = destination.arrow()?;
Ok(to_ptrs(rbs))
},
Expand Down
3 changes: 3 additions & 0 deletions connectorx-python/src/cx_read_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub fn read_sql<'py>(
protocol: Option<&str>,
queries: Option<Vec<String>>,
partition_query: Option<PyPartitionQuery>,
pre_execution_queries: Option<Vec<String>>,
) -> PyResult<Bound<'py, PyAny>> {
let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?;
let (queries, origin_query) = match (queries, partition_query) {
Expand All @@ -62,12 +63,14 @@ pub fn read_sql<'py>(
&source_conn,
origin_query,
&queries,
pre_execution_queries.as_deref(),
)?),
"arrow" => Ok(crate::arrow::write_arrow(
py,
&source_conn,
origin_query,
&queries,
pre_execution_queries.as_deref(),
)?),
_ => Err(PyValueError::new_err(format!(
"return type should be 'pandas' or 'arrow', got '{}'",
Expand Down
13 changes: 11 additions & 2 deletions connectorx-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,25 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
}

#[pyfunction]
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None))]
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))]
pub fn read_sql<'py>(
py: Python<'py>,
conn: &str,
return_type: &str,
protocol: Option<&str>,
queries: Option<Vec<String>>,
partition_query: Option<cx_read_sql::PyPartitionQuery>,
pre_execution_queries: Option<Vec<String>>,
) -> PyResult<Bound<'py, PyAny>> {
cx_read_sql::read_sql(py, conn, return_type, protocol, queries, partition_query)
cx_read_sql::read_sql(
py,
conn,
return_type,
protocol,
queries,
partition_query,
pre_execution_queries,
)
}

#[pyfunction]
Expand Down
4 changes: 4 additions & 0 deletions connectorx-python/src/pandas/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ where
}
}

pub fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) {
self.src.set_pre_execution_queries(pre_execution_queries);
}

/// Start the data loading process.
pub fn run(mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TP::Error> {
debug!("Run dispatcher");
Expand Down
Loading
Loading