Skip to content

Introduce block_softmax_adjustment kernel (#163) #263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: deepseek_r1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm_hpu_extension/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm_hpu_extension.environment import get_environment
from vllm_hpu_extension.kernels import fsdpa
from vllm_hpu_extension.kernels import block_softmax_adjustment


detected = None
Expand Down Expand Up @@ -160,6 +161,11 @@ def enabled_flags():
& ModelType("llama")
& Not(EnvFlag("VLLM_PROMPT_USE_FUSEDSDPA", "false"))
& EnvFlag("VLLM_PROMPT_USE_FLEX_ATTENTION", "false")),
"fused_block_softmax_adjustment": (Not(Hardware("cpu"))
& VersionRange(">=1.22.0.101")
& Kernel(block_softmax_adjustment)
& EnvFlag("VLLM_FUSED_BLOCK_SOFTMAX_ADJUSTMENT",
Not(ModelType('qwen2')) & Hardware("gaudi3"))),
}
environment = get_environment()
detected = Flags(supported_flags, environment)
Expand Down
45 changes: 29 additions & 16 deletions vllm_hpu_extension/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,37 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

from .utils import logger
from functools import cache


@cache
def _kernel(name):
def loader(fn):
@cache
def loader_impl():
try:
print("Load", name, fn)
return fn()
except (ImportError, AttributeError):
from .utils import logger
logger().warning(f"Could not import HPU {name} kernel. "
"vLLM will use native implementation")
return loader_impl
return loader


@_kernel("FusedSDPA")
def fsdpa():
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
return FusedSDPA
except ImportError:
logger().warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

@cache
from habana_frameworks.torch.hpex.kernels import FusedSDPA
return FusedSDPA


@_kernel("FusedRMSNorm")
def rms_norm():
try:
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
return FusedRMSNorm
except ImportError:
logger().warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
return FusedRMSNorm


@_kernel("block_softmax_adjustment")
def block_softmax_adjustment():
import torch
return torch.ops.hpu.block_softmax_adjustment
58 changes: 33 additions & 25 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,31 +65,39 @@ def pipelined_pa(attn, value, block_groups, block_mapping, block_scales, batch_s
adjustment_target_shape = block_max.shape
attn = attn.sub(block_max)
attn = attn.exp()
attn = attn.to(value.dtype)
if attn.dtype == torch.float32:
attn = attn.to(value.dtype)
block_sums = attn.sum(dim=-1, keepdim=True)
attn = matmul_av_op(attn, value)
block_max = block_max.squeeze()
block_sums = block_sums.squeeze()

# Calculate maximum of blocks that belong to the same sequences
# and cast adjustments to native dtype
group_max = grouped_max(block_max, batch_size, block_groups)
block_adjustment = (block_max - group_max).exp()
block_adjustment = block_adjustment.to(value.dtype)
sum_adjusted = block_sums.mul(block_adjustment)

# Sum block's sums that belongs to the same sequences
group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op)
group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op)
sum_adjusted = sum_adjusted.view(*adjustment_target_shape)
group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape)
block_adjustment = block_adjustment.view(*adjustment_target_shape)

# For stability in case some of the sums have been zeroed out during block aggretation
group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted)

# Post processing for the attention scores
rescale = block_adjustment.div(group_sum_adjusted)

if 'fused_block_softmax_adjustment' in enabled_flags() and block_max.dtype != torch.float16:
rescale = torch.ops.hpu.block_softmax_adjustment(block_max,
block_sums.to(block_max.dtype),
block_groups,
batch_size).to(attn.dtype)
else:
block_max = block_max.squeeze()
block_sums = block_sums.squeeze()

# Calculate maximum of blocks that belong to the same sequences
# and cast adjustments to native dtype
group_max = grouped_max(block_max, batch_size, block_groups)
block_adjustment = (block_max - group_max).exp()
if block_adjustment.dtype == torch.float32:
block_adjustment = block_adjustment.to(value.dtype)
sum_adjusted = block_sums.mul(block_adjustment)

# Sum block's sums that belongs to the same sequences
group_sum_adjusted = block2batch(sum_adjusted, block_mapping, block2batch_matmul_op)
group_sum_adjusted = batch2block(group_sum_adjusted, block_mapping, batch2block_matmul_op)
sum_adjusted = sum_adjusted.view(*adjustment_target_shape)
group_sum_adjusted = group_sum_adjusted.view(*adjustment_target_shape)
block_adjustment = block_adjustment.view(*adjustment_target_shape)

# For stability in case some of the sums have been zeroed out during block aggretation
group_sum_adjusted = torch.maximum(group_sum_adjusted, sum_adjusted)
# Post processing for the attention scores
rescale = block_adjustment.div(group_sum_adjusted)
attn = attn.mul(rescale)
return attn

Expand Down Expand Up @@ -405,8 +413,8 @@ def forward(self, hidden_states, score, topk):
htorch.core.mark_step()
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

Expand Down