Skip to content

Cohort grouping for popgen #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ Sgkit functions are compatible with this idiom by default and this example shows
Xarray and Pandas operations in a single pipeline:

.. ipython:: python
:okwarning:

import sgkit as sg
ds = sg.simulate_genotype_call_dataset(n_variant=100, n_sample=50, missing_pct=.1)
Expand All @@ -276,10 +277,9 @@ Xarray and Pandas operations in a single pipeline:
# Assign a "cohort" variable that splits samples into two groups
.assign(sample_cohort=np.repeat([0, 1], ds.dims['samples'] // 2))
# Compute Fst between the groups
# TODO: Refactor based on https://github.com/pystatgen/sgkit/pull/260
.pipe(lambda ds: sg.Fst(*(g[1] for g in ds.groupby('sample_cohort'))))
# Extract the single Fst value from the resulting array
.item(0)
.pipe(sg.Fst)
# Extract the Fst values for cohort pairs
.stat_Fst.values
)

This is possible because sgkit functions nearly always take a ``Dataset`` as the first argument, create new
Expand Down
80 changes: 79 additions & 1 deletion sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Literal
from xarray import Dataset

from sgkit.stats.utils import assert_array_shape
from sgkit.typing import ArrayLike
from sgkit.utils import conditional_merge_datasets

Expand Down Expand Up @@ -51,6 +52,38 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
out[a] += 1


# n = samples, c = cohorts, k = alleles
@guvectorize( # type: ignore
[
"void(uint8[:, :], int32[:], uint8[:], int32[:,:])",
"void(uint8[:, :], int64[:], uint8[:], int32[:,:])",
],
"(n, k),(n),(c)->(c,k)",
nopython=True,
)
def _count_cohort_alleles(
ac: ArrayLike, cohorts: ArrayLike, _: ArrayLike, out: ArrayLike
) -> None:
"""Generalized U-function for computing per cohort allele counts.

Parameters
----------
ac
Allele counts of shape (samples, alleles) containing per-sample allele counts.
cohorts
Cohort indexes for samples of shape (samples,).
_
Dummy variable of type `uint8` and shape (cohorts,) used to
define the number of cohorts.
out
Allele counts with shape (cohorts, alleles) and values corresponding to
the number of non-missing occurrences of each allele in each cohort.
"""
out[:, :] = 0 # (cohorts, alleles)
for i in range(ac.shape[0]):
out[cohorts[i]] += ac[i]


def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute per sample allele counts from genotype calls.

Expand All @@ -60,7 +93,6 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset:
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge
(optional)
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
the computed output variables.
Expand Down Expand Up @@ -167,6 +199,52 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset:
return conditional_merge_datasets(ds, new_ds, merge)


def count_cohort_alleles(ds: Dataset, merge: bool = True) -> Dataset:
"""Compute per cohort allele counts from genotype calls.

Parameters
----------
ds
Genotype call dataset such as from
`sgkit.create_genotype_call_dataset`.
merge
If True (the default), merge the input dataset and the computed
output variables into a single dataset, otherwise return only
the computed output variables.
See :ref:`dataset_merge` for more details.

Returns
-------
Dataset containing variable `call_allele_count` of allele counts with
shape (variants, cohorts, alleles) and values corresponding to
the number of non-missing occurrences of each allele.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this shouldn't also return a (variants, cohorts) array that contains the number of present samples used in each aggregation. It doesn't do it now, but if count_call_alleles returned -1 for all alleles when any (or potentially all) input calls were missing then I can see a count of the samples with non-negative allele counts being a useful downstream denominator.

Is that not important with the stats functions in popgen.py now? It may make sense to put a warning on those functions and log an issue related to supporting it. It seems like only ignoring the missing values during the original allele count must introduce some kind of bias through all those aggregations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I haven't thought about missingness with respect to the meaning of the functions though. Opened #290

"""

n_variants = ds.dims["variants"]
n_alleles = ds.dims["alleles"]

ds = count_call_alleles(ds)
AC, SC = da.asarray(ds.call_allele_count), da.asarray(ds.sample_cohort)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we standardize on using ds. call_allele_count vs ds['call_allele_count']? I've been trying to use the former in examples only and stick to the latter in code, but not for great reasons other than consistency and hedging against potentially having variable names that aren't legal attributes or conflict with something in the xarray namespace. Presumably not using attributes becomes easier to standardize on with https://github.com/pystatgen/sgkit/pull/276.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. It's certainly shorter to say ds.call_allele_count, but this form can't be used in assignments:

ds.call_allele_count = ...

I've changed the accessors in this PR, but we should probably discuss a bit more about our code standards, and what we recommend for user code.

n_cohorts = SC.max().compute() + 1 # 0-based indexing
C = da.empty(n_cohorts, dtype=np.uint8)

G = da.asarray(ds.call_genotype)
shape = (G.chunks[0], n_cohorts, n_alleles)

AC = da.map_blocks(_count_cohort_alleles, AC, SC, C, chunks=shape, dtype=np.int32)
assert_array_shape(
AC, n_variants, n_cohorts * AC.numblocks[1], n_alleles * AC.numblocks[2]
)

# Stack the blocks and sum across them
# (which will only work because each chunk is guaranteed to have same size)
AC = da.stack([AC.blocks[:, i] for i in range(AC.numblocks[1])]).sum(axis=0)
assert_array_shape(AC, n_variants, n_cohorts, n_alleles)

new_ds = Dataset({"cohort_allele_count": (("variants", "cohorts", "alleles"), AC)})
return conditional_merge_datasets(ds, new_ds, merge)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomwhite A quick question, I am trying to understand what this function (count_cohort_alleles) is trying to do mathematically, is there a reference implementation or formula, which I can look at to understand this better? Or is it just a an attempt to do a better rewrite of an old implementation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aktech Good question (and this function definitely needs documentation!). It's needed because count_call_alleles and count_variant_alleles return allele counts for each sample and for all samples (respectively), whereas in this PR we need something between the two since we now have the concept of "cohort" (a set of samples): so we need per-cohort counts.

The counts array computed here is of shape (n_variants, n_cohorts, n_alleles), so we have a count for each variant/cohort/allele combination. Have a look at the tests (e.g. test_count_cohort_alleles__multi_variant_multi_sample), and perhaps try calling the allele counting functions in a Python repl to get a feel for what they are doing.



def _swap(dim: Dimension) -> Dimension:
return "samples" if dim == "variants" else "variants"

Expand Down
Loading