Skip to content

Commit 06b8be6

Browse files
hyanwongmergify-bot
authored and
mergify-bot
committed
Add ts.coiterate method
Fixes #1021
1 parent fca033d commit 06b8be6

File tree

5 files changed

+147
-27
lines changed

5 files changed

+147
-27
lines changed

docs/python-api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ directly, but are the return types for the various iterators provided by the
8686
.. autoclass:: Edge()
8787
:members:
8888

89+
.. autoclass:: Interval()
90+
:members:
91+
8992
.. autoclass:: Site()
9093
:members:
9194

python/CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
**Features**
66

7+
- Expose ``TreeSequence.coiterate()`` method to allow iteration over 2 sequences
8+
simultaneously, aiding comparison of trees from two sequences.
9+
(:user:`jeromekelleher`, :user:`hyanwong`, :issue:`1021`, :pr:`1022`)
10+
711
- tskit is now supported on, and has wheels for, python3.9.
812
(:user:`benjeffery`, :issue:`982`, :pr:`907`)
913

python/tests/test_topology.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4758,6 +4758,95 @@ def test_partial_overlap_contradictory_children(self):
47584758
tskit.load_text(nodes=nodes, edges=edges, strict=False)
47594759

47604760

4761+
class TestCoiteration:
4762+
"""
4763+
Test ability to iterate over multiple (currently 2) tree sequences simultaneously
4764+
"""
4765+
4766+
def test_identical_ts(self):
4767+
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
4768+
assert ts.num_trees > 1
4769+
total_iterations = 0
4770+
for tree, (_, t1, t2) in zip(ts.trees(), ts.coiterate(ts)):
4771+
total_iterations += 1
4772+
assert tree == t1 == t2
4773+
assert ts.num_trees == total_iterations
4774+
4775+
def test_intervals(self):
4776+
ts1 = msprime.simulate(4, recombination_rate=1, random_seed=1)
4777+
assert ts1.num_trees > 1
4778+
one_tree_ts = msprime.simulate(5, random_seed=2)
4779+
multi_tree_ts = msprime.simulate(5, recombination_rate=1, random_seed=2)
4780+
assert multi_tree_ts.num_trees > 1
4781+
for ts2 in (one_tree_ts, multi_tree_ts):
4782+
bp1 = set(ts1.breakpoints())
4783+
bp2 = set(ts2.breakpoints())
4784+
assert bp1 != bp2
4785+
breaks = set()
4786+
for interval, t1, t2 in ts1.coiterate(ts2):
4787+
assert set(interval) <= set(t1.interval) | set(t2.interval)
4788+
breaks.add(interval.left)
4789+
breaks.add(interval.right)
4790+
assert t1.tree_sequence == ts1
4791+
assert t2.tree_sequence == ts2
4792+
assert breaks == bp1 | bp2
4793+
4794+
def test_simple_ts(self):
4795+
nodes = """\
4796+
id is_sample time
4797+
0 1 0
4798+
1 1 0
4799+
2 1 0
4800+
3 0 1
4801+
4 0 2
4802+
"""
4803+
edges1 = """\
4804+
left right parent child
4805+
0 0.2 3 0,1
4806+
0 0.2 4 2,3
4807+
0.2 1 3 2,1
4808+
0.2 1 4 0,3
4809+
"""
4810+
edges2 = """\
4811+
left right parent child
4812+
0 0.8 3 2,1
4813+
0 0.8 4 0,3
4814+
0.8 1 3 0,1
4815+
0.8 1 4 2,3
4816+
"""
4817+
ts1 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges1), strict=False)
4818+
ts2 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges2), strict=False)
4819+
coiterator = ts1.coiterate(ts2)
4820+
interval, tree1, tree2 = next(coiterator)
4821+
assert interval.left == 0
4822+
assert interval.right == 0.2
4823+
assert tree1 == ts1.at_index(0)
4824+
assert tree2 == ts2.at_index(0)
4825+
interval, tree1, tree2 = next(coiterator)
4826+
assert interval.left == 0.2
4827+
assert interval.right == 0.8
4828+
assert tree1 == ts1.at_index(1)
4829+
assert tree2 == ts2.at_index(0)
4830+
interval, tree1, tree2 = next(coiterator)
4831+
assert interval.left == 0.8
4832+
assert interval.right == 1
4833+
assert tree1 == ts1.at_index(1)
4834+
assert tree2 == ts2.at_index(1)
4835+
4836+
def test_nonequal_lengths(self):
4837+
ts1 = msprime.simulate(4, random_seed=1, length=2)
4838+
ts2 = msprime.simulate(4, random_seed=1)
4839+
with pytest.raises(ValueError, match="equal sequence length"):
4840+
next(ts1.coiterate(ts2))
4841+
4842+
def test_kwargs(self):
4843+
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
4844+
for _, t1, t2 in ts.coiterate(ts):
4845+
assert t1.num_tracked_samples() == t2.num_tracked_samples() == 0
4846+
for _, t1, t2 in ts.coiterate(ts, tracked_samples=ts.samples()):
4847+
assert t1.num_tracked_samples() == t2.num_tracked_samples() == 4
4848+
4849+
47614850
class SimplifyTestBase:
47624851
"""
47634852
Base class for simplify tests.
@@ -5695,9 +5784,7 @@ def verify_keep_input_roots(self, ts, samples):
56955784
new_to_input_map = {
56965785
value: key for key, value in enumerate(node_map) if value != tskit.NULL
56975786
}
5698-
for (left, right), input_tree, tree_with_roots in tsutil.coiterate(
5699-
ts, ts_with_roots
5700-
):
5787+
for (left, right), input_tree, tree_with_roots in ts.coiterate(ts_with_roots):
57015788
input_roots = input_tree.roots
57025789
assert len(tree_with_roots.roots) > 0
57035790
for root in tree_with_roots.roots:

python/tests/tsutil.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,27 +1282,3 @@ def genealogical_nearest_neighbours(ts, focal, reference_sets):
12821282
L[L == 0] = 1
12831283
A /= L.reshape((len(focal), 1))
12841284
return A
1285-
1286-
1287-
def coiterate(ts1, ts2, **kwargs):
1288-
"""
1289-
Returns an iterator over the pairs of trees for each distinct
1290-
interval in the specified pair of tree sequences.
1291-
"""
1292-
if ts1.sequence_length != ts2.sequence_length:
1293-
raise ValueError("Tree sequences must be equal length.")
1294-
L = ts1.sequence_length
1295-
trees1 = ts1.trees(**kwargs)
1296-
trees2 = ts2.trees(**kwargs)
1297-
tree1 = next(trees1)
1298-
tree2 = next(trees2)
1299-
right = 0
1300-
while right != L:
1301-
left = right
1302-
right = min(tree1.interval[1], tree2.interval[1])
1303-
yield (left, right), tree1, tree2
1304-
# Advance
1305-
if tree1.interval[1] == right:
1306-
tree1 = next(trees1, None)
1307-
if tree2.interval[1] == right:
1308-
tree2 = next(trees2, None)

python/tskit/trees.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@
6262

6363

6464
class Interval(BaseInterval):
65+
"""
66+
A tuple of 2 numbers, ``[left, right)``, defining an interval over the genome.
67+
68+
:ivar left: The left hand end of the interval. By convention this value is included
69+
in the interval.
70+
:vartype left: float
71+
:ivar right: The right hand end of the iterval. By convention this value is *not*
72+
included in the interval, i.e. the interval is half-open.
73+
:vartype right: float
74+
:ivar span: The span of the genome covered by this interval, simply ``right-left``.
75+
:vartype span: float
76+
"""
77+
6578
@property
6679
def span(self):
6780
return self.right - self.left
@@ -3933,6 +3946,43 @@ def trees(
39333946
)
39343947
return TreeIterator(tree)
39353948

3949+
def coiterate(self, other, **kwargs):
3950+
"""
3951+
Returns an iterator over the pairs of trees for each distinct
3952+
interval in the specified pair of tree sequences.
3953+
3954+
:param TreeSequence other: The other tree sequence from which to take trees. The
3955+
sequence length must be the same as the current tree sequence.
3956+
:param \\**kwargs: Further named arguments that will be passed to the
3957+
:meth:`.trees` method when constructing the returned trees.
3958+
3959+
:return: An iterator returning successive tuples of the form
3960+
``(interval, tree_self, tree_other)``. For example, the first item returned
3961+
will consist of an tuple of the initial interval, the first tree of the
3962+
current tree sequence, and the first tree of the ``other`` tree sequence;
3963+
the ``.left`` attribute of the initial interval will be 0 and the ``.right``
3964+
attribute will be the smallest non-zero breakpoint of the 2 tree sequences.
3965+
:rtype: iter(:class:`Interval`, :class:`Tree`, :class:`Tree`)
3966+
3967+
"""
3968+
if self.sequence_length != other.sequence_length:
3969+
raise ValueError("Tree sequences must be of equal sequence length.")
3970+
L = self.sequence_length
3971+
trees1 = self.trees(**kwargs)
3972+
trees2 = other.trees(**kwargs)
3973+
tree1 = next(trees1)
3974+
tree2 = next(trees2)
3975+
right = 0
3976+
while right != L:
3977+
left = right
3978+
right = min(tree1.interval[1], tree2.interval[1])
3979+
yield Interval(left, right), tree1, tree2
3980+
# Advance
3981+
if tree1.interval[1] == right:
3982+
tree1 = next(trees1, None)
3983+
if tree2.interval[1] == right:
3984+
tree2 = next(trees2, None)
3985+
39363986
def haplotypes(
39373987
self,
39383988
*,

0 commit comments

Comments
 (0)