Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[Perf] Explore more performant Fp8 Casting #83

Closed
drisspg opened this issue Sep 12, 2023 · 7 comments
Closed

[Perf] Explore more performant Fp8 Casting #83

drisspg opened this issue Sep 12, 2023 · 7 comments
Labels
Perf Issues related to perf optimizations

Comments

@drisspg
Copy link
Contributor

drisspg commented Sep 12, 2023

Summary

There are two components to this, non_saturated casting and saturated casting.

Non-Saturated casting

  • We are currently using bit logic to cast from fp32 to fp8 where as there exists intrinsics to perform the same, see Nikitas comment below.
  • Currently for fp16 -> fp8 casting we actually first rescaled fp16 to fp32 and then recast to fp8.

Saturated Casting

@drisspg drisspg changed the title Check saturated cast isa [Perf] Check saturated cast isa Sep 22, 2023
@drisspg drisspg added the Perf Issues related to perf optimizations label Sep 25, 2023
@malfet
Copy link

malfet commented Sep 25, 2023

There is also this intrinsic: __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2);

@drisspg drisspg changed the title [Perf] Check saturated cast isa [Perf] Fp8 Casting Sep 25, 2023
@malfet
Copy link

malfet commented Sep 25, 2023

Another point to consider: right now we have only fp32->fp8 cast, so we can probably optimize fp16->fp8 cast with specialized intrinsic...

@malfet
Copy link

malfet commented Sep 25, 2023

Fun fact:

#include <cuda_fp8.h>

__device__ unsigned char conv_e5m2_nosat(float value) {
  return __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2);
}

__device__ unsigned char conv_e5m2(float value) {
  return __nv_cvt_float_to_fp8(value, __NV_SATFINITE, __NV_E5M2);
}

Results in 2 very different kernels:

% nvcc -dc -gencode arch=compute_90,code=sm_90 foo.cu -O3; cuobjdump -sass foo.o 
		Function : _Z9conv_e5m2f
	.headerflags	@"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
        /*0000*/                   F2FP.SATFINITE.E5M2.F32.PACK_AB_MERGE_C R4, RZ, R4, R4.H1 ;  /* 0x00000004ff04723e */
                                                                                                /* 0x000fc80004806104 */
        /*0010*/                   LOP3.LUT R4, R4, 0xff, RZ, 0xc0, !PT ;                       /* 0x000000ff04047812 */
                                                                                                /* 0x000fe200078ec0ff */
        /*0020*/                   RET.ABS.NODEC R20 0x0 ;                                      /* 0x0000000014007950 */
                                                                                                /* 0x000fec0003e00000 */
        /*0030*/                   BRA 0x30;                                                    /* 0xfffffffc00fc7947 */
		Function : _Z15conv_e5m2_nosatf
	.headerflags	@"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM90 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM90)"
        /*0000*/                   F2F.F64.F32 R6, R4 ;                                         /* 0x0000000400067310 */
                                                                                                /* 0x0000620000201800 */
        /*0010*/                   LOP3.LUT R0, R4, 0x7fffffff, RZ, 0xc0, !PT ;                 /* 0x7fffffff04007812 */
                                                                                                /* 0x000fe200078ec0ff */
        /*0020*/                   BSSY B0, 0x3b0 ;                                             /* 0x0000038000007945 */
                                                                                                /* 0x000fe60003800000 */
        /*0030*/                   ISETP.GT.U32.AND P0, PT, R0, 0x7f800000, PT ;                /* 0x7f8000000000780c */
                                                                                                /* 0x000fe20003f04070 */
        /*0040*/                   IMAD.MOV.U32 R4, RZ, RZ, RZ ;                                /* 0x000000ffff047224 */
                                                                                                /* 0x001fc600078e00ff */
        /*0050*/                   SEL R8, R6, 0xe0000000, !P0 ;                                /* 0xe000000006087807 */
                                                                                                /* 0x002fe40004000000 */
        /*0060*/                   SEL R7, R7, 0x7fffffff, !P0 ;                                /* 0x7fffffff07077807 */
                                                                                                /* 0x000fe40004000000 */
        /*0070*/                   ISETP.GE.U32.AND P0, PT, R8, 0x1, PT ;                       /* 0x000000010800780c */
                                                                                                /* 0x000fe40003f06070 */
        /*0080*/                   LOP3.LUT R5, R7, 0x7fffffff, RZ, 0xc0, !PT ;                 /* 0x7fffffff07057812 */
                                                                                                /* 0x000fc800078ec0ff */
        /*0090*/                   ISETP.GE.U32.AND.EX P0, PT, R5, 0x3ee00000, PT, P0 ;         /* 0x3ee000000500780c */
                                                                                                /* 0x000fda0003f06100 */
        /*00a0*/              @!P0 BRA 0x3a0 ;                                                  /* 0x0000000000bc8947 */
                                                                                                /* 0x000fea0003800000 */
        /*00b0*/                   ISETP.GT.U32.AND P0, PT, R8, RZ, PT ;                        /* 0x000000ff0800720c */
                                                                                                /* 0x000fe40003f04070 */
        /*00c0*/                   SHF.R.U32.HI R0, RZ, 0x12, R7 ;                              /* 0x00000012ff007819 */
                                                                                                /* 0x000fe40000011607 */
        /*00d0*/                   ISETP.GT.U32.AND.EX P1, PT, R5, 0x7ff00000, PT, P0 ;         /* 0x7ff000000500780c */
                                                                                                /* 0x000fe40003f24100 */
        /*00e0*/                   LOP3.LUT R3, R0, 0x3, RZ, 0xc0, !PT ;                        /* 0x0000000300037812 */
                                                                                                /* 0x000fd600078ec0ff */
        /*00f0*/               @P1 BRA 0x390 ;                                                  /* 0x0000000000a41947 */
                                                                                                /* 0x000fea0003800000 */
        /*0100*/                   ISETP.GT.U32.AND P1, PT, R8, -0x1, PT ;                      /* 0xffffffff0800780c */
                                                                                                /* 0x000fe20003f24070 */
        /*0110*/                   IMAD.MOV.U32 R4, RZ, RZ, 0x7c ;                              /* 0x0000007cff047424 */
                                                                                                /* 0x000fc600078e00ff */
        /*0120*/                   ISETP.GT.U32.AND.EX P1, PT, R5, 0x40edffff, PT, P1 ;         /* 0x40edffff0500780c */
                                                                                                /* 0x000fda0003f24110 */
        /*0130*/               @P1 BRA 0x3a0 ;                                                  /* 0x0000000000981947 */
                                                                                                /* 0x000fea0003800000 */
        /*0140*/                   ISETP.GE.U32.AND P1, PT, R8, RZ, PT ;                        /* 0x000000ff0800720c */
                                                                                                /* 0x000fe40003f26070 */
        /*0150*/                   LEA.HI R4, R7, 0x10, RZ, 0xc ;                               /* 0x0000001007047811 */
                                                                                                /* 0x000fe400078f60ff */
        /*0160*/                   ISETP.GE.U32.AND.EX P1, PT, R5, 0x3f100000, PT, P1 ;         /* 0x3f1000000500780c */
                                                                                                /* 0x000fda0003f26110 */
        /*0170*/              @!P1 BRA 0x230 ;                                                  /* 0x00000000002c9947 */
                                                                                                /* 0x000fea0003800000 */
        /*0180*/                   LOP3.LUT R0, R0, 0x1, RZ, 0xc0, !PT ;                        /* 0x0000000100007812 */
                                                                                                /* 0x000fe200078ec0ff */
        /*0190*/                   IMAD.SHL.U32 R4, R4, 0x4, RZ ;                               /* 0x0000000404047824 */
                                                                                                /* 0x000fe200078e00ff */
        /*01a0*/                   ISETP.EQ.U32.AND P2, PT, R8, RZ, PT ;                        /* 0x000000ff0800720c */
                                                                                                /* 0x000fe40003f42070 */
        /*01b0*/                   LOP3.LUT R5, R7, 0x3ffff, RZ, 0xc0, !PT ;                    /* 0x0003ffff07057812 */
                                                                                                /* 0x000fe400078ec0ff */
        /*01c0*/                   ISETP.NE.U32.AND P1, PT, R0, 0x1, PT ;                       /* 0x000000010000780c */
                                                                                                /* 0x000fe40003f25070 */
        /*01d0*/                   LOP3.LUT R4, R3, 0x3fc, R4, 0xf8, !PT ;                      /* 0x000003fc03047812 */
                                                                                                /* 0x000fe400078ef804 */
        /*01e0*/                   ISETP.EQ.AND.EX P1, PT, R5, 0x20000, !P1, P2 ;               /* 0x000200000500780c */
                                                                                                /* 0x000fc40004f22320 */
        /*01f0*/                   LOP3.LUT R0, R4, 0xff, RZ, 0xc0, !PT ;                       /* 0x000000ff04007812 */
                                                                                                /* 0x000fe400078ec0ff */
        /*0200*/                   ISETP.GT.U32.OR.EX P1, PT, R5, 0x20000, P1, P0 ;             /* 0x000200000500780c */
                                                                                                /* 0x000fda0000f24500 */
        /*0210*/               @P1 IADD3 R4, R0, 0x1, RZ ;                                      /* 0x0000000100041810 */
                                                                                                /* 0x000fe20007ffe0ff */
        /*0220*/                   BRA 0x3a0 ;                                                  /* 0x00000000005c7947 */
                                                                                                /* 0x000fec0003800000 */
        /*0230*/                   IMAD.MOV R0, RZ, RZ, -R4 ;                                   /* 0x000000ffff007224 */
                                                                                                /* 0x000fe200078e0a04 */
        /*0240*/                   IADD3 R5, P0, RZ, -0x1, RZ ;                                 /* 0xffffffffff057810 */
                                                                                                /* 0x000fe40007f1e0ff */
        /*0250*/                   LOP3.LUT R3, R3, 0x4, RZ, 0xfc, !PT ;                        /* 0x0000000403037812 */
                                                                                                /* 0x000fe400078efcff */
        /*0260*/                   PRMT R0, R0, 0x7710, RZ ;                                    /* 0x0000771000007816 */
                                                                                                /* 0x000fe400000000ff */
        /*0270*/                   LOP3.LUT R5, R5, R8, RZ, 0xc0, !PT ;                         /* 0x0000000805057212 */
                                                                                                /* 0x000fe400078ec0ff */
        /*0280*/                   IADD3 R0, R0, 0x1, RZ ;                                      /* 0x0000000100007810 */
                                                                                                /* 0x000fe40007ffe0ff */
        /*0290*/                   ISETP.GT.U32.AND P1, PT, R5, RZ, PT ;                        /* 0x000000ff0500720c */
                                                                                                /* 0x000fc40003f24070 */
        /*02a0*/                   LOP3.LUT R0, R0, 0xff, RZ, 0xc0, !PT ;                       /* 0x000000ff00007812 */
                                                                                                /* 0x000fc800078ec0ff */
        /*02b0*/                   SHF.L.U64.HI R6, RZ, R0.reuse, 0x40000 ;                     /* 0x00040000ff067419 */
                                                                                                /* 0x080fe40000010200 */
        /*02c0*/                   SHF.R.U32.HI R4, RZ, R0.reuse, R3 ;                          /* 0x00000000ff047219 */
                                                                                                /* 0x080fe40000011603 */
        /*02d0*/                   IADD3.X R6, R6, -0x1, RZ, P0, !PT ;                          /* 0xffffffff06067810 */
                                                                                                /* 0x000fe400007fe4ff */
        /*02e0*/                   ISETP.NE.U32.AND P0, PT, R5, RZ, PT ;                        /* 0x000000ff0500720c */
                                                                                                /* 0x000fe40003f05070 */
        /*02f0*/                   SHF.L.U64.HI R0, RZ, R0, 0x20000 ;                           /* 0x00020000ff007419 */
                                                                                                /* 0x000fe40000010200 */
        /*0300*/                   LOP3.LUT R3, R6, 0x100000, R7, 0xe0, !PT ;                   /* 0x0010000006037812 */
                                                                                                /* 0x000fc400078ee007 */
        /*0310*/                   LOP3.LUT R5, R4, 0x1, RZ, 0xc0, !PT ;                        /* 0x0000000104057812 */
                                                                                                /* 0x000fe400078ec0ff */
        /*0320*/                   ISETP.NE.AND.EX P0, PT, R3.reuse, R0.reuse, PT, P0 ;         /* 0x000000000300720c */
                                                                                                /* 0x0c0fe40003f05300 */
        /*0330*/                   ISETP.GT.U32.AND.EX P1, PT, R3, R0, PT, P1 ;                 /* 0x000000000300720c */
                                                                                                /* 0x000fe40003f24110 */
        /*0340*/                   ISETP.NE.U32.OR P0, PT, R5, 0x1, P0 ;                        /* 0x000000010500780c */
                                                                                                /* 0x000fda0000705470 */
        /*0350*/               @P0 BRA !P1, 0x3a0 ;                                             /* 0x0000000000100947 */
                                                                                                /* 0x000fea0004800000 */
        /*0360*/                   LOP3.LUT R4, R4, 0xff, RZ, 0xc0, !PT ;                       /* 0x000000ff04047812 */
                                                                                                /* 0x000fc800078ec0ff */
        /*0370*/                   IADD3 R4, R4, 0x1, RZ ;                                      /* 0x0000000104047810 */
                                                                                                /* 0x000fe20007ffe0ff */
        /*0380*/                   BRA 0x3a0 ;                                                  /* 0x0000000000047947 */
                                                                                                /* 0x000fec0003800000 */
        /*0390*/                   LOP3.LUT R4, R3, 0x7e, RZ, 0xfc, !PT ;                       /* 0x0000007e03047812 */
                                                                                                /* 0x000fce00078efcff */
        /*03a0*/                   BSYNC B0 ;                                                   /* 0x0000000000007941 */
                                                                                                /* 0x000fea0003800000 */
        /*03b0*/                   SHF.R.U32.HI R3, RZ, 0x1f, R7 ;                              /* 0x0000001fff037819 */
                                                                                                /* 0x000fca0000011607 */
        /*03c0*/                   IMAD.SHL.U32 R3, R3, 0x80, RZ ;                              /* 0x0000008003037824 */
                                                                                                /* 0x000fca00078e00ff */
        /*03d0*/                   LOP3.LUT R4, R4, 0xff, R3, 0xc8, !PT ;                       /* 0x000000ff04047812 */
                                                                                                /* 0x000fe200078ec803 */
        /*03e0*/                   RET.ABS.NODEC R20 0x0 ;                                      /* 0x0000000014007950 */
                                                                                                /* 0x000fec0003e00000 */
        /*03f0*/                   BRA 0x3f0;                                                   /* 0xfffffffc00fc7947 */
                                                                                                /* 0x000fc0000383ffff */
        /*0400*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0410*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0420*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0430*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0440*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0450*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0460*/                   NOP;                                                         /* 0x0000000000007918 */
                                                                                                /* 0x000fc00000000000 */
        /*0470*/                   NOP;                                                         /* 0x0000000000007918 */

@drisspg drisspg changed the title [Perf] Fp8 Casting [Perf] Explore more performant Fp8 Casting Sep 25, 2023
@malfet
Copy link

malfet commented Sep 25, 2023

But perfwise, with all the TensorIterator overhead they seem to be taking roughly the same time, will try to write a targeted kernel right now:

// Run me as: nvcc -gencode arch=compute_90,code=sm_90 foo.cu -O3; ncu ./a.out
#include <cuda_fp8.h>

template<__nv_fp8_interpretation_t DTYPE = __NV_E5M2, __nv_saturation_t SAT = __NV_SATFINITE>
__global__
void do_conv_e5m2(const float* inp, char* out, unsigned size) {
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size) {
  return;
}
out[idx] = __nv_cvt_float_to_fp8(inp[idx], SAT, DTYPE);
}

int main() {
  float *fp32_ptr = nullptr;
  char* fp8_ptr = nullptr;
  constexpr unsigned numElem = 1024*1024*64;
  constexpr unsigned blockSize = 512;
  constexpr auto numBlocks = (numElem + blockSize - 1 ) / blockSize;

  cudaMalloc(&fp32_ptr, sizeof(*fp32_ptr)*numElem);
  cudaMalloc(&fp8_ptr, sizeof(*fp8_ptr)*numElem);

  do_conv_e5m2<__NV_E5M2, __NV_SATFINITE><<<numBlocks, blockSize>>>(fp32_ptr, fp8_ptr, numElem);
  do_conv_e5m2<__NV_E5M2, __NV_NOSAT><<<numBlocks, blockSize>>>(fp32_ptr, fp8_ptr, numElem);
  cudaDeviceSynchronize();
  return 0;
}

@malfet
Copy link

malfet commented Sep 27, 2023

Another thing to consider:

$ python3 -c "import torch;print(torch.arange(1e0, 1e6,step=5e4,device='cuda').to(torch.float16))"
tensor([1.0000e+00, 5.0016e+04,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf], device='cuda:0', dtype=torch.float16)
$ python3 -c "import torch;print(torch.arange(1e0, 1e6,step=5e4,device='cuda').to(torch.float8_e5m2))"
tensor([1.0000e+00, 4.9152e+04,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf], device='cuda:0', dtype=torch.float8_e5m2)
$  python3 -c "import torch;print(torch.arange(1e0, 1e6,step=5e4,device='cuda').to(torch.float8_e4m3fn))"
tensor([1., nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan], device='cuda:0', dtype=torch.float8_e4m3fn)

@drisspg
Copy link
Contributor Author

drisspg commented Feb 22, 2024

I further explored this and created a kernel cuda as well to do the casting here:
https://github.com/drisspg/driss_torch/blob/1a8c41c84c9521f35a7f9332da08813ea20608b1/driss_torch/__init__.py#L30

Performance comparing against inductor/triton can be found here:
drisspg/driss_torch@67c596e

Note though is that I also need to add an option to return the matrix in transposed format ( which we will need to fuse for the backward).

For delayed scaling we would expect the absmax calc to be fused into the prior op, and then this fused kernel could be used based off of the historical scale. We should weigh though if it is worth the added complexity to ship in this repo

@vkuzo
Copy link
Contributor

vkuzo commented Jul 30, 2024

moved to pytorch/ao#559

@vkuzo vkuzo closed this as completed Jul 30, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Perf Issues related to perf optimizations
Projects
None yet
Development

No branches or pull requests

3 participants