|
19 | 19 | def apply_rotary_emb_single(
|
20 | 20 | x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
21 | 21 | ) -> torch.Tensor:
|
22 |
| - x_r, x_i = x[..., ::2], x[..., 1::2] |
23 |
| - |
| 22 | + # Change to RoPE of huggingface version |
| 23 | + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
24 | 24 | # brodcast for batch_prefill mode input x
|
25 | 25 | if x.dim() == 4:
|
26 |
| - freqs_cos = freqs_cos[None, :, None, :] |
27 |
| - freqs_sin = freqs_sin[None, :, None, :] |
| 26 | + freqs_cos = freqs_cos[None, None, :, :] |
| 27 | + freqs_sin = freqs_sin[None, None, :, :] |
28 | 28 | x_out_r = x_r * freqs_cos - x_i * freqs_sin
|
29 | 29 | x_out_i = x_r * freqs_sin + x_i * freqs_cos
|
30 | 30 |
|
@@ -108,21 +108,27 @@ def forward_sha(
|
108 | 108 | hidden_states, (bsz, seq_len, 1, self.dim)
|
109 | 109 | ).transpose(1, 3)
|
110 | 110 | q = [
|
111 |
| - wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) |
| 111 | + wq_sha(hidden_states) |
| 112 | + .permute(0, 2, 3, 1) |
| 113 | + .reshape(bsz, seq_len, self.head_dim) |
112 | 114 | for wq_sha in self.wq_sha
|
113 | 115 | ]
|
114 | 116 | k = [
|
115 |
| - wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) |
| 117 | + wk_sha(hidden_states) |
| 118 | + .permute(0, 2, 3, 1) |
| 119 | + .reshape(bsz, seq_len, self.head_dim) |
116 | 120 | for wk_sha in self.wk_sha
|
117 | 121 | ]
|
118 | 122 | v = [
|
119 |
| - wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) |
| 123 | + wv_sha(hidden_states) |
| 124 | + .permute(0, 2, 3, 1) |
| 125 | + .reshape(bsz, seq_len, self.head_dim) |
120 | 126 | for wv_sha in self.wv_sha
|
121 | 127 | ]
|
122 | 128 | for i in range(len(q)):
|
123 | 129 | q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
|
124 | 130 | for i in range(len(k)):
|
125 |
| - k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) |
| 131 | + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2) |
126 | 132 |
|
127 | 133 | output_y = []
|
128 | 134 | kh, vh = [], []
|
@@ -249,10 +255,10 @@ def prepare_feedfoward_conv(self):
|
249 | 255 |
|
250 | 256 | def forward_feedfoward_conv(self, x):
|
251 | 257 | bsz, _, _ = x.size()
|
252 |
| - x = torch.reshape(x, (bsz, -1, self.dim, 1)) |
253 |
| - x = x.transpose(1, 2) # Transpose right before and after Conv |
| 258 | + x = torch.reshape(x, (bsz, -1, 1, self.dim)) |
| 259 | + x = x.transpose(1, 3) # Transpose right before and after Conv |
254 | 260 | x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x))
|
255 |
| - x = x.transpose(1, 2) |
| 261 | + x = x.transpose(1, 3) |
256 | 262 | x = torch.reshape(x, (bsz, -1, self.dim))
|
257 | 263 | return x
|
258 | 264 |
|
|
0 commit comments