From 65747880750e8020b673369e47bdd80e212123e1 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Thu, 18 Jul 2024 15:07:01 -0700 Subject: [PATCH] use index_put only in kv cache update to reduce number of operators (#3786) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3786 The decomposition from ``` class IndexPut(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, input_pos, value): x[:, :, input_pos] = value return x ``` is ``` opcode name target args kwargs ------------- --------------- -------------------------- ----------------------------------------------- -------- placeholder x x () {} placeholder input_pos input_pos () {} placeholder value value () {} call_function slice_1 aten.slice.Tensor (x, 0, 0, 9223372036854775807) {} call_function slice_2 aten.slice.Tensor (slice_1, 1, 0, 9223372036854775807) {} call_function index_put aten.index_put.default (slice_2, [None, None, input_pos], value) {} call_function slice_3 aten.slice.Tensor (x, 0, 0, 9223372036854775807) {} call_function slice_scatter aten.slice_scatter.default (slice_3, index_put, 1, 0, 9223372036854775807) {} call_function slice_scatter_1 aten.slice_scatter.default (x, slice_scatter, 0, 0, 9223372036854775807) {} output output output ((slice_scatter_1, slice_scatter_1),) {} ``` however `x[:, :, input_pos] = value` really is just updating the content inside `x` with value, essentially just `index_put` By replacing `x[:, :, input_pos] = value` with `torch.ops.aten.index_put_(x, [None, None, input_pos], value)`, we reduce the number of operators from 6 to 1. ``` class IndexPut(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, indices, values): torch.ops.aten.index_put_(x, [None, None, input_pos], value) return x ``` decomposition is ``` opcode name target args kwargs ------------- --------- ---------------------- ----------------------------------- -------- placeholder x x () {} placeholder input_pos input_pos () {} placeholder value value () {} call_function index_put aten.index_put.default (x, [None, None, input_pos], value) {} output output output ((index_put, index_put),) {} ``` A more proper way to address this in long term is via pattern matching to replace the patterns with the simplified pattern Perf: For stories, before the diff ``` I 00:00:03.437290 executorch:runner.cpp:419] Prompt Tokens: 9 Generated Tokens: 118 I 00:00:03.437295 executorch:runner.cpp:425] Model Load Time: 0.763000 (seconds) I 00:00:03.437301 executorch:runner.cpp:435] Total inference time: 2.661000 (seconds) Rate: 44.344231 (tokens/second) I 00:00:03.437305 executorch:runner.cpp:443] Prompt evaluation: 0.185000 (seconds) Rate: 48.648649 (tokens/second) I 00:00:03.437309 executorch:runner.cpp:454] Generated 118 tokens: 2.476000 (seconds) Rate: 47.657512 (tokens/second) I 00:00:03.437313 executorch:runner.cpp:462] Time to first generated token: 0.206000 (seconds) I 00:00:03.437315 executorch:runner.cpp:469] Sampling time over 127 tokens: 0.042000 (seconds) ``` After the diff ``` I 00:00:03.195257 executorch:runner.cpp:419] Prompt Tokens: 9 Generated Tokens: 118 I 00:00:03.195295 executorch:runner.cpp:425] Model Load Time: 0.683000 (seconds) I 00:00:03.195314 executorch:runner.cpp:435] Total inference time: 2.502000 (seconds) Rate: 47.162270 (tokens/second) I 00:00:03.195319 executorch:runner.cpp:443] Prompt evaluation: 0.175000 (seconds) Rate: 51.428571 (tokens/second) I 00:00:03.195323 executorch:runner.cpp:454] Generated 118 tokens: 2.327000 (seconds) Rate: 50.709067 (tokens/second) I 00:00:03.195327 executorch:runner.cpp:462] Time to first generated token: 0.195000 (seconds) I 00:00:03.195330 executorch:runner.cpp:469] Sampling time over 127 tokens: 0.049000 (seconds) ``` Differential Revision: D57949659 --- examples/models/llama2/llama_transformer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 431cc6dc2c4..229c2b37d8d 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -191,10 +191,15 @@ def update( narrowed_v.copy_(v_val) return self.k_cache, self.v_cache else: - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, None, input_pos], k_val + ) + v_out = torch.ops.aten.index_put_( + self.v_cache, [None, None, input_pos], v_val + ) + v_out = torch.ops.aten.index_put_( + self.v_cache, [None, None, input_pos], v_val + ) return k_out, v_out