Skip to content

Commit 3c31253

Browse files
committed
Don't use to_haplotype_calls for Garud H
1 parent b372fc2 commit 3c31253

File tree

3 files changed

+47
-45
lines changed

3 files changed

+47
-45
lines changed

sgkit/stats/popgen.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from numba import guvectorize
77
from xarray import Dataset
88

9-
from sgkit import to_haplotype_calls
109
from sgkit.stats.utils import assert_array_shape
1110
from sgkit.typing import ArrayLike
1211
from sgkit.utils import (
1312
conditional_merge_datasets,
1413
define_variable_if_absent,
15-
hash_columns,
14+
hash_array,
1615
)
1716
from sgkit.window import has_windows, window_statistic
1817

@@ -693,13 +692,13 @@ def pbs(
693692
N_GARUD_H_STATS = 4 # H1, H12, H123, H2/H1
694693

695694

696-
def _Garud_h(k: ArrayLike) -> ArrayLike:
695+
def _Garud_h(haplotypes: ArrayLike) -> ArrayLike:
697696
# find haplotype counts (sorted in descending order)
698-
counts = sorted(collections.Counter(k.tolist()).values(), reverse=True)
697+
counts = sorted(collections.Counter(haplotypes.tolist()).values(), reverse=True)
699698
counts = np.array(counts)
700699

701700
# find haplotype frequencies
702-
n = k.shape[0]
701+
n = haplotypes.shape[0]
703702
f = counts / n
704703

705704
# compute H1
@@ -719,19 +718,20 @@ def _Garud_h(k: ArrayLike) -> ArrayLike:
719718

720719

721720
def _Garud_h_cohorts(
722-
ht: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int
721+
gt: ArrayLike, sample_cohort: ArrayLike, n_cohorts: int
723722
) -> ArrayLike:
724-
k = hash_columns(ht) # hash haplotypes
723+
# transpose to hash columns (haplotypes)
724+
haplotypes = hash_array(gt.transpose()).transpose().flatten()
725725
arr = np.empty((n_cohorts, N_GARUD_H_STATS))
726726
for c in range(n_cohorts):
727-
arr[c, :] = _Garud_h(k[sample_cohort == c])
727+
arr[c, :] = _Garud_h(haplotypes[sample_cohort == c])
728728
return arr
729729

730730

731731
def Garud_h(
732732
ds: Dataset,
733733
*,
734-
call_haplotype: Hashable = variables.call_haplotype,
734+
call_genotype: Hashable = variables.call_genotype,
735735
merge: bool = True,
736736
) -> Dataset:
737737
"""Compute the H1, H12, H123 and H2/H1 statistics for detecting signatures
@@ -745,11 +745,10 @@ def Garud_h(
745745
----------
746746
ds
747747
Genotype call dataset.
748-
call_haplotype
749-
Call haplotype variable to use or calculate. Defined by
750-
:data:`sgkit.variables.call_haplotype_spec`.
751-
If the variable is not present in ``ds``, it will be computed
752-
using :func:`to_haplotype_calls`.
748+
call_genotype
749+
Input variable name holding call_genotype as defined by
750+
:data:`sgkit.variables.call_genotype_spec`.
751+
Must be present in ``ds``.
753752
merge
754753
If True (the default), merge the input dataset and the computed
755754
output variables into a single dataset, otherwise return only
@@ -814,12 +813,9 @@ def Garud_h(
814813
if ds.dims["ploidy"] != 2:
815814
raise NotImplementedError("Garud H only implemented for diploid genotypes")
816815

817-
ds = define_variable_if_absent(
818-
ds, variables.call_haplotype, call_haplotype, to_haplotype_calls
819-
)
820-
variables.validate(ds, {call_haplotype: variables.call_haplotype_spec})
816+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
821817

822-
ht = ds[call_haplotype]
818+
gt = ds[call_genotype]
823819

824820
# convert sample cohorts to haplotype layout
825821
sc = ds.sample_cohort.values
@@ -828,14 +824,13 @@ def Garud_h(
828824

829825
if has_windows(ds):
830826
gh = window_statistic(
831-
ht,
832-
lambda ht: _Garud_h_cohorts(ht, hsc, n_cohorts),
827+
gt,
828+
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts),
833829
ds.window_start.values,
834830
ds.window_stop.values,
835831
dtype=np.float64,
836832
# first chunks dimension is windows, computed in window_statistic
837833
chunks=(-1, n_cohorts, N_GARUD_H_STATS),
838-
new_axis=2, # 2d -> 3d
839834
)
840835
n_windows = ds.window_start.shape[0]
841836
assert_array_shape(gh, n_windows, n_cohorts, N_GARUD_H_STATS)
@@ -861,9 +856,9 @@ def Garud_h(
861856
)
862857
else:
863858
# TODO: note this materializes all the data, so windowless should be discouraged/not supported
864-
ht = ht.values
859+
gt = gt.values
865860

866-
gh = _Garud_h_cohorts(ht, sample_cohort=hsc, n_cohorts=n_cohorts)
861+
gh = _Garud_h_cohorts(gt, sample_cohort=hsc, n_cohorts=n_cohorts)
867862
assert_array_shape(gh, n_cohorts, N_GARUD_H_STATS)
868863

869864
new_ds = Dataset(

sgkit/tests/test_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
check_array_like,
1414
define_variable_if_absent,
1515
encode_array,
16-
hash_columns,
16+
hash_array,
1717
max_str_len,
1818
merge_datasets,
1919
split_array_chunks,
@@ -211,21 +211,21 @@ def test_split_array_chunks__raise_on_n_lte_0():
211211
split_array_chunks(0, 0)
212212

213213

214-
@given(st.integers(1, 50), st.integers(2, 50))
214+
@given(st.integers(2, 50), st.integers(1, 50))
215215
@settings(deadline=None) # avoid problem with numba jit compilation
216-
def test_hash_columns(n_rows, n_cols):
217-
# construct an array with random repeated columns
218-
x = np.random.randint(-2, 10, size=(n_rows, n_cols // 2))
219-
cols = np.random.choice(x.shape[1], n_cols, replace=True)
220-
x = x[:, cols]
216+
def test_hash_array(n_rows, n_cols):
217+
# construct an array with random repeated rows
218+
x = np.random.randint(-2, 10, size=(n_rows // 2, n_cols))
219+
rows = np.random.choice(x.shape[0], n_rows, replace=True)
220+
x = x[rows, :]
221221

222222
# find unique column counts (exact method)
223223
_, expected_inverse, expected_counts = np.unique(
224-
x, axis=1, return_inverse=True, return_counts=True
224+
x, axis=0, return_inverse=True, return_counts=True
225225
)
226226

227227
# hash columns, then find unique column counts using the hash values
228-
h = hash_columns(x)
228+
h = hash_array(x)
229229
_, inverse, counts = np.unique(h, return_inverse=True, return_counts=True)
230230

231231
# counts[inverse] gives the count for each column in x

sgkit/utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import warnings
22
from typing import Any, Callable, Hashable, List, Optional, Set, Tuple, Union
33

4-
import numba
54
import numpy as np
5+
from numba import guvectorize
66
from xarray import Dataset
77

88
from .typing import ArrayLike, DType
@@ -274,9 +274,19 @@ def max_str_len(a: ArrayLike) -> ArrayLike:
274274
return lens.max()
275275

276276

277-
@numba.njit(nogil=True, cache=True) # type: ignore
278-
def hash_columns(x: ArrayLike) -> ArrayLike:
279-
"""Hash columns of ``x`` using the DJBX33A hash function.
277+
@guvectorize( # type: ignore
278+
[
279+
"void(int8[:], int64[:])",
280+
"void(int16[:], int64[:])",
281+
"void(int32[:], int64[:])",
282+
"void(int64[:], int64[:])",
283+
],
284+
"(n)->()",
285+
nopython=True,
286+
cache=True,
287+
)
288+
def hash_array(x: ArrayLike, out: ArrayLike) -> None:
289+
"""Hash entries of ``x`` using the DJBX33A hash function.
280290
281291
This is ~5 times faster than calling ``tobytes()`` followed
282292
by ``hash()`` on array columns. This function also does not
@@ -286,15 +296,12 @@ def hash_columns(x: ArrayLike) -> ArrayLike:
286296
Parameters
287297
----------
288298
x
289-
Array of shape (m, n) and type integer.
299+
1D array of type integer.
290300
291301
Returns
292302
-------
293-
Array containing hash values of shape (n,) and type int64.
303+
Array containing a single hash value of type int64.
294304
"""
295-
h = np.empty((x.shape[1]), dtype=np.int64)
296-
for j in range(x.shape[1]):
297-
h[j] = 5381
298-
for i in range(x.shape[0]):
299-
h[j] = h[j] * 33 + x[i, j]
300-
return h
305+
out[0] = 5381
306+
for i in range(x.shape[0]):
307+
out[0] = out[0] * 33 + x[i]

0 commit comments

Comments
 (0)