Skip to content

Commit cd497bd

Browse files
committed
Add ts.coiterate method
Fixes #1021
1 parent 6060ce7 commit cd497bd

File tree

3 files changed

+117
-24
lines changed

3 files changed

+117
-24
lines changed

python/tests/test_topology.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4785,6 +4785,86 @@ def test_partial_overlap_contradictory_children(self):
47854785
tskit.load_text(nodes=nodes, edges=edges, strict=False)
47864786

47874787

4788+
class TestCoiteration:
4789+
"""
4790+
Test ability to iterate over multiple (currently 2) tree sequences simultaneously
4791+
"""
4792+
4793+
def test_identical_ts(self):
4794+
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
4795+
assert ts.num_trees > 1
4796+
total_iterations = 0
4797+
for _, t1, t2 in ts.coiterate(ts):
4798+
total_iterations += 1
4799+
assert t1 == t2
4800+
assert ts.num_trees == total_iterations
4801+
4802+
def test_intervals(self):
4803+
ts1 = msprime.simulate(4, recombination_rate=1, random_seed=1)
4804+
assert ts1.num_trees > 1
4805+
one_tree_ts = msprime.simulate(5, random_seed=2)
4806+
multi_tree_ts = msprime.simulate(5, recombination_rate=1, random_seed=2)
4807+
assert multi_tree_ts.num_trees > 1
4808+
for ts2 in (one_tree_ts, multi_tree_ts):
4809+
bp1 = set(ts1.breakpoints())
4810+
bp2 = set(ts2.breakpoints())
4811+
assert bp1 != bp2
4812+
breaks = {0}
4813+
for interval, t1, t2 in ts1.coiterate(ts2):
4814+
breaks.add(interval.right)
4815+
assert t1.tree_sequence == ts1
4816+
assert t2.tree_sequence == ts2
4817+
assert breaks == bp1 | bp2
4818+
4819+
def test_simple_ts(self):
4820+
nodes = """\
4821+
id is_sample time
4822+
0 1 0
4823+
1 1 0
4824+
2 1 0
4825+
3 0 1
4826+
4 0 2
4827+
"""
4828+
edges1 = """\
4829+
left right parent child
4830+
0 0.2 3 0,1
4831+
0 0.2 4 2,3
4832+
0.2 1 3 2,1
4833+
0.2 1 4 0,3
4834+
"""
4835+
edges2 = """\
4836+
left right parent child
4837+
0 0.8 3 2,1
4838+
0 0.8 4 0,3
4839+
0.8 1 3 0,1
4840+
0.8 1 4 2,3
4841+
"""
4842+
ts1 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges1), strict=False)
4843+
ts2 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges2), strict=False)
4844+
coiterator = ts1.coiterate(ts2)
4845+
interval, tree1, tree2 = next(coiterator)
4846+
assert interval.left == 0
4847+
assert interval.right == 0.2
4848+
assert tree1 == ts1.at_index(0)
4849+
assert tree2 == ts2.at_index(0)
4850+
interval, tree1, tree2 = next(coiterator)
4851+
assert interval.left == 0.2
4852+
assert interval.right == 0.8
4853+
assert tree1 == ts1.at_index(1)
4854+
assert tree2 == ts2.at_index(0)
4855+
interval, tree1, tree2 = next(coiterator)
4856+
assert interval.left == 0.8
4857+
assert interval.right == 1
4858+
assert tree1 == ts1.at_index(1)
4859+
assert tree2 == ts2.at_index(1)
4860+
4861+
def test_nonequal_lengths(self):
4862+
ts1 = msprime.simulate(4, random_seed=1, length=2)
4863+
ts2 = msprime.simulate(4, random_seed=1)
4864+
with pytest.raises(ValueError):
4865+
next(ts1.coiterate(ts2))
4866+
4867+
47884868
class SimplifyTestBase:
47894869
"""
47904870
Base class for simplify tests.

python/tests/tsutil.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,27 +1282,3 @@ def genealogical_nearest_neighbours(ts, focal, reference_sets):
12821282
L[L == 0] = 1
12831283
A /= L.reshape((len(focal), 1))
12841284
return A
1285-
1286-
1287-
def coiterate(ts1, ts2, **kwargs):
1288-
"""
1289-
Returns an iterator over the pairs of trees for each distinct
1290-
interval in the specified pair of tree sequences.
1291-
"""
1292-
if ts1.sequence_length != ts2.sequence_length:
1293-
raise ValueError("Tree sequences must be equal length.")
1294-
L = ts1.sequence_length
1295-
trees1 = ts1.trees(**kwargs)
1296-
trees2 = ts2.trees(**kwargs)
1297-
tree1 = next(trees1)
1298-
tree2 = next(trees2)
1299-
right = 0
1300-
while right != L:
1301-
left = right
1302-
right = min(tree1.interval[1], tree2.interval[1])
1303-
yield (left, right), tree1, tree2
1304-
# Advance
1305-
if tree1.interval[1] == right:
1306-
tree1 = next(trees1, None)
1307-
if tree2.interval[1] == right:
1308-
tree2 = next(trees2, None)

python/tskit/trees.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3850,6 +3850,43 @@ def trees(
38503850
)
38513851
return TreeIterator(tree)
38523852

3853+
def coiterate(self, other, **kwargs):
3854+
"""
3855+
Returns an iterator over the pairs of trees for each distinct
3856+
interval in the specified pair of tree sequences.
3857+
3858+
:param bool other: The other tree sequence from which to take trees. The
3859+
sequence length should be the same as the current tree sequence.
3860+
:param \\**kwargs: Further named arguments that will be passed to the
3861+
:meth:`.trees` method when constructing the returned trees.
3862+
3863+
:return: An iterator returning successive tuples of the form
3864+
``(interval, tree_self, tree_other)``. For example, the first item returned
3865+
will consist of an tuple of the intail interval, the first tree of the
3866+
current tree sequence, and the first tree of the ``other`` tree sequence;
3867+
the ``.left`` attribute of the initial interval will be 0 and the ``.right``
3868+
attribute will be the smallest non-zero breakpoint of the 2 tree sequences.
3869+
:rtype: iter(:class:`Interval`, :class:`Tree`, :class:`Tree`)
3870+
3871+
"""
3872+
if self.sequence_length != other.sequence_length:
3873+
raise ValueError("Tree sequences must be equal length.")
3874+
L = self.sequence_length
3875+
trees1 = self.trees(**kwargs)
3876+
trees2 = other.trees(**kwargs)
3877+
tree1 = next(trees1)
3878+
tree2 = next(trees2)
3879+
right = 0
3880+
while right != L:
3881+
left = right
3882+
right = min(tree1.interval[1], tree2.interval[1])
3883+
yield Interval(left, right), tree1, tree2
3884+
# Advance
3885+
if tree1.interval[1] == right:
3886+
tree1 = next(trees1, None)
3887+
if tree2.interval[1] == right:
3888+
tree2 = next(trees2, None)
3889+
38533890
def haplotypes(
38543891
self,
38553892
*,

0 commit comments

Comments
 (0)