Skip to content

Commit 4155438

Browse files
committed
PBS
1 parent fe45a21 commit 4155438

File tree

4 files changed

+180
-1
lines changed

4 files changed

+180
-1
lines changed

sgkit/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .stats.hwe import hardy_weinberg_test
1919
from .stats.pc_relate import pc_relate
2020
from .stats.pca import pca
21-
from .stats.popgen import Fst, Tajimas_D, divergence, diversity
21+
from .stats.popgen import Fst, Tajimas_D, divergence, diversity, pbs
2222
from .stats.preprocessing import filter_partial_calls
2323
from .stats.regenie import regenie
2424
from .testing import simulate_genotype_call_dataset
@@ -45,6 +45,7 @@
4545
"divergence",
4646
"Fst",
4747
"Tajimas_D",
48+
"pbs",
4849
"pc_relate",
4950
"simulate_genotype_call_dataset",
5051
"variables",

sgkit/stats/popgen.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,106 @@ def Tajimas_D(
578578

579579
new_ds = Dataset({variables.stat_Tajimas_D: D})
580580
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
581+
582+
583+
# c = cohorts
584+
@guvectorize( # type: ignore
585+
["void(float32[:, :], float32[:,:,:])", "void(float64[:, :], float64[:,:,:])"],
586+
"(c,c)->(c,c,c)",
587+
nopython=True,
588+
)
589+
def _pbs(t: ArrayLike, out: ArrayLike) -> None:
590+
"""Generalized U-function for computing PBS."""
591+
out[:, :, :] = np.nan # (cohorts, cohorts, cohorts)
592+
n_cohorts = t.shape[0]
593+
# calculate PBS for each cohort triple
594+
for i in range(n_cohorts):
595+
for j in range(i + 1, n_cohorts):
596+
for k in range(j + 1, n_cohorts):
597+
ret = (t[i, j] + t[i, k] - t[j, k]) / 2
598+
norm = 1 + (t[i, j] + t[i, k] + t[j, k]) / 2
599+
ret = ret / norm
600+
out[i, j, k] = ret
601+
602+
603+
def pbs(
604+
ds: Dataset,
605+
*,
606+
stat_Fst: Hashable = variables.stat_Fst,
607+
merge: bool = True,
608+
) -> Dataset:
609+
"""Compute the population branching statistic (PBS) between cohort triples.
610+
611+
By default, values of this statistic are calculated per variant.
612+
To compute values in windows, call :func:`window` before calling
613+
this function.
614+
615+
Parameters
616+
----------
617+
ds
618+
Genotype call dataset.
619+
stat_Fst
620+
Fst variable to use or calculate. Defined by
621+
:data:`sgkit.variables.stat_Fst_spec`.
622+
If the variable is not present in ``ds``, it will be computed
623+
using :func:`Fst`.
624+
merge
625+
If True (the default), merge the input dataset and the computed
626+
output variables into a single dataset, otherwise return only
627+
the computed output variables.
628+
See :ref:`dataset_merge` for more details.
629+
630+
Returns
631+
-------
632+
A dataset containing the PBS value between cohort triples, as defined by
633+
:data:`sgkit.variables.stat_pbs_spec`.
634+
Shape (variants, cohorts, cohorts, cohorts), or
635+
(windows, cohorts, cohorts, cohorts) if windowing information is available.
636+
637+
Warnings
638+
--------
639+
This method does not currently support datasets that are chunked along the
640+
samples dimension.
641+
642+
Examples
643+
--------
644+
645+
>>> import numpy as np
646+
>>> import sgkit as sg
647+
>>> import xarray as xr
648+
>>> ds = sg.simulate_genotype_call_dataset(n_variant=5, n_sample=6)
649+
650+
>>> # Divide samples into three named cohorts
651+
>>> n_cohorts = 3
652+
>>> sample_cohort = np.repeat(range(n_cohorts), ds.dims["samples"] // n_cohorts)
653+
>>> ds["sample_cohort"] = xr.DataArray(sample_cohort, dims="samples")
654+
>>> cohort_names = [f"co_{i}" for i in range(n_cohorts)]
655+
>>> ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names, "cohorts_2": cohort_names})
656+
657+
>>> # Divide into two windows of size three (variants)
658+
>>> ds = sg.window(ds, size=3, step=3)
659+
>>> sg.pbs(ds)["stat_pbs"].sel(cohorts_0="co_0", cohorts_1="co_1", cohorts_2="co_2").values # doctest: +NORMALIZE_WHITESPACE
660+
array([ 0. , -0.160898])
661+
"""
662+
663+
ds = define_variable_if_absent(ds, variables.stat_Fst, stat_Fst, Fst)
664+
variables.validate(ds, {stat_Fst: variables.stat_Fst_spec})
665+
666+
fst = ds[variables.stat_Fst]
667+
fst = fst.clip(min=0, max=0.99999)
668+
669+
t = -np.log(1 - fst)
670+
n_cohorts = ds.dims["cohorts"]
671+
n_windows = ds.dims["windows"]
672+
assert_array_shape(t, n_windows, n_cohorts, n_cohorts)
673+
674+
# calculate PBS triples
675+
t = da.asarray(t)
676+
shape = (t.chunks[0], n_cohorts, n_cohorts, n_cohorts)
677+
p = da.map_blocks(_pbs, t, chunks=shape, new_axis=3, dtype=np.float64)
678+
assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts)
679+
680+
new_ds = Dataset(
681+
{variables.stat_pbs: (["windows", "cohorts_0", "cohorts_1", "cohorts_2"], p)}
682+
)
683+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

sgkit/tests/test_popgen.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
create_genotype_call_dataset,
1616
divergence,
1717
diversity,
18+
pbs,
1819
variables,
1920
)
2021
from sgkit.window import window
@@ -328,3 +329,73 @@ def test_Tajimas_D(sample_size):
328329
d = ds.stat_Tajimas_D.compute()
329330
ts_d = ts.Tajimas_D()
330331
np.testing.assert_allclose(d, ts_d)
332+
333+
334+
@pytest.mark.parametrize(
335+
"sample_size, n_cohorts",
336+
[(10, 3)],
337+
)
338+
def test_pbs(sample_size, n_cohorts):
339+
ts = msprime.simulate(sample_size, length=100, mutation_rate=0.05, random_seed=42)
340+
subsets = np.array_split(ts.samples(), n_cohorts)
341+
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
342+
sample_cohorts = np.concatenate(
343+
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
344+
)
345+
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
346+
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
347+
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
348+
n_variants = ds.dims["variants"]
349+
ds = window(ds, size=n_variants, step=n_variants) # single window
350+
351+
ds = pbs(ds)
352+
stat_pbs = ds["stat_pbs"]
353+
354+
# scikit-allel
355+
ac1 = ds.cohort_allele_count.values[:, 0, :]
356+
ac2 = ds.cohort_allele_count.values[:, 1, :]
357+
ac3 = ds.cohort_allele_count.values[:, 2, :]
358+
359+
ska_pbs_value = np.full([1, n_cohorts, n_cohorts, n_cohorts], np.nan)
360+
for i, j, k in itertools.combinations(range(n_cohorts), 3):
361+
ska_pbs_value[0, i, j, k] = allel.pbs(
362+
ac1, ac2, ac3, window_size=n_variants, window_step=n_variants
363+
)
364+
365+
np.testing.assert_allclose(stat_pbs, ska_pbs_value)
366+
367+
368+
@pytest.mark.parametrize(
369+
"sample_size, n_cohorts",
370+
[(10, 3)],
371+
)
372+
@pytest.mark.parametrize("chunks", [(-1, -1), (50, -1)])
373+
def test_pbs__windowed(sample_size, n_cohorts, chunks):
374+
ts = msprime.simulate(sample_size, length=200, mutation_rate=0.05, random_seed=42)
375+
subsets = np.array_split(ts.samples(), n_cohorts)
376+
ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call]
377+
sample_cohorts = np.concatenate(
378+
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
379+
)
380+
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
381+
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
382+
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
383+
ds = window(ds, size=25, step=25)
384+
385+
ds = pbs(ds)
386+
stat_pbs = ds["stat_pbs"].values
387+
388+
# scikit-allel
389+
ac1 = ds.cohort_allele_count.values[:, 0, :]
390+
ac2 = ds.cohort_allele_count.values[:, 1, :]
391+
ac3 = ds.cohort_allele_count.values[:, 2, :]
392+
393+
# scikit-allel has final window missing
394+
n_windows = ds.dims["windows"] - 1
395+
ska_pbs_value = np.full([n_windows, n_cohorts, n_cohorts, n_cohorts], np.nan)
396+
for i, j, k in itertools.combinations(range(n_cohorts), 3):
397+
ska_pbs_value[:, i, j, k] = allel.pbs(
398+
ac1, ac2, ac3, window_size=25, window_step=25
399+
)
400+
401+
np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value)

sgkit/variables.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def _check_field(
308308
ArrayLikeSpec("stat_diversity", ndim=2, kind="f")
309309
)
310310
"""Genetic diversity (also known as "Tajima’s pi") for cohorts."""
311+
stat_pbs, stat_pbs_spec = SgkitVariables.register_variable(
312+
ArrayLikeSpec("stat_pbs", ndim=4, kind="f")
313+
)
314+
"""Population branching statistic for cohort triples."""
311315
stat_Tajimas_D, stat_Tajimas_D_spec = SgkitVariables.register_variable(
312316
ArrayLikeSpec("stat_Tajimas_D", ndim={0, 2}, kind="f")
313317
)

0 commit comments

Comments
 (0)