Skip to content

Commit e261ab6

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Remove (N, T, H, W, C) => (N, T, C, H, W) from presets (#6058)
Summary: * Remove `(N, T, H, W, C) => (N, T, C, H, W)` conversion on presets * Update docs. * Fix the tests * Use `output_format` for `read_video()` * Use `output_format` for `Kinetics()` * Adding input descriptions on presets Reviewed By: NicolasHug Differential Revision: D36760943 fbshipit-source-id: 316f98583f39cc29b9a40f9c7c479b565981f088
1 parent d467afa commit e261ab6

File tree

5 files changed

+19
-9
lines changed

5 files changed

+19
-9
lines changed

docs/source/models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ Here is an example of how to use the pre-trained video classification models:
471471
from torchvision.io.video import read_video
472472
from torchvision.models.video import r3d_18, R3D_18_Weights
473473
474-
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi")
474+
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
475475
vid = vid[:32] # optionally shorten duration
476476
477477
# Step 1: Initialize model with the best available weights

gallery/plot_optical_flow.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def plot(imgs, **imshow_kwargs):
7272
# single model input.
7373

7474
from torchvision.io import read_video
75-
frames, _, _ = read_video(str(video_path))
76-
frames = frames.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
75+
frames, _, _ = read_video(str(video_path), output_format="TCHW")
7776

7877
img1_batch = torch.stack([frames[100], frames[150]])
7978
img2_batch = torch.stack([frames[101], frames[151]])

references/video_classification/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def main(args):
157157
"avi",
158158
"mp4",
159159
),
160+
output_format="TCHW",
160161
)
161162
if args.cache_dataset:
162163
print(f"Saving dataset_train to {cache_path}")
@@ -193,6 +194,7 @@ def main(args):
193194
"avi",
194195
"mp4",
195196
),
197+
output_format="TCHW",
196198
)
197199
if args.cache_dataset:
198200
print(f"Saving dataset_test to {cache_path}")

test/test_extended_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_transforms_jit(model_fn):
180180
"input_shape": (1, 3, 520, 520),
181181
},
182182
"video": {
183-
"input_shape": (1, 4, 112, 112, 3),
183+
"input_shape": (1, 4, 3, 112, 112),
184184
},
185185
"optical_flow": {
186186
"input_shape": (1, 3, 128, 128),

torchvision/transforms/_presets.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def __repr__(self) -> str:
2929
return self.__class__.__name__ + "()"
3030

3131
def describe(self) -> str:
32-
return "The images are rescaled to ``[0.0, 1.0]``."
32+
return (
33+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
34+
"The images are rescaled to ``[0.0, 1.0]``."
35+
)
3336

3437

3538
class ImageClassification(nn.Module):
@@ -70,6 +73,7 @@ def __repr__(self) -> str:
7073

7174
def describe(self) -> str:
7275
return (
76+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
7377
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
7478
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
7579
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
@@ -99,7 +103,6 @@ def forward(self, vid: Tensor) -> Tensor:
99103
vid = vid.unsqueeze(dim=0)
100104
need_squeeze = True
101105

102-
vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W)
103106
N, T, C, H, W = vid.shape
104107
vid = vid.view(-1, C, H, W)
105108
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
@@ -126,9 +129,11 @@ def __repr__(self) -> str:
126129

127130
def describe(self) -> str:
128131
return (
129-
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
132+
"Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
133+
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
130134
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
131-
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
135+
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
136+
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
132137
)
133138

134139

@@ -167,6 +172,7 @@ def __repr__(self) -> str:
167172

168173
def describe(self) -> str:
169174
return (
175+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
170176
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
171177
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
172178
f"``std={self.std}``."
@@ -196,4 +202,7 @@ def __repr__(self) -> str:
196202
return self.__class__.__name__ + "()"
197203

198204
def describe(self) -> str:
199-
return "The images are rescaled to ``[-1.0, 1.0]``."
205+
return (
206+
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
207+
"The images are rescaled to ``[-1.0, 1.0]``."
208+
)

0 commit comments

Comments
 (0)