|
| 1 | +from typing import Hashable, Optional |
| 2 | + |
| 3 | +import dask.array as da |
| 4 | +import numpy as np |
| 5 | +import xarray as xr |
| 6 | +from numba import njit |
| 7 | +from numpy import ndarray |
| 8 | +from xarray import Dataset |
| 9 | + |
| 10 | + |
| 11 | +def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float: |
| 12 | + """Exact test for HWE as described in Wigginton et al. 2005 [1]. |
| 13 | +
|
| 14 | + Parameters |
| 15 | + ---------- |
| 16 | + obs_hets : int |
| 17 | + Number of heterozygotes with minor variant. |
| 18 | + obs_hom1 : int |
| 19 | + Number of reference/major homozygotes. |
| 20 | + obs_hom2 : int |
| 21 | + Number of alternate/minor homozygotes. |
| 22 | +
|
| 23 | + Returns |
| 24 | + ------- |
| 25 | + float |
| 26 | + P value in [0, 1] |
| 27 | +
|
| 28 | + References |
| 29 | + ---------- |
| 30 | + - [1] Wigginton, Janis E., David J. Cutler, and Goncalo R. Abecasis. 2005. |
| 31 | + “A Note on Exact Tests of Hardy-Weinberg Equilibrium.” American Journal of |
| 32 | + Human Genetics 76 (5): 887–93. |
| 33 | +
|
| 34 | + Raises |
| 35 | + ------ |
| 36 | + ValueError |
| 37 | + If any observed counts are negative. |
| 38 | + """ |
| 39 | + if obs_hom1 < 0 or obs_hom2 < 0 or obs_hets < 0: |
| 40 | + raise ValueError("Observed genotype counts must be positive") |
| 41 | + |
| 42 | + obs_homc = obs_hom2 if obs_hom1 < obs_hom2 else obs_hom1 |
| 43 | + obs_homr = obs_hom1 if obs_hom1 < obs_hom2 else obs_hom2 |
| 44 | + obs_mac = 2 * obs_homr + obs_hets |
| 45 | + obs_n = obs_hets + obs_homc + obs_homr |
| 46 | + het_probs = np.zeros(obs_mac + 1, dtype=np.float64) |
| 47 | + |
| 48 | + if obs_n == 0: |
| 49 | + return np.nan # type: ignore[no-any-return] |
| 50 | + |
| 51 | + # Identify distribution midpoint |
| 52 | + mid = int(obs_mac * (2 * obs_n - obs_mac) / (2 * obs_n)) |
| 53 | + if (obs_mac & 1) ^ (mid & 1): |
| 54 | + mid += 1 |
| 55 | + het_probs[mid] = 1.0 |
| 56 | + prob_sum = het_probs[mid] |
| 57 | + |
| 58 | + # Integrate downward from distribution midpoint |
| 59 | + curr_hets = mid |
| 60 | + curr_homr = int((obs_mac - mid) / 2) |
| 61 | + curr_homc = obs_n - curr_hets - curr_homr |
| 62 | + while curr_hets > 1: |
| 63 | + het_probs[curr_hets - 2] = ( |
| 64 | + het_probs[curr_hets] |
| 65 | + * curr_hets |
| 66 | + * (curr_hets - 1.0) |
| 67 | + / (4.0 * (curr_homr + 1.0) * (curr_homc + 1.0)) |
| 68 | + ) |
| 69 | + prob_sum += het_probs[curr_hets - 2] |
| 70 | + curr_homr += 1 |
| 71 | + curr_homc += 1 |
| 72 | + curr_hets -= 2 |
| 73 | + |
| 74 | + # Integrate upward from distribution midpoint |
| 75 | + curr_hets = mid |
| 76 | + curr_homr = int((obs_mac - mid) / 2) |
| 77 | + curr_homc = obs_n - curr_hets - curr_homr |
| 78 | + while curr_hets <= obs_mac - 2: |
| 79 | + het_probs[curr_hets + 2] = ( |
| 80 | + het_probs[curr_hets] |
| 81 | + * 4.0 |
| 82 | + * curr_homr |
| 83 | + * curr_homc |
| 84 | + / ((curr_hets + 2.0) * (curr_hets + 1.0)) |
| 85 | + ) |
| 86 | + prob_sum += het_probs[curr_hets + 2] |
| 87 | + curr_homr -= 1 |
| 88 | + curr_homc -= 1 |
| 89 | + curr_hets += 2 |
| 90 | + |
| 91 | + if prob_sum <= 0: # pragma: no cover |
| 92 | + return np.nan # type: ignore[no-any-return] |
| 93 | + het_probs = het_probs / prob_sum |
| 94 | + p = het_probs[het_probs <= het_probs[obs_hets]].sum() |
| 95 | + p = max(min(1.0, p), 0.0) |
| 96 | + |
| 97 | + return p # type: ignore[no-any-return] |
| 98 | + |
| 99 | + |
| 100 | +# Benchmarks show ~25% improvement w/ fastmath on large (~10M) counts |
| 101 | +hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value, fastmath=True) |
| 102 | + |
| 103 | + |
| 104 | +def hardy_weinberg_p_value_vec( |
| 105 | + obs_hets: ndarray, obs_hom1: ndarray, obs_hom2: ndarray |
| 106 | +) -> ndarray: |
| 107 | + arrs = [obs_hets, obs_hom1, obs_hom2] |
| 108 | + if len(set(map(len, arrs))) != 1: |
| 109 | + raise ValueError("All arrays must have same length") |
| 110 | + if list(set(map(lambda x: x.ndim, arrs))) != [1]: |
| 111 | + raise ValueError("All arrays must be 1D") |
| 112 | + n = len(obs_hets) |
| 113 | + p = np.empty(n, dtype=np.float64) |
| 114 | + for i in range(n): |
| 115 | + p[i] = hardy_weinberg_p_value_jit(obs_hets[i], obs_hom1[i], obs_hom2[i]) |
| 116 | + return p |
| 117 | + |
| 118 | + |
| 119 | +hardy_weinberg_p_value_vec_jit = njit(hardy_weinberg_p_value_vec, fastmath=True) |
| 120 | + |
| 121 | + |
| 122 | +def hardy_weinberg_test( |
| 123 | + ds: Dataset, genotype_counts: Optional[Hashable] = None |
| 124 | +) -> Dataset: |
| 125 | + """Exact test for HWE as described in Wigginton et al. 2005 [1]. |
| 126 | +
|
| 127 | + Parameters |
| 128 | + ---------- |
| 129 | + ds : Dataset |
| 130 | + Dataset containing genotype calls or precomputed genotype counts. |
| 131 | + genotype_counts : Optional[Hashable], optional |
| 132 | + Name of variable containing precomputed genotype counts, by default |
| 133 | + None. If not provided, these counts will be computed automatically |
| 134 | + from genotype calls. If present, must correspond to an (`N`, 3) array |
| 135 | + where `N` is equal to the number of variants and the 3 columns contain |
| 136 | + heterozygous, homozygous reference, and homozygous alternate counts |
| 137 | + (in that order) across all samples for a variant. |
| 138 | +
|
| 139 | + Warnings |
| 140 | + -------- |
| 141 | + This function is only applicable to diploid, biallelic datasets. |
| 142 | +
|
| 143 | + Returns |
| 144 | + ------- |
| 145 | + Dataset |
| 146 | + Dataset containing (N = num variants): |
| 147 | + variant_hwe_p_value : (N,) ArrayLike |
| 148 | + P values from HWE test for each variant as float in [0, 1]. |
| 149 | +
|
| 150 | + References |
| 151 | + ---------- |
| 152 | + - [1] Wigginton, Janis E., David J. Cutler, and Goncalo R. Abecasis. 2005. |
| 153 | + “A Note on Exact Tests of Hardy-Weinberg Equilibrium.” American Journal of |
| 154 | + Human Genetics 76 (5): 887–93. |
| 155 | +
|
| 156 | + Raises |
| 157 | + ------ |
| 158 | + NotImplementedError |
| 159 | + * If ploidy of provided dataset != 2 |
| 160 | + * If maximum number of alleles in provided dataset != 2 |
| 161 | + """ |
| 162 | + if ds.dims["ploidy"] != 2: |
| 163 | + raise NotImplementedError("HWE test only implemented for diploid genotypes") |
| 164 | + if ds.dims["alleles"] != 2: |
| 165 | + raise NotImplementedError("HWE test only implemented for biallelic genotypes") |
| 166 | + # Use precomputed genotype counts if provided |
| 167 | + if genotype_counts is not None: |
| 168 | + obs = list(da.asarray(ds[genotype_counts]).T) |
| 169 | + # Otherwise compute genotype counts from calls |
| 170 | + else: |
| 171 | + # TODO: Use API genotype counting function instead, e.g. |
| 172 | + # https://github.com/pystatgen/sgkit/issues/29#issuecomment-656691069 |
| 173 | + M = ds["call_genotype_mask"].any(dim="ploidy") |
| 174 | + AC = xr.where(M, -1, ds["call_genotype"].sum(dim="ploidy")) # type: ignore[no-untyped-call] |
| 175 | + cts = [1, 0, 2] # arg order: hets, hom1, hom2 |
| 176 | + obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts] |
| 177 | + p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs) |
| 178 | + return xr.Dataset({"variant_hwe_p_value": ("variants", p)}) |
0 commit comments