Skip to content

add gemm with rmsnorm #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: sycl-develop
Choose a base branch
from

Conversation

yuankuns
Copy link

Add an example with post op RMSNorm after gemm

Copy link
Collaborator

@aacostadiaz aacostadiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I left a few minor comments

gemm_op.run();
}
syclcompat::wait();
double io =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
double io =
double io =

#pragma once

#include "cutlass/cutlass.h"
#include <sycl/sycl.hpp>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sycl is already included in cutlass/cutlass.h (gpu_generics.h)

Suggested change
#include <sycl/sycl.hpp>

@joeatodd
Copy link
Collaborator

Hello @yuankuns. Since we made some extensive naming changes since you submitted this PR, I thought I'd help you out & provide the required changes to this branch. It's the last commit on my gemmrmsnorm-updates branch. That should fix the CI failures 👍

Copy link
Collaborator

@joeatodd joeatodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but I think the implementation could be made more EVT-friendly.

N * L * sizeof(ElementW));
syclcompat::wait();

constexpr float eps = 1e-5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use options.eps here.

Comment on lines +310 to +315
gemm_op.can_implement(arguments);

gemm_op.initialize(arguments, workspace.get());

// Run the GEMM
gemm_op.run();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow e.g. 00_pvc_gemm.cpp and ensure that the example returns an early failure if these steps fail.

using StrideWeight = Stride<_1, _0, int64_t>;
ElementWeight const* weight_ptr = nullptr;
float eps = 1e-5;
StrideWeight dWeight = {};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this unused?

auto loop_t = res(_, loop, _);
auto pow2_t = pow2_buff(_, loop, _);
Tensor group_sum = make_tensor<float>(make_shape(Int<vec_size>{}));
float rev_dim = 1 / (float)params.inner_dim;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
float rev_dim = 1 / (float)params.inner_dim;
const float rev_dim = 1 / static_cast<float>(params.inner_dim);

This could be brought out the loop too.

Comment on lines +213 to +215
int gx = syclcompat::global_id::x() % 256;
int gy = syclcompat::global_id::y();
auto gid = gx / 16 * 32 + gx % 16;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment to explain why these calculations are being performed would be useful. Why 256, 32, 16?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think gy is unused.

}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Epi_N / IntelPVCEpilogue::SubgroupSize; i++) {
const float wgt_per_col = (float)wgt_ptr[gid + i * IntelPVCEpilogue::SubgroupSize];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading this weight data here is kind of an anti-pattern in the context of EVT epilogues. There is a specific EVT operation for this: XeRowBroadcast, which will load data once and broadcast it as required. For example, the Linear Combination With Per Column Bias is defined as:

// template args...
using XeLinCombPerColBias =
  Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
    Sm90ScalarBroadcast<ElementScalar, Stride<_0,_0,int64_t>>, // beta
    Sm90SrcFetch<ElementSource>, // C
    Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
      Sm90ScalarBroadcast<ElementScalar, Stride<_0,_0,int64_t>>, // alpha
      Sm90AccFetch, // acc
      XeRowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias
    >
  >;

Since for RMSNorm, the * weight is effectively an independent calculation, this approach could be accomplished by:

  1. Remove all references to the weight from XeRMSNormRowReduction
  2. Define an outer layer in your SM90Evt definition which does an SM90Compute<multiplies,...>, taking inputs: XeRMSNormRowReduction and XeRowBroadcast.

Taking this approach has the advantages that:

  • It will be generally correct, regardless of thread_idx layout etc
  • We have one fewer 'load' operation in the library to optimize
  • The RMSNorm operation has fewer responsibilities, and could more easily be generalized (e.g. to reduce over more/different dimensions) in future

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants