Skip to content

Commit 18e3a52

Browse files
committed
Replace b_complement stack arrays with std::vector for large-tree support
The two splitbit b_complement[SL_MAX_SPLITS][SL_MAX_BINS] stack arrays in robinson_foulds_distance() and robinson_foulds_info() would overflow the stack when compiled against TreeTools with SL_MAX_TIPS = 32768 (128 GB). Replace with std::vector<splitbit> sized to actual dimensions (b.n_splits * n_bins). These are serial per-pair paths (reportMatching = TRUE), so heap allocation cost is negligible. Also upgrade assert() to static_assert() in tree_distances.h for the int16 width checks — these now fire at compile time rather than silently passing in release builds.
1 parent 8803015 commit 18e3a52

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

src/tree_distances.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,29 +51,32 @@ inline List robinson_foulds_distance(const RawMatrix &x, const RawMatrix &y,
5151

5252
grf_match matching(a.n_splits, NA_INTEGER);
5353

54-
splitbit b_complement[SL_MAX_SPLITS][SL_MAX_BINS];
54+
const int32 n_bins = a.n_bins;
55+
std::vector<splitbit> b_complement(b.n_splits * n_bins);
5556
for (int32 i = b.n_splits; i--; ) {
57+
splitbit* bc_i = &b_complement[i * n_bins];
5658
for (int32 bin = last_bin; bin--; ) {
57-
b_complement[i][bin] = ~b.state[i][bin];
59+
bc_i[bin] = ~b.state[i][bin];
5860
}
59-
b_complement[i][last_bin] = b.state[i][last_bin] ^ unset_mask;
61+
bc_i[last_bin] = b.state[i][last_bin] ^ unset_mask;
6062
}
6163

6264
for (int32 ai = a.n_splits; ai--; ) {
6365
for (int32 bi = b.n_splits; bi--; ) {
6466

6567
bool all_match = true;
6668
bool all_complement = true;
69+
const splitbit* bc_bi = &b_complement[bi * n_bins];
6770

68-
for (int32 bin = 0; bin < a.n_bins; ++bin) {
71+
for (int32 bin = 0; bin < n_bins; ++bin) {
6972
if ((a.state[ai][bin] != b.state[bi][bin])) {
7073
all_match = false;
7174
break;
7275
}
7376
}
7477
if (!all_match) {
75-
for (int32 bin = 0; bin < a.n_bins; ++bin) {
76-
if (a.state[ai][bin] != b_complement[bi][bin]) {
78+
for (int32 bin = 0; bin < n_bins; ++bin) {
79+
if (a.state[ai][bin] != bc_bi[bin]) {
7780
all_complement = false;
7881
break;
7982
}
@@ -105,29 +108,31 @@ inline List robinson_foulds_info(const RawMatrix &x, const RawMatrix &y,
105108

106109
grf_match matching(a.n_splits, NA_INTEGER);
107110

108-
/* Dynamic allocation 20% faster for 105 tips, but VLA not permitted in C11 */
109-
splitbit b_complement[SL_MAX_SPLITS][SL_MAX_BINS];
111+
const int16 n_bins = a.n_bins;
112+
std::vector<splitbit> b_complement(b.n_splits * n_bins);
110113
for (int16 i = 0; i < b.n_splits; i++) {
114+
splitbit* bc_i = &b_complement[i * n_bins];
111115
for (int16 bin = 0; bin < last_bin; ++bin) {
112-
b_complement[i][bin] = ~b.state[i][bin];
116+
bc_i[bin] = ~b.state[i][bin];
113117
}
114-
b_complement[i][last_bin] = b.state[i][last_bin] ^ unset_mask;
118+
bc_i[last_bin] = b.state[i][last_bin] ^ unset_mask;
115119
}
116120

117121
for (int16 ai = 0; ai < a.n_splits; ++ai) {
118122
for (int16 bi = 0; bi < b.n_splits; ++bi) {
119123

120124
bool all_match = true, all_complement = true;
125+
const splitbit* bc_bi = &b_complement[bi * n_bins];
121126

122-
for (int16 bin = 0; bin < a.n_bins; ++bin) {
127+
for (int16 bin = 0; bin < n_bins; ++bin) {
123128
if ((a.state[ai][bin] != b.state[bi][bin])) {
124129
all_match = false;
125130
break;
126131
}
127132
}
128133
if (!all_match) {
129-
for (int16 bin = 0; bin < a.n_bins; ++bin) {
130-
if ((a.state[ai][bin] != b_complement[bi][bin])) {
134+
for (int16 bin = 0; bin < n_bins; ++bin) {
135+
if ((a.state[ai][bin] != bc_bi[bin])) {
131136
all_complement = false;
132137
break;
133138
}

src/tree_distances.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ namespace TreeDist {
3939

4040
// Returns lg2_unrooted[x] - lg2_trees_matching_split(y, x - y)
4141
[[nodiscard]] inline double mmsi_pair_score(const int16 x, const int16 y) noexcept {
42-
assert(SL_MAX_TIPS + 2 <= std::numeric_limits<int16>::max()); // verify int16 ok
42+
static_assert(SL_MAX_TIPS + 2 <= std::numeric_limits<int16>::max(),
43+
"int16 too narrow for SL_MAX_TIPS");
4344

4445
return lg2_unrooted[x] - (lg2_rooted[y] + lg2_rooted[x - y]);
4546
}
@@ -60,7 +61,8 @@ namespace TreeDist {
6061

6162

6263
[[nodiscard]] inline double one_overlap(const int16 a, const int16 b, const int16 n) noexcept {
63-
assert(SL_MAX_TIPS + 2 <= std::numeric_limits<int16>::max()); // verify int16 ok
64+
static_assert(SL_MAX_TIPS + 2 <= std::numeric_limits<int16>::max(),
65+
"int16 too narrow for SL_MAX_TIPS");
6466
if (a == b) {
6567
return lg2_rooted[a] + lg2_rooted[n - a];
6668
}
@@ -71,7 +73,8 @@ namespace TreeDist {
7173
}
7274

7375
[[nodiscard]] inline double one_overlap_notb(const int16 a, const int16 n_minus_b, const int16 n) noexcept {
74-
assert(SL_MAX_TIPS + 2 <= std::numeric_limits<int16>::max()); // verify int16 ok
76+
static_assert(SL_MAX_TIPS + 2 <= std::numeric_limits<int16>::max(),
77+
"int16 too narrow for SL_MAX_TIPS");
7578
const int16 b = n - n_minus_b;
7679
if (a == b) {
7780
return lg2_rooted[b] + lg2_rooted[n_minus_b];
@@ -90,7 +93,8 @@ namespace TreeDist {
9093
const int16 n_tips, const int16 in_a,
9194
const int16 in_b, const int16 n_bins) noexcept {
9295

93-
assert(SL_MAX_BINS <= INT16_MAX);
96+
static_assert(SL_MAX_BINS <= INT16_MAX,
97+
"int16 too narrow for SL_MAX_BINS");
9498

9599
int16 n_ab = 0;
96100
for (int16 bin = 0; bin < n_bins; ++bin) {

0 commit comments

Comments
 (0)