Skip to content

Commit 8603b98

Browse files
ai-edge-botcopybara-github
authored andcommitted
Fix issue of outputting stop token in Decode/DecodeStreaming
LiteRT-LM-PiperOrigin-RevId: 789000868
1 parent 0691ed3 commit 8603b98

3 files changed

Lines changed: 27 additions & 13 deletions

File tree

runtime/core/pipeline.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,11 @@ absl::StatusOr<Responses> Decode(LlmExecutor& executor, Tokenizer& tokenizer,
341341
}
342342

343343
previous_token_ids.resize(0);
344-
response_texts[0] +=
345-
absl::StrReplaceAll(run_one_step.GetResultTokens()[0], {{"", " "}});
346344
num_decoded_steps++;
345+
if (decode_result == kPartial) {
346+
response_texts[0] +=
347+
absl::StrReplaceAll(run_one_step.GetResultTokens()[0], {{"", " "}});
348+
}
347349

348350
if (ShouldStop(decode_result == kDone, benchmark_decode_token_count,
349351
num_decoded_steps, executor.GetCurrentStep().value(),
@@ -398,10 +400,12 @@ absl::Status DecodeStreaming(LlmExecutor& executor, Tokenizer& tokenizer,
398400
}
399401

400402
previous_token_ids.resize(0);
401-
response_texts[0] +=
402-
absl::StrReplaceAll(run_one_step.GetResultTokens()[0], {{"", " "}});
403403
num_decoded_steps++;
404-
observer->OnNext(responses);
404+
if (decode_result == kPartial) {
405+
response_texts[0] +=
406+
absl::StrReplaceAll(run_one_step.GetResultTokens()[0], {{"", " "}});
407+
observer->OnNext(responses);
408+
}
405409

406410
if (ShouldStop(decode_result == kDone, benchmark_decode_token_count,
407411
num_decoded_steps, executor.GetCurrentStep().value(),
@@ -533,8 +537,10 @@ absl::Status DecodeCustomSamplingStreaming(
533537
scores[j] += run_one_step.GetScores()[j];
534538
}
535539
}
540+
if (*decode_result == kPartial) {
541+
observer->OnNext(responses);
542+
}
536543
num_decode_steps++;
537-
observer->OnNext(responses);
538544
if (ShouldStop(*decode_result == kDone, benchmark_decode_token_count,
539545
num_decode_steps, executor.GetCurrentStep().value(),
540546
max_num_tokens, observer)) {

runtime/core/pipeline_test.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class PipelineTest : public testing::Test {
6363
.string());
6464
ASSERT_OK(tokenizer);
6565
tokenizer_ = std::move(*tokenizer);
66+
6667
// The prefill tokens are the expected tokens that will be passed in at each
6768
// time the Prefill function is called. The values are the token ids of the
6869
// input prompt "Hello World!" prepended with the bos token id (2).
@@ -111,7 +112,9 @@ TEST_F(PipelineTest, Decode) {
111112
auto responses =
112113
Decode(*executor_, *tokenizer_, stop_token_detector, benchmark_info);
113114
EXPECT_OK(responses);
114-
EXPECT_EQ(*(responses->GetResponseTextAt(0)), " How's it going?!");
115+
// The response is " How's it going?" since "!" is the stop token which is
116+
// not included in the response.
117+
EXPECT_EQ(*(responses->GetResponseTextAt(0)), " How's it going?");
115118
}
116119

117120
TEST_F(PipelineTest, DecodeReachMaxNumTokens) {
@@ -134,7 +137,9 @@ TEST_F(PipelineTest, DecodeStreaming) {
134137
EXPECT_OK(stop_token_detector.AddStopTokenSequence({2294}));
135138
EXPECT_OK(DecodeStreaming(*executor_, *tokenizer_, stop_token_detector,
136139
benchmark_info, &observer));
137-
EXPECT_EQ(observer.GetResponses()[0], " How's it going?!");
140+
// The response is " How's it going?" since "!" is the stop token which is
141+
// not included in the response.
142+
EXPECT_EQ(observer.GetResponses()[0], " How's it going?");
138143
}
139144

140145
TEST_F(PipelineTest, DecodeStreamingReachMaxNumTokens) {
@@ -171,18 +176,19 @@ TEST_F(PipelineTest, DecodeBytePairEncodingTokens) {
171176
EXPECT_CALL(*tokenizer, TokenIdsToText(std::vector<int>{18}))
172177
.WillOnce(testing::Return(" "));
173178
EXPECT_CALL(*tokenizer, TokenIdsToText(std::vector<int>{2295}))
174-
.WillOnce(testing::Return("going"));
179+
.WillOnce(testing::Return("going?"));
175180
EXPECT_CALL(*tokenizer, TokenIdsToText(std::vector<int>{2294}))
176-
.WillOnce(testing::Return("?!"));
181+
.WillOnce(testing::Return("!"));
177182

178183
std::optional<BenchmarkInfo> benchmark_info;
179184
StopTokenDetector stop_token_detector(1);
180185
EXPECT_OK(stop_token_detector.AddStopTokenSequence({2294}));
181186
auto responses =
182187
Decode(*executor_, *tokenizer, stop_token_detector, benchmark_info);
183188
EXPECT_OK(responses);
184-
// The response is truncated at the max number of tokens.
185-
EXPECT_EQ(*(responses->GetResponseTextAt(0)), " How's it going?!");
189+
// The response is " How's it going?" since "!" is the stop token which is
190+
// not included in the response.
191+
EXPECT_EQ(*(responses->GetResponseTextAt(0)), " How's it going?");
186192
}
187193

188194
class PipelineCustomSamplingTest : public testing::Test {

runtime/core/session_basic_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ TEST_F(SessionBasicTest, RunDecode) {
8888
auto responses = (*session)->RunDecode();
8989
EXPECT_OK(responses);
9090
EXPECT_EQ(responses->GetNumOutputCandidates(), 1);
91-
EXPECT_EQ(*(responses->GetResponseTextAt(0)), " How's it going?!");
91+
// The response is " How's it going?" since "!" is the stop token which is
92+
// not included in the response.
93+
EXPECT_EQ(*(responses->GetResponseTextAt(0)), " How's it going?");
9294
}
9395

9496
class TestObserver : public InferenceObservable {

0 commit comments

Comments
 (0)