diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 643a80467..5cbb9d5bc 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -2,7 +2,7 @@ import importlib import urllib.parse - +from collections.abc import Iterator from importlib.metadata import version from pathlib import Path from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar @@ -177,6 +177,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pd.DataFrame: ... @@ -192,6 +193,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pd.DataFrame: ... @@ -207,6 +209,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pa.Table: ... @@ -222,6 +225,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> mpd.DataFrame: ... @@ -237,6 +241,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> dd.DataFrame: ... @@ -252,6 +257,7 @@ def read_sql( partition_num: int | None = None, index_col: str | None = None, pre_execution_query: list[str] | str | None = None, + **kwargs ) -> pl.DataFrame: ... @@ -260,7 +266,7 @@ def read_sql( query: list[str] | str, *, return_type: Literal[ - "pandas", "polars", "arrow", "modin", "dask" + "pandas", "polars", "arrow", "modin", "dask", "arrow_stream" ] = "pandas", protocol: Protocol | None = None, partition_on: str | None = None, @@ -269,18 +275,20 @@ def read_sql( index_col: str | None = None, strategy: str | None = None, pre_execution_query: list[str] | str | None = None, -) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table: + **kwargs + +) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader: """ Run the SQL query, download the data from database into a dataframe. Parameters ========== conn - the connection string, or dict of connection string mapping for federated query. + the connection string, or dict of connection string mapping for a federated query. query a SQL query or a list of SQL queries. return_type - the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)". + the return type of this function; one of "arrow", "arrow_stream", "pandas", "modin", "dask" or "polars". protocol backend-specific transfer protocol directive; defaults to 'binary' (except for redshift connection strings, where 'cursor' will be used instead). @@ -293,10 +301,12 @@ def read_sql( index_col the index column to set; only applicable for return type "pandas", "modin", "dask". strategy - strategy of rewriting the federated query for join pushdown + strategy of rewriting the federated query for join pushdown. pre_execution_query SQL query or list of SQL queries executed before main query; can be used to set runtime configurations using SET statements; only applicable for source "Postgres" and "MySQL". + batch_size + the maximum size of each batch when return type is `arrow_stream`. Examples ======== @@ -414,6 +424,7 @@ def read_sql( partition_query=partition_query, pre_execution_queries=pre_execution_queries, ) + df = reconstruct_arrow(result) if return_type in {"polars"}: pl = try_import_module("polars") @@ -422,12 +433,46 @@ def read_sql( except AttributeError: # previous polars api (< 0.8.*) was pl.DataFrame.from_arrow df = pl.DataFrame.from_arrow(df) + elif return_type in {"arrow_stream"}: + batch_size = int(kwargs.get("batch_size", 10000)) + result = _read_sql( + conn, + "arrow_stream", + queries=queries, + protocol=protocol, + partition_query=partition_query, + pre_execution_queries=pre_execution_queries, + batch_size=batch_size + ) + + df = reconstruct_arrow_rb(result) else: raise ValueError(return_type) return df +def reconstruct_arrow_rb(results) -> pa.RecordBatchReader: + import pyarrow as pa + + # Get Schema + names, chunk_ptrs_list = results.schema_ptr() + for chunk_ptrs in chunk_ptrs_list: + arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs] + empty_rb = pa.RecordBatch.from_arrays(arrays, names) + + schema = empty_rb.schema + + def generate_batches(iterator) -> Iterator[pa.RecordBatch]: + for rb_ptrs in iterator: + chunk_ptrs = rb_ptrs.to_ptrs() + yield pa.RecordBatch.from_arrays( + [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names + ) + + return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results)) + + def reconstruct_arrow(result: _ArrowInfos) -> pa.Table: import pyarrow as pa diff --git a/connectorx-python/connectorx/connectorx.pyi b/connectorx-python/connectorx/connectorx.pyi index f63709c07..21a16f189 100644 --- a/connectorx-python/connectorx/connectorx.pyi +++ b/connectorx-python/connectorx/connectorx.pyi @@ -26,15 +26,17 @@ def read_sql( queries: list[str] | None, partition_query: dict[str, Any] | None, pre_execution_queries: list[str] | None, + **kwargs ) -> _DataframeInfos: ... @overload def read_sql( conn: str, - return_type: Literal["arrow"], + return_type: Literal["arrow", "arrow_stream"], protocol: str | None, queries: list[str] | None, partition_query: dict[str, Any] | None, pre_execution_queries: list[str] | None, + **kwargs ) -> _ArrowInfos: ... def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ... def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ... diff --git a/connectorx-python/connectorx/tests/test_arrow.py b/connectorx-python/connectorx/tests/test_arrow.py index 735cb85e6..d0ad82e4a 100644 --- a/connectorx-python/connectorx/tests/test_arrow.py +++ b/connectorx-python/connectorx/tests/test_arrow.py @@ -44,6 +44,73 @@ def test_arrow(postgres_url: str) -> None: df.sort_values(by="test_int", inplace=True, ignore_index=True) assert_frame_equal(df, expected, check_names=True) +def test_arrow_stream(postgres_url: str) -> None: + import pyarrow as pa + query = "SELECT * FROM test_table" + reader = read_sql( + postgres_url, + query, + return_type="arrow_stream", + batch_size=2, + ) + batches = [] + for batch in reader: + batches.append(batch) + table = pa.Table.from_batches(batches) + df = table.to_pandas() + df.sort_values(by="test_int", inplace=True, ignore_index=True) + + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"), + "test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"), + "test_str": pd.Series( + ["a", "str1", "str2", "b", "c", None], dtype="object" + ), + "test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"), + "test_bool": pd.Series( + [None, True, False, False, None, True], dtype="object" + ), + }, + ) + assert_frame_equal(df, expected, check_names=True) + +def test_arrow_stream_with_partition(postgres_url: str) -> None: + import pyarrow as pa + query = "SELECT * FROM test_table" + reader = read_sql( + postgres_url, + query, + partition_on="test_int", + partition_range=(0, 2000), + partition_num=3, + return_type="arrow_stream", + batch_size=2, + ) + batches = [] + for batch in reader: + batches.append(batch) + table = pa.Table.from_batches(batches) + df = table.to_pandas() + df.sort_values(by="test_int", inplace=True, ignore_index=True) + + expected = pd.DataFrame( + index=range(6), + data={ + "test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"), + "test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"), + "test_str": pd.Series( + ["a", "str1", "str2", "b", "c", None], dtype="object" + ), + "test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"), + "test_bool": pd.Series( + [None, True, False, False, None, True], dtype="object" + ), + }, + ) + assert_frame_equal(df, expected, check_names=True) + def decimal_s10(val): return Decimal(val).quantize(Decimal("0.0000000001")) diff --git a/connectorx-python/src/arrow.rs b/connectorx-python/src/arrow.rs index f506f60dd..0c9aca6dd 100644 --- a/connectorx-python/src/arrow.rs +++ b/connectorx-python/src/arrow.rs @@ -1,14 +1,83 @@ use crate::errors::ConnectorXPythonError; +use anyhow::anyhow; use arrow::record_batch::RecordBatch; use connectorx::source_router::SourceConn; use connectorx::{prelude::*, sql::CXQuery}; use fehler::throws; use libc::uintptr_t; use pyo3::prelude::*; +use pyo3::pyclass; use pyo3::{PyAny, Python}; use std::convert::TryFrom; use std::sync::Arc; +/// Python-exposed RecordBatch wrapper +#[pyclass] +pub struct PyRecordBatch(Option); + +/// Python-exposed iterator over RecordBatches +#[pyclass(module = "connectorx")] +pub struct PyRecordBatchIterator(Box); + +#[pymethods] +impl PyRecordBatch { + pub fn num_rows(&self) -> usize { + self.0.as_ref().map_or(0, |rb| rb.num_rows()) + } + + pub fn num_columns(&self) -> usize { + self.0.as_ref().map_or(0, |rb| rb.num_columns()) + } + + #[throws(ConnectorXPythonError)] + pub fn to_ptrs<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyAny> { + // Convert the RecordBatch to a vector of pointers, once the RecordBatch is taken, it cannot be reached again. + let rb = self + .0 + .take() + .ok_or_else(|| anyhow!("RecordBatch is None, cannot convert to pointers"))?; + let ptrs = py.allow_threads( + || -> Result, ConnectorXPythonError> { Ok(to_ptrs_rb(rb)) }, + )?; + let obj: PyObject = ptrs.into_py(py); + obj.into_bound(py) + } +} + +#[pymethods] +impl PyRecordBatchIterator { + #[throws(ConnectorXPythonError)] + fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> { + let (rb, _) = self.0.get_schema(); + let ptrs = py.allow_threads( + || -> Result<(Vec, Vec>), ConnectorXPythonError> { + let rbs = vec![rb]; + Ok(to_ptrs(rbs)) + }, + )?; + let obj: PyObject = ptrs.into_py(py); + obj.into_bound(py) + } + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__<'py>( + mut slf: PyRefMut<'py, Self>, + py: Python<'py>, + ) -> PyResult>> { + match slf.0.next_batch() { + Some(rb) => { + let wrapped = PyRecordBatch(Some(rb)); + let py_obj = Py::new(py, wrapped)?; + Ok(Some(py_obj)) + } + + None => Ok(None), + } + } +} + #[throws(ConnectorXPythonError)] pub fn write_arrow<'py>( py: Python<'py>, @@ -28,6 +97,48 @@ pub fn write_arrow<'py>( obj.into_bound(py) } +#[throws(ConnectorXPythonError)] +pub fn get_arrow_rb_iter<'py>( + py: Python<'py>, + source_conn: &SourceConn, + origin_query: Option, + queries: &[CXQuery], + pre_execution_queries: Option<&[String]>, + batch_size: usize, +) -> Bound<'py, PyAny> { + let mut arrow_iter: Box = new_record_batch_iter( + source_conn, + origin_query, + queries, + batch_size, + pre_execution_queries, + ); + + arrow_iter.prepare(); + let py_rb_iter = PyRecordBatchIterator(arrow_iter); + + let obj: PyObject = py_rb_iter.into_py(py); + obj.into_bound(py) +} + +pub fn to_ptrs_rb(rb: RecordBatch) -> Vec<(uintptr_t, uintptr_t)> { + let mut cols = vec![]; + + for array in rb.columns().into_iter() { + let data = array.to_data(); + let array_ptr = Arc::new(arrow::ffi::FFI_ArrowArray::new(&data)); + let schema_ptr = Arc::new( + arrow::ffi::FFI_ArrowSchema::try_from(data.data_type()).expect("export schema c"), + ); + cols.push(( + Arc::into_raw(array_ptr) as uintptr_t, + Arc::into_raw(schema_ptr) as uintptr_t, + )); + } + + cols +} + pub fn to_ptrs(rbs: Vec) -> (Vec, Vec>) { if rbs.is_empty() { return (vec![], vec![]); @@ -42,21 +153,7 @@ pub fn to_ptrs(rbs: Vec) -> (Vec, Vec( queries: Option>, partition_query: Option, pre_execution_queries: Option>, + kwargs: Option<&Bound>, ) -> PyResult> { let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?; let (queries, origin_query) = match (queries, partition_query) { @@ -72,6 +74,22 @@ pub fn read_sql<'py>( &queries, pre_execution_queries.as_deref(), )?), + "arrow_stream" => { + let batch_size = kwargs + .and_then(|dict| dict.get_item("batch_size").ok().flatten()) + .and_then(|obj| obj.extract::().ok()) + .unwrap_or(10000); + + Ok(crate::arrow::get_arrow_rb_iter( + py, + &source_conn, + origin_query, + &queries, + pre_execution_queries.as_deref(), + batch_size, + )?) + } + _ => Err(PyValueError::new_err(format!( "return type should be 'pandas' or 'arrow', got '{}'", return_type diff --git a/connectorx-python/src/lib.rs b/connectorx-python/src/lib.rs index b4f573b35..d6286dddf 100644 --- a/connectorx-python/src/lib.rs +++ b/connectorx-python/src/lib.rs @@ -8,6 +8,7 @@ use crate::constants::J4RS_BASE_PATH; use ::connectorx::{fed_dispatcher::run, partition::partition, source_router::parse_source}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; +use pyo3::types::PyDict; use pyo3::{wrap_pyfunction, PyResult}; use std::collections::HashMap; use std::env; @@ -35,11 +36,13 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(partition_sql))?; m.add_wrapped(wrap_pyfunction!(get_meta))?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } #[pyfunction] -#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))] +#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None, *, **kwargs))] pub fn read_sql<'py>( py: Python<'py>, conn: &str, @@ -48,6 +51,7 @@ pub fn read_sql<'py>( queries: Option>, partition_query: Option, pre_execution_queries: Option>, + kwargs: Option<&Bound>, ) -> PyResult> { cx_read_sql::read_sql( py, @@ -57,6 +61,7 @@ pub fn read_sql<'py>( queries, partition_query, pre_execution_queries, + kwargs, ) } diff --git a/connectorx/src/arrow_batch_iter.rs b/connectorx/src/arrow_batch_iter.rs index 1794a9616..ed4f3eef2 100644 --- a/connectorx/src/arrow_batch_iter.rs +++ b/connectorx/src/arrow_batch_iter.rs @@ -149,11 +149,11 @@ where type Item = RecordBatch; /// NOTE: not thread safe fn next(&mut self) -> Option { - self.dst.record_batch().unwrap() + self.dst.record_batch().ok().flatten() } } -pub trait RecordBatchIterator { +pub trait RecordBatchIterator: Send { fn get_schema(&self) -> (RecordBatch, &[String]); fn prepare(&mut self); fn next_batch(&mut self) -> Option; @@ -163,11 +163,11 @@ impl<'a, S, TP> RecordBatchIterator for ArrowBatchIter where S: Source + 'a, TP: Transport< - TSS = S::TypeSystem, - TSD = ArrowStreamTypeSystem, - S = S, - D = ArrowStreamDestination, - >, + TSS = S::TypeSystem, + TSD = ArrowStreamTypeSystem, + S = S, + D = ArrowStreamDestination, + > + std::marker::Send, { fn get_schema(&self) -> (RecordBatch, &[String]) { (self.dst.empty_batch(), self.dst.names()) diff --git a/connectorx/src/destinations/arrowstream/mod.rs b/connectorx/src/destinations/arrowstream/mod.rs index d8487a268..089b927bb 100644 --- a/connectorx/src/destinations/arrowstream/mod.rs +++ b/connectorx/src/destinations/arrowstream/mod.rs @@ -221,7 +221,7 @@ impl ArrowPartitionWriter { .map(|(builder, &dt)| Realize::::realize(dt)?(builder)) .collect::, crate::errors::ConnectorXError>>()?; let rb = RecordBatch::try_new(Arc::clone(&self.arrow_schema), columns)?; - self.sender.as_ref().unwrap().send(rb).unwrap(); + self.sender.as_ref().and_then(|s| s.send(rb).ok()); self.current_row = 0; self.current_col = 0;