diff --git a/feature_transformer.py b/feature_transformer.py index c9ee33f0..9da3538f 100644 --- a/feature_transformer.py +++ b/feature_transformer.py @@ -1,317 +1,172 @@ +import math import torch +import triton +import triton.language as tl + from torch import nn from torch import autograd -import cupy as cp -import math - -def _find_nearest_divisor(value, target): - divisors = [] - for i in range(1, value + 1): - if value % i == 0: - divisors.append((i, abs(target - i))) - divisors.sort(key=lambda x: x[1]) - return divisors[0][0] - -_num_threads_forward_cache = dict() - - -def _get_num_threads_for_forward(output_size): - optimal_num_threads = 512 - if output_size not in _num_threads_forward_cache: - _num_threads_forward_cache[output_size] = _find_nearest_divisor( - output_size, optimal_num_threads +@triton.autotune( + configs=[ + triton.Config({"OUTPUT_BLOCK_SIZE": 32}), + triton.Config({"OUTPUT_BLOCK_SIZE": 64}), + triton.Config({"OUTPUT_BLOCK_SIZE": 128}), + triton.Config({"OUTPUT_BLOCK_SIZE": 256}), + triton.Config({"OUTPUT_BLOCK_SIZE": 512}), + triton.Config({"OUTPUT_BLOCK_SIZE": 1024}), + triton.Config({"OUTPUT_BLOCK_SIZE": 2048}), + ], + key=["max_active_features", "output_size"] +) +@triton.jit +def _feature_transformer_slice_forward_kernel( + feature_indices, + feature_values, + weight, + bias, + output, + max_active_features: tl.constexpr, + output_size: tl.constexpr, + OUTPUT_BLOCK_SIZE: tl.constexpr +): + batch_idx = tl.program_id(0) + output_block_idx = tl.program_id(1) + + output_offsets = OUTPUT_BLOCK_SIZE * output_block_idx + tl.arange(0, OUTPUT_BLOCK_SIZE) + output_mask = output_offsets < output_size + + feature_indices_slice = feature_indices + batch_idx * max_active_features + feature_values_slice = feature_values + batch_idx * max_active_features + output_slice = output + batch_idx * output_size + + acc = tl.load(bias + output_offsets, mask=output_mask, other=0.0) + acc = acc.to(tl.float32) + + past_active_features = False + for k in range(max_active_features): + if not past_active_features: + feature_idx = tl.load(feature_indices_slice + k) + if feature_idx == -1: + past_active_features = True + else: + curr_feature_values = tl.load(feature_values_slice + k) + curr_weight_values = tl.load(weight + feature_idx * output_size + output_offsets, mask=output_mask, other=0.0) + acc += curr_weight_values * curr_feature_values + + tl.store(output_slice + output_offsets, acc, mask=output_mask) + + +def feature_transformer_slice_forward( + feature_indices, + feature_values, + weight, + bias, + output, + batch_size, + max_active_features, + output_size +): + def grid(meta): + return (batch_size, triton.cdiv(output_size, meta["OUTPUT_BLOCK_SIZE"])) + + _feature_transformer_slice_forward_kernel[grid]( + feature_indices=feature_indices, + feature_values=feature_values, + weight=weight, + bias=bias, + output=output, + max_active_features=max_active_features, + output_size=output_size, + ) + + +@triton.autotune( + configs=[ + triton.Config({"OUTPUT_BLOCK_SIZE": 8}), + triton.Config({"OUTPUT_BLOCK_SIZE": 16}), + triton.Config({"OUTPUT_BLOCK_SIZE": 32}), + triton.Config({"OUTPUT_BLOCK_SIZE": 64}), + triton.Config({"OUTPUT_BLOCK_SIZE": 128}), + triton.Config({"OUTPUT_BLOCK_SIZE": 256}), + triton.Config({"OUTPUT_BLOCK_SIZE": 512}), + triton.Config({"OUTPUT_BLOCK_SIZE": 1024}), + ], + key=["max_active_features", "output_size"] +) +@triton.jit +def _feature_transformer_slice_backward_kernel( + feature_indices, + feature_values, + bias_grad, + weight_grad, + output_grad, + max_active_features: tl.constexpr, + output_size: tl.constexpr, + OUTPUT_BLOCK_SIZE: tl.constexpr +): + batch_idx = tl.program_id(0) + output_block_idx = tl.program_id(1) + + output_offsets = OUTPUT_BLOCK_SIZE * output_block_idx + tl.arange(0, OUTPUT_BLOCK_SIZE) + output_mask = output_offsets < output_size + + feature_indices_slice = feature_indices + batch_idx * max_active_features + feature_values_slice = feature_values + batch_idx * max_active_features + + output_grad_slice = output_grad + batch_idx * output_size + output_grad_values = tl.load(output_grad_slice + output_offsets, mask=output_mask, other=0.0) + nonzero_grad_mask = output_mask & (output_grad_values != 0) + + tl.atomic_add( + bias_grad + output_offsets, + output_grad_values, + mask=nonzero_grad_mask + ) + + past_active_features = False + k = 0 + while k < max_active_features and not past_active_features: + feature_idx = tl.load(feature_indices_slice + k) + if feature_idx == -1: + past_active_features = True + else: + curr_feature_values = tl.load(feature_values_slice + k) + curr_weight_grad_values = output_grad_values * curr_feature_values + tl.atomic_add( + weight_grad + feature_idx * output_size + output_offsets, + curr_weight_grad_values, + mask=nonzero_grad_mask + ) + k += 1 + + +def feature_transformer_slice_backward( + feature_indices, + feature_values, + bias_grad, + weight_grad, + output_grad, + batch_size, + max_active_features, + output_size +): + def grid(meta): + return ( + batch_size, + triton.cdiv(output_size, meta['OUTPUT_BLOCK_SIZE']) ) - return _num_threads_forward_cache[output_size] - - -_num_threads_backward_cache = dict() - + _feature_transformer_slice_backward_kernel[grid]( + feature_indices=feature_indices, + feature_values=feature_values, + weight_grad=weight_grad, + bias_grad=bias_grad, + output_grad=output_grad, + max_active_features=max_active_features, + output_size=output_size + ) -def _get_num_threads_for_backward(output_size): - optimal_num_threads = 512 - if output_size not in _num_threads_backward_cache: - _num_threads_backward_cache[output_size] = _find_nearest_divisor( - output_size, optimal_num_threads - ) - - return _num_threads_backward_cache[output_size] - - -def _kernel_with_threads(kernel, threads): - def f(grid, args): - kernel(grid=grid, block=threads, args=args) - - return f - - -_feature_transformer_slice_forward_kernel_cache = dict() - - -@torch.compiler.disable(recursive=False) -def make_feature_transformer_slice_forward_kernel(max_active_features, output_size): - """ - @param: max_active_features - The maximum number of features that are active - (non-zero) for a single position. This value determines - the shape of the inputs. - This value is of type uint32_t. - - @param: output_size - The number of outputs. Must match the shape of weights - and biases. - This value is of type uint32. - """ - num_threads = _get_num_threads_for_forward(output_size) - output_thread_slice_size = output_size // num_threads - key = (max_active_features, output_size, num_threads) - if key not in _feature_transformer_slice_forward_kernel_cache: - kernel = cp.RawKernel( - r""" - -typedef unsigned int uint32_t; -typedef int int32_t; - -extern "C" __global__ - -/* - @assumptions: - The blocks must have dimensionality (BATCH_SIZE,) - The threads must have dimensionality (N,), where - N * output_thread_slice_size == output_size. - - @param: feature_indices - A matrix of shape (BATCH_SIZE, max_active_features) - containing indices of active features for each position - in a batch. Feature index of -1 means that the slot is empty - and the weights will not be accumulated for it. Moreover - no further indices from this block will be considered. - The indices form an implicit matrix of shape - (BATCH_SIZE, NUM_INPUTS), where the first dimension index is - inferred from the memory location (BATCH_SIZE), and the - second dimension index is stored in the feature_indices matrix. - The type for feature indices is int32_t. - - @param: feature_values - A matrix of shape (BATCH_SIZE, max_active_features) - containing the values (arity) of the corresponding - feature index in feature_indices. - The type for the feature value (arity) is float32. - - @param: weight - The weight matrix of shape (NUM_INPUTS, output_size). - Weights must be of type float32. - - @param: bias - The bias vector of shape (output_size,). - Bias values must be of type float32. - - @param: output - An output matrix of shape (BATCH_SIZE, output_size). - It may not be initialized, bias is always copied - to the output first. - Output values must have type float32. -*/ -void feature_transformer_slice_forward( - const int32_t* const feature_indices, - const float* const feature_values, - const float* const weight, - const float* const bias, - float* const output -) {{ - __shared__ - float shared_output[{output_size}]; - - const uint32_t block_idx = blockIdx.x; - const uint32_t slice_offset = threadIdx.x * {output_thread_slice_size}; - - float* const output_slice = output + block_idx * {output_size} + slice_offset; - const float* const bias_slice = bias + slice_offset; - float* shared_output_slice = shared_output + slice_offset; - - const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features}; - const float* const feature_value_row = feature_values + block_idx * {max_active_features}; - - #pragma unroll - for (uint32_t s = 0; s < {output_thread_slice_size}; ++s) - {{ - shared_output_slice[s] = bias_slice[s]; - }} - - for (uint32_t k = 0; k < {max_active_features}; ++k) - {{ - const int32_t feature_index = feature_index_row[k]; - const float feature_value = feature_value_row[k]; - if (feature_index != -1) - {{ - const float* const weight_slice = weight + feature_index * {output_size} + slice_offset; - #pragma unroll - for (uint32_t s = 0; s < {output_thread_slice_size}; ++s) - {{ - shared_output_slice[s] += weight_slice[s] * feature_value; - }} - }} else break; - }} - - #pragma unroll - for (uint32_t s = 0; s < {output_thread_slice_size}; ++s) - {{ - output_slice[s] = shared_output_slice[s]; - }} -}} - -""".format( - max_active_features=max_active_features, - output_thread_slice_size=output_thread_slice_size, - output_size=output_size, - ), - "feature_transformer_slice_forward", - ) - kernel.compile() - _feature_transformer_slice_forward_kernel_cache[key] = _kernel_with_threads( - kernel, (num_threads,) - ) - return _feature_transformer_slice_forward_kernel_cache[key] - - -_feature_transformer_slice_backward_kernel_cache = dict() - - -@torch.compiler.disable(recursive=False) -def make_feature_transformer_slice_backward_kernel(max_active_features, output_size): - """' - @param: max_active_features - The maximum number of features that are active - (non-zero) for a single position. This value determines - the shape of the inputs. - This value is of type uint32_t. - - @param: output_size - The number of outputs. Must match the shape of weights - and biases. - This value is of type uint32. - """ - num_threads = _get_num_threads_for_backward(output_size) - output_thread_slice_size = output_size // num_threads - key = (max_active_features, output_size, num_threads) - if key not in _feature_transformer_slice_backward_kernel_cache: - kernel = cp.RawKernel( - r""" - -typedef unsigned int uint32_t; -typedef int int32_t; - -extern "C" __global__ -/* - @assumptions: - The blocks must have dimensionality (BATCH_SIZE,) - The threads must have dimensionality (N,), where - N * output_thread_slice_size == output_size. - - @param: feature_indices - A matrix of shape (BATCH_SIZE, max_active_features) - containing indices of active features for each position - in a batch. Feature index of -1 means that the slot is empty - and the weights will not be accumulated for it. Moreover - no further indices from this block will be considered. - The indices form an implicit matrix of shape - (BATCH_SIZE, NUM_INPUTS), where the first dimension index is - inferred from the memory location (BATCH_SIZE), and the - second dimension index is stored in the feature_indices matrix. - The type for feature indices is int32_t. - - @param: feature_values - A matrix of shape (BATCH_SIZE, max_active_features) - containing the values (arity) of the corresponding - feature index in feature_indices. - The type for the feature value (arity) is float32. - - @param: weight_grad - The weight gradient matrix of shape (NUM_INPUTS, output_size). - The gradient is accumulated, i.e. it must be zero initialized - on the first call. - Weights must be of type float32. - - @param: bias_grad - The bias gradient vector of shape (output_size,). - The gradient is accumulated, i.e. it must be zero initialized - on the first call. - Bias values must be of type float32. - - @param: output_grad - An output gradient matrix of shape (BATCH_SIZE, output_size). - Output values must have type float32. -*/ -void feature_transformer_slice_backward( - const int32_t* const feature_indices, - const float* const feature_values, - float* const weight_grad, - float* const bias_grad, - const float* const output_grad -) {{ - __shared__ - float shared_output_grad[{output_size}]; - - const uint32_t block_idx = blockIdx.x; - const uint32_t slice_offset = threadIdx.x * {output_thread_slice_size}; - - const float* const output_grad_slice = output_grad + block_idx * {output_size} + slice_offset; - float* const bias_grad_slice = bias_grad + slice_offset; - float* shared_output_grad_slice = shared_output_grad + slice_offset; - - const int32_t* const feature_index_row = feature_indices + block_idx * {max_active_features}; - const float* const feature_value_row = feature_values + block_idx * {max_active_features}; - - #pragma unroll - for (uint32_t s = 0; s < {output_thread_slice_size}; ++s) - {{ - shared_output_grad_slice[s] = output_grad_slice[s]; - }} - - #pragma unroll - for (uint32_t s = 0; s < {output_thread_slice_size}; ++s) - {{ - const float sog = shared_output_grad_slice[s]; - if (sog != 0.0f) - {{ - atomicAdd(&bias_grad_slice[s], sog); - }} - }} - - for (uint32_t k = 0; k < {max_active_features}; ++k) - {{ - const int32_t feature_index = feature_index_row[k]; - const float feature_value = feature_value_row[k]; - if (feature_index != -1) - {{ - float* const weight_grad_slice = weight_grad + feature_index * {output_size} + slice_offset; - #pragma unroll - for (int s = 0; s < {output_thread_slice_size}; ++s) - {{ - const float sog = shared_output_grad_slice[s]; - if (sog != 0.0f) - {{ - atomicAdd(&weight_grad_slice[s], sog * feature_value); - }} - }} - }} else break; - }} -}} - -""".format( - max_active_features=max_active_features, - output_thread_slice_size=output_thread_slice_size, - output_size=output_size, - ), - "feature_transformer_slice_backward", - ) - kernel.compile() - _feature_transformer_slice_backward_kernel_cache[key] = _kernel_with_threads( - kernel, (num_threads,) - ) - return _feature_transformer_slice_backward_kernel_cache[key] class FeatureTransformerSliceFunction(autograd.Function): @@ -352,25 +207,21 @@ def forward(ctx, feature_indices, feature_values, weight, bias): output_size = weight.shape[1] output = torch.empty( - batch_size, - output_size, + (batch_size, output_size), dtype=torch.float32, device=device, requires_grad=True, ) - kernel = make_feature_transformer_slice_forward_kernel( - max_active_features, output_size - ) - kernel( - grid=(batch_size,), - args=( - feature_indices.data_ptr(), - feature_values.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), - output.data_ptr(), - ), + feature_transformer_slice_forward( + feature_indices=feature_indices, + feature_values=feature_values, + weight=weight, + bias=bias, + output=output, + batch_size=batch_size, + max_active_features=max_active_features, + output_size=output_size ) return output @@ -394,18 +245,15 @@ def backward(ctx, grad_output): ) bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device) - kernel = make_feature_transformer_slice_backward_kernel( - max_active_features, output_size - ) - kernel( - grid=(batch_size,), - args=( - feature_indices.data_ptr(), - feature_values.data_ptr(), - weight_grad.data_ptr(), - bias_grad.data_ptr(), - grad_output.data_ptr(), - ), + feature_transformer_slice_backward( + feature_indices=feature_indices, + feature_values=feature_values, + weight_grad=weight_grad, + bias_grad=bias_grad, + output_grad=grad_output, + batch_size=batch_size, + max_active_features=max_active_features, + output_size=output_size ) return None, None, weight_grad, bias_grad @@ -476,47 +324,42 @@ def forward( max_active_features = feature_indices_0.shape[1] output_size = weight.shape[1] - output0 = torch.empty( + output_0 = torch.empty( batch_size, output_size, dtype=torch.float32, device=device, - requires_grad=True, ) - output1 = torch.empty( + output_1 = torch.empty( batch_size, output_size, dtype=torch.float32, device=device, - requires_grad=True, ) - kernel = make_feature_transformer_slice_forward_kernel( - max_active_features, output_size - ) - kernel( - grid=(batch_size,), - args=( - feature_indices_0.data_ptr(), - feature_values_0.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), - output0.data_ptr(), - ), + feature_transformer_slice_forward( + feature_indices=feature_indices_0, + feature_values=feature_values_0, + weight=weight, + bias=bias, + output=output_0, + batch_size=batch_size, + max_active_features=max_active_features, + output_size=output_size ) - kernel( - grid=(batch_size,), - args=( - feature_indices_1.data_ptr(), - feature_values_1.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), - output1.data_ptr(), - ), + feature_transformer_slice_forward( + feature_indices=feature_indices_1, + feature_values=feature_values_1, + weight=weight, + bias=bias, + output=output_1, + batch_size=batch_size, + max_active_features=max_active_features, + output_size=output_size ) - return output0, output1 + return output_0, output_1 @staticmethod def backward(ctx, grad_output_0, grad_output_1): @@ -545,29 +388,26 @@ def backward(ctx, grad_output_0, grad_output_1): ) bias_grad = torch.zeros(output_size, dtype=torch.float32, device=device) - kernel = make_feature_transformer_slice_backward_kernel( - max_active_features, output_size - ) - kernel( - grid=(batch_size,), - args=( - feature_indices_0.data_ptr(), - feature_values_0.data_ptr(), - weight_grad.data_ptr(), - bias_grad.data_ptr(), - grad_output_0.data_ptr(), - ), + feature_transformer_slice_backward( + feature_indices=feature_indices_0, + feature_values=feature_values_0, + weight_grad=weight_grad, + bias_grad=bias_grad, + output_grad=grad_output_0, + batch_size=batch_size, + max_active_features=max_active_features, + output_size=output_size, ) - kernel( - grid=(batch_size,), - args=( - feature_indices_1.data_ptr(), - feature_values_1.data_ptr(), - weight_grad.data_ptr(), - bias_grad.data_ptr(), - grad_output_1.data_ptr(), - ), + feature_transformer_slice_backward( + feature_indices=feature_indices_1, + feature_values=feature_values_1, + weight_grad=weight_grad, + bias_grad=bias_grad, + output_grad=grad_output_1, + batch_size=batch_size, + max_active_features=max_active_features, + output_size=output_size, ) return None, None, None, None, weight_grad, bias_grad @@ -723,11 +563,9 @@ def bench(): output0, output1 = layer(indices0, values0, indices1, values1) - device = indices0.device - start = time.time() - for i in range(ITERS): + for _ in range(ITERS): output0, output1 = layer(indices0, values0, indices1, values1) output0 = torch.clamp(output0, 0.0, 1.0) output1 = torch.clamp(output1, 0.0, 1.0)