Skip to content

Commit 6e8c06c

Browse files
committed
Fix buffer allocation issues
Fix a number or issues with how buffers for connections are processed including: * Fix a number of uses of len(..) instead of cap(..) eliminating unnecessary buffer reallocation and invalid assignment of buffers smaller than defaultBufSize, which could cause a panic. * Eliminate redundant size test in takeBuffer. * Change buffer takeXXX functions to return an error to make it explicit that they can fail and fix missing error check in handleAuthResult. * Always initialise nullMask in writeExecutePacket. Also: * Fix some typo's and unnecessary UTF-8 characters in comments. * Add benchmarks for buffer and connection creation to validate memory allocations. * Add myself / company to AUTHORS.
1 parent fd197cd commit 6e8c06c

File tree

7 files changed

+87
-50
lines changed

7 files changed

+87
-50
lines changed

AUTHORS

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Shuode Li <elemount at qq.com>
7373
Soroush Pour <me at soroushjp.com>
7474
Stan Putrya <root.vagner at gmail.com>
7575
Stanley Gunawan <gunawan.stanley at gmail.com>
76+
Steven Hartland <steven.hartland at multiplay.co.uk>
7677
Thomas Wodarek <wodarekwebpage at gmail.com>
7778
Tom Jenkinson <tom at tjenkinson.me>
7879
Xiangyu Hu <xiangyu.hu at outlook.com>
@@ -90,3 +91,4 @@ Keybase Inc.
9091
Percona LLC
9192
Pivotal Inc.
9293
Stripe Inc.
94+
Multiplay Ltd.

auth.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
360360
pubKey := mc.cfg.pubKey
361361
if pubKey == nil {
362362
// request public key from server
363-
data := mc.buf.takeSmallBuffer(4 + 1)
363+
data, err := mc.buf.takeSmallBuffer(4 + 1)
364+
if err != nil {
365+
return err
366+
}
364367
data[4] = cachingSha2PasswordRequestPublicKey
365368
mc.writePacket(data)
366369

367370
// parse public key
368-
data, err := mc.readPacket()
369-
if err != nil {
371+
if data, err = mc.readPacket(); err != nil {
370372
return err
371373
}
372374

benchmark_test.go

+27
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,30 @@ func BenchmarkExecContext(b *testing.B) {
317317
})
318318
}
319319
}
320+
321+
var buf buffer
322+
323+
func BenchmarkNewBuffer(b *testing.B) {
324+
b.ReportAllocs()
325+
326+
var bu buffer
327+
for i := 0; i < b.N; i++ {
328+
bu = newBuffer(nil)
329+
}
330+
331+
buf = bu
332+
}
333+
334+
var con *mysqlConn
335+
336+
func BenchmarkNewConn(b *testing.B) {
337+
b.ReportAllocs()
338+
339+
var c *mysqlConn
340+
for i := 0; i < b.N; i++ {
341+
c = &mysqlConn{}
342+
c.buf = newBuffer(nil)
343+
}
344+
345+
con = c
346+
}

buffer.go

+16-13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type buffer struct {
2929
timeout time.Duration
3030
}
3131

32+
// newBuffer allocates and returns a new buffer.
3233
func newBuffer(nc net.Conn) buffer {
3334
var b [defaultBufSize]byte
3435
return buffer{
@@ -49,11 +50,13 @@ func (b *buffer) fill(need int) error {
4950
// grow buffer if necessary
5051
// TODO: let the buffer shrink again at some point
5152
// Maybe keep the org buf slice and swap back?
52-
if need > len(b.buf) {
53+
if need > cap(b.buf) {
5354
// Round up to the next multiple of the default size
5455
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
5556
copy(newBuf, b.buf)
5657
b.buf = newBuf
58+
} else if need > len(b.buf) {
59+
b.buf = b.buf[:need]
5760
}
5861

5962
b.idx = 0
@@ -109,39 +112,39 @@ func (b *buffer) readNext(need int) ([]byte, error) {
109112
// If possible, a slice from the existing buffer is returned.
110113
// Otherwise a bigger buffer is made.
111114
// Only one buffer (total) can be used at a time.
112-
func (b *buffer) takeBuffer(length int) []byte {
115+
func (b *buffer) takeBuffer(length int) ([]byte, error) {
113116
if b.length > 0 {
114-
return nil
117+
return nil, ErrBusyBuffer
115118
}
116119

117120
// test (cheap) general case first
118-
if length <= defaultBufSize || length <= cap(b.buf) {
119-
return b.buf[:length]
121+
if length <= cap(b.buf) {
122+
return b.buf[:length], nil
120123
}
121124

122125
if length < maxPacketSize {
123126
b.buf = make([]byte, length)
124-
return b.buf
127+
return b.buf, nil
125128
}
126-
return make([]byte, length)
129+
return make([]byte, length), nil
127130
}
128131

129132
// shortcut which can be used if the requested buffer is guaranteed to be
130133
// smaller than defaultBufSize
131134
// Only one buffer (total) can be used at a time.
132-
func (b *buffer) takeSmallBuffer(length int) []byte {
135+
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
133136
if b.length > 0 {
134-
return nil
137+
return nil, ErrBusyBuffer
135138
}
136-
return b.buf[:length]
139+
return b.buf[:length], nil
137140
}
138141

139142
// takeCompleteBuffer returns the complete existing buffer.
140143
// This can be used if the necessary buffer size is unknown.
141144
// Only one buffer (total) can be used at a time.
142-
func (b *buffer) takeCompleteBuffer() []byte {
145+
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
143146
if b.length > 0 {
144-
return nil
147+
return nil, ErrBusyBuffer
145148
}
146-
return b.buf
149+
return b.buf, nil
147150
}

connection.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
182182
return "", driver.ErrSkip
183183
}
184184

185-
buf := mc.buf.takeCompleteBuffer()
186-
if buf == nil {
185+
buf, err := mc.buf.takeCompleteBuffer()
186+
if err != nil {
187187
// can not take the buffer. Something must be wrong with the connection
188-
errLog.Print(ErrBusyBuffer)
188+
errLog.Print(err)
189189
return "", ErrInvalidConn
190190
}
191191
buf = buf[:0]

driver.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
5050

5151
// Open new Connection.
5252
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
53-
// the DSN string is formated
53+
// the DSN string is formatted
5454
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
5555
var err error
5656

packets.go

+33-30
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
5151
mc.sequence++
5252

5353
// packets with length 0 terminate a previous packet which is a
54-
// multiple of (2^24)1 bytes long
54+
// multiple of (2^24)-1 bytes long
5555
if pktLen == 0 {
5656
// there was no previous packet
5757
if prevData == nil {
@@ -288,10 +288,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
288288
}
289289

290290
// Calculate packet length and get buffer with that size
291-
data := mc.buf.takeSmallBuffer(pktLen + 4)
292-
if data == nil {
291+
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
292+
if err != nil {
293293
// cannot take the buffer. Something must be wrong with the connection
294-
errLog.Print(ErrBusyBuffer)
294+
errLog.Print(err)
295295
return errBadConnNoWrite
296296
}
297297

@@ -375,10 +375,10 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
375375
if addNUL {
376376
pktLen++
377377
}
378-
data := mc.buf.takeSmallBuffer(pktLen)
379-
if data == nil {
378+
data, err := mc.buf.takeSmallBuffer(pktLen)
379+
if err != nil {
380380
// cannot take the buffer. Something must be wrong with the connection
381-
errLog.Print(ErrBusyBuffer)
381+
errLog.Print(err)
382382
return errBadConnNoWrite
383383
}
384384

@@ -399,10 +399,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
399399
// Reset Packet Sequence
400400
mc.sequence = 0
401401

402-
data := mc.buf.takeSmallBuffer(4 + 1)
403-
if data == nil {
402+
data, err := mc.buf.takeSmallBuffer(4 + 1)
403+
if err != nil {
404404
// cannot take the buffer. Something must be wrong with the connection
405-
errLog.Print(ErrBusyBuffer)
405+
errLog.Print(err)
406406
return errBadConnNoWrite
407407
}
408408

@@ -418,10 +418,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
418418
mc.sequence = 0
419419

420420
pktLen := 1 + len(arg)
421-
data := mc.buf.takeBuffer(pktLen + 4)
422-
if data == nil {
421+
data, err := mc.buf.takeBuffer(pktLen + 4)
422+
if err != nil {
423423
// cannot take the buffer. Something must be wrong with the connection
424-
errLog.Print(ErrBusyBuffer)
424+
errLog.Print(err)
425425
return errBadConnNoWrite
426426
}
427427

@@ -439,10 +439,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
439439
// Reset Packet Sequence
440440
mc.sequence = 0
441441

442-
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
443-
if data == nil {
442+
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
443+
if err != nil {
444444
// cannot take the buffer. Something must be wrong with the connection
445-
errLog.Print(ErrBusyBuffer)
445+
errLog.Print(err)
446446
return errBadConnNoWrite
447447
}
448448

@@ -895,7 +895,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
895895
const minPktLen = 4 + 1 + 4 + 1 + 4
896896
mc := stmt.mc
897897

898-
// Determine threshould dynamically to avoid packet size shortage.
898+
// Determine threshold dynamically to avoid packet size shortage.
899899
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
900900
if longDataSize < 64 {
901901
longDataSize = 64
@@ -905,15 +905,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
905905
mc.sequence = 0
906906

907907
var data []byte
908+
var err error
908909

909910
if len(args) == 0 {
910-
data = mc.buf.takeBuffer(minPktLen)
911+
data, err = mc.buf.takeBuffer(minPktLen)
911912
} else {
912-
data = mc.buf.takeCompleteBuffer()
913+
data, err = mc.buf.takeCompleteBuffer()
913914
}
914-
if data == nil {
915+
if err != nil {
915916
// cannot take the buffer. Something must be wrong with the connection
916-
errLog.Print(ErrBusyBuffer)
917+
errLog.Print(err)
917918
return errBadConnNoWrite
918919
}
919920

@@ -939,23 +940,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
939940
pos := minPktLen
940941

941942
var nullMask []byte
942-
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
943+
maskLen, typesLen := (len(args)+7)/8, 1+2*len(args)
944+
l := pos + maskLen + typesLen
945+
if l > cap(data) {
943946
// buffer has to be extended but we don't know by how much so
944947
// we depend on append after all data with known sizes fit.
945948
// We stop at that because we deal with a lot of columns here
946949
// which makes the required allocation size hard to guess.
947950
tmp := make([]byte, pos+maskLen+typesLen)
948951
copy(tmp[:pos], data[:pos])
949952
data = tmp
950-
nullMask = data[pos : pos+maskLen]
951-
pos += maskLen
952-
} else {
953-
nullMask = data[pos : pos+maskLen]
954-
for i := 0; i < maskLen; i++ {
955-
nullMask[i] = 0
956-
}
957-
pos += maskLen
953+
} else if l > len(data) {
954+
data = data[:l]
955+
}
956+
957+
nullMask = data[pos : pos+maskLen]
958+
for i := range nullMask {
959+
nullMask[i] = 0
958960
}
961+
pos += maskLen
959962

960963
// newParameterBoundFlag 1 [1 byte]
961964
data[pos] = 0x01

0 commit comments

Comments
 (0)