diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index 935ab99f7..53de4b00a 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -63,22 +63,27 @@ template struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using T_ACC = acc_type_device; using WelfordType = WelfordData; - using WelfordOp = - WelfordOpsXPU>; + using WelfordOp = WelfordOps>; [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<1> item) const { const int64_t i = item.get_group(0); - WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item}; + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; WelfordType val(0, 0, 0, 0); + WelfordType identity_element(0, 0, 0, 0); for (int64_t j = item.get_local_id(0); j < N_; j += item.get_local_range(0)) { const int64_t index = i * N_ + j; val = welford_op.reduce(val, static_cast(X_[index]), index); } - val = GroupReduceWithoutBroadcast( - item, val, welford_op, shared_); + if (item.get_local_range(0) <= SIMD) { + val = SubgroupReduceWithoutBroadcast( + item, val, welford_op); + } else { + val = GroupReduceWithoutBroadcast( + item, val, welford_op, identity_element, shared_); + } if (item.get_local_id(0) == 0) { T_ACC m1; @@ -110,14 +115,14 @@ struct GNRowwiseMomentsVectorizedFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using T_ACC = acc_type_device; using WelfordType = WelfordData; - using WelfordOp = - WelfordOpsXPU>; + using WelfordOp = WelfordOps>; using vec_t = memory::aligned_vector; [[intel::reqd_sub_group_size(SIMD)]] void operator()( sycl::nd_item<1> item) const { WelfordType val[VEC_SIZE]; - WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item}; + WelfordType identity_element(0, 0, 0, 0); + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; auto group_start = item.get_group(0) * VEC_SIZE; #pragma unroll @@ -138,8 +143,13 @@ struct GNRowwiseMomentsVectorizedFunctor #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { - val[v] = GroupReduceWithoutBroadcast( - item, val[v], welford_op, shared_); + if (item.get_local_range(0) <= SIMD) { + val[v] = SubgroupReduceWithoutBroadcast( + item, val[v], welford_op); + } else { + val[v] = GroupReduceWithoutBroadcast( + item, val[v], welford_op, identity_element, shared_); + } } if (item.get_local_id(0) == 0) { diff --git a/src/ATen/native/xpu/sycl/GroupReduceUtils.h b/src/ATen/native/xpu/sycl/GroupReduceUtils.h index 5e383669e..20ef59974 100644 --- a/src/ATen/native/xpu/sycl/GroupReduceUtils.h +++ b/src/ATen/native/xpu/sycl/GroupReduceUtils.h @@ -136,8 +136,10 @@ inline T& GroupReduceWithoutBroadcast( sycl::nd_item& item, T& val, const ReduceOp& op, + const T& identity_element, shared_t shared) { auto sg = item.get_sub_group(); + int g_tid = item.get_local_linear_id(); int sg_tid = sg.get_local_linear_id(); int sg_id = sg.get_group_linear_id(); int n_sg = get_local_linear_range(item) / SIMD; @@ -151,10 +153,13 @@ inline T& GroupReduceWithoutBroadcast( shared[sg_id] = val; } item.barrier(sycl_local_fence); + val = identity_element; + if (sg_id == 0) { - for (int i = 1; i < n_sg; i++) { + for (int i = sg_tid; i < n_sg; i += SIMD) { val = op.combine(val, shared[i]); } + val = SubgroupReduceWithoutBroadcast(item, val, op); } return val; } diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 2d62ad058..7365d9741 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -191,6 +191,7 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const int64_t i = item_id.get_group(0); WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; WelfordType val(0, 0, 0, 0); + WelfordType identity_element(0, 0, 0, 0); for (int64_t j = item_id.get_local_id(0); j < N_; j += item_id.get_local_range(0)) { const int64_t index = i * N_ + j; @@ -198,7 +199,7 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } val = GroupReduceWithoutBroadcast( - item_id, val, welford_op, shared_); + item_id, val, welford_op, identity_element, shared_); if (item_id.get_local_id(0) == 0) { T_ACC m1; diff --git a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp index 7ae95e36b..1e5a2fc6c 100644 --- a/src/ATen/native/xpu/sycl/TensorModeKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorModeKernel.cpp @@ -217,6 +217,7 @@ inline T reduceGroupWithNThreadLocalReductions( T init) { int offset = item.get_local_id(2) * N; T local = offset < numVals ? threadVals[0] : init; + T identity_element = init; #pragma unroll for (int i = 1; i < N; ++i) { @@ -226,7 +227,7 @@ inline T reduceGroupWithNThreadLocalReductions( } return GroupReduceWithoutBroadcast( - item, local, reduceOp, smem); + item, local, reduceOp, identity_element, smem); } template