@@ -40,9 +40,11 @@ def test_diversity(size):
4040 ds = ts_to_dataset (ts ) # type: ignore[no-untyped-call]
4141 sample_cohorts = np .full_like (ts .samples (), 0 )
4242 ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
43- div = diversity (ds )["stat_diversity" ].compute ()
43+ ds = ds .assign_coords ({"cohorts" : ["co_0" ]})
44+ ds = diversity (ds )
45+ div = ds ["stat_diversity" ].sel (cohorts = "co_0" ).values
4446 ts_div = ts .diversity (span_normalise = False )
45- np .testing .assert_allclose (div [ 0 ] , ts_div )
47+ np .testing .assert_allclose (div , ts_div )
4648
4749
4850@pytest .mark .parametrize ("size" , [2 , 3 , 10 , 100 ])
@@ -55,7 +57,10 @@ def test_divergence(size):
5557 (np .full_like (subset_1 , 0 ), np .full_like (subset_2 , 1 ))
5658 )
5759 ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
58- div = divergence (ds )["stat_divergence" ].compute ()
60+ cohort_names = ["co_0" , "co_1" ]
61+ ds = ds .assign_coords ({"cohorts_a" : cohort_names , "cohorts_b" : cohort_names })
62+ ds = divergence (ds )
63+ div = ds ["stat_divergence" ].sel (cohorts_a = "co_0" , cohorts_b = "co_1" ).values
5964 ts_div = ts .divergence ([subset_1 , subset_2 ], span_normalise = False )
6065 np .testing .assert_allclose (div , ts_div )
6166
@@ -70,7 +75,10 @@ def test_Fst(size):
7075 (np .full_like (subset_1 , 0 ), np .full_like (subset_2 , 1 ))
7176 )
7277 ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
73- fst = Fst (ds )["stat_Fst" ].compute ()
78+ cohort_names = ["co_0" , "co_1" ]
79+ ds = ds .assign_coords ({"cohorts_a" : cohort_names , "cohorts_b" : cohort_names })
80+ ds = Fst (ds )
81+ fst = ds ["stat_Fst" ].sel (cohorts_a = "co_0" , cohorts_b = "co_1" ).values
7482 ts_fst = ts .Fst ([subset_1 , subset_2 ])
7583 np .testing .assert_allclose (fst , ts_fst )
7684
@@ -81,6 +89,7 @@ def test_Tajimas_D(size):
8189 ds = ts_to_dataset (ts ) # type: ignore[no-untyped-call]
8290 sample_cohorts = np .full_like (ts .samples (), 0 )
8391 ds ["sample_cohort" ] = xr .DataArray (sample_cohorts , dims = "samples" )
92+ ds = Tajimas_D (ds )
93+ d = ds ["stat_Tajimas_D" ].compute ()
8494 ts_d = ts .Tajimas_D ()
85- d = Tajimas_D (ds )["stat_Tajimas_D" ].compute ()
8695 np .testing .assert_allclose (d , ts_d )
0 commit comments