Skip to content

Commit 21da9b0

Browse files
authored
Merge pull request #1190 from hyanwong/keep_unary_in_individuals_python
Implement keep_unary_in_individuals in Python
2 parents a23b5b4 + e49fe6d commit 21da9b0

12 files changed

+174
-38
lines changed

c/tskit/tables.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6982,7 +6982,7 @@ simplifier_merge_ancestors(simplifier_t *self, tsk_id_t input_id)
69826982
keep_unary = true;
69836983
}
69846984
if ((self->options & TSK_KEEP_UNARY_IN_INDIVIDUALS)
6985-
&& (self->tables->nodes.individual[input_id] != TSK_NULL)) {
6985+
&& (self->input_tables.nodes.individual[input_id] != TSK_NULL)) {
69866986
keep_unary = true;
69876987
}
69886988

python/_tskitmodule.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4960,18 +4960,20 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
49604960
int filter_individuals = false;
49614961
int filter_populations = false;
49624962
int keep_unary = false;
4963+
int keep_unary_in_individuals = false;
49634964
int keep_input_roots = false;
49644965
int reduce_to_site_topology = false;
4965-
static char *kwlist[]
4966-
= { "samples", "filter_sites", "filter_populations", "filter_individuals",
4967-
"reduce_to_site_topology", "keep_unary", "keep_input_roots", NULL };
4966+
static char *kwlist[] = { "samples", "filter_sites", "filter_populations",
4967+
"filter_individuals", "reduce_to_site_topology", "keep_unary",
4968+
"keep_unary_in_individuals", "keep_input_roots", NULL };
49684969

49694970
if (TableCollection_check_state(self) != 0) {
49704971
goto out;
49714972
}
4972-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiii", kwlist, &samples,
4973+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|iiiiiii", kwlist, &samples,
49734974
&filter_sites, &filter_populations, &filter_individuals,
4974-
&reduce_to_site_topology, &keep_unary, &keep_input_roots)) {
4975+
&reduce_to_site_topology, &keep_unary, &keep_unary_in_individuals,
4976+
&keep_input_roots)) {
49754977
goto out;
49764978
}
49774979
samples_array = (PyArrayObject *) PyArray_FROMANY(
@@ -4996,6 +4998,9 @@ TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds)
49964998
if (keep_unary) {
49974999
options |= TSK_KEEP_UNARY;
49985000
}
5001+
if (keep_unary_in_individuals) {
5002+
options |= TSK_KEEP_UNARY_IN_INDIVIDUALS;
5003+
}
49995004
if (keep_input_roots) {
50005005
options |= TSK_KEEP_INPUT_ROOTS;
50015006
}

python/tests/simplify.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
filter_populations=True,
110110
filter_individuals=True,
111111
keep_unary=False,
112+
keep_unary_in_individuals=False,
112113
keep_input_roots=False,
113114
):
114115
self.ts = ts
@@ -119,6 +120,7 @@ def __init__(
119120
self.filter_populations = filter_populations
120121
self.filter_individuals = filter_individuals
121122
self.keep_unary = keep_unary
123+
self.keep_unary_in_individuals = keep_unary_in_individuals
122124
self.keep_input_roots = keep_input_roots
123125
self.num_mutations = ts.num_mutations
124126
self.input_sites = list(ts.sites())
@@ -295,7 +297,10 @@ def merge_labeled_ancestors(self, S, input_id):
295297
if is_sample:
296298
self.record_edge(left, right, output_id, ancestry_node)
297299
ancestry_node = output_id
298-
elif self.keep_unary:
300+
elif self.keep_unary or (
301+
self.keep_unary_in_individuals
302+
and self.ts.node(input_id).individual >= 0
303+
):
299304
if output_id == -1:
300305
output_id = self.record_node(input_id)
301306
self.record_edge(left, right, output_id, ancestry_node)
@@ -308,7 +313,10 @@ def merge_labeled_ancestors(self, S, input_id):
308313
if is_sample and left != prev_right:
309314
# Fill in any gaps in the ancestry for the sample
310315
self.add_ancestry(input_id, prev_right, left, output_id)
311-
if self.keep_unary:
316+
if self.keep_unary or (
317+
self.keep_unary_in_individuals
318+
and self.ts.node(input_id).individual >= 0
319+
):
312320
ancestry_node = output_id
313321
self.add_ancestry(input_id, left, right, ancestry_node)
314322
prev_right = right
@@ -757,7 +765,6 @@ def print_state(self):
757765

758766
samples = list(map(int, sys.argv[3:]))
759767

760-
# When keep_unary = True
761768
print("When keep_unary = True:")
762769
s = Simplifier(ts, samples, keep_unary=True)
763770
# s.print_state()
@@ -768,8 +775,7 @@ def print_state(self):
768775
print(tables.sites)
769776
print(tables.mutations)
770777

771-
# When keep_unary = False
772-
print("\nWhen keep_unary = False:")
778+
print("\nWhen keep_unary = False")
773779
s = Simplifier(ts, samples, keep_unary=False)
774780
# s.print_state()
775781
tss, _ = s.simplify()
@@ -779,6 +785,16 @@ def print_state(self):
779785
print(tables.sites)
780786
print(tables.mutations)
781787

788+
print("\nWhen keep_unary_in_individuals = True")
789+
s = Simplifier(ts, samples, keep_unary_in_individuals=True)
790+
# s.print_state()
791+
tss, _ = s.simplify()
792+
tables = tss.dump_tables()
793+
print(tables.nodes)
794+
print(tables.edges)
795+
print(tables.sites)
796+
print(tables.mutations)
797+
782798
elif class_to_implement == "AncestorMap":
783799

784800
samples = sys.argv[3]

python/tests/test_lowlevel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ def test_simplify_bad_args(self):
263263
tc.simplify("asdf")
264264
with pytest.raises(TypeError):
265265
tc.simplify([0, 1], keep_unary="sdf")
266+
with pytest.raises(TypeError):
267+
tc.simplify([0, 1], keep_unary_in_individuals="abc")
266268
with pytest.raises(TypeError):
267269
tc.simplify([0, 1], keep_input_roots="sdf")
268270
with pytest.raises(TypeError):

python/tests/test_tables.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2368,12 +2368,13 @@ def wf_sim_with_individual_metadata(self):
23682368
9,
23692369
10,
23702370
seed=1,
2371-
deep_history=True,
2371+
deep_history=False,
23722372
initial_generation_samples=False,
23732373
num_loci=5,
23742374
record_individuals=True,
23752375
)
23762376
assert tables.individuals.num_rows > 50
2377+
assert np.all(tables.nodes.individual >= 0)
23772378
individuals_copy = tables.copy().individuals
23782379
tables.individuals.clear()
23792380
tables.individuals.metadata_schema = tskit.MetadataSchema({"codec": "json"})
@@ -2404,6 +2405,26 @@ def test_individual_parent_mapping(self, wf_sim_with_individual_metadata):
24042405
)
24052406
assert set(tables.individuals.parents) != {tskit.NULL}
24062407

2408+
def verify_complete_genetic_pedigree(self, tables):
2409+
ts = tables.tree_sequence()
2410+
for edge in ts.edges():
2411+
child = ts.individual(ts.node(edge.child).individual)
2412+
parent = ts.individual(ts.node(edge.parent).individual)
2413+
assert parent.id in child.parents
2414+
assert parent.metadata["original_id"] in child.metadata["original_parents"]
2415+
2416+
def test_no_complete_genetic_pedigree(self, wf_sim_with_individual_metadata):
2417+
tables = wf_sim_with_individual_metadata.copy()
2418+
tables.simplify() # Will remove intermediate individuals
2419+
with pytest.raises(AssertionError):
2420+
self.verify_complete_genetic_pedigree(tables)
2421+
2422+
def test_complete_genetic_pedigree(self, wf_sim_with_individual_metadata):
2423+
for params in [{"keep_unary": True}, {"keep_unary_in_individuals": True}]:
2424+
tables = wf_sim_with_individual_metadata.copy()
2425+
tables.simplify(**params) # Keep intermediate individuals
2426+
self.verify_complete_genetic_pedigree(tables)
2427+
24072428
def test_shuffled_individual_parent_mapping(self, wf_sim_with_individual_metadata):
24082429
tables = wf_sim_with_individual_metadata.copy()
24092430
tsutil.shuffle_tables(

python/tests/test_topology.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,19 +2363,20 @@ def test_ladder_tree(self):
23632363
def verify_unary_tree_sequence(self, ts):
23642364
"""
23652365
Take the specified tree sequence and produce an equivalent in which
2366-
unary records have been interspersed.
2366+
unary records have been interspersed, every other with an associated individual
23672367
"""
23682368
assert ts.num_trees > 2
23692369
assert ts.num_mutations > 2
23702370
tables = ts.dump_tables()
23712371
next_node = ts.num_nodes
23722372
node_times = {j: node.time for j, node in enumerate(ts.nodes())}
23732373
edges = []
2374-
for e in ts.edges():
2374+
for i, e in enumerate(ts.edges()):
23752375
node = ts.node(e.parent)
23762376
t = node.time - 1e-14 # Arbitrary small value.
23772377
next_node = len(tables.nodes)
2378-
tables.nodes.add_row(time=t, population=node.population)
2378+
indiv = tables.individuals.add_row() if i % 2 == 0 else tskit.NULL
2379+
tables.nodes.add_row(time=t, population=node.population, individual=indiv)
23792380
edges.append(
23802381
tskit.Edge(left=e.left, right=e.right, parent=next_node, child=e.child)
23812382
)
@@ -2398,11 +2399,16 @@ def verify_unary_tree_sequence(self, ts):
23982399
self.assert_haplotypes_equal(ts, ts_simplified)
23992400
self.assert_variants_equal(ts, ts_simplified)
24002401
assert len(list(ts.edge_diffs())) == ts.num_trees
2402+
assert 0 < ts_new.num_individuals < ts_new.num_nodes
24012403

2402-
for keep_unary in [True, False]:
2403-
s = tests.Simplifier(ts, ts.samples(), keep_unary=keep_unary)
2404+
for params in [
2405+
{"keep_unary": False, "keep_unary_in_individuals": False},
2406+
{"keep_unary": True, "keep_unary_in_individuals": False},
2407+
{"keep_unary": False, "keep_unary_in_individuals": True},
2408+
]:
2409+
s = tests.Simplifier(ts_new, ts_new.samples(), **params)
24042410
py_ts, py_node_map = s.simplify()
2405-
lib_ts, lib_node_map = ts.simplify(keep_unary=keep_unary, map_nodes=True)
2411+
lib_ts, lib_node_map = ts_new.simplify(map_nodes=True, **params)
24062412
py_tables = py_ts.dump_tables()
24072413
py_tables.provenances.clear()
24082414
lib_tables = lib_ts.dump_tables()

python/tests/test_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_ploidy_2_reversed(self):
138138
ts = msprime.simulate(10, random_seed=1)
139139
assert ts.num_individuals == 0
140140
samples = ts.samples()[::-1]
141-
ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2)
141+
ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2)
142142
assert ts.num_individuals == 5
143143
for j, ind in enumerate(ts.individuals()):
144144
assert list(ind.nodes) == [samples[2 * j + 1], samples[2 * j]]

python/tests/test_vcf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,14 @@ def test_simple_infinite_sites_ploidy_2(self):
212212
def test_simple_infinite_sites_ploidy_2_reversed_samples(self):
213213
ts = msprime.simulate(10, mutation_rate=1, random_seed=2)
214214
samples = ts.samples()[::-1]
215-
ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2)
215+
ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2)
216216
assert ts.num_sites > 2
217217
self.verify(ts)
218218

219219
def test_simple_infinite_sites_ploidy_2_even_samples(self):
220220
ts = msprime.simulate(20, mutation_rate=1, random_seed=2)
221221
samples = ts.samples()[0::2]
222-
ts = tsutil.insert_individuals(ts, samples=samples, ploidy=2)
222+
ts = tsutil.insert_individuals(ts, nodes=samples, ploidy=2)
223223
assert ts.num_sites > 2
224224
self.verify(ts)
225225

python/tests/test_wright_fisher.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,58 @@ def test_simplify_tables(self, ts, nsamples):
592592
other_tables.provenances.clear()
593593
assert tables == other_tables
594594
self.verify_simplify(ts, small_ts, sub_samples, node_map)
595+
596+
@pytest.mark.parametrize("ts", wf_sims)
597+
@pytest.mark.parametrize("nsamples", [2, 5])
598+
def test_simplify_keep_unary(self, ts, nsamples):
599+
np.random.seed(123)
600+
ts = tsutil.mark_metadata(ts, "nodes")
601+
sub_samples = random.sample(list(ts.samples()), min(nsamples, ts.num_samples))
602+
random_nodes = np.random.choice(ts.num_nodes, ts.num_nodes // 2)
603+
ts = tsutil.insert_individuals(ts, random_nodes)
604+
ts = tsutil.mark_metadata(ts, "individuals")
605+
606+
for params in [{}, {"keep_unary": True}, {"keep_unary_in_individuals": True}]:
607+
sts = ts.simplify(sub_samples, **params)
608+
# check samples match
609+
assert sts.num_samples == len(sub_samples)
610+
for n, sn in zip(sub_samples, sts.samples()):
611+
assert ts.node(n).metadata == sts.node(sn).metadata
612+
613+
# check that nodes are correctly retained: only nodes ancestral to
614+
# retained samples, and: by default, only coalescent events; if
615+
# keep_unary_in_individuals then also nodes in individuals; if
616+
# keep_unary then all such nodes.
617+
for t in ts.trees(tracked_samples=sub_samples):
618+
st = sts.at(t.interval[0])
619+
visited = [False for _ in sts.nodes()]
620+
for n, sn in zip(sub_samples, sts.samples()):
621+
last_n = t.num_tracked_samples(n)
622+
while n != tskit.NULL:
623+
ind = ts.node(n).individual
624+
keep = False
625+
if t.num_tracked_samples(n) > last_n:
626+
# a coalescent node
627+
keep = True
628+
if "keep_unary_in_individuals" in params and ind != tskit.NULL:
629+
keep = True
630+
if "keep_unary" in params:
631+
keep = True
632+
if (n in sub_samples) or keep:
633+
visited[sn] = True
634+
assert sn != tskit.NULL
635+
assert ts.node(n).metadata == sts.node(sn).metadata
636+
assert t.num_tracked_samples(n) == st.num_samples(sn)
637+
if ind != tskit.NULL:
638+
sind = sts.node(sn).individual
639+
assert sind != tskit.NULL
640+
assert (
641+
ts.individual(ind).metadata
642+
== sts.individual(sind).metadata
643+
)
644+
sn = st.parent(sn)
645+
last_n = t.num_tracked_samples(n)
646+
n = t.parent(n)
647+
st_nodes = list(st.nodes())
648+
for k, v in enumerate(visited):
649+
assert v == (k in st_nodes)

python/tests/tsutil.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -242,30 +242,44 @@ def insert_random_ploidy_individuals(ts, max_ploidy=5, max_dimension=3, seed=1):
242242
return tables.tree_sequence()
243243

244244

245-
def insert_individuals(ts, samples=None, ploidy=1):
245+
def insert_individuals(ts, nodes=None, ploidy=1):
246246
"""
247247
Inserts individuals into the tree sequence using the specified list
248-
of samples (or all samples if None) with the specified ploidy by combining
249-
ploidy-sized chunks of the list.
248+
of node (or use all sample nodes if None) with the specified ploidy by combining
249+
ploidy-sized chunks of the list. Add metadata to the individuals so we can
250+
track them
250251
"""
251-
if samples is None:
252-
samples = ts.samples()
253-
if len(samples) % ploidy != 0:
254-
raise ValueError("number of samples must be divisible by ploidy")
252+
if nodes is None:
253+
nodes = ts.samples()
254+
assert len(nodes) % ploidy == 0 # To allow mixed ploidies we could comment this out
255255
tables = ts.dump_tables()
256256
tables.individuals.clear()
257257
individual = tables.nodes.individual[:]
258258
individual[:] = tskit.NULL
259259
j = 0
260-
while j < len(samples):
261-
nodes = samples[j : j + ploidy]
262-
ind_id = tables.individuals.add_row()
263-
individual[nodes] = ind_id
260+
while j < len(nodes):
261+
nodes_in_individual = nodes[j : min(len(nodes), j + ploidy)]
262+
# should we warn here if nodes[j : j + ploidy] are at different times?
263+
# probably not, as although this is unusual, it is actually allowed
264+
ind_id = tables.individuals.add_row(
265+
metadata=f"orig_id {tables.individuals.num_rows}".encode()
266+
)
267+
individual[nodes_in_individual] = ind_id
264268
j += ploidy
265269
tables.nodes.individual = individual
266270
return tables.tree_sequence()
267271

268272

273+
def mark_metadata(ts, table_name, prefix="orig_id:"):
274+
"""
275+
Add metadata to all rows of the form prefix + row_number
276+
"""
277+
tables = ts.dump_tables()
278+
table = getattr(tables, table_name)
279+
table.packset_metadata([(prefix + str(i)).encode() for i in range(table.num_rows)])
280+
return tables.tree_sequence()
281+
282+
269283
def permute_nodes(ts, node_map):
270284
"""
271285
Returns a copy of the specified tree sequence such that the nodes are

python/tskit/tables.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,6 +2493,7 @@ def simplify(
24932493
filter_individuals=True,
24942494
filter_sites=True,
24952495
keep_unary=False,
2496+
keep_unary_in_individuals=None,
24962497
keep_input_roots=False,
24972498
record_provenance=True,
24982499
filter_zero_mutation_sites=None, # Deprecated alias for filter_sites
@@ -2538,9 +2539,14 @@ def simplify(
25382539
not referenced by mutations after simplification; new site IDs are
25392540
allocated sequentially from zero. If False, the site table will not
25402541
be altered in any way. (Default: True)
2541-
:param bool keep_unary: If True, any unary nodes (i.e. nodes with exactly
2542-
one child) that exist on the path from samples to root will be preserved
2543-
in the output. (Default: False)
2542+
:param bool keep_unary: If True, preserve unary nodes (i.e. nodes with
2543+
exactly one child) that exist on the path from samples to root.
2544+
(Default: False)
2545+
:param bool keep_unary_in_individuals: If True, preserve unary nodes
2546+
that exist on the path from samples to root, but only if they are
2547+
associated with an individual in the individuals table. Cannot be
2548+
specified at the same time as ``keep_unary``. (Default: ``None``,
2549+
equivalent to False)
25442550
:param bool keep_input_roots: Whether to retain history ancestral to the
25452551
MRCA of the samples. If ``False``, no topology older than the MRCAs of the
25462552
samples will be included. If ``True`` the roots of all trees in the returned
@@ -2568,13 +2574,17 @@ def simplify(
25682574
].astype(np.int32)
25692575
else:
25702576
samples = util.safe_np_int_cast(samples, np.int32)
2577+
if keep_unary_in_individuals is None:
2578+
keep_unary_in_individuals = False
2579+
25712580
node_map = self._ll_tables.simplify(
25722581
samples,
25732582
filter_sites=filter_sites,
25742583
filter_individuals=filter_individuals,
25752584
filter_populations=filter_populations,
25762585
reduce_to_site_topology=reduce_to_site_topology,
25772586
keep_unary=keep_unary,
2587+
keep_unary_in_individuals=keep_unary_in_individuals,
25782588
keep_input_roots=keep_input_roots,
25792589
)
25802590
if record_provenance:

0 commit comments

Comments
 (0)