-
-
Notifications
You must be signed in to change notification settings - Fork 16.9k
Expand file tree
/
Copy pathtorch_bindings.cpp
More file actions
711 lines (617 loc) · 28.1 KB
/
torch_bindings.cpp
File metadata and controls
711 lines (617 loc) · 28.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
#include <torch/library.h>
#include <torch/version.h>
// Note on op signatures:
// The X_meta signatures are for the meta functions corresponding to op X.
// They must be kept in sync with the signature for X. Generally, only
// functions that return Tensors require a meta function.
//
// See the following links for detailed docs on op registration and function
// schemas.
// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
//
ops.def(
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
"y_q, Tensor! y_s,"
"bool use_ue8m0) -> ()");
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
&persistent_masked_m_silu_mul_quant);
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
ops.def(
"merge_attn_states("
" Tensor! output,"
" Tensor!? output_lse,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context,"
" Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
&convert_vertical_slash_indexes);
ops.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
// SwiGLU activation with input clamping.
ops.def(
"silu_and_mul_with_clamp(Tensor! result, Tensor input, float limit) "
"-> ()");
ops.impl("silu_and_mul_with_clamp", torch::kCUDA, &silu_and_mul_clamp);
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
// FATReLU implementation.
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
ops.def(
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
"limit=7.0) "
"-> ()");
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);
// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Function for fused QK Norm and RoPE
ops.def(
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
"int num_heads_k, int num_heads_v, int head_dim, float eps, "
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids, "
"int forced_token_heads_per_warp=-1) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"float eps, int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
"Tensor output_mask, Tensor repetition_penalties) -> ()");
ops.impl("apply_repetition_penalties_", torch::kCUDA,
&apply_repetition_penalties_);
// Optimized top-k per row operation
ops.def(
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, int numRows, int stride0, "
"int stride1, int topK) -> ()");
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
ops.def(
"top_k_per_row_decode(Tensor logits, int next_n, "
"Tensor seq_lens, Tensor! indices, "
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"persistent_topk(Tensor logits, Tensor lengths, Tensor! output, "
"Tensor workspace, int k, int max_seq_len) -> ()");
ops.impl("persistent_topk", torch::kCUDA, &persistent_topk);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
"()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
ops.def(
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual) -> ()");
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant);
// Fused Layernorm + Block quant kernels
ops.def(
"rms_norm_per_block_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual, int group_size, "
"bool is_scale_transposed) -> ()");
ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox, int "
"rope_dim_offset=0, bool inverse=False) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Quantization ops
#ifndef USE_ROCM
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Note about marlin kernel 'workspace' arguments:
// Technically these should be mutable since they are modified by the kernel.
// But since they are set back to zero once the kernel is finished we can
// hand wave and say that they have no net effect.
//
// The reason to mark 'workspace' as immutable is so that they don't interfere
// with using ScalarType arguments in the ops. If they are marked as mutable,
// pytorch throws an assert in
// 'torch._higher_order_ops._register_effectful_op' that prevents these
// kernels from being torch.compile'd.
// See the following document for more info on custom types and ops that use
// custom types:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def(
"machete_supported_schedules("
" ScalarType a_type,"
" int b_type,"
" ScalarType? maybe_group_scales_type,"
" ScalarType? maybe_group_zeros_type,"
" ScalarType? maybe_channel_scales_type,"
" ScalarType? maybe_token_scales_type,"
" ScalarType? maybe_out_type"
") -> str[]");
ops.def(
"machete_mm("
" Tensor A,"
" Tensor B,"
" int b_type,"
" ScalarType? out_type,"
" Tensor? group_scales,"
" Tensor? group_zeros,"
" int? group_size,"
" Tensor? channel_scales,"
" Tensor? token_scales,"
" str? schedule"
") -> Tensor");
ops.def(
"machete_prepack_B("
" Tensor B,"
" ScalarType a_type,"
" int b_type,"
" ScalarType? group_scales_type"
") -> Tensor");
// conditionally compiled so impl registration is in source file
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
ops.def(
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// preprocess W-int4A-fp8 weight for marlin kernel
ops.def(
"marlin_int4_fp8_preprocess(Tensor qweight, "
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file
#endif
// Dequantization for GGML.
ops.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML.
ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML.
ops.def(
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// moe kernel for GGML.
ops.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
ops.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def(
"mxfp8_experts_quant("
" Tensor input, Tensor problem_sizes, Tensor expert_offsets,"
" Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)"
" -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+).
ops.def(
"cutlass_mxfp8_grouped_mm("
" Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)"
" -> ()");
// conditionally compiled so impl registration is in source file
// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
" Tensor q_pe, Tensor kv_c_and_k_pe_cache,"
" Tensor seq_lens, Tensor page_table,"
" Tensor workspace, float scale,"
" int num_kv_splits) -> ()");
// conditionally compiled so impl in source file
// SM100 CUTLASS MLA workspace
ops.def(
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
" int sm_count, int num_kv_splits) "
"-> int");
// conditionally compiled so impl in source file
#endif
// Quantized GEMM for GPTQ.
// Note: even though the C++ inferred schema is correct for this op, it seems
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
"-> Tensor");
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
"(int, int)? group_shape=None) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states,"
"int null_block_id,"
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
#ifndef USE_ROCM
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops.def(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"Tensor? b_zeros, "
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
"Tensor!? b_zeros_reorder, "
"int K, int N, int N_32align) -> ()");
// conditionally compiled so impl in source file
// AllSpark quantization ops
ops.def(
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
ops.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor");
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)");
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
// conditionally compiled so impl in source file
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes,"
" bool is_src_access_order_any=False) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
" Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Concat kv_c and k_pe and cache them.
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
// Rotate Q and K, then write to kv cache for MLA
cache_ops.def(
"concat_and_cache_mla_rope_fused("
" Tensor positions,"
" Tensor! q_pe,"
" Tensor! k_pe,"
" Tensor kv_c,"
" Tensor cos_sin_cache,"
" bool is_neox,"
" Tensor slot_mapping,"
" Tensor! kv_cache,"
" str kv_cache_dtype,"
" Tensor kv_cache_scale) -> ()");
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
&concat_and_cache_mla_rope_fused);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
cache_ops.def(
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" Tensor token_to_seq, "
" int num_tokens, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
&gather_and_maybe_dequant_cache);
cache_ops.def(
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
"batch_size) -> ()");
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
&cp_gather_and_upconvert_fp8_kv_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
cache_ops.impl("cp_gather_indexer_k_quant_cache", torch::kCUDA,
&cp_gather_indexer_k_quant_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
// Gets the specified device attribute.
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.impl("get_device_attribute", &get_device_attribute);
// Gets the maximum shared memory per block device attribute.
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool fully_connected) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.def("register_buffer", ®ister_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
custom_ar.def("allocate_shared_buffer_and_handle",
&allocate_shared_buffer_and_handle);
custom_ar.def("open_mem_handle(Tensor mem_handle) -> int", &open_mem_handle);
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
custom_ar.def("free_shared_buffer", &free_shared_buffer);
#ifdef USE_ROCM
// Quick Reduce all-reduce kernels
custom_ar.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
custom_ar.def("init_custom_qr", &init_custom_qr);
custom_ar.def("qr_destroy", &qr_destroy);
custom_ar.def("qr_get_handle", &qr_get_handle);
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
custom_ar.def("qr_max_size", &qr_max_size);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)