Skip to content

Commit b163d44

Browse files
billmguofacebook-github-bot
authored andcommitted
Fix num_iters > 5 Shiftpointer issue (#9150)
Summary: Pull Request resolved: #9150 in the original implementation when num_iters > 5 it will crash during exit Differential Revision: D70974749
1 parent e763a83 commit b163d44

File tree

3 files changed

+68
-6
lines changed

3 files changed

+68
-6
lines changed

examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp

+47-2
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,53 @@ void ShiftPointerIoMgr::init_io() {
168168
}
169169
}
170170

171-
void ShiftPointerIoMgr::reset_io() {
171+
void ShiftPointerIoMgr::reset_io(
172+
const std::vector<executorch::runtime::Result<
173+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
174+
const std::vector<
175+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
176+
kv_methods_meta) {
172177
IO* ptr = static_cast<IO*>(data_ptr_.get());
178+
std::fill(ptr->prefill_input_pos.begin(), ptr->prefill_input_pos.end(), 0);
179+
ptr->kv_input_pos = 0;
173180
std::fill(
174181
ptr->prefill_attention_mask.begin(),
175182
ptr->prefill_attention_mask.end(),
176183
0);
177184
std::fill(ptr->kv_attention_mask.begin(), ptr->kv_attention_mask.end(), 0);
185+
186+
input_tensors_[kv_forward_name_].clear();
187+
input_tensors_[kv_forward_name_].resize(modules_.size());
188+
output_tensors_[kv_forward_name_].clear();
189+
output_tensors_[kv_forward_name_].resize(modules_.size());
190+
191+
k_cache_in_[kv_forward_name_].clear();
192+
v_cache_in_[kv_forward_name_].clear();
193+
k_cache_out_[kv_forward_name_].clear();
194+
v_cache_out_[kv_forward_name_].clear();
195+
196+
input_tensors_[prefill_forward_name_].clear();
197+
input_tensors_[prefill_forward_name_].resize(modules_.size());
198+
output_tensors_[prefill_forward_name_].clear();
199+
output_tensors_[prefill_forward_name_].resize(modules_.size());
200+
201+
k_cache_in_[prefill_forward_name_].clear();
202+
v_cache_in_[prefill_forward_name_].clear();
203+
k_cache_out_[prefill_forward_name_].clear();
204+
v_cache_out_[prefill_forward_name_].clear();
205+
206+
switch (eval_mode_) {
207+
case EvalMode::kKVCached:
208+
prepare_kv_io(kv_methods_meta);
209+
break;
210+
case EvalMode::kHybrid:
211+
prepare_prefill_io(prefill_methods_meta);
212+
prepare_kv_io(kv_methods_meta);
213+
break;
214+
default:
215+
ET_CHECK_MSG(false, "unsupported mode");
216+
break;
217+
}
178218
}
179219
void ShiftPointerIoMgr::prepare_kv_io(
180220
const std::vector<Result<MethodMeta>>& methods_meta) {
@@ -893,7 +933,12 @@ void SmartMaskIoMgr::init_io() {
893933
ptr->init_io_ptrs(shared_ptr, io_bytes_map);
894934
}
895935

896-
void SmartMaskIoMgr::reset_io() {
936+
void SmartMaskIoMgr::reset_io(
937+
const std::vector<executorch::runtime::Result<
938+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
939+
const std::vector<
940+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
941+
kv_methods_meta) {
897942
IO* ptr = static_cast<IO*>(data_ptr_.get());
898943
int32_t prefill_attn_size = prefill_ar_len_ * context_len_;
899944
int32_t kv_attn_size = kv_ar_len_ * context_len_;

examples/qualcomm/oss_scripts/llama/runner/io_manager.h

+18-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ class IoMgrBase {
3333
std::vector<std::shared_ptr<executorch::extension::Module>>& modules);
3434
virtual ~IoMgrBase();
3535
virtual void init_io() = 0;
36-
virtual void reset_io() = 0;
36+
virtual void reset_io(
37+
const std::vector<executorch::runtime::Result<
38+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
39+
const std::vector<
40+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
41+
kv_methods_meta) = 0;
3742
virtual void prepare_prefill_io(
3843
const std::vector<
3944
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
@@ -98,7 +103,12 @@ class ShiftPointerIoMgr : public IoMgrBase {
98103
const bool use_int64_token);
99104

100105
void init_io() override;
101-
void reset_io() override;
106+
void reset_io(
107+
const std::vector<executorch::runtime::Result<
108+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
109+
const std::vector<
110+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
111+
kv_methods_meta) override;
102112
void prepare_prefill_io(
103113
const std::vector<
104114
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
@@ -201,7 +211,12 @@ class SmartMaskIoMgr : public IoMgrBase {
201211
const bool use_int64_token);
202212

203213
void init_io() override;
204-
void reset_io() override;
214+
void reset_io(
215+
const std::vector<executorch::runtime::Result<
216+
executorch::runtime::MethodMeta>>& prefill_methods_meta,
217+
const std::vector<
218+
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
219+
kv_methods_meta) override;
205220
void prepare_prefill_io(
206221
const std::vector<
207222
executorch::runtime::Result<executorch::runtime::MethodMeta>>&

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ Error Runner::generate(
447447
if (stats_callback) {
448448
stats_callback(stats_);
449449
}
450-
io_mgr_->reset_io();
450+
io_mgr_->reset_io(
451+
get_methods_meta(prefill_forward_name_),
452+
get_methods_meta(kv_forward_name_));
451453
prompt_.clear();
452454
return Error::Ok;
453455
}

0 commit comments

Comments
 (0)