Skip to content

Commit 9fad4c0

Browse files
committed
refactoring
1 parent e8b96f2 commit 9fad4c0

File tree

9 files changed

+104
-114
lines changed

9 files changed

+104
-114
lines changed

benchmark_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,8 @@ func BenchmarkInterpolation(b *testing.B) {
234234
},
235235
maxAllowedPacket: maxPacketSize,
236236
maxWriteSize: maxPacketSize - 1,
237-
buf: newBuffer(nil),
237+
buf: newBuffer(),
238238
}
239-
mc.packetRW = &mc.buf
240239

241240
args := []driver.Value{
242241
int64(42424242),

buffer.go

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,30 @@ package mysql
1010

1111
import (
1212
"io"
13-
"net"
14-
"time"
1513
)
1614

1715
const defaultBufSize = 4096
1816
const maxCachedBufSize = 256 * 1024
1917

18+
// readwriteFunc is a function that compatible with io.Reader and io.Writer.
19+
// We use this function type instead of io.ReadWriter because we want to
20+
// just pass mc.readWithTimeout or mc.writeWithTimeout functions.
21+
type readwriteFunc func([]byte) (int, error)
22+
2023
// A buffer which is used for both reading and writing.
2124
// This is possible since communication on each connection is synchronous.
2225
// In other words, we can't write and read simultaneously on the same connection.
2326
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
2427
// Also highly optimized for this particular use case.
2528
type buffer struct {
26-
buf []byte // read buffer.
27-
cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize.
28-
nc net.Conn
29-
readTimeout time.Duration
30-
writeTimeout time.Duration
29+
buf []byte // read buffer.
30+
cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize.
3131
}
3232

3333
// newBuffer allocates and returns a new buffer.
34-
func newBuffer(nc net.Conn) buffer {
34+
func newBuffer() buffer {
3535
return buffer{
3636
cachedBuf: make([]byte, defaultBufSize),
37-
nc: nc,
3837
}
3938
}
4039

@@ -44,7 +43,7 @@ func (b *buffer) busy() bool {
4443
}
4544

4645
// fill reads into the read buffer until at least _need_ bytes are in it.
47-
func (b *buffer) fill(need int) error {
46+
func (b *buffer) fill(need int, r readwriteFunc) error {
4847
// we'll move the contents of the current buffer to dest before filling it.
4948
dest := b.cachedBuf
5049

@@ -65,13 +64,7 @@ func (b *buffer) fill(need int) error {
6564
copy(dest[:n], b.buf)
6665

6766
for {
68-
if b.readTimeout > 0 {
69-
if err := b.nc.SetReadDeadline(time.Now().Add(b.readTimeout)); err != nil {
70-
return err
71-
}
72-
}
73-
74-
nn, err := b.nc.Read(dest[n:])
67+
nn, err := r(dest[n:])
7568
n += nn
7669

7770
if err == nil && n < need {
@@ -93,10 +86,10 @@ func (b *buffer) fill(need int) error {
9386

9487
// returns next N bytes from buffer.
9588
// The returned slice is only guaranteed to be valid until the next read
96-
func (b *buffer) readNext(need int) ([]byte, error) {
89+
func (b *buffer) readNext(need int, r readwriteFunc) ([]byte, error) {
9790
if len(b.buf) < need {
9891
// refill
99-
if err := b.fill(need); err != nil {
92+
if err := b.fill(need, r); err != nil {
10093
return nil, err
10194
}
10295
}
@@ -156,14 +149,3 @@ func (b *buffer) store(buf []byte) {
156149
b.cachedBuf = buf[:cap(buf)]
157150
}
158151
}
159-
160-
// writePackets is a proxy function to nc.Write.
161-
// This is used to make the buffer type compatible with compressed I/O.
162-
func (b *buffer) writePackets(packets []byte) (int, error) {
163-
if b.writeTimeout > 0 {
164-
if err := b.nc.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil {
165-
return 0, err
166-
}
167-
}
168-
return b.nc.Write(packets)
169-
}

compress.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func zDecompress(src, dst []byte) (int, error) {
4747
}
4848
} else {
4949
zr = a.(io.ReadCloser)
50-
if zr.(zlib.Resetter).Reset(br, nil); err != nil {
50+
if err := zr.(zlib.Resetter).Reset(br, nil); err != nil {
5151
return 0, err
5252
}
5353
}
@@ -96,18 +96,18 @@ func newCompIO(mc *mysqlConn) *compIO {
9696
}
9797
}
9898

99-
func (c *compIO) readNext(need int) ([]byte, error) {
99+
func (c *compIO) readNext(need int, r readwriteFunc) ([]byte, error) {
100100
for c.buff.Len() < need {
101-
if err := c.readCompressedPacket(); err != nil {
101+
if err := c.readCompressedPacket(r); err != nil {
102102
return nil, err
103103
}
104104
}
105105
data := c.buff.Next(need)
106106
return data[:need:need], nil // prevent caller writes into c.buff
107107
}
108108

109-
func (c *compIO) readCompressedPacket() error {
110-
header, err := c.mc.buf.readNext(7) // size of compressed header
109+
func (c *compIO) readCompressedPacket(r readwriteFunc) error {
110+
header, err := c.mc.buf.readNext(7, r) // size of compressed header
111111
if err != nil {
112112
return err
113113
}
@@ -134,7 +134,7 @@ func (c *compIO) readCompressedPacket() error {
134134
c.mc.sequence = compressionSequence + 1
135135
c.mc.compressSequence = c.mc.sequence
136136

137-
comprData, err := c.mc.buf.readNext(comprLength)
137+
comprData, err := c.mc.buf.readNext(comprLength, r)
138138
if err != nil {
139139
return err
140140
}
@@ -221,7 +221,7 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) error {
221221
data[3] = mc.compressSequence
222222
putUint24(data[4:7], uncompressedLen)
223223

224-
if _, err := mc.buf.writePackets(data); err != nil {
224+
if _, err := mc.writeWithTimeout(data); err != nil {
225225
mc.log("writing compressed packet:", err)
226226
return err
227227
}

compress_test.go

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ package mysql
1111
import (
1212
"bytes"
1313
"crypto/rand"
14-
"fmt"
1514
"io"
1615
"testing"
1716
)
@@ -26,15 +25,12 @@ func makeRandByteSlice(size int) []byte {
2625
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
2726
conn := new(mockConn)
2827
mc.netConn = conn
29-
comp := newCompIO(mc)
3028

31-
n, err := comp.writePackets(uncompressedPacket)
29+
err := mc.writePacket(append(make([]byte, 4), uncompressedPacket...))
3230
if err != nil {
3331
t.Fatal(err)
3432
}
35-
if n != len(uncompressedPacket) {
36-
t.Fatalf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n)
37-
}
33+
3834
return conn.written
3935
}
4036

@@ -43,10 +39,9 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS
4339
// mocking out buf variable
4440
conn := new(mockConn)
4541
conn.data = compressedPacket
46-
mc.buf.nc = conn
47-
cr := newCompIO(mc)
42+
mc.netConn = conn
4843

49-
uncompressedPacket, err := cr.readNext(expSize)
44+
uncompressedPacket, err := mc.readPacket()
5045
if err != nil {
5146
if err != io.EOF {
5247
t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error())
@@ -72,8 +67,6 @@ func TestRoundtrip(t *testing.T) {
7267
}{
7368
{uncompressed: []byte("a"),
7469
desc: "a"},
75-
{uncompressed: []byte{0},
76-
desc: "0 byte"},
7770
{uncompressed: []byte("hello world"),
7871
desc: "hello world"},
7972
{uncompressed: make([]byte, 100),
@@ -82,8 +75,6 @@ func TestRoundtrip(t *testing.T) {
8275
desc: "32768 bytes"},
8376
{uncompressed: make([]byte, 330000),
8477
desc: "33000 bytes"},
85-
{uncompressed: make([]byte, 0),
86-
desc: "nothing"},
8778
{uncompressed: makeRandByteSlice(10),
8879
desc: "10 rand bytes",
8980
},
@@ -100,26 +91,29 @@ func TestRoundtrip(t *testing.T) {
10091

10192
_, cSend := newRWMockConn(0)
10293
cSend.compress = true
94+
cSend.compIO = newCompIO(cSend)
10395
_, cReceive := newRWMockConn(0)
10496
cReceive.compress = true
97+
cReceive.compIO = newCompIO(cReceive)
10598

10699
for _, test := range tests {
107-
s := fmt.Sprintf("Test roundtrip with %s", test.desc)
108-
cSend.resetSequenceNr()
109-
cReceive.resetSequenceNr()
100+
t.Run(test.desc, func(t *testing.T) {
101+
cSend.resetSequenceNr()
102+
cReceive.resetSequenceNr()
110103

111-
uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed)
112-
if !bytes.Equal(uncompressed, test.uncompressed) {
113-
t.Fatalf("%s: roundtrip failed", s)
114-
}
104+
uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed)
105+
if !bytes.Equal(uncompressed, test.uncompressed) {
106+
t.Fatalf("roundtrip failed")
107+
}
115108

116-
if cSend.sequence != cReceive.sequence {
117-
t.Errorf("inconsistent sequence number: send=%v recv=%v",
118-
cSend.sequence, cReceive.sequence)
119-
}
120-
if cSend.compressSequence != cReceive.compressSequence {
121-
t.Errorf("inconsistent compress sequence number: send=%v recv=%v",
122-
cSend.compressSequence, cReceive.compressSequence)
123-
}
109+
if cSend.sequence != cReceive.sequence {
110+
t.Errorf("inconsistent sequence number: send=%v recv=%v",
111+
cSend.sequence, cReceive.sequence)
112+
}
113+
if cSend.compressSequence != cReceive.compressSequence {
114+
t.Errorf("inconsistent compress sequence number: send=%v recv=%v",
115+
cSend.compressSequence, cReceive.compressSequence)
116+
}
117+
})
124118
}
125119
}

connection.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ type mysqlConn struct {
2828
netConn net.Conn
2929
rawConn net.Conn // underlying connection when netConn is TLS connection.
3030
result mysqlResult // managed by clearResult() and handleOkPacket().
31-
packetRW packetIO
31+
compIO *compIO
3232
cfg *Config
3333
connector *connector
3434
maxAllowedPacket int
@@ -64,9 +64,24 @@ func (mc *mysqlConn) log(v ...any) {
6464
mc.cfg.Logger.Print(v...)
6565
}
6666

67-
type packetIO interface {
68-
readNext(need int) ([]byte, error)
69-
writePackets(data []byte) (int, error)
67+
func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) {
68+
to := mc.cfg.ReadTimeout
69+
if to > 0 {
70+
if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil {
71+
return 0, err
72+
}
73+
}
74+
return mc.netConn.Read(b)
75+
}
76+
77+
func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) {
78+
to := mc.cfg.WriteTimeout
79+
if to > 0 {
80+
if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil {
81+
return 0, err
82+
}
83+
}
84+
return mc.netConn.Write(b)
7085
}
7186

7287
func (mc *mysqlConn) resetSequenceNr() {

connection_test.go

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@ import (
1919

2020
func TestInterpolateParams(t *testing.T) {
2121
mc := &mysqlConn{
22-
buf: newBuffer(nil),
22+
buf: newBuffer(),
2323
maxAllowedPacket: maxPacketSize,
2424
cfg: &Config{
2525
InterpolateParams: true,
2626
},
2727
}
28-
mc.packetRW = &mc.buf
2928

3029
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
3130
if err != nil {
@@ -40,7 +39,7 @@ func TestInterpolateParams(t *testing.T) {
4039

4140
func TestInterpolateParamsJSONRawMessage(t *testing.T) {
4241
mc := &mysqlConn{
43-
buf: newBuffer(nil),
42+
buf: newBuffer(),
4443
maxAllowedPacket: maxPacketSize,
4544
cfg: &Config{
4645
InterpolateParams: true,
@@ -67,13 +66,12 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) {
6766

6867
func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
6968
mc := &mysqlConn{
70-
buf: newBuffer(nil),
69+
buf: newBuffer(),
7170
maxAllowedPacket: maxPacketSize,
7271
cfg: &Config{
7372
InterpolateParams: true,
7473
},
7574
}
76-
mc.packetRW = &mc.buf
7775

7876
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
7977
if err != driver.ErrSkip {
@@ -85,15 +83,13 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
8583
// https://github.com/go-sql-driver/mysql/pull/490
8684
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
8785
mc := &mysqlConn{
88-
buf: newBuffer(nil),
86+
buf: newBuffer(),
8987
maxAllowedPacket: maxPacketSize,
9088
cfg: &Config{
9189
InterpolateParams: true,
9290
},
9391
}
9492

95-
mc.packetRW = &mc.buf
96-
9793
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
9894
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
9995
if err != driver.ErrSkip {
@@ -103,7 +99,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) {
10399

104100
func TestInterpolateParamsUint64(t *testing.T) {
105101
mc := &mysqlConn{
106-
buf: newBuffer(nil),
102+
buf: newBuffer(),
107103
maxAllowedPacket: maxPacketSize,
108104
cfg: &Config{
109105
InterpolateParams: true,
@@ -164,11 +160,10 @@ func TestCleanCancel(t *testing.T) {
164160
func TestPingMarkBadConnection(t *testing.T) {
165161
nc := badConnection{err: errors.New("boom")}
166162

167-
buf := newBuffer(nc)
163+
buf := newBuffer()
168164
mc := &mysqlConn{
169165
netConn: nc,
170166
buf: buf,
171-
packetRW: &buf,
172167
maxAllowedPacket: defaultMaxAllowedPacket,
173168
closech: make(chan struct{}),
174169
cfg: NewConfig(),
@@ -184,11 +179,9 @@ func TestPingMarkBadConnection(t *testing.T) {
184179
func TestPingErrInvalidConn(t *testing.T) {
185180
nc := badConnection{err: errors.New("failed to write"), n: 10}
186181

187-
buf := newBuffer(nc)
188182
mc := &mysqlConn{
189183
netConn: nc,
190-
buf: buf,
191-
packetRW: &buf,
184+
buf: newBuffer(),
192185
maxAllowedPacket: defaultMaxAllowedPacket,
193186
closech: make(chan struct{}),
194187
cfg: NewConfig(),

0 commit comments

Comments
 (0)