Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions paddle/phi/kernels/funcs/transpose_function.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,12 +557,160 @@ __global__ void TransposeSimpleKernel(IndexType nthreads,
}
}

typedef struct alignas(8) fp8x8_t {
union data_t {
phi::float8_e4m3fn scalar[8];
uint2 vector;
};
data_t data;

__device__ __forceinline__ void load(const void* ptr) {
data = *reinterpret_cast<const data_t*>(ptr);
}

__device__ __forceinline__ void store(void* ptr) const {
*reinterpret_cast<data_t*>(ptr) = data;
}
} fp8x8_t;

constexpr int kVecSize = 8;
constexpr int BLOCK_DIM = 16;
constexpr int BLOCK_TILE_SIZE = 128;
constexpr int BLOCK_TILE_WIDTH = BLOCK_TILE_SIZE;
constexpr int BLOCK_TILE_HEIGHT = BLOCK_TILE_SIZE;
constexpr int THREAD_TILE_DIM = BLOCK_TILE_SIZE / BLOCK_DIM;

__global__ void
__launch_bounds__(BLOCK_DIM* BLOCK_DIM) inline fp8_fast_transpose_kernel(
const phi::float8_e4m3fn* __restrict__ src, // Source matrix (M x N)
phi::float8_e4m3fn* __restrict__ dst, // Destination matrix (N x M)
int B,
int M,
int N, // Batch size, M-dimension, N-dimension
size_t batch_stride) { // Stride between batches in global memory (M*N
// elements)
// Shared memory tile with padding to avoid bank conflicts, padding instead of
// swizzle for better performance
__shared__ __align__(1024)
fp8x8_t smem[BLOCK_TILE_HEIGHT][BLOCK_TILE_WIDTH / kVecSize + 1];

// Thread-local storage: 8 fp8x8_t units, effectively an 8x8 block of fp8_t
// values.
fp8x8_t local_tile[kVecSize];
fp8x8_t local_tile_transposed[kVecSize];

// Thread indices within the block (0-15 for x and y, since 16x16 = 256
// threads)
const uint32_t tid_x = threadIdx.x; // Column-wise thread index (0-15)
const uint32_t tid_y = threadIdx.y; // Row-wise thread index (0-15)

// Block indices within the grid
const uint32_t block_x = blockIdx.x; // Tile index along N-dimension
const uint32_t block_y = blockIdx.y; // Tile index along M-dimension
const uint32_t block_z = blockIdx.z; // Batch index

// Calculate global offsets for the current block's tile in the M x N source
// matrix
const uint32_t global_m_offset =
block_y * BLOCK_TILE_HEIGHT; // Starting M index for this block
const uint32_t global_n_offset =
block_x * BLOCK_TILE_WIDTH; // Starting N index for this block

const size_t current_batch_offset =
static_cast<size_t>(batch_stride) * block_z;

// 1. Load src into register in uint2 vectorized manner.
#pragma unroll
for (uint32_t k = 0; k < THREAD_TILE_DIM;
++k) { // Iterate 8 times for the 8 rows in the thread's block
const uint32_t src_global_row =
global_m_offset + tid_y * THREAD_TILE_DIM + k;
const uint32_t src_global_col_start =
global_n_offset + tid_x * THREAD_TILE_DIM;

// Check bounds for source matrix before loading
// THREAD_TILE_DIM (8) is the width of the fp8x8_t block.
const phi::float8_e4m3fn* src_ptr =
src + current_batch_offset + static_cast<size_t>(src_global_row) * N +
src_global_col_start;
local_tile[k].load(src_ptr);
}

// 2. Transpose local_tile in register level.
#pragma unroll
for (uint32_t k_row = 0; k_row < THREAD_TILE_DIM; ++k_row) {
#pragma unroll
for (uint32_t k_col = 0; k_col < THREAD_TILE_DIM; ++k_col) {
local_tile_transposed[k_col].data.scalar[k_row] =
local_tile[k_row].data.scalar[k_col];
}
}

// 3. Store transposed data to shared memory
#pragma unroll
for (uint32_t k = 0; k < THREAD_TILE_DIM; ++k) {
const uint32_t smem_row = tid_x * THREAD_TILE_DIM + k;
const uint32_t smem_col_start = tid_y * THREAD_TILE_DIM / 8; // = tid_y
smem[smem_row][smem_col_start] = local_tile_transposed[k];
}

__syncthreads();

// 4. Store from shared memory to dst in uint2 vectorized manner.
#pragma unroll
for (uint32_t k = 0; k < THREAD_TILE_DIM; ++k) {
const uint32_t dst_global_row =
global_n_offset + tid_y * THREAD_TILE_DIM + k;
const uint32_t dst_global_col_start =
global_m_offset + tid_x * THREAD_TILE_DIM;

size_t offset = current_batch_offset +
static_cast<size_t>(dst_global_row) * M +
dst_global_col_start;
phi::float8_e4m3fn* dst_ptr = dst + offset;

fp8x8_t output_block;
const uint32_t smem_row = tid_y * THREAD_TILE_DIM + k;
const uint32_t smem_col = tid_x * THREAD_TILE_DIM / kVecSize; // = tid_x
output_block = smem[smem_row][smem_col];
output_block.store(dst_ptr);
}
}

template <typename T, typename IndexType = int>
void dispatch_fp8_fast_transpose_kernel(const phi::GPUContext& d,
const T* input,
const uint32_t B,
const uint32_t M,
const uint32_t N,
T* output) {
dim3 grid, block;
block.x = BLOCK_DIM; // 256 threads per block
block.y = BLOCK_DIM;

grid.z = B;
grid.y = M / BLOCK_TILE_SIZE; // not for un-aligned
grid.x = N / BLOCK_TILE_SIZE; // not for un-aligned

fp8_fast_transpose_kernel<<<grid, block, 0, d.stream()>>>(
input, output, B, M, N, static_cast<size_t>(M) * static_cast<size_t>(N));
}

// Here suppose convert all tensor to dim3, so just change dim1 and 2.
template <typename T, typename IndexType = int>
void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
const T* input,
const Dim3<IndexType>& input_dims,
T* output) {
// FP8 fast path
if constexpr (std::is_same<T, phi::float8_e4m3fn>::value) {
if (input_dims[1] >= 128 && input_dims[2] >= 128 &&
input_dims[1] % 128 == 0 && input_dims[2] % 128 == 0) {
dispatch_fp8_fast_transpose_kernel<T, IndexType>(
d, input, input_dims[0], input_dims[1], input_dims[2], output);
return;
}
}
// Suppose tile size > 16
static const int kMinTileSize = 16;
static const int kMinNarrowTileSize = 96;
Expand Down
34 changes: 34 additions & 0 deletions test/legacy_test/test_transpose_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,40 @@ def test_check_grad(self):
)


@unittest.skipIf(
not paddle.base.core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 9.0,
"core is not compiled with CUDA or not support native fp8",
)
class TestFP8FastTranspose(unittest.TestCase):
def setUp(self):
self.dtype = paddle.float8_e4m3fn
self.test_cases = [
{"shape": (7168, 16384), "perm": [1, 0], "name": "2D(7168,16384)"},
{
"shape": (8, 7168, 4096),
"perm": [0, 2, 1],
"name": "3D(8,7168,4096)",
},
{
"shape": (8, 2048, 7168),
"perm": [0, 2, 1],
"name": "3D(8,2048,7168)",
},
]

def test_verify_transpose(self):
paddle.disable_static()
with paddle.no_grad():
for case in self.test_cases:
x = paddle.randn(case["shape"]).cast(self.dtype)
np_data = x.numpy()
gold = np.transpose(np_data, case["perm"])
out = paddle.transpose(x, case["perm"]).contiguous()
np.testing.assert_equal(out.numpy(), gold)
paddle.enable_static()


class TestAutoTuneTransposeFP16Op(OpTest):
def setUp(self):
self.init_op_type()
Expand Down