diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index c1643e1ff1..c8dcb7143f 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -2,6 +2,11 @@ [1.1.2] - 2023-XX-XX -------------------- +**Performance improvements** + +- tsk_tree_seek is now much faster at seeking to arbitrary points along + the sequence from the null tree (:user:`molpopgen`, :pr:`2661`). + **Features** - The struct ``tsk_treeseq_t`` now has the variables ``min_time`` and ``max_time``, @@ -24,6 +29,8 @@ - Add `x_table_keep_rows` methods to provide efficient in-place table subsetting (:user:`jeromekelleher`, :pr:`2700`). +- Add `tsk_tree_seek_index` function + -------------------- [1.1.1] - 2022-07-29 -------------------- diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index f0ced8585f..94e33ee487 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -6140,10 +6140,16 @@ test_seek_multi_tree(void) ret = tsk_tree_seek(&t, breakpoints[j], 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, j); + ret = tsk_tree_seek_index(&t, j, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, j); for (k = 0; k < num_trees; k++) { ret = tsk_tree_seek(&t, breakpoints[k], 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(t.index, k); + ret = tsk_tree_seek_index(&t, k, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(t.index, k); } } @@ -6205,6 +6211,10 @@ test_seek_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); ret = tsk_tree_seek(&t, 11, 0); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); + ret = tsk_tree_seek_index(&t, (tsk_id_t) ts.num_trees, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); + ret = tsk_tree_seek_index(&t, -1, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SEEK_OUT_OF_BOUNDS); tsk_tree_free(&t); tsk_treeseq_free(&ts); diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 06cad0e823..4604579e0b 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -4549,19 +4549,138 @@ tsk_tree_position_in_interval(const tsk_tree_t *self, double x) return self->interval.left <= x && x < self->interval.right; } -int TSK_WARN_UNUSED -tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +/* NOTE: + * + * Notes from Kevin Thornton: + * + * This method inserts the edges for an arbitrary tree + * in linear time and requires no additional memory. + * + * During design, the following alternatives were tested + * (in a combination of rust + C): + * 1. Indexing edge insertion/removal locations by tree. + * The indexing can be done in O(n) time, giving O(1) + * access to the first edge in a tree. We can then add + * edges to the tree in O(e) time, where e is the number + * of edges. This apparoach requires O(n) additional memory + * and is only marginally faster than the implementation below. + * 2. Building an interval tree mapping edge id -> span. + * This approach adds a lot of complexity and wasn't any faster + * than the indexing described above. + */ +static int +tsk_tree_seek_from_null(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) { int ret = 0; + tsk_size_t edge; + tsk_id_t p, c, e, j, k, tree_index; const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); - const double t_l = self->interval.left; - const double t_r = self->interval.right; - double distance_left, distance_right; + const tsk_treeseq_t *treeseq = self->tree_sequence; + const tsk_table_collection_t *tables = treeseq->tables; + const tsk_id_t *restrict edge_parent = tables->edges.parent; + const tsk_id_t *restrict edge_child = tables->edges.child; + const tsk_size_t num_edges = tables->edges.num_rows; + const tsk_size_t num_trees = self->tree_sequence->num_trees; + const double *restrict edge_left = tables->edges.left; + const double *restrict edge_right = tables->edges.right; + const double *restrict breakpoints = treeseq->breakpoints; + const tsk_id_t *restrict insertion = tables->indexes.edge_insertion_order; + const tsk_id_t *restrict removal = tables->indexes.edge_removal_order; + + // NOTE: it may be better to get the + // index first and then ask if we are + // searching in the first or last 1/2 + // of trees. + j = -1; + if (x <= L / 2.0) { + for (edge = 0; edge < num_edges; edge++) { + e = insertion[edge]; + if (edge_left[e] > x) { + j = (tsk_id_t) edge; + break; + } + if (x >= edge_left[e] && x < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_tree_insert_edge(self, p, c, e); + } + } + } else { + for (edge = 0; edge < num_edges; edge++) { + e = removal[num_edges - edge - 1]; + if (edge_right[e] < x) { + j = (tsk_id_t)(num_edges - edge - 1); + while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { + j++; + } + break; + } + if (x >= edge_left[e] && x < edge_right[e]) { + p = edge_parent[e]; + c = edge_child[e]; + tsk_tree_insert_edge(self, p, c, e); + } + } + } - if (x < 0 || x >= L) { + if (j == -1) { + j = 0; + while (j < (tsk_id_t) num_edges && edge_left[insertion[j]] <= x) { + j++; + } + } + k = 0; + while (k < (tsk_id_t) num_edges && edge_right[removal[k]] <= x) { + k++; + } + + /* NOTE: tsk_search_sorted finds the first the first + * insertion locatiom >= the query point, which + * finds a RIGHT value for queries not at the left edge. + */ + tree_index = (tsk_id_t) tsk_search_sorted(breakpoints, num_trees + 1, x); + if (breakpoints[tree_index] > x) { + tree_index--; + } + self->index = tree_index; + self->interval.left = breakpoints[tree_index]; + self->interval.right = breakpoints[tree_index + 1]; + self->left_index = j; + self->right_index = k; + self->direction = TSK_DIR_FORWARD; + self->num_nodes = tables->nodes.num_rows; + if (tables->sites.num_rows > 0) { + self->sites = treeseq->tree_sites[self->index]; + self->sites_length = treeseq->tree_sites_length[self->index]; + } + + return ret; +} + +int TSK_WARN_UNUSED +tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options) +{ + int ret = 0; + double x; + + if (tree < 0 || tree >= (tsk_id_t) self->tree_sequence->num_trees) { ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; goto out; } + x = self->tree_sequence->breakpoints[tree]; + ret = tsk_tree_seek(self, x, options); +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_tree_seek_linear(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) +{ + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + const double t_l = self->interval.left; + const double t_r = self->interval.right; + int ret = 0; + double distance_left, distance_right; if (x < t_l) { /* |-----|-----|========|---------| */ @@ -4594,6 +4713,27 @@ tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t TSK_UNUSED(options)) return ret; } +int TSK_WARN_UNUSED +tsk_tree_seek(tsk_tree_t *self, double x, tsk_flags_t options) +{ + int ret = 0; + const double L = tsk_treeseq_get_sequence_length(self->tree_sequence); + + if (x < 0 || x >= L) { + ret = TSK_ERR_SEEK_OUT_OF_BOUNDS; + goto out; + } + + if (self->index == -1) { + ret = tsk_tree_seek_from_null(self, x, options); + } else { + ret = tsk_tree_seek_linear(self, x, options); + } + +out: + return ret; +} + int TSK_WARN_UNUSED tsk_tree_clear(tsk_tree_t *self) { diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 4a84bf3446..efe9980077 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1192,12 +1192,6 @@ we will have ``position < tree.interval.right``. Seeking to a position currently covered by the tree is a constant time operation. - -.. warning:: - The current implementation of ``seek`` does **not** provide efficient - random access to arbitrary positions along the genome. However, - sequentially seeking in either direction is as efficient as calling - :c:func:`tsk_tree_next` or :c:func:`tsk_tree_prev` directly. @endrst @param self A pointer to an initialised tsk_tree_t object. @@ -1208,6 +1202,22 @@ a constant time operation. */ int tsk_tree_seek(tsk_tree_t *self, double position, tsk_flags_t options); +/** +@brief Seek to a specific tree in a tree sequence. + +@rst +Set the state of this tree to reflect the tree in parent +tree sequence whose index is ``0 <= tree < num_trees``. +@endrst + +@param self A pointer to an initialised tsk_tree_t object. +@param tree The target tree index. +@param options Seek options. Currently unused. Set to 0 for compatibility + with future versions of tskit. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_tree_seek_index(tsk_tree_t *self, tsk_id_t tree, tsk_flags_t options); + /** @} */ /** diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 3a81b2aa1c..4450bf06c4 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -2,6 +2,11 @@ [0.5.5] - 2023-01-XX -------------------- +**Performance improvements** + +- Methods like ts.at() which seek to a specified position on the sequence from + a new Tree instance are now much faster (:user:`molpopgen`, :pr:`2661`). + **Features** - Add ``__repr__`` for variants to return a string representation of the raw data diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 2b379ff2b1..30c3e7743b 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10658,6 +10658,29 @@ Tree_seek(Tree *self, PyObject *args) return ret; } +static PyObject * +Tree_seek_index(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + tsk_id_t index = 0; + int err; + + if (Tree_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O&", tsk_id_converter, &index)) { + goto out; + } + err = tsk_tree_seek_index(self->tree, index, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + static PyObject * Tree_clear(Tree *self) { @@ -11796,6 +11819,10 @@ static PyMethodDef Tree_methods[] = { .ml_meth = (PyCFunction) Tree_seek, .ml_flags = METH_VARARGS, .ml_doc = "Seeks to the tree at the specified position" }, + { .ml_name = "seek_index", + .ml_meth = (PyCFunction) Tree_seek_index, + .ml_flags = METH_VARARGS, + .ml_doc = "Seeks to the tree at the specified index" }, { .ml_name = "clear", .ml_meth = (PyCFunction) Tree_clear, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index c926121aea..ca67e24ce0 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -4598,10 +4598,13 @@ def test_index_from_different_directions(self, index): t2.prev() assert_same_tree_different_order(t1, t2) - def test_seek_0_from_null(self): + @pytest.mark.parametrize("position", [0, 1, 2, 3]) + def test_seek_from_null(self, position): t1, t2 = self.setup() - t1.first() - t2.seek(0) + t1.clear() + t1.seek(position) + t2.first() + t2.seek(position) assert_trees_identical(t1, t2) @pytest.mark.parametrize("index", range(3)) @@ -4654,6 +4657,14 @@ def test_seek_3_from_null(self): t2.seek(3) assert_trees_identical(t1, t2) + def test_seek_3_from_null_prev(self): + t1, t2 = self.setup() + t1.last() + t1.prev() + t2.seek(3) + t2.prev() + assert_trees_identical(t1, t2) + def test_seek_3_from_0(self): t1, t2 = self.setup() t1.last() @@ -4669,6 +4680,37 @@ def test_seek_0_from_3(self): t2.seek(0) assert_trees_identical(t1, t2) + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_mid_null_and_middle(self, ts): + breakpoints = ts.breakpoints(as_array=True) + mid = breakpoints[:-1] + np.diff(breakpoints) / 2 + for index, x in enumerate(mid[:-1]): + t1 = tskit.Tree(ts) + t1.seek(x) + # Also seek to this point manually to make sure we're not + # reusing the seek from null under the hood. + t2 = tskit.Tree(ts) + if index <= ts.num_trees / 2: + while t2.index != index: + t2.next() + else: + while t2.index != index: + t2.prev() + assert t1.index == t2.index + assert np.all(t1.parent_array == t2.parent_array) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_seek_last_then_prev(self, ts): + t1 = tskit.Tree(ts) + t1.seek(ts.sequence_length - 0.00001) + assert t1.index == ts.num_trees - 1 + t2 = tskit.Tree(ts) + t2.prev() + assert_trees_identical(t1, t2) + t1.prev() + t2.prev() + assert_trees_identical(t1, t2) + class TestSeek: @pytest.mark.parametrize("ts", get_example_tree_sequences()) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 8bbc201b85..7ebe6467eb 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2968,6 +2968,16 @@ def test_seek_errors(self): with pytest.raises(_tskit.LibraryError): tree.seek(bad_pos) + def test_seek_index_errors(self): + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + for bad_type in ["", "x", {}]: + with pytest.raises(TypeError): + tree.seek_index(bad_type) + for bad_index in [-1, 10**6]: + with pytest.raises(_tskit.LibraryError): + tree.seek_index(bad_index) + def test_root_threshold(self): for ts in self.get_example_tree_sequences(): tree = _tskit.Tree(ts) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 39a1e4c41a..da5a7f9d07 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -820,7 +820,6 @@ def seek_index(self, index): .. include:: substitutions/linear_traversal_warning.rst - :param int index: The tree index to seek to. :raises IndexError: If an index outside the acceptable range is provided. """ @@ -829,12 +828,7 @@ def seek_index(self, index): index += num_trees if index < 0 or index >= num_trees: raise IndexError("Index out of bounds") - # This should be implemented in C efficiently using the indexes. - # No point in complicating the current implementation by trying - # to seek from the correct direction. - self.first() - while self.index != index: - self.next() + self._ll_tree.seek_index(index) def seek(self, position): """