Skip to content

Commit 3377027

Browse files
committed
update
1 parent 82f5fab commit 3377027

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

src/ATen/native/xpu/sycl/GroupNormKernels.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@ template <typename T, int SIMD>
6363
struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
6464
using T_ACC = acc_type_device<T, kXPU>;
6565
using WelfordType = WelfordData<T_ACC, int64_t>;
66-
using WelfordOp =
67-
WelfordOpsXPU<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
66+
using WelfordOp = WelfordOps<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
6867

6968
[[intel::reqd_sub_group_size(SIMD)]] void operator()(
7069
sycl::nd_item<1> item) const {
7170
const int64_t i = item.get_group(0);
72-
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item};
71+
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
7372
WelfordType val(0, 0, 0, 0);
7473
WelfordType identity_element(0, 0, 0, 0);
7574
for (int64_t j = item.get_local_id(0); j < N_;
@@ -78,8 +77,13 @@ struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
7877
val = welford_op.reduce(val, static_cast<T_ACC>(X_[index]), index);
7978
}
8079

81-
val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
82-
item, val, welford_op, identity_element, shared_);
80+
if (item.get_local_range(0) <= SIMD) {
81+
val = SubgroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
82+
item, val, welford_op);
83+
} else {
84+
val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
85+
item, val, welford_op, identity_element, shared_);
86+
}
8387

8488
if (item.get_local_id(0) == 0) {
8589
T_ACC m1;
@@ -111,15 +115,14 @@ struct GNRowwiseMomentsVectorizedFunctor
111115
: public __SYCL_KER_CONFIG_CONVENTION__ {
112116
using T_ACC = acc_type_device<T, kXPU>;
113117
using WelfordType = WelfordData<T_ACC, int64_t>;
114-
using WelfordOp =
115-
WelfordOpsXPU<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
118+
using WelfordOp = WelfordOps<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
116119
using vec_t = memory::aligned_vector<T, VEC_SIZE>;
117120

118121
[[intel::reqd_sub_group_size(SIMD)]] void operator()(
119122
sycl::nd_item<1> item) const {
120123
WelfordType val[VEC_SIZE];
121124
WelfordType identity_element(0, 0, 0, 0);
122-
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item};
125+
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
123126
auto g_start = item.get_group(0) * VEC_SIZE;
124127

125128
#pragma unroll
@@ -140,8 +143,15 @@ struct GNRowwiseMomentsVectorizedFunctor
140143

141144
#pragma unroll
142145
for (int v = 0; v < VEC_SIZE; ++v) {
143-
val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
144-
item, val[v], welford_op, identity_element, shared_);
146+
// val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
147+
// item, val[v], welford_op, identity_element, shared_);
148+
if (item.get_local_range(0) <= SIMD) {
149+
val[v] = SubgroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
150+
item, val[v], welford_op);
151+
} else {
152+
val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
153+
item, val[v], welford_op, identity_element, shared_);
154+
}
145155
}
146156

147157
if (item.get_local_id(0) == 0) {

src/ATen/native/xpu/sycl/GroupReduceUtils.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,12 @@ inline T& SubgroupReduceWithoutBroadcast(
119119
auto sg = item.get_sub_group();
120120
auto sg_tid = sg.get_local_linear_id();
121121
#pragma unroll
122-
for (int offset = 1; offset < SIMD; offset <<= 1) {
123-
if (sg_tid < SIMD - offset) {
124-
val = op.combine(val, sycl::shift_group_left(sg, val, offset));
125-
}
122+
for (int offset = (SIMD >> 1); offset > 0; offset >>= 1) {
123+
// for (int offset = 1; offset < SIMD; offset <<= 1) {
124+
T temp = sycl::shift_group_left(sg, val, offset);
125+
// if (sg_tid < SIMD - offset) {
126+
val = op.combine(val, temp);
127+
// }
126128
}
127129
return val;
128130
}
@@ -135,6 +137,7 @@ inline T& GroupReduceWithoutBroadcast(
135137
const T& identity_element,
136138
shared_t shared) {
137139
auto sg = item.get_sub_group();
140+
int g_tid = item.get_local_linear_id();
138141
int sg_tid = sg.get_local_linear_id();
139142
int sg_id = sg.get_group_linear_id();
140143
int n_sg = get_local_linear_range<DIM>(item) / SIMD;
@@ -148,9 +151,17 @@ inline T& GroupReduceWithoutBroadcast(
148151
shared[sg_id] = val;
149152
}
150153
item.barrier(sycl_local_fence);
151-
val = (sg_id < n_sg) ? shared[sg_id] : identity_element;
154+
// val = (g_tid < n_sg) ? shared[sg_id] : identity_element;
155+
val = identity_element;
156+
152157
if (sg_id == 0) {
158+
for (int i = sg_tid; i < n_sg; i += SIMD) {
159+
val = op.combine(val, shared[i]);
160+
}
153161
val = SubgroupReduceWithoutBroadcast<T, ReduceOp, SIMD, DIM>(item, val, op);
162+
// for (int i = 1; i < n_sg; i++) {
163+
// val = op.combine(val, shared[i]);
164+
// }
154165
}
155166
return val;
156167
}

0 commit comments

Comments
 (0)