Skip to content

Commit 3b5194f

Browse files
committed
Garud H should only support windowed datasets
1 parent 3c31253 commit 3b5194f

File tree

2 files changed

+39
-68
lines changed

2 files changed

+39
-68
lines changed

sgkit/stats/popgen.py

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,9 @@ def Garud_h(
813813
if ds.dims["ploidy"] != 2:
814814
raise NotImplementedError("Garud H only implemented for diploid genotypes")
815815

816+
if not has_windows(ds):
817+
raise ValueError("Dataset must be windowed for Garud_h")
818+
816819
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
817820

818821
gt = ds[call_genotype]
@@ -822,51 +825,36 @@ def Garud_h(
822825
hsc = np.stack((sc, sc), axis=1).ravel() # TODO: assumes diploid
823826
n_cohorts = sc.max() + 1 # 0-based indexing
824827

825-
if has_windows(ds):
826-
gh = window_statistic(
827-
gt,
828-
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts),
829-
ds.window_start.values,
830-
ds.window_stop.values,
831-
dtype=np.float64,
832-
# first chunks dimension is windows, computed in window_statistic
833-
chunks=(-1, n_cohorts, N_GARUD_H_STATS),
834-
)
835-
n_windows = ds.window_start.shape[0]
836-
assert_array_shape(gh, n_windows, n_cohorts, N_GARUD_H_STATS)
837-
new_ds = Dataset(
838-
{
839-
variables.stat_Garud_h1: (
840-
("windows", "cohorts"),
841-
gh[:, :, 0],
842-
),
843-
variables.stat_Garud_h12: (
844-
("windows", "cohorts"),
845-
gh[:, :, 1],
846-
),
847-
variables.stat_Garud_h123: (
848-
("windows", "cohorts"),
849-
gh[:, :, 2],
850-
),
851-
variables.stat_Garud_h2_h1: (
852-
("windows", "cohorts"),
853-
gh[:, :, 3],
854-
),
855-
}
856-
)
857-
else:
858-
# TODO: note this materializes all the data, so windowless should be discouraged/not supported
859-
gt = gt.values
860-
861-
gh = _Garud_h_cohorts(gt, sample_cohort=hsc, n_cohorts=n_cohorts)
862-
assert_array_shape(gh, n_cohorts, N_GARUD_H_STATS)
828+
gh = window_statistic(
829+
gt,
830+
lambda gt: _Garud_h_cohorts(gt, hsc, n_cohorts),
831+
ds.window_start.values,
832+
ds.window_stop.values,
833+
dtype=np.float64,
834+
# first chunks dimension is windows, computed in window_statistic
835+
chunks=(-1, n_cohorts, N_GARUD_H_STATS),
836+
)
837+
n_windows = ds.window_start.shape[0]
838+
assert_array_shape(gh, n_windows, n_cohorts, N_GARUD_H_STATS)
839+
new_ds = Dataset(
840+
{
841+
variables.stat_Garud_h1: (
842+
("windows", "cohorts"),
843+
gh[:, :, 0],
844+
),
845+
variables.stat_Garud_h12: (
846+
("windows", "cohorts"),
847+
gh[:, :, 1],
848+
),
849+
variables.stat_Garud_h123: (
850+
("windows", "cohorts"),
851+
gh[:, :, 2],
852+
),
853+
variables.stat_Garud_h2_h1: (
854+
("windows", "cohorts"),
855+
gh[:, :, 3],
856+
),
857+
}
858+
)
863859

864-
new_ds = Dataset(
865-
{
866-
variables.stat_Garud_h1: gh[:, 0],
867-
variables.stat_Garud_h12: gh[:, 1],
868-
variables.stat_Garud_h123: gh[:, 2],
869-
variables.stat_Garud_h2_h1: gh[:, 3],
870-
}
871-
)
872860
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

sgkit/tests/test_popgen.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -405,46 +405,29 @@ def test_pbs__windowed(sample_size, n_cohorts, chunks):
405405

406406
@pytest.mark.parametrize(
407407
"n_variants, n_samples, n_contigs, n_cohorts",
408-
[(3, 5, 1, 1), (3, 5, 1, 2)],
408+
[(3, 5, 1, 1)],
409409
)
410-
@pytest.mark.parametrize("chunks", [(-1, -1), (2, -1)])
411-
def test_Garud_h(n_variants, n_samples, n_contigs, n_cohorts, chunks):
410+
def test_Garud_h__no_windows(n_variants, n_samples, n_contigs, n_cohorts):
412411
# We can't use msprime since it doesn't generate diploid data, and Garud uses phased data
413412
ds = simulate_genotype_call_dataset(
414413
n_variant=n_variants, n_sample=n_samples, n_contig=n_contigs
415414
)
416-
ds = ds.chunk(dict(zip(["variants", "samples"], chunks)))
417415
subsets = np.array_split(ds.samples.values, n_cohorts)
418416
sample_cohorts = np.concatenate(
419417
[np.full_like(subset, i) for i, subset in enumerate(subsets)]
420418
)
421419
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
422420

423-
gh = Garud_h(ds)
424-
h1 = gh.stat_Garud_h1.values
425-
h12 = gh.stat_Garud_h12.values
426-
h123 = gh.stat_Garud_h123.values
427-
h2_h1 = gh.stat_Garud_h2_h1.values
428-
429-
# scikit-allel
430-
for c in range(n_cohorts):
431-
gt = ds.call_genotype.values[:, sample_cohorts == c, :]
432-
ska_gt = allel.GenotypeArray(gt)
433-
ska_ha = ska_gt.to_haplotypes()
434-
ska_h = allel.garud_h(ska_ha)
435-
436-
np.testing.assert_allclose(h1[c], ska_h[0])
437-
np.testing.assert_allclose(h12[c], ska_h[1])
438-
np.testing.assert_allclose(h123[c], ska_h[2])
439-
np.testing.assert_allclose(h2_h1[c], ska_h[3])
421+
with pytest.raises(ValueError, match="Dataset must be windowed for Garud_h"):
422+
Garud_h(ds)
440423

441424

442425
@pytest.mark.parametrize(
443426
"n_variants, n_samples, n_contigs, n_cohorts",
444427
[(9, 5, 1, 1), (9, 5, 1, 2)],
445428
)
446429
@pytest.mark.parametrize("chunks", [(-1, -1), (5, -1)])
447-
def test_Garud_h__windowed(n_variants, n_samples, n_contigs, n_cohorts, chunks):
430+
def test_Garud_h(n_variants, n_samples, n_contigs, n_cohorts, chunks):
448431
ds = simulate_genotype_call_dataset(
449432
n_variant=n_variants, n_sample=n_samples, n_contig=n_contigs
450433
)

0 commit comments

Comments
 (0)