Skip to content

Commit e2622d7

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Using the new dataset idx API everywhere.
Summary: Using the API from D35012121 everywhere. Reviewed By: bottler Differential Revision: D35045870 fbshipit-source-id: dab112b5e04160334859bbe8fa2366344b6e0f70
1 parent c0bb49b commit e2622d7

File tree

5 files changed

+17
-15
lines changed

5 files changed

+17
-15
lines changed

projects/implicitron_trainer/visualize_reconstruction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def render_sequence(
6767
if seed is None:
6868
seed = hash(sequence_name)
6969
print(f"Loading all data of sequence '{sequence_name}'.")
70-
seq_idx = dataset.seq_to_idx[sequence_name]
70+
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
7171
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
7272
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
7373
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
@@ -345,7 +345,7 @@ def export_scenes(
345345
dataset = dataset_zoo(**config.dataset_args)[split]
346346

347347
# iterate over the sequences in the dataset
348-
for sequence_name in dataset.seq_to_idx.keys():
348+
for sequence_name in dataset.sequence_names():
349349
with torch.no_grad():
350350
render_sequence(
351351
dataset,

pytorch3d/implicitron/dataset/implicitron_dataset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
210210
# Maps sequence name to the sequence's global frame indices.
211211
# It is used for the default implementations of some functions in this class.
212212
# Implementations which override them are free to ignore this member.
213-
seq_to_idx: Dict[str, List[int]] = field(init=False)
213+
_seq_to_idx: Dict[str, List[int]] = field(init=False)
214214

215215
def __len__(self) -> int:
216216
raise NotImplementedError
@@ -223,7 +223,7 @@ def get_frame_numbers_and_timestamps(
223223
unordered views, then the dataset should override this method to
224224
return the index and timestamp in their videos of the frames whose
225225
indices are given in `idxs`. In addition,
226-
the values in seq_to_idx should be in ascending order.
226+
the values in _seq_to_idx should be in ascending order.
227227
If timestamps are absent, they should be replaced with a constant.
228228
229229
This is used for letting SceneBatchSampler identify consecutive
@@ -244,7 +244,7 @@ def get_eval_batches(self) -> Optional[List[List[int]]]:
244244

245245
def sequence_names(self) -> Iterable[str]:
246246
"""Returns an iterator over sequence names in the dataset."""
247-
return self.seq_to_idx.keys()
247+
return self._seq_to_idx.keys()
248248

249249
def sequence_frames_in_order(
250250
self, seq_name: str
@@ -262,7 +262,7 @@ def sequence_frames_in_order(
262262
`dataset_idx` is the index within the dataset.
263263
`None` timestamps are replaced with 0s.
264264
"""
265-
seq_frame_indices = self.seq_to_idx[seq_name]
265+
seq_frame_indices = self._seq_to_idx[seq_name]
266266
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
267267

268268
yield from sorted(
@@ -411,7 +411,7 @@ def seq_frame_index_to_dataset_index(
411411
self.frame_annots[idx]["frame_annotation"].frame_number: idx
412412
for idx in seq_idx
413413
}
414-
for seq, seq_idx in self.seq_to_idx.items()
414+
for seq, seq_idx in self._seq_to_idx.items()
415415
}
416416

417417
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
@@ -804,7 +804,7 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
804804
if self.n_frames_per_sequence > 0:
805805
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
806806
keep_idx = []
807-
for seq, seq_indices in self.seq_to_idx.items():
807+
for seq, seq_indices in self._seq_to_idx.items():
808808
# infer the seed from the sequence name, this is reproducible
809809
# and makes the selection differ for different sequences
810810
seed = _seq_name_to_seed(seq) + self.seed
@@ -826,20 +826,20 @@ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
826826
self._invalidate_indexes(filter_seq_annots=True)
827827

828828
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
829-
# update seq_to_idx and filter seq_meta according to frame_annots change
830-
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
829+
# update _seq_to_idx and filter seq_meta according to frame_annots change
830+
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
831831
self._invalidate_seq_to_idx()
832832

833833
if filter_seq_annots:
834834
self.seq_annots = {
835-
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx
835+
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
836836
}
837837

838838
def _invalidate_seq_to_idx(self) -> None:
839839
seq_to_idx = defaultdict(list)
840840
for idx, entry in enumerate(self.frame_annots):
841841
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
842-
self.seq_to_idx = seq_to_idx
842+
self._seq_to_idx = seq_to_idx
843843

844844
def _resize_image(
845845
self, image, mode="bilinear"

pytorch3d/implicitron/eval_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _get_all_source_cameras(
198198
"""
199199

200200
# load all source cameras of the sequence
201-
seq_idx = dataset.seq_to_idx[sequence_name]
201+
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
202202
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
203203
(all_frame_data,) = torch.utils.data.DataLoader(
204204
dataset_for_loader,

tests/implicitron/test_batch_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, num_seq, max_frame_gap=1):
2525
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
2626
"""
2727
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
28-
self.seq_to_idx = {
28+
self._seq_to_idx = {
2929
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
3030
}
3131

tests/implicitron/test_evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import contextlib
99
import copy
1010
import dataclasses
11+
import itertools
1112
import math
1213
import os
1314
import unittest
@@ -285,6 +286,7 @@ def _one_sequence_test(
285286

286287
def test_full_eval(self, n_sequences=5):
287288
"""Test evaluation."""
288-
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
289+
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
290+
idx = list(self.dataset.sequence_indices_in_order(seq))
289291
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
290292
self._one_sequence_test(seq_dataset)

0 commit comments

Comments
 (0)