@@ -578,3 +578,106 @@ def Tajimas_D(
578
578
579
579
new_ds = Dataset ({variables .stat_Tajimas_D : D })
580
580
return conditional_merge_datasets (ds , variables .validate (new_ds ), merge )
581
+
582
+
583
+ # c = cohorts
584
+ @guvectorize ( # type: ignore
585
+ ["void(float32[:, :], float32[:,:,:])" , "void(float64[:, :], float64[:,:,:])" ],
586
+ "(c,c)->(c,c,c)" ,
587
+ nopython = True ,
588
+ )
589
+ def _pbs (t : ArrayLike , out : ArrayLike ) -> None :
590
+ """Generalized U-function for computing PBS."""
591
+ out [:, :, :] = np .nan # (cohorts, cohorts, cohorts)
592
+ n_cohorts = t .shape [0 ]
593
+ # calculate PBS for each cohort triple
594
+ for i in range (n_cohorts ):
595
+ for j in range (i + 1 , n_cohorts ):
596
+ for k in range (j + 1 , n_cohorts ):
597
+ ret = (t [i , j ] + t [i , k ] - t [j , k ]) / 2
598
+ norm = 1 + (t [i , j ] + t [i , k ] + t [j , k ]) / 2
599
+ ret = ret / norm
600
+ out [i , j , k ] = ret
601
+
602
+
603
+ def pbs (
604
+ ds : Dataset ,
605
+ * ,
606
+ stat_Fst : Hashable = variables .stat_Fst ,
607
+ merge : bool = True ,
608
+ ) -> Dataset :
609
+ """Compute the population branching statistic (PBS) between cohort triples.
610
+
611
+ By default, values of this statistic are calculated per variant.
612
+ To compute values in windows, call :func:`window` before calling
613
+ this function.
614
+
615
+ Parameters
616
+ ----------
617
+ ds
618
+ Genotype call dataset.
619
+ stat_Fst
620
+ Fst variable to use or calculate. Defined by
621
+ :data:`sgkit.variables.stat_Fst_spec`.
622
+ If the variable is not present in ``ds``, it will be computed
623
+ using :func:`Fst`.
624
+ merge
625
+ If True (the default), merge the input dataset and the computed
626
+ output variables into a single dataset, otherwise return only
627
+ the computed output variables.
628
+ See :ref:`dataset_merge` for more details.
629
+
630
+ Returns
631
+ -------
632
+ A dataset containing the PBS value between cohort triples, as defined by
633
+ :data:`sgkit.variables.stat_pbs_spec`.
634
+ Shape (variants, cohorts, cohorts, cohorts), or
635
+ (windows, cohorts, cohorts, cohorts) if windowing information is available.
636
+
637
+ Warnings
638
+ --------
639
+ This method does not currently support datasets that are chunked along the
640
+ samples dimension.
641
+
642
+ Examples
643
+ --------
644
+
645
+ >>> import numpy as np
646
+ >>> import sgkit as sg
647
+ >>> import xarray as xr
648
+ >>> ds = sg.simulate_genotype_call_dataset(n_variant=5, n_sample=6)
649
+
650
+ >>> # Divide samples into three named cohorts
651
+ >>> n_cohorts = 3
652
+ >>> sample_cohort = np.repeat(range(n_cohorts), ds.dims["samples"] // n_cohorts)
653
+ >>> ds["sample_cohort"] = xr.DataArray(sample_cohort, dims="samples")
654
+ >>> cohort_names = [f"co_{i}" for i in range(n_cohorts)]
655
+ >>> ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names, "cohorts_2": cohort_names})
656
+
657
+ >>> # Divide into two windows of size three (variants)
658
+ >>> ds = sg.window(ds, size=3, step=3)
659
+ >>> sg.pbs(ds)["stat_pbs"].sel(cohorts_0="co_0", cohorts_1="co_1", cohorts_2="co_2").values # doctest: +NORMALIZE_WHITESPACE
660
+ array([ 0. , -0.160898])
661
+ """
662
+
663
+ ds = define_variable_if_absent (ds , variables .stat_Fst , stat_Fst , Fst )
664
+ variables .validate (ds , {stat_Fst : variables .stat_Fst_spec })
665
+
666
+ fst = ds [variables .stat_Fst ]
667
+ fst = fst .clip (min = 0 , max = 0.99999 )
668
+
669
+ t = - np .log (1 - fst )
670
+ n_cohorts = ds .dims ["cohorts" ]
671
+ n_windows = ds .dims ["windows" ]
672
+ assert_array_shape (t , n_windows , n_cohorts , n_cohorts )
673
+
674
+ # calculate PBS triples
675
+ t = da .asarray (t )
676
+ shape = (t .chunks [0 ], n_cohorts , n_cohorts , n_cohorts )
677
+ p = da .map_blocks (_pbs , t , chunks = shape , new_axis = 3 , dtype = np .float64 )
678
+ assert_array_shape (p , n_windows , n_cohorts , n_cohorts , n_cohorts )
679
+
680
+ new_ds = Dataset (
681
+ {variables .stat_pbs : (["windows" , "cohorts_0" , "cohorts_1" , "cohorts_2" ], p )}
682
+ )
683
+ return conditional_merge_datasets (ds , variables .validate (new_ds ), merge )
0 commit comments