Skip to content

Add update_quantized_cache op #5527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions extension/llm/custom_ops/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ runtime.python_test(
],
)

runtime.python_test(
name = "test_update_quantized_cache",
srcs = [
"test_update_quantized_cache.py",
],
preload_deps = [
":custom_ops_aot_lib",
],
deps = [
"//caffe2:torch",
],
)

runtime.python_test(
name = "test_preprocess_custom_ops",
srcs = [
Expand Down
41 changes: 39 additions & 2 deletions extension/llm/custom_ops/op_sdpa_aot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>

#include <torch/library.h>

namespace torch {
namespace executor {

namespace native {
namespace {
Tensor& sdpa_with_kv_cache_out_no_context(
const Tensor& q_projected,
const Tensor& k_projected,
Expand Down Expand Up @@ -81,7 +81,27 @@ at::Tensor sdpa_with_kv_cache_aten(
output);
return output;
}
} // namespace

Tensor& update_quantized_cache_out_no_context(
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
Tensor& output) {
exec_aten::RuntimeContext context{};
return torch::executor::native::update_quantized_cache_out(
context, value, cache, start_pos, output);
}

at::Tensor update_quantized_cache_aten(
const at::Tensor& value,
at::Tensor& cache,
const int64_t start_pos) {
auto output = at::empty({1});
WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
(value, cache, start_pos, output);
return output;
}

} // namespace native
} // namespace executor
} // namespace torch
Expand All @@ -95,6 +115,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
m.def(
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
"SymInt start_pos) -> Tensor");
m.def(
"update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
}

TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
Expand All @@ -105,3 +131,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
WRAP_TO_ATEN(
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
}

// TODO: Rename this file to op_custom_ops_aot.cpp
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
m.impl(
"update_quantized_cache",
torch::executor::native::update_quantized_cache_aten);
m.impl(
"update_quantized_cache.out",
WRAP_TO_ATEN(
torch::executor::native::update_quantized_cache_out_no_context, 3));
}
143 changes: 143 additions & 0 deletions extension/llm/custom_ops/op_update_quantized_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>

#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
// @lint-ignore CLANGTIDY facebook-unused-include-check
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>

#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>

namespace torch {
namespace executor {

namespace native {

namespace {
bool validate_cache_params(
const Tensor& quantized_value,
const Tensor& quantized_cache,
int64_t start_pos,
int64_t seq_length) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
start_pos < quantized_cache.size(1),
"start_pos must be less than cache size at dim 1");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
(start_pos + seq_length) <= quantized_cache.size(1),
"start_post + seq_length must be less than max seq length supported by cache."
"start pos: %" PRId64 ", seq_length: %" PRId64
"."
"cache size: %zd",
start_pos,
seq_length,
quantized_cache.size(1));

// Make sure they are in contiguous dim order
ET_LOG_MSG_AND_RETURN_IF_FALSE(
is_contiguous_dim_order(
quantized_cache.dim_order().data(), quantized_cache.dim()),
"quantized cache must be in contiguous dim order");

ET_LOG_MSG_AND_RETURN_IF_FALSE(
is_contiguous_dim_order(
quantized_value.dim_order().data(), quantized_value.dim()),
"quantized value must be in contiguous dim order");

return true;
}
} // anonymous namespace

Tensor& update_quantized_cache_out(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
Tensor& output) {
(void)ctx;
int64_t seq_len = value.size(1);
ET_KERNEL_CHECK(
ctx,
validate_cache_params(value, cache, start_pos, seq_len),
InvalidArgument,
output);

ET_CHECK_MSG(
value.size(0) == cache.size(0),
"projected_value batch size should be equal to the cache batch size.");
ET_CHECK_MSG(
value.size(2) == cache.size(2),
"projected_value number of heads should be equal to the cache number of heads.");
ET_CHECK_MSG(
value.size(3) == cache.size(3),
"projected_value embedding dimension should be equal to the cache embedding dimension.");
ET_CHECK_MSG(
value.element_size() == cache.element_size(),
"projected_value data type size should be equal to the cache data type size.");

ET_CHECK_MSG(
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
"projected value must be in contiguous dim order");
ET_CHECK_MSG(
is_contiguous_dim_order(cache.dim_order().data(), cache.dim()),
"projected value must be in contiguous dim order");

const void* value_data = value.const_data_ptr();
void* cache_data = cache.mutable_data_ptr();

ET_CHECK_MSG(value_data, "projected_value data is null");
ET_CHECK_MSG(cache_data, "cache data is null");

auto cache_strides = cache.strides();
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];

auto value_strides = value.strides();
exec_aten::StridesType value_batch_dim_stride = value_strides[0];

exec_aten::SizesType num_bytes_to_copy =
(value.numel() / value.size(0)) * value.element_size();

for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
exec_aten::SizesType cache_pos_offset =
(batch_line * cache_batch_dim_stride +
start_pos * cache_seq_dim_stride) *
cache.element_size();
exec_aten::SizesType value_pos_offset =
(batch_line * value_batch_dim_stride) * cache.element_size();

std::memcpy(
(uint8_t*)cache_data + cache_pos_offset,
(uint8_t*)value_data + value_pos_offset,
num_bytes_to_copy);
}

// Noone uses output. Just a placeholder.
return output;
}
} // namespace native
} // namespace executor
} // namespace torch

// Really this is just an inplace tensor update op
// which makes assumption on the rank of a tensor,
// and the dim order (memory layout) of the tensor.
// Furthermore assumes that the indexing is along
// sequence dimension (dim 1) of the tensor.
// In later diffs will rename this to update_cache.
EXECUTORCH_LIBRARY(
llama,
"update_quantized_cache.out",
torch::executor::native::update_quantized_cache_out);
26 changes: 26 additions & 0 deletions extension/llm/custom_ops/op_update_quantized_cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

namespace native {

Tensor& update_quantized_cache_out(
RuntimeContext& ctx,
const Tensor& value,
Tensor& cache,
const int64_t start_pos,
Tensor& output);
} // namespace native
} // namespace executor
} // namespace torch
52 changes: 52 additions & 0 deletions extension/llm/custom_ops/sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from torch.library import impl

# TODO rename this file to custom_ops_meta_registration.py
try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
Expand Down Expand Up @@ -138,3 +139,54 @@ def fast_hadamard_transform_meta(mat):
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
return torch.empty_like(mat)


def _validate_update_cache_params(
value,
cache,
start_pos,
):
seq_len = value.size(1)
assert (
value.dim() == 4
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."

assert (
value.dtype == cache.dtype
), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"

for i in [0, 2, 3]:
assert value.size(i) == cache.size(
i
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"

torch._check_is_size(start_pos)
# Setting to arbitrary limit of 256 for now since there is no way
# to plumb this information from model config
torch._check(start_pos < cache.size(1))
assert start_pos < cache.size(
1
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"

torch._check((start_pos + seq_len) < cache.size(1))
assert (start_pos + seq_len) < cache.size(
1
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"


@impl(custom_ops_lib, "update_quantized_cache", "Meta")
def update_quantized_cache_meta(
value,
cache,
start_pos,
):
_validate_update_cache_params(
value,
cache,
start_pos,
)

# Update cache doesnt really return anything but I dont know a better
# workaround. Should we just return cache instead? But I am afraid that
# will result in extra memory allocation
return torch.empty((1,), dtype=value.dtype, device="meta")
2 changes: 2 additions & 0 deletions extension/llm/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ def define_common_targets():
"op_fallback.cpp",
"op_fast_hadamard_transform.cpp",
"op_sdpa.cpp",
"op_update_quantized_cache.cpp",
],
exported_headers = [
"op_fallback.h",
"op_fast_hadamard_transform.h",
"op_sdpa.h",
"op_update_quantized_cache.h",
],
exported_deps = [
"//executorch/runtime/kernel:kernel_includes",
Expand Down
Loading
Loading