Skip to content

Improve buffer handling #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Shuode Li <elemount at qq.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Xiangyu Hu <xiangyu.hu at outlook.com>
Expand All @@ -90,3 +91,4 @@ Keybase Inc.
Percona LLC
Pivotal Inc.
Stripe Inc.
Multiplay Ltd.
8 changes: 5 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data := mc.buf.takeSmallBuffer(4 + 1)
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)

// parse public key
data, err := mc.readPacket()
if err != nil {
if data, err = mc.readPacket(); err != nil {
return err
}

Expand Down
49 changes: 31 additions & 18 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ const defaultBufSize = 4096
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
}

// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
buf: make([]byte, defaultBufSize),
nc: nc,
}
}
Expand Down Expand Up @@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
return b.buf[offset:b.idx], nil
}

// returns a buffer with the requested size.
// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}

// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
if length <= cap(b.buf) {
return b.buf[:length], nil
}

if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
return b.buf, nil
}
return make([]byte, length)

// buffer is larger than we want to store.
return make([]byte, length), nil
}

// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
return b.buf[:length]
return b.buf[:length], nil
}

// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
return b.buf, nil
}

// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return nil
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return b.buf
return nil
}
6 changes: 3 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}

buf := mc.buf.takeCompleteBuffer()
if buf == nil {
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return "", ErrInvalidConn
}
buf = buf[:0]
Expand Down
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {

// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
// the DSN string is formatted
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error

Expand Down
54 changes: 30 additions & 24 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
mc.sequence++

// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)1 bytes long
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
Expand Down Expand Up @@ -286,10 +286,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}

// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -367,10 +367,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
data := mc.buf.takeSmallBuffer(pktLen)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -387,10 +387,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0

data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -406,10 +406,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.sequence = 0

pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -427,10 +427,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0

data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -883,7 +883,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc

// Determine threshould dynamically to avoid packet size shortage.
// Determine threshold dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
Expand All @@ -893,15 +893,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc.sequence = 0

var data []byte
var err error

if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if data == nil {
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}

Expand All @@ -927,7 +929,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
pos := minPktLen

var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
Expand All @@ -936,10 +938,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
// No need to clean nullMask as make ensures that.
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := 0; i < maskLen; i++ {
for i := range nullMask {
nullMask[i] = 0
}
pos += maskLen
Expand Down Expand Up @@ -1076,7 +1079,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
if err = mc.buf.store(data); err != nil {
errLog.Print(err)
return errBadConnNoWrite
}
}

pos += len(paramValues)
Expand Down