|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | + |
| 12 | +#include "cutlass/cutlass.h" |
| 13 | + |
| 14 | +#include "cute/tensor.hpp" |
| 15 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 16 | +#include "cutlass/epilogue/collective/default_epilogue.hpp" |
| 17 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 18 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 19 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 20 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 21 | + |
| 22 | +#include "cutlass/util/mixed_dtype_utils.hpp" |
| 23 | +#include "cutlass/util/packed_stride.hpp" |
| 24 | + |
| 25 | +#include "cutlass_extensions/include/kernel_mode.h" |
| 26 | + |
| 27 | +namespace fbgemm_gpu { |
| 28 | + |
| 29 | +#if CUDART_VERSION >= 12000 |
| 30 | + |
| 31 | +template <int TB_M, int TB_N, int TBS_M, int TBS_N, int TBS_K, bool COOP> |
| 32 | +at::Tensor _f8i4bf16_shuffled( |
| 33 | + at::Tensor XQ, |
| 34 | + at::Tensor WQ, |
| 35 | + at::Tensor x_scale, |
| 36 | + at::Tensor w_scale) { |
| 37 | + // Get shape information from input tensors. |
| 38 | + int M = XQ.size(0); |
| 39 | + int K = XQ.size(1); |
| 40 | + int N = WQ.size(0); |
| 41 | + // Make sure w_scale is in proper format. |
| 42 | + TORCH_CHECK( |
| 43 | + w_scale.size(1) == 8, |
| 44 | + "Weights and scales must be prepacked with preshuffle_i4."); |
| 45 | + int num_groups = w_scale.size(0); |
| 46 | + int group_size = K / num_groups; |
| 47 | + // Allocate output. |
| 48 | + at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16)); |
| 49 | + |
| 50 | + // Define input types. |
| 51 | + using MmaType = cutlass::float_e4m3_t; |
| 52 | + using QuantType = cutlass::int4b_t; |
| 53 | + constexpr int TileShapeK = 128 * 8 / cute::sizeof_bits<MmaType>::value; |
| 54 | + |
| 55 | + // A Matrix configuration. |
| 56 | + using ElementA = MmaType; |
| 57 | + using LayoutA = cutlass::layout::RowMajor; |
| 58 | + constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
| 59 | + |
| 60 | + // B Matrix Configuration. |
| 61 | + using ElementB = QuantType; |
| 62 | + using LayoutB = cutlass::layout::ColumnMajor; |
| 63 | + constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; |
| 64 | + |
| 65 | + // We need to manually swap and transpose inputs. Unclear how required this is |
| 66 | + // though. |
| 67 | + using LayoutA_Transpose = |
| 68 | + typename cutlass::layout::LayoutTranspose<LayoutA>::type; |
| 69 | + using LayoutB_Transpose = |
| 70 | + typename cutlass::layout::LayoutTranspose<LayoutB>::type; |
| 71 | + |
| 72 | + using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>; |
| 73 | + using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>; |
| 74 | + |
| 75 | + // Define layout for shuffled weight tensor. |
| 76 | + using LayoutAtomQuant = |
| 77 | + decltype(cutlass::compute_memory_reordering_atom<MmaType>()); |
| 78 | + using LayoutB_Reordered = decltype(cute::tile_to_shape( |
| 79 | + LayoutAtomQuant{}, cute::Layout<cute::Shape<int, int, int>, StrideB>{})); |
| 80 | + |
| 81 | + using ElementScale = MmaType; |
| 82 | + |
| 83 | + // Output Matrix configuration. |
| 84 | + using ElementC = cutlass::bfloat16_t; |
| 85 | + using LayoutC = cutlass::layout::RowMajor; |
| 86 | + constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; |
| 87 | + |
| 88 | + // Core kernel configurations |
| 89 | + using ElementAccumulator = float; |
| 90 | + using ElementCompute = float; |
| 91 | + using ArchTag = cutlass::arch::Sm90; |
| 92 | + using OperatorClass = cutlass::arch::OpClassTensorOp; |
| 93 | + // TODO tune these shapes. |
| 94 | + using TileShape = |
| 95 | + cute::Shape<cute::Int<TB_M>, cute::Int<TB_N>, cute::Int<TileShapeK>>; |
| 96 | + using ClusterShape = |
| 97 | + cute::Shape<cute::Int<TBS_M>, cute::Int<TBS_N>, cute::Int<TBS_K>>; |
| 98 | + // TODO Should we use fast accum here? |
| 99 | + using KernelSchedule = cute::conditional_t< |
| 100 | + COOP, |
| 101 | + cutlass::gemm::KernelTmaWarpSpecializedCooperative, |
| 102 | + cutlass::gemm::KernelTmaWarpSpecialized>; |
| 103 | + // Might be the only epilogue schedule that supports swap + transpose. |
| 104 | + using EpilogueSchedule = cute::conditional_t< |
| 105 | + COOP, |
| 106 | + cutlass::epilogue::TmaWarpSpecializedCooperative, |
| 107 | + cutlass::epilogue::TmaWarpSpecialized>; |
| 108 | + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; |
| 109 | + |
| 110 | + // Define EVT for rowwise scaling. |
| 111 | + using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast< |
| 112 | + 0, |
| 113 | + TileShape, |
| 114 | + ElementAccumulator, |
| 115 | + ElementAccumulator, |
| 116 | + cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>; |
| 117 | + |
| 118 | + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; |
| 119 | + |
| 120 | + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< |
| 121 | + cutlass::multiplies, |
| 122 | + ElementC, // First stage output type. |
| 123 | + ElementAccumulator, // First stage input types. |
| 124 | + cutlass::FloatRoundStyle::round_to_nearest>; |
| 125 | + |
| 126 | + using EpilogueEVT = |
| 127 | + cutlass::epilogue::fusion::Sm90EVT<Compute0, XScale, Accum>; |
| 128 | + |
| 129 | + using CollectiveEpilogue = |
| 130 | + typename cutlass::epilogue::collective::CollectiveBuilder< |
| 131 | + cutlass::arch::Sm90, |
| 132 | + cutlass::arch::OpClassTensorOp, |
| 133 | + TileShape, |
| 134 | + ClusterShape, |
| 135 | + EpilogueTileType, |
| 136 | + ElementAccumulator, |
| 137 | + ElementAccumulator, |
| 138 | + ElementC, |
| 139 | + typename cutlass::layout::LayoutTranspose<LayoutC>::type, |
| 140 | + AlignmentC, |
| 141 | + ElementC, |
| 142 | + typename cutlass::layout::LayoutTranspose<LayoutC>::type, |
| 143 | + AlignmentC, |
| 144 | + EpilogueSchedule, |
| 145 | + EpilogueEVT>::CollectiveOp; |
| 146 | + |
| 147 | + using CollectiveMainloopShuffled = |
| 148 | + typename cutlass::gemm::collective::CollectiveBuilder< |
| 149 | + ArchTag, |
| 150 | + OperatorClass, |
| 151 | + cute::tuple<ElementB, cutlass::Array<ElementScale, 8>>, |
| 152 | + LayoutB_Reordered, |
| 153 | + AlignmentB, |
| 154 | + ElementA, |
| 155 | + LayoutA_Transpose, |
| 156 | + AlignmentA, |
| 157 | + ElementAccumulator, |
| 158 | + TileShape, |
| 159 | + ClusterShape, |
| 160 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| 161 | + sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 162 | + KernelSchedule>::CollectiveOp; |
| 163 | + |
| 164 | + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< |
| 165 | + cute::Shape<int, int, int, int>, |
| 166 | + CollectiveMainloopShuffled, |
| 167 | + CollectiveEpilogue>; |
| 168 | + |
| 169 | + using GemmShuffled = |
| 170 | + cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>; |
| 171 | + |
| 172 | + using StrideC = typename GemmKernelShuffled::StrideC; |
| 173 | + |
| 174 | + /// Initialization |
| 175 | + auto shape_B = cute::make_shape(N, K, 1); |
| 176 | + StrideA stride_A = |
| 177 | + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); |
| 178 | + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); |
| 179 | + StrideC stride_C = |
| 180 | + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1)); |
| 181 | + LayoutB_Reordered layout_B_reordered = |
| 182 | + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); |
| 183 | + using StrideS = typename CollectiveMainloopShuffled::StrideScale; |
| 184 | + StrideS stride_S = cutlass::make_cute_packed_stride( |
| 185 | + StrideS{}, cute::make_shape(N, num_groups, 1)); |
| 186 | + |
| 187 | + // Define Gemm arguments. |
| 188 | + typename GemmShuffled::Arguments arguments{ |
| 189 | + cutlass::gemm::GemmUniversalMode::kGemm, |
| 190 | + {N, M, K, 1}, |
| 191 | + {reinterpret_cast<ElementB*>(WQ.data_ptr()), |
| 192 | + layout_B_reordered, |
| 193 | + reinterpret_cast<ElementA*>(XQ.data_ptr()), |
| 194 | + stride_A, |
| 195 | + reinterpret_cast<cutlass::Array<ElementScale, 8>*>(w_scale.data_ptr()), |
| 196 | + stride_S, |
| 197 | + group_size}, |
| 198 | + {{}, |
| 199 | + reinterpret_cast<ElementC*>(Y.data_ptr()), |
| 200 | + stride_C, |
| 201 | + reinterpret_cast<ElementC*>(Y.data_ptr()), |
| 202 | + stride_C}}; |
| 203 | + |
| 204 | + arguments.epilogue.thread = { |
| 205 | + {reinterpret_cast<ElementAccumulator*>(x_scale.data_ptr())}, // x_scale |
| 206 | + {}, // Accumulator |
| 207 | + {}, // Multiplies |
| 208 | + }; |
| 209 | + |
| 210 | + // Launch the workload. |
| 211 | + GemmShuffled gemm; |
| 212 | + |
| 213 | + // Using the arguments, query for extra workspace required for matrix |
| 214 | + // multiplication computation |
| 215 | + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); |
| 216 | + |
| 217 | + // Allocate workspace memory |
| 218 | + cutlass::device_memory::allocation<uint8_t> workspace(workspace_size); |
| 219 | + |
| 220 | + // Check the problem size is supported or not |
| 221 | + cutlass::Status status = gemm.can_implement(arguments); |
| 222 | + if (status != cutlass::Status::kSuccess) { |
| 223 | + throw std::runtime_error("cutlass cannot implement"); |
| 224 | + } |
| 225 | + |
| 226 | + // Initialize CUTLASS kernel with arguments and workspace pointer |
| 227 | + status = gemm.initialize(arguments, workspace.get()); |
| 228 | + if (status != cutlass::Status::kSuccess) { |
| 229 | + throw std::runtime_error("cutlass cannot initialize"); |
| 230 | + } |
| 231 | + |
| 232 | + status = gemm(at::cuda::getCurrentCUDAStream()); |
| 233 | + |
| 234 | + if (status != cutlass::Status::kSuccess) { |
| 235 | + throw std::runtime_error( |
| 236 | + std::string("cutlass cannot run") + |
| 237 | + cutlass::cutlassGetStatusString(status)); |
| 238 | + } |
| 239 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 240 | + |
| 241 | + return Y; |
| 242 | +} |
| 243 | + |
| 244 | +at::Tensor f8i4bf16_shuffled( |
| 245 | + at::Tensor XQ, |
| 246 | + at::Tensor WQ, |
| 247 | + at::Tensor x_scale, |
| 248 | + at::Tensor w_scale) { |
| 249 | + int M = XQ.size(0); |
| 250 | + int K = XQ.size(1); |
| 251 | + int N = WQ.size(0); |
| 252 | + // Use shape heuristics to dispatch to optimized kernel configuration. |
| 253 | + if (M <= 16) { |
| 254 | + return _f8i4bf16_shuffled<64, 16, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale); |
| 255 | + } else if (M <= 32) { |
| 256 | + return _f8i4bf16_shuffled<64, 32, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale); |
| 257 | + } else if (M <= 64) { |
| 258 | + return _f8i4bf16_shuffled<64, 64, 2, 1, 1, false>(XQ, WQ, x_scale, w_scale); |
| 259 | + } else if (M <= 128) { |
| 260 | + return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>( |
| 261 | + XQ, WQ, x_scale, w_scale); |
| 262 | + } else if (M <= 256) { |
| 263 | + if (N <= 4096) { |
| 264 | + return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>( |
| 265 | + XQ, WQ, x_scale, w_scale); |
| 266 | + } else { |
| 267 | + return _f8i4bf16_shuffled<64, 256, 1, 1, 1, false>( |
| 268 | + XQ, WQ, x_scale, w_scale); |
| 269 | + } |
| 270 | + } else if (M <= 512) { |
| 271 | + if (N <= 4096) { |
| 272 | + return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>( |
| 273 | + XQ, WQ, x_scale, w_scale); |
| 274 | + } else { |
| 275 | + return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>( |
| 276 | + XQ, WQ, x_scale, w_scale); |
| 277 | + } |
| 278 | + } else if (M <= 1024) { |
| 279 | + if (N <= 1024) { |
| 280 | + return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>( |
| 281 | + XQ, WQ, x_scale, w_scale); |
| 282 | + } else if (N <= 2048) { |
| 283 | + return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>( |
| 284 | + XQ, WQ, x_scale, w_scale); |
| 285 | + } else { |
| 286 | + return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>( |
| 287 | + XQ, WQ, x_scale, w_scale); |
| 288 | + } |
| 289 | + } else { |
| 290 | + if (N <= 1024) { |
| 291 | + return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>( |
| 292 | + XQ, WQ, x_scale, w_scale); |
| 293 | + } else { |
| 294 | + return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>( |
| 295 | + XQ, WQ, x_scale, w_scale); |
| 296 | + } |
| 297 | + } |
| 298 | +} |
| 299 | + |
| 300 | +#else |
| 301 | + |
| 302 | +at::Tensor f8i4bf16_shuffled( |
| 303 | + at::Tensor XQ, |
| 304 | + at::Tensor WQ, |
| 305 | + at::Tensor x_scale, |
| 306 | + at::Tensor w_scale) { |
| 307 | + throw std::runtime_error( |
| 308 | + "CUDA version is older than 12.0"); // requires CUDA>=12 |
| 309 | +} |
| 310 | + |
| 311 | +#endif |
| 312 | + |
| 313 | +} // namespace fbgemm_gpu |
0 commit comments