-
Notifications
You must be signed in to change notification settings - Fork 19
[Perf] Explore more performant Fp8 Casting #83
Comments
There is also this intrinsic: |
Another point to consider: right now we have only fp32->fp8 cast, so we can probably optimize fp16->fp8 cast with specialized intrinsic... |
Fun fact: #include <cuda_fp8.h>
__device__ unsigned char conv_e5m2_nosat(float value) {
return __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2);
}
__device__ unsigned char conv_e5m2(float value) {
return __nv_cvt_float_to_fp8(value, __NV_SATFINITE, __NV_E5M2);
} Results in 2 very different kernels:
|
But perfwise, with all the TensorIterator overhead they seem to be taking roughly the same time, will try to write a targeted kernel right now: // Run me as: nvcc -gencode arch=compute_90,code=sm_90 foo.cu -O3; ncu ./a.out
#include <cuda_fp8.h>
template<__nv_fp8_interpretation_t DTYPE = __NV_E5M2, __nv_saturation_t SAT = __NV_SATFINITE>
__global__
void do_conv_e5m2(const float* inp, char* out, unsigned size) {
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size) {
return;
}
out[idx] = __nv_cvt_float_to_fp8(inp[idx], SAT, DTYPE);
}
int main() {
float *fp32_ptr = nullptr;
char* fp8_ptr = nullptr;
constexpr unsigned numElem = 1024*1024*64;
constexpr unsigned blockSize = 512;
constexpr auto numBlocks = (numElem + blockSize - 1 ) / blockSize;
cudaMalloc(&fp32_ptr, sizeof(*fp32_ptr)*numElem);
cudaMalloc(&fp8_ptr, sizeof(*fp8_ptr)*numElem);
do_conv_e5m2<__NV_E5M2, __NV_SATFINITE><<<numBlocks, blockSize>>>(fp32_ptr, fp8_ptr, numElem);
do_conv_e5m2<__NV_E5M2, __NV_NOSAT><<<numBlocks, blockSize>>>(fp32_ptr, fp8_ptr, numElem);
cudaDeviceSynchronize();
return 0;
} |
Another thing to consider:
|
I further explored this and created a kernel cuda as well to do the casting here: Performance comparing against inductor/triton can be found here: Note though is that I also need to add an option to return the matrix in transposed format ( which we will need to fuse for the backward). For delayed scaling we would expect the absmax calc to be fused into the prior op, and then this fused kernel could be used based off of the historical scale. We should weigh though if it is worth the added complexity to ship in this repo |
moved to pytorch/ao#559 |
Summary
There are two components to this, non_saturated casting and saturated casting.
Non-Saturated casting
Saturated Casting
float8_experimental/float8_experimental/float8_utils.py
Line 19 in cdcadb5
There does appear to be intrinisics with PTX for doing saturated casts, see: https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L182
The text was updated successfully, but these errors were encountered: