Skip to content

Commit 19d2b90

Browse files
More tests for seek_from_null
Also minor style tweaks
1 parent 5c78177 commit 19d2b90

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

c/tskit/trees.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4581,8 +4581,8 @@ tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(optio
45814581
const tsk_table_collection_t *tables = treeseq->tables;
45824582
const tsk_id_t *restrict edge_parent = tables->edges.parent;
45834583
const tsk_id_t *restrict edge_child = tables->edges.child;
4584-
const tsk_size_t num_edges = tables->edges.num_rows,
4585-
num_trees = self->tree_sequence->num_trees;
4584+
const tsk_size_t num_edges = tables->edges.num_rows;
4585+
const tsk_size_t num_trees = self->tree_sequence->num_trees;
45864586
const double *restrict edge_left = tables->edges.left;
45874587
const double *restrict edge_right = tables->edges.right;
45884588
const double *restrict breakpoints = treeseq->breakpoints;
@@ -4595,7 +4595,7 @@ tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(optio
45954595
// of trees.
45964596
j = -1;
45974597
if (x <= L / 2.0) {
4598-
for (edge = 0; edge < num_edges; ++edge) {
4598+
for (edge = 0; edge < num_edges; edge++) {
45994599
e = insertion[edge];
46004600
if (edge_left[e] > x) {
46014601
j = (tsk_id_t) edge;
@@ -4608,12 +4608,12 @@ tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(optio
46084608
}
46094609
}
46104610
} else {
4611-
for (edge = 0; edge < num_edges; ++edge) {
4611+
for (edge = 0; edge < num_edges; edge++) {
46124612
e = removal[num_edges - edge - 1];
46134613
if (edge_right[e] < x) {
46144614
j = (tsk_id_t)(num_edges - edge - 1);
46154615
while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) {
4616-
j += 1;
4616+
j++;
46174617
}
46184618
break;
46194619
}
@@ -4628,12 +4628,12 @@ tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(optio
46284628
if (j == -1) {
46294629
j = 0;
46304630
while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) {
4631-
j += 1;
4631+
j++;
46324632
}
46334633
}
46344634
k = 0;
46354635
while (k < (tsk_id_t) num_edges && edge_right[removal[k]] <= x) {
4636-
k += 1;
4636+
k--;
46374637
}
46384638

46394639
/* NOTE: tsk_search_sorted finds the first the first
@@ -4683,6 +4683,7 @@ tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)
46834683
const double t_r = self->interval.right;
46844684
int ret = 0;
46854685
double distance_left, distance_right;
4686+
46864687
if (x < t_l) {
46874688
/* |-----|-----|========|---------| */
46884689
/* 0 x t_l t_r L */

python/tests/test_highlevel.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4598,10 +4598,13 @@ def test_index_from_different_directions(self, index):
45984598
t2.prev()
45994599
assert_same_tree_different_order(t1, t2)
46004600

4601-
def test_seek_0_from_null(self):
4601+
@pytest.mark.parametrize("position", [0, 1, 2, 3])
4602+
def test_seek_from_null(self, position):
46024603
t1, t2 = self.setup()
4603-
t1.first()
4604-
t2.seek(0)
4604+
t1.clear()
4605+
t1.seek(position)
4606+
t2.first()
4607+
t2.seek(position)
46054608
assert_trees_identical(t1, t2)
46064609

46074610
@pytest.mark.parametrize("index", range(3))
@@ -4654,6 +4657,14 @@ def test_seek_3_from_null(self):
46544657
t2.seek(3)
46554658
assert_trees_identical(t1, t2)
46564659

4660+
def test_seek_3_from_null_prev(self):
4661+
t1, t2 = self.setup()
4662+
t1.last()
4663+
t1.prev()
4664+
t2.seek(3)
4665+
t2.prev()
4666+
assert_trees_identical(t1, t2)
4667+
46574668
def test_seek_3_from_0(self):
46584669
t1, t2 = self.setup()
46594670
t1.last()
@@ -4669,6 +4680,37 @@ def test_seek_0_from_3(self):
46694680
t2.seek(0)
46704681
assert_trees_identical(t1, t2)
46714682

4683+
@pytest.mark.parametrize("ts", get_example_tree_sequences())
4684+
def test_seek_mid_null_and_middle(self, ts):
4685+
breakpoints = ts.breakpoints(as_array=True)
4686+
mid = breakpoints[:-1] + np.diff(breakpoints) / 2
4687+
for index, x in enumerate(mid[:-1]):
4688+
t1 = tskit.Tree(ts)
4689+
t1.seek(x)
4690+
# Also seek to this point manually to make sure we're not
4691+
# reusing the seek from null under the hood.
4692+
t2 = tskit.Tree(ts)
4693+
if index <= ts.num_trees / 2:
4694+
while t2.index != index:
4695+
t2.next()
4696+
else:
4697+
while t2.index != index:
4698+
t2.prev()
4699+
assert t1.index == t2.index
4700+
assert np.all(t1.parent_array == t2.parent_array)
4701+
4702+
@pytest.mark.parametrize("ts", get_example_tree_sequences())
4703+
def test_seek_last_then_prev(self, ts):
4704+
t1 = tskit.Tree(ts)
4705+
t1.seek(ts.sequence_length - 0.00001)
4706+
assert t1.index == ts.num_trees - 1
4707+
t2 = tskit.Tree(ts)
4708+
t2.prev()
4709+
assert_trees_identical(t1, t2)
4710+
t1.prev()
4711+
t2.prev()
4712+
assert_trees_identical(t1, t2)
4713+
46724714

46734715
class TestSeek:
46744716
@pytest.mark.parametrize("ts", get_example_tree_sequences())

0 commit comments

Comments
 (0)