Skip to content

Commit 0b46462

Browse files
authored
[python] Fix partition readers' naïve assumption that nnz is fast (#3915)
* add work-around for single-fragment nnz * add raise_if_slow to nnz, and use in partitioning readers * add raise_if_slow to nnz, and use in partitioning readers * revert likely buggy nnz optimization * improve fallback partitioning for ExperimentAxisQuery * pin numcodecs * pr fb * pr fb * PR fb
1 parent 9d2ab41 commit 0b46462

File tree

5 files changed

+36
-17
lines changed

5 files changed

+36
-17
lines changed

apis/python/src/tiledbsoma/_query.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
if TYPE_CHECKING:
5555
from ._experiment import Experiment
5656
from ._constants import SPATIAL_DISCLAIMER
57+
from ._exception import SOMAError
5758
from ._fastercsx import CompressedMatrix
5859
from ._measurement import Measurement
5960
from ._sparse_nd_array import SparseNDArray
@@ -812,7 +813,10 @@ def _read_as_csr(
812813

813814
d0_joinids = d0_joinids_arr.to_numpy()
814815
d1_joinids = d1_joinids_arr.to_numpy()
815-
nnz = matrix.nnz
816+
try:
817+
nnz: int | None = matrix._handle._handle.nnz(raise_if_slow=True)
818+
except SOMAError:
819+
nnz = None
816820

817821
# if able, downcast from int64 - reduces working memory
818822
index_dtype = (
@@ -852,13 +856,16 @@ def _reindex(batch: pa.RecordBatch) -> pa.RecordBatch:
852856

853857
approx_X_shape = tuple(b - a + 1 for a, b in matrix.non_empty_domain())
854858
# heuristically derived number (benchmarking). Thesis is that this is roughly 80% of a 1 GiB io buffer,
855-
# which is the default for SOMA.
859+
# which is the default for SOMA. If we have fast NNZ, use to partition based upon memory size. If we do not
860+
# have fast NNZ, pick a default partition size which is large, but not so large as to rule out reasonable
861+
# parallelism (in this case, calculated based on typical scRNASeq assay density of 3-6K features).
856862
target_point_count = 96 * 1024**2
863+
fallback_row_count = 32768
857864
# compute partition size from array density and target point count, rounding to nearest 1024.
858865
partition_size = (
859866
max(1024 * round(approx_X_shape[0] * target_point_count / nnz / 1024), 1024)
860-
if nnz > 0
861-
else approx_X_shape[0]
867+
if nnz is not None and nnz > 0
868+
else min(fallback_row_count, approx_X_shape[0])
862869
)
863870
splits = list(
864871
range(
@@ -867,7 +874,7 @@ def _reindex(batch: pa.RecordBatch) -> pa.RecordBatch:
867874
partition_size,
868875
)
869876
)
870-
if len(splits) > 0:
877+
if len(splits) > 1:
871878
d0_joinids_splits = np.array_split(np.partition(d0_joinids, splits), splits)
872879
tp = matrix.context.threadpool
873880
tbl = pa.concat_tables(

apis/python/src/tiledbsoma/io/outgest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,14 @@ def _read_partitioned_sparse(X: SparseNDArray, d0_size: int) -> pa.Table:
106106
# density of matrix. Magic number determined empirically, as a tradeoff
107107
# between concurrency and fixed query overhead.
108108
tgt_point_count = 96 * 1024**2
109-
nnz = X.nnz
109+
try:
110+
nnz: int | None = X._handle._handle.nnz(raise_if_slow=True)
111+
except SOMAError:
112+
nnz = None
110113
partition_sz = (
111114
max(1024 * round(d0_size * tgt_point_count / nnz / 1024), 1024)
112-
if nnz > 0
113-
else d0_size
115+
if nnz is not None and nnz > 0
116+
else d0_size # i.e, no partitioning
114117
)
115118
partitions = [
116119
slice(st, min(st + partition_sz - 1, d0_size - 1))

apis/python/src/tiledbsoma/soma_array.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ void load_soma_array(py::module& m) {
153153
&SOMAArray::config_options_from_schema)
154154
.def("context", &SOMAArray::ctx)
155155

156-
.def("nnz", &SOMAArray::nnz, py::call_guard<py::gil_scoped_release>())
156+
.def(
157+
"nnz",
158+
&SOMAArray::nnz,
159+
py::arg("raise_if_slow") = false,
160+
py::call_guard<py::gil_scoped_release>())
157161

158162
.def_property_readonly("uri", &SOMAArray::uri)
159163

libtiledbsoma/src/soma/soma_array.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ ArrowTable SOMAArray::_get_core_domainish(enum Domainish which_kind) {
543543
return ArrowTable(std::move(arrow_array), std::move(arrow_schema));
544544
}
545545

546-
uint64_t SOMAArray::nnz() {
546+
uint64_t SOMAArray::nnz(bool raise_if_slow) {
547547
// Verify array is sparse
548548
if (schema_->array_type() != TILEDB_SPARSE) {
549549
throw TileDBSOMAError(
@@ -578,7 +578,7 @@ uint64_t SOMAArray::nnz() {
578578
frag_ts.second <= timestamp_->second)) {
579579
// fragment overlaps read timestamp range, but isn't fully
580580
// contained within: fall back to count_cells to sort that out.
581-
return _nnz_slow();
581+
return _nnz_slow(raise_if_slow);
582582
}
583583
}
584584
// fall through: fragment is fully contained within the read timestamp
@@ -591,7 +591,7 @@ uint64_t SOMAArray::nnz() {
591591
// application's job to otherwise ensure uniqueness), then
592592
// sum-over-fragments is the right thing to do.
593593
if (!schema_->allows_dups() && frag_ts.first != frag_ts.second) {
594-
return _nnz_slow();
594+
return _nnz_slow(raise_if_slow);
595595
}
596596
}
597597

@@ -627,7 +627,7 @@ uint64_t SOMAArray::nnz() {
627627
"soma_joinid or int64 soma_dim_0: using _nnz_slow",
628628
tiledb::impl::type_to_str(type_code),
629629
dim_name));
630-
return _nnz_slow();
630+
return _nnz_slow(raise_if_slow);
631631
}
632632

633633
for (uint32_t i = 0; i < fragment_count; i++) {
@@ -668,10 +668,15 @@ uint64_t SOMAArray::nnz() {
668668
return total_cell_num;
669669
}
670670
// Found relevant fragments with overlap, count cells
671-
return _nnz_slow();
671+
return _nnz_slow(raise_if_slow);
672672
}
673673

674-
uint64_t SOMAArray::_nnz_slow() {
674+
uint64_t SOMAArray::_nnz_slow(bool raise_if_slow) {
675+
if (raise_if_slow) {
676+
throw TileDBSOMAError(
677+
"NNZ slow path called with 'raise_if_slow==true'");
678+
}
679+
675680
LOG_DEBUG(
676681
"[SOMAArray] nnz() found consolidated or overlapping fragments, "
677682
"counting cells...");

libtiledbsoma/src/soma/soma_array.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ class SOMAArray : public SOMAObject {
659659
*
660660
* @return uint64_t Total number of unique cells
661661
*/
662-
uint64_t nnz();
662+
uint64_t nnz(bool raise_if_slow = false);
663663

664664
/**
665665
* @brief Get the current capacity of each dimension.
@@ -1164,7 +1164,7 @@ class SOMAArray : public SOMAObject {
11641164
std::shared_ptr<Array> meta_cache_arr_;
11651165

11661166
// Unoptimized method for computing nnz() (issue `count_cells` query)
1167-
uint64_t _nnz_slow();
1167+
uint64_t _nnz_slow(bool raise_if_slow);
11681168
};
11691169

11701170
} // namespace tiledbsoma

0 commit comments

Comments
 (0)