Skip to content
43 changes: 43 additions & 0 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,3 +1399,46 @@ def set_plus0000_timezone_to_utc(tbl: pyarrow.Table) -> pyarrow.Table:

new_schema = pyarrow.schema(fields, metadata=tbl.schema.metadata)
return pyarrow.Table.from_arrays(arrays, schema=new_schema)


def cast_date64_columns_to_timestamp(tbl: pyarrow.Table, tz: Optional[str] = None) -> pyarrow.Table:
"""
Cast any date64 columns to timestamp with microsecond precision, preserving the
semantic time values. Uses pyarrow.compute.cast on the column (works for chunked arrays)
and promotes precision from milliseconds (date64) to microseconds (timestamp[us]).

Args:
tbl: Input Arrow table.
tz: Optional timezone to annotate the resulting timestamp with (e.g. "UTC").
If None (default), produces a naive timestamp.

Returns:
A new table with date64 columns cast to timestamp[us] (optionally tz-aware),
or the original table if no date64 columns were found.
"""
arrays, fields = [], []
changed = False

for col, fld in zip(tbl.columns, tbl.schema):
if pyarrow.types.is_date64(fld.type):
changed = True
# promote to microseconds to avoid precision loss in downstream systems
unit = "us"
new_type = pyarrow.timestamp(unit, tz)
# reinterpret underlying 64-bit values without rescaling units
if isinstance(col, pyarrow.ChunkedArray):
new_chunks = [c.view(new_type) for c in col.chunks]
new_col = pyarrow.chunked_array(new_chunks)
else:
new_col = col.view(new_type)
arrays.append(new_col)
fields.append(pyarrow.field(fld.name, new_type, fld.nullable, fld.metadata))
else:
arrays.append(col)
fields.append(fld)

if not changed:
return tbl

new_schema = pyarrow.schema(fields, metadata=tbl.schema.metadata)
return pyarrow.Table.from_arrays(arrays, schema=new_schema)
30 changes: 27 additions & 3 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
configspec,
)
from dlt.common.exceptions import DltException, MissingDependencyException
from dlt.common.libs.pyarrow import cast_date64_columns_to_timestamp
from dlt.common.schema import TTableSchemaColumns
from dlt.common.schema.typing import TWriteDispositionDict
from dlt.common.typing import TColumnNames, TDataItem, TSortOrder
Expand Down Expand Up @@ -222,17 +223,27 @@ def _load_rows(self, query: SelectClause, backend_kwargs: Dict[str, Any]) -> TDa
def _load_rows_connectorx(
self, query: SelectClause, backend_kwargs: Dict[str, Any]
) -> Iterator[TDataItem]:
import pyarrow as pa

try:
import connectorx as cx
except ImportError:
raise MissingDependencyException("Connector X table backend", ["connectorx"])

# default settings
backend_kwargs = {
"return_type": "arrow",
"protocol": "binary",
**backend_kwargs,
}

is_streaming = False
if "return_type" in backend_kwargs:
if backend_kwargs["return_type"] == "arrow_stream":
is_streaming = True
backend_kwargs["batch_size"] = backend_kwargs.get("batch_size", self.chunk_size)
else:
backend_kwargs["return_type"] = "arrow"

conn = backend_kwargs.pop(
"conn",
self.engine.url._replace(
Expand All @@ -248,8 +259,21 @@ def _load_rows_connectorx(
f" literals that cannot be rendered, upgrade to 2.x: `{str(ex)}`"
) from ex
logger.info(f"Executing query on ConnectorX: {query_str}")
df = cx.read_sql(conn, query_str, **backend_kwargs)
yield self._maybe_fix_0000_timezone(df)

if is_streaming:
record_reader = cx.read_sql(conn, query_str, **backend_kwargs)
for record_batch in record_reader:
table = pa.Table.from_batches((record_batch,), schema=record_batch.schema)
yield cast_date64_columns_to_timestamp(self._maybe_fix_0000_timezone(table))
else:
df = cx.read_sql(conn, query_str, **backend_kwargs)
if len(df) > self.chunk_size:
logger.info(
f"The size of the dataset being loaded is more than {self.chunk_size}, consider"
" using streaming mode (see"
" https://dlthub.com/docs/dlt-ecosystem/verified-sources/sql_database/configuration#connectorx)"
)
yield self._maybe_fix_0000_timezone(df)

def _maybe_fix_0000_timezone(self, df: Any) -> Any:
"""Optionally convert +00:00 timezone to UTC"""
Expand Down
1 change: 1 addition & 0 deletions dlt/sources/sql_database/schema_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,5 @@ def table_to_resource_hints(
}
if resolve_foreign_keys:
result["references"] = get_table_references(table)
# print("RES HINTS", result)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@
"from dlt.sources.helpers.rest_client.auth import BearerTokenAuth\n",
"from dlt.common.typing import TDataItems\n",
"\n",
"\n",
"@dlt.source\n",
"def github_source(\n",
" access_token=dlt.secrets.value,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ print(info)
The [`ConnectorX`](https://sfu-db.github.io/connector-x/intro.html) backend completely skips `SQLALchemy` when reading table rows, in favor of doing that in Rust. This is claimed to be significantly faster than any other method (validated only on PostgreSQL). With the default settings, it will emit `PyArrow` tables, but you can configure this by specifying the `return_type` in `backend_kwargs`. (See the [`ConnectorX` docs](https://sfu-db.github.io/connector-x/api.html) for a full list of configurable parameters.)

There are certain limitations when using this backend:
* It will ignore `chunk_size`. `ConnectorX` cannot yield data in batches.
* Unless `return_type` is set to `arrow_stream` in `backend_kwargs`, it will ignore `chunk_size`. Please note that certain data types such as arrays and high-precision time types are not supported in streaming mode by `ConnectorX`. We also observer that timestamps are not properly returned: tz-aware timestamps are passed without timezone, naive timestamps are passed as date64 which we internally cast back to naive timestamps.
* In many cases, it requires a connection string that differs from the `SQLAlchemy` connection string. Use the `conn` argument in `backend_kwargs` to set this.
* For `connectorx>=0.4.2`, on `reflection_level="minimal"`, `connectorx` can return decimal values. On higher `reflection_level`, dlt will coerce the data type (e.g., modify the decimal `precision` and `scale`, convert to `float`).
* For `connectorx<0.4.2`, dlt will convert decimals to doubles, thus losing numerical precision.
Expand Down
46 changes: 46 additions & 0 deletions tests/libs/pyarrow/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_arrow_item,
remove_null_columns_from_schema,
UnsupportedArrowTypeException,
cast_date64_columns_to_timestamp,
)
from dlt.common.destination import DestinationCapabilitiesContext
from tests.cases import table_update_and_row
Expand Down Expand Up @@ -493,3 +494,48 @@ def test_fill_empty_source_column_values_with_placeholder() -> None:
]
expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C", "D"])
assert new_table.equals(expected_table)


def test_cast_date64_columns_to_timestamp_preserves_ms_bits() -> None:
# Prepare timestamp[us] values with non-ms microseconds to detect precision loss
us_values = [0, 1001, 1609459200123123, 1609459200456789, None]
ts_us_arr = pa.array(us_values, type=pa.timestamp("us"))
# Mimic connectorx mis-typed output by reinterpreting as date64[ms]
date64_arr = ts_us_arr.view(pa.date64())
tbl = pa.table({"ts_like": date64_arr})

# Reinterpret date64 -> timestamp[us] (naive)
out = cast_date64_columns_to_timestamp(tbl)

# Type is timestamp[us] and values are preserved exactly (no precision loss)
assert pa.types.is_timestamp(out["ts_like"].type)
assert out["ts_like"].type == pa.timestamp("us")
expected_us = ts_us_arr
# Table columns are ChunkedArray; compare against a chunked view of the expected array
assert out["ts_like"].equals(pa.chunked_array([expected_us]))
# Additionally ensure ms integer 1609459200456 is present when converting back to ms
micros = pa.compute.cast(out["ts_like"], pa.int64()).combine_chunks()
assert micros[3].as_py() == 1609459200456789


def test_cast_date64_is_noop_when_absent_and_returns_same_object() -> None:
# Table without date64 columns should be returned unchanged (same object)
tbl = pa.table({"a": pa.array([1, 2, None]), "b": pa.array(["x", "y", "z"])})
out = cast_date64_columns_to_timestamp(tbl)
assert out is tbl


def test_cast_date64_chunked_array_support() -> None:
# Build a chunked date64 column from two timestamp[us] chunks (simulate DB microseconds)
vals1 = pa.array([0, 1001, 2002], type=pa.timestamp("us"))
vals2 = pa.array([1609459200123123, None], type=pa.timestamp("us"))
date64_chunked = pa.chunked_array([vals1.view(pa.date64()), vals2.view(pa.date64())])
tbl = pa.table({"ts_like": date64_chunked})

out = cast_date64_columns_to_timestamp(tbl)

# Should be timestamp[us] chunked array equal to original timestamp chunks (no rescale)
assert pa.types.is_timestamp(out["ts_like"].type)
assert out["ts_like"].type == pa.timestamp("us")
expected = pa.chunked_array([vals1, vals2])
assert out["ts_like"].equals(expected)
19 changes: 15 additions & 4 deletions tests/load/sources/sql_database/test_sql_database_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from copy import deepcopy
from typing import Any, Callable, cast, List, Optional, Set
from importlib.metadata import version

import pytest
from pytest_mock import MockerFixture
Expand All @@ -9,11 +10,11 @@
from dlt.common import logger
from dlt.common import json
from dlt.common.configuration.exceptions import ConfigFieldMissingException
from dlt.common.exceptions import MissingDependencyException
from dlt.common.exceptions import DependencyVersionException, MissingDependencyException

from dlt.common.schema.typing import TColumnSchema, TSortOrder, TTableSchemaColumns
from dlt.common.time import ensure_pendulum_datetime_utc
from dlt.common.utils import uniq_id
from dlt.common.utils import assert_min_pkg_version, uniq_id

from dlt.extract.exceptions import ResourceExtractionError
from dlt.extract.incremental.transform import JsonIncremental, ArrowIncremental
Expand Down Expand Up @@ -1443,6 +1444,7 @@ def assert_precision_columns(
elif backend == "connectorx":
# connector x emits 32 precision which gets merged with sql alchemy schema
del actual[0]["precision"]
expected = add_default_decimal_precision(expected, is_connectorx=True)
assert actual == expected


Expand Down Expand Up @@ -1525,14 +1527,23 @@ def convert_connectorx_types(columns: List[TColumnSchema]) -> List[TColumnSchema
column["precision"] = 16 # only int and bigint in connectorx
if column["data_type"] == "text" and column.get("precision"):
del column["precision"]
if column["data_type"] == "decimal" and column["name"] == "numeric_default_col":
try:
assert_min_pkg_version(pkg_name="connectorx", version="0.4.4")
add_default_decimal_precision([column], is_connectorx=True)
except DependencyVersionException:
pass
return columns


def add_default_decimal_precision(columns: List[TColumnSchema]) -> List[TColumnSchema]:
def add_default_decimal_precision(
columns: List[TColumnSchema], is_connectorx: bool = False
) -> List[TColumnSchema]:
scale = 9 if not is_connectorx else 10
for column in columns:
if column["data_type"] == "decimal" and not column.get("precision"):
column["precision"] = 38
column["scale"] = 9
column["scale"] = scale
return columns


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,68 @@ def test_load_sql_schema_loads_all_tables_parallel(
assert_row_counts(pipeline, postgres_db)


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True),
ids=lambda x: x.name,
)
def test_load_sql_schema_loads_all_tables_parallel_connectorx_arrow_stream(
postgres_db: PostgresSourceDB,
destination_config: DestinationTestConfiguration,
) -> None:
pipeline = destination_config.setup_pipeline(
"test_load_sql_schema_loads_all_tables_parallel_connectorx", dev_mode=True
)
os.environ["SOURCES__SQL_DATABASE__HAS_PRECISION__EXCLUDED_COLUMNS"] = '["array_col"]'
os.environ["SOURCES__SQL_DATABASE__HAS_PRECISION_NULLABLE__EXCLUDED_COLUMNS"] = '["array_col"]'

source = sql_database(
credentials=postgres_db.credentials,
schema=postgres_db.schema,
backend="connectorx",
reflection_level="minimal",
backend_kwargs={"return_type": "arrow_stream"},
type_adapter_callback=default_test_callback(
destination_config.destination_type, "connectorx"
),
).parallelize()

if destination_config.destination_type == "bigquery":
# connectorx generates nanoseconds time which bigquery cannot load
source.has_precision.add_map(convert_time_to_us)
source.has_precision_nullable.add_map(convert_time_to_us)

load_info = pipeline.run(source)
# print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at))
assert_load_info(load_info)
assert_row_counts(pipeline, postgres_db)

# make sure timestamp ntz is correct (stream converts them to dates, we convert them back)
assert (
pipeline.default_schema.tables["has_precision"]["columns"]["datetime_ntz_col"]["data_type"]
== "timestamp"
)
# fetch more than one row to avoid flakiness when first value lands on exact ms
data_ = pipeline.dataset().table("has_precision").limit(100).arrow()
import pyarrow as pa # local import for assertions

# verify Arrow logical type and timezone
col_field = data_.schema.field("datetime_ntz_col")
assert pa.types.is_timestamp(col_field.type)
assert col_field.type.unit == "us"

# verify at least one value has non-zero microseconds (not only millisecond multiples)
micros = pa.compute.cast(data_["datetime_ntz_col"], pa.int64()).combine_chunks()
has_sub_ms = False
for i in range(len(micros)):
if micros[i].is_valid and (micros[i].as_py() % 1000 != 0):
has_sub_ms = True
break
assert (
has_sub_ms
), "Expected at least one datetime_ntz_col value with non-zero microseconds remainder"


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True),
Expand Down
Loading
Loading