diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 7112489b769..1d0d64f12d5 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -273,7 +273,6 @@ Tensor& flash_attention_kernel_out( Format [n_layers, batch size, max_seq_len, num heads, head dim] .... @param[in] start_pos: sequence position - @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected. */ Tensor& custom_sdpa_out( RuntimeContext& ctx, @@ -306,63 +305,7 @@ Tensor& custom_sdpa_out( const int64_t seq_len = q.size(1); auto q_seq_len = q.size(1); - // Refactor the following into create_view util perhaps using - // TensorPtr - std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim> - sliced_key_dim_order{0, 1, 2, 3}; - std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim> - sliced_key_sizes; - sliced_key_sizes[0] = k.size(0); - sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2); - sliced_key_sizes[2] = k.size(2); - sliced_key_sizes[3] = k.size(3); - std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim> - sliced_key_strides; - dim_order_to_stride_nocheck( - sliced_key_sizes.data(), - sliced_key_dim_order.data(), - sdpa::impl::kKVDim, - sliced_key_strides.data()); - // since the cache is sliced, the batch stride needs to stay the same. - sliced_key_strides[0] = k.strides()[0]; - void* key_cache_data = k.mutable_data_ptr(); - TensorImpl k_impl = TensorImpl( - k.scalar_type(), - sdpa::impl::kKVDim, - sliced_key_sizes.data(), - key_cache_data, - sliced_key_dim_order.data(), - sliced_key_strides.data(), - TensorShapeDynamism::STATIC); - Tensor sliced_key_cache(&k_impl); - - std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim> - sliced_value_dim_order{0, 1, 2, 3}; - std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim> - sliced_value_sizes; - sliced_value_sizes[0] = v.size(0); - sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2); - sliced_value_sizes[2] = v.size(2); - sliced_value_sizes[3] = v.size(3); - std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim> - sliced_value_strides; - dim_order_to_stride_nocheck( - sliced_value_sizes.data(), - sliced_value_dim_order.data(), - sdpa::impl::kKVDim, - sliced_value_strides.data()); - // since the cache is sliced, the batch stride needs to stay the same. - sliced_value_strides[0] = v.strides()[0]; - void* value_cache_data = v.mutable_data_ptr(); - TensorImpl value_impl = TensorImpl( - v.scalar_type(), - sdpa::impl::kKVDim, - sliced_value_sizes.data(), - value_cache_data, - sliced_value_dim_order.data(), - sliced_value_strides.data(), - TensorShapeDynamism::STATIC); - Tensor sliced_value_cache(&value_impl); + const int64_t num_keys_for_causal_attention = start_pos + seq_len; ET_KERNEL_CHECK( ctx, @@ -380,38 +323,41 @@ Tensor& custom_sdpa_out( sdpa::impl::cpu_flash_attention( output, q, - sliced_key_cache, - sliced_value_cache, + k, + v, dropout_p, is_causal, attn_mask, scale, true, /* is_seq_at_dim_1 */ - start_pos); + start_pos, + num_keys_for_causal_attention); } else if (q_seq_len >= 192) { sdpa::impl::cpu_flash_attention( output, q, - sliced_key_cache, - sliced_value_cache, + k, + v, dropout_p, is_causal, attn_mask, scale, true, /* is_seq_at_dim_1 */ - start_pos); + start_pos, + num_keys_for_causal_attention); } else { sdpa::impl::cpu_flash_attention( output, q, - sliced_key_cache, - sliced_value_cache, + k, + v, dropout_p, is_causal, attn_mask, scale, true, /* is_seq_at_dim_1 */ - start_pos); + start_pos, + num_keys_for_causal_attention); } }); return output; diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 6962d24ad4c..46a1797f67c 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -212,7 +212,8 @@ void cpu_flash_attention( const optional& attn_mask, const optional& scale, bool is_seq_at_dim_1 = false, - const int64_t start_pos = 0) { + const int64_t start_pos = 0, + const int64_t num_keys_for_causal_attention = -1) { (void)dropout_p; // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -258,6 +259,13 @@ void cpu_flash_attention( kvSize = value.size(1); } + if (num_keys_for_causal_attention > 0) { + ET_CHECK_MSG( + num_keys_for_causal_attention <= kvSize, + "num_keys_for_causal_attention must be <= kvSize"); + kvSize = num_keys_for_causal_attention; + } + ET_CHECK_MSG( num_heads_kv <= num_head, "FlashAttention does not support num kv heads > num query heads.Got num query heads=%" PRId64