Skip to content

Commit ef7ff20

Browse files
committed
Adddressed the comments.
1 parent 59988b4 commit ef7ff20

File tree

4 files changed

+41
-46
lines changed

4 files changed

+41
-46
lines changed

applications/scaled_mm/collective/xe_scaled_mm_mma_fp8.hpp

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
7171
using ElementAccumulator = typename TiledMma::ValTypeC;
7272
using GmemTiledCopyA = GmemTiledCopyA_;
7373
using GmemTiledCopyB = GmemTiledCopyB_;
74-
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel
75-
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; //Have to use the same shape size as FP8 used in the kernel
74+
using GmemTiledCopyScaleA = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
75+
using GmemTiledCopyScaleB = XE_2D_U16x32x32_LD_N; // Shape of the copy atom for scales A must match shape of the copy atom for A in the number of elements
7676

7777
using SmemLayoutAtomA = SmemLayoutAtomA_;
7878
using SmemLayoutAtomB = SmemLayoutAtomB_;
@@ -169,15 +169,15 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
169169
make_layout(make_shape(M, K, L), args.dA));
170170
auto mB_nkl = make_tensor(make_gmem_ptr(static_cast<ElementB const*>(args.ptr_B)),
171171
make_layout(make_shape(N, K, L), args.dB));
172-
auto mscaleA_mkl = make_tensor(make_gmem_ptr(static_cast<ElementScaleA const*>(args.ptr_scaleA)),
172+
auto mScaleA_mkl = make_tensor(make_gmem_ptr(static_cast<ElementScaleA const*>(args.ptr_scaleA)),
173173
make_layout(make_shape(M, K, L), args.dscaleA));
174-
auto mscaleB_nkl = make_tensor(make_gmem_ptr(static_cast<ElementScaleB const*>(args.ptr_scaleB)),
174+
auto mScaleB_nkl = make_tensor(make_gmem_ptr(static_cast<ElementScaleB const*>(args.ptr_scaleB)),
175175
make_layout(make_shape(N, K, L), args.dscaleB));
176176

177177
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
178178
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};
179-
Copy_scaleA tiled_copy_scale_a{Copy_scaleA{}.with(mscaleA_mkl)};
180-
Copy_scaleB tiled_copy_scale_b{Copy_scaleB{}.with(mscaleB_nkl)};
179+
Copy_scaleA tiled_copy_scale_a{Copy_scaleA{}.with(mScaleA_mkl)};
180+
Copy_scaleB tiled_copy_scale_b{Copy_scaleB{}.with(mScaleB_nkl)};
181181

182182
return Params{tiled_copy_a, tiled_copy_b, tiled_copy_scale_a, tiled_copy_scale_b};
183183
}
@@ -204,41 +204,22 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
204204
using saType = typename EnginescaleA::value_type;
205205
using sbType = typename EnginescaleB::value_type;
206206

207-
auto const& a = a_tensor(_, _, _);
208-
auto const& b = b_tensor(_, _, _);
209-
210-
aType* pA = a.data();
211-
bType* pB = b.data();
207+
aType* pA = a_tensor.data();
208+
bType* pB = b_tensor.data();
212209
saType* psA = scale_a_tensor.data();
213210
sbType* psB = scale_b_tensor.data();
214211

215-
constexpr int a_num_elements = decltype(size(a))::value;
216-
constexpr int b_num_elements = decltype(size(b))::value;
212+
constexpr int a_num_elements = decltype(size(a_tensor))::value;
213+
constexpr int b_num_elements = decltype(size(b_tensor))::value;
217214

218215
for (int i = 0; i < a_num_elements; i++){
219-
reinterpret_cast<aType*>(pA)[i] = reinterpret_cast<aType*>(pA)[i] * reinterpret_cast<saType*>(psA)[i];
216+
pA[i] *= psA[i];
220217
}
221218
for (int i = 0; i < b_num_elements; i++){
222-
reinterpret_cast<bType*>(pB)[i] = reinterpret_cast<bType*>(pB)[i] * reinterpret_cast<sbType*>(psB)[i];
219+
pB[i] *= psB[i];
223220
}
224221
}
225222

226-
template <class EngineIn,
227-
class LayoutIn,
228-
class... Ts>
229-
CUTLASS_DEVICE
230-
void debug(Tensor<EngineIn, LayoutIn> & in)
231-
{
232-
auto const& src = in(_, _, _);
233-
using SrcType = typename EngineIn::value_type;
234-
SrcType* pSrc = src.data();
235-
constexpr int num_elements = decltype(size(src))::value;
236-
for (int i = 0; i < num_elements; i++){
237-
SrcType tmp = reinterpret_cast<SrcType*>(pSrc)[i];
238-
print(static_cast<float>(tmp)); print(' ');
239-
}
240-
}
241-
242223
//
243224
// Methods
244225
//
@@ -264,13 +245,10 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
264245
static_assert(std::is_same_v<SrcType, uint8_t>, "Expected fp8 (E4M3) input as uint8_t");
265246
static_assert(std::is_same_v<DstType, half_t>, "Expected fp16 output as half_t");
266247

267-
auto const& src = in(_, _, _);
268-
auto const& dst = out(_, _, _);
248+
SrcType const* pSrc = in.data();
249+
DstType* pDst = out.data();
269250

270-
SrcType const* pSrc = src.data();
271-
DstType* pDst = dst.data();
272-
273-
constexpr int num_elements = decltype(size(src))::value;
251+
constexpr int num_elements = decltype(size(in))::value;
274252
// TODO(Codeplay): Move conversion to NumericArrayConverter
275253
if constexpr (std::is_same_v<ElementA, float_e5m2_t>) {
276254
// Using something as simple as the following code surprisingly
@@ -285,7 +263,6 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
285263
SrcArray const* pSrcArr = reinterpret_cast<SrcArray const*>(pSrc);
286264
DstArray* pDstArr = reinterpret_cast<DstArray*>(pDst);
287265
E5M2_to_FP16<num_elements>(*pSrcArr, *pDstArr);
288-
//E5M2_to_FP16_and_scale_fp16<num_elements>(*pSrcArr, *pDstArr);
289266
} else {
290267
// E4M3 -> FP16 conversion
291268
constexpr int chunk_size = 16;
@@ -399,4 +376,4 @@ struct CollectiveMma<MainloopIntelScaledMMW8A8<Stages, Schedule>, TileShape_, El
399376

400377
} // namespace cutlass::gemm::collective
401378

402-
/////////////////////////////////////////////////////////////////////////////////////////////////
379+
/////////////////////////////////////////////////////////////////////////////////////////////////

applications/scaled_mm/kernel/xe_scaled_mm_fp8.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,4 +296,4 @@ class GemmScaledMM<
296296
297297
///////////////////////////////////////////////////////////////////////////////
298298
299-
} // namespace cutlass::gemm::kernel
299+
} // namespace cutlass::gemm::kernel

examples/sycl/09_bmg_scaled_mm_f8/09_bmg_scaled_mm_f8.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct Options {
8484

8585
bool help;
8686
bool error;
87+
bool e5m2;
8788

8889
int m, n, k, l, iterations;
8990
float alpha, beta;
@@ -104,6 +105,11 @@ struct Options {
104105
return;
105106
}
106107

108+
if (cmd.check_cmd_line_flag("e5m2")) {
109+
e5m2 = true;
110+
return;
111+
}
112+
107113
cmd.get_cmd_line_argument("m", m, 5120);
108114
cmd.get_cmd_line_argument("n", n, 4096);
109115
cmd.get_cmd_line_argument("k", k, 4096);
@@ -125,7 +131,8 @@ struct Options {
125131
<< " --l=<int> Sets the L extent (batch count) of the GEMM\n"
126132
<< " --alpha=<s32> Epilogue scalar alpha\n"
127133
<< " --beta=<s32> Epilogue scalar beta\n\n"
128-
<< " --iterations=<int> Iterations\n\n";
134+
<< " --iterations=<int> Iterations\n\n"
135+
<< " --e5m2 Use e5m2\n\n";
129136

130137
return out;
131138
}
@@ -205,7 +212,10 @@ struct ExampleRunner {
205212
}
206213

207214
syclcompat::memcpy(d_src, h_src_multiplied, size * sizeof(SrcT));
208-
syclcompat::wait();
215+
syclcompat::wait();
216+
217+
delete[] h_src_multiplied;
218+
delete[] h_scale;
209219
}
210220

211221
template <typename SrcT, typename DstT>
@@ -221,6 +231,9 @@ struct ExampleRunner {
221231

222232
syclcompat::memcpy(d_dst, h_dst, size * sizeof(DstT));
223233
syclcompat::wait();
234+
235+
delete[] h_src;
236+
delete[] h_dst;
224237
}
225238

226239
bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) {
@@ -393,10 +406,15 @@ int main(int argc, const char** argv)
393406
bool passed;
394407

395408
using ElementAccumulator = float;
396-
using ElementComputeEpilogue = float;
397-
// TODO: support E5M2
409+
using ElementComputeEpilogue = float;
398410
using ElementInputA = cutlass::float_e4m3_t;
399411
using ElementInputB = cutlass::float_e4m3_t;
412+
413+
if (options.e5m2){
414+
using ElementInputA = cutlass::float_e5m2_t;
415+
using ElementInputB = cutlass::float_e5m2_t;
416+
}
417+
400418
using ElementOutput = float;
401419

402420
using LayoutA = cutlass::layout::RowMajor;
@@ -460,4 +478,4 @@ int main(int argc, const char** argv)
460478
CUTLASS_CHECK(runner.run(options, hw_info));
461479

462480
return 0;
463-
}
481+
}

examples/sycl/09_bmg_scaled_mm_f8/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ set(CUTLASS_EXAMPLES_DIR ${CMAKE_SOURCE_DIR}/examples)
3333
cutlass_example_add_executable(
3434
09_bmg_scaled_mm_f8
3535
09_bmg_scaled_mm_f8.cpp
36-
)
36+
)

0 commit comments

Comments
 (0)