Skip to content

Commit 2abb4ae

Browse files
author
Evan Shaw
committed
Add context support for reads
1 parent 669fc71 commit 2abb4ae

9 files changed

+88
-77
lines changed

buffer.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func newBuffer(nc net.Conn) buffer {
3838
}
3939

4040
// fill reads into the buffer until at least _need_ bytes are in it
41-
func (b *buffer) fill(need int) error {
41+
func (b *buffer) fill(ctx mysqlContext, need int) error {
4242
n := b.length
4343

4444
// move existing data to the beginning
@@ -59,8 +59,14 @@ func (b *buffer) fill(need int) error {
5959
b.idx = 0
6060

6161
for {
62-
if b.timeout > 0 {
63-
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
62+
var deadline time.Time
63+
if ctxDeadline, ok := ctx.Deadline(); ok {
64+
deadline = ctxDeadline
65+
} else if b.timeout > 0 {
66+
deadline = time.Now().Add(b.timeout)
67+
}
68+
if !deadline.IsZero() {
69+
if err := b.nc.SetReadDeadline(deadline); err != nil {
6470
return err
6571
}
6672
}
@@ -91,10 +97,10 @@ func (b *buffer) fill(need int) error {
9197

9298
// returns next N bytes from buffer.
9399
// The returned slice is only guaranteed to be valid until the next read
94-
func (b *buffer) readNext(need int) ([]byte, error) {
100+
func (b *buffer) readNext(ctx mysqlContext, need int) ([]byte, error) {
95101
if b.length < need {
96102
// refill
97-
if err := b.fill(need); err != nil {
103+
if err := b.fill(ctx, need); err != nil {
98104
return nil, err
99105
}
100106
}

connection.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,16 @@ func (mc *mysqlConn) prepareContext(ctx mysqlContext, query string) (driver.Stmt
128128
}
129129

130130
// Read Result
131-
columnCount, err := stmt.readPrepareResultPacket()
131+
columnCount, err := stmt.readPrepareResultPacket(ctx)
132132
if err == nil {
133133
if stmt.paramCount > 0 {
134-
if err = mc.readUntilEOF(); err != nil {
134+
if err = mc.readUntilEOF(ctx); err != nil {
135135
return nil, err
136136
}
137137
}
138138

139139
if columnCount > 0 {
140-
err = mc.readUntilEOF()
140+
err = mc.readUntilEOF(ctx)
141141
}
142142
}
143143

@@ -307,24 +307,24 @@ func (mc *mysqlConn) exec(ctx mysqlContext, query string) error {
307307
}
308308

309309
// Read Result
310-
resLen, err := mc.readResultSetHeaderPacket()
310+
resLen, err := mc.readResultSetHeaderPacket(ctx)
311311
if err != nil {
312312
return err
313313
}
314314

315315
if resLen > 0 {
316316
// columns
317-
if err := mc.readUntilEOF(); err != nil {
317+
if err := mc.readUntilEOF(ctx); err != nil {
318318
return err
319319
}
320320

321321
// rows
322-
if err := mc.readUntilEOF(); err != nil {
322+
if err := mc.readUntilEOF(ctx); err != nil {
323323
return err
324324
}
325325
}
326326

327-
return mc.discardResults()
327+
return mc.discardResults(ctx)
328328
}
329329

330330
// Query implements driver.Queryer interface
@@ -353,7 +353,7 @@ func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver.
353353
if err == nil {
354354
// Read Result
355355
var resLen int
356-
resLen, err = mc.readResultSetHeaderPacket()
356+
resLen, err = mc.readResultSetHeaderPacket(ctx)
357357
if err == nil {
358358
rows := new(textRows)
359359
rows.mc = mc
@@ -369,7 +369,7 @@ func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver.
369369
}
370370
}
371371
// Columns
372-
rows.rs.columns, err = mc.readColumns(resLen)
372+
rows.rs.columns, err = mc.readColumns(ctx, resLen)
373373
return rows, err
374374
}
375375
}
@@ -378,29 +378,29 @@ func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver.
378378

379379
// Gets the value of the given MySQL System Variable
380380
// The returned byte slice is only valid until the next read
381-
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
381+
func (mc *mysqlConn) getSystemVar(ctx mysqlContext, name string) ([]byte, error) {
382382
// Send command
383-
if err := mc.writeCommandPacketStr(backgroundCtx(), comQuery, "SELECT @@"+name); err != nil {
383+
if err := mc.writeCommandPacketStr(ctx, comQuery, "SELECT @@"+name); err != nil {
384384
return nil, err
385385
}
386386

387387
// Read Result
388-
resLen, err := mc.readResultSetHeaderPacket()
388+
resLen, err := mc.readResultSetHeaderPacket(ctx)
389389
if err == nil {
390390
rows := new(textRows)
391391
rows.mc = mc
392392
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
393393

394394
if resLen > 0 {
395395
// Columns
396-
if err := mc.readUntilEOF(); err != nil {
396+
if err := mc.readUntilEOF(ctx); err != nil {
397397
return nil, err
398398
}
399399
}
400400

401401
dest := make([]driver.Value, resLen)
402-
if err = rows.readRow(dest); err == nil {
403-
return dest[0].([]byte), mc.readUntilEOF()
402+
if err = rows.readRow(ctx, dest); err == nil {
403+
return dest[0].([]byte), mc.readUntilEOF(ctx)
404404
}
405405
}
406406
return nil, err

connection_ctx.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) error {
2626
return err
2727
}
2828

29-
if _, err := mc.readResultOK(); err != nil {
29+
if _, err := mc.readResultOK(ctx); err != nil {
3030
errLog.Print(err)
3131
return err
3232
}

driver.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
8989
mc.writeTimeout = mc.cfg.WriteTimeout
9090

9191
// Reading Handshake Initialization Packet
92-
cipher, err := mc.readInitPacket()
92+
cipher, err := mc.readInitPacket(backgroundCtx())
9393
if err != nil {
9494
mc.cleanup()
9595
return nil, err
@@ -114,7 +114,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
114114
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
115115
} else {
116116
// Get max allowed packet size
117-
maxap, err := mc.getSystemVar("max_allowed_packet")
117+
maxap, err := mc.getSystemVar(backgroundCtx(), "max_allowed_packet")
118118
if err != nil {
119119
mc.Close()
120120
return nil, err
@@ -137,7 +137,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
137137

138138
func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
139139
// Read Result Packet
140-
cipher, err := mc.readResultOK()
140+
cipher, err := mc.readResultOK(backgroundCtx())
141141
if err == nil {
142142
return nil // auth successful
143143
}
@@ -161,20 +161,20 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
161161
if err = mc.writeOldAuthPacket(backgroundCtx(), cipher); err != nil {
162162
return err
163163
}
164-
_, err = mc.readResultOK()
164+
_, err = mc.readResultOK(backgroundCtx())
165165
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
166166
// Retry with clear text password for
167167
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
168168
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
169169
if err = mc.writeClearAuthPacket(backgroundCtx()); err != nil {
170170
return err
171171
}
172-
_, err = mc.readResultOK()
172+
_, err = mc.readResultOK(backgroundCtx())
173173
} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword {
174174
if err = mc.writeNativeAuthPacket(backgroundCtx(), cipher); err != nil {
175175
return err
176176
}
177-
_, err = mc.readResultOK()
177+
_, err = mc.readResultOK(backgroundCtx())
178178
}
179179
return err
180180
}

infile.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ func (mc *mysqlConn) handleInFileRequest(ctx mysqlContext, name string) (err err
173173

174174
// read OK packet
175175
if err == nil {
176-
_, err = mc.readResultOK()
176+
_, err = mc.readResultOK(ctx)
177177
return err
178178
}
179179

180-
mc.readPacket()
180+
mc.readPacket(ctx)
181181
return err
182182
}

packets.go

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,16 @@ import (
2424
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
2525

2626
// Read packet to buffer 'data'
27-
func (mc *mysqlConn) readPacket() ([]byte, error) {
27+
func (mc *mysqlConn) readPacket(ctx mysqlContext) ([]byte, error) {
28+
ctxDeadline, isCtxDeadlineSet := ctx.Deadline()
29+
if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) {
30+
return nil, deadlineExceeded
31+
}
32+
2833
var prevData []byte
2934
for {
3035
// read packet header
31-
data, err := mc.buf.readNext(4)
36+
data, err := mc.buf.readNext(ctx, 4)
3237
if err != nil {
3338
errLog.Print(err)
3439
mc.Close()
@@ -61,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
6166
}
6267

6368
// read packet body [pktLen bytes]
64-
data, err = mc.buf.readNext(pktLen)
69+
data, err = mc.buf.readNext(ctx, pktLen)
6570
if err != nil {
6671
errLog.Print(err)
6772
mc.Close()
@@ -152,8 +157,8 @@ func (mc *mysqlConn) writePacket(ctx mysqlContext, data []byte) error {
152157

153158
// Handshake Initialization Packet
154159
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
155-
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
156-
data, err := mc.readPacket()
160+
func (mc *mysqlConn) readInitPacket(ctx mysqlContext) ([]byte, error) {
161+
data, err := mc.readPacket(ctx)
157162
if err != nil {
158163
return nil, err
159164
}
@@ -484,8 +489,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx mysqlContext, command byte, ar
484489
******************************************************************************/
485490

486491
// Returns error if Packet is not an 'Result OK'-Packet
487-
func (mc *mysqlConn) readResultOK() ([]byte, error) {
488-
data, err := mc.readPacket()
492+
func (mc *mysqlConn) readResultOK(ctx mysqlContext) ([]byte, error) {
493+
data, err := mc.readPacket(ctx)
489494
if err == nil {
490495
// packet indicator
491496
switch data[0] {
@@ -526,8 +531,8 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
526531

527532
// Result Set Header Packet
528533
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
529-
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
530-
data, err := mc.readPacket()
534+
func (mc *mysqlConn) readResultSetHeaderPacket(ctx mysqlContext) (int, error) {
535+
data, err := mc.readPacket(ctx)
531536
if err == nil {
532537
switch data[0] {
533538

@@ -616,11 +621,11 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
616621

617622
// Read Packets as Field Packets until EOF-Packet or an Error appears
618623
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
619-
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
624+
func (mc *mysqlConn) readColumns(ctx mysqlContext, count int) ([]mysqlField, error) {
620625
columns := make([]mysqlField, count)
621626

622627
for i := 0; ; i++ {
623-
data, err := mc.readPacket()
628+
data, err := mc.readPacket(ctx)
624629
if err != nil {
625630
return nil, err
626631
}
@@ -709,14 +714,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
709714

710715
// Read Packets as Field Packets until EOF-Packet or an Error appears
711716
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
712-
func (rows *textRows) readRow(dest []driver.Value) error {
717+
func (rows *textRows) readRow(ctx mysqlContext, dest []driver.Value) error {
713718
mc := rows.mc
714719

715720
if rows.rs.done {
716721
return io.EOF
717722
}
718723

719-
data, err := mc.readPacket()
724+
data, err := mc.readPacket(ctx)
720725
if err != nil {
721726
return err
722727
}
@@ -777,9 +782,9 @@ func (rows *textRows) readRow(dest []driver.Value) error {
777782
}
778783

779784
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
780-
func (mc *mysqlConn) readUntilEOF() error {
785+
func (mc *mysqlConn) readUntilEOF(ctx mysqlContext) error {
781786
for {
782-
data, err := mc.readPacket()
787+
data, err := mc.readPacket(ctx)
783788
if err != nil {
784789
return err
785790
}
@@ -802,8 +807,8 @@ func (mc *mysqlConn) readUntilEOF() error {
802807

803808
// Prepare Result Packets
804809
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
805-
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
806-
data, err := stmt.mc.readPacket()
810+
func (stmt *mysqlStmt) readPrepareResultPacket(ctx mysqlContext) (uint16, error) {
811+
data, err := stmt.mc.readPacket(ctx)
807812
if err == nil {
808813
// packet indicator [1 byte]
809814
if data[0] != iOK {
@@ -1096,19 +1101,19 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx mysqlContext, args []driver.Value)
10961101
return mc.writePacket(ctx, data)
10971102
}
10981103

1099-
func (mc *mysqlConn) discardResults() error {
1104+
func (mc *mysqlConn) discardResults(ctx mysqlContext) error {
11001105
for mc.status&statusMoreResultsExists != 0 {
1101-
resLen, err := mc.readResultSetHeaderPacket()
1106+
resLen, err := mc.readResultSetHeaderPacket(ctx)
11021107
if err != nil {
11031108
return err
11041109
}
11051110
if resLen > 0 {
11061111
// columns
1107-
if err := mc.readUntilEOF(); err != nil {
1112+
if err := mc.readUntilEOF(ctx); err != nil {
11081113
return err
11091114
}
11101115
// rows
1111-
if err := mc.readUntilEOF(); err != nil {
1116+
if err := mc.readUntilEOF(ctx); err != nil {
11121117
return err
11131118
}
11141119
}
@@ -1117,8 +1122,8 @@ func (mc *mysqlConn) discardResults() error {
11171122
}
11181123

11191124
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
1120-
func (rows *binaryRows) readRow(dest []driver.Value) error {
1121-
data, err := rows.mc.readPacket()
1125+
func (rows *binaryRows) readRow(ctx mysqlContext, dest []driver.Value) error {
1126+
data, err := rows.mc.readPacket(ctx)
11221127
if err != nil {
11231128
return err
11241129
}

0 commit comments

Comments
 (0)