Skip to content

Commit 1a46905

Browse files
committed
Add ts.coiterate method
Fixes #1021
1 parent 6060ce7 commit 1a46905

File tree

4 files changed

+142
-27
lines changed

4 files changed

+142
-27
lines changed

docs/python-api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ directly, but are the return types for the various iterators provided by the
8686
.. autoclass:: Edge()
8787
:members:
8888

89+
.. autoclass:: Interval()
90+
:members:
91+
8992
.. autoclass:: Site()
9093
:members:
9194

python/tests/test_topology.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4785,6 +4785,94 @@ def test_partial_overlap_contradictory_children(self):
47854785
tskit.load_text(nodes=nodes, edges=edges, strict=False)
47864786

47874787

4788+
class TestCoiteration:
4789+
"""
4790+
Test ability to iterate over multiple (currently 2) tree sequences simultaneously
4791+
"""
4792+
4793+
def test_identical_ts(self):
4794+
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
4795+
assert ts.num_trees > 1
4796+
total_iterations = 0
4797+
for tree, (_, t1, t2) in zip(ts.trees(), ts.coiterate(ts)):
4798+
total_iterations += 1
4799+
assert tree == t1 == t2
4800+
assert ts.num_trees == total_iterations
4801+
4802+
def test_intervals(self):
4803+
ts1 = msprime.simulate(4, recombination_rate=1, random_seed=1)
4804+
assert ts1.num_trees > 1
4805+
one_tree_ts = msprime.simulate(5, random_seed=2)
4806+
multi_tree_ts = msprime.simulate(5, recombination_rate=1, random_seed=2)
4807+
assert multi_tree_ts.num_trees > 1
4808+
for ts2 in (one_tree_ts, multi_tree_ts):
4809+
bp1 = set(ts1.breakpoints())
4810+
bp2 = set(ts2.breakpoints())
4811+
assert bp1 != bp2
4812+
breaks = set()
4813+
for interval, t1, t2 in ts1.coiterate(ts2):
4814+
breaks.add(interval.left)
4815+
breaks.add(interval.right)
4816+
assert t1.tree_sequence == ts1
4817+
assert t2.tree_sequence == ts2
4818+
assert breaks == bp1 | bp2
4819+
4820+
def test_simple_ts(self):
4821+
nodes = """\
4822+
id is_sample time
4823+
0 1 0
4824+
1 1 0
4825+
2 1 0
4826+
3 0 1
4827+
4 0 2
4828+
"""
4829+
edges1 = """\
4830+
left right parent child
4831+
0 0.2 3 0,1
4832+
0 0.2 4 2,3
4833+
0.2 1 3 2,1
4834+
0.2 1 4 0,3
4835+
"""
4836+
edges2 = """\
4837+
left right parent child
4838+
0 0.8 3 2,1
4839+
0 0.8 4 0,3
4840+
0.8 1 3 0,1
4841+
0.8 1 4 2,3
4842+
"""
4843+
ts1 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges1), strict=False)
4844+
ts2 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges2), strict=False)
4845+
coiterator = ts1.coiterate(ts2)
4846+
interval, tree1, tree2 = next(coiterator)
4847+
assert interval.left == 0
4848+
assert interval.right == 0.2
4849+
assert tree1 == ts1.at_index(0)
4850+
assert tree2 == ts2.at_index(0)
4851+
interval, tree1, tree2 = next(coiterator)
4852+
assert interval.left == 0.2
4853+
assert interval.right == 0.8
4854+
assert tree1 == ts1.at_index(1)
4855+
assert tree2 == ts2.at_index(0)
4856+
interval, tree1, tree2 = next(coiterator)
4857+
assert interval.left == 0.8
4858+
assert interval.right == 1
4859+
assert tree1 == ts1.at_index(1)
4860+
assert tree2 == ts2.at_index(1)
4861+
4862+
def test_nonequal_lengths(self):
4863+
ts1 = msprime.simulate(4, random_seed=1, length=2)
4864+
ts2 = msprime.simulate(4, random_seed=1)
4865+
with pytest.raises(ValueError, match="equal sequence length"):
4866+
next(ts1.coiterate(ts2))
4867+
4868+
def test_kwargs(self):
4869+
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
4870+
for _, t1, t2 in ts.coiterate(ts):
4871+
assert t1.num_tracked_samples() == t2.num_tracked_samples() == 0
4872+
for _, t1, t2 in ts.coiterate(ts, tracked_samples=ts.samples()):
4873+
assert t1.num_tracked_samples() == t2.num_tracked_samples() == 4
4874+
4875+
47884876
class SimplifyTestBase:
47894877
"""
47904878
Base class for simplify tests.
@@ -5722,9 +5810,7 @@ def verify_keep_input_roots(self, ts, samples):
57225810
new_to_input_map = {
57235811
value: key for key, value in enumerate(node_map) if value != tskit.NULL
57245812
}
5725-
for (left, right), input_tree, tree_with_roots in tsutil.coiterate(
5726-
ts, ts_with_roots
5727-
):
5813+
for (left, right), input_tree, tree_with_roots in ts.coiterate(ts_with_roots):
57285814
input_roots = input_tree.roots
57295815
assert len(tree_with_roots.roots) > 0
57305816
for root in tree_with_roots.roots:

python/tests/tsutil.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,27 +1282,3 @@ def genealogical_nearest_neighbours(ts, focal, reference_sets):
12821282
L[L == 0] = 1
12831283
A /= L.reshape((len(focal), 1))
12841284
return A
1285-
1286-
1287-
def coiterate(ts1, ts2, **kwargs):
1288-
"""
1289-
Returns an iterator over the pairs of trees for each distinct
1290-
interval in the specified pair of tree sequences.
1291-
"""
1292-
if ts1.sequence_length != ts2.sequence_length:
1293-
raise ValueError("Tree sequences must be equal length.")
1294-
L = ts1.sequence_length
1295-
trees1 = ts1.trees(**kwargs)
1296-
trees2 = ts2.trees(**kwargs)
1297-
tree1 = next(trees1)
1298-
tree2 = next(trees2)
1299-
right = 0
1300-
while right != L:
1301-
left = right
1302-
right = min(tree1.interval[1], tree2.interval[1])
1303-
yield (left, right), tree1, tree2
1304-
# Advance
1305-
if tree1.interval[1] == right:
1306-
tree1 = next(trees1, None)
1307-
if tree2.interval[1] == right:
1308-
tree2 = next(trees2, None)

python/tskit/trees.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@
6262

6363

6464
class Interval(BaseInterval):
65+
"""
66+
A tuple of 2 numbers, ``[left, right)``, defining an interval over the genome.
67+
68+
:ivar left: The left hand end of the interval. By convention this value is included
69+
in the interval.
70+
:vartype left: float
71+
:ivar right: The right hand end of the iterval. By convention this value is *not*
72+
included in the interval, i.e. the interval is half-open.
73+
:vartype right: float
74+
:ivar span: The span of the genome covered by this interval, simply ``right-left``.
75+
:vartype span: float
76+
"""
77+
6578
@property
6679
def span(self):
6780
return self.right - self.left
@@ -3850,6 +3863,43 @@ def trees(
38503863
)
38513864
return TreeIterator(tree)
38523865

3866+
def coiterate(self, other, **kwargs):
3867+
"""
3868+
Returns an iterator over the pairs of trees for each distinct
3869+
interval in the specified pair of tree sequences.
3870+
3871+
:param TreeSequence other: The other tree sequence from which to take trees. The
3872+
sequence length must be the same as the current tree sequence.
3873+
:param \\**kwargs: Further named arguments that will be passed to the
3874+
:meth:`.trees` method when constructing the returned trees.
3875+
3876+
:return: An iterator returning successive tuples of the form
3877+
``(interval, tree_self, tree_other)``. For example, the first item returned
3878+
will consist of an tuple of the initial interval, the first tree of the
3879+
current tree sequence, and the first tree of the ``other`` tree sequence;
3880+
the ``.left`` attribute of the initial interval will be 0 and the ``.right``
3881+
attribute will be the smallest non-zero breakpoint of the 2 tree sequences.
3882+
:rtype: iter(:class:`Interval`, :class:`Tree`, :class:`Tree`)
3883+
3884+
"""
3885+
if self.sequence_length != other.sequence_length:
3886+
raise ValueError("Tree sequences must be of equal sequence length.")
3887+
L = self.sequence_length
3888+
trees1 = self.trees(**kwargs)
3889+
trees2 = other.trees(**kwargs)
3890+
tree1 = next(trees1)
3891+
tree2 = next(trees2)
3892+
right = 0
3893+
while right != L:
3894+
left = right
3895+
right = min(tree1.interval[1], tree2.interval[1])
3896+
yield Interval(left, right), tree1, tree2
3897+
# Advance
3898+
if tree1.interval[1] == right:
3899+
tree1 = next(trees1, None)
3900+
if tree2.interval[1] == right:
3901+
tree2 = next(trees2, None)
3902+
38533903
def haplotypes(
38543904
self,
38553905
*,

0 commit comments

Comments
 (0)