5
5
import torch
6
6
from torch .nn .functional import conv2d , pad as torch_pad
7
7
from torchvision .prototype import features
8
- from torchvision .transforms import functional_tensor as _FT
9
8
from torchvision .transforms .functional import pil_to_tensor , to_pil_image
10
9
11
10
@@ -68,9 +67,9 @@ def normalize(
68
67
69
68
70
69
def _get_gaussian_kernel1d (kernel_size : int , sigma : float , dtype : torch .dtype , device : torch .device ) -> torch .Tensor :
71
- lim = (kernel_size - 1 ) / (2 * math .sqrt (2 ) * sigma )
70
+ lim = (kernel_size - 1 ) / (2.0 * math .sqrt (2.0 ) * sigma )
72
71
x = torch .linspace (- lim , lim , steps = kernel_size , dtype = dtype , device = device )
73
- kernel1d = torch .softmax (- x .pow_ (2 ), dim = 0 )
72
+ kernel1d = torch .softmax (x .pow_ (2 ). neg_ ( ), dim = 0 )
74
73
return kernel1d
75
74
76
75
@@ -89,54 +88,61 @@ def gaussian_blur_image_tensor(
89
88
# TODO: consider deprecating integers from sigma on the future
90
89
if isinstance (kernel_size , int ):
91
90
kernel_size = [kernel_size , kernel_size ]
92
- if len (kernel_size ) != 2 :
91
+ elif len (kernel_size ) != 2 :
93
92
raise ValueError (f"If kernel_size is a sequence its length should be 2. Got { len (kernel_size )} " )
94
93
for ksize in kernel_size :
95
94
if ksize % 2 == 0 or ksize < 0 :
96
95
raise ValueError (f"kernel_size should have odd and positive integers. Got { kernel_size } " )
97
96
98
97
if sigma is None :
99
98
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size ]
100
-
101
- if sigma is not None and not isinstance (sigma , (int , float , list , tuple )):
102
- raise TypeError (f"sigma should be either float or sequence of floats. Got { type (sigma )} " )
103
- if isinstance (sigma , (int , float )):
104
- sigma = [float (sigma ), float (sigma )]
105
- if isinstance (sigma , (list , tuple )) and len (sigma ) == 1 :
106
- sigma = [sigma [0 ], sigma [0 ]]
107
- if len (sigma ) != 2 :
108
- raise ValueError (f"If sigma is a sequence, its length should be 2. Got { len (sigma )} " )
99
+ else :
100
+ if isinstance (sigma , (list , tuple )):
101
+ length = len (sigma )
102
+ if length == 1 :
103
+ s = float (sigma [0 ])
104
+ sigma = [s , s ]
105
+ elif length != 2 :
106
+ raise ValueError (f"If sigma is a sequence, its length should be 2. Got { length } " )
107
+ elif isinstance (sigma , (int , float )):
108
+ s = float (sigma )
109
+ sigma = [s , s ]
110
+ else :
111
+ raise TypeError (f"sigma should be either float or sequence of floats. Got { type (sigma )} " )
109
112
for s in sigma :
110
113
if s <= 0.0 :
111
114
raise ValueError (f"sigma should have positive values. Got { sigma } " )
112
115
113
116
if image .numel () == 0 :
114
117
return image
115
118
119
+ dtype = image .dtype
116
120
shape = image .shape
117
-
118
- if image .ndim > 4 :
121
+ ndim = image .ndim
122
+ if ndim == 3 :
123
+ image = image .unsqueeze (dim = 0 )
124
+ elif ndim > 4 :
119
125
image = image .reshape ((- 1 ,) + shape [- 3 :])
120
- needs_unsquash = True
121
- else :
122
- needs_unsquash = False
123
126
124
- dtype = image . dtype if torch .is_floating_point (image ) else torch . float32
125
- kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype , device = image .device )
126
- kernel = kernel .expand (image . shape [- 3 ], 1 , kernel .shape [0 ], kernel .shape [1 ])
127
+ fp = torch .is_floating_point (image )
128
+ kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype if fp else torch . float32 , device = image .device )
129
+ kernel = kernel .expand (shape [- 3 ], 1 , kernel .shape [0 ], kernel .shape [1 ])
127
130
128
- image , need_cast , need_squeeze , out_dtype = _FT . _cast_squeeze_in ( image , [ kernel . dtype ] )
131
+ output = image if fp else image . to ( dtype = torch . float32 )
129
132
130
133
# padding = (left, right, top, bottom)
131
134
padding = [kernel_size [0 ] // 2 , kernel_size [0 ] // 2 , kernel_size [1 ] // 2 , kernel_size [1 ] // 2 ]
132
- output = torch_pad (image , padding , mode = "reflect" )
133
- output = conv2d (output , kernel , groups = output .shape [- 3 ])
134
-
135
- output = _FT ._cast_squeeze_out (output , need_cast , need_squeeze , out_dtype )
135
+ output = torch_pad (output , padding , mode = "reflect" )
136
+ output = conv2d (output , kernel , groups = shape [- 3 ])
136
137
137
- if needs_unsquash :
138
+ if ndim == 3 :
139
+ output = output .squeeze (dim = 0 )
140
+ elif ndim > 4 :
138
141
output = output .reshape (shape )
139
142
143
+ if not fp :
144
+ output = output .round_ ().to (dtype = dtype )
145
+
140
146
return output
141
147
142
148
0 commit comments