Skip to content

Commit db86fbd

Browse files
committed
code review
1 parent c2adfa9 commit db86fbd

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ HybridMemory::HybridMemory(
6464
const std::string& kv_forward_name)
6565
: Memory(modules),
6666
shard_layers_({num_layers}),
67-
prefill_cache_len_(prefill_cache_len),
6867
kv_cache_len_(kv_cache_len),
68+
prefill_cache_len_(prefill_cache_len),
6969
vocab_size_(vocab_size),
7070
num_layers_(num_layers),
7171
head_dim_(head_dim),
@@ -332,7 +332,8 @@ void HybridMemory::prepare_prefill_io(
332332
input_tensors_[prefill_forward_name_][0].push_back(prefill_attn_mask_.get());
333333
// [O]: logits
334334
int logit_index = 0;
335-
Result<TensorInfo> logits = methods_meta[0]->output_tensor_meta(0);
335+
Result<TensorInfo> logits =
336+
methods_meta[modules_.size() - 1]->output_tensor_meta(logit_index);
336337
prefill_logits_ = std::make_unique<TensorImpl>(
337338
logits->scalar_type(),
338339
logits->sizes().size(),

0 commit comments

Comments
 (0)