Skip to content

Randomly resolve polytomies #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hyanwong opened this issue Aug 27, 2020 · 10 comments · Fixed by #815
Closed

Randomly resolve polytomies #809

hyanwong opened this issue Aug 27, 2020 · 10 comments · Fixed by #815
Labels
C API Issue is about the C API enhancement New feature or request Python API Issue is about the Python API

Comments

@hyanwong
Copy link
Member

hyanwong commented Aug 27, 2020

I've finally got some working code to randomly resolve polytomies in a tree sequence. It seems to work for the admittedly small sample of inferred trees that I have tried. It tries to be clever by not resolving per tree, but per polytomy (hence if an identical polytomy spans several trees, it will only resolve the polytomy once, creating fewer edges than a per-tree approch, which should mean it scales to larger sample sizes better). I'm posting it here so that:

a. I don't lose it (!)
b. We can decide if we want something like this in the base tskit library, as a method on a tree sequence.
c. @hyl317 can use it if he wants, for ARGweaver compatibility (although his current solution may be faster)

Obviously it needs a fair bit of tidying up. It's also slow, but there's also considerable scope for optimization, I think. It requires the PR I made at #787

import collections
import itertools

import numpy as np
import tskit

def resolve_polytomy(parent_node_id, child_ids, new_nodes_by_time, rng):
    """
    For a polytomy and list of child node ids, return a list of (child, parent) tuples,
    describing a bifurcating tree, rooted at parent_node_id, where the new_nodes_by_time
    have been used to break polytomies. All possible topologies should be equiprobable.
    """
    assert len(child_ids) == len(new_nodes_by_time) + 2
    edges = [[child_ids[0], None], ]  # Introduce a single edge that will be deleted later
    edge_choice = rng.integers(0, np.arange(1, len(child_ids) * 2 - 1, 2))
    tmp_new_node_lab = [parent_node_id] + new_nodes_by_time
    assert len(edge_choice) == len(child_ids) - 1
    for node_lab, child_id, target_edge_id in zip(tmp_new_node_lab, child_ids[1:], edge_choice):
        target_edge = edges[target_edge_id]
        # print("target", target_edge)
        # Insert to keep edges in time order of parent
        edges.insert(target_edge_id, [child_id, node_lab])
        edges.insert(target_edge_id, [target_edge[0], node_lab])
        target_edge[0] = node_lab
    # We need to re-map the internal nodes so that they are in time order
    real_node = iter(new_nodes_by_time)
    edges.pop() # remove the unary node at the top
    node_map = {c: c for c in child_ids}
    # print("orig_edges", edges)
    # print("parent IDs to allocate", new_nodes_by_time)
    # last edge should have the highest node
    node_map[edges[-1][1]] = parent_node_id
    for e in reversed(edges):
        # edges should be in time order - the oldest one can be give the parent_node_id
        if e[1] not in node_map:
            node_map[e[1]] = next(real_node)
        if e[0] not in node_map:
            node_map[e[0]] = next(real_node)
        e[0] = node_map[e[0]]
        e[1] = node_map[e[1]]
    # print("mapped edges", edges)
    assert len(node_map) == len(new_nodes_by_time) + len(child_ids) + 1
    return edges


def resolve_polytomies(ts, *, epsilon=1e-10, random_seed=None):
    """
    For a given parent node, an edge in or an edge out signifies a change in children
    Each time such a change happens, we cut all existing edges with that parent,
    and add the previous portion in to the new edge table. If, previously, there were
    3 or more children for this node, we break the polytomy at random
    """
    rng = np.random.default_rng(seed=random_seed)

    tables = ts.dump_tables()
    edges_table = tables.edges
    nodes_table = tables.nodes
    # Store the left of the existing edges, as we will need to change it if the edge is split
    existing_edges_left = edges_table.left
    # Keep these arrays for handy reading later
    existing_edges_right = edges_table.right
    existing_edges_parent = edges_table.parent
    existing_edges_child = edges_table.child
    existing_node_time = nodes_table.time

    edges_table.clear()

    edges_for_node = collections.defaultdict(set)  # The edge ids dangling from each active node
    nodes_changed = set()

    for interval, e_out, e_in in ts.edge_diffs(include_terminal=True):
        for edge in itertools.chain(e_out, e_in):
            if edge.parent != tskit.NULL:
                nodes_changed.add(edge.parent)

        pos = interval[0]
        for parent_node in nodes_changed:
            child_edge_ids = edges_for_node[parent_node]
            if len(child_edge_ids) >= 3:
                # We have a previous polytomy to break
                parent_time = existing_node_time[parent_node]
                new_nodes = []
                child_ids = existing_edges_child[list(child_edge_ids)]
                remaining_edges = child_edge_ids.copy()
                left = None
                max_time = 0
                for edge_id, child_id in zip(child_edge_ids, child_ids):
                    max_time = max(max_time, existing_node_time[child_id])
                    if left is None:
                        left = existing_edges_left[edge_id]
                    else:
                        assert left == existing_edges_left[edge_id]
                    if existing_edges_right[edge_id] > interval[0]:
                        # make sure we carry on the edge after this polytomy
                        existing_edges_left[edge_id] = pos

                # ADD THE PREVIOUS EDGE SEGMENTS
                dt = min((parent_time - max_time)/(len(child_ids)*2), epsilon)
                # Each broken polytomy of degree N introduces N-2 extra nodes, each at a time
                # slighly less than the parent_time. Create new nodes in order of decreasing time
                new_nodes = [nodes_table.add_row(time=parent_time - (i * dt))
                             for i in range(1, len(child_ids) - 1)]
                # print("new_nodes:", new_nodes, [tables.nodes[n].time for n in new_nodes])
                for new_edge in resolve_polytomy(parent_node, child_ids, new_nodes, rng):
                    edges_table.add_row(left=left, right=pos, child=new_edge[0], parent=new_edge[1])
                    # print("new_edge: left={}, right={}, child={}, parent={}".format(
                    #     left, pos, new_edge[0], new_edge[1]))
            else:
                # Previous node was not a polytomy - just add the edges_out, with modified left
                for edge_id in child_edge_ids:
                    if existing_edges_right[edge_id] == pos:  # this edge has just gone out
                        edges_table.add_row(
                            left=existing_edges_left[edge_id],
                            right=pos,
                            parent=parent_node,
                            child=existing_edges_child[edge_id],
                        )

        for edge in e_out: 
            if edge.parent != tskit.NULL:
                # print("REMOVE", edge.id)
                edges_for_node[edge.parent].remove(edge.id)
        for edge in e_in:
            if edge.parent != tskit.NULL:
                # print("ADD", edge.id)
                edges_for_node[edge.parent].add(edge.id)            

        # Chop if we have created a polytomy: the polytomy itself will be resolved
        # at a future iteration, when any of the edges move in or out of the polytomy
        while nodes_changed:
            node = nodes_changed.pop()
            edge_ids = edges_for_node[node]
            # print("Looking at", node)

            if len(edge_ids) == 0:
                del edges_for_node[node]
            # if this node has changed *to* a polytomy, we need to cut all of the
            # child edges that were previously present by adding the previous segment
            # and left-truncating
            elif len(edge_ids) >= 3:
                # print("Polytomy at", node, " breaking edges")
                for edge_id in edge_ids:
                    if existing_edges_left[edge_id] < interval[0]:
                        tables.edges.add_row(
                            left=existing_edges_left[edge_id],
                            right=interval[0],
                            parent=existing_edges_parent[edge_id],
                            child=existing_edges_child[edge_id],
                        )
                    existing_edges_left[edge_id] = interval[0]
    assert len(edges_for_node) == 0

    tables.edges.squash()
    tables.sort() # Shouldn't need to do this: https://github.com/tskit-dev/tskit/issues/808

    return tables.tree_sequence()

The code can be tested using something like this:

import io
import collections
import tqdm
import time

import msprime
import tsinfer

### Check equiprobable

nodes_polytomy_4 = """\
id      is_sample   population      time
0       1       0               0.00000000000000
1       1       0               0.00000000000000
2       1       0               0.00000000000000
3       1       0               0.00000000000000
4       0       0               1.00000000000000
"""
edges_polytomy_4 = """\
id      left            right           parent  child
0       0.00000000      1.00000000      4       0,1,2,3
"""

poly_ts = tskit.load_text(
    nodes=io.StringIO(nodes_polytomy_4),
    edges=io.StringIO(edges_polytomy_4),
    strict=False,
)

trees = collections.Counter()
for seed in tqdm.trange(1, 100000):
    ts2 = resolve_polytomies(poly_ts, random_seed=seed)
    trees.update([ts2.first().rank()])
print(trees)

### Time on 10000 tip inferred TS and check

ts_old = msprime.simulate(10000, recombination_rate=100, mutation_rate=10, random_seed=123)
sd = tsinfer.SampleData.from_tree_sequence(ts_old, use_times=False)
ts = tsinfer.infer(sd)
print(f"{ts.num_samples} tips ; {ts.num_trees} trees")
start = time.time()
ts2 = resolve_polytomies(ts, random_seed=1)
print("Time (s):", time.time()-start)
for tree in ts2.trees():
    for node in tree.nodes():
        assert tree.num_children(node) < 3
@benjeffery
Copy link
Member

Cool, @hyanwong! I assume we will want a method like this in tskit at some point.

One thing I wanted to flag - adding new nodes may invalidate the mutation time requirements if the tree sequence has them, for example if a mutation is just above one of the polytomies to be broken then it would need to be moved above the oldest node in the new set that replaces the polytomy. Clearly an edge-case but thought I should flag!

@hyanwong
Copy link
Member Author

That's a good point @benjeffery, thanks. Actually, I break the polytomy below the focal node, so that the polytomy node does not change time: it's mutations on any nodes below the polytomy which will need checking. I think I can see how to do this: at the moment I set the delta time (dt) value to the smallest difference between the child node times and the parent. I need to also check the mutation times above each child.

What do we do if the time differences are so small that they start to encounter floating point accuracy errors. Or is this so unlikely that we don't care?

@benjeffery
Copy link
Member

What do we do if the time differences are so small that they start to encounter floating point accuracy errors. Or is this so unlikely that we don't care?

I think we detect the situation and error out. Only other option is some horrible recursive shuffling right?

@hyanwong
Copy link
Member Author

hyanwong commented Aug 28, 2020

Yes, I guess so. The question is whether we actively look for this and raise a specific error, or just expect it to bomb out (with e.g. time[parent] <= time[child]) when we try to convert to a tree sequence.

Shall I actually work this up into a PR, if you think it's a useful extra method on a tree_sequence? I guess , following other examples, I could make it an in-place method on a TableCollection, and then create a idempotent (is that the right word?) version for a tree sequence.

@hyl317
Copy link

hyl317 commented Aug 28, 2020 via email

@hyanwong
Copy link
Member Author

Hi @hyl317 - to try it out, until #787 is merged you'll need to install my branch directly, e.g.

python3 -m pip install git+https://github.com/hyanwong@edge_diff_include_terminal#subdirectory=python

Then try out the code.

ISWYM about the specific error, it's just a bit more work to do properly!

@hyanwong
Copy link
Member Author

Hi @hyl317 and @awohns - you can now test this using a single install via the PR I just made. The name has changed (for the time being) to randomly_split_polytomies:

python3 -m pip install git+https://github.com/hyanwong@random-split-polytomy#subdirectory=python

Simply call like

ts_binary = ts.randomly_split_polytomies(random_seed=1)

@brianzhang01
Copy link
Member

I'd be happy to help review code for this.

@hyanwong
Copy link
Member Author

Great, thanks @brianzhang01 - that would be really useful. The PR is at #815 but I don't know if we want to implement out own PRNG, so that we can have an equivalent C function.

@benjeffery benjeffery added enhancement New feature or request C API Issue is about the C API Python API Issue is about the Python API labels Sep 29, 2020
@mergify mergify bot closed this as completed in #815 Nov 21, 2020
@hyanwong
Copy link
Member Author

hyanwong commented Sep 5, 2024

Just to note that the opposite: collapse edges into polytomies, only retaining those supported by mutations, is at #2926

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
C API Issue is about the C API enhancement New feature or request Python API Issue is about the Python API
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants