Skip to content

Commit 4f37bc7

Browse files
committed
Add chunk_gated_delta_rule triton kernel for CUDA backend
Registers FLA's chunk_gated_delta_rule as a @triton_op, following the same pattern as the existing SDPA triton kernel. Six FLA triton kernels are launched via wrap_triton() so AOTInductor compiles them directly into the generated .so — no C++ shim needed. Key trick: FLA kernels use @triton.heuristics which wrap_triton doesn't support. We unwrap via kernel.fn to get the inner @triton.autotune kernel and pass heuristic values (USE_G, IS_VARLEN, etc.) explicitly. Requires: pip install flash-linear-attention
1 parent e458023 commit 4f37bc7

File tree

6 files changed

+825
-1
lines changed

6 files changed

+825
-1
lines changed

.github/workflows/cuda.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,18 @@ jobs:
119119
cmake --workflow --preset default
120120
popd
121121
122-
# Run CUDA backend Python tests, overrides addopts so that we don't run all tests in pytest.ini
122+
# Install flash-linear-attention for chunk_gated_delta_rule triton kernel tests
123+
pip install "flash-linear-attention==0.2.5"
124+
125+
# Build chunk_gated_delta_rule C++ runner (needed by test_e2e_cpp_runner)
126+
cmake -DCMAKE_BUILD_TYPE=Release \
127+
-DCMAKE_PREFIX_PATH=$PWD/cmake-out \
128+
-DEXECUTORCH_BUILD_CUDA=ON \
129+
-B cmake-out/backends/cuda/tests/chunk_gated_delta_runner \
130+
-S backends/cuda/tests/chunk_gated_delta_runner
131+
cmake --build cmake-out/backends/cuda/tests/chunk_gated_delta_runner
132+
133+
# Run all CUDA backend Python tests (including chunk_gated_delta e2e)
123134
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
124135
125136
export-model-cuda-artifact:
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
cmake_minimum_required(VERSION 3.24)
2+
project(chunk_gated_delta_runner)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
6+
7+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
8+
9+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
10+
11+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
12+
13+
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../third-party/gflags)
14+
find_package(gflags REQUIRED)
15+
16+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../../..)
17+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
18+
executorch_target_link_options_shared_lib(executorch)
19+
20+
set(link_libraries executorch gflags)
21+
22+
list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas)
23+
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
24+
25+
list(
26+
APPEND
27+
link_libraries
28+
extension_module
29+
extension_data_loader
30+
extension_tensor
31+
extension_flat_tensor
32+
extension_named_data_map
33+
)
34+
35+
if(EXECUTORCH_BUILD_CUDA)
36+
find_package(CUDAToolkit REQUIRED)
37+
list(APPEND link_libraries aoti_cuda_backend)
38+
if(NOT MSVC)
39+
executorch_target_link_options_shared_lib(aoti_cuda_backend)
40+
endif()
41+
endif()
42+
43+
add_executable(chunk_gated_delta_runner main.cpp)
44+
target_include_directories(
45+
chunk_gated_delta_runner PUBLIC ${_common_include_directories}
46+
)
47+
target_link_libraries(chunk_gated_delta_runner PUBLIC ${link_libraries})
48+
49+
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
50+
target_link_options_gc_sections(chunk_gated_delta_runner)
51+
if(NOT APPLE AND NOT MSVC)
52+
target_link_options(chunk_gated_delta_runner PRIVATE "LINKER:-s")
53+
endif()
54+
endif()
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstdio>
10+
#include <cstring>
11+
#include <fstream>
12+
#include <vector>
13+
14+
#include <gflags/gflags.h>
15+
16+
#include <executorch/extension/module/module.h>
17+
#include <executorch/extension/tensor/tensor.h>
18+
19+
DEFINE_string(model_path, "", "Path to .pte file");
20+
DEFINE_string(data_path, "", "Path to .ptd directory (for CUDA delegate)");
21+
DEFINE_string(input_dir, "", "Directory with input .bin files");
22+
DEFINE_string(output_dir, "", "Directory to write output .bin files");
23+
24+
using ::executorch::extension::from_blob;
25+
using ::executorch::extension::Module;
26+
using ::executorch::runtime::Error;
27+
using ::executorch::runtime::EValue;
28+
29+
static std::vector<char> read_file(const std::string& path) {
30+
std::ifstream f(path, std::ios::binary | std::ios::ate);
31+
if (!f) {
32+
fprintf(stderr, "Cannot open %s\n", path.c_str());
33+
exit(1);
34+
}
35+
std::size_t size = static_cast<std::size_t>(f.tellg());
36+
f.seekg(0);
37+
std::vector<char> buf(size);
38+
f.read(buf.data(), static_cast<std::streamsize>(size));
39+
return buf;
40+
}
41+
42+
static void write_file(const std::string& path, const void* data, size_t len) {
43+
std::ofstream f(path, std::ios::binary);
44+
f.write(static_cast<const char*>(data), len);
45+
}
46+
47+
int main(int argc, char** argv) {
48+
gflags::ParseCommandLineFlags(&argc, &argv, true);
49+
if (FLAGS_model_path.empty()) {
50+
fprintf(stderr, "Error: --model_path required\n");
51+
return 1;
52+
}
53+
54+
std::unique_ptr<Module> module;
55+
if (!FLAGS_data_path.empty()) {
56+
module = std::make_unique<Module>(
57+
FLAGS_model_path,
58+
FLAGS_data_path,
59+
Module::LoadMode::MmapUseMlockIgnoreErrors);
60+
} else {
61+
module = std::make_unique<Module>(
62+
FLAGS_model_path, Module::LoadMode::MmapUseMlockIgnoreErrors);
63+
}
64+
65+
auto load_err = module->load();
66+
if (load_err != Error::Ok) {
67+
fprintf(stderr, "Failed to load model: 0x%x\n", static_cast<int>(load_err));
68+
return 1;
69+
}
70+
71+
constexpr int B = 1, T = 128, H = 4, K = 64, V = 64;
72+
73+
std::vector<EValue> inputs;
74+
75+
if (!FLAGS_input_dir.empty()) {
76+
// Load inputs from binary files
77+
struct TensorSpec {
78+
const char* name;
79+
std::vector<exec_aten::SizesType> shape;
80+
exec_aten::ScalarType dtype;
81+
};
82+
TensorSpec specs[] = {
83+
{"q", {B, T, H, K}, exec_aten::ScalarType::BFloat16},
84+
{"k", {B, T, H, K}, exec_aten::ScalarType::BFloat16},
85+
{"v", {B, T, H, V}, exec_aten::ScalarType::BFloat16},
86+
{"g", {B, T, H}, exec_aten::ScalarType::BFloat16},
87+
{"beta", {B, T, H}, exec_aten::ScalarType::BFloat16},
88+
{"initial_state", {B, H, K, V}, exec_aten::ScalarType::BFloat16},
89+
};
90+
91+
// Keep data and TensorPtrs alive for the duration of execution
92+
static std::vector<std::vector<char>> input_bufs;
93+
static std::vector<executorch::extension::TensorPtr> input_tensors;
94+
input_bufs.resize(6);
95+
input_tensors.clear();
96+
97+
for (int i = 0; i < 6; i++) {
98+
std::string path = FLAGS_input_dir + "/" + specs[i].name + ".bin";
99+
input_bufs[i] = read_file(path);
100+
input_tensors.push_back(
101+
from_blob(input_bufs[i].data(), specs[i].shape, specs[i].dtype));
102+
inputs.push_back(*input_tensors.back());
103+
}
104+
} else {
105+
// Generate deterministic test inputs
106+
auto to_bf16 = [](float f) -> uint16_t {
107+
uint32_t bits;
108+
std::memcpy(&bits, &f, sizeof(float));
109+
return static_cast<uint16_t>(bits >> 16);
110+
};
111+
112+
static std::vector<uint16_t> qk_data(B * T * H * K);
113+
for (size_t i = 0; i < qk_data.size(); i++)
114+
qk_data[i] = to_bf16(static_cast<float>(i % 100) * 0.01f - 0.5f);
115+
static auto v_data = std::vector<uint16_t>(qk_data.begin(), qk_data.end());
116+
static std::vector<uint16_t> g_data(B * T * H, to_bf16(-0.5f));
117+
static std::vector<uint16_t> beta_data(B * T * H, to_bf16(0.5f));
118+
static std::vector<uint16_t> state_data(B * H * K * V, to_bf16(0.0f));
119+
120+
static std::vector<executorch::extension::TensorPtr> default_tensors;
121+
default_tensors.clear();
122+
default_tensors.push_back(from_blob(
123+
qk_data.data(), {B, T, H, K}, exec_aten::ScalarType::BFloat16));
124+
default_tensors.push_back(from_blob(
125+
qk_data.data(), {B, T, H, K}, exec_aten::ScalarType::BFloat16));
126+
default_tensors.push_back(from_blob(
127+
v_data.data(), {B, T, H, V}, exec_aten::ScalarType::BFloat16));
128+
default_tensors.push_back(
129+
from_blob(g_data.data(), {B, T, H}, exec_aten::ScalarType::BFloat16));
130+
default_tensors.push_back(from_blob(
131+
beta_data.data(), {B, T, H}, exec_aten::ScalarType::BFloat16));
132+
default_tensors.push_back(from_blob(
133+
state_data.data(), {B, H, K, V}, exec_aten::ScalarType::BFloat16));
134+
for (auto& t : default_tensors)
135+
inputs.push_back(*t);
136+
}
137+
138+
auto result = module->execute("forward", inputs);
139+
if (!result.ok()) {
140+
fprintf(stderr, "Forward failed: 0x%x\n", static_cast<int>(result.error()));
141+
return 1;
142+
}
143+
144+
auto outputs = result.get();
145+
for (size_t i = 0; i < outputs.size(); i++) {
146+
if (!outputs[i].isTensor())
147+
continue;
148+
const auto& t = outputs[i].toTensor();
149+
printf("Output %zu: [", i);
150+
for (int d = 0; d < t.dim(); d++)
151+
printf("%d%s", static_cast<int>(t.size(d)), d < t.dim() - 1 ? "," : "");
152+
printf("] dtype=%d\n", static_cast<int>(t.scalar_type()));
153+
154+
if (!FLAGS_output_dir.empty()) {
155+
// Output tensors are on host memory (CUDA delegate copies back to CPU)
156+
std::string path =
157+
FLAGS_output_dir + "/output_" + std::to_string(i) + ".bin";
158+
write_file(path, t.const_data_ptr(), t.nbytes());
159+
printf(" Saved to %s (%zu bytes)\n", path.c_str(), (size_t)t.nbytes());
160+
}
161+
}
162+
163+
printf("SUCCESS\n");
164+
return 0;
165+
}

0 commit comments

Comments
 (0)