@@ -40,9 +40,11 @@ def test_diversity(size):
40
40
ds = ts_to_dataset (ts ) # type: ignore[no-untyped-call]
41
41
sample_cohorts = np .full_like (ts .samples (), 0 )
42
42
ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
43
- div = diversity (ds )["stat_diversity" ].compute ()
43
+ ds = ds .assign_coords ({"cohorts" : ["co_0" ]})
44
+ ds = diversity (ds )
45
+ div = ds ["stat_diversity" ].sel (cohorts = "co_0" ).values
44
46
ts_div = ts .diversity (span_normalise = False )
45
- np .testing .assert_allclose (div [ 0 ] , ts_div )
47
+ np .testing .assert_allclose (div , ts_div )
46
48
47
49
48
50
@pytest .mark .parametrize ("size" , [2 , 3 , 10 , 100 ])
@@ -55,7 +57,10 @@ def test_divergence(size):
55
57
(np .full_like (subset_1 , 0 ), np .full_like (subset_2 , 1 ))
56
58
)
57
59
ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
58
- div = divergence (ds )["stat_divergence" ].compute ()
60
+ cohort_names = ["co_0" , "co_1" ]
61
+ ds = ds .assign_coords ({"cohorts_a" : cohort_names , "cohorts_b" : cohort_names })
62
+ ds = divergence (ds )
63
+ div = ds ["stat_divergence" ].sel (cohorts_a = "co_0" , cohorts_b = "co_1" ).values
59
64
ts_div = ts .divergence ([subset_1 , subset_2 ], span_normalise = False )
60
65
np .testing .assert_allclose (div , ts_div )
61
66
@@ -70,7 +75,10 @@ def test_Fst(size):
70
75
(np .full_like (subset_1 , 0 ), np .full_like (subset_2 , 1 ))
71
76
)
72
77
ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
73
- fst = Fst (ds )["stat_Fst" ].compute ()
78
+ cohort_names = ["co_0" , "co_1" ]
79
+ ds = ds .assign_coords ({"cohorts_a" : cohort_names , "cohorts_b" : cohort_names })
80
+ ds = Fst (ds )
81
+ fst = ds ["stat_Fst" ].sel (cohorts_a = "co_0" , cohorts_b = "co_1" ).values
74
82
ts_fst = ts .Fst ([subset_1 , subset_2 ])
75
83
np .testing .assert_allclose (fst , ts_fst )
76
84
@@ -81,6 +89,7 @@ def test_Tajimas_D(size):
81
89
ds = ts_to_dataset (ts ) # type: ignore[no-untyped-call]
82
90
sample_cohorts = np .full_like (ts .samples (), 0 )
83
91
ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
92
+ ds = Tajimas_D (ds )
93
+ d = ds ["stat_Tajimas_D" ].compute ()
84
94
ts_d = ts .Tajimas_D ()
85
- d = Tajimas_D (ds )["stat_Tajimas_D" ].compute ()
86
95
np .testing .assert_allclose (d , ts_d )
0 commit comments