Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import importlib
import urllib.parse

from collections.abc import Iterator
from importlib.metadata import version
from pathlib import Path
from typing import Literal, TYPE_CHECKING, overload, Generic, TypeVar
Expand Down Expand Up @@ -177,6 +177,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pd.DataFrame: ...


Expand All @@ -192,6 +193,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pd.DataFrame: ...


Expand All @@ -207,6 +209,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pa.Table: ...


Expand All @@ -222,6 +225,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> mpd.DataFrame: ...


Expand All @@ -237,6 +241,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> dd.DataFrame: ...


Expand All @@ -252,6 +257,7 @@ def read_sql(
partition_num: int | None = None,
index_col: str | None = None,
pre_execution_query: list[str] | str | None = None,
**kwargs
) -> pl.DataFrame: ...


Expand All @@ -260,7 +266,7 @@ def read_sql(
query: list[str] | str,
*,
return_type: Literal[
"pandas", "polars", "arrow", "modin", "dask"
"pandas", "polars", "arrow", "modin", "dask", "arrow_record_batches"
] = "pandas",
protocol: Protocol | None = None,
partition_on: str | None = None,
Expand All @@ -269,18 +275,20 @@ def read_sql(
index_col: str | None = None,
strategy: str | None = None,
pre_execution_query: list[str] | str | None = None,
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
**kwargs

) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table | pa.RecordBatchReader:
"""
Run the SQL query, download the data from database into a dataframe.

Parameters
==========
conn
the connection string, or dict of connection string mapping for federated query.
the connection string, or dict of connection string mapping for a federated query.
query
a SQL query or a list of SQL queries.
return_type
the return type of this function; one of "arrow(2)", "pandas", "modin", "dask" or "polars(2)".
the return type of this function; one of "arrow(2)", "arrow_record_batches", "pandas", "modin", "dask" or "polars(2)".
protocol
backend-specific transfer protocol directive; defaults to 'binary' (except for redshift
connection strings, where 'cursor' will be used instead).
Expand Down Expand Up @@ -414,6 +422,7 @@ def read_sql(
partition_query=partition_query,
pre_execution_queries=pre_execution_queries,
)

df = reconstruct_arrow(result)
if return_type in {"polars"}:
pl = try_import_module("polars")
Expand All @@ -422,12 +431,46 @@ def read_sql(
except AttributeError:
# previous polars api (< 0.8.*) was pl.DataFrame.from_arrow
df = pl.DataFrame.from_arrow(df)
elif return_type in {"arrow_record_batches"}:
record_batch_size = int(kwargs.get("record_batch_size", 10000))
result = _read_sql(
conn,
"arrow_record_batches",
queries=queries,
protocol=protocol,
partition_query=partition_query,
pre_execution_queries=pre_execution_queries,
record_batch_size=record_batch_size
)

df = reconstruct_arrow_rb(result)
else:
raise ValueError(return_type)

return df


def reconstruct_arrow_rb(results) -> pa.RecordBatchReader:
import pyarrow as pa

# Get Schema
names, chunk_ptrs_list = results.schema_ptr()
for chunk_ptrs in chunk_ptrs_list:
arrays = [pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs]
empty_rb = pa.RecordBatch.from_arrays(arrays, names)

schema = empty_rb.schema

def generate_batches(iterator) -> Iterator[pa.RecordBatch]:
for rb_ptrs in iterator:
chunk_ptrs = rb_ptrs.to_ptrs()
yield pa.RecordBatch.from_arrays(
[pa.Array._import_from_c(*col_ptr) for col_ptr in chunk_ptrs], names
)

return pa.RecordBatchReader.from_batches(schema=schema, batches=generate_batches(results))


def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
import pyarrow as pa

Expand Down
4 changes: 3 additions & 1 deletion connectorx-python/connectorx/connectorx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ def read_sql(
queries: list[str] | None,
partition_query: dict[str, Any] | None,
pre_execution_queries: list[str] | None,
**kwargs
) -> _DataframeInfos: ...
@overload
def read_sql(
conn: str,
return_type: Literal["arrow"],
return_type: Literal["arrow", "arrow_record_batches"],
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
pre_execution_queries: list[str] | None,
**kwargs
) -> _ArrowInfos: ...
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...
Expand Down
67 changes: 67 additions & 0 deletions connectorx-python/connectorx/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,73 @@ def test_arrow(postgres_url: str) -> None:
df.sort_values(by="test_int", inplace=True, ignore_index=True)
assert_frame_equal(df, expected, check_names=True)

def test_arrow_stream(postgres_url: str) -> None:
import pyarrow as pa
query = "SELECT * FROM test_table"
reader = read_sql(
postgres_url,
query,
return_type="arrow_record_batches",
record_batch_size=2,
)
batches = []
for batch in reader:
batches.append(batch)
table = pa.Table.from_batches(batches)
df = table.to_pandas()
df.sort_values(by="test_int", inplace=True, ignore_index=True)

expected = pd.DataFrame(
index=range(6),
data={
"test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"),
"test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"),
"test_str": pd.Series(
["a", "str1", "str2", "b", "c", None], dtype="object"
),
"test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"),
"test_bool": pd.Series(
[None, True, False, False, None, True], dtype="object"
),
},
)
assert_frame_equal(df, expected, check_names=True)

def test_arrow_stream_with_partition(postgres_url: str) -> None:
import pyarrow as pa
query = "SELECT * FROM test_table"
reader = read_sql(
postgres_url,
query,
partition_on="test_int",
partition_range=(0, 2000),
partition_num=3,
return_type="arrow_record_batches",
record_batch_size=2,
)
batches = []
for batch in reader:
batches.append(batch)
table = pa.Table.from_batches(batches)
df = table.to_pandas()
df.sort_values(by="test_int", inplace=True, ignore_index=True)

expected = pd.DataFrame(
index=range(6),
data={
"test_int": pd.Series([0, 1, 2, 3, 4, 1314], dtype="int64"),
"test_nullint": pd.Series([5, 3, None, 7, 9, 2], dtype="float64"),
"test_str": pd.Series(
["a", "str1", "str2", "b", "c", None], dtype="object"
),
"test_float": pd.Series([3.1, None, 2.2, 3, 7.8, -10], dtype="float64"),
"test_bool": pd.Series(
[None, True, False, False, None, True], dtype="object"
),
},
)
assert_frame_equal(df, expected, check_names=True)

def decimal_s10(val):
return Decimal(val).quantize(Decimal("0.0000000001"))

Expand Down
127 changes: 112 additions & 15 deletions connectorx-python/src/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,83 @@
use crate::errors::ConnectorXPythonError;
use anyhow::anyhow;
use arrow::record_batch::RecordBatch;
use connectorx::source_router::SourceConn;
use connectorx::{prelude::*, sql::CXQuery};
use fehler::throws;
use libc::uintptr_t;
use pyo3::prelude::*;
use pyo3::pyclass;
use pyo3::{PyAny, Python};
use std::convert::TryFrom;
use std::sync::Arc;

/// Python-exposed RecordBatch wrapper
#[pyclass]
pub struct PyRecordBatch(Option<RecordBatch>);

/// Python-exposed iterator over RecordBatches
#[pyclass(module = "connectorx")]
pub struct PyRecordBatchIterator(Box<dyn RecordBatchIterator>);

#[pymethods]
impl PyRecordBatch {
pub fn num_rows(&self) -> usize {
self.0.as_ref().map_or(0, |rb| rb.num_rows())
}

pub fn num_columns(&self) -> usize {
self.0.as_ref().map_or(0, |rb| rb.num_columns())
}

#[throws(ConnectorXPythonError)]
pub fn to_ptrs<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyAny> {
// Convert the RecordBatch to a vector of pointers, once the RecordBatch is taken, it cannot be reached again.
let rb = self
.0
.take()
.ok_or_else(|| anyhow!("RecordBatch is None, cannot convert to pointers"))?;
let ptrs = py.allow_threads(
|| -> Result<Vec<(uintptr_t, uintptr_t)>, ConnectorXPythonError> { Ok(to_ptrs_rb(rb)) },
)?;
let obj: PyObject = ptrs.into_py(py);
obj.into_bound(py)
}
}

#[pymethods]
impl PyRecordBatchIterator {
#[throws(ConnectorXPythonError)]
fn schema_ptr<'py>(&self, py: Python<'py>) -> Bound<'py, PyAny> {
let (rb, _) = self.0.get_schema();
let ptrs = py.allow_threads(
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
let rbs = vec![rb];
Ok(to_ptrs(rbs))
},
)?;
let obj: PyObject = ptrs.into_py(py);
obj.into_bound(py)
}
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __next__<'py>(
mut slf: PyRefMut<'py, Self>,
py: Python<'py>,
) -> PyResult<Option<Py<PyRecordBatch>>> {
match slf.0.next_batch() {
Some(rb) => {
let wrapped = PyRecordBatch(Some(rb));
let py_obj = Py::new(py, wrapped)?;
Ok(Some(py_obj))
}

None => Ok(None),
}
}
}

#[throws(ConnectorXPythonError)]
pub fn write_arrow<'py>(
py: Python<'py>,
Expand All @@ -28,6 +97,48 @@ pub fn write_arrow<'py>(
obj.into_bound(py)
}

#[throws(ConnectorXPythonError)]
pub fn get_arrow_rb_iter<'py>(
py: Python<'py>,
source_conn: &SourceConn,
origin_query: Option<String>,
queries: &[CXQuery<String>],
pre_execution_queries: Option<&[String]>,
batch_size: usize,
) -> Bound<'py, PyAny> {
let mut arrow_iter: Box<dyn RecordBatchIterator> = new_record_batch_iter(
source_conn,
origin_query,
queries,
batch_size,
pre_execution_queries,
);

arrow_iter.prepare();
let py_rb_iter = PyRecordBatchIterator(arrow_iter);

let obj: PyObject = py_rb_iter.into_py(py);
obj.into_bound(py)
}

pub fn to_ptrs_rb(rb: RecordBatch) -> Vec<(uintptr_t, uintptr_t)> {
let mut cols = vec![];

for array in rb.columns().into_iter() {
let data = array.to_data();
let array_ptr = Arc::new(arrow::ffi::FFI_ArrowArray::new(&data));
let schema_ptr = Arc::new(
arrow::ffi::FFI_ArrowSchema::try_from(data.data_type()).expect("export schema c"),
);
cols.push((
Arc::into_raw(array_ptr) as uintptr_t,
Arc::into_raw(schema_ptr) as uintptr_t,
));
}

cols
}

pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>) {
if rbs.is_empty() {
return (vec![], vec![]);
Expand All @@ -42,21 +153,7 @@ pub fn to_ptrs(rbs: Vec<RecordBatch>) -> (Vec<String>, Vec<Vec<(uintptr_t, uintp
.collect();

for rb in rbs.into_iter() {
let mut cols = vec![];

for array in rb.columns().into_iter() {
let data = array.to_data();
let array_ptr = Arc::new(arrow::ffi::FFI_ArrowArray::new(&data));
let schema_ptr = Arc::new(
arrow::ffi::FFI_ArrowSchema::try_from(data.data_type()).expect("export schema c"),
);
cols.push((
Arc::into_raw(array_ptr) as uintptr_t,
Arc::into_raw(schema_ptr) as uintptr_t,
));
}

result.push(cols);
result.push(to_ptrs_rb(rb));
}
(names, result)
}
Loading