Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
raise NotImplementedError

Expand All @@ -246,7 +247,6 @@ def forward(
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
14 changes: 7 additions & 7 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
Expand Down Expand Up @@ -350,6 +351,12 @@ def __init__(
active_head_range=self.blocksparse_params.active_head_range,
)

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -359,7 +366,6 @@ def forward(
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand All @@ -375,12 +381,6 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -627,6 +628,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type

def forward(
self,
Expand All @@ -637,7 +639,6 @@ def forward(
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand All @@ -659,6 +660,7 @@ def forward(

assert output is not None, "Output tensor must be provided."

attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
Expand Down
15 changes: 7 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -764,6 +765,12 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -773,18 +780,10 @@ def forward(
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")

num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
Expand Down
13 changes: 7 additions & 6 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
attn_type: str = AttentionType.DECODER,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand Down Expand Up @@ -143,6 +144,12 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -152,7 +159,6 @@ def forward(
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand All @@ -166,11 +172,6 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

Expand Down
12 changes: 6 additions & 6 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -146,6 +147,11 @@ def __init__(
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")

def split_kv_cache(
self,
Expand All @@ -172,7 +178,6 @@ def forward(
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand All @@ -189,11 +194,6 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
13 changes: 7 additions & 6 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -141,6 +142,12 @@ def __init__(
# megacore mode will be None.
self.megacore_mode = "batch"

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")

def forward(
self,
query: torch.Tensor,
Expand All @@ -150,7 +157,6 @@ def forward(
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand All @@ -168,11 +174,6 @@ def forward(
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
Expand Down
14 changes: 7 additions & 7 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -397,6 +398,12 @@ def __init__(
self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
Expand All @@ -414,7 +421,6 @@ def forward(
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand All @@ -432,12 +438,6 @@ def forward(
"""
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl")

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down Expand Up @@ -421,6 +422,7 @@ def __init__(
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
self.attn_type = attn_type

def forward(
self,
Expand All @@ -431,7 +433,6 @@ def forward(
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand All @@ -448,6 +449,7 @@ def forward(
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand All @@ -405,6 +406,8 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.attn_type = attn_type

def forward(
self,
query: torch.Tensor,
Expand All @@ -414,7 +417,6 @@ def forward(
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down Expand Up @@ -468,7 +470,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""

attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER
Expand Down
Loading
Loading