@@ -168,13 +168,53 @@ void ShiftPointerIoMgr::init_io() {
168
168
}
169
169
}
170
170
171
- void ShiftPointerIoMgr::reset_io () {
171
+ void ShiftPointerIoMgr::reset_io (
172
+ const std::vector<executorch::runtime::Result<
173
+ executorch::runtime::MethodMeta>>& prefill_methods_meta,
174
+ const std::vector<
175
+ executorch::runtime::Result<executorch::runtime::MethodMeta>>&
176
+ kv_methods_meta) {
172
177
IO* ptr = static_cast <IO*>(data_ptr_.get ());
178
+ std::fill (ptr->prefill_input_pos .begin (), ptr->prefill_input_pos .end (), 0 );
179
+ ptr->kv_input_pos = 0 ;
173
180
std::fill (
174
181
ptr->prefill_attention_mask .begin (),
175
182
ptr->prefill_attention_mask .end (),
176
183
0 );
177
184
std::fill (ptr->kv_attention_mask .begin (), ptr->kv_attention_mask .end (), 0 );
185
+
186
+ input_tensors_[kv_forward_name_].clear ();
187
+ input_tensors_[kv_forward_name_].resize (modules_.size ());
188
+ output_tensors_[kv_forward_name_].clear ();
189
+ output_tensors_[kv_forward_name_].resize (modules_.size ());
190
+
191
+ k_cache_in_[kv_forward_name_].clear ();
192
+ v_cache_in_[kv_forward_name_].clear ();
193
+ k_cache_out_[kv_forward_name_].clear ();
194
+ v_cache_out_[kv_forward_name_].clear ();
195
+
196
+ input_tensors_[prefill_forward_name_].clear ();
197
+ input_tensors_[prefill_forward_name_].resize (modules_.size ());
198
+ output_tensors_[prefill_forward_name_].clear ();
199
+ output_tensors_[prefill_forward_name_].resize (modules_.size ());
200
+
201
+ k_cache_in_[prefill_forward_name_].clear ();
202
+ v_cache_in_[prefill_forward_name_].clear ();
203
+ k_cache_out_[prefill_forward_name_].clear ();
204
+ v_cache_out_[prefill_forward_name_].clear ();
205
+
206
+ switch (eval_mode_) {
207
+ case EvalMode::kKVCached :
208
+ prepare_kv_io (kv_methods_meta);
209
+ break ;
210
+ case EvalMode::kHybrid :
211
+ prepare_prefill_io (prefill_methods_meta);
212
+ prepare_kv_io (kv_methods_meta);
213
+ break ;
214
+ default :
215
+ ET_CHECK_MSG (false , " unsupported mode" );
216
+ break ;
217
+ }
178
218
}
179
219
void ShiftPointerIoMgr::prepare_kv_io (
180
220
const std::vector<Result<MethodMeta>>& methods_meta) {
@@ -893,7 +933,12 @@ void SmartMaskIoMgr::init_io() {
893
933
ptr->init_io_ptrs (shared_ptr, io_bytes_map);
894
934
}
895
935
896
- void SmartMaskIoMgr::reset_io () {
936
+ void SmartMaskIoMgr::reset_io (
937
+ const std::vector<executorch::runtime::Result<
938
+ executorch::runtime::MethodMeta>>& prefill_methods_meta,
939
+ const std::vector<
940
+ executorch::runtime::Result<executorch::runtime::MethodMeta>>&
941
+ kv_methods_meta) {
897
942
IO* ptr = static_cast <IO*>(data_ptr_.get ());
898
943
int32_t prefill_attn_size = prefill_ar_len_ * context_len_;
899
944
int32_t kv_attn_size = kv_ar_len_ * context_len_;
0 commit comments