Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tilegym/ops/cutile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# NN operations
from . import activation
from . import attention
from . import dropout
from . import flash_decode
from . import group_gemm
from . import matmul
Expand Down Expand Up @@ -59,6 +60,7 @@
"get_rms_norm_module",
"rms_norm",
"silu_and_mul",
"dropout",
"softmax",
"mla_decoding_split_kv",
"moe",
Expand Down
166 changes: 166 additions & 0 deletions src/tilegym/ops/cutile/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: MIT

import math

import cuda.tile as ct
import torch

from tilegym.backend import register_impl


@ct.kernel
def dropout_kernel_ct(
x,
output,
p: ct.Constant[float],
seed: ct.Constant[int],
TILE_SIZE: ct.Constant[int],
training: ct.Constant[bool],
):
"""
cuTile kernel for dropout operation.

Args:
x: Input tensor
output: Output tensor
p: Dropout probability
seed: Random seed
TILE_SIZE: Tile size for computation
training: Whether in training mode
"""
# Get program ID
bid = ct.bid(0)
tile_start = bid * TILE_SIZE

# Create offset tile
offsets = ct.add(ct.arange(TILE_SIZE, dtype=ct.int32), tile_start)
# Load input data using gather
# For 1D arrays, indices are passed directly (not as tuple)
# Use padding_value=0 (int) to avoid dtype mismatch with float16
x_tile = ct.gather(x, offsets, padding_value=0)

# Initialize output tile
output_tile = ct.zeros((TILE_SIZE,), dtype=x_tile.dtype)

# Only apply dropout if training
if training:
# Generate pseudo-random numbers using a simple hash function
# This is a deterministic approximation since cuTile doesn't have tl.rand
# Use a simple hash based on offsets and seed
# Combine seed and offsets with a simple formula
combined = ct.add(
ct.mul(offsets, 1103515245), # Large prime number
ct.full((TILE_SIZE,), seed, dtype=ct.int32),
)

# Apply a simple hash function using available bitwise operations
hash_val = ct.bitwise_xor(combined, ct.bitwise_rshift(combined, 16))
hash_val = ct.bitwise_xor(hash_val, ct.bitwise_lshift(hash_val, 8))
hash_val = ct.bitwise_xor(hash_val, ct.bitwise_rshift(hash_val, 4))

# Convert to float and normalize to [0, 1)
hash_float = ct.truediv(
ct.astype(ct.bitwise_and(hash_val, 0x7FFFFFFF), ct.float32),
2147483647.0, # 2^31 - 1
)

# Create mask for elements to keep
keep_mask = ct.greater(hash_float, p)

# Apply dropout: x / (1-p) if kept, 0 otherwise
scale = ct.full((TILE_SIZE,), 1.0 / (1.0 - p), dtype=x_tile.dtype)
scaled_x = ct.mul(x_tile, scale)
output_tile = ct.where(keep_mask, scaled_x, output_tile)
else:
# In inference mode, just copy input to output
output_tile = x_tile

# Store result
ct.scatter(output, offsets, output_tile)


class Dropout_CT(torch.autograd.Function):
@staticmethod
def forward(ctx, x, seed, p=0.5, training=True, inplace=False):
"""
Forward pass for dropout.

Args:
x: Input tensor
seed: Random seed
p: Dropout probability
training: Whether in training mode
inplace: Whether to perform operation in-place

Returns:
Output tensor with dropout applied
"""
if not training:
ctx.mark_dirty(x)
return x

if inplace:
ctx.mark_dirty(x)
output = x
else:
output = torch.empty_like(x)

assert x.is_contiguous()

n_elements = x.numel()

# Launch kernel
TILE_SIZE = 1024
grid = (math.ceil(n_elements / TILE_SIZE), 1, 1)

# Reshape to 1D for processing
x_flat = x.view(-1)
output_flat = output.view(-1)

# Convert seed to int32 to avoid overflow
seed_int32 = int(seed) % 2147483647 # Convert to int32 range

ct.launch(
torch.cuda.current_stream(),
grid,
dropout_kernel_ct,
(
x_flat,
output_flat,
p,
seed_int32,
TILE_SIZE,
training,
),
)

ctx.p = p
ctx.seed = seed
return output

@staticmethod
def backward(ctx, dy):
raise NotImplementedError("Backward pass for dropout is not implemented")


@register_impl("dropout", backend="cutile")
def dropout(x, seed, p=0.5, training=True, inplace=False, **kwargs):
"""
cuTile implementation of dropout.

Performs dropout on x.

Args:
x: Input tensor
seed: Integer value for initializing random mask
p: Dropout probability, default is 0.5
training: If True perform dropout, else return x
inplace: If True, modify x directly with dropout
**kwargs: Additional arguments for backend-specific configurations

Returns:
Tensor with dropout applied
"""
return Dropout_CT.apply(x, seed, p, training, inplace)
28 changes: 28 additions & 0 deletions src/tilegym/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,34 @@ def silu_and_mul(
raise NotImplementedError(f"silu_and_mul is not implemented for {get_current_backend()}")


@dispatch(
"dropout",
)
def dropout(
x: torch.Tensor,
seed: int,
p: float = 0.5,
training: bool = True,
inplace: bool = False,
**kwargs: Any,
):
"""
Dropout operation with stateless random generation from a given seed.

Args:
x: Input tensor
seed: Integer seed used to generate dropout mask in the kernel
p: Drop probability, default is 0.5
training: Apply dropout if True, otherwise return input
inplace: If True, perform the operation in-place
**kwargs: Additional arguments for backend-specific configurations

Returns:
Tensor of the same shape as `x` with dropout applied
"""
raise NotImplementedError(f"dropout is not implemented for {get_current_backend()}")


@dispatch(
"softmax",
)
Expand Down
1 change: 1 addition & 0 deletions tests/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ python bench_matrix_multiplication.py
```

Available benchmark scripts:
- `bench_dropout.py`
- `bench_fused_attention.py`
- `bench_matrix_multiplication.py`
- `bench_mix_triton_cutile.py`
Expand Down
113 changes: 113 additions & 0 deletions tests/benchmark/bench_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: MIT

import torch
import triton
import triton.testing

import tilegym
from tilegym.backend import is_backend_available
from tilegym.backend import register_impl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

torch.manual_seed(0) # For reproducibility


def reference_dropout(
x: torch.Tensor,
seed: int,
p: float = 0.5,
training: bool = True,
inplace: bool = False,
**kwargs,
):
"""Reference implementation of dropout using PyTorch"""
# Seed is accepted for interface compatibility but not used by PyTorch's dropout.
return torch.nn.functional.dropout(x, p=p, training=training, inplace=inplace)


register_impl("dropout", "torch")(reference_dropout)


# Available backends for benchmarking
ALL_BACKENDS = [
("cutile", "CuTile", ("orange", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),
]


def get_supported_backends():
"""Filter backends based on availability"""
return [p for p in ALL_BACKENDS if p is not None]


def create_benchmark_config(datatype, p: float):
"""Create a benchmark configuration for dropout"""
available_backends = get_supported_backends()
if not available_backends:
return None

backends, names, styles = zip(*available_backends)
dtype_name = str(datatype).split(".")[-1] # e.g., 'float16' from 'torch.float16'

return triton.testing.Benchmark(
x_names=["M"],
x_vals=[2**i for i in range(20, 28, 2)],
line_arg="backend",
line_vals=list(backends),
line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"dropout-p{p}-{dtype_name}-GBps",
args={
"p": p,
"datatype": datatype,
},
)


@triton.testing.perf_report(
[create_benchmark_config(datatype, p) for datatype in [torch.float16, torch.float32] for p in [0.5]]
)
def bench_dropout(
M,
backend,
p,
datatype,
device=DEVICE,
):
seed = torch.random.initial_seed()

# Create input tensor
x = torch.rand(M, device=device, dtype=datatype, requires_grad=False)

training = True
inplace = False

fn = lambda: tilegym.ops.dropout(x, seed, p, training, inplace, backend=backend)

# Run a light sanity check
out = fn()
zero_ratio = 1 - torch.count_nonzero(out) / torch.numel(out)
threshold = 0.04
assert p - threshold < zero_ratio < p + threshold, f"Unexpected dropout ratio {zero_ratio} for p={p}"

# Benchmark the function
ms = triton.testing.do_bench_cudagraph(fn)

# Calculate memory bandwidth
# Dropout forward pass: read input + write output
num_elements = x.numel()
bytes_per_element = x.element_size()
total_bytes = 2 * num_elements * bytes_per_element
gb_per_s = total_bytes * 1e-9 / (ms * 1e-3)

return gb_per_s


if __name__ == "__main__":
bench_dropout.run(print_data=True)
Loading