You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: applications/scaled_mm/collective/xe_scaled_mm_mma_fp8.hpp
+16-39Lines changed: 16 additions & 39 deletions
Original file line number
Diff line number
Diff line change
@@ -71,8 +71,8 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
71
71
using ElementAccumulator = typename TiledMma::ValTypeC;
72
72
using GmemTiledCopyA = GmemTiledCopyA_;
73
73
using GmemTiledCopyB = GmemTiledCopyB_;
74
-
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel
75
-
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel
74
+
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
75
+
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
76
76
77
77
using SmemLayoutAtomA = SmemLayoutAtomA_;
78
78
using SmemLayoutAtomB = SmemLayoutAtomB_;
@@ -169,15 +169,15 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
0 commit comments