@@ -29,7 +29,10 @@ def __repr__(self) -> str:
29
29
return self .__class__ .__name__ + "()"
30
30
31
31
def describe (self ) -> str :
32
- return "The images are rescaled to ``[0.0, 1.0]``."
32
+ return (
33
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
34
+ "The images are rescaled to ``[0.0, 1.0]``."
35
+ )
33
36
34
37
35
38
class ImageClassification (nn .Module ):
@@ -70,6 +73,7 @@ def __repr__(self) -> str:
70
73
71
74
def describe (self ) -> str :
72
75
return (
76
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
73
77
f"The images are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
74
78
f"followed by a central crop of ``crop_size={ self .crop_size } ``. Finally the values are first rescaled to "
75
79
f"``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
@@ -99,7 +103,6 @@ def forward(self, vid: Tensor) -> Tensor:
99
103
vid = vid .unsqueeze (dim = 0 )
100
104
need_squeeze = True
101
105
102
- vid = vid .permute (0 , 1 , 4 , 2 , 3 ) # (N, T, H, W, C) => (N, T, C, H, W)
103
106
N , T , C , H , W = vid .shape
104
107
vid = vid .view (- 1 , C , H , W )
105
108
vid = F .resize (vid , self .resize_size , interpolation = self .interpolation )
@@ -126,9 +129,11 @@ def __repr__(self) -> str:
126
129
127
130
def describe (self ) -> str :
128
131
return (
129
- f"The video frames are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
132
+ "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
133
+ f"The frames are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
130
134
f"followed by a central crop of ``crop_size={ self .crop_size } ``. Finally the values are first rescaled to "
131
- f"``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
135
+ f"``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and ``std={ self .std } ``. Finally the output "
136
+ "dimensions are permuted to ``(..., C, T, H, W)`` tensors."
132
137
)
133
138
134
139
@@ -167,6 +172,7 @@ def __repr__(self) -> str:
167
172
168
173
def describe (self ) -> str :
169
174
return (
175
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
170
176
f"The images are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``. "
171
177
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and "
172
178
f"``std={ self .std } ``."
@@ -196,4 +202,7 @@ def __repr__(self) -> str:
196
202
return self .__class__ .__name__ + "()"
197
203
198
204
def describe (self ) -> str :
199
- return "The images are rescaled to ``[-1.0, 1.0]``."
205
+ return (
206
+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
207
+ "The images are rescaled to ``[-1.0, 1.0]``."
208
+ )
0 commit comments