Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pkg/driver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,16 @@ func (conn *BackendConnection) Connect(ctx context.Context) error {
} else {
typ = conn.conf.Net
}
netConn, err := net.Dial(typ, conn.conf.Addr)

var (
netConn net.Conn
err error
)
if conn.conf.Timeout > 0 {
netConn, err = net.DialTimeout(typ, conn.conf.Addr, conn.conf.Timeout)
} else {
netConn, err = net.Dial(typ, conn.conf.Addr)
}
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/errors/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ var (
// to trigger a resend.
// See https://github.com/go-sql-driver/mysql/pull/302
ErrBadConnNoWrite = errors.New("bad connection")

ErrUnexpectedRead = errors.New("unexpected read from socket")
)
64 changes: 64 additions & 0 deletions pkg/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net"
"strings"
"sync"
"syscall"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -117,6 +118,9 @@ type Conn struct {
// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
// It can be allocated from bufPool or heap and should be recycled in the same manner.
currentEphemeralBuffer *[]byte

ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
}

// NewConn is an internal method to create a Conn. Used by client and server
Expand Down Expand Up @@ -474,6 +478,17 @@ func (c *Conn) WritePacket(data []byte) error {
w, unget := c.getWriter()
defer unget()

if c.ReadTimeout != 0 {
err := c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
if err != nil {
return err
}
}
err := connCheck(c.conn)
if err != nil {
return err
}

for {
// Packet length is capped to MaxPacketSize.
packetLength := length
Expand All @@ -487,6 +502,13 @@ func (c *Conn) WritePacket(data []byte) error {
header[1] = byte(packetLength >> 8)
header[2] = byte(packetLength >> 16)
header[3] = c.sequence

if c.WriteTimeout > 0 {
if err := c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil {
return err
}
}

if n, err := w.Write(header[:]); err != nil {
return errors.Wrapf(err, "Write(header) failed")
} else if n != 4 {
Expand Down Expand Up @@ -995,6 +1017,14 @@ func (c *Conn) SetUserName(userName string) {
c.userName = userName
}

func (c *Conn) SetReadTimeout(readTimeout time.Duration) {
c.ReadTimeout = readTimeout
}

func (c *Conn) SetWriteTimeout(writeTimeout time.Duration) {
c.WriteTimeout = writeTimeout
}

// RemoteAddr returns the underlying socket RemoteAddr().
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
Expand Down Expand Up @@ -1034,3 +1064,37 @@ func (c *Conn) Close() {
func (c *Conn) IsClosed() bool {
return c.closed.Get()
}

func connCheck(conn net.Conn) error {
var sysErr error

sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to return error instead of nil?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only syscall.Conn SyscallConn() func return RawConn, we use RawConn to check connection alive. If a connection is't syscall.Conn does't mean it can't read.

}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}

err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = err2.ErrUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
})
if err != nil {
return err
}

return sysErr
}