Skip to content

Commit 8692ae6

Browse files
committed
Add fast RF cross-pairs via Day 1985 ClusterTable batch
RobinsonFoulds(trees1, trees2) previously fell through to per-pair R dispatch via CalculateTreeDistance() -> .mapply(), which was ~27x slower per pair than the all-pairs path. Add robinson_foulds_cross_pairs() C++ function in day_1985.cpp using the same Day 1985 algorithm and ClusterTable encoding as the existing all-pairs function. Wire it into RobinsonFoulds() as a fast path when both inputs are tree lists with matching tip labels. Benchmark: 21x speedup for RobinsonFoulds(trees1, trees2). Per-pair cost is now 1.84x all-pairs (down from 27x). Tests: 5 new tests in test-batch_coverage.R covering distance, similarity, normalization, single-tree fallback, and all-pairs consistency.
1 parent dd6cf73 commit 8692ae6

File tree

5 files changed

+208
-4
lines changed

5 files changed

+208
-4
lines changed

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ robinson_foulds_all_pairs <- function(tables) {
1313
.Call(`_TreeDist_robinson_foulds_all_pairs`, tables)
1414
}
1515

16+
robinson_foulds_cross_pairs <- function(tables_a, tables_b) {
17+
.Call(`_TreeDist_robinson_foulds_cross_pairs`, tables_a, tables_b)
18+
}
19+
1620
consensus_info <- function(trees, phylo, p) {
1721
.Call(`_TreeDist_consensus_info`, trees, phylo, p)
1822
}

R/tree_distance_rf.R

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,43 @@ RobinsonFoulds <- function(tree1, tree2 = NULL, similarity = FALSE,
148148
class = "dist")
149149
}
150150
} else {
151-
unnormalized <- CalculateTreeDistance(RobinsonFouldsSplits, tree1, tree2,
152-
reportMatching)
153-
if (similarity) {
154-
unnormalized <- .MaxValue(tree1, tree2, NSplits) - unnormalized
151+
# Fast cross-pairs path: batch C++ via ClusterTable (Day 1985).
152+
# Only applicable when both inputs are lists of trees with matching
153+
# tip labels and reportMatching is not requested.
154+
fast <- NULL
155+
if (!reportMatching &&
156+
!inherits(tree1, c("phylo", "Splits")) &&
157+
!inherits(tree2, c("phylo", "Splits")) &&
158+
is.null(getOption("TreeDist-cluster"))) {
159+
lab1 <- TipLabels(tree1)
160+
lab2 <- TipLabels(tree2)
161+
if (is.list(lab1)) lab1 <- lab1[[1]]
162+
if (is.list(lab2)) lab2 <- lab2[[1]]
163+
if (setequal(lab1, lab2)) {
164+
ct1 <- as.ClusterTable(tree1, tipLabels = lab1)
165+
ct2 <- as.ClusterTable(tree2, tipLabels = lab1)
166+
shared <- robinson_foulds_cross_pairs(
167+
if (is.list(ct1)) ct1 else list(ct1),
168+
if (is.list(ct2)) ct2 else list(ct2)
169+
)
170+
splits1 <- NSplits(tree1)
171+
splits2 <- NSplits(tree2)
172+
if (similarity) {
173+
fast <- shared + shared
174+
} else {
175+
fast <- outer(splits1, splits2, "+") - shared - shared
176+
}
177+
dimnames(fast) <- list(names(tree1), names(tree2))
178+
}
179+
}
180+
if (is.null(fast)) {
181+
unnormalized <- CalculateTreeDistance(RobinsonFouldsSplits, tree1, tree2,
182+
reportMatching)
183+
if (similarity) {
184+
unnormalized <- .MaxValue(tree1, tree2, NSplits) - unnormalized
185+
}
186+
} else {
187+
unnormalized <- fast
155188
}
156189
}
157190

src/RcppExports.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ BEGIN_RCPP
4444
return rcpp_result_gen;
4545
END_RCPP
4646
}
47+
// robinson_foulds_cross_pairs
48+
IntegerMatrix robinson_foulds_cross_pairs(const List& tables_a, const List& tables_b);
49+
RcppExport SEXP _TreeDist_robinson_foulds_cross_pairs(SEXP tables_aSEXP, SEXP tables_bSEXP) {
50+
BEGIN_RCPP
51+
Rcpp::RObject rcpp_result_gen;
52+
Rcpp::RNGScope rcpp_rngScope_gen;
53+
Rcpp::traits::input_parameter< const List& >::type tables_a(tables_aSEXP);
54+
Rcpp::traits::input_parameter< const List& >::type tables_b(tables_bSEXP);
55+
rcpp_result_gen = Rcpp::wrap(robinson_foulds_cross_pairs(tables_a, tables_b));
56+
return rcpp_result_gen;
57+
END_RCPP
58+
}
4759
// consensus_info
4860
double consensus_info(const List trees, const LogicalVector phylo, const NumericVector p);
4961
RcppExport SEXP _TreeDist_consensus_info(SEXP treesSEXP, SEXP phyloSEXP, SEXP pSEXP) {
@@ -620,6 +632,7 @@ static const R_CallMethodDef CallEntries[] = {
620632
{"_TreeDist_binary_entropy_counts", (DL_FUNC) &_TreeDist_binary_entropy_counts, 2},
621633
{"_TreeDist_COMCLUST", (DL_FUNC) &_TreeDist_COMCLUST, 1},
622634
{"_TreeDist_robinson_foulds_all_pairs", (DL_FUNC) &_TreeDist_robinson_foulds_all_pairs, 1},
635+
{"_TreeDist_robinson_foulds_cross_pairs", (DL_FUNC) &_TreeDist_robinson_foulds_cross_pairs, 2},
623636
{"_TreeDist_consensus_info", (DL_FUNC) &_TreeDist_consensus_info, 3},
624637
{"_TreeDist_HMI_xptr", (DL_FUNC) &_TreeDist_HMI_xptr, 2},
625638
{"_TreeDist_HH_xptr", (DL_FUNC) &_TreeDist_HH_xptr, 1},

src/day_1985.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,110 @@ IntegerVector robinson_foulds_all_pairs(const List& tables) {
318318
return shared;
319319
}
320320

321+
// Cross-pairs variant: compute RF shared splits between two collections.
322+
// tables_a and tables_b are lists of ClusterTable XPtr objects
323+
// (all trees must share the same tip labels).
324+
// Returns an nA × nB IntegerMatrix of shared-split counts.
325+
// [[Rcpp::export]]
326+
IntegerMatrix robinson_foulds_cross_pairs(const List& tables_a,
327+
const List& tables_b) {
328+
const int nA = static_cast<int>(tables_a.size());
329+
const int nB = static_cast<int>(tables_b.size());
330+
if (nA == 0 || nB == 0) return IntegerMatrix(nA, nB);
331+
332+
std::vector<ClusterTable*> tbl_a, tbl_b;
333+
tbl_a.reserve(nA);
334+
tbl_b.reserve(nB);
335+
for (int i = 0; i < nA; ++i) {
336+
Rcpp::XPtr<ClusterTable> xp = tables_a[i];
337+
tbl_a.push_back(xp.get());
338+
}
339+
for (int j = 0; j < nB; ++j) {
340+
Rcpp::XPtr<ClusterTable> xp = tables_b[j];
341+
tbl_b.push_back(xp.get());
342+
}
343+
344+
IntegerMatrix result(nA, nB);
345+
346+
const int32 n_tip = tbl_a[0]->N();
347+
StackEntry* S_start;
348+
std::array<StackEntry, ct_stack_threshold> S_stack;
349+
std::vector<StackEntry> S_heap;
350+
if (n_tip <= ct_stack_threshold) {
351+
S_start = S_stack.data();
352+
} else {
353+
S_heap.resize(n_tip);
354+
S_start = S_heap.data();
355+
}
356+
357+
for (int i = 0; i < nA; ++i) {
358+
359+
ClusterTable* Xi = tbl_a[i];
360+
361+
for (int j = 0; j < nB; ++j) {
362+
363+
int32 v;
364+
int32 w;
365+
int32 n_shared = 0;
366+
367+
ClusterTable* Tj = tbl_b[j];
368+
369+
StackEntry* S_top = S_start;
370+
371+
Tj->TRESET();
372+
Tj->NVERTEX_short(&v, &w);
373+
374+
while (v) {
375+
if (Tj->is_leaf(v)) {
376+
const auto enc_v = Xi->ENCODE(v);
377+
*S_top++ = {enc_v, enc_v, 1, 1};
378+
} else {
379+
const StackEntry& entry = *--S_top;
380+
int32 L = entry.L;
381+
int32 R = entry.R;
382+
int32 N = entry.N;
383+
const int32 W_i = entry.W;
384+
int32 W = 1 + W_i;
385+
386+
w -= W_i;
387+
388+
if (w) {
389+
const StackEntry& entry = *--S_top;
390+
const int32 W_i = entry.W;
391+
392+
L = std::min(L, entry.L);
393+
R = std::max(R, entry.R);
394+
N += entry.N;
395+
W += W_i;
396+
w -= W_i;
397+
398+
while (w) {
399+
const StackEntry& entry = *--S_top;
400+
const int32 W_i = entry.W;
401+
402+
L = std::min(L, entry.L);
403+
R = std::max(R, entry.R);
404+
N += entry.N;
405+
W += W_i;
406+
w -= W_i;
407+
}
408+
}
409+
410+
*S_top++ = {L, R, N, W};
411+
412+
if (N == R - L + 1) {
413+
if (Xi->ISCLUST(L, R)) ++n_shared;
414+
}
415+
}
416+
Tj->NVERTEX_short(&v, &w);
417+
}
418+
result(i, j) = n_shared;
419+
}
420+
}
421+
422+
return result;
423+
}
424+
321425
// [[Rcpp::export]]
322426
double consensus_info(const List trees, const LogicalVector phylo,
323427
const NumericVector p) {

tests/testthat/test-batch_coverage.R

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,56 @@ test_that("Issue #162: CID by-hand matches function for 33-taxon trees", {
363363
expect_equal(d_fn, d_hand, tolerance = 1e-10)
364364
})
365365

366+
# RobinsonFoulds cross-pairs (Day 1985 ClusterTable batch) ----
367+
368+
test_that("RF cross-pairs fast path agrees with per-pair", {
369+
cross <- RobinsonFoulds(tA, tB)
370+
expect_equal(dim(cross), c(5L, 5L))
371+
expect_equal(cross[1, 1],
372+
RobinsonFoulds(tA[[1]], tB[[1]]),
373+
tolerance = 0)
374+
expect_equal(cross[3, 2],
375+
RobinsonFoulds(tA[[3]], tB[[2]]),
376+
tolerance = 0)
377+
expect_equal(cross[5, 5],
378+
RobinsonFoulds(tA[[5]], tB[[5]]),
379+
tolerance = 0)
380+
})
381+
382+
test_that("RF cross-pairs similarity mode agrees with per-pair", {
383+
cross_sim <- RobinsonFoulds(tA, tB, similarity = TRUE)
384+
expect_equal(cross_sim[2, 3],
385+
RobinsonFoulds(tA[[2]], tB[[3]], similarity = TRUE),
386+
tolerance = 0)
387+
})
388+
389+
test_that("RF cross-pairs normalization agrees with per-pair", {
390+
cross_norm <- RobinsonFoulds(tA, tB, normalize = TRUE)
391+
expect_equal(cross_norm[1, 4],
392+
RobinsonFoulds(tA[[1]], tB[[4]], normalize = TRUE),
393+
tolerance = 1e-10)
394+
})
395+
396+
test_that("RF cross-pairs handles single-tree inputs via fallback", {
397+
# Single tree1 → falls back to CalculateTreeDistance
398+
d <- RobinsonFoulds(tA[[1]], tB)
399+
expect_equal(length(d), 5L)
400+
expect_equal(d[1], RobinsonFoulds(tA[[1]], tB[[1]]), tolerance = 0)
401+
})
402+
403+
test_that("RF cross-pairs all-pairs matches cross-pairs diagonal", {
404+
all_dist <- as.matrix(RobinsonFoulds(trees20))
405+
# Cross-pairs of trees20 with itself should match all-pairs
406+
cross <- RobinsonFoulds(trees20, trees20)
407+
expect_equal(dim(cross), c(10L, 10L))
408+
for (i in 1:10) {
409+
for (j in 1:10) {
410+
expect_equal(cross[i, j], all_dist[i, j], tolerance = 0,
411+
info = paste0("i=", i, " j=", j))
412+
}
413+
}
414+
})
415+
366416
test_that("Issue #162: batch path agrees with per-pair for 33-taxon trees", {
367417
tree1 <- ape::read.tree(text = "(B,A,((AG,AF),((((C,(E,D)),((F,G),H)),((K,I),J)),((Q,R),((((AE,(AC,AD)),(AB,(N,(P,O)))),(((Y,(Z,AA)),(W,X)),(V,(T,U)))),((M,S),L))))));")
368418
tree2 <- ape::read.tree(text = "(B,A,((AG,AF),((((C,(E,D)),((F,G),H)),(((Q,R),((((AE,(AC,AD)),AB),(N,(P,O))),((Y,(Z,AA)),((V,(T,U)),(W,X))))),((M,L),S))),((K,I),J))));")

0 commit comments

Comments
 (0)