diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 431cc6dc2c4..229c2b37d8d 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -191,10 +191,15 @@ def update( narrowed_v.copy_(v_val) return self.k_cache, self.v_cache else: - k_out = self.k_cache - v_out = self.v_cache - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, None, input_pos], k_val + ) + v_out = torch.ops.aten.index_put_( + self.v_cache, [None, None, input_pos], v_val + ) + v_out = torch.ops.aten.index_put_( + self.v_cache, [None, None, input_pos], v_val + ) return k_out, v_out