22import logging
33import tempfile
44from pathlib import Path
5- from threading import RLock
65from typing import (
76 Any ,
87 Dict ,
2221import pandas as pd
2322import xarray as xr
2423import zarr
25- from cachetools import LRUCache , cached
2624from cbgen import bgen_file , bgen_metafile
2725from rechunker import api as rechunker_api
2826from xarray import Dataset
@@ -85,7 +83,8 @@ def __init__(
8583 self .partition_size = mf .partition_size
8684
8785 self .shape = (self .n_variants , self .n_samples , 3 )
88- self .dtype = dtype
86+ self .dtype = np .dtype (dtype )
87+ self .precision = 64 if self .dtype .itemsize >= 8 else 32
8988 self .ndim = 3
9089
9190 def __getitem__ (self , idx : Any ) -> np .ndarray :
@@ -135,7 +134,7 @@ def __getitem__(self, idx: Any) -> np.ndarray:
135134 with bgen_file (self .path ) as bgen :
136135 res = None
137136 for i , vaddr in enumerate (all_vaddr ):
138- probs = bgen .read_probability (vaddr , precision = 32 )[idx [1 ]]
137+ probs = bgen .read_probability (vaddr , precision = self . precision )[idx [1 ]]
139138 assert len (probs .shape ) == 2 and probs .shape [1 ] == 3
140139 if res is None :
141140 res = np .zeros ((len (all_vaddr ), len (probs ), 3 ), dtype = self .dtype )
@@ -144,10 +143,6 @@ def __getitem__(self, idx: Any) -> np.ndarray:
144143 return np .squeeze (res , axis = squeeze_dims )
145144
146145
147- cache = LRUCache (maxsize = 3 )
148- lock = RLock ()
149-
150-
151146def _split_alleles (allele_ids : bytes ) -> List [bytes ]:
152147 alleles = allele_ids .split (b"," )
153148 if len (alleles ) != 2 :
@@ -157,7 +152,6 @@ def _split_alleles(allele_ids: bytes) -> List[bytes]:
157152 return alleles
158153
159154
160- @cached (cache , lock = lock ) # type: ignore[misc]
161155def _read_metafile_partition (path : Path , partition : int ) -> pd .DataFrame :
162156 with bgen_metafile (path ) as mf :
163157 part = mf .read_partition (partition )
@@ -243,17 +237,42 @@ def read_bgen(
243237 be read multiple times when False.
244238 contig_dtype
245239 Data type for contig names, by default "str".
246- This may be an integer type, but this will fail if any of the contig names cannot be
247- converted to integers.
240+ This may also be an integer type (e.g. "int") , but will fail if any of the contig names
241+ cannot be converted to integers.
248242 gp_dtype
249243 Data type for genotype probabilities, by default "float32".
250244
251245 Warnings
252246 --------
253247 Only bi-allelic, diploid BGEN files are currently supported.
248+
249+ Returns
250+ -------
251+ A dataset containing the following variables:
252+
253+ - :data:`sgkit.variables.variant_id` (variants)
254+ - :data:`sgkit.variables.variant_contig` (variants)
255+ - :data:`sgkit.variables.variant_position` (variants)
256+ - :data:`sgkit.variables.variant_allele` (variants)
257+ - :data:`sgkit.variables.sample_id` (samples)
258+ - :data:`sgkit.variables.call_dosage` (variants, samples)
259+ - :data:`sgkit.variables.call_dosage_mask` (variants, samples)
260+ - :data:`sgkit.variables.call_genotype_probability` (variants, samples, genotypes)
261+ - :data:`sgkit.variables.call_genotype_probability_mask` (variants, samples, genotypes)
262+
254263 """
255264 if isinstance (chunks , tuple ) and len (chunks ) != 3 :
256- raise ValueError (f"Chunks must be tuple with 3 items, not { chunks } " )
265+ raise ValueError (f"`chunks` must be tuple with 3 items, not { chunks } " )
266+ if not np .issubdtype (gp_dtype , np .floating ):
267+ raise ValueError (
268+ f"`gp_dtype` must be a floating point data type, not { gp_dtype } "
269+ )
270+ if not np .issubdtype (contig_dtype , np .integer ) and np .dtype (
271+ contig_dtype
272+ ).kind not in {"U" , "S" }:
273+ raise ValueError (
274+ f"`contig_dtype` must be of string or int type, not { contig_dtype } "
275+ )
257276
258277 path = Path (path )
259278 sample_path = Path (sample_path ) if sample_path else path .with_suffix (".sample" )
0 commit comments