Skip to content

Commit 22e7cb9

Browse files
committed
Merge from main branch
Signed-off-by: Cheng, Penghui <[email protected]>
2 parents c71e099 + 5907931 commit 22e7cb9

File tree

10 files changed

+457
-186
lines changed

10 files changed

+457
-186
lines changed

.github/scripts/apply_torch_pr.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,12 @@ def check_merged(pr_info):
5656
merged = False
5757
return merged
5858

59-
def appyly_pr(pull_number, re_apply_msg):
60-
# get the diff
61-
os.system(f"\
62-
git fetch origin pull/{pull_number}/head:{pull_number} && \
63-
git checkout -f {pull_number} && \
64-
git merge ci-tmp-$(hostname) --no-edit --no-ff > /dev/null && \
65-
git diff ci-tmp-$(hostname) {pull_number} > {pull_number}.diff \
66-
")
59+
def appyly_pr(pr_info, repo_info, re_apply_msg):
60+
# get pr diff
61+
pr_file = pr_info["diff_url"].split("/")[-1]
62+
os.system(f"gh --repo {repo_info[-4]}/{repo_info[-3]} pr diff {repo_info[-1]} > {pr_file}")
6763
# apply diff
68-
os.system("git checkout ci-test-$(hostname)")
69-
apply_cmd = f"git reset --hard && git apply --3way {pull_number}.diff"
64+
apply_cmd = "git apply --3way " + pr_file
7065
apply_info = subprocess.Popen(apply_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
7166
apply_message = apply_info.communicate()[0].decode("utf-8")
7267
apply_status = apply_info.returncode
@@ -88,9 +83,6 @@ def appyly_pr(pull_number, re_apply_msg):
8883
pr_list = args.pr_list + args.extra_pr_list
8984
pr_list = set(pr_list)
9085
pr_list = sorted(pr_list)
91-
# checkout a base branch
92-
os.system("git checkout -b ci-tmp-$(hostname) && git checkout -b ci-test-$(hostname) && rm -f *.diff")
93-
os.system("git config --global user.email intel.com && git config --global user.name intel")
9486
for pr_link in pr_list:
9587
repo_info = pr_link.split("/")
9688
pr_info = requests.get('https://api.' + repo_info[-5] + '/repos/' + repo_info[-4] + '/' + \
@@ -107,7 +99,7 @@ def appyly_pr(pull_number, re_apply_msg):
10799
continue
108100
else:
109101
re_apply_msg = "is re-opened and reverted,"
110-
appyly_pr(repo_info[-1], re_apply_msg)
102+
appyly_pr(pr_info, repo_info, re_apply_msg)
111103
elif pr_info["state"].lower() == "closed":
112104
merged_id = next((item["id"] for item in pr_info["labels"] if item["name"] == "Merged"), -1)
113105
re_apply_msg = "is closed but not merged"
@@ -116,8 +108,7 @@ def appyly_pr(pull_number, re_apply_msg):
116108
if merged:
117109
print("{} is closed and merged, no need to apply".format(pr_info["diff_url"]))
118110
continue
119-
appyly_pr(repo_info[-1], re_apply_msg)
111+
appyly_pr(pr_info, repo_info, re_apply_msg)
120112
else:
121113
print("{} is {}, no need to apply".format(pr_info["diff_url"], pr_info["state"]))
122114
sys.exit(1)
123-
os.system("git checkout ci-test-$(hostname) && git reset --hard && git apply --3way *.diff && git status")

.github/workflows/_linux_ut.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
cd third_party/torch-xpu-ops
120120
git checkout ${TORCH_XPU_OPS_COMMIT}
121121
cd ../..
122-
python third_party/torch-xpu-ops/.github/scripts/apply_torch_pr.py
122+
python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py
123123
fi
124124
pip install -r .ci/docker/requirements-ci.txt
125125
- name: Torch Config
@@ -412,7 +412,7 @@ jobs:
412412
cd third_party/torch-xpu-ops
413413
git checkout ${TORCH_XPU_OPS_COMMIT}
414414
cd ../..
415-
python third_party/torch-xpu-ops/.github/scripts/apply_torch_pr.py
415+
python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py
416416
fi
417417
pip install -r .ci/docker/requirements-ci.txt
418418
- name: Torch Config

.github/workflows/nightly_ondemand_whl.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ jobs:
116116
cd pytorch && git checkout ${TORCH_COMMIT_ID}
117117
# apply PRs for stock pytorch
118118
pip install requests
119-
# python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py
120119
git status && git show -s
121120
pip install -r requirements.txt
122121
TORCH_XPU_OPS_COMMIT=$(<third_party/xpu.txt)
@@ -126,7 +125,7 @@ jobs:
126125
cd third_party/torch-xpu-ops
127126
git checkout ${TORCH_XPU_OPS_COMMIT}
128127
cd ../../
129-
python third_party/torch-xpu-ops/.github/scripts/apply_torch_pr.py
128+
python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py
130129
- name: Identify pinned versions
131130
id: pinned
132131
run: |

src/ATen/native/xpu/SpectralOps.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
#include <ATen/native/Resize.h>
2+
#include <ATen/ops/_fft_r2c_native.h>
13
#if defined(USE_ONEMKL_XPU)
24
#include <ATen/native/xpu/mkl/SpectralOps.h>
35
#else
4-
#include <ATen/native/Resize.h>
56
#include <ATen/ops/_fft_c2c_native.h>
67
#include <ATen/ops/_fft_c2r_native.h>
7-
#include <ATen/ops/_fft_r2c_native.h>
88
#endif // USE_ONEMKL_XPU
99

1010
namespace at::native {
@@ -87,13 +87,9 @@ Tensor _fft_r2c_xpu(
8787
bool onesided) {
8888
TORCH_CHECK(self.is_floating_point());
8989

90-
#if defined(USE_ONEMKL_XPU)
91-
return native::xpu::_fft_r2c_mkl(self, dim, normalization, onesided);
92-
#else
9390
Tensor out_cpu = native::_fft_r2c_mkl(
9491
self.to(Device(at::kCPU)), dim, normalization, onesided);
9592
return out_cpu.to(Device(at::kXPU));
96-
#endif // USE_ONEMKL_XPU
9793
}
9894

9995
Tensor& _fft_r2c_xpu_out(
@@ -104,15 +100,11 @@ Tensor& _fft_r2c_xpu_out(
104100
Tensor& out) {
105101
TORCH_CHECK(self.is_floating_point());
106102

107-
#if defined(USE_ONEMKL_XPU)
108-
return native::xpu::_fft_r2c_mkl_out(self, dim, normalization, onesided, out);
109-
#else
110103
Tensor out_cpu = native::_fft_r2c_mkl(
111104
self.to(Device(at::kCPU)), dim, normalization, onesided);
112105
at::native::resize_output(out, out_cpu.sizes());
113106
out.copy_(out_cpu);
114107
return out;
115-
#endif // USE_ONEMKL_XPU
116108
}
117109

118110
} // namespace at::native

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
193193
"_flash_attention_forward",
194194
"geqrf",
195195
"linalg_cholesky_ex.L",
196-
"_linalg_det.result",
197196
"linalg_eig",
198197
"_linalg_eigvals",
199198
"linalg_eigvals.out",
@@ -206,8 +205,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
206205
"linalg_lu.out",
207206
"linalg_matrix_exp",
208207
"linalg_qr.out",
209-
"_linalg_slogdet.sign",
210-
"_linalg_solve_ex.result",
211208
"linalg_solve_triangular",
212209
"_linalg_svd.U",
213210
"lu_unpack.out",

src/ATen/native/xpu/sycl/GroupNormKernels.cpp

Lines changed: 134 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ struct GNRowwiseMomentsVectorizedFunctor
118118
sycl::nd_item<1> item) const {
119119
WelfordType val[VEC_SIZE];
120120
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item};
121-
auto g_start = item.get_group(0) * VEC_SIZE;
121+
auto group_start = item.get_group(0) * VEC_SIZE;
122122

123123
#pragma unroll
124124
for (int v = 0; v < VEC_SIZE; ++v) {
125-
const int64_t i = g_start + v;
125+
const int64_t i = group_start + v;
126126
for (int64_t j = item.get_local_id(0) * VEC_SIZE; j < N_;
127127
j += item.get_local_range(0) * VEC_SIZE) {
128128
const int64_t vec_index = i * N_ + j;
@@ -153,8 +153,8 @@ struct GNRowwiseMomentsVectorizedFunctor
153153
mean_vec[v] = m1;
154154
rstd_vec[v] = c10::xpu::compat::rsqrt(m2 + static_cast<T_ACC>(eps_));
155155
}
156-
*(reinterpret_cast<vec_t*>(mean_ + g_start)) = mean_vec;
157-
*(reinterpret_cast<vec_t*>(rstd_ + g_start)) = rstd_vec;
156+
*(reinterpret_cast<vec_t*>(mean_ + group_start)) = mean_vec;
157+
*(reinterpret_cast<vec_t*>(rstd_ + group_start)) = rstd_vec;
158158
}
159159
}
160160

@@ -934,6 +934,91 @@ struct ComputeInternalGradientsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
934934
sycl_local_acc_t<T_ACC> db_shared_;
935935
};
936936

937+
template <typename T, int SIMD, int VEC_SIZE>
938+
struct ComputeInternalGradientsVectorizedFunctor
939+
: public __SYCL_KER_CONFIG_CONVENTION__ {
940+
using T_ACC = acc_type_device<T, kXPU>;
941+
using vec_t = memory::aligned_vector<T, VEC_SIZE>;
942+
using acc_vec_t = memory::aligned_vector<T_ACC, VEC_SIZE>;
943+
944+
[[intel::reqd_sub_group_size(SIMD)]] void operator()(
945+
sycl::nd_item<1> item) const {
946+
acc_vec_t sum1_vec;
947+
acc_vec_t sum2_vec;
948+
949+
#pragma unroll
950+
for (int v = 0; v < VEC_SIZE; ++v) {
951+
sum1_vec[v] = 0;
952+
sum2_vec[v] = 0;
953+
}
954+
955+
auto group_start = item.get_group(0) * VEC_SIZE;
956+
957+
#pragma unroll
958+
for (int v = 0; v < VEC_SIZE; ++v) {
959+
const int64_t nc = group_start + v;
960+
for (int64_t hw = item.get_local_id(0) * VEC_SIZE; hw < HxW_;
961+
hw += item.get_local_range(0) * VEC_SIZE) {
962+
const int64_t vec_index = nc * HxW_ + hw;
963+
vec_t vec_dY_ =
964+
*reinterpret_cast<vec_t*>(const_cast<T*>(dY_) + vec_index);
965+
vec_t vec_X_ =
966+
*reinterpret_cast<vec_t*>(const_cast<T*>(X_) + vec_index);
967+
968+
#pragma unroll
969+
for (int iv = 0; iv < VEC_SIZE; ++iv) {
970+
sum1_vec[v] += static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]);
971+
sum2_vec[v] += static_cast<T_ACC>(vec_dY_[iv]);
972+
}
973+
}
974+
}
975+
976+
#pragma unroll
977+
for (int v = 0; v < VEC_SIZE; ++v) {
978+
sum1_vec[v] = GroupReduceSumWithoutBroadcast<T_ACC, SIMD>(
979+
item, sum1_vec[v], ds_shared_);
980+
sum2_vec[v] = GroupReduceSumWithoutBroadcast<T_ACC, SIMD>(
981+
item, sum2_vec[v], db_shared_);
982+
}
983+
984+
if (item.get_local_id(0) == 0) {
985+
acc_vec_t ds_vec;
986+
acc_vec_t db_vec;
987+
#pragma unroll
988+
for (int v = 0; v < VEC_SIZE; ++v) {
989+
ds_vec[v] = sum1_vec[v];
990+
db_vec[v] = sum2_vec[v];
991+
}
992+
*(reinterpret_cast<acc_vec_t*>(ds_ + group_start)) = ds_vec;
993+
*(reinterpret_cast<acc_vec_t*>(db_ + group_start)) = db_vec;
994+
}
995+
}
996+
997+
void sycl_ker_config_convention(sycl::handler& cgh) {
998+
ds_shared_ =
999+
sycl_local_acc_t<T_ACC>(get_group_reduce_group_size(SIMD), cgh);
1000+
db_shared_ =
1001+
sycl_local_acc_t<T_ACC>(get_group_reduce_group_size(SIMD), cgh);
1002+
}
1003+
1004+
ComputeInternalGradientsVectorizedFunctor(
1005+
int64_t HxW,
1006+
const T* dY,
1007+
const T* X,
1008+
T_ACC* ds,
1009+
T_ACC* db)
1010+
: HxW_(HxW), dY_(dY), X_(X), ds_(ds), db_(db) {}
1011+
1012+
private:
1013+
int64_t HxW_;
1014+
const T* dY_;
1015+
const T* X_;
1016+
T_ACC* ds_;
1017+
T_ACC* db_;
1018+
sycl_local_acc_t<T_ACC> ds_shared_;
1019+
sycl_local_acc_t<T_ACC> db_shared_;
1020+
};
1021+
9371022
template <typename T, typename T_ACC>
9381023
struct GroupNormBackwardC1Functor {
9391024
T_ACC operator()(T rstd, T gamma) const {
@@ -1272,23 +1357,50 @@ void group_norm_backward_kernel_impl(
12721357
}
12731358

12741359
auto& queue = getCurrentSYCLQueue();
1275-
12761360
int64_t simd = syclMaxSubGroupSize();
1277-
int64_t wg_size = HxW < get_group_reduce_group_size(simd)
1278-
? simd
1279-
: get_group_reduce_group_size(simd);
1280-
group_norm_kernel_simd_choice_and_launch<
1281-
ComputeInternalGradientsFunctor<T, SIMD16>,
1282-
ComputeInternalGradientsFunctor<T, SIMD32>>(
1283-
simd,
1284-
sycl::range<1>(N * C * wg_size),
1285-
sycl::range<1>(wg_size),
1286-
queue,
1287-
HxW,
1288-
dY_data,
1289-
X_data,
1290-
ds_data,
1291-
db_data);
1361+
1362+
constexpr int VEC_SIZE = PREFERRED_VEC_SIZE;
1363+
int64_t wg_size = 0;
1364+
1365+
if (can_use_vectorization(dY_data, VEC_SIZE) &&
1366+
can_use_vectorization(X_data, VEC_SIZE) &&
1367+
can_use_vectorization(ds_data, VEC_SIZE) &&
1368+
can_use_vectorization(db_data, VEC_SIZE) && HxW % VEC_SIZE == 0 &&
1369+
(N * C) % VEC_SIZE == 0) {
1370+
using KernelS16T =
1371+
ComputeInternalGradientsVectorizedFunctor<T, SIMD16, VEC_SIZE>;
1372+
using KernelS32T =
1373+
ComputeInternalGradientsVectorizedFunctor<T, SIMD32, VEC_SIZE>;
1374+
wg_size = (HxW / VEC_SIZE) < get_group_reduce_group_size(simd)
1375+
? simd
1376+
: get_group_reduce_group_size(simd);
1377+
group_norm_kernel_simd_choice_and_launch<KernelS16T, KernelS32T>(
1378+
simd,
1379+
sycl::range<1>((N * C / VEC_SIZE) * wg_size),
1380+
sycl::range<1>(wg_size),
1381+
queue,
1382+
HxW,
1383+
dY_data,
1384+
X_data,
1385+
ds_data,
1386+
db_data);
1387+
} else {
1388+
using KernelS16T = ComputeInternalGradientsFunctor<T, SIMD16>;
1389+
using KernelS32T = ComputeInternalGradientsFunctor<T, SIMD32>;
1390+
wg_size = HxW < get_group_reduce_group_size(simd)
1391+
? simd
1392+
: get_group_reduce_group_size(simd);
1393+
group_norm_kernel_simd_choice_and_launch<KernelS16T, KernelS32T>(
1394+
simd,
1395+
sycl::range<1>(N * C * wg_size),
1396+
sycl::range<1>(wg_size),
1397+
queue,
1398+
HxW,
1399+
dY_data,
1400+
X_data,
1401+
ds_data,
1402+
db_data);
1403+
}
12921404

12931405
if (dX.defined()) {
12941406
Tensor c1 = at::empty({0}, X.options().dtype(kAccType));
@@ -1373,8 +1485,8 @@ void group_norm_backward_kernel_impl(
13731485
sycl_kernel_submit(sycl::range<1>(C), queue, caller);
13741486
} else {
13751487
// The algorithm for colwise reduction here is to accumulate each
1376-
// (subgroup_size) cols to a (subgroup_size^2) tile and write the tile to
1377-
// shared memory. Then do subgroup reduce for each col in the tile.
1488+
// (subgroup_size) cols to a (subgroup_size^2) tile and write the tile
1489+
// to shared memory. Then do subgroup reduce for each col in the tile.
13781490
const int64_t kReduceTileSize = simd;
13791491
const int64_t B = (C + kReduceTileSize - 1) / kReduceTileSize;
13801492
auto global_range =

src/ATen/native/xpu/sycl/GroupReduceUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ inline T& GroupReduceSumWithoutBroadcast(
4949
T& val,
5050
shared_t shared) {
5151
auto sg = item.get_sub_group();
52+
int g_tid = item.get_local_linear_id();
5253
int sg_tid = sg.get_local_linear_id();
5354
int sg_id = sg.get_group_linear_id();
5455
int n_sg = get_local_linear_range<DIM>(item) / SIMD;
@@ -62,10 +63,12 @@ inline T& GroupReduceSumWithoutBroadcast(
6263
shared[sg_id] = val;
6364
}
6465
item.barrier(sycl_local_fence);
66+
val = 0;
6567
if (sg_id == 0) {
66-
for (int i = 1; i < n_sg; i++) {
68+
for (int i = sg_tid; i < n_sg; i += SIMD) {
6769
val += shared[i];
6870
}
71+
val = SubgroupReduceSumWithoutBroadcast<T, SIMD, DIM>(item, val);
6972
}
7073
return val;
7174
}

0 commit comments

Comments
 (0)