18
18
19
19
#include " paddle/phi/backends/gpu/gpu_context.h"
20
20
#include " paddle/phi/common/amp_type_traits.h"
21
+ #include " paddle/phi/common/type_traits.h"
21
22
#include " paddle/phi/core/dense_tensor.h"
22
23
#include " paddle/phi/core/generator.h"
23
24
#include " paddle/phi/core/kernel_registry.h"
25
+ #include " paddle/phi/kernels/complex_kernel.h"
24
26
#include " paddle/phi/kernels/funcs/distribution_helper.h"
25
27
#include " paddle/phi/kernels/funcs/index_impl.cu.h"
26
28
27
29
namespace phi {
28
30
31
+ template <typename T>
32
+ using ComplexType = phi::dtype::complex <T>;
33
+
29
34
template <typename T>
30
35
struct GaussianGenerator {
31
36
T mean_, std_;
@@ -51,8 +56,41 @@ struct GaussianGenerator {
51
56
}
52
57
};
53
58
54
- template <typename T, typename Context>
55
- void GaussianKernel (const Context& dev_ctx,
59
+ template <typename T>
60
+ struct GaussianGenerator <ComplexType<T>> {
61
+ T mean_, std_;
62
+ unsigned int seed_;
63
+ unsigned int offset_ = 0 ;
64
+
65
+ __host__ __device__ GaussianGenerator (T mean, T std, int seed)
66
+ : mean_(mean), std_(std), seed_(seed) {}
67
+
68
+ __host__ __device__ GaussianGenerator (T mean, T std, int seed, int offset)
69
+ : mean_(mean), std_(std), seed_(seed), offset_(offset) {}
70
+
71
+ __host__ __device__ ComplexType<T> operator ()(const unsigned int n) const {
72
+ thrust::minstd_rand rng_real;
73
+ thrust::minstd_rand rng_img;
74
+ rng_real.seed (seed_);
75
+ rng_img.seed (seed_);
76
+ thrust::normal_distribution<T> dist (mean_, std_);
77
+ unsigned int new_n = n + offset_;
78
+ rng_real.discard (new_n);
79
+ rng_img.discard (new_n);
80
+ T real = dist (rng_real);
81
+ T imag = dist (rng_img);
82
+ return ComplexType<T>(real, imag);
83
+ }
84
+ };
85
+
86
+ // If T is not complex
87
+ template <
88
+ typename T,
89
+ typename Context,
90
+ std::enable_if_t <!std::is_same<T, phi::dtype::complex <float >>::value &&
91
+ !std::is_same<T, phi::dtype::complex <double >>::value,
92
+ bool > = true >
93
+ void GaussianRandom (const Context& dev_ctx,
56
94
const IntArray& shape,
57
95
float mean,
58
96
float std,
@@ -76,8 +114,55 @@ void GaussianKernel(const Context& dev_ctx,
76
114
}
77
115
}
78
116
79
- template <typename T, typename Context>
80
- void GaussianInplaceKernel (const Context& dev_ctx,
117
+ // If T is complex
118
+ template <
119
+ typename T,
120
+ typename Context,
121
+ std::enable_if_t <std::is_same<T, phi::dtype::complex <float >>::value ||
122
+ std::is_same<T, phi::dtype::complex <double >>::value,
123
+ bool > = true >
124
+ void GaussianRandom (const Context& dev_ctx,
125
+ const IntArray& shape,
126
+ float mean,
127
+ float std,
128
+ int seed,
129
+ DataType dtype,
130
+ DenseTensor* out) {
131
+ out->Resize (common::make_ddim (shape.GetData ()));
132
+ dev_ctx.template Alloc <T>(out);
133
+ float std_of_real_or_imag = std::sqrt (std::pow (std, 2 ) / 2 );
134
+ if (seed == 0 ) {
135
+ // use global Generator seed
136
+ DenseTensor* out_real = new DenseTensor ();
137
+ DenseTensor* out_imag = new DenseTensor ();
138
+ out_real->Resize (common::make_ddim (shape.GetData ()));
139
+ out_imag->Resize (common::make_ddim (shape.GetData ()));
140
+ dev_ctx.template Alloc <T>(out_real);
141
+ dev_ctx.template Alloc <T>(out_imag);
142
+ funcs::normal_distribution<phi::dtype::Real<T>> dist;
143
+ funcs::normal_distribution<phi::dtype::Real<T>> dist_imag;
144
+ funcs::normal_transform<phi::dtype::Real<T>> trans (mean,
145
+ std_of_real_or_imag);
146
+ funcs::distribution_and_transform<phi::dtype::Real<T>>(
147
+ dev_ctx, out_real, dist, trans);
148
+ funcs::distribution_and_transform<phi::dtype::Real<T>>(
149
+ dev_ctx, out_imag, dist_imag, trans);
150
+ phi::ComplexKernel<phi::dtype::Real<T>>(dev_ctx, *out_real, *out_imag, out);
151
+ } else {
152
+ // use OP seed
153
+ auto func = GaussianGenerator<T>(mean, std_of_real_or_imag, seed);
154
+ IndexKernel<T, GaussianGenerator<T>>(dev_ctx, out, func);
155
+ }
156
+ }
157
+
158
+ // If T is not complex
159
+ template <
160
+ typename T,
161
+ typename Context,
162
+ std::enable_if_t <!std::is_same<T, phi::dtype::complex <float >>::value &&
163
+ !std::is_same<T, phi::dtype::complex <double >>::value,
164
+ bool > = true >
165
+ void GaussianRandomInplace (const Context& dev_ctx,
81
166
const DenseTensor& x,
82
167
float mean,
83
168
float std,
@@ -99,6 +184,66 @@ void GaussianInplaceKernel(const Context& dev_ctx,
99
184
}
100
185
}
101
186
187
+ // If T is complex
188
+ template <
189
+ typename T,
190
+ typename Context,
191
+ std::enable_if_t <std::is_same<T, phi::dtype::complex <float >>::value ||
192
+ std::is_same<T, phi::dtype::complex <double >>::value,
193
+ bool > = true >
194
+ void GaussianRandomInplace (const Context& dev_ctx,
195
+ const DenseTensor& x,
196
+ float mean,
197
+ float std,
198
+ int seed,
199
+ DenseTensor* out) {
200
+ dev_ctx.template Alloc <T>(out);
201
+ float std_of_real_or_imag = std::sqrt (std::pow (std, 2 ) / 2 );
202
+ if (seed == 0 ) {
203
+ // use global Generator seed
204
+ DenseTensor* out_real = new DenseTensor ();
205
+ DenseTensor* out_imag = new DenseTensor ();
206
+ out_real->Resize (x.dims ());
207
+ out_imag->Resize (x.dims ());
208
+ dev_ctx.template Alloc <T>(out_real);
209
+ dev_ctx.template Alloc <T>(out_imag);
210
+ funcs::normal_distribution<phi::dtype::Real<T>> dist;
211
+ funcs::normal_distribution<phi::dtype::Real<T>> dist_imag;
212
+ funcs::normal_transform<phi::dtype::Real<T>> trans (mean,
213
+ std_of_real_or_imag);
214
+ funcs::distribution_and_transform<phi::dtype::Real<T>>(
215
+ dev_ctx, out_real, dist, trans);
216
+ funcs::distribution_and_transform<phi::dtype::Real<T>>(
217
+ dev_ctx, out_imag, dist_imag, trans);
218
+ phi::ComplexKernel<phi::dtype::Real<T>>(dev_ctx, *out_real, *out_imag, out);
219
+ } else {
220
+ // use OP seed
221
+ auto func = GaussianGenerator<T>(mean, std_of_real_or_imag, seed);
222
+ IndexKernel<T, GaussianGenerator<T>>(dev_ctx, out, func);
223
+ }
224
+ }
225
+
226
+ template <typename T, typename Context>
227
+ void GaussianKernel (const Context& dev_ctx,
228
+ const IntArray& shape,
229
+ float mean,
230
+ float std,
231
+ int seed,
232
+ DataType dtype,
233
+ DenseTensor* out) {
234
+ GaussianRandom<T>(dev_ctx, shape, mean, std, seed, dtype, out);
235
+ }
236
+
237
+ template <typename T, typename Context>
238
+ void GaussianInplaceKernel (const Context& dev_ctx,
239
+ const DenseTensor& x,
240
+ float mean,
241
+ float std,
242
+ int seed,
243
+ DenseTensor* out) {
244
+ GaussianRandomInplace<T>(dev_ctx, x, mean, std, seed, out);
245
+ }
246
+
102
247
} // namespace phi
103
248
104
249
PD_REGISTER_KERNEL (gaussian,
@@ -108,7 +253,9 @@ PD_REGISTER_KERNEL(gaussian,
108
253
phi::dtype::float16,
109
254
phi::dtype::bfloat16,
110
255
float ,
111
- double ) {}
256
+ double ,
257
+ phi::dtype::complex <float >,
258
+ phi::dtype::complex <double >) {}
112
259
113
260
PD_REGISTER_KERNEL (gaussian_inplace,
114
261
GPU,
@@ -117,4 +264,6 @@ PD_REGISTER_KERNEL(gaussian_inplace,
117
264
phi::dtype::float16,
118
265
phi::dtype::bfloat16,
119
266
float ,
120
- double ) {}
267
+ double ,
268
+ phi::dtype::complex <float >,
269
+ phi::dtype::complex <double >) {}
0 commit comments