diff --git a/sgkit/__init__.py b/sgkit/__init__.py index 55736fdd8..dc88e8819 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -18,7 +18,7 @@ from .stats.hwe import hardy_weinberg_test from .stats.pc_relate import pc_relate from .stats.pca import pca -from .stats.popgen import Fst, Tajimas_D, divergence, diversity +from .stats.popgen import Fst, Tajimas_D, divergence, diversity, pbs from .stats.preprocessing import filter_partial_calls from .stats.regenie import regenie from .testing import simulate_genotype_call_dataset @@ -45,6 +45,7 @@ "divergence", "Fst", "Tajimas_D", + "pbs", "pc_relate", "simulate_genotype_call_dataset", "variables", diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index 245b31e54..5e49b2a30 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -578,3 +578,107 @@ def Tajimas_D( new_ds = Dataset({variables.stat_Tajimas_D: D}) return conditional_merge_datasets(ds, variables.validate(new_ds), merge) + + +# c = cohorts +@guvectorize( # type: ignore + ["void(float32[:, :], float32[:,:,:])", "void(float64[:, :], float64[:,:,:])"], + "(c,c)->(c,c,c)", + nopython=True, + cache=True, +) +def _pbs(t: ArrayLike, out: ArrayLike) -> None: + """Generalized U-function for computing PBS.""" + out[:, :, :] = np.nan # (cohorts, cohorts, cohorts) + n_cohorts = t.shape[0] + # calculate PBS for each cohort triple + for i in range(n_cohorts): + for j in range(i + 1, n_cohorts): + for k in range(j + 1, n_cohorts): + ret = (t[i, j] + t[i, k] - t[j, k]) / 2 + norm = 1 + (t[i, j] + t[i, k] + t[j, k]) / 2 + ret = ret / norm + out[i, j, k] = ret + + +def pbs( + ds: Dataset, + *, + stat_Fst: Hashable = variables.stat_Fst, + merge: bool = True, +) -> Dataset: + """Compute the population branching statistic (PBS) between cohort triples. + + By default, values of this statistic are calculated per variant. + To compute values in windows, call :func:`window` before calling + this function. + + Parameters + ---------- + ds + Genotype call dataset. + stat_Fst + Fst variable to use or calculate. Defined by + :data:`sgkit.variables.stat_Fst_spec`. + If the variable is not present in ``ds``, it will be computed + using :func:`Fst`. + merge + If True (the default), merge the input dataset and the computed + output variables into a single dataset, otherwise return only + the computed output variables. + See :ref:`dataset_merge` for more details. + + Returns + ------- + A dataset containing the PBS value between cohort triples, as defined by + :data:`sgkit.variables.stat_pbs_spec`. + Shape (variants, cohorts, cohorts, cohorts), or + (windows, cohorts, cohorts, cohorts) if windowing information is available. + + Warnings + -------- + This method does not currently support datasets that are chunked along the + samples dimension. + + Examples + -------- + + >>> import numpy as np + >>> import sgkit as sg + >>> import xarray as xr + >>> ds = sg.simulate_genotype_call_dataset(n_variant=5, n_sample=6) + + >>> # Divide samples into three named cohorts + >>> n_cohorts = 3 + >>> sample_cohort = np.repeat(range(n_cohorts), ds.dims["samples"] // n_cohorts) + >>> ds["sample_cohort"] = xr.DataArray(sample_cohort, dims="samples") + >>> cohort_names = [f"co_{i}" for i in range(n_cohorts)] + >>> ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names, "cohorts_2": cohort_names}) + + >>> # Divide into two windows of size three (variants) + >>> ds = sg.window(ds, size=3, step=3) + >>> sg.pbs(ds)["stat_pbs"].sel(cohorts_0="co_0", cohorts_1="co_1", cohorts_2="co_2").values # doctest: +NORMALIZE_WHITESPACE + array([ 0. , -0.160898]) + """ + + ds = define_variable_if_absent(ds, variables.stat_Fst, stat_Fst, Fst) + variables.validate(ds, {stat_Fst: variables.stat_Fst_spec}) + + fst = ds[variables.stat_Fst] + fst = fst.clip(min=0, max=(1 - np.finfo(float).epsneg)) + + t = -np.log(1 - fst) + n_cohorts = ds.dims["cohorts"] + n_windows = ds.dims["windows"] + assert_array_shape(t, n_windows, n_cohorts, n_cohorts) + + # calculate PBS triples + t = da.asarray(t) + shape = (t.chunks[0], n_cohorts, n_cohorts, n_cohorts) + p = da.map_blocks(_pbs, t, chunks=shape, new_axis=3, dtype=np.float64) + assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts) + + new_ds = Dataset( + {variables.stat_pbs: (["windows", "cohorts_0", "cohorts_1", "cohorts_2"], p)} + ) + return conditional_merge_datasets(ds, variables.validate(new_ds), merge) diff --git a/sgkit/tests/test_popgen.py b/sgkit/tests/test_popgen.py index ca70506a8..8170945d5 100644 --- a/sgkit/tests/test_popgen.py +++ b/sgkit/tests/test_popgen.py @@ -15,6 +15,7 @@ create_genotype_call_dataset, divergence, diversity, + pbs, variables, ) from sgkit.window import window @@ -328,3 +329,73 @@ def test_Tajimas_D(sample_size): d = ds.stat_Tajimas_D.compute() ts_d = ts.Tajimas_D() np.testing.assert_allclose(d, ts_d) + + +@pytest.mark.parametrize( + "sample_size, n_cohorts", + [(10, 3)], +) +def test_pbs(sample_size, n_cohorts): + ts = msprime.simulate(sample_size, length=100, mutation_rate=0.05, random_seed=42) + subsets = np.array_split(ts.samples(), n_cohorts) + ds = ts_to_dataset(ts) # type: ignore[no-untyped-call] + sample_cohorts = np.concatenate( + [np.full_like(subset, i) for i, subset in enumerate(subsets)] + ) + ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples") + cohort_names = [f"co_{i}" for i in range(n_cohorts)] + ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names}) + n_variants = ds.dims["variants"] + ds = window(ds, size=n_variants, step=n_variants) # single window + + ds = pbs(ds) + stat_pbs = ds["stat_pbs"] + + # scikit-allel + ac1 = ds.cohort_allele_count.values[:, 0, :] + ac2 = ds.cohort_allele_count.values[:, 1, :] + ac3 = ds.cohort_allele_count.values[:, 2, :] + + ska_pbs_value = np.full([1, n_cohorts, n_cohorts, n_cohorts], np.nan) + for i, j, k in itertools.combinations(range(n_cohorts), 3): + ska_pbs_value[0, i, j, k] = allel.pbs( + ac1, ac2, ac3, window_size=n_variants, window_step=n_variants + ) + + np.testing.assert_allclose(stat_pbs, ska_pbs_value) + + +@pytest.mark.parametrize( + "sample_size, n_cohorts", + [(10, 3)], +) +@pytest.mark.parametrize("chunks", [(-1, -1), (50, -1)]) +def test_pbs__windowed(sample_size, n_cohorts, chunks): + ts = msprime.simulate(sample_size, length=200, mutation_rate=0.05, random_seed=42) + subsets = np.array_split(ts.samples(), n_cohorts) + ds = ts_to_dataset(ts, chunks) # type: ignore[no-untyped-call] + sample_cohorts = np.concatenate( + [np.full_like(subset, i) for i, subset in enumerate(subsets)] + ) + ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples") + cohort_names = [f"co_{i}" for i in range(n_cohorts)] + ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names}) + ds = window(ds, size=25, step=25) + + ds = pbs(ds) + stat_pbs = ds["stat_pbs"].values + + # scikit-allel + ac1 = ds.cohort_allele_count.values[:, 0, :] + ac2 = ds.cohort_allele_count.values[:, 1, :] + ac3 = ds.cohort_allele_count.values[:, 2, :] + + # scikit-allel has final window missing + n_windows = ds.dims["windows"] - 1 + ska_pbs_value = np.full([n_windows, n_cohorts, n_cohorts, n_cohorts], np.nan) + for i, j, k in itertools.combinations(range(n_cohorts), 3): + ska_pbs_value[:, i, j, k] = allel.pbs( + ac1, ac2, ac3, window_size=25, window_step=25 + ) + + np.testing.assert_allclose(stat_pbs[:-1], ska_pbs_value) diff --git a/sgkit/variables.py b/sgkit/variables.py index e07144d86..03e759ad1 100644 --- a/sgkit/variables.py +++ b/sgkit/variables.py @@ -308,6 +308,10 @@ def _check_field( ArrayLikeSpec("stat_diversity", ndim=2, kind="f") ) """Genetic diversity (also known as "Tajima’s pi") for cohorts.""" +stat_pbs, stat_pbs_spec = SgkitVariables.register_variable( + ArrayLikeSpec("stat_pbs", ndim=4, kind="f") +) +"""Population branching statistic for cohort triples.""" stat_Tajimas_D, stat_Tajimas_D_spec = SgkitVariables.register_variable( ArrayLikeSpec("stat_Tajimas_D", ndim={0, 2}, kind="f") )