From 681f0ac0c24d7b058630c08860d9df5f330c3aae Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sun, 11 Jun 2017 12:02:10 +0900 Subject: [PATCH 1/2] add test for read timeout --- driver_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/driver_test.go b/driver_test.go index 206e07cc9..9c021ff8d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1999,3 +1999,33 @@ func TestPing(t *testing.T) { } }) } + +func TestReadTimeout(t *testing.T) { + runTests(t, dsn+"&readTimeout=500ms", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + startTime := time.Now() + + // This query will read-timeout. + if _, err := dbt.db.Exec("INSERT INTO test VALUES (SLEEP(1))"); err == nil { + dbt.Error("expected error") + } else if err, ok := err.(net.Error); !ok || !err.Timeout() { + dbt.Error("expected timeout error") + } + + if d := time.Since(startTime); d > time.Second { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} From 4dc1d49e2c3d906f42b67bee5dc7ab35e8ddab67 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sun, 11 Jun 2017 14:42:30 +0900 Subject: [PATCH 2/2] return original reading error, instead of driver.ErrBadConn. --- packets.go | 6 +++--- packets_test.go | 13 ++++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/packets.go b/packets.go index 9715067c4..c476e7b3e 100644 --- a/packets.go +++ b/packets.go @@ -35,7 +35,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, err } // packet length [24 bit] @@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if prevData == nil { errLog.Print(ErrMalformPkt) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrMalformPkt } return prevData, nil @@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, err } // return data if this was the last packet diff --git a/packets_test.go b/packets_test.go index 31c892d85..ffa38ba30 100644 --- a/packets_test.go +++ b/packets_test.go @@ -9,7 +9,6 @@ package mysql import ( - "database/sql/driver" "errors" "net" "testing" @@ -252,8 +251,8 @@ func TestReadPacketFail(t *testing.T) { conn.data = []byte{0x00, 0x00, 0x00, 0x00} conn.maxReads = 1 _, err := mc.readPacket() - if err != driver.ErrBadConn { - t.Errorf("expected ErrBadConn, got %v", err) + if err != ErrMalformPkt { + t.Errorf("expected %v, got %v", ErrMalformPkt, err) } // reset @@ -264,8 +263,8 @@ func TestReadPacketFail(t *testing.T) { // fail to read header conn.closed = true _, err = mc.readPacket() - if err != driver.ErrBadConn { - t.Errorf("expected ErrBadConn, got %v", err) + if err != errConnClosed { + t.Errorf("expected %v, got %v", errConnClosed, err) } // reset @@ -277,7 +276,7 @@ func TestReadPacketFail(t *testing.T) { // fail to read body conn.maxReads = 1 _, err = mc.readPacket() - if err != driver.ErrBadConn { - t.Errorf("expected ErrBadConn, got %v", err) + if err != errConnTooManyReads { + t.Errorf("expected %v, got %v", errConnTooManyReads, err) } }