Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 19 additions & 54 deletions paddlespeech/s2t/models/wavlm/modules/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def _mha_shape_check(query: paddle.Tensor, key: paddle.Tensor, value: paddle.Ten
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor")

def masked_fill(x, mask, value):
y = paddle.full(x.shape, value)


def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
"""
Expand All @@ -61,18 +58,22 @@ def scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal):
d_key = k.shape[-1]
scaled_q = paddle.scale(x=q, scale=d_key ** -0.5)
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
weights = paddle.nn.functional.softmax(x=product + attn_mask)
weights = F.softmax(x=product + attn_mask)
if dropout_p:
weights = paddle.fluid.layers.nn.dropout(
weights = F.dropout(
weights,
dropout_prob=dropout_p,
dropout_implementation="upscale_in_train",
is_test=False)
p=dropout_p,
training=True,
mode="upscale_in_train"
)
out = paddle.matmul(x=weights, y=v)
return out


def addr(input, vec1, vec2, beta=1, alpha=1, out=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是算什么的?加下doc-string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理,这个是计算alpha * (vec1 * vec2.T) + beta * input的帮手函数,用于attention的QK计算

"""
A helper function to calculate alpha*(vec1*vec2^T) + beta*input
"""
row = vec1.shape[0]
column = vec2.shape[0]
vec1 = paddle.unsqueeze(vec1, 0)
Expand Down Expand Up @@ -164,12 +165,11 @@ def _in_projection_packed(
- in output list :math:`[q', k', v']`, each output tensor will have the
same shape as the corresponding input tensor.
"""
# E = q.size(-1)
E = q.shape[-1]
if k is v:
if q is k:
# self-attention
proj = linear(q, w, b)
proj = F.linear(q, w, b)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
return proj[0], proj[1], proj[2]
Expand All @@ -180,8 +180,8 @@ def _in_projection_packed(
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
q_proj = linear(q, w_q, b_q)
kv_proj = linear(k, w_kv, b_kv)
q_proj = F.linear(q, w_q, b_q)
kv_proj = F.linear(k, w_kv, b_kv)
# reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose([2, 1, 0]).squeeze(-2).contiguous()
return (q_proj, kv_proj[0], kv_proj[1])
Expand All @@ -191,7 +191,7 @@ def _in_projection_packed(
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)

def _in_projection(
q: paddle.Tensor,
Expand All @@ -204,10 +204,8 @@ def _in_projection(
b_k: Optional[paddle.Tensor] = None,
b_v: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
A, B, C = linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

A, B, C = F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
return A, B, C
# return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

def multi_head_attention_forward_paddle(
query: paddle.Tensor,
Expand Down Expand Up @@ -299,22 +297,7 @@ def multi_head_attention_forward_paddle(
"""

is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)

# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
# batch dimension so that the output doesn't carry this temporary batch dimension.
# if not is_batched:
# # unsqueeze if the input is unbatched
# query = query.unsqueeze(1)
# key = key.unsqueeze(1)
# value = value.unsqueeze(1)
# if key_padding_mask is not None:
# key_padding_mask = key_padding_mask.unsqueeze(0)

# set up shape vars
# import pdb; pdb.set_trace()
tgt_len, bsz, embed_dim = query.shape
# tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape

if is_causal:
Expand Down Expand Up @@ -373,9 +356,7 @@ def multi_head_attention_forward_paddle(
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
# k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
k = paddle.concat([k, bias_k.repeat(1, bsz, 1)], axis=1)
# v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
v = paddle.concat([v, bias_v.repeat(1, bsz, 1)], axis=1)
if attn_mask is not None:
# attn_mask = pad(attn_mask, (0, 1))
Expand All @@ -392,22 +373,18 @@ def multi_head_attention_forward_paddle(
#
# reshape q, k, v for multihead attention and make em batch first
#
# q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])


if static_k is None:
# k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.size(0) == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.size(2) == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
# v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
Expand All @@ -420,9 +397,7 @@ def multi_head_attention_forward_paddle(
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
# k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
k = paddle.concat([k, paddle.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], axis=1)
# v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
v = paddle.concat([v, paddle.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], axis=1)
if attn_mask is not None:
# attn_mask = pad(attn_mask, (0, 1))
Expand All @@ -438,7 +413,6 @@ def multi_head_attention_forward_paddle(
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
# key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
key_padding_mask = key_padding_mask.reshape([bsz, 1, 1, src_len]).expand([-1, num_heads, -1, -1]).reshape([bsz * num_heads, 1, src_len])
if attn_mask is None:
attn_mask = key_padding_mask
Expand All @@ -456,25 +430,20 @@ def multi_head_attention_forward_paddle(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
if attn_mask is not None:
# attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
attn_output_weights = addr(q_scaled, k.transpose(-2, -1))
else:
# attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = paddle.bmm(q_scaled, k.transpose(0, 2, 1))
# attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = paddle.nn.functional.softmax(attn_output_weights, axis=-1)
attn_output_weights = F.softmax(attn_output_weights, axis=-1)
if dropout_p > 0.0:
# attn_output_weights = dropout(attn_output_weights, p=dropout_p)
attn_output_weights = paddle.nn.functional.dropout(attn_output_weights, p=dropout_p)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)

# attn_output = torch.bmm(attn_output_weights, v)
attn_output = paddle.bmm(attn_output_weights, v)
attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len * bsz, embed_dim])
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
# attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])

# optionally average attention weights over heads
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.reshape([bsz, num_heads, tgt_len, src_len])
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
Expand All @@ -492,17 +461,13 @@ def multi_head_attention_forward_paddle(
if attn_mask.shape[0] == 1 and attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
# attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
attn_mask = attn_mask.reshape([bsz, num_heads, -1, src_len])

q = q.reshape([bsz, num_heads, tgt_len, head_dim])
k = k.reshape([bsz, num_heads, src_len, head_dim])
v = v.reshape([bsz, num_heads, src_len, head_dim])
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.transpose(perm=[2, 0, 1, 3]).reshape([bsz * tgt_len, embed_dim])
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.reshape([tgt_len, bsz, attn_output.shape[1]])
# if not is_batched:
# # squeeze the output if input was unbatched
# attn_output = attn_output.squeeze(1)
return attn_output, None
Loading