Skip to content

Commit da7e8ac

Browse files
ddchenhao66chang-wenbin
authored andcommitted
[XPU] fix thinking bug where output only contains reasoning_content (PaddlePaddle#4761)
Co-authored-by: ddchenhao66 <dhaochen163.com>
1 parent 4400571 commit da7e8ac

File tree

8 files changed

+201
-21
lines changed

8 files changed

+201
-21
lines changed

custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v1.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,38 @@ void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens,
2525
const paddle::Tensor& max_think_lens,
2626
const paddle::Tensor& step_idx,
2727
const paddle::Tensor& limit_think_status,
28+
const paddle::Tensor& stop_flags,
29+
const paddle::Tensor& eos_token_ids,
2830
const int64_t think_end_id) {
2931
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
3032
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
3133
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
3234

3335
const int batch_size = next_tokens.shape()[0];
36+
const int eos_token_id_len = eos_token_ids.shape()[0];
3437
int r = baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1(
3538
xpu_ctx->x_context(),
3639
const_cast<int64_t*>(next_tokens.data<int64_t>()),
3740
max_think_lens.data<int>(),
3841
step_idx.data<int64_t>(),
42+
eos_token_ids.data<int64_t>(),
3943
const_cast<int*>(limit_think_status.data<int>()),
44+
const_cast<bool*>(stop_flags.data<bool>()),
4045
think_end_id,
41-
batch_size);
46+
batch_size,
47+
eos_token_id_len);
4248
PD_CHECK(r == 0,
4349
"baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1 "
4450
"failed.");
4551
}
4652

4753
PD_BUILD_STATIC_OP(limit_thinking_content_length_v1)
48-
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
54+
.Inputs({"next_tokens",
55+
"max_think_lens",
56+
"step_idx",
57+
"limit_think_status",
58+
"stop_flags",
59+
"eos_token_ids"})
4960
.Attrs({"think_end_id: int64_t"})
5061
.Outputs({"next_tokens_out"})
5162
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

custom_ops/xpu_ops/src/ops/limit_thinking_content_length_v2.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
2525
const paddle::Tensor& max_think_lens,
2626
const paddle::Tensor& step_idx,
2727
const paddle::Tensor& limit_think_status,
28+
const paddle::Tensor& stop_flags,
2829
const int64_t think_end_id,
2930
const int64_t line_break_id) {
3031
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -38,6 +39,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
3839
max_think_lens.data<int>(),
3940
step_idx.data<int64_t>(),
4041
const_cast<int*>(limit_think_status.data<int>()),
42+
stop_flags.data<bool>(),
4143
think_end_id,
4244
line_break_id,
4345
batch_size);
@@ -47,7 +49,11 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
4749
}
4850

4951
PD_BUILD_STATIC_OP(limit_thinking_content_length_v2)
50-
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
52+
.Inputs({"next_tokens",
53+
"max_think_lens",
54+
"step_idx",
55+
"limit_think_status",
56+
"stop_flags"})
5157
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
5258
.Outputs({"next_tokens_out"})
5359
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,16 +220,20 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v1(
220220
int64_t* next_tokens,
221221
const int* max_think_lens,
222222
const int64_t* step_idx,
223+
const int64_t* eos_token_ids,
223224
int* limit_think_status,
225+
bool* stop_flags,
224226
const int64_t think_end_id,
225-
const int bs);
227+
const int bs,
228+
const int eos_token_id_len);
226229

227230
DLL_EXPORT int limit_thinking_content_length_kernel_v2(
228231
api::Context* ctx,
229232
int64_t* next_tokens,
230233
const int* max_think_lens,
231234
const int64_t* step_idx,
232235
int* limit_think_status,
236+
const bool* stop_flags,
233237
const int64_t think_end_id,
234238
const int64_t line_break_id,
235239
const int bs);

custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v1.xpu

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,43 +10,68 @@
1010
namespace xpu3 {
1111
namespace plugin {
1212

13+
template <typename T>
14+
static inline __device__ bool is_in_end(const T id,
15+
const T* end_ids,
16+
const int length) {
17+
for (int i = 0; i < length; i++) {
18+
if (id == end_ids[i]) {
19+
return true;
20+
}
21+
}
22+
return false;
23+
}
24+
1325
__global__ void limit_thinking_content_length_kernel_v1(
1426
int64_t* next_tokens,
1527
const int* max_think_lens,
1628
const int64_t* step_idx,
29+
const int64_t* eos_token_ids,
1730
int* limit_think_status,
31+
bool* stop_flags,
1832
const int64_t think_end_id,
19-
const int bs) {
33+
const int bs,
34+
const int eos_token_id_len) {
2035
int cid = core_id();
2136
int ncores = core_num();
2237
int clusterid = cluster_id();
2338
int nclusters = cluster_num();
2439
if (clusterid != 0) return;
40+
__simd__ __local__ int64_t eos_token_ids_lm[256];
2541

2642
for (int i = cid; i < bs; i += ncores) {
2743
int max_think_len_lm;
2844
int limit_think_status_lm;
2945
int64_t next_token_lm;
3046
int64_t step_idx_lm;
47+
bool stop_flags_lm;
3148
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
3249
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
3350
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
51+
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
52+
GM2LM_ASYNC(
53+
eos_token_ids, eos_token_ids_lm, sizeof(int64_t) * eos_token_id_len);
3454
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
3555

3656
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
3757
if (max_think_len_lm < 0) continue;
3858
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
39-
if (limit_think_status_lm == 2) continue;
59+
if (limit_think_status_lm == 2 && stop_flags_lm) continue;
4060

4161
// ======================= 思考阶段控制 =======================
4262
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
4363
if (limit_think_status_lm < 1) {
4464
// 当开启思考长度控制时,检查是否超时
45-
if (step_idx_lm >= max_think_len_lm) {
65+
if ((step_idx_lm >= max_think_len_lm) ||
66+
is_in_end(next_token_lm, eos_token_ids_lm, eos_token_id_len)) {
4667
// 强制将当前token替换为结束思考的token
4768
next_token_lm = think_end_id;
4869
// 将状态推进到 1, 表示 "正在结束思考"
4970
limit_think_status_lm = 1;
71+
if (stop_flags_lm) {
72+
stop_flags_lm = false;
73+
LM2GM(&stop_flags_lm, stop_flags + i, sizeof(bool));
74+
}
5075
}
5176
}
5277

custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/limit_thinking_content_length_v2.xpu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ __global__ void limit_thinking_content_length_kernel_v2(
1515
const int* max_think_lens,
1616
const int64_t* step_idx,
1717
int* limit_think_status,
18+
const bool* stop_flags,
1819
const int64_t think_end_id,
1920
const int64_t line_break_id,
2021
const int bs) {
@@ -29,15 +30,17 @@ __global__ void limit_thinking_content_length_kernel_v2(
2930
int limit_think_status_lm;
3031
int64_t next_token_lm;
3132
int64_t step_idx_lm;
33+
bool stop_flags_lm;
3234
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
3335
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
36+
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
3437
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
3538
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
3639

3740
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
3841
if (max_think_len_lm < 0) continue;
3942
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
40-
if (limit_think_status_lm == 3) continue;
43+
if (limit_think_status_lm == 3 && stop_flags_lm) continue;
4144

4245
// ======================= 思考阶段控制 =======================
4346
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束

custom_ops/xpu_ops/src/plugin/src/wrapper/limit_thinking_content_length_v1.cpp

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v1(
2424
int64_t* next_tokens,
2525
const int* max_think_lens,
2626
const int64_t* step_idx,
27+
const int64_t* eos_token_ids,
2728
int* limit_think_status,
29+
bool* stop_flags,
2830
const int64_t think_end_id,
29-
const int bs);
30-
31+
const int bs,
32+
const int eos_token_id_len);
3133
} // namespace plugin
3234
} // namespace xpu3
3335

@@ -36,55 +38,118 @@ namespace xpu {
3638
namespace api {
3739
namespace plugin {
3840

41+
static int cpu_wrapper(Context* ctx,
42+
int64_t* next_tokens,
43+
const int* max_think_lens,
44+
const int64_t* step_idx,
45+
const int64_t* eos_token_ids,
46+
int* limit_think_status,
47+
bool* stop_flags,
48+
const int64_t think_end_id,
49+
const int bs,
50+
const int eos_token_id_len) {
51+
auto is_in_end = [](int64_t token_id, const int64_t* end_ids, int length) {
52+
for (int i = 0; i < length; i++) {
53+
if (token_id == end_ids[i]) {
54+
return true;
55+
}
56+
}
57+
return false;
58+
};
59+
for (int bid = 0; bid < bs; bid++) {
60+
const int max_think_len = max_think_lens[bid];
61+
if (max_think_len < 0) continue;
62+
int current_limit_think_status = limit_think_status[bid];
63+
if (limit_think_status[bid] == 2 && stop_flags[bid]) continue;
64+
int64_t next_token = next_tokens[bid];
65+
const int64_t step = step_idx[bid];
66+
if (current_limit_think_status < 1) {
67+
if (step >= max_think_len ||
68+
is_in_end(next_token, eos_token_ids, eos_token_id_len)) {
69+
next_token = think_end_id;
70+
current_limit_think_status = 1;
71+
}
72+
}
73+
if (current_limit_think_status < 2) {
74+
if (next_token == think_end_id) {
75+
current_limit_think_status = 2;
76+
}
77+
}
78+
next_tokens[bid] = next_token;
79+
limit_think_status[bid] = current_limit_think_status;
80+
}
81+
return api::SUCCESS;
82+
}
3983
static int xpu3_wrapper(Context* ctx,
4084
int64_t* next_tokens,
4185
const int* max_think_lens,
4286
const int64_t* step_idx,
87+
const int64_t* eos_token_ids,
4388
int* limit_think_status,
89+
bool* stop_flags,
4490
const int64_t think_end_id,
45-
const int bs) {
91+
const int bs,
92+
const int eos_token_id_len) {
4693
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
4794
auto limit_thinking_content_length_kernel_v1 =
4895
xpu3::plugin::limit_thinking_content_length_kernel_v1;
4996
limit_thinking_content_length_kernel_v1<<<1, 64, ctx->xpu_stream>>>(
5097
reinterpret_cast<XPU_INT64*>(next_tokens),
5198
max_think_lens,
5299
reinterpret_cast<const XPU_INT64*>(step_idx),
100+
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
53101
limit_think_status,
102+
stop_flags,
54103
think_end_id,
55-
bs);
104+
bs,
105+
eos_token_id_len);
56106
return api::SUCCESS;
57107
}
58108

59109
int limit_thinking_content_length_kernel_v1(Context* ctx,
60110
int64_t* next_tokens,
61111
const int* max_think_lens,
62112
const int64_t* step_idx,
113+
const int64_t* eos_token_ids,
63114
int* limit_think_status,
115+
bool* stop_flags,
64116
const int64_t think_end_id,
65-
const int bs) {
117+
const int bs,
118+
const int eos_token_id_len) {
66119
WRAPPER_CHECK_CTX(ctx);
67120
WRAPPER_DUMP_FUNCTION_T1(ctx, "limit_thinking_content_length_kernel_v1", int);
68121
WRAPPER_DUMP_PARAM5(ctx,
69122
next_tokens,
70123
max_think_lens,
71124
step_idx,
72-
limit_think_status,
73-
think_end_id);
74-
WRAPPER_DUMP_PARAM1(ctx, bs);
125+
eos_token_ids,
126+
limit_think_status);
127+
WRAPPER_DUMP_PARAM4(ctx, stop_flags, think_end_id, bs, eos_token_id_len);
75128

76129
WRAPPER_DUMP(ctx);
77130
if (ctx->dev().type() == api::kCPU) {
78-
assert(false);
131+
return cpu_wrapper(ctx,
132+
next_tokens,
133+
max_think_lens,
134+
step_idx,
135+
eos_token_ids,
136+
limit_think_status,
137+
stop_flags,
138+
think_end_id,
139+
bs,
140+
eos_token_id_len);
79141
}
80142
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
81143
return xpu3_wrapper(ctx,
82144
next_tokens,
83145
max_think_lens,
84146
step_idx,
147+
eos_token_ids,
85148
limit_think_status,
149+
stop_flags,
86150
think_end_id,
87-
bs);
151+
bs,
152+
eos_token_id_len);
88153
}
89154
WRAPPER_UNIMPLEMENTED(ctx);
90155
}

0 commit comments

Comments
 (0)