Skip to content

Commit cf6cd53

Browse files
committed
Add skip_empty to Tree class and trees iterator
1 parent ece21ff commit cf6cd53

File tree

2 files changed

+44
-17
lines changed

2 files changed

+44
-17
lines changed

python/tests/test_highlevel.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,21 @@ def test_trees_interface(self):
17221722
assert t.get_num_tracked_samples(0) == 0
17231723
assert list(t.samples(0)) == [0]
17241724

1725+
def test_trees_skip_empty(self):
1726+
ts = tskit.Tree.generate_balanced(10).tree_sequence
1727+
ts = ts.delete_intervals([[0, 0.2], [0.7, 0.9]])
1728+
assert ts.num_trees == 4
1729+
for i, tree in enumerate(ts.trees()):
1730+
if i % 2 == 0:
1731+
assert tree.is_empty()
1732+
else:
1733+
assert not tree.is_empty()
1734+
assert i == 3
1735+
1736+
for _i, tree in enumerate(ts.trees(skip_empty=True)):
1737+
assert not tree.is_empty()
1738+
assert _i == 1
1739+
17251740
@pytest.mark.parametrize("ts", get_example_tree_sequences())
17261741
def test_get_pairwise_diversity(self, ts):
17271742
with pytest.raises(ValueError, match="at least one element"):
@@ -4453,7 +4468,7 @@ class TestSeekDirection:
44534468
def ts(self):
44544469
return tsutil.all_trees_ts(3)
44554470

4456-
def setup(self):
4471+
def setup_trees(self):
44574472
ts = self.ts()
44584473
t1 = tskit.Tree(ts)
44594474
t2 = tskit.Tree(ts)
@@ -4466,22 +4481,22 @@ def setup(self):
44664481
def test_index_from_different_directions(self, index):
44674482
# Check that we get different orderings of the children arrays
44684483
# for all trees when we go in different directions.
4469-
t1, t2 = self.setup()
4484+
t1, t2 = self.setup_trees()
44704485
while t1.index != index:
44714486
t1.next()
44724487
while t2.index != index:
44734488
t2.prev()
44744489
assert_same_tree_different_order(t1, t2)
44754490

44764491
def test_seek_0_from_null(self):
4477-
t1, t2 = self.setup()
4492+
t1, t2 = self.setup_trees()
44784493
t1.first()
44794494
t2.seek(0)
44804495
assert_trees_identical(t1, t2)
44814496

44824497
@pytest.mark.parametrize("index", range(3))
44834498
def test_seek_next_tree(self, index):
4484-
t1, t2 = self.setup()
4499+
t1, t2 = self.setup_trees()
44854500
while t1.index != index:
44864501
t1.next()
44874502
t2.next()
@@ -4491,7 +4506,7 @@ def test_seek_next_tree(self, index):
44914506

44924507
@pytest.mark.parametrize("index", [3, 2, 1])
44934508
def test_seek_prev_tree(self, index):
4494-
t1, t2 = self.setup()
4509+
t1, t2 = self.setup_trees()
44954510
while t1.index != index:
44964511
t1.prev()
44974512
t2.prev()
@@ -4500,44 +4515,44 @@ def test_seek_prev_tree(self, index):
45004515
assert_trees_identical(t1, t2)
45014516

45024517
def test_seek_1_from_0(self):
4503-
t1, t2 = self.setup()
4518+
t1, t2 = self.setup_trees()
45044519
t1.first()
45054520
t1.next()
45064521
t2.first()
45074522
t2.seek(1)
45084523
assert_trees_identical(t1, t2)
45094524

45104525
def test_seek_1_5_from_0(self):
4511-
t1, t2 = self.setup()
4526+
t1, t2 = self.setup_trees()
45124527
t1.first()
45134528
t1.next()
45144529
t2.first()
45154530
t2.seek(1.5)
45164531
assert_trees_identical(t1, t2)
45174532

45184533
def test_seek_1_5_from_1(self):
4519-
t1, t2 = self.setup()
4534+
t1, t2 = self.setup_trees()
45204535
for _ in range(2):
45214536
t1.next()
45224537
t2.next()
45234538
t2.seek(1.5)
45244539
assert_trees_identical(t1, t2)
45254540

45264541
def test_seek_3_from_null(self):
4527-
t1, t2 = self.setup()
4542+
t1, t2 = self.setup_trees()
45284543
t1.last()
45294544
t2.seek(3)
45304545
assert_trees_identical(t1, t2)
45314546

45324547
def test_seek_3_from_0(self):
4533-
t1, t2 = self.setup()
4548+
t1, t2 = self.setup_trees()
45344549
t1.last()
45354550
t2.first()
45364551
t2.seek(3)
45374552
assert_trees_identical(t1, t2)
45384553

45394554
def test_seek_0_from_3(self):
4540-
t1, t2 = self.setup()
4555+
t1, t2 = self.setup_trees()
45414556
t1.last()
45424557
t1.first()
45434558
t2.last()

python/tskit/trees.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def __init__(
652652
tracked_samples=None,
653653
*,
654654
sample_lists=False,
655+
skip_empty=None,
655656
root_threshold=1,
656657
sample_counts=None,
657658
):
@@ -671,6 +672,7 @@ def __init__(
671672

672673
self._tree_sequence = tree_sequence
673674
self._ll_tree = _tskit.Tree(tree_sequence.ll_tree_sequence, **kwargs)
675+
self.skip_empty = skip_empty # This is a python-only facility, not in _ll
674676
self._ll_tree.set_root_threshold(root_threshold)
675677
self._make_arrays()
676678

@@ -3851,6 +3853,7 @@ def __init__(self, tree):
38513853
self.tree = tree
38523854
self.more_trees = True
38533855
self.forward = True
3856+
self.no_skip_empty = True if tree.skip_empty is None else not tree.skip_empty
38543857

38553858
def __iter__(self):
38563859
return self
@@ -3860,15 +3863,19 @@ def __reversed__(self):
38603863
return self
38613864

38623865
def __next__(self):
3863-
if self.forward:
3864-
self.more_trees = self.more_trees and self.tree.next()
3865-
else:
3866-
self.more_trees = self.more_trees and self.tree.prev()
3867-
if not self.more_trees:
3868-
raise StopIteration()
3866+
while True:
3867+
if self.forward:
3868+
self.more_trees = self.more_trees and self.tree.next()
3869+
else:
3870+
self.more_trees = self.more_trees and self.tree.prev()
3871+
if not self.more_trees:
3872+
raise StopIteration()
3873+
if self.no_skip_empty or not self.tree.is_empty():
3874+
break
38693875
return self.tree
38703876

38713877
def __len__(self):
3878+
# NB: this can return a longer length than the iterator if we skip empty trees
38723879
return self.tree.tree_sequence.num_trees
38733880

38743881

@@ -4945,6 +4952,7 @@ def trees(
49454952
tracked_samples=None,
49464953
*,
49474954
sample_lists=False,
4955+
skip_empty=None,
49484956
root_threshold=1,
49494957
sample_counts=None,
49504958
tracked_leaves=None,
@@ -4972,6 +4980,9 @@ def trees(
49724980
:param bool sample_lists: If True, provide more efficient access
49734981
to the samples beneath a given node using the
49744982
:meth:`Tree.samples` method.
4983+
:param bool skip_empty: If True, skip trees that are
4984+
:meth:`empty<Tree.is_empty>`, commonly found at the start and end of
4985+
a tree sequence. Default: ``None`` treated as ``False``.
49754986
:param int root_threshold: The minimum number of samples that a node
49764987
must be ancestral to for it to be in the list of roots. By default
49774988
this is 1, so that isolated samples (representing missing data)
@@ -4996,6 +5007,7 @@ def trees(
49965007
self,
49975008
tracked_samples=tracked_samples,
49985009
sample_lists=sample_lists,
5010+
skip_empty=skip_empty,
49995011
root_threshold=root_threshold,
50005012
sample_counts=sample_counts,
50015013
)

0 commit comments

Comments
 (0)