Skip to content

Commit a84d84e

Browse files
Implement TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS
Closes #2662
1 parent 2b61dfd commit a84d84e

File tree

10 files changed

+211
-62
lines changed

10 files changed

+211
-62
lines changed

c/tests/test_trees.c

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,6 +3323,36 @@ test_simplest_no_node_filter(void)
33233323
tsk_treeseq_free(&ts);
33243324
}
33253325

3326+
static void
3327+
test_simplest_no_update_flags(void)
3328+
{
3329+
const char *nodes = "0 0 0\n"
3330+
"1 0 0\n"
3331+
"0 1 0\n";
3332+
const char *edges = "0 1 2 0,1\n";
3333+
tsk_treeseq_t ts, simplified;
3334+
tsk_id_t sample_ids[] = { 0, 1 };
3335+
int ret;
3336+
3337+
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);
3338+
3339+
/* We have a mixture of sample and non-samples in the input tables */
3340+
ret = tsk_treeseq_simplify(
3341+
&ts, sample_ids, 2, TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS, &simplified, NULL);
3342+
CU_ASSERT_EQUAL_FATAL(ret, 0);
3343+
CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0));
3344+
tsk_treeseq_free(&simplified);
3345+
3346+
ret = tsk_treeseq_simplify(&ts, sample_ids, 2,
3347+
TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS | TSK_SIMPLIFY_NO_FILTER_NODES, &simplified,
3348+
NULL);
3349+
CU_ASSERT_EQUAL_FATAL(ret, 0);
3350+
CU_ASSERT_TRUE(tsk_table_collection_equals(ts.tables, simplified.tables, 0));
3351+
tsk_treeseq_free(&simplified);
3352+
3353+
tsk_treeseq_free(&ts);
3354+
}
3355+
33263356
static void
33273357
test_simplest_map_mutations(void)
33283358
{
@@ -8093,6 +8123,7 @@ main(int argc, char **argv)
80938123
{ "test_simplest_population_filter", test_simplest_population_filter },
80948124
{ "test_simplest_individual_filter", test_simplest_individual_filter },
80958125
{ "test_simplest_no_node_filter", test_simplest_no_node_filter },
8126+
{ "test_simplest_no_update_flags", test_simplest_no_update_flags },
80968127
{ "test_simplest_map_mutations", test_simplest_map_mutations },
80978128
{ "test_simplest_nonbinary_map_mutations",
80988129
test_simplest_nonbinary_map_mutations },

c/tskit/tables.c

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8808,16 +8808,18 @@ static tsk_id_t TSK_WARN_UNUSED
88088808
simplifier_record_node(simplifier_t *self, tsk_id_t input_id)
88098809
{
88108810
tsk_node_t node;
8811-
tsk_flags_t flags;
8811+
bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS);
88128812

88138813
tsk_node_table_get_row_unsafe(&self->input_tables.nodes, (tsk_id_t) input_id, &node);
8814-
/* Zero out the sample bit */
8815-
flags = node.flags & (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
8816-
if (self->is_sample[input_id]) {
8817-
flags |= TSK_NODE_IS_SAMPLE;
8814+
if (update_flags) {
8815+
/* Zero out the sample bit */
8816+
node.flags &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
8817+
if (self->is_sample[input_id]) {
8818+
node.flags |= TSK_NODE_IS_SAMPLE;
8819+
}
88188820
}
88198821
self->node_id_map[input_id] = (tsk_id_t) self->tables->nodes.num_rows;
8820-
return tsk_node_table_add_row(&self->tables->nodes, flags, node.time,
8822+
return tsk_node_table_add_row(&self->tables->nodes, node.flags, node.time,
88218823
node.population, node.individual, node.metadata, node.metadata_length);
88228824
}
88238825

@@ -9108,6 +9110,7 @@ simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples)
91089110
tsk_size_t j;
91099111
const tsk_size_t num_nodes = self->input_tables.nodes.num_rows;
91109112
bool filter_nodes = !(self->options & TSK_SIMPLIFY_NO_FILTER_NODES);
9113+
bool update_flags = !(self->options & TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS);
91119114
tsk_flags_t *node_flags = self->tables->nodes.flags;
91129115
tsk_id_t *node_id_map = self->node_id_map;
91139116

@@ -9123,13 +9126,17 @@ simplifier_init_nodes(simplifier_t *self, const tsk_id_t *samples)
91239126
}
91249127
} else {
91259128
tsk_bug_assert(self->tables->nodes.num_rows == num_nodes);
9126-
/* The node table has not been changed */
9127-
for (j = 0; j < num_nodes; j++) {
9128-
/* Reset the sample flags */
9129-
node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
9130-
if (self->is_sample[j]) {
9131-
node_flags[j] |= TSK_NODE_IS_SAMPLE;
9129+
if (update_flags) {
9130+
for (j = 0; j < num_nodes; j++) {
9131+
/* Reset the sample flags */
9132+
node_flags[j] &= (tsk_flags_t) ~TSK_NODE_IS_SAMPLE;
9133+
if (self->is_sample[j]) {
9134+
node_flags[j] |= TSK_NODE_IS_SAMPLE;
9135+
}
91329136
}
9137+
}
9138+
9139+
for (j = 0; j < num_nodes; j++) {
91339140
node_id_map[j] = (tsk_id_t) j;
91349141
}
91359142
}

c/tskit/tables.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,10 @@ first.
694694
*/
695695
#define TSK_SIMPLIFY_NO_FILTER_NODES (1 << 7)
696696
/**
697+
Do not update the sample status of nodes as a result of simplification.
698+
*/
699+
#define TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS (1 << 8)
700+
/**
697701
Reduce the topological information in the tables to the minimum necessary to
698702
represent the trees that contain sites. If there are zero sites this will
699703
result in an zero output edges. When the number of sites is greater than zero,
@@ -3919,9 +3923,10 @@ or :c:macro:`TSK_NULL` if the node has been removed. Thus, ``node_map`` must be
39193923
of at least ``self->nodes.num_rows`` :c:type:`tsk_id_t` values.
39203924
39213925
If the `TSK_SIMPLIFY_NO_FILTER_NODES` option is specified, the node table will be
3922-
unaltered except for changing the sample status of nodes that were samples in the
3923-
input tables, but not in the specified list of sample IDs (if provided). The
3924-
``node_map`` (if specified) will always be the identity mapping, such that
3926+
unaltered except for changing the sample status of nodes (but see the
3927+
`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option below) and to update references
3928+
to other tables that may have changed as a result of filtering (see below).
3929+
The ``node_map`` (if specified) will always be the identity mapping, such that
39253930
``node_map[u] == u`` for all nodes. Note also that the order of the list of
39263931
samples is not important in this case.
39273932
@@ -3941,6 +3946,17 @@ sample status flag of nodes.
39413946
may be entirely unreferenced entities in the input tables, which
39423947
are not affected by whether we filter nodes or not.
39433948
3949+
By default, the node sample flags are updated by unsetting the
3950+
:c:macro:`TSK_NODE_IS_SAMPLE` flag for all nodes and subsequently
3951+
setting it for the nodes provided as input to this function.
3952+
The `TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS` option will prevent this
3953+
from occuring, making it the responsibility of calling code to
3954+
keep track of the ultimate sample status of nodes. Using
3955+
this option in conjunction with `TSK_SIMPLIFY_NO_FILTER_NODES`
3956+
(and without the `TSK_SIMPLIFY_FILTER_POPULATIONS` and
3957+
`TSK_SIMPLIFY_FILTER_INDIVIDUALS`) guarantees that the node table
3958+
will not be written to during the lifetime of this function.
3959+
39443960
The table collection will always be unindexed after simplify successfully completes.
39453961
39463962
.. note:: Migrations are currently not supported by simplify, and an error will
@@ -3956,6 +3972,7 @@ Options can be specified by providing one or more of the following bitwise
39563972
- :c:macro:`TSK_SIMPLIFY_FILTER_POPULATIONS`
39573973
- :c:macro:`TSK_SIMPLIFY_FILTER_INDIVIDUALS`
39583974
- :c:macro:`TSK_SIMPLIFY_NO_FILTER_NODES`
3975+
- :c:macro:`TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS`
39593976
- :c:macro:`TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY`
39603977
- :c:macro:`TSK_SIMPLIFY_KEEP_UNARY`
39613978
- :c:macro:`TSK_SIMPLIFY_KEEP_INPUT_ROOTS`

python/_tskitmodule.c

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6589,21 +6589,23 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
65896589
int filter_individuals = false;
65906590
int filter_populations = false;
65916591
int filter_nodes = true;
6592+
int update_sample_flags = true;
65926593
int keep_unary = false;
65936594
int keep_unary_in_individuals = false;
65946595
int keep_input_roots = false;
65956596
int reduce_to_site_topology = false;
6596-
static char *kwlist[] = { "samples", "filter_sites", "filter_populations",
6597-
"filter_individuals", "filter_nodes", "reduce_to_site_topology", "keep_unary",
6598-
"keep_unary_in_individuals", "keep_input_roots", NULL };
6597+
static char *kwlist[]
6598+
= { "samples", "filter_sites", "filter_populations", "filter_individuals",
6599+
"filter_nodes", "update_sample_flags", "reduce_to_site_topology",
6600+
"keep_unary", "keep_unary_in_individuals", "keep_input_roots", NULL };
65996601

66006602
if (TableCollection_check_state(self) != 0) {
66016603
goto out;
66026604
}
6603-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiii", kwlist, &samples,
6605+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiiiii", kwlist, &samples,
66046606
&filter_sites, &filter_populations, &filter_individuals, &filter_nodes,
6605-
&reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals,
6606-
&keep_input_roots)) {
6607+
&update_sample_flags, &reduce_to_site_topology, &keep_unary,
6608+
&keep_unary_in_individuals, &keep_input_roots)) {
66076609
goto out;
66086610
}
66096611
samples_array = (PyArrayObject *) PyArray_FROMANY(
@@ -6625,6 +6627,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
66256627
if (!filter_nodes) {
66266628
options |= TSK_SIMPLIFY_NO_FILTER_NODES;
66276629
}
6630+
if (!update_sample_flags) {
6631+
options |= TSK_SIMPLIFY_NO_UPDATE_SAMPLE_FLAGS;
6632+
}
66286633
if (reduce_to_site_topology) {
66296634
options |= TSK_SIMPLIFY_REDUCE_TO_SITE_TOPOLOGY;
66306635
}

python/tests/simplify.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def __init__(
111111
keep_unary=False,
112112
keep_unary_in_individuals=False,
113113
keep_input_roots=False,
114-
filter_nodes=True, # If this is False, the order in `sample` is ignored
114+
filter_nodes=True,
115+
update_sample_flags=True,
115116
):
116117
self.ts = ts
117118
self.n = len(sample)
@@ -121,6 +122,7 @@ def __init__(
121122
self.filter_populations = filter_populations
122123
self.filter_individuals = filter_individuals
123124
self.filter_nodes = filter_nodes
125+
self.update_sample_flags = update_sample_flags
124126
self.keep_unary = keep_unary
125127
self.keep_unary_in_individuals = keep_unary_in_individuals
126128
self.keep_input_roots = keep_input_roots
@@ -152,14 +154,14 @@ def __init__(
152154
# NOTE In the C implementation we would really just not touch the
153155
# original tables.
154156
self.tables.nodes.replace_with(self.ts.tables.nodes)
155-
# TODO make this optional somehow
156-
flags = self.tables.nodes.flags
157-
# Zero out other sample flags
158-
flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE)
159-
flags[sample] |= tskit.NODE_IS_SAMPLE
160-
self.tables.nodes.flags = flags.astype(np.uint32)
161-
self.node_id_map[:] = np.arange(ts.num_nodes)
157+
if self.update_sample_flags:
158+
flags = self.tables.nodes.flags
159+
# Zero out other sample flags
160+
flags = np.bitwise_and(flags, ~tskit.NODE_IS_SAMPLE)
161+
flags[sample] |= tskit.NODE_IS_SAMPLE
162+
self.tables.nodes.flags = flags.astype(np.uint32)
162163

164+
self.node_id_map[:] = np.arange(ts.num_nodes)
163165
for sample_id in sample:
164166
self.add_ancestry(sample_id, 0, self.sequence_length, sample_id)
165167
else:
@@ -178,10 +180,11 @@ def record_node(self, input_id):
178180
"""
179181
node = self.ts.node(input_id)
180182
flags = node.flags
181-
# Need to zero out the sample flag
182-
flags &= ~tskit.NODE_IS_SAMPLE
183-
if self.is_sample[input_id]:
184-
flags |= tskit.NODE_IS_SAMPLE
183+
if self.update_sample_flags:
184+
# Need to zero out the sample flag
185+
flags &= ~tskit.NODE_IS_SAMPLE
186+
if self.is_sample[input_id]:
187+
flags |= tskit.NODE_IS_SAMPLE
185188
output_id = self.tables.nodes.append(node.replace(flags=flags))
186189
self.node_id_map[input_id] = output_id
187190
return output_id

python/tests/test_highlevel.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -254,19 +254,19 @@ def get_internal_samples_examples():
254254
# Set all nodes to be samples.
255255
flags[:] = tskit.NODE_IS_SAMPLE
256256
nodes.flags = flags
257-
ret.append(("all nodes samples", tables.tree_sequence()))
257+
ret.append(("all_nodes_samples", tables.tree_sequence()))
258258

259259
# Set just internal nodes to be samples.
260260
flags[:] = 0
261261
flags[n:] = tskit.NODE_IS_SAMPLE
262262
nodes.flags = flags
263-
ret.append(("internal nodes samples", tables.tree_sequence()))
263+
ret.append(("internal_nodes_samples", tables.tree_sequence()))
264264

265265
# Set a mixture of internal and leaf samples.
266266
flags[:] = 0
267267
flags[n // 2 : n + n // 2] = tskit.NODE_IS_SAMPLE
268268
nodes.flags = flags
269-
ret.append(("mixture of internal and leaf samples", tables.tree_sequence()))
269+
ret.append(("mixed_internal_leaf_samples", tables.tree_sequence()))
270270
return ret
271271

272272

@@ -281,7 +281,7 @@ def get_decapitated_examples():
281281

282282
ts = msprime.simulate(20, recombination_rate=1, random_seed=1234)
283283
assert ts.num_trees > 2
284-
ret.append(("decapitate recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4)))
284+
ret.append(("decapitate_recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4)))
285285
return ret
286286

287287

@@ -302,7 +302,7 @@ def get_bottleneck_examples():
302302
demographic_events=bottlenecks,
303303
random_seed=n,
304304
)
305-
yield (f"bottleneck n={n}", ts)
305+
yield (f"bottleneck_n={n}", ts)
306306

307307

308308
def get_back_mutation_examples():
@@ -337,13 +337,13 @@ def make_example_tree_sequences():
337337
)
338338
ts = tsutil.insert_random_ploidy_individuals(ts, 4, seed=seed)
339339
yield (
340-
f"n={n} m={m} rho={rho}",
340+
f"n={n}_m={m}_rho={rho}",
341341
tsutil.add_random_metadata(ts, seed=seed),
342342
)
343343
seed += 1
344344
for name, ts in get_bottleneck_examples():
345345
yield (
346-
f"{name} mutated",
346+
f"{name}_mutated",
347347
msprime.mutate(
348348
ts,
349349
rate=0.1,
@@ -352,7 +352,7 @@ def make_example_tree_sequences():
352352
),
353353
)
354354
ts = tskit.Tree.generate_balanced(8).tree_sequence
355-
yield ("rev node order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1)))
355+
yield ("rev_node_order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1)))
356356
ts = msprime.sim_ancestry(
357357
8, sequence_length=40, recombination_rate=0.1, random_seed=seed
358358
)
@@ -361,20 +361,20 @@ def make_example_tree_sequences():
361361
ts = tables.tree_sequence()
362362
assert ts.num_trees > 1
363363
yield (
364-
"back mutations",
364+
"back_mutations",
365365
tsutil.insert_branch_mutations(ts, mutations_per_branch=2),
366366
)
367367
ts = tsutil.insert_multichar_mutations(ts)
368368
yield ("multichar", ts)
369-
yield ("multichar w/ metadata", tsutil.add_random_metadata(ts))
369+
yield ("multichar_no_metadata", tsutil.add_random_metadata(ts))
370370
tables = ts.dump_tables()
371371
tables.nodes.flags = np.zeros_like(tables.nodes.flags)
372-
yield ("no samples", tables.tree_sequence()) # no samples
372+
yield ("no_samples", tables.tree_sequence()) # no samples
373373
tables = ts.dump_tables()
374374
tables.edges.clear()
375-
yield ("empty tree", tables.tree_sequence()) # empty tree
375+
yield ("empty_tree", tables.tree_sequence()) # empty tree
376376
yield (
377-
"empty ts",
377+
"empty_ts",
378378
tskit.TableCollection(sequence_length=1).tree_sequence(),
379379
) # empty tree seq
380380
yield ("all_fields", tsutil.all_fields_ts())
@@ -384,6 +384,8 @@ def make_example_tree_sequences():
384384

385385

386386
def get_example_tree_sequences(pytest_params=True):
387+
# NOTE: pytest names should not contain spaces and be shell safe so
388+
# that they can be easily specified on the command line.
387389
if pytest_params:
388390
return [pytest.param(ts, id=name) for name, ts in _examples]
389391
else:
@@ -2785,6 +2787,19 @@ def test_simplify_migrations_fails(self):
27852787
with pytest.raises(_tskit.LibraryError):
27862788
ts.simplify()
27872789

2790+
@pytest.mark.parametrize("ts", get_example_tree_sequences())
2791+
def test_no_update_sample_flags_no_filter_nodes(self, ts):
2792+
# Can't simplify edges with metadata
2793+
if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None):
2794+
k = min(ts.num_samples, 3)
2795+
subset = ts.samples()[:k]
2796+
ts1 = ts.simplify(subset)
2797+
ts2 = ts.simplify(subset, update_sample_flags=False, filter_nodes=False)
2798+
assert ts1.num_samples == len(subset)
2799+
assert ts2.num_samples == ts.num_samples
2800+
assert ts1.num_edges == ts2.num_edges
2801+
assert ts2.tables.nodes == ts.tables.nodes
2802+
27882803

27892804
class TestMinMaxTime:
27902805
def get_example_tree_sequence(self, use_unknown_time):

python/tests/test_lowlevel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def test_simplify_bad_args(self):
349349
tc.simplify([0, 1], filter_populations="x")
350350
with pytest.raises(TypeError):
351351
tc.simplify([0, 1], filter_nodes="x")
352+
with pytest.raises(TypeError):
353+
tc.simplify([0, 1], update_sample_flags="x")
352354
with pytest.raises(_tskit.LibraryError):
353355
tc.simplify([0, -1])
354356

@@ -360,6 +362,7 @@ def test_simplify_bad_args(self):
360362
"filter_populations",
361363
"filter_individuals",
362364
"filter_nodes",
365+
"update_sample_flags",
363366
"reduce_to_site_topology",
364367
"keep_unary",
365368
"keep_unary_in_individuals",

0 commit comments

Comments
 (0)