@@ -63,13 +63,12 @@ template <typename T, int SIMD>
63
63
struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
64
64
using T_ACC = acc_type_device<T, kXPU >;
65
65
using WelfordType = WelfordData<T_ACC, int64_t >;
66
- using WelfordOp =
67
- WelfordOpsXPU<T_ACC, T_ACC, int64_t , std::pair<T_ACC, T_ACC>>;
66
+ using WelfordOp = WelfordOps<T_ACC, T_ACC, int64_t , std::pair<T_ACC, T_ACC>>;
68
67
69
68
[[intel::reqd_sub_group_size(SIMD)]] void operator ()(
70
69
sycl::nd_item<1 > item) const {
71
70
const int64_t i = item.get_group (0 );
72
- WelfordOp welford_op = {/* correction=*/ 0 , /* take_sqrt=*/ false , item };
71
+ WelfordOp welford_op = {/* correction=*/ 0 , /* take_sqrt=*/ false };
73
72
WelfordType val (0 , 0 , 0 , 0 );
74
73
WelfordType identity_element (0 , 0 , 0 , 0 );
75
74
for (int64_t j = item.get_local_id (0 ); j < N_;
@@ -78,8 +77,13 @@ struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
78
77
val = welford_op.reduce (val, static_cast <T_ACC>(X_[index]), index);
79
78
}
80
79
81
- val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
82
- item, val, welford_op, identity_element, shared_);
80
+ if (item.get_local_range (0 ) <= SIMD) {
81
+ val = SubgroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
82
+ item, val, welford_op);
83
+ } else {
84
+ val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
85
+ item, val, welford_op, identity_element, shared_);
86
+ }
83
87
84
88
if (item.get_local_id (0 ) == 0 ) {
85
89
T_ACC m1;
@@ -111,15 +115,14 @@ struct GNRowwiseMomentsVectorizedFunctor
111
115
: public __SYCL_KER_CONFIG_CONVENTION__ {
112
116
using T_ACC = acc_type_device<T, kXPU >;
113
117
using WelfordType = WelfordData<T_ACC, int64_t >;
114
- using WelfordOp =
115
- WelfordOpsXPU<T_ACC, T_ACC, int64_t , std::pair<T_ACC, T_ACC>>;
118
+ using WelfordOp = WelfordOps<T_ACC, T_ACC, int64_t , std::pair<T_ACC, T_ACC>>;
116
119
using vec_t = memory::aligned_vector<T, VEC_SIZE>;
117
120
118
121
[[intel::reqd_sub_group_size(SIMD)]] void operator ()(
119
122
sycl::nd_item<1 > item) const {
120
123
WelfordType val[VEC_SIZE];
121
124
WelfordType identity_element (0 , 0 , 0 , 0 );
122
- WelfordOp welford_op = {/* correction=*/ 0 , /* take_sqrt=*/ false , item };
125
+ WelfordOp welford_op = {/* correction=*/ 0 , /* take_sqrt=*/ false };
123
126
auto g_start = item.get_group (0 ) * VEC_SIZE;
124
127
125
128
#pragma unroll
@@ -140,8 +143,15 @@ struct GNRowwiseMomentsVectorizedFunctor
140
143
141
144
#pragma unroll
142
145
for (int v = 0 ; v < VEC_SIZE; ++v) {
143
- val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
144
- item, val[v], welford_op, identity_element, shared_);
146
+ // val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
147
+ // item, val[v], welford_op, identity_element, shared_);
148
+ if (item.get_local_range (0 ) <= SIMD) {
149
+ val[v] = SubgroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
150
+ item, val[v], welford_op);
151
+ } else {
152
+ val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
153
+ item, val[v], welford_op, identity_element, shared_);
154
+ }
145
155
}
146
156
147
157
if (item.get_local_id (0 ) == 0 ) {
0 commit comments