diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp index 6a3f4d24030..ce7baefa080 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp @@ -168,13 +168,53 @@ void ShiftPointerIoMgr::init_io() { } } -void ShiftPointerIoMgr::reset_io() { +void ShiftPointerIoMgr::reset_io( + const std::vector>& prefill_methods_meta, + const std::vector< + executorch::runtime::Result>& + kv_methods_meta) { IO* ptr = static_cast(data_ptr_.get()); + std::fill(ptr->prefill_input_pos.begin(), ptr->prefill_input_pos.end(), 0); + ptr->kv_input_pos = 0; std::fill( ptr->prefill_attention_mask.begin(), ptr->prefill_attention_mask.end(), 0); std::fill(ptr->kv_attention_mask.begin(), ptr->kv_attention_mask.end(), 0); + + input_tensors_[kv_forward_name_].clear(); + input_tensors_[kv_forward_name_].resize(modules_.size()); + output_tensors_[kv_forward_name_].clear(); + output_tensors_[kv_forward_name_].resize(modules_.size()); + + k_cache_in_[kv_forward_name_].clear(); + v_cache_in_[kv_forward_name_].clear(); + k_cache_out_[kv_forward_name_].clear(); + v_cache_out_[kv_forward_name_].clear(); + + input_tensors_[prefill_forward_name_].clear(); + input_tensors_[prefill_forward_name_].resize(modules_.size()); + output_tensors_[prefill_forward_name_].clear(); + output_tensors_[prefill_forward_name_].resize(modules_.size()); + + k_cache_in_[prefill_forward_name_].clear(); + v_cache_in_[prefill_forward_name_].clear(); + k_cache_out_[prefill_forward_name_].clear(); + v_cache_out_[prefill_forward_name_].clear(); + + switch (eval_mode_) { + case EvalMode::kKVCached: + prepare_kv_io(kv_methods_meta); + break; + case EvalMode::kHybrid: + prepare_prefill_io(prefill_methods_meta); + prepare_kv_io(kv_methods_meta); + break; + default: + ET_CHECK_MSG(false, "unsupported mode"); + break; + } } void ShiftPointerIoMgr::prepare_kv_io( const std::vector>& methods_meta) { @@ -893,7 +933,12 @@ void SmartMaskIoMgr::init_io() { ptr->init_io_ptrs(shared_ptr, io_bytes_map); } -void SmartMaskIoMgr::reset_io() { +void SmartMaskIoMgr::reset_io( + const std::vector>& prefill_methods_meta, + const std::vector< + executorch::runtime::Result>& + kv_methods_meta) { IO* ptr = static_cast(data_ptr_.get()); int32_t prefill_attn_size = prefill_ar_len_ * context_len_; int32_t kv_attn_size = kv_ar_len_ * context_len_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h index efb4c70acc7..03808ede3bf 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h @@ -33,7 +33,12 @@ class IoMgrBase { std::vector>& modules); virtual ~IoMgrBase(); virtual void init_io() = 0; - virtual void reset_io() = 0; + virtual void reset_io( + const std::vector>& prefill_methods_meta, + const std::vector< + executorch::runtime::Result>& + kv_methods_meta) = 0; virtual void prepare_prefill_io( const std::vector< executorch::runtime::Result>& @@ -98,7 +103,12 @@ class ShiftPointerIoMgr : public IoMgrBase { const bool use_int64_token); void init_io() override; - void reset_io() override; + void reset_io( + const std::vector>& prefill_methods_meta, + const std::vector< + executorch::runtime::Result>& + kv_methods_meta) override; void prepare_prefill_io( const std::vector< executorch::runtime::Result>& @@ -201,7 +211,12 @@ class SmartMaskIoMgr : public IoMgrBase { const bool use_int64_token); void init_io() override; - void reset_io() override; + void reset_io( + const std::vector>& prefill_methods_meta, + const std::vector< + executorch::runtime::Result>& + kv_methods_meta) override; void prepare_prefill_io( const std::vector< executorch::runtime::Result>& diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 13742b9b2a0..db7ac468b5e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -447,7 +447,9 @@ Error Runner::generate( if (stats_callback) { stats_callback(stats_); } - io_mgr_->reset_io(); + io_mgr_->reset_io( + get_methods_meta(prefill_forward_name_), + get_methods_meta(kv_forward_name_)); prompt_.clear(); return Error::Ok; }