Skip to content

Commit a6641ec

Browse files
authored
Merge pull request #109 from lxzan/writeReader
WriteFile
2 parents 3dc044f + bf42d3f commit a6641ec

File tree

16 files changed

+656
-62
lines changed

16 files changed

+656
-62
lines changed

.golangci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ linters:
33
# Disable specific linter
44
# https://golangci-lint.run/usage/linters/#disabled-by-default
55
disable:
6+
- maintidx
67
- mnd
78
- testpackage
89
- nlreturn

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ ok github.com/lxzan/gws 17.231s
102102
- [x] Broadcast
103103
- [x] Dial via Proxy
104104
- [x] Context-Takeover
105-
- [x] Passed Autobahn Test Cases [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/)
106105
- [x] Concurrent & Asynchronous Non-Blocking Write
106+
- [x] Segmented Writing of Large Files
107+
- [x] Passed Autobahn Test Cases [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/)
107108

108109
### Attention
109110

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ ok github.com/lxzan/gws 17.231s
9292
- [x] 广播
9393
- [x] 代理拨号
9494
- [x] 上下文接管
95+
- [x] 大文件分段写入
9596
- [x] 支持并发和异步非阻塞写入
9697
- [x] 通过所有 Autobahn 测试用例 [Server](https://lxzan.github.io/gws/reports/servers/) / [Client](https://lxzan.github.io/gws/reports/clients/)
9798

benchmark_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
6868
conn: &benchConn{},
6969
config: upgrader.option.getConfig(),
7070
}
71-
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)
71+
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), frameConfig{
72+
fin: true,
73+
compress: conn1.pd.Enabled,
74+
broadcast: false,
75+
checkEncoding: false,
76+
})
7277

7378
var reader = bytes.NewBuffer(buf.Bytes())
7479
var conn2 = &Conn{
@@ -98,7 +103,12 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
98103
deflater: new(deflater),
99104
}
100105
conn1.deflater.initialize(false, conn1.pd, config.ReadMaxPayloadSize)
101-
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)
106+
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), frameConfig{
107+
fin: true,
108+
compress: conn1.pd.Enabled,
109+
broadcast: false,
110+
checkEncoding: false,
111+
})
102112

103113
var reader = bytes.NewBuffer(buf.Bytes())
104114
var conn2 = &Conn{

bigfile.go

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package gws
2+
3+
import (
4+
"bytes"
5+
"encoding/binary"
6+
"errors"
7+
"io"
8+
"math"
9+
10+
"github.com/klauspost/compress/flate"
11+
"github.com/lxzan/gws/internal"
12+
)
13+
14+
const segmentSize = 128 * 1024
15+
16+
// 获取大文件压缩器
17+
// Get bigDeflater
18+
func (c *Conn) getBigDeflater() *bigDeflater {
19+
if c.isServer {
20+
return c.config.bdPool.Get()
21+
}
22+
return (*bigDeflater)(c.deflater.cpsWriter)
23+
}
24+
25+
// 回收大文件压缩器
26+
// Recycle bigDeflater
27+
func (c *Conn) putBigDeflater(d *bigDeflater) {
28+
if c.isServer {
29+
c.config.bdPool.Put(d)
30+
}
31+
}
32+
33+
// 拆分io.Reader为小切片
34+
// Split io.Reader into small slices
35+
func (c *Conn) splitReader(r io.Reader, f func(index int, eof bool, p []byte) error) error {
36+
var buf = binaryPool.Get(segmentSize)
37+
defer binaryPool.Put(buf)
38+
39+
var p = buf.Bytes()[:segmentSize]
40+
var n, index = 0, 0
41+
var err error
42+
for n, err = r.Read(p); err == nil || errors.Is(err, io.EOF); n, err = r.Read(p) {
43+
eof := errors.Is(err, io.EOF)
44+
if err = f(index, eof, p[:n]); err != nil {
45+
return err
46+
}
47+
index++
48+
if eof {
49+
break
50+
}
51+
}
52+
return err
53+
}
54+
55+
// WriteFile 大文件写入
56+
// 采用分段写入技术, 减少写入过程中的内存占用
57+
// Segmented write technology to reduce memory usage during write process
58+
func (c *Conn) WriteFile(opcode Opcode, payload io.Reader) error {
59+
err := c.doWriteFile(opcode, payload)
60+
c.emitError(err)
61+
return err
62+
}
63+
64+
func (c *Conn) doWriteFile(opcode Opcode, payload io.Reader) error {
65+
c.mu.Lock()
66+
defer c.mu.Unlock()
67+
68+
var cb = func(index int, eof bool, p []byte) error {
69+
if index > 0 {
70+
opcode = OpcodeContinuation
71+
}
72+
frame, err := c.genFrame(opcode, internal.Bytes(p), frameConfig{
73+
fin: eof,
74+
compress: false,
75+
broadcast: false,
76+
checkEncoding: false,
77+
})
78+
if err != nil {
79+
return err
80+
}
81+
if c.pd.Enabled && index == 0 {
82+
frame.Bytes()[0] |= uint8(64)
83+
}
84+
if c.isClosed() {
85+
return ErrConnClosed
86+
}
87+
err = internal.WriteN(c.conn, frame.Bytes())
88+
binaryPool.Put(frame)
89+
return err
90+
}
91+
92+
if c.pd.Enabled {
93+
var deflater = c.getBigDeflater()
94+
var fw = &flateWriter{cb: cb}
95+
err := deflater.Compress(payload, fw, c.getCpsDict(false), &c.cpsWindow)
96+
c.putBigDeflater(deflater)
97+
return err
98+
} else {
99+
return c.splitReader(payload, cb)
100+
}
101+
}
102+
103+
// 大文件压缩器
104+
type bigDeflater flate.Writer
105+
106+
// 创建大文件压缩器
107+
// Create a bigDeflater
108+
func newBigDeflater(isServer bool, options PermessageDeflate) *bigDeflater {
109+
windowBits := internal.SelectValue(isServer, options.ServerMaxWindowBits, options.ClientMaxWindowBits)
110+
if windowBits == 15 {
111+
cpsWriter, _ := flate.NewWriter(nil, options.Level)
112+
return (*bigDeflater)(cpsWriter)
113+
} else {
114+
cpsWriter, _ := flate.NewWriterWindow(nil, internal.BinaryPow(windowBits))
115+
return (*bigDeflater)(cpsWriter)
116+
}
117+
}
118+
119+
func (c *bigDeflater) FlateWriter() *flate.Writer { return (*flate.Writer)(c) }
120+
121+
// Compress 压缩
122+
func (c *bigDeflater) Compress(src io.Reader, dst *flateWriter, dict []byte, sw *slideWindow) error {
123+
if err := compressTo(c.FlateWriter(), &readerWrapper{r: src, sw: sw}, dst, dict); err != nil {
124+
return err
125+
}
126+
return dst.Flush()
127+
}
128+
129+
// 写入代理
130+
// 将切片透传给回调函数, 以实现分段写入功能
131+
// Write proxy
132+
// Passthrough slices to the callback function for segmented writes.
133+
type flateWriter struct {
134+
index int
135+
buffers []*bytes.Buffer
136+
cb func(index int, eof bool, p []byte) error
137+
}
138+
139+
// 是否可以执行回调函数
140+
// Whether the callback function can be executed
141+
func (c *flateWriter) shouldCall() bool {
142+
var n = len(c.buffers)
143+
if n < 2 {
144+
return false
145+
}
146+
var sum = 0
147+
for i := 1; i < n; i++ {
148+
sum += c.buffers[i].Len()
149+
}
150+
return sum >= 4
151+
}
152+
153+
// 聚合写入, 减少syscall.write调用次数
154+
// Aggregate writes, reducing the number of syscall.write calls
155+
func (c *flateWriter) write(p []byte) {
156+
if len(c.buffers) == 0 {
157+
c.buffers = append(c.buffers, binaryPool.Get(segmentSize))
158+
}
159+
var n = len(c.buffers)
160+
var tail = c.buffers[n-1]
161+
if tail.Len()+len(p)+frameHeaderSize > tail.Cap() {
162+
tail = binaryPool.Get(segmentSize)
163+
c.buffers = append(c.buffers, tail)
164+
}
165+
tail.Write(p)
166+
}
167+
168+
func (c *flateWriter) Write(p []byte) (n int, err error) {
169+
c.write(p)
170+
if c.shouldCall() {
171+
err = c.cb(c.index, false, c.buffers[0].Bytes())
172+
binaryPool.Put(c.buffers[0])
173+
c.buffers = c.buffers[1:]
174+
c.index++
175+
}
176+
return n, err
177+
}
178+
179+
func (c *flateWriter) Flush() error {
180+
var buf = c.buffers[0]
181+
for i := 1; i < len(c.buffers); i++ {
182+
buf.Write(c.buffers[i].Bytes())
183+
binaryPool.Put(c.buffers[i])
184+
}
185+
if n := buf.Len(); n >= 4 {
186+
if tail := buf.Bytes()[n-4:]; binary.BigEndian.Uint32(tail) == math.MaxUint16 {
187+
buf.Truncate(n - 4)
188+
}
189+
}
190+
var err = c.cb(c.index, true, buf.Bytes())
191+
c.index++
192+
binaryPool.Put(buf)
193+
return err
194+
}
195+
196+
// 将io.Reader包装为io.WriterTo
197+
// Wrapping io.Reader as io.WriterTo
198+
type readerWrapper struct {
199+
r io.Reader
200+
sw *slideWindow
201+
}
202+
203+
// WriteTo 写入内容, 并更新字典
204+
// Write the contents, and update the dictionary
205+
func (c *readerWrapper) WriteTo(w io.Writer) (int64, error) {
206+
var buf = binaryPool.Get(segmentSize)
207+
defer binaryPool.Put(buf)
208+
209+
var p = buf.Bytes()[:segmentSize]
210+
var sum, n = 0, 0
211+
var err error
212+
for n, err = c.r.Read(p); err == nil || errors.Is(err, io.EOF); n, err = c.r.Read(p) {
213+
eof := errors.Is(err, io.EOF)
214+
if _, err = w.Write(p[:n]); err != nil {
215+
return int64(sum), err
216+
}
217+
sum += n
218+
_, _ = c.sw.Write(p[:n])
219+
if eof {
220+
break
221+
}
222+
}
223+
return int64(sum), err
224+
}
225+
226+
func compressTo(cpsWriter *flate.Writer, r io.WriterTo, w io.Writer, dict []byte) error {
227+
cpsWriter.ResetDict(w, dict)
228+
if _, err := r.WriteTo(cpsWriter); err != nil {
229+
return err
230+
}
231+
return cpsWriter.Flush()
232+
}

client.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ func (c *connector) handshake() (*Conn, *http.Response, error) {
190190
writeQueue: workerQueue{maxConcurrency: 1},
191191
readQueue: make(channel, c.option.ParallelGolimit),
192192
}
193+
194+
// 压缩字典和解压字典内存开销比较大, 故使用懒加载
195+
// Compressing and decompressing dictionaries has a large memory overhead, so use lazy loading.
193196
if pd.Enabled {
194197
socket.deflater.initialize(false, pd, c.option.ReadMaxPayloadSize)
195198
if pd.ServerContextTakeover {

compress.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,11 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er
100100
func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte) error {
101101
c.cpsLocker.Lock()
102102
defer c.cpsLocker.Unlock()
103-
104-
c.cpsWriter.ResetDict(dst, dict)
105-
if _, err := src.WriteTo(c.cpsWriter); err != nil {
106-
return err
107-
}
108-
if err := c.cpsWriter.Flush(); err != nil {
103+
if err := compressTo(c.cpsWriter, src, dst, dict); err != nil {
109104
return err
110105
}
111106
if n := dst.Len(); n >= 4 {
112-
compressedContent := dst.Bytes()
113-
if tail := compressedContent[n-4:]; binary.BigEndian.Uint32(tail) == math.MaxUint16 {
107+
if tail := dst.Bytes()[n-4:]; binary.BigEndian.Uint32(tail) == math.MaxUint16 {
114108
dst.Truncate(n - 4)
115109
}
116110
}

compress_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,7 @@ func (c *writerTo) Len() int {
259259
func (c *writerTo) WriteTo(w io.Writer) (n int64, err error) {
260260
return 0, errors.New("1")
261261
}
262+
263+
func (c *writerTo) Read(p []byte) (n int, err error) {
264+
return 0, errors.New("1")
265+
}

examples/chatroom/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func main() {
6161

6262
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
6363
if value, exist := session.Load(key); exist {
64-
v = value.(T)
64+
v, _ = value.(T)
6565
}
6666
return
6767
}

0 commit comments

Comments
 (0)