1
+ // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+ // =============================================================================
15
+
16
+ #if GOOGLE_CUDA
17
+
18
+ #define EIGEN_USE_GPU
19
+
20
+ #include < stdio.h>
21
+
22
+ #include < cmath>
23
+
24
+ #include " tensorflow/core/framework/register_types.h"
25
+ #include " tensorflow/core/util/gpu_kernel_helper.h"
26
+ #include " tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h"
27
+
28
+ namespace tensorflow {
29
+ namespace addons {
30
+
31
+ using GPUDevice = Eigen::GpuDevice;
32
+
33
+ namespace {
34
+
35
+ #define GET_DATA_POINT (x, y ) \
36
+ data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \
37
+ chan]
38
+
39
+ template <typename T>
40
+ __global__ void Resampler2DKernel (const T* __restrict__ data,
41
+ const T* __restrict__ warp,
42
+ T* __restrict__ output, const int batch_size,
43
+ const int data_height, const int data_width,
44
+ const int data_channels,
45
+ const int num_sampling_points) {
46
+ const int output_data_size = batch_size * num_sampling_points * data_channels;
47
+ CUDA_1D_KERNEL_LOOP (index, output_data_size) {
48
+ const int out_index = index;
49
+
50
+ // Get (idxSample, channel, point) from the index.
51
+ // Use this formula
52
+ // index = batch_id * num_sampling_points * num_chans +
53
+ // sample_id * num_chans + chan_id,
54
+ // with sample_id = [0, ... ,num_sampling_points)
55
+ const int data_batch_stride = data_height * data_width * data_channels;
56
+ const int warp_batch_stride = num_sampling_points * 2 ;
57
+ const int output_batch_stride = num_sampling_points * data_channels;
58
+
59
+ const int batch_id = index / output_batch_stride;
60
+ const int index_in_batch = index % output_batch_stride;
61
+ const int chan = index_in_batch % data_channels;
62
+ const int sample_id = index_in_batch / data_channels;
63
+
64
+ // Get coords of 2D point where data will be resampled
65
+ const T x = warp[batch_id * warp_batch_stride + sample_id * 2 ];
66
+ const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1 ];
67
+ const T zero = static_cast <T>(0.0 );
68
+ const T one = static_cast <T>(1.0 );
69
+ // The interpolation function:
70
+ // a) implicitly pads the input data with 0s (hence the unusual checks
71
+ // with {x,y} > -1)
72
+ // b) returns 0 when sampling outside the (padded) image.
73
+ // The effect is that the sampled signal smoothly goes to 0 outside
74
+ // the original input domain, rather than presenting a jump
75
+ // discontinuity at the image boundaries.
76
+ if (x > static_cast <T>(-1.0 ) && y > static_cast <T>(-1.0 ) &&
77
+ x < static_cast <T>(data_width) && y < static_cast <T>(data_height)) {
78
+ // Precompute floor (f) and ceil (c) values for x and y.
79
+ const int fx = std::floor (static_cast <float >(x));
80
+ const int fy = std::floor (static_cast <float >(y));
81
+ const int cx = fx + 1 ;
82
+ const int cy = fy + 1 ;
83
+ const T dx = static_cast <T>(cx) - x;
84
+ const T dy = static_cast <T>(cy) - y;
85
+
86
+ const T img_fxfy =
87
+ (fx >= 0 && fy >= 0 ) ? dx * dy * GET_DATA_POINT (fx, fy) : zero;
88
+
89
+ const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1 )
90
+ ? (one - dx) * (one - dy) * GET_DATA_POINT (cx, cy)
91
+ : zero;
92
+
93
+ const T img_fxcy = (fx >= 0 && cy <= data_height - 1 )
94
+ ? dx * (one - dy) * GET_DATA_POINT (fx, cy)
95
+ : zero;
96
+
97
+ const T img_cxfy = (cx <= data_width - 1 && fy >= 0 )
98
+ ? (one - dx) * dy * GET_DATA_POINT (cx, fy)
99
+ : zero;
100
+
101
+ output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy;
102
+ } else {
103
+ output[out_index] = zero;
104
+ }
105
+ }
106
+ }
107
+
108
+ } // namespace
109
+
110
+ namespace functor {
111
+
112
+ template <typename T>
113
+ struct Resampler2DFunctor <GPUDevice, T> {
114
+ void operator ()(OpKernelContext* ctx, const GPUDevice& d,
115
+ const T* __restrict__ data, const T* __restrict__ warp,
116
+ T* __restrict__ output, const int batch_size,
117
+ const int data_height, const int data_width,
118
+ const int data_channels, const int num_sampling_points) {
119
+ const int output_data_size =
120
+ batch_size * num_sampling_points * data_channels;
121
+ GpuLaunchConfig config = GetGpuLaunchConfig (output_data_size, d);
122
+ TF_CHECK_OK (GpuLaunchKernel (
123
+ Resampler2DKernel<T>, config.block_count , config.thread_per_block , 0 ,
124
+ d.stream (), data, warp, output, batch_size, data_height, data_width,
125
+ data_channels, num_sampling_points));
126
+ }
127
+ };
128
+
129
+ // TODO(fviola): gcudacc fails at compile time with Eigen::half.
130
+ // template struct Resampler2DFunctor<GPUDevice, Eigen::half>;
131
+ template struct Resampler2DFunctor <GPUDevice, float >;
132
+ template struct Resampler2DFunctor <GPUDevice, double >;
133
+
134
+ } // namespace functor
135
+
136
+ namespace {
137
+
138
+ #define UPDATE_GRAD_DATA_POINT (x, y, v ) \
139
+ atomicAdd (grad_data + (batch_id * data_batch_stride + \
140
+ data_channels * (y * data_width + x) + chan), \
141
+ v)
142
+
143
+ template <typename T>
144
+ __global__ void ResamplerGrad2DKernel (
145
+ const T* __restrict__ data, const T* __restrict__ warp,
146
+ const T* __restrict__ grad_output, T* __restrict__ grad_data,
147
+ T* __restrict__ grad_warp, const int batch_size, const int data_height,
148
+ const int data_width, const int data_channels,
149
+ const int num_sampling_points) {
150
+ const int resampler_output_size =
151
+ batch_size * num_sampling_points * data_channels;
152
+ CUDA_1D_KERNEL_LOOP (index, resampler_output_size) {
153
+ const int out_index = index;
154
+
155
+ // Get (idxSample, channel, point) from the index.
156
+ // Use this formula
157
+ // index = batch_id * num_sampling_points * num_chans +
158
+ // sample_id * num_chans + chan_id,
159
+ // with sample_id = [0, ... ,num_sampling_points)
160
+ const int data_batch_stride = data_height * data_width * data_channels;
161
+ const int warp_batch_stride = num_sampling_points * 2 ;
162
+ const int output_batch_stride = num_sampling_points * data_channels;
163
+
164
+ const int batch_id = index / output_batch_stride;
165
+ const int index_in_batch = index % output_batch_stride;
166
+ const int chan = index_in_batch % data_channels;
167
+ const int sample_id = index_in_batch / data_channels;
168
+
169
+ // Get coords of 2D point where data will be resampled
170
+ const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2 ;
171
+ const int warp_id_y = warp_id_x + 1 ;
172
+ const T x = warp[warp_id_x];
173
+ const T y = warp[warp_id_y];
174
+ const T zero = static_cast <T>(0.0 );
175
+ const T one = static_cast <T>(1.0 );
176
+
177
+ // Get grad output
178
+ const T grad_output_value = grad_output[out_index];
179
+ // The interpolation function whose gradient this kernel implements:
180
+ // a) implicitly pads the input data with 0s (hence the unusual checks
181
+ // with {x,y} > -1)
182
+ // b) returns 0 when sampling outside the (padded) image.
183
+ // The effect is that the sampled signal smoothly goes to 0 outside
184
+ // the original input domain, rather than presenting a jump
185
+ // discontinuity at the image boundaries.
186
+ if (x > static_cast <T>(-1.0 ) && y > static_cast <T>(-1.0 ) &&
187
+ x < static_cast <T>(data_width) && y < static_cast <T>(data_height)) {
188
+ // Precompute floor (f) and ceil (c) values for x and y.
189
+ const int fx = std::floor (static_cast <float >(x));
190
+ const int fy = std::floor (static_cast <float >(y));
191
+ const int cx = fx + 1 ;
192
+ const int cy = fy + 1 ;
193
+ const T dx = static_cast <T>(cx) - x;
194
+ const T dy = static_cast <T>(cy) - y;
195
+
196
+ const T img_fxfy = (fx >= 0 && fy >= 0 ) ? GET_DATA_POINT (fx, fy) : zero;
197
+
198
+ const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1 )
199
+ ? GET_DATA_POINT (cx, cy)
200
+ : zero;
201
+
202
+ const T img_fxcy =
203
+ (fx >= 0 && cy <= data_height - 1 ) ? GET_DATA_POINT (fx, cy) : zero;
204
+
205
+ const T img_cxfy =
206
+ (cx <= data_width - 1 && fy >= 0 ) ? GET_DATA_POINT (cx, fy) : zero;
207
+
208
+ // Update partial gradients wrt relevant warp field entries
209
+ atomicAdd (grad_warp + warp_id_x,
210
+ grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
211
+ dy * (img_cxfy - img_fxfy)));
212
+ atomicAdd (grad_warp + warp_id_y,
213
+ grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
214
+ dx * (img_fxcy - img_fxfy)));
215
+
216
+ // Update partial gradients wrt sampled data
217
+ if (fx >= 0 && fy >= 0 ) {
218
+ UPDATE_GRAD_DATA_POINT (fx, fy, grad_output_value * dx * dy);
219
+ }
220
+ if (cx <= data_width - 1 && cy <= data_height - 1 ) {
221
+ UPDATE_GRAD_DATA_POINT (cx, cy,
222
+ grad_output_value * (one - dx) * (one - dy));
223
+ }
224
+ if (fx >= 0 && cy <= data_height - 1 ) {
225
+ UPDATE_GRAD_DATA_POINT (fx, cy, grad_output_value * dx * (one - dy));
226
+ }
227
+ if (cx <= data_width - 1 && fy >= 0 ) {
228
+ UPDATE_GRAD_DATA_POINT (cx, fy, grad_output_value * (one - dx) * dy);
229
+ }
230
+ }
231
+ }
232
+ }
233
+
234
+ #undef GET_DATA_POINT
235
+ #undef UPDATE_GRAD_DATA_POINT
236
+
237
+ } // namespace
238
+
239
+ namespace functor {
240
+
241
+ template <typename T>
242
+ struct ResamplerGrad2DFunctor <GPUDevice, T> {
243
+ void operator ()(OpKernelContext* ctx, const GPUDevice& d,
244
+ const T* __restrict__ data, const T* __restrict__ warp,
245
+ const T* __restrict__ grad_output, T* __restrict__ grad_data,
246
+ T* __restrict__ grad_warp, const int batch_size,
247
+ const int data_height, const int data_width,
248
+ const int data_channels, const int num_sampling_points) {
249
+ // Set gradients to 0, because the kernel incrementally updates the
250
+ // tensor entries by adding partial contributions.
251
+ const int grad_warp_size = batch_size * num_sampling_points * 2 ;
252
+ const int grad_data_size =
253
+ batch_size * data_height * data_width * data_channels;
254
+
255
+ GpuLaunchConfig config = GetGpuLaunchConfig (grad_warp_size, d);
256
+ TF_CHECK_OK (GpuLaunchKernel (SetZero<T>, config.block_count ,
257
+ config.thread_per_block , 0 , d.stream (),
258
+ grad_warp_size, grad_warp));
259
+
260
+ config = GetGpuLaunchConfig (grad_data_size, d);
261
+ TF_CHECK_OK (GpuLaunchKernel (SetZero<T>, config.block_count ,
262
+ config.thread_per_block , 0 , d.stream (),
263
+ grad_data_size, grad_data));
264
+
265
+ const int resampler_output_size =
266
+ batch_size * num_sampling_points * data_channels;
267
+ config = GetGpuLaunchConfig (resampler_output_size, d);
268
+ TF_CHECK_OK (GpuLaunchKernel (ResamplerGrad2DKernel<T>, config.block_count ,
269
+ config.thread_per_block , 0 , d.stream (), data,
270
+ warp, grad_output, grad_data, grad_warp,
271
+ batch_size, data_height, data_width,
272
+ data_channels, num_sampling_points));
273
+ }
274
+ };
275
+
276
+ template struct ResamplerGrad2DFunctor <GPUDevice, float >;
277
+
278
+ } // namespace functor
279
+ } // namespace addons
280
+ } // namespace tensorflow
281
+ #endif // GOOGLE_CUDA
0 commit comments