Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/commons/utils/clear_tensor_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import functools
from typing import Optional, Tuple

import torch


@functools.lru_cache(maxsize=None)
def _empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor().cuda()


def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor]]) -> None:
Comment thread
JacoCheung marked this conversation as resolved.
"""
Trick to deallocate tensor memory when delete operation does not
release the tensor due to PyTorch override.

Must be used carefully.
"""
for t in tensors:
if t is not None:
if hasattr(t, "clear"):
t.clear() # type: ignore
else:
t.data = _empty_tensor() # type: ignore
del t
8 changes: 8 additions & 0 deletions examples/hstu/benchmark/fused_hstu_layer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def create_hstu_layer(
default=True,
required=False,
)
@click.option(
"--recompute-input-layernorm",
type=bool,
default=False,
required=False,
)
@click.option(
"--kernel-backend",
type=click.Choice(_backend_str_to_type.keys()),
Expand Down Expand Up @@ -120,6 +126,7 @@ def run(
async_wgrad,
dump_memory_snapshot,
num_layers,
recompute_input_layernorm,
):
log_layer_type = layer_type.upper()
layer_type = _layer_type_str_to_type[layer_type]
Expand All @@ -137,6 +144,7 @@ def run(
hstu_layer_type=layer_type,
learnable_input_layernorm=True,
async_wgrad=async_wgrad,
recompute_input_layernorm=recompute_input_layernorm,
)
hstu_blocks = [
create_hstu_layer(
Expand Down
7 changes: 5 additions & 2 deletions examples/hstu/benchmark_ranking.gin
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ BenchmarkDatasetArgs.feature_args = [

item_embedding/DynamicEmbeddingArgs.feature_names = ['item']
item_embedding/DynamicEmbeddingArgs.table_name = 'item'
item_embedding/DynamicEmbeddingArgs.item_vocab_size_or_capacity = 100000000 # gross 100M embedding rows
item_embedding/DynamicEmbeddingArgs.item_vocab_size_or_capacity = 50000000 # gross 50M embedding rows
item_embedding/DynamicEmbeddingArgs.item_vocab_gpu_capacity_ratio = 0.1
item_embedding/DynamicEmbeddingArgs.evict_strategy = 'lru'

Expand All @@ -46,8 +46,11 @@ NetworkArgs.kv_channels = 256
NetworkArgs.kernel_backend = 'cutlass'
NetworkArgs.layer_type = 'fused'

# recompute can incurs perf regression, but save memory
NetworkArgs.recompute_input_layernorm = True

RankingArgs.prediction_head_arch = [
[512, 8],
512, 8,
]
RankingArgs.prediction_head_bias = True
RankingArgs.num_tasks = 8
Expand Down
14 changes: 14 additions & 0 deletions examples/hstu/configs/hstu_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class HSTUConfig(TransformerConfig):
target_group_size (int): The size of the sub-candidate group where causal attention is applied only within a sub-group (usually in the case of ranking). Defaults to 1.
learnable_input_layernorm (bool): Flag to enable learnable input layernorm. Defaults to True.
residual (bool): Flag to enable residual connection. Defaults to True.
async_wgrad (bool): Flag to enable async wgrad. Defaults to False.
async_wgrad_stream (torch.cuda.Stream): Stream for async wgrad. Defaults to None.
async_wgrad_event (torch.cuda.Event): Event for async wgrad. Defaults to None.
recompute_input_layernorm (bool): Flag to enable recompute input layernorm. Defaults to False.
"""

position_encoding_config: Optional[PositionEncodingConfig] = None
Expand All @@ -98,6 +102,8 @@ class HSTUConfig(TransformerConfig):
async_wgrad: bool = False
async_wgrad_stream: Optional[torch.cuda.Stream] = None
async_wgrad_event: Optional[torch.cuda.Event] = None
# whether to recompute input layernorm
recompute_input_layernorm: bool = False

def __post_init__(self):
super().__post_init__()
Expand All @@ -119,6 +125,7 @@ def get_hstu_config(
learnable_input_layernorm: bool = True,
residual: bool = True,
async_wgrad: bool = False,
recompute_input_layernorm: bool = False,
Comment thread
JacoCheung marked this conversation as resolved.
) -> HSTUConfig:
"""
Create the HSTU configuration.
Expand All @@ -134,6 +141,12 @@ def get_hstu_config(
norm_epsilon (float, optional): Epsilon value for normalization. Defaults to 1e-5.
is_causal (bool, optional): Whether the attention is causal. Defaults to False.
kernel_backend (KernelBackend, optional): Backend for kernel operations. Defaults to KernelBackend.CUTLASS.
target_group_size (int, optional): The size of the sub-candidate group where causal attention is applied only within a sub-group (usually in the case of ranking). Defaults to 1.
hstu_layer_type (HSTULayerType, optional): The type of HSTU layer. Defaults to HSTULayerType.FUSED.
learnable_input_layernorm (bool, optional): Whether to use learnable input layernorm. Defaults to True.
residual (bool, optional): Whether to add residual connection. Defaults to True.
async_wgrad (bool, optional): Whether to use async wgrad. Defaults to False.
recompute_input_layernorm (bool, optional): Whether to recompute input layernorm. Defaults to False.

Returns:
HSTUConfig: The HSTU configuration object.
Expand Down Expand Up @@ -168,4 +181,5 @@ def get_hstu_config(
async_wgrad=async_wgrad,
async_wgrad_stream=async_wgrad_stream,
async_wgrad_event=async_wgrad_event,
recompute_input_layernorm=recompute_input_layernorm,
)
3 changes: 2 additions & 1 deletion examples/hstu/modules/fused_hstu_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, config: HSTUConfig):
)
)
)

self._recompute_input_layernorm = config.recompute_input_layernorm
torch.nn.init.xavier_uniform_(self._linear_proj_weight)

@output_nvtx_hook(nvtx_tag="FusedHSTULayer", hook_tensor_attr_name="values")
Expand Down Expand Up @@ -148,6 +148,7 @@ def forward(self, jd: JaggedData) -> JaggedData:
residual=self._residual,
wgrad_stream=self._wgrad_stream,
wgrad_event=self._wgrad_event,
recompute_input_layernorm=self._recompute_input_layernorm,
)
return JaggedData(
values=output,
Expand Down
38 changes: 34 additions & 4 deletions examples/hstu/ops/fused_hstu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import hstu_hopper_cuda as flash_attn_cuda_hopper
import nvtx
import torch
from commons.utils.clear_tensor_data import clear_tensor_data
from configs import KernelBackend
from ops.pt_ops.torch_addmm import torch_addmm_silu_fwd
from ops.triton_ops.triton_addmm import triton_addmm_silu_bwd, triton_addmm_silu_fwd
Expand Down Expand Up @@ -84,6 +85,7 @@ def forward(
residual: bool = True,
wgrad_stream: Optional[torch.cuda.Stream] = None,
wgrad_event: Optional[torch.cuda.Event] = None,
recompute_input_layernorm: bool = False,
) -> torch.Tensor:
"""Forward pass of the fused HSTU layer.
Args:
Expand Down Expand Up @@ -130,6 +132,7 @@ def forward(
ctx.residual = residual
ctx.wgrad_stream = wgrad_stream
ctx.wgrad_event = wgrad_event
ctx.recompute_input_layernorm = recompute_input_layernorm
saved_tensor_map = OrderedDict()
if num_contextuals is None and attn_backend == KernelBackend.TRITON:
num_contextuals = 0
Expand Down Expand Up @@ -196,10 +199,15 @@ def _ln_linear_silu_fwd(
# for gemm backward
saved_tensor_map.update(
{
"linear_uvqk_input": normed_input,
"linear_uvqk_input": normed_input
if not recompute_input_layernorm
else None,
"linear_uvqk_weight": linear_weight,
}
)
if recompute_input_layernorm:
clear_tensor_data(normed_input)
del normed_input
Comment thread
JacoCheung marked this conversation as resolved.
saved_tensor_map.update(
{
"silu_input": linear_uvqk,
Expand Down Expand Up @@ -409,9 +417,12 @@ def _linear_residual_fwd(
tv = tv.view(-1, num_heads, linear_dim_per_head).contiguous()
tq = tq.view(-1, num_heads, attention_dim_per_head).contiguous()
tk = tk.view(-1, num_heads, attention_dim_per_head).contiguous()
# tu = tu.contiguous()
# to make a copy here because we need to delete act_linear_uvqk
tu = tu.contiguous()
clear_tensor_data(act_linear_uvqk)
# we are safe to delete because contiguous creates a copy
# del act_linear_uvqk
# in the future, we can recompute silu in the backward pass with saved_tensor_map['silu_input']
del act_linear_uvqk
with nvtx.annotate("hstu attn fwd", color="BLUE"):
if ctx.attn_backend == KernelBackend.CUTLASS:
# attn_output: [T, num_heads * attention_dim_per_head]
Expand Down Expand Up @@ -741,6 +752,7 @@ def _ln_linear_silu_bwd(
)

with nvtx.annotate("hstu attn bwd", color="BLUE"):
# TODO q,k,v can be recomputed via silu_fwd(saved_tensor_map['silu_input'])
if ctx.attn_backend == KernelBackend.CUTLASS:
grad_q, grad_k, grad_v = _hstu_attn_cutlass_bwd(
dout=grad_output.view(
Expand Down Expand Up @@ -773,7 +785,6 @@ def _ln_linear_silu_bwd(
causal=ctx.causal,
contextual_seq_len=ctx.contextual_seq_len, # saved_tensor_map["num_contexts"] == None,
)

with nvtx.annotate("ln_linear_silu bwd", color="RED"):
grad_q = grad_q.view(-1, ctx.num_heads * ctx.attention_dim_per_head)
grad_k = grad_k.view(-1, ctx.num_heads * ctx.attention_dim_per_head)
Expand All @@ -782,6 +793,22 @@ def _ln_linear_silu_bwd(
grad_output = torch.cat(
[grad_u, grad_v, grad_q, grad_k], dim=-1
).contiguous()
if ctx.recompute_input_layernorm:
(
normed_input,
_,
_,
_,
_,
) = triton_weighted_layer_norm_fwd(
x=saved_tensor_map["input"],
weight=saved_tensor_map["input_ln_weight"],
bias=saved_tensor_map["input_ln_bias"],
eps=ctx.eps,
mean=saved_tensor_map["input_ln_mean"],
rstd=saved_tensor_map["input_ln_rstd"],
)
saved_tensor_map["linear_uvqk_input"] = normed_input

(
grad_input,
Expand Down Expand Up @@ -836,6 +863,7 @@ def _ln_linear_silu_bwd(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -870,6 +898,7 @@ def fused_hstu_op(
residual: bool = True,
wgrad_stream: Optional[torch.cuda.Stream] = None,
wgrad_event: Optional[torch.cuda.Event] = None,
recompute_input_layernorm: bool = False,
):
out = FusedHSTULayerFunction.apply(
input,
Expand Down Expand Up @@ -898,6 +927,7 @@ def fused_hstu_op(
residual,
wgrad_stream,
wgrad_event,
recompute_input_layernorm,
)

return out
7 changes: 4 additions & 3 deletions examples/hstu/test/test_hstu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def generate_or_copy_parameters(
@pytest.mark.parametrize(
"is_causal,kernel_backend",
[
(True, KernelBackend.TRITON),
(True, KernelBackend.CUTLASS),
(False, KernelBackend.CUTLASS),
],
Expand Down Expand Up @@ -407,6 +406,7 @@ def generate_input(from_scratch: bool = True):
@pytest.mark.parametrize("upcast_reference", [False])
@pytest.mark.parametrize("residual", [False])
@pytest.mark.parametrize("async_wgrad", [True, False])
@pytest.mark.parametrize("recompute_input_layernorm", [True, False])
def test_fused_hstu_op(
dtype: torch.dtype,
batchsize: int,
Expand All @@ -425,6 +425,7 @@ def test_fused_hstu_op(
upcast_reference: bool,
residual: bool,
async_wgrad: bool,
recompute_input_layernorm: bool,
):
init.initialize_distributed()
init.set_random_seed(1234)
Expand Down Expand Up @@ -453,7 +454,6 @@ def test_fused_hstu_op(

hstu_config.kernel_backend = KernelBackend.PYTORCH
hstu_config.dtype = torch.float32
hstu_config.hstu_layer_type = HSTULayerType.NATIVE
fp32_ref_hstu_layer = HSTULayer(hstu_config)
fp32_ref_hstu_layer.load_state_dict(ref_hstu_layer.state_dict())
fp32_ref_hstu_layer.cuda()
Expand Down Expand Up @@ -592,6 +592,7 @@ def test_fused_hstu_op(
residual=residual,
wgrad_stream=None,
wgrad_event=None,
recompute_input_layernorm=recompute_input_layernorm,
)
ref_out = ref_out.values
fp32_ref_out = fp32_ref_out.values
Expand Down Expand Up @@ -649,7 +650,7 @@ def test_fused_hstu_op(
),
}
)
for tensor_name, (grad_ref, grad_fused, fp32_ref_grad) in reversed(
for tensor_name, (grad_fused, grad_ref, fp32_ref_grad) in reversed(
grad_to_compared.items()
):
print(
Expand Down
3 changes: 3 additions & 0 deletions examples/hstu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class NetworkArgs:

num_position_buckets: int = 8192

recompute_input_layernorm: bool = False

def __post_init__(self):
assert self.dtype_str in [
"bfloat16",
Expand Down Expand Up @@ -240,6 +242,7 @@ def create_hstu_config(network_args: NetworkArgs):
position_encoding_config=position_encoding_config,
target_group_size=network_args.target_group_size,
hstu_layer_type=layer_type,
recompute_input_layernorm=network_args.recompute_input_layernorm,
)


Expand Down