Skip to content

Commit 810988d

Browse files
Implement TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS
Closes tskit-dev#2662
1 parent 2b61dfd commit 810988d

File tree

9 files changed

+167
-47
lines changed

9 files changed

+167
-47
lines changed

c/tests/test_trees.c

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,6 +3323,36 @@ test_simplest_no_node_filter(void)
33233323
tsk_treeseq_free(&ts);
33243324
}
33253325

3326+
static void
3327+
test_simplest_no_update_flags(void)
3328+
{
3329+
const char *nodes = "0 0 0\n"
3330+
"1 0 0\n"
3331+
"0 1 0\n";
3332+
const char *edges = "0 1 2 0,1\n";
3333+
tsk_treeseq_t ts, simplified;
3334+
tsk_id_t sample_ids[] = { 0, 1 };
3335+
int ret;
3336+
3337+
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);
3338+
3339+
/* We have a mixture of sample and non-samples in the input tables */
3340+
ret = tsk_treeseq_simplify(
3341+
&ts, sample_ids, 2, TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS, &simplified, NULL);
3342+
CU_ASSERT_EQUAL_FATAL(ret, 0);
3343+
CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0));
3344+
tsk_treeseq_free(&simplified);
3345+
3346+
ret = tsk_treeseq_simplify(&ts, sample_ids, 2,
3347+
TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS | TSK_SIMPLIFY_NO_FILTER_NODES, &simplified,
3348+
NULL);
3349+
CU_ASSERT_EQUAL_FATAL(ret, 0);
3350+
CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0));
3351+
tsk_treeseq_free(&simplified);
3352+
3353+
tsk_treeseq_free(&ts);
3354+
}
3355+
33263356
static void
33273357
test_simplest_map_mutations(void)
33283358
{
@@ -8093,6 +8123,7 @@ main(int argc, char **argv)
80938123
{ "test_simplest_population_filter", test_simplest_population_filter },
80948124
{ "test_simplest_individual_filter", test_simplest_individual_filter },
80958125
{ "test_simplest_no_node_filter", test_simplest_no_node_filter },
8126+
{ "test_simplest_no_update_flags", test_simplest_no_update_flags },
80968127
{ "test_simplest_map_mutations", test_simplest_map_mutations },
80978128
{ "test_simplest_nonbinary_map_mutations",
80988129
test_simplest_nonbinary_map_mutations },

c/tskit/tables.c

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8808,16 +8808,18 @@ static tsk_id_t TSK_WARN_UNUSED
88088808
simplifier_record_node(simplifier_t *self, tsk_id_t input_id)
88098809
{
88108810
tsk_node_t node;
8811-
tsk_flags_t flags;
8811+
bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS);
88128812

88138813
tsk_node_table_get_row_unsafe(&self->input_tables.nodes, (tsk_id_t) input_id, &node);
8814-
/* Zero out the sample bit */
8815-
flags = node.flags & (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
8816-
if (self->is_sample[input_id]) {
8817-
flags |= TSK_NODE_IS_SAMPLE;
8814+
if (update_flags) {
8815+
/* Zero out the sample bit */
8816+
node.flags &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
8817+
if (self->is_sample[input_id]) {
8818+
node.flags |= TSK_NODE_IS_SAMPLE;
8819+
}
88188820
}
88198821
self->node_id_map[input_id] = (tsk_id_t) self->tables->nodes.num_rows;
8820-
return tsk_node_table_add_row(&self->tables->nodes, flags, node.time,
8822+
return tsk_node_table_add_row(&self->tables->nodes, node.flags, node.time,
88218823
node.population, node.individual, node.metadata, node.metadata_length);
88228824
}
88238825

@@ -9108,6 +9110,7 @@ simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples)
91089110
tsk_size_t j;
91099111
const tsk_size_t num_nodes = self->input_tables.nodes.num_rows;
91109112
bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES);
9113+
bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS);
91119114
tsk_flags_t *node_flags = self->tables->nodes.flags;
91129115
tsk_id_t *node_id_map = self->node_id_map;
91139116

@@ -9123,13 +9126,17 @@ simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples)
91239126
}
91249127
} else {
91259128
tsk_bug_assert(self->tables->nodes.num_rows == num_nodes);
9126-
/* The node table has not been changed */
9127-
for (j = 0; j < num_nodes; j++) {
9128-
/* Reset the sample flags */
9129-
node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
9130-
if (self->is_sample[j]) {
9131-
node_flags[j] |= TSK_NODE_IS_SAMPLE;
9129+
if (update_flags) {
9130+
for (j = 0; j < num_nodes; j++) {
9131+
/* Reset the sample flags */
9132+
node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
9133+
if (self->is_sample[j]) {
9134+
node_flags[j] |= TSK_NODE_IS_SAMPLE;
9135+
}
91329136
}
9137+
}
9138+
9139+
for (j = 0; j < num_nodes; j++) {
91339140
node_id_map[j] = (tsk_id_t) j;
91349141
}
91359142
}

c/tskit/tables.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,10 @@ first.
694694
*/
695695
#define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7)
696696
/**
697+
Do not update the sample status of nodes as a result of simplification.
698+
*/
699+
#define TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS (1 << 8)
700+
/**
697701
Reduce the topological information in the tables to the minimum necessary to
698702
represent the trees that contain sites. If there are zero sites this will
699703
result in an zero output edges. When the number of sites is greater than zero,
@@ -3919,9 +3923,10 @@ or :c:macro:`TSK_NULL` if the node has been removed. Thus, ``node_map`` must be
39193923
of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values.
39203924
39213925
If the `TSK_SIMPLIFY_NO_FILTER_NODES` option is specified, the node table will be
3922-
unaltered except for changing the sample status of nodes that were samples in the
3923-
input tables, but not in the specified list of sample IDs (if provided). The
3924-
``node_map`` (if specified) will always be the identity mapping, such that
3926+
unaltered except for changing the sample status of nodes (but see the
3927+
`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option below) and to update references
3928+
to other tables that may have changed as a result of filtering (see below).
3929+
The ``node_map`` (if specified) will always be the identity mapping, such that
39253930
``node_map[u] == u`` for all nodes. Note also that the order of the list of
39263931
samples is not important in this case.
39273932
@@ -3941,6 +3946,17 @@ sample status flag of nodes.
39413946
may be entirely unreferenced entities in the input tables, which
39423947
are not affected by whether we filter nodes or not.
39433948
3949+
By default, the node sample flags are updated by unsetting
3950+
:c:macro:`TSK_NODE_IS_SAMPLE` flag for all nodes and subsequently
3951+
setting it for the nodes provided as input to this function.
3952+
The `TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option will prevent this
3953+
from occuring, making it the responsibility of calling code to
3954+
keep track of the ultimate sample status of nodes. Using
3955+
this option in conjunction with `TSK_SIMPLIFY_NO_FILTER_NODES`
3956+
(and without the `TSK_SIMPLIFY_FILTER_POPULATIONS` and
3957+
`TSK_SIMPLIFY_FILTER_INDIVIDUALS`) guarantees that the node table
3958+
will not be written to during the lifetime of this function.
3959+
39443960
The table collection will always be unindexed after simplify successfully completes.
39453961
39463962
.. note:: Migrations are currently not supported by simplify, and an error will
@@ -3956,6 +3972,7 @@ Options can be specified by providing one or more of the following bitwise
39563972
- :c:macro:`TSK_SIMPLIFY_FILTER_POPULATIONS`
39573973
- :c:macro:`TSK_SIMPLIFY_FILTER_INDIVIDUALS`
39583974
- :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES`
3975+
- :c:macro:`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS`
39593976
- :c:macro:`TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY`
39603977
- :c:macro:`TSK_SIMPLIFY_KEEP_UNARY`
39613978
- :c:macro:`TSK_SIMPLIFY_KEEP_INPUT_ROOTS`

python/_tskitmodule.c

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6589,21 +6589,23 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
65896589
int filter_individuals = false;
65906590
int filter_populations = false;
65916591
int filter_nodes = true;
6592+
int update_sample_flags = true;
65926593
int keep_unary = false;
65936594
int keep_unary_in_individuals = false;
65946595
int keep_input_roots = false;
65956596
int reduce_to_site_topology = false;
6596-
static char *kwlist[] = { "samples", "filter_sites", "filter_populations",
6597-
"filter_individuals", "filter_nodes", "reduce_to_site_topology", "keep_unary",
6598-
"keep_unary_in_individuals", "keep_input_roots", NULL };
6597+
static char *kwlist[]
6598+
= { "samples", "filter_sites", "filter_populations", "filter_individuals",
6599+
"filter_nodes", "update_sample_flags", "reduce_to_site_topology",
6600+
"keep_unary", "keep_unary_in_individuals", "keep_input_roots", NULL };
65996601

66006602
if (TableCollection_check_state(self) != 0) {
66016603
goto out;
66026604
}
6603-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiii", kwlist, &samples,
6605+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiiii", kwlist, &samples,
66046606
&filter_sites, &filter_populations, &filter_individuals, &filter_nodes,
6605-
&reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals,
6606-
&keep_input_roots)) {
6607+
&update_sample_flags, &reduce_to_site_topology, &keep_unary,
6608+
&keep_unary_in_individuals, &keep_input_roots)) {
66076609
goto out;
66086610
}
66096611
samples_array = (PyArrayObject *) PyArray_FROMANY(
@@ -6625,6 +6627,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
66256627
if (!filter_nodes) {
66266628
options |= TSK_SIMPLIFY_NO_FILTER_NODES;
66276629
}
6630+
if (!update_sample_flags) {
6631+
options |= TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS;
6632+
}
66286633
if (reduce_to_site_topology) {
66296634
options |= TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY;
66306635
}

python/tests/simplify.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def __init__(
111111
keep_unary=False,
112112
keep_unary_in_individuals=False,
113113
keep_input_roots=False,
114-
filter_nodes=True, # If this is False, the order in `sample` is ignored
114+
filter_nodes=True,
115+
update_sample_flags=True,
115116
):
116117
self.ts = ts
117118
self.n = len(sample)
@@ -121,6 +122,7 @@ def __init__(
121122
self.filter_populations = filter_populations
122123
self.filter_individuals = filter_individuals
123124
self.filter_nodes = filter_nodes
125+
self.update_sample_flags = update_sample_flags
124126
self.keep_unary = keep_unary
125127
self.keep_unary_in_individuals = keep_unary_in_individuals
126128
self.keep_input_roots = keep_input_roots
@@ -152,14 +154,14 @@ def __init__(
152154
# NOTE In the C implementation we would really just not touch the
153155
# original tables.
154156
self.tables.nodes.replace_with(self.ts.tables.nodes)
155-
# TODO make this optional somehow
156-
flags = self.tables.nodes.flags
157-
# Zero out other sample flags
158-
flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE)
159-
flags[sample] |= tskit.NODE_IS_SAMPLE
160-
self.tables.nodes.flags = flags.astype(np.uint32)
161-
self.node_id_map[:] = np.arange(ts.num_nodes)
157+
if self.update_sample_flags:
158+
flags = self.tables.nodes.flags
159+
# Zero out other sample flags
160+
flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE)
161+
flags[sample] |= tskit.NODE_IS_SAMPLE
162+
self.tables.nodes.flags = flags.astype(np.uint32)
162163

164+
self.node_id_map[:] = np.arange(ts.num_nodes)
163165
for sample_id in sample:
164166
self.add_ancestry(sample_id, 0, self.sequence_length, sample_id)
165167
else:
@@ -178,10 +180,11 @@ def record_node(self, input_id):
178180
"""
179181
node = self.ts.node(input_id)
180182
flags = node.flags
181-
# Need to zero out the sample flag
182-
flags &= ~tskit.NODE_IS_SAMPLE
183-
if self.is_sample[input_id]:
184-
flags |= tskit.NODE_IS_SAMPLE
183+
if self.update_sample_flags:
184+
# Need to zero out the sample flag
185+
flags &= ~tskit.NODE_IS_SAMPLE
186+
if self.is_sample[input_id]:
187+
flags |= tskit.NODE_IS_SAMPLE
185188
output_id = self.tables.nodes.append(node.replace(flags=flags))
186189
self.node_id_map[input_id] = output_id
187190
return output_id

python/tests/test_lowlevel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def test_simplify_bad_args(self):
349349
tc.simplify([0, 1], filter_populations="x")
350350
with pytest.raises(TypeError):
351351
tc.simplify([0, 1], filter_nodes="x")
352+
with pytest.raises(TypeError):
353+
tc.simplify([0, 1], update_sample_flags="x")
352354
with pytest.raises(_tskit.LibraryError):
353355
tc.simplify([0, -1])
354356

@@ -360,6 +362,7 @@ def test_simplify_bad_args(self):
360362
"filter_populations",
361363
"filter_individuals",
362364
"filter_nodes",
365+
"update_sample_flags",
363366
"reduce_to_site_topology",
364367
"keep_unary",
365368
"keep_unary_in_individuals",

python/tests/test_topology.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2686,7 +2686,7 @@ def verify_simplify(
26862686
filter_sites=filter_sites,
26872687
keep_input_roots=keep_input_roots,
26882688
filter_nodes=filter_nodes,
2689-
compare_lib=True, # TMP
2689+
compare_lib=True,
26902690
)
26912691
if debug:
26922692
print("before")
@@ -4757,6 +4757,7 @@ def do_simplify(
47574757
filter_nodes=True,
47584758
keep_unary=False,
47594759
keep_input_roots=False,
4760+
update_sample_flags=True,
47604761
):
47614762
"""
47624763
Runs the Python test implementation of simplify.
@@ -4772,6 +4773,7 @@ def do_simplify(
47724773
filter_nodes=filter_nodes,
47734774
keep_unary=keep_unary,
47744775
keep_input_roots=keep_input_roots,
4776+
update_sample_flags=update_sample_flags,
47754777
)
47764778
new_ts, node_map = s.simplify()
47774779
if compare_lib:
@@ -4781,28 +4783,16 @@ def do_simplify(
47814783
filter_individuals=filter_individuals,
47824784
filter_populations=filter_populations,
47834785
filter_nodes=filter_nodes,
4786+
update_sample_flags=update_sample_flags,
47844787
keep_unary=keep_unary,
47854788
keep_input_roots=keep_input_roots,
47864789
map_nodes=True,
47874790
)
47884791
lib_tables1 = sts.dump_tables()
47894792

4790-
lib_tables2 = ts.dump_tables()
4791-
lib_node_map2 = lib_tables2.simplify(
4792-
samples,
4793-
filter_sites=filter_sites,
4794-
keep_unary=keep_unary,
4795-
keep_input_roots=keep_input_roots,
4796-
filter_individuals=filter_individuals,
4797-
filter_populations=filter_populations,
4798-
filter_nodes=filter_nodes,
4799-
)
4800-
48014793
py_tables = new_ts.dump_tables()
48024794
py_tables.assert_equals(lib_tables1, ignore_provenance=True)
4803-
py_tables.assert_equals(lib_tables2, ignore_provenance=True)
48044795
assert all(node_map == lib_node_map1)
4805-
assert all(node_map == lib_node_map2)
48064796
return new_ts, node_map
48074797

48084798

@@ -6091,6 +6081,64 @@ def test_mutations_on_removed_branches(self):
60916081
assert ts2.num_mutations == 0
60926082

60936083

6084+
class TestSimplifyNoUpdateSampleFlags:
6085+
"""
6086+
Tests for simplify when we don't update the sample flags.
6087+
"""
6088+
6089+
def test_simple_case_filter_nodes(self):
6090+
# 2.00┊ 6 ┊
6091+
# ┊ ┏━┻━┓ ┊
6092+
# 1.00┊ 4 5 ┊
6093+
# ┊ ┏┻┓ ┏┻┓ ┊
6094+
# 0.00┊ 0 1 2 3 ┊
6095+
# 0 1
6096+
ts1 = tskit.Tree.generate_balanced(4).tree_sequence
6097+
ts2, node_map = do_simplify(
6098+
ts1,
6099+
[0, 1, 6],
6100+
update_sample_flags=False,
6101+
)
6102+
# Because we don't retain 2 and 3 here, they don't stay as
6103+
# samples. But, we specified 6 as a sample, so it's coming
6104+
# through where it would ordinarily be dropped.
6105+
6106+
# 2.00┊ 2 ┊
6107+
# ┊ ┃ ┊
6108+
# 1.00┊ 3 ┊
6109+
# ┊ ┏┻┓ ┊
6110+
# 0.00┊ 0 1 ┊
6111+
# 0 1
6112+
assert list(ts2.nodes_flags) == [1, 1, 0, 0]
6113+
tree = ts2.first()
6114+
assert list(tree.parent_array) == [3, 3, -1, 2, -1]
6115+
6116+
def test_simple_case_no_filter_nodes(self):
6117+
# 2.00┊ 6 ┊
6118+
# ┊ ┏━┻━┓ ┊
6119+
# 1.00┊ 4 5 ┊
6120+
# ┊ ┏┻┓ ┏┻┓ ┊
6121+
# 0.00┊ 0 1 2 3 ┊
6122+
# 0 1
6123+
ts1 = tskit.Tree.generate_balanced(4).tree_sequence
6124+
ts2, node_map = do_simplify(
6125+
ts1,
6126+
[0, 1, 6],
6127+
update_sample_flags=False,
6128+
filter_nodes=False,
6129+
)
6130+
6131+
# 2.00┊ 6 ┊
6132+
# ┊ ┃ ┊
6133+
# 1.00┊ 4 ┊
6134+
# ┊ ┏┻┓ ┊
6135+
# 0.00┊ 0 1 2 3 ┊
6136+
# 0 1
6137+
assert list(ts2.nodes_flags) == list(ts1.nodes_flags)
6138+
tree = ts2.first()
6139+
assert list(tree.parent_array) == [4, 4, -1, -1, 6, -1, -1, -1]
6140+
6141+
60946142
class TestMapToAncestors:
60956143
"""
60966144
Tests the AncestorMap class.

0 commit comments

Comments
 (0)