Skip to content

Commit 085503a

Browse files
autoihWindQAQ
authored andcommitted
add resampler kernel (#662)
* add resampler kernel * add register op * namespace and register * python format * headers and cleanup * sanity cleanup * readme update * alphabetic order * gpu test & minor revision * comment on wrapping part * cpu test * miscellaneous fixing * minior fix * line removal
1 parent ba5bbe1 commit 085503a

File tree

11 files changed

+1164
-4
lines changed

11 files changed

+1164
-4
lines changed

tensorflow_addons/custom_ops/image/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ custom_op_library(
3535
"cc/kernels/image_projective_transform_op_gpu.cu.cc",
3636
],
3737
)
38+
39+
custom_op_library(
40+
name = "_resampler_ops.so",
41+
srcs = [
42+
"cc/kernels/resampler_ops.cc",
43+
"cc/kernels/resampler_ops.h",
44+
"cc/ops/resampler_ops.cc",
45+
],
46+
cuda_srcs = [
47+
"cc/kernels/resampler_ops.h",
48+
"cc/kernels/resampler_ops_gpu.cu.cc",
49+
],
50+
)

tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc

Lines changed: 417 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#ifndef TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_
17+
#define TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_
18+
19+
#if PLATFORM_WINDOWS
20+
#define __restrict__ __restrict
21+
#endif
22+
23+
#include "tensorflow/core/framework/op_kernel.h"
24+
25+
namespace tensorflow {
26+
namespace addons {
27+
namespace functor {
28+
29+
// Helper functor for the Resampler Op in 2D
30+
template <typename Device, typename T>
31+
struct Resampler2DFunctor {
32+
void operator()(OpKernelContext* ctx, const Device& d,
33+
const T* __restrict__ data, const T* __restrict__ warp,
34+
T* __restrict__ output, const int batch_size,
35+
const int data_height, const int data_width,
36+
const int data_channels, const int num_sampling_points);
37+
};
38+
39+
// Helper functor for the Resampler Gradient Op in 2D
40+
template <typename Device, typename T>
41+
struct ResamplerGrad2DFunctor {
42+
void operator()(OpKernelContext* ctx, const Device& d,
43+
const T* __restrict__ data, const T* __restrict__ warp,
44+
const T* __restrict__ grad_output, T* __restrict__ grad_data,
45+
T* __restrict__ grad_warp, const int batch_size,
46+
const int data_height, const int data_width,
47+
const int data_channels, const int num_sampling_points);
48+
};
49+
50+
} // namespace functor
51+
} // namespace addons
52+
} // namespace tensorflow
53+
#endif // TENSORFLOW_ADDONS_IMAGE_KERNELS_RESAMPLER_OPS_H_
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
// =============================================================================
15+
16+
#if GOOGLE_CUDA
17+
18+
#define EIGEN_USE_GPU
19+
20+
#include <stdio.h>
21+
22+
#include <cmath>
23+
24+
#include "tensorflow/core/framework/register_types.h"
25+
#include "tensorflow/core/util/gpu_kernel_helper.h"
26+
#include "tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.h"
27+
28+
namespace tensorflow {
29+
namespace addons {
30+
31+
using GPUDevice = Eigen::GpuDevice;
32+
33+
namespace {
34+
35+
#define GET_DATA_POINT(x, y) \
36+
data[batch_id * data_batch_stride + data_channels * (y * data_width + x) + \
37+
chan]
38+
39+
template <typename T>
40+
__global__ void Resampler2DKernel(const T* __restrict__ data,
41+
const T* __restrict__ warp,
42+
T* __restrict__ output, const int batch_size,
43+
const int data_height, const int data_width,
44+
const int data_channels,
45+
const int num_sampling_points) {
46+
const int output_data_size = batch_size * num_sampling_points * data_channels;
47+
CUDA_1D_KERNEL_LOOP(index, output_data_size) {
48+
const int out_index = index;
49+
50+
// Get (idxSample, channel, point) from the index.
51+
// Use this formula
52+
// index = batch_id * num_sampling_points * num_chans +
53+
// sample_id * num_chans + chan_id,
54+
// with sample_id = [0, ... ,num_sampling_points)
55+
const int data_batch_stride = data_height * data_width * data_channels;
56+
const int warp_batch_stride = num_sampling_points * 2;
57+
const int output_batch_stride = num_sampling_points * data_channels;
58+
59+
const int batch_id = index / output_batch_stride;
60+
const int index_in_batch = index % output_batch_stride;
61+
const int chan = index_in_batch % data_channels;
62+
const int sample_id = index_in_batch / data_channels;
63+
64+
// Get coords of 2D point where data will be resampled
65+
const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
66+
const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
67+
const T zero = static_cast<T>(0.0);
68+
const T one = static_cast<T>(1.0);
69+
// The interpolation function:
70+
// a) implicitly pads the input data with 0s (hence the unusual checks
71+
// with {x,y} > -1)
72+
// b) returns 0 when sampling outside the (padded) image.
73+
// The effect is that the sampled signal smoothly goes to 0 outside
74+
// the original input domain, rather than presenting a jump
75+
// discontinuity at the image boundaries.
76+
if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
77+
x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
78+
// Precompute floor (f) and ceil (c) values for x and y.
79+
const int fx = std::floor(static_cast<float>(x));
80+
const int fy = std::floor(static_cast<float>(y));
81+
const int cx = fx + 1;
82+
const int cy = fy + 1;
83+
const T dx = static_cast<T>(cx) - x;
84+
const T dy = static_cast<T>(cy) - y;
85+
86+
const T img_fxfy =
87+
(fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero;
88+
89+
const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
90+
? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy)
91+
: zero;
92+
93+
const T img_fxcy = (fx >= 0 && cy <= data_height - 1)
94+
? dx * (one - dy) * GET_DATA_POINT(fx, cy)
95+
: zero;
96+
97+
const T img_cxfy = (cx <= data_width - 1 && fy >= 0)
98+
? (one - dx) * dy * GET_DATA_POINT(cx, fy)
99+
: zero;
100+
101+
output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy;
102+
} else {
103+
output[out_index] = zero;
104+
}
105+
}
106+
}
107+
108+
} // namespace
109+
110+
namespace functor {
111+
112+
template <typename T>
113+
struct Resampler2DFunctor<GPUDevice, T> {
114+
void operator()(OpKernelContext* ctx, const GPUDevice& d,
115+
const T* __restrict__ data, const T* __restrict__ warp,
116+
T* __restrict__ output, const int batch_size,
117+
const int data_height, const int data_width,
118+
const int data_channels, const int num_sampling_points) {
119+
const int output_data_size =
120+
batch_size * num_sampling_points * data_channels;
121+
GpuLaunchConfig config = GetGpuLaunchConfig(output_data_size, d);
122+
TF_CHECK_OK(GpuLaunchKernel(
123+
Resampler2DKernel<T>, config.block_count, config.thread_per_block, 0,
124+
d.stream(), data, warp, output, batch_size, data_height, data_width,
125+
data_channels, num_sampling_points));
126+
}
127+
};
128+
129+
// TODO(fviola): gcudacc fails at compile time with Eigen::half.
130+
// template struct Resampler2DFunctor<GPUDevice, Eigen::half>;
131+
template struct Resampler2DFunctor<GPUDevice, float>;
132+
template struct Resampler2DFunctor<GPUDevice, double>;
133+
134+
} // namespace functor
135+
136+
namespace {
137+
138+
#define UPDATE_GRAD_DATA_POINT(x, y, v) \
139+
atomicAdd(grad_data + (batch_id * data_batch_stride + \
140+
data_channels * (y * data_width + x) + chan), \
141+
v)
142+
143+
template <typename T>
144+
__global__ void ResamplerGrad2DKernel(
145+
const T* __restrict__ data, const T* __restrict__ warp,
146+
const T* __restrict__ grad_output, T* __restrict__ grad_data,
147+
T* __restrict__ grad_warp, const int batch_size, const int data_height,
148+
const int data_width, const int data_channels,
149+
const int num_sampling_points) {
150+
const int resampler_output_size =
151+
batch_size * num_sampling_points * data_channels;
152+
CUDA_1D_KERNEL_LOOP(index, resampler_output_size) {
153+
const int out_index = index;
154+
155+
// Get (idxSample, channel, point) from the index.
156+
// Use this formula
157+
// index = batch_id * num_sampling_points * num_chans +
158+
// sample_id * num_chans + chan_id,
159+
// with sample_id = [0, ... ,num_sampling_points)
160+
const int data_batch_stride = data_height * data_width * data_channels;
161+
const int warp_batch_stride = num_sampling_points * 2;
162+
const int output_batch_stride = num_sampling_points * data_channels;
163+
164+
const int batch_id = index / output_batch_stride;
165+
const int index_in_batch = index % output_batch_stride;
166+
const int chan = index_in_batch % data_channels;
167+
const int sample_id = index_in_batch / data_channels;
168+
169+
// Get coords of 2D point where data will be resampled
170+
const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2;
171+
const int warp_id_y = warp_id_x + 1;
172+
const T x = warp[warp_id_x];
173+
const T y = warp[warp_id_y];
174+
const T zero = static_cast<T>(0.0);
175+
const T one = static_cast<T>(1.0);
176+
177+
// Get grad output
178+
const T grad_output_value = grad_output[out_index];
179+
// The interpolation function whose gradient this kernel implements:
180+
// a) implicitly pads the input data with 0s (hence the unusual checks
181+
// with {x,y} > -1)
182+
// b) returns 0 when sampling outside the (padded) image.
183+
// The effect is that the sampled signal smoothly goes to 0 outside
184+
// the original input domain, rather than presenting a jump
185+
// discontinuity at the image boundaries.
186+
if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
187+
x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
188+
// Precompute floor (f) and ceil (c) values for x and y.
189+
const int fx = std::floor(static_cast<float>(x));
190+
const int fy = std::floor(static_cast<float>(y));
191+
const int cx = fx + 1;
192+
const int cy = fy + 1;
193+
const T dx = static_cast<T>(cx) - x;
194+
const T dy = static_cast<T>(cy) - y;
195+
196+
const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero;
197+
198+
const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
199+
? GET_DATA_POINT(cx, cy)
200+
: zero;
201+
202+
const T img_fxcy =
203+
(fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero;
204+
205+
const T img_cxfy =
206+
(cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero;
207+
208+
// Update partial gradients wrt relevant warp field entries
209+
atomicAdd(grad_warp + warp_id_x,
210+
grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
211+
dy * (img_cxfy - img_fxfy)));
212+
atomicAdd(grad_warp + warp_id_y,
213+
grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
214+
dx * (img_fxcy - img_fxfy)));
215+
216+
// Update partial gradients wrt sampled data
217+
if (fx >= 0 && fy >= 0) {
218+
UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy);
219+
}
220+
if (cx <= data_width - 1 && cy <= data_height - 1) {
221+
UPDATE_GRAD_DATA_POINT(cx, cy,
222+
grad_output_value * (one - dx) * (one - dy));
223+
}
224+
if (fx >= 0 && cy <= data_height - 1) {
225+
UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy));
226+
}
227+
if (cx <= data_width - 1 && fy >= 0) {
228+
UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy);
229+
}
230+
}
231+
}
232+
}
233+
234+
#undef GET_DATA_POINT
235+
#undef UPDATE_GRAD_DATA_POINT
236+
237+
} // namespace
238+
239+
namespace functor {
240+
241+
template <typename T>
242+
struct ResamplerGrad2DFunctor<GPUDevice, T> {
243+
void operator()(OpKernelContext* ctx, const GPUDevice& d,
244+
const T* __restrict__ data, const T* __restrict__ warp,
245+
const T* __restrict__ grad_output, T* __restrict__ grad_data,
246+
T* __restrict__ grad_warp, const int batch_size,
247+
const int data_height, const int data_width,
248+
const int data_channels, const int num_sampling_points) {
249+
// Set gradients to 0, because the kernel incrementally updates the
250+
// tensor entries by adding partial contributions.
251+
const int grad_warp_size = batch_size * num_sampling_points * 2;
252+
const int grad_data_size =
253+
batch_size * data_height * data_width * data_channels;
254+
255+
GpuLaunchConfig config = GetGpuLaunchConfig(grad_warp_size, d);
256+
TF_CHECK_OK(GpuLaunchKernel(SetZero<T>, config.block_count,
257+
config.thread_per_block, 0, d.stream(),
258+
grad_warp_size, grad_warp));
259+
260+
config = GetGpuLaunchConfig(grad_data_size, d);
261+
TF_CHECK_OK(GpuLaunchKernel(SetZero<T>, config.block_count,
262+
config.thread_per_block, 0, d.stream(),
263+
grad_data_size, grad_data));
264+
265+
const int resampler_output_size =
266+
batch_size * num_sampling_points * data_channels;
267+
config = GetGpuLaunchConfig(resampler_output_size, d);
268+
TF_CHECK_OK(GpuLaunchKernel(ResamplerGrad2DKernel<T>, config.block_count,
269+
config.thread_per_block, 0, d.stream(), data,
270+
warp, grad_output, grad_data, grad_warp,
271+
batch_size, data_height, data_width,
272+
data_channels, num_sampling_points));
273+
}
274+
};
275+
276+
template struct ResamplerGrad2DFunctor<GPUDevice, float>;
277+
278+
} // namespace functor
279+
} // namespace addons
280+
} // namespace tensorflow
281+
#endif // GOOGLE_CUDA

0 commit comments

Comments
 (0)