Skip to content

Commit 6bf417f

Browse files
committed
Adding tests that use the PytorchBigWigDataset object. Also adding custom_position_sampler and custom_track_sampler options to the dataset objects.
1 parent 8301e12 commit 6bf417f

File tree

10 files changed

+290
-26
lines changed

10 files changed

+290
-26
lines changed

bigwig_loader/dataset.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
from typing import Any
44
from typing import Callable
5+
from typing import Iterable
56
from typing import Iterator
67
from typing import Literal
78
from typing import Optional
@@ -78,6 +79,12 @@ class BigWigDataset:
7879
GPU. More threads means that more IO can take place while the GPU is busy doing
7980
calculations (decompressing or neural network training for example). More threads
8081
also means a higher GPU memory usage. Default: 4
82+
custom_position_sampler: if set, this sampler will be used instead of the default
83+
position sampler (which samples randomly and uniform from regions of interest)
84+
This should be an iterable of tuples (chromosome, center).
85+
custom_track_sampler: if specified, this sampler will be used to sample tracks. When not
86+
specified, each batch simply contains all tracks, or a randomly sellected subset of
87+
tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices.
8188
return_batch_objects: if True, the batches will be returned as instances of
8289
bigwig_loader.batch.Batch
8390
"""
@@ -107,6 +114,8 @@ def __init__(
107114
repeat_same_positions: bool = False,
108115
sub_sample_tracks: Optional[int] = None,
109116
n_threads: int = 4,
117+
custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None,
118+
custom_track_sampler: Optional[Iterable[list[int]]] = None,
110119
return_batch_objects: bool = False,
111120
):
112121
super().__init__()
@@ -152,32 +161,34 @@ def __init__(
152161
self._sub_sample_tracks = sub_sample_tracks
153162
self._n_threads = n_threads
154163
self._return_batch_objects = return_batch_objects
155-
156-
def _create_dataloader(self) -> StreamedDataloader:
157-
position_sampler = RandomPositionSampler(
164+
self._position_sampler = custom_position_sampler or RandomPositionSampler(
158165
regions_of_interest=self.regions_of_interest,
159166
buffer_size=self._position_sampler_buffer_size,
160167
repeat_same=self._repeat_same_positions,
161168
)
169+
if custom_track_sampler is not None:
170+
self._track_sampler: Optional[Iterable[list[int]]] = custom_track_sampler
171+
elif sub_sample_tracks is not None:
172+
self._track_sampler = TrackSampler(
173+
total_number_of_tracks=len(self.bigwig_collection),
174+
sample_size=sub_sample_tracks,
175+
)
176+
else:
177+
self._track_sampler = None
162178

179+
def _create_dataloader(self) -> StreamedDataloader:
163180
sequence_sampler = GenomicSequenceSampler(
164181
reference_genome_path=self.reference_genome_path,
165182
sequence_length=self.sequence_length,
166-
position_sampler=position_sampler,
183+
position_sampler=self._position_sampler,
167184
maximum_unknown_bases_fraction=self.maximum_unknown_bases_fraction,
168185
)
169-
track_sampler = None
170-
if self._sub_sample_tracks is not None:
171-
track_sampler = TrackSampler(
172-
total_number_of_tracks=len(self.bigwig_collection),
173-
sample_size=self._sub_sample_tracks,
174-
)
175186

176187
query_batch_generator = QueryBatchGenerator(
177188
genomic_location_sampler=sequence_sampler,
178189
center_bin_to_predict=self.center_bin_to_predict,
179190
batch_size=self.super_batch_size,
180-
track_sampler=track_sampler,
191+
track_sampler=self._track_sampler,
181192
)
182193

183194
return StreamedDataloader(

bigwig_loader/pytorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
from typing import Any
33
from typing import Callable
4+
from typing import Iterable
45
from typing import Iterator
56
from typing import Literal
67
from typing import Optional
@@ -165,6 +166,12 @@ class PytorchBigWigDataset(IterableDataset[BATCH_TYPE]):
165166
also means a higher GPU memory usage. Default: 4
166167
return_batch_objects: if True, the batches will be returned as instances of
167168
bigwig_loader.pytorch.PytorchBatch
169+
custom_position_sampler: if set, this sampler will be used instead of the default
170+
position sampler (which samples randomly and uniform from regions of interest)
171+
This should be an iterable of tuples (chromosome, center).
172+
custom_track_sampler: if specified, this sampler will be used to sample tracks. When not
173+
specified, each batch simply contains all tracks, or a randomly sellected subset of
174+
tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices.
168175
"""
169176

170177
def __init__(
@@ -192,6 +199,8 @@ def __init__(
192199
repeat_same_positions: bool = False,
193200
sub_sample_tracks: Optional[int] = None,
194201
n_threads: int = 4,
202+
custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None,
203+
custom_track_sampler: Optional[Iterable[list[int]]] = None,
195204
return_batch_objects: bool = False,
196205
):
197206
super().__init__()
@@ -217,6 +226,8 @@ def __init__(
217226
repeat_same_positions=repeat_same_positions,
218227
sub_sample_tracks=sub_sample_tracks,
219228
n_threads=n_threads,
229+
custom_position_sampler=custom_position_sampler,
230+
custom_track_sampler=custom_track_sampler,
220231
return_batch_objects=True,
221232
)
222233
self._return_batch_objects = return_batch_objects

bigwig_loader/sampler/genome_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
from typing import Any
33
from typing import Callable
4+
from typing import Iterable
45
from typing import Iterator
56
from typing import Literal
67
from typing import Optional
@@ -21,7 +22,7 @@ def __init__(
2122
self,
2223
reference_genome_path: Path,
2324
sequence_length: int,
24-
position_sampler: Iterator[tuple[str, int]],
25+
position_sampler: Iterable[tuple[str, int]],
2526
maximum_unknown_bases_fraction: float = 0.1,
2627
):
2728
self.reference_genome_path = reference_genome_path

bigwig_loader/sampler/position_sampler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import numpy as np
44
import pandas as pd
55

6-
from bigwig_loader.util import make_cumulative_index_intervals
6+
7+
def make_cumulative_index_intervals(intervals: pd.DataFrame) -> pd.DataFrame:
8+
intervals.reset_index(drop=True, inplace=True)
9+
intervals.index = (
10+
(intervals["end"] - intervals["start"]).cumsum().shift().fillna(0).astype(int) # type: ignore
11+
)
12+
return intervals
713

814

915
class RandomPositionSampler:
@@ -22,6 +28,8 @@ def __init__(
2228
self._repeat_same = repeat_same
2329

2430
def __iter__(self) -> Iterator[tuple[str, int]]:
31+
if self._repeat_same:
32+
self._index = 0
2533
return self
2634

2735
def __next__(self) -> tuple[str, int]:
@@ -36,6 +44,7 @@ def __next__(self) -> tuple[str, int]:
3644
return chromosome, center
3745

3846
def _refresh_buffer(self) -> None:
47+
print("refresh buffer called")
3948
batch_rand = np.random.randint(
4049
low=0, high=self._max_index, size=self.buffer_size
4150
)

bigwig_loader/util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@ def sort_intervals(intervals: pd.DataFrame, inplace: bool = False) -> pd.DataFra
2828
)
2929

3030

31-
def make_cumulative_index_intervals(intervals: pd.DataFrame) -> pd.DataFrame:
32-
intervals.reset_index(drop=True, inplace=True)
33-
intervals.index = (
34-
(intervals["end"] - intervals["start"]).cumsum().shift().fillna(0).astype(int) # type: ignore
35-
)
36-
return intervals
37-
38-
3931
_string_to_encoding = {
4032
"A": [1.0, 0.0, 0.0, 0.0],
4133
"C": [0.0, 1.0, 0.0, 0.0],

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import pytest
66

77
from bigwig_loader import config
8+
from bigwig_loader.download_example_data import get_example_bigwigs_files
9+
from bigwig_loader.download_example_data import get_reference_genome
810

911
try:
1012
from bigwig_loader.collection import BigWigCollection
11-
from bigwig_loader.download_example_data import get_example_bigwigs_files
12-
from bigwig_loader.download_example_data import get_reference_genome
1313
except ImportError:
1414
logging.warning(
1515
"Can not import from bigwig_loader.collection without cupy installed"

tests/test_dataset.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,38 @@ def test_batch_return_type(bigwig_path, reference_genome_path, merged_intervals)
7373
for i, batch in enumerate(dataset):
7474
assert isinstance(batch, Batch)
7575
assert batch.track_indices is not None
76+
77+
78+
def test_positions_are_reproducible(
79+
bigwig_path, reference_genome_path, merged_intervals
80+
):
81+
batch_size = 16
82+
83+
dataset = BigWigDataset(
84+
regions_of_interest=merged_intervals,
85+
collection=bigwig_path,
86+
reference_genome_path=reference_genome_path,
87+
sequence_length=2000,
88+
center_bin_to_predict=1000,
89+
window_size=4,
90+
batch_size=batch_size,
91+
batches_per_epoch=10,
92+
maximum_unknown_bases_fraction=0.1,
93+
first_n_files=2,
94+
repeat_same_positions=True,
95+
n_threads=1,
96+
return_batch_objects=True,
97+
)
98+
99+
starts_a = [
100+
position
101+
for batch in dataset
102+
for position in zip(batch.chromosomes, batch.starts)
103+
]
104+
starts_b = [
105+
position
106+
for batch in dataset
107+
for position in zip(batch.chromosomes, batch.starts)
108+
]
109+
110+
assert starts_a == starts_b

tests/test_position_sampler.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from bigwig_loader.sampler.position_sampler import RandomPositionSampler
2+
3+
4+
def test_repeat_same_positions(merged_intervals):
5+
sampler = RandomPositionSampler(
6+
regions_of_interest=merged_intervals, repeat_same=True
7+
)
8+
9+
first_samples = []
10+
for i, sample in enumerate(sampler):
11+
first_samples.append(sample)
12+
if i == 5:
13+
break
14+
second_samples = []
15+
for i, sample in enumerate(sampler):
16+
second_samples.append(sample)
17+
if i == 5:
18+
break
19+
20+
assert first_samples == second_samples
21+
22+
23+
def test_not_repeat_same_positions(merged_intervals):
24+
sampler = RandomPositionSampler(
25+
regions_of_interest=merged_intervals, repeat_same=False
26+
)
27+
28+
first_samples = []
29+
for i, sample in enumerate(sampler):
30+
first_samples.append(sample)
31+
if i == 5:
32+
break
33+
second_samples = []
34+
for i, sample in enumerate(sampler):
35+
second_samples.append(sample)
36+
if i == 5:
37+
break
38+
39+
assert first_samples != second_samples

tests/test_pytorch_dataset.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from math import isnan
2+
3+
import pandas as pd
14
import pytest
25

6+
from bigwig_loader import config
7+
38
torch = pytest.importorskip("torch")
49

510

@@ -30,3 +35,81 @@ def test_input_and_target_is_torch_tensor(pytorch_dataset):
3035
sequence, target = next(iter(pytorch_dataset))
3136
assert isinstance(sequence, torch.Tensor)
3237
assert isinstance(target, torch.Tensor)
38+
39+
40+
@pytest.mark.parametrize("default_value", [0.0, torch.nan, 4.0, 5.6])
41+
def test_pytorch_dataset_with_window_function(
42+
default_value, bigwig_path, reference_genome_path, merged_intervals
43+
):
44+
from bigwig_loader.pytorch import PytorchBigWigDataset
45+
46+
center_bin_to_predict = 2048
47+
window_size = 128
48+
reduced_dim = center_bin_to_predict // window_size
49+
50+
batch_size = 16
51+
52+
df = pd.read_csv(config.example_positions, sep="\t")
53+
df = df[df["chr"].isin({"chr1", "chr3", "chr5"})]
54+
chromosomes = list(df["chr"])[:batch_size]
55+
centers = list(df["center"])[:batch_size]
56+
57+
position_sampler = [(chrom, center) for chrom, center in zip(chromosomes, centers)]
58+
59+
dataset = PytorchBigWigDataset(
60+
regions_of_interest=merged_intervals,
61+
collection=bigwig_path,
62+
reference_genome_path=reference_genome_path,
63+
sequence_length=center_bin_to_predict * 2,
64+
center_bin_to_predict=center_bin_to_predict,
65+
window_size=1,
66+
batch_size=batch_size,
67+
batches_per_epoch=1,
68+
maximum_unknown_bases_fraction=0.1,
69+
first_n_files=3,
70+
custom_position_sampler=position_sampler,
71+
default_value=default_value,
72+
return_batch_objects=True,
73+
)
74+
75+
dataset_with_window = PytorchBigWigDataset(
76+
regions_of_interest=merged_intervals,
77+
collection=bigwig_path,
78+
reference_genome_path=reference_genome_path,
79+
sequence_length=center_bin_to_predict * 2,
80+
center_bin_to_predict=center_bin_to_predict,
81+
window_size=window_size,
82+
batch_size=batch_size,
83+
batches_per_epoch=1,
84+
maximum_unknown_bases_fraction=0.1,
85+
first_n_files=3,
86+
custom_position_sampler=position_sampler,
87+
default_value=default_value,
88+
return_batch_objects=True,
89+
)
90+
91+
print(dataset_with_window._dataset.bigwig_collection.bigwig_paths)
92+
93+
for batch, batch_with_window in zip(dataset, dataset_with_window):
94+
print(batch)
95+
print(batch_with_window)
96+
print(batch.chromosomes)
97+
print(batch_with_window.chromosomes)
98+
print(batch.starts)
99+
print(batch_with_window.starts)
100+
print(batch.ends)
101+
print(batch_with_window.ends)
102+
expected = batch.values.reshape(
103+
batch.values.shape[0], batch.values.shape[1], reduced_dim, window_size
104+
)
105+
if not isnan(default_value) or default_value == 0:
106+
expected = torch.nan_to_num(expected, nan=default_value)
107+
expected = torch.nanmean(expected, axis=-1)
108+
print("---")
109+
print("expected")
110+
print(expected)
111+
print("batch_with_window")
112+
print(batch_with_window.values)
113+
assert torch.allclose(expected, batch_with_window.values, equal_nan=True)
114+
if isnan(default_value):
115+
assert torch.isnan(batch_with_window.values).any()

0 commit comments

Comments
 (0)