@@ -118,11 +118,11 @@ struct GNRowwiseMomentsVectorizedFunctor
118
118
sycl::nd_item<1 > item) const {
119
119
WelfordType val[VEC_SIZE];
120
120
WelfordOp welford_op = {/* correction=*/ 0 , /* take_sqrt=*/ false , item};
121
- auto g_start = item.get_group (0 ) * VEC_SIZE;
121
+ auto group_start = item.get_group (0 ) * VEC_SIZE;
122
122
123
123
#pragma unroll
124
124
for (int v = 0 ; v < VEC_SIZE; ++v) {
125
- const int64_t i = g_start + v;
125
+ const int64_t i = group_start + v;
126
126
for (int64_t j = item.get_local_id (0 ) * VEC_SIZE; j < N_;
127
127
j += item.get_local_range (0 ) * VEC_SIZE) {
128
128
const int64_t vec_index = i * N_ + j;
@@ -153,8 +153,8 @@ struct GNRowwiseMomentsVectorizedFunctor
153
153
mean_vec[v] = m1;
154
154
rstd_vec[v] = c10::xpu::compat::rsqrt (m2 + static_cast <T_ACC>(eps_));
155
155
}
156
- *(reinterpret_cast <vec_t *>(mean_ + g_start )) = mean_vec;
157
- *(reinterpret_cast <vec_t *>(rstd_ + g_start )) = rstd_vec;
156
+ *(reinterpret_cast <vec_t *>(mean_ + group_start )) = mean_vec;
157
+ *(reinterpret_cast <vec_t *>(rstd_ + group_start )) = rstd_vec;
158
158
}
159
159
}
160
160
@@ -934,6 +934,91 @@ struct ComputeInternalGradientsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
934
934
sycl_local_acc_t <T_ACC> db_shared_;
935
935
};
936
936
937
+ template <typename T, int SIMD, int VEC_SIZE>
938
+ struct ComputeInternalGradientsVectorizedFunctor
939
+ : public __SYCL_KER_CONFIG_CONVENTION__ {
940
+ using T_ACC = acc_type_device<T, kXPU >;
941
+ using vec_t = memory::aligned_vector<T, VEC_SIZE>;
942
+ using acc_vec_t = memory::aligned_vector<T_ACC, VEC_SIZE>;
943
+
944
+ [[intel::reqd_sub_group_size(SIMD)]] void operator ()(
945
+ sycl::nd_item<1 > item) const {
946
+ acc_vec_t sum1_vec;
947
+ acc_vec_t sum2_vec;
948
+
949
+ #pragma unroll
950
+ for (int v = 0 ; v < VEC_SIZE; ++v) {
951
+ sum1_vec[v] = 0 ;
952
+ sum2_vec[v] = 0 ;
953
+ }
954
+
955
+ auto group_start = item.get_group (0 ) * VEC_SIZE;
956
+
957
+ #pragma unroll
958
+ for (int v = 0 ; v < VEC_SIZE; ++v) {
959
+ const int64_t nc = group_start + v;
960
+ for (int64_t hw = item.get_local_id (0 ) * VEC_SIZE; hw < HxW_;
961
+ hw += item.get_local_range (0 ) * VEC_SIZE) {
962
+ const int64_t vec_index = nc * HxW_ + hw;
963
+ vec_t vec_dY_ =
964
+ *reinterpret_cast <vec_t *>(const_cast <T*>(dY_) + vec_index);
965
+ vec_t vec_X_ =
966
+ *reinterpret_cast <vec_t *>(const_cast <T*>(X_) + vec_index);
967
+
968
+ #pragma unroll
969
+ for (int iv = 0 ; iv < VEC_SIZE; ++iv) {
970
+ sum1_vec[v] += static_cast <T_ACC>(vec_dY_[iv] * vec_X_[iv]);
971
+ sum2_vec[v] += static_cast <T_ACC>(vec_dY_[iv]);
972
+ }
973
+ }
974
+ }
975
+
976
+ #pragma unroll
977
+ for (int v = 0 ; v < VEC_SIZE; ++v) {
978
+ sum1_vec[v] = GroupReduceSumWithoutBroadcast<T_ACC, SIMD>(
979
+ item, sum1_vec[v], ds_shared_);
980
+ sum2_vec[v] = GroupReduceSumWithoutBroadcast<T_ACC, SIMD>(
981
+ item, sum2_vec[v], db_shared_);
982
+ }
983
+
984
+ if (item.get_local_id (0 ) == 0 ) {
985
+ acc_vec_t ds_vec;
986
+ acc_vec_t db_vec;
987
+ #pragma unroll
988
+ for (int v = 0 ; v < VEC_SIZE; ++v) {
989
+ ds_vec[v] = sum1_vec[v];
990
+ db_vec[v] = sum2_vec[v];
991
+ }
992
+ *(reinterpret_cast <acc_vec_t *>(ds_ + group_start)) = ds_vec;
993
+ *(reinterpret_cast <acc_vec_t *>(db_ + group_start)) = db_vec;
994
+ }
995
+ }
996
+
997
+ void sycl_ker_config_convention (sycl::handler& cgh) {
998
+ ds_shared_ =
999
+ sycl_local_acc_t <T_ACC>(get_group_reduce_group_size (SIMD), cgh);
1000
+ db_shared_ =
1001
+ sycl_local_acc_t <T_ACC>(get_group_reduce_group_size (SIMD), cgh);
1002
+ }
1003
+
1004
+ ComputeInternalGradientsVectorizedFunctor (
1005
+ int64_t HxW,
1006
+ const T* dY,
1007
+ const T* X,
1008
+ T_ACC* ds,
1009
+ T_ACC* db)
1010
+ : HxW_(HxW), dY_(dY), X_(X), ds_(ds), db_(db) {}
1011
+
1012
+ private:
1013
+ int64_t HxW_;
1014
+ const T* dY_;
1015
+ const T* X_;
1016
+ T_ACC* ds_;
1017
+ T_ACC* db_;
1018
+ sycl_local_acc_t <T_ACC> ds_shared_;
1019
+ sycl_local_acc_t <T_ACC> db_shared_;
1020
+ };
1021
+
937
1022
template <typename T, typename T_ACC>
938
1023
struct GroupNormBackwardC1Functor {
939
1024
T_ACC operator ()(T rstd, T gamma) const {
@@ -1272,23 +1357,50 @@ void group_norm_backward_kernel_impl(
1272
1357
}
1273
1358
1274
1359
auto & queue = getCurrentSYCLQueue ();
1275
-
1276
1360
int64_t simd = syclMaxSubGroupSize ();
1277
- int64_t wg_size = HxW < get_group_reduce_group_size (simd)
1278
- ? simd
1279
- : get_group_reduce_group_size (simd);
1280
- group_norm_kernel_simd_choice_and_launch<
1281
- ComputeInternalGradientsFunctor<T, SIMD16>,
1282
- ComputeInternalGradientsFunctor<T, SIMD32>>(
1283
- simd,
1284
- sycl::range<1 >(N * C * wg_size),
1285
- sycl::range<1 >(wg_size),
1286
- queue,
1287
- HxW,
1288
- dY_data,
1289
- X_data,
1290
- ds_data,
1291
- db_data);
1361
+
1362
+ constexpr int VEC_SIZE = PREFERRED_VEC_SIZE;
1363
+ int64_t wg_size = 0 ;
1364
+
1365
+ if (can_use_vectorization (dY_data, VEC_SIZE) &&
1366
+ can_use_vectorization (X_data, VEC_SIZE) &&
1367
+ can_use_vectorization (ds_data, VEC_SIZE) &&
1368
+ can_use_vectorization (db_data, VEC_SIZE) && HxW % VEC_SIZE == 0 &&
1369
+ (N * C) % VEC_SIZE == 0 ) {
1370
+ using KernelS16T =
1371
+ ComputeInternalGradientsVectorizedFunctor<T, SIMD16, VEC_SIZE>;
1372
+ using KernelS32T =
1373
+ ComputeInternalGradientsVectorizedFunctor<T, SIMD32, VEC_SIZE>;
1374
+ wg_size = (HxW / VEC_SIZE) < get_group_reduce_group_size (simd)
1375
+ ? simd
1376
+ : get_group_reduce_group_size (simd);
1377
+ group_norm_kernel_simd_choice_and_launch<KernelS16T, KernelS32T>(
1378
+ simd,
1379
+ sycl::range<1 >((N * C / VEC_SIZE) * wg_size),
1380
+ sycl::range<1 >(wg_size),
1381
+ queue,
1382
+ HxW,
1383
+ dY_data,
1384
+ X_data,
1385
+ ds_data,
1386
+ db_data);
1387
+ } else {
1388
+ using KernelS16T = ComputeInternalGradientsFunctor<T, SIMD16>;
1389
+ using KernelS32T = ComputeInternalGradientsFunctor<T, SIMD32>;
1390
+ wg_size = HxW < get_group_reduce_group_size (simd)
1391
+ ? simd
1392
+ : get_group_reduce_group_size (simd);
1393
+ group_norm_kernel_simd_choice_and_launch<KernelS16T, KernelS32T>(
1394
+ simd,
1395
+ sycl::range<1 >(N * C * wg_size),
1396
+ sycl::range<1 >(wg_size),
1397
+ queue,
1398
+ HxW,
1399
+ dY_data,
1400
+ X_data,
1401
+ ds_data,
1402
+ db_data);
1403
+ }
1292
1404
1293
1405
if (dX.defined ()) {
1294
1406
Tensor c1 = at::empty ({0 }, X.options ().dtype (kAccType ));
@@ -1373,8 +1485,8 @@ void group_norm_backward_kernel_impl(
1373
1485
sycl_kernel_submit (sycl::range<1 >(C), queue, caller);
1374
1486
} else {
1375
1487
// The algorithm for colwise reduction here is to accumulate each
1376
- // (subgroup_size) cols to a (subgroup_size^2) tile and write the tile to
1377
- // shared memory. Then do subgroup reduce for each col in the tile.
1488
+ // (subgroup_size) cols to a (subgroup_size^2) tile and write the tile
1489
+ // to shared memory. Then do subgroup reduce for each col in the tile.
1378
1490
const int64_t kReduceTileSize = simd;
1379
1491
const int64_t B = (C + kReduceTileSize - 1 ) / kReduceTileSize ;
1380
1492
auto global_range =
0 commit comments