Skip to content

Commit 496a4f6

Browse files
authored
Merge branch 'main' into fix-lora-validation-and-added-tokens
2 parents 54240c6 + 4e905fe commit 496a4f6

2,292 files changed

Lines changed: 243434 additions & 82925 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.claude/skills/add-jit-kernel/SKILL.md

Lines changed: 607 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
---
2+
name: add-sgl-kernel
3+
description: Step-by-step tutorial for adding a heavyweight AOT CUDA/C++ kernel to sgl-kernel (including tests & benchmarks)
4+
---
5+
6+
# Tutorial: Adding a New Kernel to `sgl-kernel` (AOT / Heavyweight)
7+
8+
This tutorial walks through adding a simple element-wise scale operation as an AOT kernel. We'll implement `scale(x, factor) = x * factor` to demonstrate the complete workflow.
9+
10+
## Goal
11+
12+
Add a new operation that scales each element of a tensor by a scalar factor:
13+
14+
- Input: tensor `x` (CUDA) and scalar `factor` (float)
15+
- Output: `x * factor` (element-wise, in-place or into pre-allocated `out`)
16+
- Supported dtypes: **FP16 (`torch.float16`), BF16 (`torch.bfloat16`), FP32 (`torch.float32`)**
17+
- Dispatched via `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro (defined in `sgl-kernel/include/utils.h`)
18+
19+
## Two rules of thumb (must follow)
20+
21+
1. **Prefer `python/sglang/jit_kernel` first** when the kernel does **not** depend on CUTLASS or another large C++ project. This is the default path for lightweight kernels that benefit from rapid iteration.
22+
2. **Prefer `sgl-kernel`** when the kernel **does** depend on CUTLASS or another large C++ project, or when it should be part of the AOT wheel / torch op registration flow.
23+
3. **Exception**: if the dependency is `flashinfer`, or CUTLASS that is already provided through `flashinfer`, the kernel can still be implemented as `jit_kernel`.
24+
25+
In addition, every new kernel must ship with:
26+
27+
- **Tests** (pytest)
28+
- **A benchmark script** (triton.testing)
29+
30+
---
31+
32+
## Repository integration map
33+
34+
You will typically touch these files/areas:
35+
36+
- Implementation: `sgl-kernel/csrc/elementwise/scale.cu` (pick the right subdirectory)
37+
- Public declarations: `sgl-kernel/include/sgl_kernel_ops.h`
38+
- Torch extension registration: `sgl-kernel/csrc/common_extension.cc`
39+
- Build: `sgl-kernel/CMakeLists.txt` (`set(SOURCES ...)`)
40+
- Python API: `sgl-kernel/python/sgl_kernel/` and `sgl-kernel/python/sgl_kernel/__init__.py`
41+
- Tests: `sgl-kernel/tests/test_scale.py`
42+
- Benchmarks: `sgl-kernel/benchmark/bench_scale.py`
43+
44+
---
45+
46+
## Step 1: Implement the kernel in `csrc/`
47+
48+
Pick the right subdirectory:
49+
50+
- `csrc/elementwise/` — for element-wise ops (our example)
51+
- `csrc/gemm/`, `csrc/attention/`, `csrc/moe/` — for other categories
52+
53+
Create `sgl-kernel/csrc/elementwise/scale.cu`:
54+
55+
```cpp
56+
#include <ATen/cuda/CUDAContext.h>
57+
#include <c10/cuda/CUDAGuard.h>
58+
#include <torch/all.h>
59+
60+
#include "utils.h" // DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
61+
62+
// scale_kernel: out[i] = input[i] * factor
63+
// Supports float, half (__half), __nv_bfloat16 via template T
64+
template <typename T>
65+
__global__ void scale_kernel(T* __restrict__ out,
66+
const T* __restrict__ input,
67+
float factor,
68+
int64_t n) {
69+
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
70+
if (idx < n) {
71+
out[idx] = static_cast<T>(static_cast<float>(input[idx]) * factor);
72+
}
73+
}
74+
75+
void scale(at::Tensor& out, const at::Tensor& input, double factor) {
76+
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
77+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
78+
TORCH_CHECK(out.is_cuda(), "out must be a CUDA tensor");
79+
TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
80+
TORCH_CHECK(out.sizes() == input.sizes(), "out and input must have the same shape");
81+
TORCH_CHECK(out.scalar_type() == input.scalar_type(),
82+
"out and input must have the same dtype");
83+
84+
const int64_t n = input.numel();
85+
const int threads = 256;
86+
const int blocks = (n + threads - 1) / threads;
87+
88+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
89+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
90+
91+
// Dispatches over float, float16, bfloat16
92+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
93+
scale_kernel<c_type><<<blocks, threads, 0, stream>>>(
94+
static_cast<c_type*>(out.data_ptr()),
95+
static_cast<const c_type*>(input.data_ptr()),
96+
static_cast<float>(factor),
97+
n);
98+
cudaError_t status = cudaGetLastError();
99+
TORCH_CHECK(status == cudaSuccess,
100+
"scale_kernel launch failed: ", cudaGetErrorString(status));
101+
return true;
102+
});
103+
}
104+
```
105+
106+
**Key points:**
107+
108+
- Use `at::Tensor` (PyTorch tensors), `TORCH_CHECK` for validation, `at::cuda::getCurrentCUDAStream()` for stream
109+
- Keep Python wrappers thin; do shape/dtype/device validation in C++ right around the launch path
110+
- `DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` covers `float`, `half` (FP16), `__nv_bfloat16` (BF16)
111+
- Add device error checking after every kernel launch
112+
- If a kernel only works on certain architectures, enforce that with `TORCH_CHECK` and skip logic in tests
113+
114+
---
115+
116+
## Step 2: Add a C++ declaration in `include/sgl_kernel_ops.h`
117+
118+
Edit `sgl-kernel/include/sgl_kernel_ops.h`, add to the elementwise section:
119+
120+
```cpp
121+
void scale(at::Tensor& out, const at::Tensor& input, double factor);
122+
```
123+
124+
---
125+
126+
## Step 3: Register the op in `csrc/common_extension.cc`
127+
128+
Edit `sgl-kernel/csrc/common_extension.cc`, inside `TORCH_LIBRARY_FRAGMENT(sgl_kernel, m)`:
129+
130+
```cpp
131+
// From csrc/elementwise
132+
m.def("scale(Tensor! out, Tensor input, float factor) -> ()");
133+
m.impl("scale", torch::kCUDA, &scale);
134+
```
135+
136+
**Key points:**
137+
138+
- `Tensor!` means in-place / mutable output argument
139+
- The schema is important for `torch.compile` and for consistent call signatures
140+
- Keep the torch schema in PyTorch scalar types (`float` here), but note that the C++ launcher signature still needs `double` for scalar arguments accepted by `torch::Library`
141+
142+
---
143+
144+
## Step 4: Add the new source file to `CMakeLists.txt`
145+
146+
Edit `sgl-kernel/CMakeLists.txt`, add to `set(SOURCES ...)`:
147+
148+
```cmake
149+
csrc/elementwise/scale.cu
150+
```
151+
152+
**Key points:**
153+
154+
- Keep the list **alphabetically sorted** (the file explicitly requires this)
155+
- If the kernel has arch constraints, reflect that in tests/benchmarks via skip logic
156+
157+
---
158+
159+
## Step 5: Expose a Python API under `sgl-kernel/python/sgl_kernel/`
160+
161+
Prefer following the existing module organization first. For elementwise kernels, the usual pattern is:
162+
163+
- implement the Python wrapper in `sgl-kernel/python/sgl_kernel/elementwise.py`
164+
- then re-export it from `sgl-kernel/python/sgl_kernel/__init__.py`
165+
166+
For example, in `sgl-kernel/python/sgl_kernel/elementwise.py`, add:
167+
168+
```python
169+
import torch
170+
171+
def scale(
172+
input: torch.Tensor,
173+
factor: float,
174+
out: torch.Tensor | None = None,
175+
) -> torch.Tensor:
176+
"""
177+
Element-wise scale: out = input * factor.
178+
179+
Supported dtypes: torch.float16, torch.bfloat16, torch.float32.
180+
181+
Parameters
182+
----------
183+
input : CUDA input tensor
184+
factor : scale factor (float)
185+
out : optional pre-allocated CUDA output tensor (same shape/dtype as input)
186+
"""
187+
if out is None:
188+
out = torch.empty_like(input)
189+
torch.ops.sgl_kernel.scale.default(out, input, factor)
190+
return out
191+
```
192+
193+
Then re-export it from `sgl-kernel/python/sgl_kernel/__init__.py` following the existing import style used by other kernels.
194+
195+
---
196+
197+
## Step 6: Write tests (required)
198+
199+
Create `sgl-kernel/tests/test_scale.py`:
200+
```python
201+
import pytest
202+
203+
import torch
204+
import sgl_kernel
205+
206+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
207+
@pytest.mark.parametrize("size", [128, 1024, 4096, 65536])
208+
@pytest.mark.parametrize("factor", [0.5, 1.0, 2.0])
209+
def test_scale_correctness(dtype, size, factor):
210+
input = torch.randn(size, dtype=dtype, device="cuda")
211+
out = torch.empty_like(input)
212+
213+
result = sgl_kernel.scale(input, factor, out=out)
214+
assert result is out
215+
216+
expected = input * factor
217+
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-2, 1e-2)
218+
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
219+
220+
221+
def test_scale_shape_mismatch():
222+
input = torch.randn(128, dtype=torch.float16, device="cuda")
223+
out = torch.empty(256, dtype=torch.float16, device="cuda")
224+
with pytest.raises(RuntimeError, match="same shape"):
225+
sgl_kernel.scale(input, 2.0, out=out)
226+
227+
228+
def test_scale_cpu_input():
229+
input = torch.randn(128, dtype=torch.float16) # CPU
230+
out = torch.empty_like(input)
231+
with pytest.raises(RuntimeError, match="CUDA"):
232+
sgl_kernel.scale(input, 2.0, out=out)
233+
234+
235+
if __name__ == "__main__":
236+
import sys
237+
sys.exit(pytest.main([__file__, "-q"]))
238+
```
239+
240+
---
241+
242+
## Step 7: Add a benchmark (required)
243+
244+
Create `sgl-kernel/benchmark/bench_scale.py`:
245+
246+
```python
247+
import itertools
248+
249+
import torch
250+
import triton
251+
import triton.testing
252+
253+
import sgl_kernel
254+
from sglang.utils import is_in_ci
255+
256+
IS_CI = is_in_ci()
257+
258+
dtypes = [torch.float16] if IS_CI else [torch.float16, torch.bfloat16, torch.float32]
259+
sizes = [4096] if IS_CI else [2**n for n in range(10, 20)] # 1K … 512K
260+
factors = [2.0]
261+
262+
configs = list(itertools.product(dtypes, sizes))
263+
264+
265+
def torch_scale(input: torch.Tensor, factor: float) -> torch.Tensor:
266+
return input * factor
267+
268+
269+
@triton.testing.perf_report(
270+
triton.testing.Benchmark(
271+
x_names=["dtype", "size"],
272+
x_vals=configs,
273+
line_arg="provider",
274+
line_vals=["sglang", "torch"],
275+
line_names=["SGL Kernel", "PyTorch"],
276+
styles=[("green", "-"), ("red", "--")],
277+
ylabel="µs (median)",
278+
plot_name="scale-performance",
279+
args={},
280+
)
281+
)
282+
def benchmark(dtype, size, provider):
283+
input = torch.randn(size, dtype=dtype, device="cuda")
284+
out = torch.empty_like(input)
285+
factor = 2.0
286+
287+
if provider == "sglang":
288+
fn = lambda: sgl_kernel.scale(input, factor, out=out)
289+
else:
290+
fn = lambda: torch_scale(input, factor)
291+
292+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
293+
fn, quantiles=[0.5, 0.2, 0.8]
294+
)
295+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
296+
297+
298+
if __name__ == "__main__":
299+
benchmark.run(print_data=True)
300+
```
301+
302+
---
303+
304+
## Step 8: Build
305+
306+
Build:
307+
308+
```bash
309+
cd sgl-kernel
310+
make build -j16
311+
```
312+
313+
If you need to limit host resource usage:
314+
315+
```bash
316+
cd sgl-kernel
317+
make build -j1 MAX_JOBS=2 CMAKE_ARGS="-DSGL_KERNEL_COMPILE_THREADS=1"
318+
```
319+
320+
---
321+
322+
## Step 9: Validate
323+
324+
After building successfully, run the test and benchmark:
325+
326+
```bash
327+
pytest sgl-kernel/tests/test_scale.py -q
328+
python sgl-kernel/benchmark/bench_scale.py
329+
```
330+
331+
---
332+
333+
## Troubleshooting
334+
335+
- **Async CUDA errors**: `CUDA_LAUNCH_BLOCKING=1`
336+
- **Memory errors**: `compute-sanitizer --tool memcheck python ...`
337+
- **Build is too slow / OOM**: reduce `MAX_JOBS` and `SGL_KERNEL_COMPILE_THREADS`
338+
- **Binary bloat**: use `sgl-kernel/analyze_whl_kernel_sizes.py`
339+
- **CMake sources list**: if your `.cu` file is missing from `SOURCES`, the symbol will be undefined at link time
340+
341+
---
342+
343+
## References
344+
345+
- `sgl-kernel/README.md`
346+
- `sgl-kernel/include/sgl_kernel_ops.h`
347+
- `sgl-kernel/csrc/common_extension.cc`
348+
- `sgl-kernel/CMakeLists.txt`
349+
- `sgl-kernel/include/utils.h``DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16` macro and friends
350+
- `sgl-kernel/csrc/elementwise/activation.cu` — reference for the FP16/BF16/FP32 dispatch pattern
351+
352+
## Summary of Files Created/Modified
353+
354+
```
355+
sgl-kernel/csrc/elementwise/scale.cu # NEW: CUDA kernel + launcher
356+
sgl-kernel/include/sgl_kernel_ops.h # MODIFIED: C++ declaration
357+
sgl-kernel/csrc/common_extension.cc # MODIFIED: schema + dispatch registration
358+
sgl-kernel/CMakeLists.txt # MODIFIED: add source file (alphabetical)
359+
sgl-kernel/python/sgl_kernel/elementwise.py # MODIFIED: Python wrapper
360+
sgl-kernel/python/sgl_kernel/__init__.py # MODIFIED: re-export Python API
361+
sgl-kernel/tests/test_scale.py # NEW: tests
362+
sgl-kernel/benchmark/bench_scale.py # NEW: benchmark
363+
```

0 commit comments

Comments
 (0)