diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index e8a53a41312..e74c111cc5f 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -754,6 +754,152 @@ void update_cache( } } +/* + Input params + @param[in] q_projected Projected query with query weights. + Format [n_layers, batch size, seq_len, num heads, head dim] + @param[in] k_projected Projected query with key weights. + Format [n_layers, batch size, seq_len, num heads, head dim] + @param[in] v_projected Projected query with value weights. + Format [n_layers, batch size, seq_len, num heads, head dim] + @param[in] key_cache Cache of previous k_projected. + Format [n_layers, batch size, max_seq_len, num heads, head dim] + @param[in] key_cache Cache of previous v_projected. + Format [n_layers, batch size, max_seq_len, num heads, head dim] + .... + @param[in] start_pos: sequence position + @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected. +*/ +Tensor& custom_sdpa_out( + RuntimeContext& ctx, + const Tensor& q, + const Tensor& k, + const Tensor& v, + const int64_t start_pos, + const int64_t seq_len, + const optional& attn_mask, + const double dropout_p, + const bool is_causal, + // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy + const optional scale, + Tensor& output) { + ET_KERNEL_CHECK_MSG( + ctx, + !attn_mask.has_value() || !is_causal, + InvalidArgument, + output, + "attn_mask and is_causal cannot be set at the same time"); + + ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); + + auto q_seq_len = q.size(1); + + // Refactor the following into create_view util perhaps using + // TensorPtr + std::array sliced_key_dim_order{ + 0, 1, 2, 3}; + std::array sliced_key_sizes; + sliced_key_sizes[0] = k.size(0); + sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2); + sliced_key_sizes[2] = k.size(2); + sliced_key_sizes[3] = k.size(3); + std::array sliced_key_strides; + dim_order_to_stride_nocheck( + sliced_key_sizes.data(), + sliced_key_dim_order.data(), + util::kKVDim, + sliced_key_strides.data()); + // since the cache is sliced, the batch stride needs to stay the same. + sliced_key_strides[0] = k.strides()[0]; + void* key_cache_data = k.mutable_data_ptr(); + TensorImpl k_impl = TensorImpl( + k.scalar_type(), + util::kKVDim, + sliced_key_sizes.data(), + key_cache_data, + sliced_key_dim_order.data(), + sliced_key_strides.data(), + TensorShapeDynamism::STATIC); + Tensor sliced_key_cache(&k_impl); + + std::array sliced_value_dim_order{ + 0, 1, 2, 3}; + std::array sliced_value_sizes; + sliced_value_sizes[0] = v.size(0); + sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2); + sliced_value_sizes[2] = v.size(2); + sliced_value_sizes[3] = v.size(3); + std::array sliced_value_strides; + dim_order_to_stride_nocheck( + sliced_value_sizes.data(), + sliced_value_dim_order.data(), + util::kKVDim, + sliced_value_strides.data()); + // since the cache is sliced, the batch stride needs to stay the same. + sliced_value_strides[0] = v.strides()[0]; + void* value_cache_data = v.mutable_data_ptr(); + TensorImpl value_impl = TensorImpl( + v.scalar_type(), + util::kKVDim, + sliced_value_sizes.data(), + value_cache_data, + sliced_value_dim_order.data(), + sliced_value_strides.data(), + TensorShapeDynamism::STATIC); + Tensor sliced_value_cache(&value_impl); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(output, q.sizes()) == Error::Ok, + InvalidArgument, + output); + + // TODO(task): replace the template param selection logic + // with whatever apprpriately makes more sense for + ET_SWITCH_FLOAT_TYPES(q.scalar_type(), ctx, "flash_attention", CTYPE, [&] { + // TODO we need to re-evaluate this for ARM CPUs + // And there can be many so instead of templatizing + // we might consider another appraoch + if (q_seq_len >= 768) { + cpu_flash_attention( + output, + q, + sliced_key_cache, + sliced_value_cache, + dropout_p, + is_causal, + attn_mask, + scale, + true, + start_pos); + } else if (q_seq_len >= 192) { + cpu_flash_attention( + output, + q, + sliced_key_cache, + sliced_value_cache, + dropout_p, + is_causal, + attn_mask, + scale, + true, + start_pos); + } else { + cpu_flash_attention( + output, + q, + sliced_key_cache, + sliced_value_cache, + dropout_p, + is_causal, + attn_mask, + scale, + true, + start_pos); + } + }); + return output; +} } // anonymous namespace Tensor& flash_attention_kernel_out( @@ -860,129 +1006,24 @@ Tensor& sdpa_with_kv_cache_out( InvalidArgument, output); - ET_KERNEL_CHECK_MSG( - ctx, - !attn_mask.has_value() || !is_causal, - InvalidArgument, - output, - "attn_mask and is_causal cannot be set at the same time"); - ET_CHECK_MSG(q_projected.dim() == 4, "query must be a 4D tensor"); update_cache(k_projected, key_cache, start_pos, seq_len); update_cache(v_projected, value_cache, start_pos, seq_len); - auto q_seq_len = q_projected.size(1); - - std::array sliced_key_dim_order{ - 0, 1, 2, 3}; - std::array sliced_key_sizes; - sliced_key_sizes[0] = key_cache.size(0); - sliced_key_sizes[1] = start_pos + seq_len; // key_cache.size(2); - sliced_key_sizes[2] = key_cache.size(2); - sliced_key_sizes[3] = key_cache.size(3); - std::array sliced_key_strides; - dim_order_to_stride_nocheck( - sliced_key_sizes.data(), - sliced_key_dim_order.data(), - util::kKVDim, - sliced_key_strides.data()); - // since the cache is sliced, the batch stride needs to stay the same. - sliced_key_strides[0] = key_cache.strides()[0]; - void* key_cache_data = key_cache.mutable_data_ptr(); - TensorImpl k_impl = TensorImpl( - key_cache.scalar_type(), - util::kKVDim, - sliced_key_sizes.data(), - key_cache_data, - sliced_key_dim_order.data(), - sliced_key_strides.data(), - TensorShapeDynamism::STATIC); - Tensor sliced_key_cache(&k_impl); - - std::array sliced_value_dim_order{ - 0, 1, 2, 3}; - std::array sliced_value_sizes; - sliced_value_sizes[0] = value_cache.size(0); - sliced_value_sizes[1] = start_pos + seq_len; // value_cache.size(2); - sliced_value_sizes[2] = value_cache.size(2); - sliced_value_sizes[3] = value_cache.size(3); - std::array sliced_value_strides; - dim_order_to_stride_nocheck( - sliced_value_sizes.data(), - sliced_value_dim_order.data(), - util::kKVDim, - sliced_value_strides.data()); - // since the cache is sliced, the batch stride needs to stay the same. - sliced_value_strides[0] = value_cache.strides()[0]; - void* value_cache_data = value_cache.mutable_data_ptr(); - TensorImpl value_impl = TensorImpl( - value_cache.scalar_type(), - util::kKVDim, - sliced_value_sizes.data(), - value_cache_data, - sliced_value_dim_order.data(), - sliced_value_strides.data(), - TensorShapeDynamism::STATIC); - Tensor sliced_value_cache(&value_impl); - - // Is this true? - // Cant do this as is because the expectation of this kernel is - // that q, k, v are [B, num heads, seq length, head dim] - // and the cache is [B, max seq len, num heads, head dim] - // and q, k, v are all [B, seq length, num heads, head dim] - - ET_KERNEL_CHECK( + custom_sdpa_out( ctx, - resize_tensor(output, q_projected.sizes()) == Error::Ok, - InvalidArgument, + q_projected, + key_cache, + value_cache, + start_pos, + seq_len, + attn_mask, + dropout_p, + is_causal, + scale, output); - // TODO(task): replace the template param selection logic - // with whatever apprpriately makes more sense for - ET_SWITCH_FLOAT_TYPES( - q_projected.scalar_type(), ctx, "flash_attention", CTYPE, [&] { - // TODO we need to re-evaluate this for ARM CPUs - // And there can be many so instead of templatizing - // we might consider another appraoch - if (q_seq_len >= 768) { - cpu_flash_attention( - output, - q_projected, - sliced_key_cache, - sliced_value_cache, - dropout_p, - is_causal, - attn_mask, - scale, - true, - start_pos); - } else if (q_seq_len >= 192) { - cpu_flash_attention( - output, - q_projected, - sliced_key_cache, - sliced_value_cache, - dropout_p, - is_causal, - attn_mask, - scale, - true, - start_pos); - } else { - cpu_flash_attention( - output, - q_projected, - sliced_key_cache, - sliced_value_cache, - dropout_p, - is_causal, - attn_mask, - scale, - true, - start_pos); - } - }); return output; } } // namespace native