Skip to content

Commit 606b870

Browse files
committed
Cohort coordinates
1 parent 364c52e commit 606b870

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

sgkit/stats/popgen.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import itertools
12
from typing import Hashable
23

34
import dask.array as da
45
import numpy as np
5-
from xarray import Dataset
6+
from xarray import DataArray, Dataset
67

8+
from sgkit.stats.utils import assert_array_shape
79
from sgkit.utils import merge_datasets
810

911
from .aggregation import count_cohort_alleles, count_variant_alleles
@@ -71,11 +73,29 @@ def divergence(
7173
ac = ds[allele_counts]
7274
an = ac.sum(axis=2)
7375

74-
n_pairs = np.prod(an, axis=1).compute()
75-
n_same = np.prod(ac, axis=1).sum(axis=1).compute()
76-
n_diff = n_pairs - n_same
77-
div = n_diff / n_pairs
78-
new_ds = Dataset({"stat_divergence": div.sum()})
76+
n_variants = ds.dims["variants"]
77+
n_alleles = ds.dims["alleles"]
78+
n_cohorts = ds.dims["cohorts"]
79+
result = np.full([n_cohorts, n_cohorts], np.nan)
80+
81+
# Iterate over cohort pairs
82+
for i, j in itertools.combinations(range(n_cohorts), 2):
83+
an_cohort_pair = an[:, [i, j]]
84+
assert_array_shape(an_cohort_pair, n_variants, 2)
85+
ac_cohort_pair = ac[:, [i, j], :]
86+
assert_array_shape(ac_cohort_pair, n_variants, 2, n_alleles)
87+
88+
n_pairs = np.prod(an_cohort_pair, axis=1).compute()
89+
n_same = np.prod(ac_cohort_pair, axis=1).sum(axis=1).compute()
90+
91+
n_diff = n_pairs - n_same
92+
div = n_diff / n_pairs
93+
div_sum = div.sum().compute() # TODO: avoid this compute
94+
95+
result[i, j] = div_sum
96+
97+
arr = DataArray(result, dims=["cohorts_a", "cohorts_b"])
98+
new_ds = Dataset({"stat_divergence": arr})
7999
return merge_datasets(ds, new_ds) if merge else new_ds
80100

81101

sgkit/tests/test_popgen.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ def test_diversity(size):
4040
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
4141
sample_cohorts = np.full_like(ts.samples(), 0)
4242
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
43-
div = diversity(ds)["stat_diversity"].compute()
43+
ds = ds.assign_coords({"cohorts": ["co_0"]})
44+
ds = diversity(ds)
45+
div = ds["stat_diversity"].sel(cohorts="co_0").values
4446
ts_div = ts.diversity(span_normalise=False)
45-
np.testing.assert_allclose(div[0], ts_div)
47+
np.testing.assert_allclose(div, ts_div)
4648

4749

4850
@pytest.mark.parametrize("size", [2, 3, 10, 100])
@@ -55,7 +57,10 @@ def test_divergence(size):
5557
(np.full_like(subset_1, 0), np.full_like(subset_2, 1))
5658
)
5759
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
58-
div = divergence(ds)["stat_divergence"].compute()
60+
cohort_names = ["co_0", "co_1"]
61+
ds = ds.assign_coords({"cohorts_a": cohort_names, "cohorts_b": cohort_names})
62+
ds = divergence(ds)
63+
div = ds["stat_divergence"].sel(cohorts_a="co_0", cohorts_b="co_1").values
5964
ts_div = ts.divergence([subset_1, subset_2], span_normalise=False)
6065
np.testing.assert_allclose(div, ts_div)
6166

@@ -70,7 +75,10 @@ def test_Fst(size):
7075
(np.full_like(subset_1, 0), np.full_like(subset_2, 1))
7176
)
7277
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
73-
fst = Fst(ds)["stat_Fst"].compute()
78+
cohort_names = ["co_0", "co_1"]
79+
ds = ds.assign_coords({"cohorts_a": cohort_names, "cohorts_b": cohort_names})
80+
ds = Fst(ds)
81+
fst = ds["stat_Fst"].sel(cohorts_a="co_0", cohorts_b="co_1").values
7482
ts_fst = ts.Fst([subset_1, subset_2])
7583
np.testing.assert_allclose(fst, ts_fst)
7684

@@ -81,6 +89,7 @@ def test_Tajimas_D(size):
8189
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
8290
sample_cohorts = np.full_like(ts.samples(), 0)
8391
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
92+
ds = Tajimas_D(ds)
93+
d = ds["stat_Tajimas_D"].compute()
8494
ts_d = ts.Tajimas_D()
85-
d = Tajimas_D(ds)["stat_Tajimas_D"].compute()
8695
np.testing.assert_allclose(d, ts_d)

0 commit comments

Comments
 (0)