Skip to content

PBS #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 4, 2020
Merged

PBS #368

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sgkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .stats.hwe import hardy_weinberg_test
from .stats.pc_relate import pc_relate
from .stats.pca import pca
from .stats.popgen import Fst, Tajimas_D, divergence, diversity
from .stats.popgen import Fst, Tajimas_D, divergence, diversity, pbs
from .stats.preprocessing import filter_partial_calls
from .stats.regenie import regenie
from .testing import simulate_genotype_call_dataset
Expand All @@ -45,6 +45,7 @@
"divergence",
"Fst",
"Tajimas_D",
"pbs",
"pc_relate",
"simulate_genotype_call_dataset",
"variables",
Expand Down
104 changes: 104 additions & 0 deletions sgkit/stats/popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,107 @@ def Tajimas_D(

new_ds = Dataset({variables.stat_Tajimas_D: D})
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)


# c = cohorts
@guvectorize( # type: ignore
["void(float32[:, :], float32[:,:,:])", "void(float64[:, :], float64[:,:,:])"],
"(c,c)->(c,c,c)",
nopython=True,
cache=True,
)
def _pbs(t: ArrayLike, out: ArrayLike) -> None:
"""Generalized U-function for computing PBS."""
out[:, :, :] = np.nan # (cohorts, cohorts, cohorts)
n_cohorts = t.shape[0]
# calculate PBS for each cohort triple
for i in range(n_cohorts):
for j in range(i + 1, n_cohorts):
for k in range(j + 1, n_cohorts):
ret = (t[i, j] + t[i, k] - t[j, k]) / 2
norm = 1 + (t[i, j] + t[i, k] + t[j, k]) / 2
ret = ret / norm
out[i, j, k] = ret


def pbs(
ds: Dataset,
*,
stat_Fst: Hashable = variables.stat_Fst,
merge: bool = True,
) -> Dataset:
"""Compute the population branching statistic (PBS) between cohort triples.

By default, values of this statistic are calculated per variant.
To compute values in windows, call :func:`window` before calling
this function.

Parameters
----------
ds
Genotype call dataset.
stat_Fst
Fst variable to use or calculate. Defined by
:data:`sgkit.variables.stat_Fst_spec`.
If the variable is not present in ``ds``, it will be computed
using :func:`Fst`.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Returns
-------
A dataset containing the PBS value between cohort triples, as defined by
:data:`sgkit.variables.stat_pbs_spec`.
Shape (variants, cohorts, cohorts, cohorts), or
(windows, cohorts, cohorts, cohorts) if windowing information is available.

Warnings
--------
This method does not currently support datasets that are chunked along the
samples dimension.

Examples
--------

>>> import numpy as np
>>> import sgkit as sg
>>> import xarray as xr
>>> ds = sg.simulate_genotype_call_dataset(n_variant=5, n_sample=6)

>>> # Divide samples into three named cohorts
>>> n_cohorts = 3
>>> sample_cohort = np.repeat(range(n_cohorts), ds.dims["samples"] // n_cohorts)
>>> ds["sample_cohort"] = xr.DataArray(sample_cohort, dims="samples")
>>> cohort_names = [f"co_{i}" for i in range(n_cohorts)]
>>> ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names, "cohorts_2": cohort_names})

>>> # Divide into two windows of size three (variants)
>>> ds = sg.window(ds, size=3, step=3)
>>> sg.pbs(ds)["stat_pbs"].sel(cohorts_0="co_0", cohorts_1="co_1", cohorts_2="co_2").values # doctest: +NORMALIZE_WHITESPACE
array([ 0. , -0.160898])
"""

ds = define_variable_if_absent(ds, variables.stat_Fst, stat_Fst, Fst)
variables.validate(ds, {stat_Fst: variables.stat_Fst_spec})

fst = ds[variables.stat_Fst]
fst = fst.clip(min=0, max=(1 - np.finfo(float).epsneg))

t = -np.log(1 - fst)
n_cohorts = ds.dims["cohorts"]
n_windows = ds.dims["windows"]
assert_array_shape(t, n_windows, n_cohorts, n_cohorts)

# calculate PBS triples
t = da.asarray(t)
shape = (t.chunks[0], n_cohorts, n_cohorts, n_cohorts)
p = da.map_blocks(_pbs, t, chunks=shape, new_axis=3, dtype=np.float64)
assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts)

new_ds = Dataset(
{variables.stat_pbs: (["windows", "cohorts_0", "cohorts_1", "cohorts_2"], p)}
)
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
71 changes: 71 additions & 0 deletions sgkit/tests/test_popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
create_genotype_call_dataset,
divergence,
diversity,
pbs,
variables,
)
from sgkit.window import window
Expand Down Expand Up @@ -328,3 +329,73 @@ def test_Tajimas_D(sample_size):
d = ds.stat_Tajimas_D.compute()
ts_d = ts.Tajimas_D()
np.testing.assert_allclose(d, ts_d)


@pytest.mark.parametrize(
"sample_size, n_cohorts",
[(10, 3)],
)
def test_pbs(sample_size, n_cohorts):
ts = msprime.simulate(sample_size, length=100, mutation_rate=0.05, random_seed=42)
subsets = np.array_split(ts.samples(), n_cohorts)
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
sample_cohorts = np.concatenate(
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
n_variants = ds.dims["variants"]
ds = window(ds, size=n_variants, step=n_variants) # single window

ds = pbs(ds)
stat_pbs = ds["stat_pbs"]

# scikit-allel
ac1 = ds.cohort_allele_count.values[:, 0, :]
ac2 = ds.cohort_allele_count.values[:, 1, :]
ac3 = ds.cohort_allele_count.values[:, 2, :]

ska_pbs_value = np.full([1, n_cohorts, n_cohorts, n_cohorts], np.nan)
for i, j, k in itertools.combinations(range(n_cohorts), 3):
ska_pbs_value[0, i, j, k] = allel.pbs(
ac1, ac2, ac3, window_size=n_variants, window_step=n_variants
)

np.testing.assert_allclose(stat_pbs, ska_pbs_value)


@pytest.mark.parametrize(
"sample_size, n_cohorts",
[(10, 3)],
)
@pytest.mark.parametrize("chunks", [(-1, -1), (50, -1)])
def test_pbs__windowed(sample_size, n_cohorts, chunks):
ts = msprime.simulate(sample_size, length=200, mutation_rate=0.05, random_seed=42)
subsets = np.array_split(ts.samples(), n_cohorts)
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
sample_cohorts = np.concatenate(
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = window(ds, size=25, step=25)

ds = pbs(ds)
stat_pbs = ds["stat_pbs"].values

# scikit-allel
ac1 = ds.cohort_allele_count.values[:, 0, :]
ac2 = ds.cohort_allele_count.values[:, 1, :]
ac3 = ds.cohort_allele_count.values[:, 2, :]

# scikit-allel has final window missing
n_windows = ds.dims["windows"] - 1
ska_pbs_value = np.full([n_windows, n_cohorts, n_cohorts, n_cohorts], np.nan)
for i, j, k in itertools.combinations(range(n_cohorts), 3):
ska_pbs_value[:, i, j, k] = allel.pbs(
ac1, ac2, ac3, window_size=25, window_step=25
)

np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)
4 changes: 4 additions & 0 deletions sgkit/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def _check_field(
ArrayLikeSpec("stat_diversity", ndim=2, kind="f")
)
"""Genetic diversity (also known as "Tajima’s pi") for cohorts."""
stat_pbs, stat_pbs_spec = SgkitVariables.register_variable(
ArrayLikeSpec("stat_pbs", ndim=4, kind="f")
)
"""Population branching statistic for cohort triples."""
stat_Tajimas_D, stat_Tajimas_D_spec = SgkitVariables.register_variable(
ArrayLikeSpec("stat_Tajimas_D", ndim={0, 2}, kind="f")
)
Expand Down