Skip to content

Commit 68fc887

Browse files
committed
mask.go: Revert my changes
I'm just not good enough at assembly. I added tests to confirm that @wdvxdr's implementation works correctly and matches the output of the basic masking loop.
1 parent fee3739 commit 68fc887

File tree

9 files changed

+126
-34
lines changed

9 files changed

+126
-34
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ module nhooyr.io/websocket
22

33
go 1.19
44

5-
require golang.org/x/sys v0.17.0 // indirect
5+
require golang.org/x/sys v0.17.0

internal/examples/go.mod

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ require (
88
golang.org/x/time v0.3.0
99
nhooyr.io/websocket v0.0.0-00010101000000-000000000000
1010
)
11+
12+
require golang.org/x/sys v0.17.0 // indirect

internal/examples/go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
2+
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
13
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
24
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

internal/thirdparty/go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ require (
3636
golang.org/x/arch v0.3.0 // indirect
3737
golang.org/x/crypto v0.9.0 // indirect
3838
golang.org/x/net v0.10.0 // indirect
39-
golang.org/x/sys v0.13.0 // indirect
39+
golang.org/x/sys v0.17.0 // indirect
4040
golang.org/x/text v0.9.0 // indirect
4141
google.golang.org/protobuf v1.30.0 // indirect
4242
gopkg.in/yaml.v3 v3.0.1 // indirect

internal/thirdparty/go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
100100
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
101101
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
102102
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
103-
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
104-
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
103+
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
104+
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
105105
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
106106
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
107107
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

mask_amd64.s

+32-30
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ TEXT ·maskAsm(SB), NOSPLIT, $0-28
1010
MOVQ len+8(FP), CX
1111
MOVL key+16(FP), SI
1212

13-
// Calculate the DI aka the uint64 key.
14-
// DI = uint64(SI) | uint64(SI)<<32
13+
// calculate the DI
14+
// DI = SI<<32 | SI
1515
MOVL SI, DI
1616
MOVQ DI, DX
1717
SHLQ $32, DI
1818
ORQ DX, DI
1919

20-
CMPQ CX, $8
21-
JL less_than_8
20+
CMPQ CX, $15
21+
JLE less_than_16
22+
CMPQ CX, $63
23+
JLE less_than_64
2224
CMPQ CX, $128
2325
JLE sse
2426
TESTQ $31, AX
@@ -37,8 +39,8 @@ unaligned_loop_1byte:
3739
TESTQ $7, AX
3840
JNZ unaligned_loop_1byte
3941

40-
// Calculate DI again since SI was modified.
41-
// DI = uint64(SI) | uint64(SI)<<32
42+
// calculate DI again since SI was modified
43+
// DI = SI<<32 | SI
4244
MOVL SI, DI
4345
MOVQ DI, DX
4446
SHLQ $32, DI
@@ -48,12 +50,11 @@ unaligned_loop_1byte:
4850
JZ aligned
4951

5052
unaligned:
51-
// $7 & len, if not zero jump to loop_1b.
52-
TESTQ $7, AX
53+
TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b.
5354
JNZ unaligned_loop_1byte
5455

5556
unaligned_loop:
56-
// We don't need to check the CX since we know it's above 512.
57+
// we don't need to check the CX since we know it's above 128
5758
XORQ DI, (AX)
5859
ADDQ $8, AX
5960
SUBQ $8, CX
@@ -62,33 +63,34 @@ unaligned_loop:
6263
JMP aligned
6364

6465
avx2:
65-
CMPQ CX, $128
66+
CMPQ CX, $0x80
6667
JL sse
6768
VMOVQ DI, X0
6869
VPBROADCASTQ X0, Y0
6970

70-
// TODO: shouldn't these be aligned movs now?
71-
// TODO: should be 256?
7271
avx2_loop:
73-
VMOVDQU (AX), Y1
74-
VPXOR Y0, Y1, Y2
75-
VMOVDQU Y2, (AX)
76-
ADDQ $128, AX
77-
SUBQ $128, CX
78-
CMPQ CX, $128
79-
// Loop if CX >= 128.
80-
JAE avx2_loop
81-
82-
// TODO: should be 128?
72+
VPXOR (AX), Y0, Y1
73+
VPXOR 32(AX), Y0, Y2
74+
VPXOR 64(AX), Y0, Y3
75+
VPXOR 96(AX), Y0, Y4
76+
VMOVDQU Y1, (AX)
77+
VMOVDQU Y2, 32(AX)
78+
VMOVDQU Y3, 64(AX)
79+
VMOVDQU Y4, 96(AX)
80+
ADDQ $0x80, AX
81+
SUBQ $0x80, CX
82+
CMPQ CX, $0x80
83+
JAE avx2_loop // loop if CX >= 0x80
84+
8385
sse:
84-
CMPQ CX, $64
86+
CMPQ CX, $0x40
8587
JL less_than_64
8688
MOVQ DI, X0
8789
PUNPCKLQDQ X0, X0
8890

8991
sse_loop:
90-
MOVOU (AX), X1
91-
MOVOU 16(AX), X2
92+
MOVOU 0*16(AX), X1
93+
MOVOU 1*16(AX), X2
9294
MOVOU 2*16(AX), X3
9395
MOVOU 3*16(AX), X4
9496
PXOR X0, X1
@@ -99,9 +101,9 @@ sse_loop:
99101
MOVOU X2, 1*16(AX)
100102
MOVOU X3, 2*16(AX)
101103
MOVOU X4, 3*16(AX)
102-
ADDQ $64, AX
103-
SUBQ $64, CX
104-
CMPQ CX, $64
104+
ADDQ $0x40, AX
105+
SUBQ $0x40, CX
106+
CMPQ CX, $0x40
105107
JAE sse_loop
106108

107109
less_than_64:
@@ -141,10 +143,10 @@ less_than_4:
141143

142144
less_than_2:
143145
TESTQ $1, CX
144-
JZ end
146+
JZ done
145147
XORB SI, (AX)
146148
ROLL $24, SI
147149

148-
end:
150+
done:
149151
MOVL SI, ret+24(FP)
150152
RET

mask_asm.go

+2
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ func mask(b []byte, key uint32) uint32 {
1111
return key
1212
}
1313

14+
//lint:ignore U1000 mask_*.s
1415
var useAVX2 = cpu.X86.HasAVX2
1516

1617
// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this
1718
// function are perfect. There are almost certainly missing optimizations or
1819
// opportunities for // simplification. I'm confident there are no bugs though.
1920
// For example, the arm64 implementation doesn't align memory like the amd64.
2021
// Or the amd64 implementation could use AVX512 instead of just AVX2.
22+
//
2123
//go:noescape
2224
func maskAsm(b *byte, len int, key uint32) uint32

mask_asm_test.go

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//go:build amd64 || arm64
2+
3+
package websocket
4+
5+
import "testing"
6+
7+
func TestMaskASM(t *testing.T) {
8+
t.Parallel()
9+
10+
testMask(t, "maskASM", mask)
11+
}

mask_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package websocket
2+
3+
import (
4+
"bytes"
5+
"crypto/rand"
6+
"encoding/binary"
7+
"math/big"
8+
"math/bits"
9+
"testing"
10+
11+
"nhooyr.io/websocket/internal/test/assert"
12+
)
13+
14+
func basicMask(b []byte, key uint32) uint32 {
15+
for i := range b {
16+
b[i] ^= byte(key)
17+
key = bits.RotateLeft32(key, -8)
18+
}
19+
return key
20+
}
21+
22+
func basicMask2(b []byte, key uint32) uint32 {
23+
keyb := binary.LittleEndian.AppendUint32(nil, key)
24+
pos := 0
25+
for i := range b {
26+
b[i] ^= keyb[pos&3]
27+
pos++
28+
}
29+
return bits.RotateLeft32(key, (pos&3)*-8)
30+
}
31+
32+
func TestMask(t *testing.T) {
33+
t.Parallel()
34+
35+
testMask(t, "basicMask", basicMask)
36+
testMask(t, "maskGo", maskGo)
37+
testMask(t, "basicMask2", basicMask2)
38+
}
39+
40+
func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) {
41+
t.Run(name, func(t *testing.T) {
42+
t.Parallel()
43+
for i := 0; i < 9999; i++ {
44+
keyb := make([]byte, 4)
45+
_, err := rand.Read(keyb)
46+
assert.Success(t, err)
47+
key := binary.LittleEndian.Uint32(keyb)
48+
49+
n, err := rand.Int(rand.Reader, big.NewInt(1<<16))
50+
assert.Success(t, err)
51+
52+
b := make([]byte, 1+n.Int64())
53+
_, err = rand.Read(b)
54+
assert.Success(t, err)
55+
56+
b2 := make([]byte, len(b))
57+
copy(b2, b)
58+
b3 := make([]byte, len(b))
59+
copy(b3, b)
60+
61+
key2 := basicMask(b2, key)
62+
key3 := fn(b3, key)
63+
64+
if key2 != key3 {
65+
t.Errorf("expected key %X but got %X", key2, key3)
66+
}
67+
if !bytes.Equal(b2, b3) {
68+
t.Error("bad bytes")
69+
return
70+
}
71+
}
72+
})
73+
}

0 commit comments

Comments
 (0)