Skip to content

Commit 45ee94f

Browse files
authored
【Hackathon 6th No.31】gaussian kernel and normal api support complex -part (#64793)
* update gaussian kernel and normal api * update gaussian kernel and normal api * update kernel * update kernel * update kernel * fix kernel * fix test * fix test * fix test * fix kernel * add error test
1 parent 3dcee14 commit 45ee94f

File tree

8 files changed

+621
-44
lines changed

8 files changed

+621
-44
lines changed

paddle/phi/kernels/cpu/gaussian_inplace_grad_kernel.cc

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,49 @@ limitations under the License. */
1414

1515
#include "paddle/phi/kernels/gaussian_inplace_grad_kernel.h"
1616

17+
#include "paddle/phi/common/type_traits.h"
1718
#include "paddle/phi/core/kernel_registry.h"
1819

1920
namespace phi {
2021

22+
// If T is not complex
23+
template <
24+
typename T,
25+
typename Context,
26+
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
27+
!std::is_same<T, phi::dtype::complex<double>>::value,
28+
bool> = true>
29+
void GaussianInplaceGrad(const Context& ctx, DenseTensor* x_grad) {
30+
if (x_grad) {
31+
auto* data = ctx.template Alloc<T>(x_grad);
32+
std::fill(data, data + x_grad->numel(), T(0));
33+
}
34+
}
35+
36+
// If T is complex
37+
template <
38+
typename T,
39+
typename Context,
40+
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
41+
std::is_same<T, phi::dtype::complex<double>>::value,
42+
bool> = true>
43+
void GaussianInplaceGrad(const Context& ctx, DenseTensor* x_grad) {
44+
if (x_grad) {
45+
auto* data = ctx.template Alloc<T>(x_grad);
46+
T value = T(static_cast<phi::dtype::Real<T>>(0.0f),
47+
static_cast<phi::dtype::Real<T>>(0.0f));
48+
std::fill(data, data + x_grad->numel(), value);
49+
}
50+
}
51+
2152
template <typename T, typename Context>
2253
void GaussianInplaceGradKernel(const Context& ctx,
2354
const DenseTensor& out_grad UNUSED,
2455
float mean UNUSED,
2556
float std UNUSED,
2657
int seed UNUSED,
2758
DenseTensor* x_grad) {
28-
if (x_grad) {
29-
auto* data = ctx.template Alloc<T>(x_grad);
30-
std::fill(data, data + x_grad->numel(), T(0));
31-
}
59+
GaussianInplaceGrad<T>(ctx, x_grad);
3260
}
3361

3462
} // namespace phi
@@ -38,4 +66,6 @@ PD_REGISTER_KERNEL(gaussian_inplace_grad,
3866
ALL_LAYOUT,
3967
phi::GaussianInplaceGradKernel,
4068
float,
41-
double) {}
69+
double,
70+
phi::dtype::complex<float>,
71+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/gaussian_kernel.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ void GaussianInplaceKernel(const Context& dev_ctx,
4848
int seed,
4949
DenseTensor* out) {
5050
T* data = dev_ctx.template Alloc<T>(out);
51-
std::normal_distribution<T> dist(mean, std);
5251

5352
int64_t size = out->numel();
5453
std::shared_ptr<std::mt19937_64> engine;
@@ -59,9 +58,7 @@ void GaussianInplaceKernel(const Context& dev_ctx,
5958
engine = dev_ctx.GetGenerator()->GetCPUEngine();
6059
}
6160

62-
for (int64_t i = 0; i < size; ++i) {
63-
data[i] = dist(*engine);
64-
}
61+
NormalDistribution<T>(data, size, mean, std, engine);
6562
}
6663

6764
} // namespace phi
@@ -73,11 +70,15 @@ PD_REGISTER_KERNEL(gaussian,
7370
phi::dtype::float16,
7471
phi::dtype::bfloat16,
7572
float,
76-
double) {}
73+
double,
74+
phi::dtype::complex<float>,
75+
phi::dtype::complex<double>) {}
7776

7877
PD_REGISTER_KERNEL(gaussian_inplace,
7978
CPU,
8079
ALL_LAYOUT,
8180
phi::GaussianInplaceKernel,
8281
float,
83-
double) {}
82+
double,
83+
phi::dtype::complex<float>,
84+
phi::dtype::complex<double>) {}

paddle/phi/kernels/funcs/norm_distribution.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,34 @@ inline void NormalDistribution(phi::dtype::bfloat16* data,
5555
}
5656
}
5757

58+
template <>
59+
inline void NormalDistribution(phi::dtype::complex<float>* data,
60+
const int64_t& size,
61+
const float& mean,
62+
const float& std,
63+
std::shared_ptr<std::mt19937_64> engine) {
64+
float std_of_real_or_imag = std::sqrt(std::pow(std, 2) / 2);
65+
std::normal_distribution<float> dist(mean, std_of_real_or_imag);
66+
for (int64_t i = 0; i < size; ++i) {
67+
float real = dist(*engine);
68+
float imag = dist(*engine);
69+
data[i] = phi::dtype::complex<float>(real, imag);
70+
}
71+
}
72+
73+
template <>
74+
inline void NormalDistribution(phi::dtype::complex<double>* data,
75+
const int64_t& size,
76+
const float& mean,
77+
const float& std,
78+
std::shared_ptr<std::mt19937_64> engine) {
79+
float std_of_real_or_imag = std::sqrt(std::pow(std, 2) / 2);
80+
std::normal_distribution<double> dist(mean, std_of_real_or_imag);
81+
for (int64_t i = 0; i < size; ++i) {
82+
double real = dist(*engine);
83+
double imag = dist(*engine);
84+
data[i] = phi::dtype::complex<double>(real, imag);
85+
}
86+
}
87+
5888
} // namespace phi

paddle/phi/kernels/gpu/gaussian_inplace_grad_kernel.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ PD_REGISTER_KERNEL(gaussian_inplace_grad,
4141
float,
4242
double,
4343
phi::dtype::float16,
44-
phi::dtype::bfloat16) {}
44+
phi::dtype::bfloat16,
45+
phi::dtype::complex<float>,
46+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/gaussian_kernel.cu

Lines changed: 155 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,19 @@
1818

1919
#include "paddle/phi/backends/gpu/gpu_context.h"
2020
#include "paddle/phi/common/amp_type_traits.h"
21+
#include "paddle/phi/common/type_traits.h"
2122
#include "paddle/phi/core/dense_tensor.h"
2223
#include "paddle/phi/core/generator.h"
2324
#include "paddle/phi/core/kernel_registry.h"
25+
#include "paddle/phi/kernels/complex_kernel.h"
2426
#include "paddle/phi/kernels/funcs/distribution_helper.h"
2527
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
2628

2729
namespace phi {
2830

31+
template <typename T>
32+
using ComplexType = phi::dtype::complex<T>;
33+
2934
template <typename T>
3035
struct GaussianGenerator {
3136
T mean_, std_;
@@ -51,8 +56,41 @@ struct GaussianGenerator {
5156
}
5257
};
5358

54-
template <typename T, typename Context>
55-
void GaussianKernel(const Context& dev_ctx,
59+
template <typename T>
60+
struct GaussianGenerator<ComplexType<T>> {
61+
T mean_, std_;
62+
unsigned int seed_;
63+
unsigned int offset_ = 0;
64+
65+
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
66+
: mean_(mean), std_(std), seed_(seed) {}
67+
68+
__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
69+
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}
70+
71+
__host__ __device__ ComplexType<T> operator()(const unsigned int n) const {
72+
thrust::minstd_rand rng_real;
73+
thrust::minstd_rand rng_img;
74+
rng_real.seed(seed_);
75+
rng_img.seed(seed_);
76+
thrust::normal_distribution<T> dist(mean_, std_);
77+
unsigned int new_n = n + offset_;
78+
rng_real.discard(new_n);
79+
rng_img.discard(new_n);
80+
T real = dist(rng_real);
81+
T imag = dist(rng_img);
82+
return ComplexType<T>(real, imag);
83+
}
84+
};
85+
86+
// If T is not complex
87+
template <
88+
typename T,
89+
typename Context,
90+
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
91+
!std::is_same<T, phi::dtype::complex<double>>::value,
92+
bool> = true>
93+
void GaussianRandom(const Context& dev_ctx,
5694
const IntArray& shape,
5795
float mean,
5896
float std,
@@ -76,8 +114,55 @@ void GaussianKernel(const Context& dev_ctx,
76114
}
77115
}
78116

79-
template <typename T, typename Context>
80-
void GaussianInplaceKernel(const Context& dev_ctx,
117+
// If T is complex
118+
template <
119+
typename T,
120+
typename Context,
121+
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
122+
std::is_same<T, phi::dtype::complex<double>>::value,
123+
bool> = true>
124+
void GaussianRandom(const Context& dev_ctx,
125+
const IntArray& shape,
126+
float mean,
127+
float std,
128+
int seed,
129+
DataType dtype,
130+
DenseTensor* out) {
131+
out->Resize(common::make_ddim(shape.GetData()));
132+
dev_ctx.template Alloc<T>(out);
133+
float std_of_real_or_imag = std::sqrt(std::pow(std, 2) / 2);
134+
if (seed == 0) {
135+
// use global Generator seed
136+
DenseTensor* out_real = new DenseTensor();
137+
DenseTensor* out_imag = new DenseTensor();
138+
out_real->Resize(common::make_ddim(shape.GetData()));
139+
out_imag->Resize(common::make_ddim(shape.GetData()));
140+
dev_ctx.template Alloc<T>(out_real);
141+
dev_ctx.template Alloc<T>(out_imag);
142+
funcs::normal_distribution<phi::dtype::Real<T>> dist;
143+
funcs::normal_distribution<phi::dtype::Real<T>> dist_imag;
144+
funcs::normal_transform<phi::dtype::Real<T>> trans(mean,
145+
std_of_real_or_imag);
146+
funcs::distribution_and_transform<phi::dtype::Real<T>>(
147+
dev_ctx, out_real, dist, trans);
148+
funcs::distribution_and_transform<phi::dtype::Real<T>>(
149+
dev_ctx, out_imag, dist_imag, trans);
150+
phi::ComplexKernel<phi::dtype::Real<T>>(dev_ctx, *out_real, *out_imag, out);
151+
} else {
152+
// use OP seed
153+
auto func = GaussianGenerator<T>(mean, std_of_real_or_imag, seed);
154+
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, out, func);
155+
}
156+
}
157+
158+
// If T is not complex
159+
template <
160+
typename T,
161+
typename Context,
162+
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
163+
!std::is_same<T, phi::dtype::complex<double>>::value,
164+
bool> = true>
165+
void GaussianRandomInplace(const Context& dev_ctx,
81166
const DenseTensor& x,
82167
float mean,
83168
float std,
@@ -99,6 +184,66 @@ void GaussianInplaceKernel(const Context& dev_ctx,
99184
}
100185
}
101186

187+
// If T is complex
188+
template <
189+
typename T,
190+
typename Context,
191+
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
192+
std::is_same<T, phi::dtype::complex<double>>::value,
193+
bool> = true>
194+
void GaussianRandomInplace(const Context& dev_ctx,
195+
const DenseTensor& x,
196+
float mean,
197+
float std,
198+
int seed,
199+
DenseTensor* out) {
200+
dev_ctx.template Alloc<T>(out);
201+
float std_of_real_or_imag = std::sqrt(std::pow(std, 2) / 2);
202+
if (seed == 0) {
203+
// use global Generator seed
204+
DenseTensor* out_real = new DenseTensor();
205+
DenseTensor* out_imag = new DenseTensor();
206+
out_real->Resize(x.dims());
207+
out_imag->Resize(x.dims());
208+
dev_ctx.template Alloc<T>(out_real);
209+
dev_ctx.template Alloc<T>(out_imag);
210+
funcs::normal_distribution<phi::dtype::Real<T>> dist;
211+
funcs::normal_distribution<phi::dtype::Real<T>> dist_imag;
212+
funcs::normal_transform<phi::dtype::Real<T>> trans(mean,
213+
std_of_real_or_imag);
214+
funcs::distribution_and_transform<phi::dtype::Real<T>>(
215+
dev_ctx, out_real, dist, trans);
216+
funcs::distribution_and_transform<phi::dtype::Real<T>>(
217+
dev_ctx, out_imag, dist_imag, trans);
218+
phi::ComplexKernel<phi::dtype::Real<T>>(dev_ctx, *out_real, *out_imag, out);
219+
} else {
220+
// use OP seed
221+
auto func = GaussianGenerator<T>(mean, std_of_real_or_imag, seed);
222+
IndexKernel<T, GaussianGenerator<T>>(dev_ctx, out, func);
223+
}
224+
}
225+
226+
template <typename T, typename Context>
227+
void GaussianKernel(const Context& dev_ctx,
228+
const IntArray& shape,
229+
float mean,
230+
float std,
231+
int seed,
232+
DataType dtype,
233+
DenseTensor* out) {
234+
GaussianRandom<T>(dev_ctx, shape, mean, std, seed, dtype, out);
235+
}
236+
237+
template <typename T, typename Context>
238+
void GaussianInplaceKernel(const Context& dev_ctx,
239+
const DenseTensor& x,
240+
float mean,
241+
float std,
242+
int seed,
243+
DenseTensor* out) {
244+
GaussianRandomInplace<T>(dev_ctx, x, mean, std, seed, out);
245+
}
246+
102247
} // namespace phi
103248

104249
PD_REGISTER_KERNEL(gaussian,
@@ -108,7 +253,9 @@ PD_REGISTER_KERNEL(gaussian,
108253
phi::dtype::float16,
109254
phi::dtype::bfloat16,
110255
float,
111-
double) {}
256+
double,
257+
phi::dtype::complex<float>,
258+
phi::dtype::complex<double>) {}
112259

113260
PD_REGISTER_KERNEL(gaussian_inplace,
114261
GPU,
@@ -117,4 +264,6 @@ PD_REGISTER_KERNEL(gaussian_inplace,
117264
phi::dtype::float16,
118265
phi::dtype::bfloat16,
119266
float,
120-
double) {}
267+
double,
268+
phi::dtype::complex<float>,
269+
phi::dtype::complex<double>) {}

0 commit comments

Comments
 (0)