Skip to content

Commit 53c53c2

Browse files
committed
wip
1 parent f4e61e5 commit 53c53c2

File tree

5 files changed

+166
-135
lines changed

5 files changed

+166
-135
lines changed

frame.go

-125
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"fmt"
99
"io"
1010
"math"
11-
"math/bits"
1211

1312
"nhooyr.io/websocket/internal/errd"
1413
)
@@ -172,127 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
172171

173172
return nil
174173
}
175-
176-
// maskGo applies the WebSocket masking algorithm to p
177-
// with the given key.
178-
// See https://tools.ietf.org/html/rfc6455#section-5.3
179-
//
180-
// The returned value is the correctly rotated key to
181-
// to continue to mask/unmask the message.
182-
//
183-
// It is optimized for LittleEndian and expects the key
184-
// to be in little endian.
185-
//
186-
// See https://github.com/golang/go/issues/31586
187-
//
188-
//lint:ignore U1000 mask.go
189-
func maskGo(b []byte, key uint32) uint32 {
190-
if len(b) >= 8 {
191-
key64 := uint64(key)<<32 | uint64(key)
192-
193-
// At some point in the future we can clean these unrolled loops up.
194-
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
195-
196-
// Then we xor until b is less than 128 bytes.
197-
for len(b) >= 128 {
198-
v := binary.LittleEndian.Uint64(b)
199-
binary.LittleEndian.PutUint64(b, v^key64)
200-
v = binary.LittleEndian.Uint64(b[8:16])
201-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
202-
v = binary.LittleEndian.Uint64(b[16:24])
203-
binary.LittleEndian.PutUint64(b[16:24], v^key64)
204-
v = binary.LittleEndian.Uint64(b[24:32])
205-
binary.LittleEndian.PutUint64(b[24:32], v^key64)
206-
v = binary.LittleEndian.Uint64(b[32:40])
207-
binary.LittleEndian.PutUint64(b[32:40], v^key64)
208-
v = binary.LittleEndian.Uint64(b[40:48])
209-
binary.LittleEndian.PutUint64(b[40:48], v^key64)
210-
v = binary.LittleEndian.Uint64(b[48:56])
211-
binary.LittleEndian.PutUint64(b[48:56], v^key64)
212-
v = binary.LittleEndian.Uint64(b[56:64])
213-
binary.LittleEndian.PutUint64(b[56:64], v^key64)
214-
v = binary.LittleEndian.Uint64(b[64:72])
215-
binary.LittleEndian.PutUint64(b[64:72], v^key64)
216-
v = binary.LittleEndian.Uint64(b[72:80])
217-
binary.LittleEndian.PutUint64(b[72:80], v^key64)
218-
v = binary.LittleEndian.Uint64(b[80:88])
219-
binary.LittleEndian.PutUint64(b[80:88], v^key64)
220-
v = binary.LittleEndian.Uint64(b[88:96])
221-
binary.LittleEndian.PutUint64(b[88:96], v^key64)
222-
v = binary.LittleEndian.Uint64(b[96:104])
223-
binary.LittleEndian.PutUint64(b[96:104], v^key64)
224-
v = binary.LittleEndian.Uint64(b[104:112])
225-
binary.LittleEndian.PutUint64(b[104:112], v^key64)
226-
v = binary.LittleEndian.Uint64(b[112:120])
227-
binary.LittleEndian.PutUint64(b[112:120], v^key64)
228-
v = binary.LittleEndian.Uint64(b[120:128])
229-
binary.LittleEndian.PutUint64(b[120:128], v^key64)
230-
b = b[128:]
231-
}
232-
233-
// Then we xor until b is less than 64 bytes.
234-
for len(b) >= 64 {
235-
v := binary.LittleEndian.Uint64(b)
236-
binary.LittleEndian.PutUint64(b, v^key64)
237-
v = binary.LittleEndian.Uint64(b[8:16])
238-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
239-
v = binary.LittleEndian.Uint64(b[16:24])
240-
binary.LittleEndian.PutUint64(b[16:24], v^key64)
241-
v = binary.LittleEndian.Uint64(b[24:32])
242-
binary.LittleEndian.PutUint64(b[24:32], v^key64)
243-
v = binary.LittleEndian.Uint64(b[32:40])
244-
binary.LittleEndian.PutUint64(b[32:40], v^key64)
245-
v = binary.LittleEndian.Uint64(b[40:48])
246-
binary.LittleEndian.PutUint64(b[40:48], v^key64)
247-
v = binary.LittleEndian.Uint64(b[48:56])
248-
binary.LittleEndian.PutUint64(b[48:56], v^key64)
249-
v = binary.LittleEndian.Uint64(b[56:64])
250-
binary.LittleEndian.PutUint64(b[56:64], v^key64)
251-
b = b[64:]
252-
}
253-
254-
// Then we xor until b is less than 32 bytes.
255-
for len(b) >= 32 {
256-
v := binary.LittleEndian.Uint64(b)
257-
binary.LittleEndian.PutUint64(b, v^key64)
258-
v = binary.LittleEndian.Uint64(b[8:16])
259-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
260-
v = binary.LittleEndian.Uint64(b[16:24])
261-
binary.LittleEndian.PutUint64(b[16:24], v^key64)
262-
v = binary.LittleEndian.Uint64(b[24:32])
263-
binary.LittleEndian.PutUint64(b[24:32], v^key64)
264-
b = b[32:]
265-
}
266-
267-
// Then we xor until b is less than 16 bytes.
268-
for len(b) >= 16 {
269-
v := binary.LittleEndian.Uint64(b)
270-
binary.LittleEndian.PutUint64(b, v^key64)
271-
v = binary.LittleEndian.Uint64(b[8:16])
272-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
273-
b = b[16:]
274-
}
275-
276-
// Then we xor until b is less than 8 bytes.
277-
for len(b) >= 8 {
278-
v := binary.LittleEndian.Uint64(b)
279-
binary.LittleEndian.PutUint64(b, v^key64)
280-
b = b[8:]
281-
}
282-
}
283-
284-
// Then we xor until b is less than 4 bytes.
285-
for len(b) >= 4 {
286-
v := binary.LittleEndian.Uint32(b)
287-
binary.LittleEndian.PutUint32(b, v^key)
288-
b = b[4:]
289-
}
290-
291-
// xor remaining bytes.
292-
for i := range b {
293-
b[i] ^= byte(key)
294-
key = bits.RotateLeft32(key, -8)
295-
}
296-
297-
return key
298-
}

mask.go

+127-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,130 @@
1-
//go:build !amd64 && !arm64 && !js
2-
31
package websocket
42

5-
func mask(b []byte, key uint32) uint32 {
6-
return maskGo(b, key)
3+
import (
4+
"encoding/binary"
5+
"math/bits"
6+
)
7+
8+
// maskGo applies the WebSocket masking algorithm to p
9+
// with the given key.
10+
// See https://tools.ietf.org/html/rfc6455#section-5.3
11+
//
12+
// The returned value is the correctly rotated key to
13+
// to continue to mask/unmask the message.
14+
//
15+
// It is optimized for LittleEndian and expects the key
16+
// to be in little endian.
17+
//
18+
// See https://github.com/golang/go/issues/31586
19+
//
20+
//lint:ignore U1000 mask.go
21+
func maskGo(b []byte, key uint32) uint32 {
22+
if len(b) >= 8 {
23+
key64 := uint64(key)<<32 | uint64(key)
24+
25+
// At some point in the future we can clean these unrolled loops up.
26+
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
27+
28+
// Then we xor until b is less than 128 bytes.
29+
for len(b) >= 128 {
30+
v := binary.LittleEndian.Uint64(b)
31+
binary.LittleEndian.PutUint64(b, v^key64)
32+
v = binary.LittleEndian.Uint64(b[8:16])
33+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
34+
v = binary.LittleEndian.Uint64(b[16:24])
35+
binary.LittleEndian.PutUint64(b[16:24], v^key64)
36+
v = binary.LittleEndian.Uint64(b[24:32])
37+
binary.LittleEndian.PutUint64(b[24:32], v^key64)
38+
v = binary.LittleEndian.Uint64(b[32:40])
39+
binary.LittleEndian.PutUint64(b[32:40], v^key64)
40+
v = binary.LittleEndian.Uint64(b[40:48])
41+
binary.LittleEndian.PutUint64(b[40:48], v^key64)
42+
v = binary.LittleEndian.Uint64(b[48:56])
43+
binary.LittleEndian.PutUint64(b[48:56], v^key64)
44+
v = binary.LittleEndian.Uint64(b[56:64])
45+
binary.LittleEndian.PutUint64(b[56:64], v^key64)
46+
v = binary.LittleEndian.Uint64(b[64:72])
47+
binary.LittleEndian.PutUint64(b[64:72], v^key64)
48+
v = binary.LittleEndian.Uint64(b[72:80])
49+
binary.LittleEndian.PutUint64(b[72:80], v^key64)
50+
v = binary.LittleEndian.Uint64(b[80:88])
51+
binary.LittleEndian.PutUint64(b[80:88], v^key64)
52+
v = binary.LittleEndian.Uint64(b[88:96])
53+
binary.LittleEndian.PutUint64(b[88:96], v^key64)
54+
v = binary.LittleEndian.Uint64(b[96:104])
55+
binary.LittleEndian.PutUint64(b[96:104], v^key64)
56+
v = binary.LittleEndian.Uint64(b[104:112])
57+
binary.LittleEndian.PutUint64(b[104:112], v^key64)
58+
v = binary.LittleEndian.Uint64(b[112:120])
59+
binary.LittleEndian.PutUint64(b[112:120], v^key64)
60+
v = binary.LittleEndian.Uint64(b[120:128])
61+
binary.LittleEndian.PutUint64(b[120:128], v^key64)
62+
b = b[128:]
63+
}
64+
65+
// Then we xor until b is less than 64 bytes.
66+
for len(b) >= 64 {
67+
v := binary.LittleEndian.Uint64(b)
68+
binary.LittleEndian.PutUint64(b, v^key64)
69+
v = binary.LittleEndian.Uint64(b[8:16])
70+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
71+
v = binary.LittleEndian.Uint64(b[16:24])
72+
binary.LittleEndian.PutUint64(b[16:24], v^key64)
73+
v = binary.LittleEndian.Uint64(b[24:32])
74+
binary.LittleEndian.PutUint64(b[24:32], v^key64)
75+
v = binary.LittleEndian.Uint64(b[32:40])
76+
binary.LittleEndian.PutUint64(b[32:40], v^key64)
77+
v = binary.LittleEndian.Uint64(b[40:48])
78+
binary.LittleEndian.PutUint64(b[40:48], v^key64)
79+
v = binary.LittleEndian.Uint64(b[48:56])
80+
binary.LittleEndian.PutUint64(b[48:56], v^key64)
81+
v = binary.LittleEndian.Uint64(b[56:64])
82+
binary.LittleEndian.PutUint64(b[56:64], v^key64)
83+
b = b[64:]
84+
}
85+
86+
// Then we xor until b is less than 32 bytes.
87+
for len(b) >= 32 {
88+
v := binary.LittleEndian.Uint64(b)
89+
binary.LittleEndian.PutUint64(b, v^key64)
90+
v = binary.LittleEndian.Uint64(b[8:16])
91+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
92+
v = binary.LittleEndian.Uint64(b[16:24])
93+
binary.LittleEndian.PutUint64(b[16:24], v^key64)
94+
v = binary.LittleEndian.Uint64(b[24:32])
95+
binary.LittleEndian.PutUint64(b[24:32], v^key64)
96+
b = b[32:]
97+
}
98+
99+
// Then we xor until b is less than 16 bytes.
100+
for len(b) >= 16 {
101+
v := binary.LittleEndian.Uint64(b)
102+
binary.LittleEndian.PutUint64(b, v^key64)
103+
v = binary.LittleEndian.Uint64(b[8:16])
104+
binary.LittleEndian.PutUint64(b[8:16], v^key64)
105+
b = b[16:]
106+
}
107+
108+
// Then we xor until b is less than 8 bytes.
109+
for len(b) >= 8 {
110+
v := binary.LittleEndian.Uint64(b)
111+
binary.LittleEndian.PutUint64(b, v^key64)
112+
b = b[8:]
113+
}
114+
}
115+
116+
// Then we xor until b is less than 4 bytes.
117+
for len(b) >= 4 {
118+
v := binary.LittleEndian.Uint32(b)
119+
binary.LittleEndian.PutUint32(b, v^key)
120+
b = b[4:]
121+
}
122+
123+
// xor remaining bytes.
124+
for i := range b {
125+
b[i] ^= byte(key)
126+
key = bits.RotateLeft32(key, -8)
127+
}
128+
129+
return key
7130
}

mask_amd64.s

+30-6
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ TEXT ·maskAsm(SB), NOSPLIT, $0-28
1919

2020
CMPQ CX, $8
2121
JL less_than_8
22-
CMPQ CX, $512
22+
CMPQ CX, $128
2323
JLE sse
2424
TESTQ $31, AX
2525
JNZ unaligned
2626

27+
aligned:
28+
CMPB ·useAVX2(SB), $1
29+
JE avx2
30+
JMP sse
31+
2732
unaligned_loop_1byte:
2833
XORB SI, (AX)
2934
INCQ AX
@@ -40,7 +45,7 @@ unaligned_loop_1byte:
4045
ORQ DX, DI
4146

4247
TESTQ $31, AX
43-
JZ sse
48+
JZ aligned
4449

4550
unaligned:
4651
// $7 & len, if not zero jump to loop_1b.
@@ -54,17 +59,36 @@ unaligned_loop:
5459
SUBQ $8, CX
5560
TESTQ $31, AX
5661
JNZ unaligned_loop
57-
JMP sse
58-
62+
JMP aligned
63+
64+
avx2:
65+
CMPQ CX, $128
66+
JL sse
67+
VMOVQ DI, X0
68+
VPBROADCASTQ X0, Y0
69+
70+
// TODO: shouldn't these be aligned movs now?
71+
// TODO: should be 256?
72+
avx2_loop:
73+
VMOVDQU (AX), Y1
74+
VPXOR Y0, Y1, Y2
75+
VMOVDQU Y2, (AX)
76+
ADDQ $128, AX
77+
SUBQ $128, CX
78+
CMPQ CX, $128
79+
// Loop if CX >= 128.
80+
JAE avx2_loop
81+
82+
// TODO: should be 128?
5983
sse:
6084
CMPQ CX, $64
6185
JL less_than_64
6286
MOVQ DI, X0
6387
PUNPCKLQDQ X0, X0
6488

6589
sse_loop:
66-
MOVOU 0*16(AX), X1
67-
MOVOU 1*16(AX), X2
90+
MOVOU (AX), X1
91+
MOVOU 16(AX), X2
6892
MOVOU 2*16(AX), X3
6993
MOVOU 3*16(AX), X4
7094
PXOR X0, X1

mask_asm.go

+2
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ func mask(b []byte, key uint32) uint32 {
99
return key
1010
}
1111

12+
var useAVX2 = true
13+
1214
//go:noescape
1315
func maskAsm(b *byte, len int, key uint32) uint32

mask_go.go

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//go:build !amd64 && !arm64 && !js
2+
3+
package websocket
4+
5+
func mask(b []byte, key uint32) uint32 {
6+
return maskGo(b, key)
7+
}

0 commit comments

Comments
 (0)