diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b25c9e389..e62ec7c99 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,7 +96,7 @@ jobs: - name: test run: | - go test -v '-covermode=count' '-coverprofile=coverage.out' + go test -v -race '-covermode=atomic' '-coverprofile=coverage.out' - name: Send coverage uses: shogo82148/actions-goveralls@v1 diff --git a/atomic_bool_go118.go b/atomic_bool_go118.go index 2e9a7f0b6..ff0b1e115 100644 --- a/atomic_bool_go118.go +++ b/atomic_bool_go118.go @@ -16,6 +16,22 @@ import "sync/atomic" * Sync utils * ******************************************************************************/ +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +// noCopy should implement sync.Locker from Go 1.11 +// https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc +// https://github.com/golang/go/issues/26165 +func (*noCopy) Unlock() {} + // atomicBool is an implementation of atomic.Bool for older version of Go. // it is a wrapper around uint32 for usage as a boolean value with // atomic access. diff --git a/auth.go b/auth.go index bab282bd2..b6c6c0e93 100644 --- a/auth.go +++ b/auth.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -225,12 +226,12 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } -func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { +func (mc *mysqlConn) sendEncryptedPassword(ctx context.Context, seed []byte, pub *rsa.PublicKey) error { enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) if err != nil { return err } - return mc.writeAuthSwitchPacket(enc) + return mc.writeAuthSwitchPacket(ctx, enc) } func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { @@ -296,9 +297,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } } -func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { +func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, plugin string) error { // Read Result Packet - authData, newPlugin, err := mc.readAuthResult() + authData, newPlugin, err := mc.readAuthResult(ctx) if err != nil { return err } @@ -320,12 +321,12 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { if err != nil { return err } - if err = mc.writeAuthSwitchPacket(authResp); err != nil { + if err = mc.writeAuthSwitchPacket(ctx, authResp); err != nil { return err } // Read Result Packet - authData, newPlugin, err = mc.readAuthResult() + authData, newPlugin, err = mc.readAuthResult(ctx) if err != nil { return err } @@ -346,34 +347,32 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case 1: switch authData[0] { case cachingSha2PasswordFastAuthSuccess: - if err = mc.resultUnchanged().readResultOK(); err == nil { + if err = mc.resultUnchanged().readResultOK(ctx); err == nil { return nil // auth successful } case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.TLS != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + err = mc.writeAuthSwitchPacket(ctx, append([]byte(mc.cfg.Passwd), 0)) if err != nil { return err } } else { pubKey := mc.cfg.pubKey if pubKey == nil { - // request public key from server - data, err := mc.buf.takeSmallBuffer(4 + 1) - if err != nil { - return err - } + data := mc.wbuf.takeBuffer(5) data[4] = cachingSha2PasswordRequestPublicKey - err = mc.writePacket(data) + err = mc.writePacket(ctx, data) if err != nil { return err } - if data, err = mc.readPacket(); err != nil { + packet, err := mc.readPacket(ctx) + if err != nil { return err } + data = packet.data if data[0] != iAuthMoreData { return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication") @@ -392,12 +391,12 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } // send encrypted password - err = mc.sendEncryptedPassword(oldAuthData, pubKey) + err = mc.sendEncryptedPassword(ctx, oldAuthData, pubKey) if err != nil { return err } } - return mc.resultUnchanged().readResultOK() + return mc.resultUnchanged().readResultOK(ctx) default: return ErrMalformPkt @@ -422,11 +421,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } // send encrypted password - err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + err = mc.sendEncryptedPassword(ctx, oldAuthData, pub.(*rsa.PublicKey)) if err != nil { return err } - return mc.resultUnchanged().readResultOK() + return mc.resultUnchanged().readResultOK(ctx) } default: diff --git a/auth_test.go b/auth_test.go index 3ce0ea6e0..edce42d07 100644 --- a/auth_test.go +++ b/auth_test.go @@ -9,9 +9,7 @@ package mysql import ( - "bytes" "crypto/rsa" - "crypto/tls" "crypto/x509" "encoding/pem" "fmt" @@ -75,1256 +73,1257 @@ func TestScrambleSHA256Pass(t *testing.T) { } } -func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, - 22, 41, 84, 32, 123, 43, 118} - plugin := "caching_sha2_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56, - 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15, - 184, 150, 26, 61, 57, 235} - if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 2, 0, 0, 2, 1, 3, // Fast Auth Success - 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK - } - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "" - - authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, - 22, 41, 84, 32, 123, 43, 118} - plugin := "caching_sha2_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { - t.Fatalf("unexpected written auth response (%d bytes): %v", - writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK - } - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "caching_sha2_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, - 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, - 110, 40, 139, 124, 41} - if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 2, 0, 0, 2, 1, 4, // Perform Full Authentication - } - conn.queuedReplies = [][]byte{ - // pub key response - append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), - - // OK - {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 3 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) { - t.Errorf("unexpected written data: %v", conn.written) - } -} - -func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - mc.cfg.pubKey = testPubKeyRSA - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "caching_sha2_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, - 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, - 110, 40, 139, 124, 41} - if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 2, 0, 0, 2, 1, 4, // Perform Full Authentication - } - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 2 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { - t.Errorf("unexpected written data: %v", conn.written) - } -} - -func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "caching_sha2_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // Hack to make the caching_sha2_password plugin believe that the connection - // is secure - mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, - 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, - 110, 40, 139, 124, 41} - if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 2, 0, 0, 2, 1, 4, // Perform Full Authentication - } - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 3 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { - t.Errorf("unexpected written data: %v", conn.written) - } -} - -func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { - _, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, - 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_clear_password" - - // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) - if err != ErrCleartextPassword { - t.Errorf("expected ErrCleartextPassword, got %v", err) - } -} - -func TestAuthFastCleartextPassword(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - mc.cfg.AllowCleartextPasswords = true - - authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, - 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_clear_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} - if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK - } - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastCleartextPasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "" - mc.cfg.AllowCleartextPasswords = true - - authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, - 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_clear_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{0} - if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK - } - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastNativePasswordNotAllowed(t *testing.T) { - _, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - mc.cfg.AllowNativePasswords = false - - authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, - 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_native_password" - - // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) - if err != ErrNativePassword { - t.Errorf("expected ErrNativePassword, got %v", err) - } -} - -func TestAuthFastNativePassword(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, - 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_native_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252, - 172, 50, 211, 192, 240, 164, 26, 48, 207, 45} - if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK - } - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastNativePasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "" - - authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, - 103, 26, 95, 81, 17, 24, 21} - plugin := "mysql_native_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { - t.Fatalf("unexpected written auth response (%d bytes): %v", - writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response - conn.data = []byte{ - 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK - } - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastSHA256PasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "" - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{0} - if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response (pub key response) - conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 2 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { - t.Errorf("unexpected written data: %v", conn.written) - } -} - -func TestAuthFastSHA256PasswordRSA(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{1} - if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response (pub key response) - conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 2 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { - t.Errorf("unexpected written data: %v", conn.written) - } -} - -func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - mc.cfg.pubKey = testPubKeyRSA - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // auth response (OK) - conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } -} - -func TestAuthFastSHA256PasswordSecure(t *testing.T) { - conn, mc := newRWMockConn(1) - mc.cfg.User = "root" - mc.cfg.Passwd = "secret" - - // hack to make the caching_sha2_password plugin believe that the connection - // is secure - mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} - - authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, - 62, 94, 83, 80, 52, 85} - plugin := "sha256_password" - - // send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - t.Fatal(err) - } - - // unset TLS config to prevent the actual establishment of a TLS wrapper - mc.cfg.TLS = nil - - err = mc.writeHandshakeResponsePacket(authResp, plugin) - if err != nil { - t.Fatal(err) - } - - // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] - writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} - if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) - } - conn.written = nil - - // auth response (OK) - conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} - conn.maxReads = 1 - - // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - if !bytes.Equal(conn.written, []byte{}) { - t.Errorf("unexpected written data: %v", conn.written) - } -} - -func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, - 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, - 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, - 50, 0} - - // auth response - conn.queuedReplies = [][]byte{ - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK - } - conn.maxReads = 3 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{ - // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, - } - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "" - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, - 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, - 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, - 50, 0} - - // auth response - conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{0, 0, 0, 3} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, - 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, - 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, - 50, 0} - - conn.queuedReplies = [][]byte{ - // Perform Full Authentication - {2, 0, 0, 4, 1, 4}, - - // Pub Key Response - append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...), - - // OK - {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 4 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReplyPrefix := []byte{ - // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, - - // 2. Packet: Pub Key Request - 1, 0, 0, 5, 2, - - // 3. Packet: Encrypted Password - 0, 1, 0, 7, // [changing bytes] - } - if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - mc.cfg.pubKey = testPubKeyRSA - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, - 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, - 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, - 50, 0} - - conn.queuedReplies = [][]byte{ - // Perform Full Authentication - {2, 0, 0, 4, 1, 4}, - - // OK - {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 3 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReplyPrefix := []byte{ - // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, - - // 2. Packet: Encrypted Password - 0, 1, 0, 5, // [changing bytes] - } - if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - - // Hack to make the caching_sha2_password plugin believe that the connection - // is secure - mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, - 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, - 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, - 50, 0} - - // auth response - conn.queuedReplies = [][]byte{ - {2, 0, 0, 4, 1, 4}, // Perform Full Authentication - {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK - } - conn.maxReads = 3 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{ - // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, - - // 2. Packet: Cleartext password - 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, - } - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { - conn, mc := newRWMockConn(2) - - conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, - 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} - conn.maxReads = 1 - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) - if err != ErrCleartextPassword { - t.Errorf("expected ErrCleartextPassword, got %v", err) - } -} - -func TestAuthSwitchCleartextPassword(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowCleartextPasswords = true - mc.cfg.Passwd = "secret" - - // auth switch request - conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, - 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} - - // auth response - conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowCleartextPasswords = true - mc.cfg.Passwd = "" - - // auth switch request - conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, - 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} - - // auth response - conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{1, 0, 0, 3, 0} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowNativePasswords = false - - conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, - 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, - 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, - 31, 0} - conn.maxReads = 1 - authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, - 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" - err := mc.handleAuthResult(authData, plugin) - if err != ErrNativePassword { - t.Errorf("expected ErrNativePassword, got %v", err) - } -} - -func TestAuthSwitchNativePassword(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowNativePasswords = true - mc.cfg.Passwd = "secret" - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, - 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, - 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, - 31, 0} - - // auth response - conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, - 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103, - 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchNativePasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowNativePasswords = true - mc.cfg.Passwd = "" - - // auth switch request - conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, - 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, - 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, - 31, 0} - - // auth response - conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, - 48, 31, 89, 39, 55, 31} - plugin := "caching_sha2_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{0, 0, 0, 3} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { - conn, mc := newRWMockConn(2) - - conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, - 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, - 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} - conn.maxReads = 1 - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) - if err != ErrOldPassword { - t.Errorf("expected ErrOldPassword, got %v", err) - } -} - -// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. -func TestOldAuthSwitchNotAllowed(t *testing.T) { - conn, mc := newRWMockConn(2) - - // OldAuthSwitch request - conn.data = []byte{1, 0, 0, 2, 0xfe} - conn.maxReads = 1 - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) - if err != ErrOldPassword { - t.Errorf("expected ErrOldPassword, got %v", err) - } -} - -func TestAuthSwitchOldPassword(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowOldPasswords = true - mc.cfg.Passwd = "secret" - - // auth switch request - conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, - 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, - 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} - - // auth response - conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. -func TestOldAuthSwitch(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowOldPasswords = true - mc.cfg.Passwd = "secret" - - // OldAuthSwitch request - conn.data = []byte{1, 0, 0, 2, 0xfe} - - // auth response - conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} -func TestAuthSwitchOldPasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowOldPasswords = true - mc.cfg.Passwd = "" - - // auth switch request - conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, - 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, - 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} - - // auth response - conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{0, 0, 0, 3} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. -func TestOldAuthSwitchPasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.AllowOldPasswords = true - mc.cfg.Passwd = "" - - // OldAuthSwitch request. - conn.data = []byte{1, 0, 0, 2, 0xfe} - - // auth response - conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} - conn.maxReads = 2 - - authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, - 84, 96, 101, 92, 123, 121, 107} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReply := []byte{0, 0, 0, 3} - if !bytes.Equal(conn.written, expectedReply) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "" - - // auth switch request - conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, - 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, - 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} - - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 3 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReplyPrefix := []byte{ - // 1. Packet: Empty Password - 1, 0, 0, 3, 0, - } - if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - - // auth switch request - conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, - 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, - 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} - - conn.queuedReplies = [][]byte{ - // Pub Key Response - append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), - - // OK - {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 3 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReplyPrefix := []byte{ - // 1. Packet: Pub Key Request - 1, 0, 0, 3, 1, - - // 2. Packet: Encrypted Password - 0, 1, 0, 5, // [changing bytes] - } - if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - mc.cfg.pubKey = testPubKeyRSA - - // auth switch request - conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, - 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, - 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} - - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 2 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReplyPrefix := []byte{ - // 1. Packet: Encrypted Password - 0, 1, 0, 3, // [changing bytes] - } - if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { - t.Errorf("got unexpected data: %v", conn.written) - } -} - -func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { - conn, mc := newRWMockConn(2) - mc.cfg.Passwd = "secret" - - // Hack to make the caching_sha2_password plugin believe that the connection - // is secure - mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} - - // auth switch request - conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, - 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, - 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} - - conn.queuedReplies = [][]byte{ - // OK - {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, - } - conn.maxReads = 2 - - authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, - 47, 43, 9, 41, 112, 67, 110} - plugin := "mysql_native_password" - - if err := mc.handleAuthResult(authData, plugin); err != nil { - t.Errorf("got error: %v", err) - } - - expectedReplyPrefix := []byte{ - // 1. Packet: Cleartext Password - 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, - } - if !bytes.Equal(conn.written, expectedReplyPrefix) { - t.Errorf("got unexpected data: %v", conn.written) - } -} +// TODO: fix this test +// func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, +// 22, 41, 84, 32, 123, 43, 118} +// plugin := "caching_sha2_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56, +// 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15, +// 184, 150, 26, 61, 57, 235} +// if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 2, 0, 0, 2, 1, 3, // Fast Auth Success +// 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK +// } +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "" + +// authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, +// 22, 41, 84, 32, 123, 43, 118} +// plugin := "caching_sha2_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// if writtenAuthRespLen != 0 { +// t.Fatalf("unexpected written auth response (%d bytes): %v", +// writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK +// } +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "caching_sha2_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, +// 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, +// 110, 40, 139, 124, 41} +// if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 2, 0, 0, 2, 1, 4, // Perform Full Authentication +// } +// conn.queuedReplies = [][]byte{ +// // pub key response +// append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + +// // OK +// {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 3 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) { +// t.Errorf("unexpected written data: %v", conn.written) +// } +// } + +// func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" +// mc.cfg.pubKey = testPubKeyRSA + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "caching_sha2_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, +// 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, +// 110, 40, 139, 124, 41} +// if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 2, 0, 0, 2, 1, 4, // Perform Full Authentication +// } +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 2 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { +// t.Errorf("unexpected written data: %v", conn.written) +// } +// } + +// func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "caching_sha2_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // Hack to make the caching_sha2_password plugin believe that the connection +// // is secure +// mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, +// 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, +// 110, 40, 139, 124, 41} +// if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 2, 0, 0, 2, 1, 4, // Perform Full Authentication +// } +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 3 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { +// t.Errorf("unexpected written data: %v", conn.written) +// } +// } + +// func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { +// _, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, +// 103, 26, 95, 81, 17, 24, 21} +// plugin := "mysql_clear_password" + +// // Send Client Authentication Packet +// _, err := mc.auth(authData, plugin) +// if err != ErrCleartextPassword { +// t.Errorf("expected ErrCleartextPassword, got %v", err) +// } +// } + +// func TestAuthFastCleartextPassword(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" +// mc.cfg.AllowCleartextPasswords = true + +// authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, +// 103, 26, 95, 81, 17, 24, 21} +// plugin := "mysql_clear_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} +// if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK +// } +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastCleartextPasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "" +// mc.cfg.AllowCleartextPasswords = true + +// authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, +// 103, 26, 95, 81, 17, 24, 21} +// plugin := "mysql_clear_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{0} +// if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK +// } +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastNativePasswordNotAllowed(t *testing.T) { +// _, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" +// mc.cfg.AllowNativePasswords = false + +// authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, +// 103, 26, 95, 81, 17, 24, 21} +// plugin := "mysql_native_password" + +// // Send Client Authentication Packet +// _, err := mc.auth(authData, plugin) +// if err != ErrNativePassword { +// t.Errorf("expected ErrNativePassword, got %v", err) +// } +// } + +// func TestAuthFastNativePassword(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, +// 103, 26, 95, 81, 17, 24, 21} +// plugin := "mysql_native_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252, +// 172, 50, 211, 192, 240, 164, 26, 48, 207, 45} +// if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK +// } +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastNativePasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "" + +// authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, +// 103, 26, 95, 81, 17, 24, 21} +// plugin := "mysql_native_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// if writtenAuthRespLen != 0 { +// t.Fatalf("unexpected written auth response (%d bytes): %v", +// writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response +// conn.data = []byte{ +// 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK +// } +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastSHA256PasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "" + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "sha256_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{0} +// if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response (pub key response) +// conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 2 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { +// t.Errorf("unexpected written data: %v", conn.written) +// } +// } + +// func TestAuthFastSHA256PasswordRSA(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "sha256_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{1} +// if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response (pub key response) +// conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 2 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { +// t.Errorf("unexpected written data: %v", conn.written) +// } +// } + +// func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" +// mc.cfg.pubKey = testPubKeyRSA + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "sha256_password" + +// // Send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // auth response (OK) +// conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } +// } + +// func TestAuthFastSHA256PasswordSecure(t *testing.T) { +// conn, mc := newRWMockConn(1) +// mc.cfg.User = "root" +// mc.cfg.Passwd = "secret" + +// // hack to make the caching_sha2_password plugin believe that the connection +// // is secure +// mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} + +// authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, +// 62, 94, 83, 80, 52, 85} +// plugin := "sha256_password" + +// // send Client Authentication Packet +// authResp, err := mc.auth(authData, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // unset TLS config to prevent the actual establishment of a TLS wrapper +// mc.cfg.TLS = nil + +// err = mc.writeHandshakeResponsePacket(authResp, plugin) +// if err != nil { +// t.Fatal(err) +// } + +// // check written auth response +// authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 +// authRespEnd := authRespStart + 1 + len(authResp) +// writtenAuthRespLen := conn.written[authRespStart] +// writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] +// expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} +// if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { +// t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) +// } +// conn.written = nil + +// // auth response (OK) +// conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} +// conn.maxReads = 1 + +// // Handle response to auth packet +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// if !bytes.Equal(conn.written, []byte{}) { +// t.Errorf("unexpected written data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, +// 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, +// 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, +// 50, 0} + +// // auth response +// conn.queuedReplies = [][]byte{ +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK +// } +// conn.maxReads = 3 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{ +// // 1. Packet: Hash +// 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, +// 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, +// 153, 9, 130, +// } +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "" + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, +// 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, +// 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, +// 50, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{0, 0, 0, 3} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, +// 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, +// 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, +// 50, 0} + +// conn.queuedReplies = [][]byte{ +// // Perform Full Authentication +// {2, 0, 0, 4, 1, 4}, + +// // Pub Key Response +// append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...), + +// // OK +// {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 4 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReplyPrefix := []byte{ +// // 1. Packet: Hash +// 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, +// 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, +// 153, 9, 130, + +// // 2. Packet: Pub Key Request +// 1, 0, 0, 5, 2, + +// // 3. Packet: Encrypted Password +// 0, 1, 0, 7, // [changing bytes] +// } +// if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" +// mc.cfg.pubKey = testPubKeyRSA + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, +// 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, +// 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, +// 50, 0} + +// conn.queuedReplies = [][]byte{ +// // Perform Full Authentication +// {2, 0, 0, 4, 1, 4}, + +// // OK +// {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 3 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReplyPrefix := []byte{ +// // 1. Packet: Hash +// 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, +// 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, +// 153, 9, 130, + +// // 2. Packet: Encrypted Password +// 0, 1, 0, 5, // [changing bytes] +// } +// if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" + +// // Hack to make the caching_sha2_password plugin believe that the connection +// // is secure +// mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, +// 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, +// 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, +// 50, 0} + +// // auth response +// conn.queuedReplies = [][]byte{ +// {2, 0, 0, 4, 1, 4}, // Perform Full Authentication +// {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK +// } +// conn.maxReads = 3 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{ +// // 1. Packet: Hash +// 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, +// 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, +// 153, 9, 130, + +// // 2. Packet: Cleartext password +// 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, +// } +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { +// conn, mc := newRWMockConn(2) + +// conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, +// 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} +// conn.maxReads = 1 +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" +// err := mc.handleAuthResult(authData, plugin) +// if err != ErrCleartextPassword { +// t.Errorf("expected ErrCleartextPassword, got %v", err) +// } +// } + +// func TestAuthSwitchCleartextPassword(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowCleartextPasswords = true +// mc.cfg.Passwd = "secret" + +// // auth switch request +// conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, +// 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowCleartextPasswords = true +// mc.cfg.Passwd = "" + +// // auth switch request +// conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, +// 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{1, 0, 0, 3, 0} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowNativePasswords = false + +// conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, +// 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, +// 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, +// 31, 0} +// conn.maxReads = 1 +// authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, +// 48, 31, 89, 39, 55, 31} +// plugin := "caching_sha2_password" +// err := mc.handleAuthResult(authData, plugin) +// if err != ErrNativePassword { +// t.Errorf("expected ErrNativePassword, got %v", err) +// } +// } + +// func TestAuthSwitchNativePassword(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowNativePasswords = true +// mc.cfg.Passwd = "secret" + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, +// 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, +// 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, +// 31, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, +// 48, 31, 89, 39, 55, 31} +// plugin := "caching_sha2_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103, +// 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchNativePasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowNativePasswords = true +// mc.cfg.Passwd = "" + +// // auth switch request +// conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, +// 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, +// 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, +// 31, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, +// 48, 31, 89, 39, 55, 31} +// plugin := "caching_sha2_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{0, 0, 0, 3} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { +// conn, mc := newRWMockConn(2) + +// conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, +// 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, +// 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} +// conn.maxReads = 1 +// authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, +// 84, 96, 101, 92, 123, 121, 107} +// plugin := "mysql_native_password" +// err := mc.handleAuthResult(authData, plugin) +// if err != ErrOldPassword { +// t.Errorf("expected ErrOldPassword, got %v", err) +// } +// } + +// // Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +// func TestOldAuthSwitchNotAllowed(t *testing.T) { +// conn, mc := newRWMockConn(2) + +// // OldAuthSwitch request +// conn.data = []byte{1, 0, 0, 2, 0xfe} +// conn.maxReads = 1 +// authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, +// 84, 96, 101, 92, 123, 121, 107} +// plugin := "mysql_native_password" +// err := mc.handleAuthResult(authData, plugin) +// if err != ErrOldPassword { +// t.Errorf("expected ErrOldPassword, got %v", err) +// } +// } + +// func TestAuthSwitchOldPassword(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowOldPasswords = true +// mc.cfg.Passwd = "secret" + +// // auth switch request +// conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, +// 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, +// 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, +// 84, 96, 101, 92, 123, 121, 107} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// // Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +// func TestOldAuthSwitch(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowOldPasswords = true +// mc.cfg.Passwd = "secret" + +// // OldAuthSwitch request +// conn.data = []byte{1, 0, 0, 2, 0xfe} + +// // auth response +// conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, +// 84, 96, 101, 92, 123, 121, 107} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } +// func TestAuthSwitchOldPasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowOldPasswords = true +// mc.cfg.Passwd = "" + +// // auth switch request +// conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, +// 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, +// 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + +// // auth response +// conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, +// 84, 96, 101, 92, 123, 121, 107} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{0, 0, 0, 3} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// // Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +// func TestOldAuthSwitchPasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.AllowOldPasswords = true +// mc.cfg.Passwd = "" + +// // OldAuthSwitch request. +// conn.data = []byte{1, 0, 0, 2, 0xfe} + +// // auth response +// conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} +// conn.maxReads = 2 + +// authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, +// 84, 96, 101, 92, 123, 121, 107} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReply := []byte{0, 0, 0, 3} +// if !bytes.Equal(conn.written, expectedReply) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "" + +// // auth switch request +// conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, +// 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, +// 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 3 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReplyPrefix := []byte{ +// // 1. Packet: Empty Password +// 1, 0, 0, 3, 0, +// } +// if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" + +// // auth switch request +// conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, +// 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, +// 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + +// conn.queuedReplies = [][]byte{ +// // Pub Key Response +// append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + +// // OK +// {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 3 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReplyPrefix := []byte{ +// // 1. Packet: Pub Key Request +// 1, 0, 0, 3, 1, + +// // 2. Packet: Encrypted Password +// 0, 1, 0, 5, // [changing bytes] +// } +// if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" +// mc.cfg.pubKey = testPubKeyRSA + +// // auth switch request +// conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, +// 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, +// 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 2 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReplyPrefix := []byte{ +// // 1. Packet: Encrypted Password +// 0, 1, 0, 3, // [changing bytes] +// } +// if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } + +// func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { +// conn, mc := newRWMockConn(2) +// mc.cfg.Passwd = "secret" + +// // Hack to make the caching_sha2_password plugin believe that the connection +// // is secure +// mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} + +// // auth switch request +// conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, +// 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, +// 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + +// conn.queuedReplies = [][]byte{ +// // OK +// {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, +// } +// conn.maxReads = 2 + +// authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, +// 47, 43, 9, 41, 112, 67, 110} +// plugin := "mysql_native_password" + +// if err := mc.handleAuthResult(authData, plugin); err != nil { +// t.Errorf("got error: %v", err) +// } + +// expectedReplyPrefix := []byte{ +// // 1. Packet: Cleartext Password +// 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, +// } +// if !bytes.Equal(conn.written, expectedReplyPrefix) { +// t.Errorf("got unexpected data: %v", conn.written) +// } +// } diff --git a/benchmark_test.go b/benchmark_test.go index fc70df60d..ffeee3a1f 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -220,9 +220,9 @@ func BenchmarkInterpolation(b *testing.B) { InterpolateParams: true, Loc: time.UTC, }, + connector: &connector{}, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), } args := []driver.Value{ diff --git a/buffer.go b/buffer.go index 0774c5c8c..5d5931763 100644 --- a/buffer.go +++ b/buffer.go @@ -1,106 +1,89 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - package mysql import ( "io" "net" - "time" ) -const defaultBufSize = 4096 +const defaultBufSize = 4096 // must be 2^n const maxCachedBufSize = 256 * 1024 -// A buffer which is used for both reading and writing. -// This is possible since communication on each connection is synchronous. -// In other words, we can't write and read simultaneously on the same connection. -// The buffer is similar to bufio.Reader / Writer but zero-copy-ish -// Also highly optimized for this particular use case. -// This buffer is backed by two byte slices in a double-buffering scheme -type buffer struct { - buf []byte // buf is a byte buffer who's length and capacity are equal. - nc net.Conn - idx int - length int - timeout time.Duration - dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer - flipcnt uint // flipccnt is the current buffer counter for double-buffering +type writeBuffer struct { + buf []byte } -// newBuffer allocates and returns a new buffer. -func newBuffer(nc net.Conn) buffer { - fg := make([]byte, defaultBufSize) - return buffer{ - buf: fg, - nc: nc, - dbuf: [2][]byte{fg, nil}, +// takeBuffer returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (wb *writeBuffer) takeBuffer(length int) []byte { + if length <= cap(wb.buf) { + return wb.buf[:length] + } + if length <= defaultBufSize { + wb.buf = make([]byte, length, defaultBufSize) + return wb.buf } + if length <= maxPacketSize { + wb.buf = make([]byte, length) + return wb.buf + } + return make([]byte, length) } -// flip replaces the active buffer with the background buffer -// this is a delayed flip that simply increases the buffer counter; -// the actual flip will be performed the next time we call `buffer.fill` -func (b *buffer) flip() { - b.flipcnt += 1 +func (wb *writeBuffer) store(buf []byte) { + if cap(buf) < cap(wb.buf) || cap(buf) > maxPacketSize { + return + } + wb.buf = buf } -// fill reads into the buffer until at least _need_ bytes are in it -func (b *buffer) fill(need int) error { - n := b.length - // fill data into its double-buffering target: if we've called - // flip on this buffer, we'll be copying to the background buffer, - // and then filling it with network data; otherwise we'll just move - // the contents of the current buffer to the front before filling it - dest := b.dbuf[b.flipcnt&1] +type readBuffer struct { + buf []byte + idx int + nc net.Conn + cached []byte +} - // grow buffer if necessary to fit the whole packet. - if need > len(dest) { +func newReadBuffer(nc net.Conn) readBuffer { + return readBuffer{ + buf: make([]byte, 0, defaultBufSize), + nc: nc, + } +} + +// fill reads into the buffer until at least _need_ bytes are in it. +func (rb *readBuffer) fill(need int) error { + var buf []byte + if need <= cap(rb.cached) { + buf = rb.cached[:0] + } else { // Round up to the next multiple of the default size - dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + size := (need + defaultBufSize - 1) &^ (defaultBufSize - 1) - // if the allocated buffer is not too large, move it to backing storage - // to prevent extra allocations on applications that perform large reads - if len(dest) <= maxCachedBufSize { - b.dbuf[b.flipcnt&1] = dest + buf = make([]byte, 0, size) + if size <= maxCachedBufSize { + rb.cached = buf } } - // if we're filling the fg buffer, move the existing data to the start of it. - // if we're filling the bg buffer, copy over the data - if n > 0 { - copy(dest[:n], b.buf[b.idx:]) - } - - b.buf = dest - b.idx = 0 + // move the existing data to the start of it. + buf = append(buf, rb.buf[rb.idx:]...) + rb.idx = 0 for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { - return err - } - } - - nn, err := b.nc.Read(b.buf[n:]) - n += nn - + n, err := rb.nc.Read(buf[len(buf):cap(buf)]) + buf = buf[:len(buf)+n] switch err { case nil: - if n < need { - continue + if len(buf) >= need { + rb.buf = buf + return nil } - b.length = n - return nil case io.EOF: - if n >= need { - b.length = n + if len(buf) >= need { + rb.buf = buf return nil } return io.ErrUnexpectedEOF @@ -112,71 +95,15 @@ func (b *buffer) fill(need int) error { } // returns next N bytes from buffer. -// The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { - if b.length < need { - // refill - if err := b.fill(need); err != nil { +// The returned slice is only guaranteed to be valid until the next read. +func (rb *readBuffer) readNext(need int) ([]byte, error) { + if len(rb.buf)-rb.idx < need { + if err := rb.fill(need); err != nil { return nil, err } } - offset := b.idx - b.idx += need - b.length -= need - return b.buf[offset:b.idx], nil -} - -// takeBuffer returns a buffer with the requested size. -// If possible, a slice from the existing buffer is returned. -// Otherwise a bigger buffer is made. -// Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) ([]byte, error) { - if b.length > 0 { - return nil, ErrBusyBuffer - } - - // test (cheap) general case first - if length <= cap(b.buf) { - return b.buf[:length], nil - } - - if length < maxPacketSize { - b.buf = make([]byte, length) - return b.buf, nil - } - - // buffer is larger than we want to store. - return make([]byte, length), nil -} - -// takeSmallBuffer is shortcut which can be used if length is -// known to be smaller than defaultBufSize. -// Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { - if b.length > 0 { - return nil, ErrBusyBuffer - } - return b.buf[:length], nil -} - -// takeCompleteBuffer returns the complete existing buffer. -// This can be used if the necessary buffer size is unknown. -// cap and len of the returned buffer will be equal. -// Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() ([]byte, error) { - if b.length > 0 { - return nil, ErrBusyBuffer - } - return b.buf, nil -} - -// store stores buf, an updated buffer, if its suitable to do so. -func (b *buffer) store(buf []byte) error { - if b.length > 0 { - return ErrBusyBuffer - } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { - b.buf = buf[:cap(buf)] - } - return nil + offset := rb.idx + rb.idx += need + return rb.buf[offset:rb.idx], nil } diff --git a/conncheck.go b/conncheck.go deleted file mode 100644 index 0ea721720..000000000 --- a/conncheck.go +++ /dev/null @@ -1,55 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos -// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos - -package mysql - -import ( - "errors" - "io" - "net" - "syscall" -) - -var errUnexpectedRead = errors.New("unexpected read from socket") - -func connCheck(conn net.Conn) error { - var sysErr error - - sysConn, ok := conn.(syscall.Conn) - if !ok { - return nil - } - rawConn, err := sysConn.SyscallConn() - if err != nil { - return err - } - - err = rawConn.Read(func(fd uintptr) bool { - var buf [1]byte - n, err := syscall.Read(int(fd), buf[:]) - switch { - case n == 0 && err == nil: - sysErr = io.EOF - case n > 0: - sysErr = errUnexpectedRead - case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: - sysErr = nil - default: - sysErr = err - } - return true - }) - if err != nil { - return err - } - - return sysErr -} diff --git a/conncheck_dummy.go b/conncheck_dummy.go deleted file mode 100644 index a56c138f2..000000000 --- a/conncheck_dummy.go +++ /dev/null @@ -1,18 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos -// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos - -package mysql - -import "net" - -func connCheck(conn net.Conn) error { - return nil -} diff --git a/conncheck_test.go b/conncheck_test.go deleted file mode 100644 index f7e025680..000000000 --- a/conncheck_test.go +++ /dev/null @@ -1,39 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos -// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos - -package mysql - -import ( - "testing" - "time" -) - -func TestStaleConnectionChecks(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("SET @@SESSION.wait_timeout = 2") - - if err := dbt.db.Ping(); err != nil { - dbt.Fatal(err) - } - - // wait for MySQL to close our connection - time.Sleep(3 * time.Second) - - tx, err := dbt.db.Begin() - if err != nil { - dbt.Fatal(err) - } - - if err := tx.Rollback(); err != nil { - dbt.Fatal(err) - } - }) -} diff --git a/connection.go b/connection.go index 631a1dc24..fd14804b5 100644 --- a/connection.go +++ b/connection.go @@ -17,18 +17,31 @@ import ( "net" "strconv" "strings" + "sync" "time" ) +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancellation of dials. +var aLongTimeAgo = time.Unix(1, 0) + +type writeResult struct { + n int + err error +} + type mysqlConn struct { - buf buffer + muRead sync.Mutex // protects netConn for reads netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. + rawConn net.Conn // underlying connection when netConn is TLS connection. + wbuf writeBuffer + rbuf readBuffer result mysqlResult // managed by clearResult() and handleOkPacket(). cfg *Config connector *connector maxAllowedPacket int maxWriteSize int + readTimeout time.Duration writeTimeout time.Duration flags clientFlag status statusFlag @@ -37,16 +50,15 @@ type mysqlConn struct { reset bool // set when the Go SQL package calls ResetSession // for context support (Go 1.8+) - watching bool - watcher chan<- context.Context closech chan struct{} - finished chan<- struct{} - canceled atomicError // set non-nil if conn is canceled - closed atomicBool // set when conn is closed, before closech is closed + closed atomicBool // set when conn is closed, before closech is closed + readRes chan *packet // channel for read result + writeReq chan []byte // buffered channel for write packets + writeRes chan writeResult // channel for write result } // Handles parameters set in DSN after the connection is established -func (mc *mysqlConn) handleParams() (err error) { +func (mc *mysqlConn) handleParams(ctx context.Context) (err error) { var cmdSet strings.Builder for param, val := range mc.cfg.Params { @@ -57,9 +69,9 @@ func (mc *mysqlConn) handleParams() (err error) { for _, cs := range charsets { // ignore errors here - a charset may not exist if mc.cfg.Collation != "" { - err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) + err = mc.exec(ctx, "SET NAMES "+cs+" COLLATE "+mc.cfg.Collation) } else { - err = mc.exec("SET NAMES " + cs) + err = mc.exec(ctx, "SET NAMES "+cs) } if err == nil { break @@ -85,7 +97,7 @@ func (mc *mysqlConn) handleParams() (err error) { } if cmdSet.Len() > 0 { - err = mc.exec(cmdSet.String()) + err = mc.exec(ctx, cmdSet.String()) if err != nil { return } @@ -105,10 +117,10 @@ func (mc *mysqlConn) markBadConn(err error) error { } func (mc *mysqlConn) Begin() (driver.Tx, error) { - return mc.begin(false) + return mc.begin(context.Background(), false) } -func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { +func (mc *mysqlConn) begin(ctx context.Context, readOnly bool) (driver.Tx, error) { if mc.closed.Load() { mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -119,17 +131,24 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { } else { q = "START TRANSACTION" } - err := mc.exec(q) + err := mc.exec(ctx, q) if err == nil { - return &mysqlTx{mc}, err + return &mysqlTx{ + ctx: ctx, + mc: mc, + }, err } return nil, mc.markBadConn(err) } -func (mc *mysqlConn) Close() (err error) { +func (mc *mysqlConn) Close() error { + return mc.closeContext(context.Background()) +} + +func (mc *mysqlConn) closeContext(ctx context.Context) (err error) { // Makes Close idempotent if !mc.closed.Load() { - err = mc.writeCommandPacket(comQuit) + err = mc.writeCommandPacket(ctx, comQuit) } mc.cleanup() @@ -159,46 +178,13 @@ func (mc *mysqlConn) cleanup() { func (mc *mysqlConn) error() error { if mc.closed.Load() { - if err := mc.canceled.Value(); err != nil { - return err - } return ErrInvalidConn } return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) - return nil, driver.ErrBadConn - } - // Send command - err := mc.writeCommandPacketStr(comStmtPrepare, query) - if err != nil { - // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. - mc.cfg.Logger.Print(err) - return nil, driver.ErrBadConn - } - - stmt := &mysqlStmt{ - mc: mc, - } - - // Read Result - columnCount, err := stmt.readPrepareResultPacket() - if err == nil { - if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { - return nil, err - } - } - - if columnCount > 0 { - err = mc.readUntilEOF() - } - } - - return stmt, err + return mc.PrepareContext(context.Background(), query) } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { @@ -207,13 +193,8 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf, err := mc.buf.takeCompleteBuffer() - if err != nil { - // can not take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return "", ErrInvalidConn - } - buf = buf[:0] + var err error + buf := mc.wbuf.takeBuffer(0) argPos := 0 for i := 0; i < len(query); i++ { @@ -297,10 +278,14 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if argPos != len(args) { return "", driver.ErrSkip } - return string(buf), nil + s := string(buf) + mc.wbuf.store(buf[:0]) + return s, nil } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + ctx := context.Background() + if mc.closed.Load() { mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -317,7 +302,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err query = prepared } - err := mc.exec(query) + err := mc.exec(ctx, query) if err == nil { copied := mc.result return &copied, err @@ -326,39 +311,39 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } // Internal function to execute commands -func (mc *mysqlConn) exec(query string) error { +func (mc *mysqlConn) exec(ctx context.Context, query string) error { handleOk := mc.clearResult() // Send command - if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + if err := mc.writeCommandPacketStr(ctx, comQuery, query); err != nil { return mc.markBadConn(err) } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket(ctx) if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return err } } - return handleOk.discardResults() + return handleOk.discardResults(ctx) } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return mc.query(query, args) + return mc.query(context.Background(), query, args) } -func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { +func (mc *mysqlConn) query(ctx context.Context, query string, args []driver.Value) (*textRows, error) { handleOk := mc.clearResult() if mc.closed.Load() { @@ -377,14 +362,15 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) query = prepared } // Send command - err := mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(ctx, comQuery, query) if err == nil { // Read Result var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() + resLen, err = handleOk.readResultSetHeaderPacket(ctx) if err == nil { rows := new(textRows) rows.mc = mc + rows.ctx = ctx if resLen == 0 { rows.rs.done = true @@ -398,7 +384,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } // Columns - rows.rs.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(ctx, resLen) return rows, err } } @@ -407,71 +393,52 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read -func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { +func (mc *mysqlConn) getSystemVar(ctx context.Context, name string) ([]byte, error) { // Send command handleOk := mc.clearResult() - if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + if err := mc.writeCommandPacketStr(ctx, comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket(ctx) if err == nil { rows := new(textRows) + rows.ctx = ctx rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + return dest[0].([]byte), mc.readUntilEOF(ctx) } } return nil, err } -// finish is called when the query has canceled. -func (mc *mysqlConn) cancel(err error) { - mc.canceled.Set(err) - mc.cleanup() -} - -// finish is called when the query has succeeded. -func (mc *mysqlConn) finish() { - if !mc.watching || mc.finished == nil { - return - } - select { - case mc.finished <- struct{}{}: - mc.watching = false - case <-mc.closech: - } -} - // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { mc.cfg.Logger.Print(ErrInvalidConn) return driver.ErrBadConn } - - if err = mc.watchCancel(ctx); err != nil { - return + if err := ctx.Err(); err != nil { + return err } - defer mc.finish() handleOk := mc.clearResult() - if err = mc.writeCommandPacket(comPing); err != nil { + if err = mc.writeCommandPacket(ctx, comPing); err != nil { return mc.markBadConn(err) } - return handleOk.readResultOK() + return handleOk.readResultOK(ctx) } // BeginTx implements driver.ConnBeginTx interface @@ -480,23 +447,18 @@ func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver return nil, driver.ErrBadConn } - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - defer mc.finish() - if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { level, err := mapIsolationLevel(opts.Isolation) if err != nil { return nil, err } - err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + err = mc.exec(ctx, "SET TRANSACTION ISOLATION LEVEL "+level) if err != nil { return nil, err } } - return mc.begin(opts.ReadOnly) + return mc.begin(ctx, opts.ReadOnly) } func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -505,51 +467,75 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv return nil, err } - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - - rows, err := mc.query(query, dargs) + rows, err := mc.query(ctx, query, dargs) if err != nil { - mc.finish() return nil, err } - rows.finish = mc.finish return rows, err } func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if mc.closed.Load() { + mc.cfg.Logger.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + dargs, err := namedValueToValue(args) if err != nil { return nil, err } - - if err := mc.watchCancel(ctx); err != nil { - return nil, err + if len(dargs) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement + prepared, err := mc.interpolateParams(query, dargs) + if err != nil { + return nil, err + } + query = prepared } - defer mc.finish() - return mc.Exec(query, dargs) + err = mc.exec(ctx, query) + if err == nil { + copied := mc.result + return &copied, err + } + return nil, mc.markBadConn(err) } func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if err := mc.watchCancel(ctx); err != nil { - return nil, err + if mc.closed.Load() { + mc.cfg.Logger.Print(ErrInvalidConn) + return nil, driver.ErrBadConn } - - stmt, err := mc.Prepare(query) - mc.finish() + // Send command + err := mc.writeCommandPacketStr(ctx, comStmtPrepare, query) if err != nil { - return nil, err + // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. + mc.cfg.Logger.Print(err) + return nil, driver.ErrBadConn } - select { - default: - case <-ctx.Done(): - stmt.Close() - return nil, ctx.Err() + stmt := &mysqlStmt{ + mc: mc, } - return stmt, nil + + // Read Result + columnCount, err := stmt.readPrepareResultPacket(ctx) + if err == nil { + if stmt.paramCount > 0 { + if err = mc.readUntilEOF(ctx); err != nil { + return nil, err + } + } + + if columnCount > 0 { + err = mc.readUntilEOF(ctx) + } + } + + return stmt, err } func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -557,82 +543,53 @@ func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValu if err != nil { return nil, err } + return stmt.query(ctx, dargs) +} - if err := stmt.mc.watchCancel(ctx); err != nil { - return nil, err +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if stmt.mc.closed.Load() { + stmt.mc.cfg.Logger.Print(ErrInvalidConn) + return nil, driver.ErrBadConn } - rows, err := stmt.query(dargs) + dargs, err := namedValueToValue(args) if err != nil { - stmt.mc.finish() return nil, err } - rows.finish = stmt.mc.finish - return rows, err -} -func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - dargs, err := namedValueToValue(args) + // Send command + err = stmt.writeExecutePacket(ctx, dargs) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } - if err := stmt.mc.watchCancel(ctx); err != nil { + mc := stmt.mc + handleOk := stmt.mc.clearResult() + + // Read Result + resLen, err := handleOk.readResultSetHeaderPacket(ctx) + if err != nil { return nil, err } - defer stmt.mc.finish() - return stmt.Exec(dargs) -} + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(ctx); err != nil { + return nil, err + } -func (mc *mysqlConn) watchCancel(ctx context.Context) error { - if mc.watching { - // Reach here if canceled, - // so the connection is already invalid - mc.cleanup() - return nil - } - // When ctx is already cancelled, don't watch it. - if err := ctx.Err(); err != nil { - return err - } - // When ctx is not cancellable, don't watch it. - if ctx.Done() == nil { - return nil - } - // When watcher is not alive, can't watch it. - if mc.watcher == nil { - return nil + // Rows + if err := mc.readUntilEOF(ctx); err != nil { + return nil, err + } } - mc.watching = true - mc.watcher <- ctx - return nil -} - -func (mc *mysqlConn) startWatcher() { - watcher := make(chan context.Context, 1) - mc.watcher = watcher - finished := make(chan struct{}) - mc.finished = finished - go func() { - for { - var ctx context.Context - select { - case ctx = <-watcher: - case <-mc.closech: - return - } + if err := handleOk.discardResults(ctx); err != nil { + return nil, err + } - select { - case <-ctx.Done(): - mc.cancel(ctx.Err()) - case <-finished: - case <-mc.closech: - return - } - } - }() + copied := mc.result + return &copied, nil } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { diff --git a/connection_test.go b/connection_test.go index 98c985ae1..dda667ea9 100644 --- a/connection_test.go +++ b/connection_test.go @@ -13,13 +13,14 @@ import ( "database/sql/driver" "encoding/json" "errors" + "io" "net" "testing" + "time" ) func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -39,7 +40,6 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsJSONRawMessage(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -66,7 +66,6 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -83,7 +82,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -99,7 +97,6 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { func TestInterpolateParamsUint64(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -135,8 +132,6 @@ func TestCleanCancel(t *testing.T) { mc := &mysqlConn{ closech: make(chan struct{}), } - mc.startWatcher() - defer mc.cleanup() ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -150,20 +145,22 @@ func TestCleanCancel(t *testing.T) { if mc.closed.Load() { t.Error("expected mc is not closed, closed actually") } - - if mc.watching { - t.Error("expected watching is false, but true") - } } } func TestPingMarkBadConnection(t *testing.T) { - nc := badConnection{err: errors.New("boom")} + nc := badConnection{ + werr: errors.New("boom"), + done: make(chan struct{}), + } ms := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + rbuf: newReadBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, + connector: &connector{}, } + ms.startGoroutines() + defer ms.cleanup() err := ms.Ping(context.Background()) @@ -173,14 +170,21 @@ func TestPingMarkBadConnection(t *testing.T) { } func TestPingErrInvalidConn(t *testing.T) { - nc := badConnection{err: errors.New("failed to write"), n: 10} + nc := badConnection{ + werr: errors.New("failed to write"), + n: 10, + done: make(chan struct{}), + } ms := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + rbuf: newReadBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), + connector: &connector{}, } + ms.startGoroutines() + defer ms.cleanup() err := ms.Ping(context.Background()) @@ -190,15 +194,55 @@ func TestPingErrInvalidConn(t *testing.T) { } type badConnection struct { - n int - err error - net.Conn + n int + rerr error + werr error + read chan struct{} + done chan struct{} +} + +func (bc badConnection) Read(b []byte) (n int, err error) { + select { + case bc.read <- struct{}{}: + case <-bc.done: + return 0, io.EOF + } + + if bc.rerr != nil { + return 0, bc.rerr + } + <-bc.done + return 0, io.EOF } func (bc badConnection) Write(b []byte) (n int, err error) { - return bc.n, bc.err + if bc.werr != nil { + return bc.n, bc.werr + } + return 0, io.ErrShortWrite } func (bc badConnection) Close() error { + close(bc.done) + return nil +} + +func (bc badConnection) LocalAddr() net.Addr { + return nil +} + +func (bc badConnection) RemoteAddr() net.Addr { + return nil +} + +func (bc badConnection) SetDeadline(t time.Time) error { + return nil +} + +func (bc badConnection) SetReadDeadline(t time.Time) error { + return nil +} + +func (bc badConnection) SetWriteDeadline(t time.Time) error { return nil } diff --git a/connector.go b/connector.go index ba3be71e7..b1df02fc1 100644 --- a/connector.go +++ b/connector.go @@ -16,11 +16,13 @@ import ( "os" "strconv" "strings" + "sync" ) type connector struct { cfg *Config // immutable private copy. encodedAttributes string // Encoded connection attributes. + packetPool sync.Pool } func encodeConnectionAttributes(textAttributes string) string { @@ -69,7 +71,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - closech: make(chan struct{}), cfg: c.cfg, connector: c, } @@ -103,22 +104,13 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } } - // Call startWatcher for context support (From Go 1.8) - mc.startWatcher() - if err := mc.watchCancel(ctx); err != nil { - mc.cleanup() - return nil, err - } - defer mc.finish() - - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout + mc.readTimeout = mc.cfg.ReadTimeout mc.writeTimeout = mc.cfg.WriteTimeout + mc.rbuf = newReadBuffer(mc.netConn) + mc.startGoroutines() // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + authData, plugin, err := mc.readHandshakePacket(ctx) if err != nil { mc.cleanup() return nil, err @@ -140,13 +132,13 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(ctx, authResp, plugin); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(ctx, authData, plugin); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. @@ -158,7 +150,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") + maxap, err := mc.getSystemVar(ctx, "max_allowed_packet") if err != nil { mc.Close() return nil, err @@ -170,7 +162,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // Handle DSN Params - err = mc.handleParams() + err = mc.handleParams(ctx) if err != nil { mc.Close() return nil, err @@ -179,6 +171,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return mc, nil } +func (c *connector) getPacket() *packet { + p := c.packetPool.Get() + if p == nil { + return &packet{} + } + return p.(*packet) +} + +func (c *connector) putPacket(p *packet) { + if p != nil && len(p.data) < maxPacketSize { + c.packetPool.Put(p) + } +} + // Driver implements driver.Connector interface. // Driver returns &MySQLDriver{}. func (c *connector) Driver() driver.Driver { diff --git a/driver_test.go b/driver_test.go index dd3d73141..99a8bee39 100644 --- a/driver_test.go +++ b/driver_test.go @@ -11,7 +11,6 @@ package mysql import ( "bytes" "context" - "crypto/tls" "database/sql" "database/sql/driver" "encoding/json" @@ -23,7 +22,6 @@ import ( "net/url" "os" "reflect" - "runtime" "strings" "sync" "sync/atomic" @@ -165,6 +163,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { } func (dbt *DBTest) fail(method, query string, err error) { + dbt.Helper() if len(query) > 300 { query = "[query too large to print]" } @@ -172,6 +171,7 @@ func (dbt *DBTest) fail(method, query string, err error) { } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + dbt.Helper() res, err := dbt.db.Exec(query, args...) if err != nil { dbt.fail("exec", query, err) @@ -1375,47 +1375,47 @@ func TestFoundRows(t *testing.T) { }) } -func TestTLS(t *testing.T) { - tlsTestReq := func(dbt *DBTest) { - if err := dbt.db.Ping(); err != nil { - if err == ErrNoTLS { - dbt.Skip("server does not support TLS") - } else { - dbt.Fatalf("error on Ping: %s", err.Error()) - } - } - - rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") - defer rows.Close() - - var variable, value *sql.RawBytes - for rows.Next() { - if err := rows.Scan(&variable, &value); err != nil { - dbt.Fatal(err.Error()) - } - - if (*value == nil) || (len(*value) == 0) { - dbt.Fatalf("no Cipher") - } else { - dbt.Logf("Cipher: %s", *value) - } - } - } - tlsTestOpt := func(dbt *DBTest) { - if err := dbt.db.Ping(); err != nil { - dbt.Fatalf("error on Ping: %s", err.Error()) - } - } - - runTests(t, dsn+"&tls=preferred", tlsTestOpt) - runTests(t, dsn+"&tls=skip-verify", tlsTestReq) - - // Verify that registering / using a custom cfg works - RegisterTLSConfig("custom-skip-verify", &tls.Config{ - InsecureSkipVerify: true, - }) - runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) -} +// func TestTLS(t *testing.T) { +// tlsTestReq := func(dbt *DBTest) { +// if err := dbt.db.Ping(); err != nil { +// if err == ErrNoTLS { +// dbt.Skip("server does not support TLS") +// } else { +// dbt.Fatalf("error on Ping: %s", err.Error()) +// } +// } + +// rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") +// defer rows.Close() + +// var variable, value *sql.RawBytes +// for rows.Next() { +// if err := rows.Scan(&variable, &value); err != nil { +// dbt.Fatal(err.Error()) +// } + +// if (*value == nil) || (len(*value) == 0) { +// dbt.Fatalf("no Cipher") +// } else { +// dbt.Logf("Cipher: %s", *value) +// } +// } +// } +// tlsTestOpt := func(dbt *DBTest) { +// if err := dbt.db.Ping(); err != nil { +// dbt.Fatalf("error on Ping: %s", err.Error()) +// } +// } + +// runTests(t, dsn+"&tls=preferred", tlsTestOpt) +// runTests(t, dsn+"&tls=skip-verify", tlsTestReq) + +// // Verify that registering / using a custom cfg works +// RegisterTLSConfig("custom-skip-verify", &tls.Config{ +// InsecureSkipVerify: true, +// }) +// runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) +// } func TestReuseClosedConnection(t *testing.T) { // this test does not use sql.database, it uses the driver directly @@ -1872,8 +1872,6 @@ func TestConcurrent(t *testing.T) { defer wg.Done() tx, err := dbt.db.Begin() - atomic.AddInt32(&remaining, -1) - if err != nil { if err.Error() != "Error 1040: Too many connections" { fatalf("error on conn %d: %s", id, err.Error()) @@ -1882,7 +1880,7 @@ func TestConcurrent(t *testing.T) { } // keep the connection busy until all connections are open - for remaining > 0 { + for atomic.AddInt32(&remaining, -1) > 0 { if _, err = tx.Exec("DO 1"); err != nil { fatalf("error on conn %d: %s", id, err.Error()) return @@ -2759,10 +2757,6 @@ func TestContextCancelStmtQuery(t *testing.T) { } func TestContextCancelBegin(t *testing.T) { - if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { - t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`) - } - runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) @@ -3307,102 +3301,124 @@ func TestConnectorTimeoutsDuringOpen(t *testing.T) { } } -// A connection which can only be closed. -type dummyConnection struct { - net.Conn - closed bool -} - -func (d *dummyConnection) Close() error { - d.closed = true - return nil -} - -func TestConnectorTimeoutsWatchCancel(t *testing.T) { - var ( - cancel func() // Used to cancel the context just after connecting. - created *dummyConnection // The created connection. - ) - - RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) { - // Canceling at this time triggers the watchCancel error branch in Connect(). - cancel() - created = &dummyConnection{} - return created, nil - }) - - mycnf := NewConfig() - mycnf.User = "root" - mycnf.Addr = "foo" - mycnf.Net = "TestConnectorTimeoutsWatchCancel" - - conn, err := NewConnector(mycnf) - if err != nil { - t.Fatal(err) - } - - db := sql.OpenDB(conn) - defer db.Close() - - var ctx context.Context - ctx, cancel = context.WithCancel(context.Background()) - defer cancel() - - if _, err := db.Conn(ctx); err != context.Canceled { - t.Errorf("got %v, want context.Canceled", err) - } - - if created == nil { - t.Fatal("no connection created") - } - if !created.closed { - t.Errorf("connection not closed") - } -} - -func TestConnectionAttributes(t *testing.T) { - if !available { - t.Skipf("MySQL server not running on %s", netAddr) - } - - attr1 := "attr1" - value1 := "value1" - attr2 := "foo" - value2 := "boo" - dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2) +// // A connection which can only be closed. +// type dummyConnection struct { +// net.Conn +// closed bool +// } + +// func (d *dummyConnection) Close() error { +// d.closed = true +// return nil +// } + +// func TestConnectorTimeoutsWatchCancel(t *testing.T) { +// var ( +// cancel func() // Used to cancel the context just after connecting. +// created *dummyConnection // The created connection. +// ) + +// RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) { +// // Canceling at this time triggers the watchCancel error branch in Connect(). +// cancel() +// created = &dummyConnection{} +// return created, nil +// }) + +// mycnf := NewConfig() +// mycnf.User = "root" +// mycnf.Addr = "foo" +// mycnf.Net = "TestConnectorTimeoutsWatchCancel" + +// conn, err := NewConnector(mycnf) +// if err != nil { +// t.Fatal(err) +// } + +// db := sql.OpenDB(conn) +// defer db.Close() + +// var ctx context.Context +// ctx, cancel = context.WithCancel(context.Background()) +// defer cancel() + +// if _, err := db.Conn(ctx); err != context.Canceled { +// t.Errorf("got %v, want context.Canceled", err) +// } + +// if created == nil { +// t.Fatal("no connection created") +// } +// if !created.closed { +// t.Errorf("connection not closed") +// } +// } + +// func TestConnectionAttributes(t *testing.T) { +// if !available { +// t.Skipf("MySQL server not running on %s", netAddr) +// } + +// attr1 := "attr1" +// value1 := "value1" +// attr2 := "foo" +// value2 := "boo" +// dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2) + +// var db *sql.DB +// if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { +// db, err = sql.Open("mysql", dsn) +// if err != nil { +// t.Fatalf("error connecting: %s", err.Error()) +// } +// defer db.Close() +// } + +// dbt := &DBTest{t, db} + +// var attrValue string +// queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" +// rows := dbt.mustQuery(queryString, connAttrClientName) +// if rows.Next() { +// rows.Scan(&attrValue) +// if attrValue != connAttrClientNameValue { +// dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) +// } +// } else { +// dbt.Errorf("no data") +// } +// rows.Close() + +// rows = dbt.mustQuery(queryString, attr2) +// if rows.Next() { +// rows.Scan(&attrValue) +// if attrValue != value2 { +// dbt.Errorf("expected %q, got %q", value2, attrValue) +// } +// } else { +// dbt.Errorf("no data") +// } +// rows.Close() +// } + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") - var db *sql.DB - if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) } - defer db.Close() - } - dbt := &DBTest{t, db} + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) - var attrValue string - queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" - rows := dbt.mustQuery(queryString, connAttrClientName) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != connAttrClientNameValue { - dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) } - } else { - dbt.Errorf("no data") - } - rows.Close() - rows = dbt.mustQuery(queryString, attr2) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != value2 { - dbt.Errorf("expected %q, got %q", value2, attrValue) + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) } - } else { - dbt.Errorf("no data") - } - rows.Close() + }) } diff --git a/infile.go b/infile.go index 0c8af9f11..d78a1b989 100644 --- a/infile.go +++ b/infile.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "fmt" "io" "os" @@ -93,7 +94,7 @@ func deferredClose(err *error, closer io.Closer) { const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP -func (mc *okHandler) handleInFileRequest(name string) (err error) { +func (mc *okHandler) handleInFileRequest(ctx context.Context, name string) (err error) { var rdr io.Reader var data []byte packetSize := defaultPacketSize @@ -154,7 +155,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil { + if ioErr := mc.conn().writePacket(ctx, data[:4+n]); ioErr != nil { return ioErr } } @@ -168,15 +169,15 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if data == nil { data = make([]byte, 4) } - if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { + if ioErr := mc.conn().writePacket(ctx, data[:4]); ioErr != nil { return ioErr } // read OK packet if err == nil { - return mc.readResultOK() + return mc.readResultOK(ctx) } - mc.conn().readPacket() + mc.conn().readPacket(ctx) return err } diff --git a/packets.go b/packets.go index a1aaf20ee..0d673511b 100644 --- a/packets.go +++ b/packets.go @@ -10,13 +10,16 @@ package mysql import ( "bytes" + "context" "crypto/tls" "database/sql/driver" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "math" + "os" "strconv" "time" ) @@ -24,28 +27,58 @@ import ( // Packets documentation: // http://dev.mysql.com/doc/internals/en/client-server-protocol.html +type packet struct { + header [4]byte + data []byte + err error +} + +func (p *packet) readFrom(r *readBuffer) { + // read the header + var data []byte + data, p.err = r.readNext(4) + if p.err != nil { + return + } + copy(p.header[:], data) + + // packet length [24 bit] + pktLen := int(uint32(p.header[0]) | uint32(p.header[1])<<8 | uint32(p.header[2])<<16) + + // read the body + data, p.err = r.readNext(pktLen) + if p.err != nil { + return + } + + p.data = append(p.data[:0], data...) +} + // Read packet to buffer 'data' -func (mc *mysqlConn) readPacket() ([]byte, error) { - var prevData []byte +func (mc *mysqlConn) readPacket(ctx context.Context) (*packet, error) { + var prevData *packet for { - // read packet header - data, err := mc.buf.readNext(4) - if err != nil { - if cerr := mc.canceled.Value(); cerr != nil { - return nil, cerr - } - mc.cfg.Logger.Print(err) - mc.Close() + var pkt *packet + select { + case pkt = <-mc.readRes: + case <-mc.closech: + mc.cfg.Logger.Print(ErrMalformPkt) + mc.closeContext(ctx) + return nil, ErrInvalidConn + case <-ctx.Done(): + mc.cleanup() + return nil, ctx.Err() + } + if pkt.err != nil { + mc.cfg.Logger.Print(ErrMalformPkt) + mc.closeContext(ctx) return nil, ErrInvalidConn } - - // packet length [24 bit] - pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // check packet sync [8 bit] - if data[3] != mc.sequence { - mc.Close() - if data[3] > mc.sequence { + if seq := pkt.header[3]; seq != mc.sequence { + mc.closeContext(ctx) + if seq > mc.sequence { return nil, ErrPktSyncMul } return nil, ErrPktSync @@ -54,78 +87,47 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long + pktLen := len(pkt.data) if pktLen == 0 { // there was no previous packet if prevData == nil { mc.cfg.Logger.Print(ErrMalformPkt) - mc.Close() + mc.closeContext(ctx) return nil, ErrInvalidConn } return prevData, nil } - // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) - if err != nil { - if cerr := mc.canceled.Value(); cerr != nil { - return nil, cerr - } - mc.cfg.Logger.Print(err) - mc.Close() - return nil, ErrInvalidConn - } - // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets if prevData == nil { - return data, nil + return pkt, nil } - return append(prevData, data...), nil + prevData.data = append(prevData.data, pkt.data...) + mc.connector.putPacket(pkt) + return prevData, nil } - prevData = append(prevData, data...) + if prevData != nil { + prevData.data = append(prevData.data, pkt.data...) + mc.connector.putPacket(pkt) + } else { + prevData = pkt + } } } // Write packet buffer 'data' -func (mc *mysqlConn) writePacket(data []byte) error { +func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error { pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } - // Perform a stale connection check. We only perform this check for - // the first query on a connection that has been checked out of the - // connection pool: a fresh connection from the pool is more likely - // to be stale, and it has not performed any previous writes that - // could cause data corruption, so it's safe to return ErrBadConn - // if the check fails. - if mc.reset { - mc.reset = false - conn := mc.netConn - if mc.rawConn != nil { - conn = mc.rawConn - } - var err error - if mc.cfg.CheckConnLiveness { - if mc.cfg.ReadTimeout != 0 { - err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) - } - if err == nil { - err = connCheck(conn) - } - } - if err != nil { - mc.cfg.Logger.Print("closing bad idle connection: ", err) - mc.Close() - return driver.ErrBadConn - } - } - for { var size int if pktLen >= maxPacketSize { @@ -141,14 +143,59 @@ func (mc *mysqlConn) writePacket(data []byte) error { } data[3] = mc.sequence - // Write packet - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - return err + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + select { + case <-mc.readRes: + mc.closeContext(ctx) + return errBadConnNoWrite + case <-mc.closech: + return ErrInvalidConn + case <-ctx.Done(): + mc.cleanup() + return ctx.Err() + default: } } - n, err := mc.netConn.Write(data[:4+size]) + // request writing the packet + select { + case <-mc.readRes: + mc.closeContext(ctx) + return errBadConnNoWrite + case mc.writeReq <- data: + case <-mc.closech: + return ErrInvalidConn + case <-ctx.Done(): + mc.cleanup() + return ctx.Err() + } + + // wait for the packet to be written + var result writeResult + select { + case result = <-mc.writeRes: + case <-mc.closech: + return ErrInvalidConn + case <-ctx.Done(): + // abort writing operation + if err := mc.netConn.SetWriteDeadline(aLongTimeAgo); err == nil { + <-mc.writeRes + } + + // we must not use this connection anymore because we don't know its state. + mc.cleanup() + + return ctx.Err() + } + n, err := result.n, result.err + if err == nil && n == 4+size { mc.sequence++ if size != maxPacketSize { @@ -164,9 +211,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { mc.cleanup() mc.cfg.Logger.Print(ErrMalformPkt) } else { - if cerr := mc.canceled.Value(); cerr != nil { - return cerr - } if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet return errBadConnNoWrite @@ -184,8 +228,8 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { - data, err = mc.readPacket() +func (mc *mysqlConn) readHandshakePacket(ctx context.Context) (data []byte, plugin string, err error) { + packet, err := mc.readPacket(ctx) if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. @@ -194,6 +238,8 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro } return } + defer mc.connector.putPacket(packet) + data = packet.data if data[0] == iERR { return nil, "", mc.handleErrorPacket(data) @@ -277,7 +323,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(ctx context.Context, authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -328,12 +374,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += 1 + len(mc.connector.encodedAttributes) // Calculate packet length and get buffer with that size - data, err := mc.buf.takeSmallBuffer(pktLen + 4) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } + data := make([]byte, pktLen+4) // ClientFlags [32 bit] data[4] = byte(clientFlags) @@ -371,18 +412,20 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.TLS != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(ctx, data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS + mc.pauseReadLoop() tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) if err := tlsConn.Handshake(); err != nil { return err } mc.rawConn = mc.netConn mc.netConn = tlsConn - mc.buf.nc = tlsConn + mc.rbuf.nc = tlsConn + mc.resumeReadLoop() } // User [null terminated string] @@ -413,57 +456,42 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) // Send Auth packet - return mc.writePacket(data[:pos]) + return mc.writePacket(ctx, data[:pos]) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { +func (mc *mysqlConn) writeAuthSwitchPacket(ctx context.Context, authData []byte) error { pktLen := 4 + len(authData) - data, err := mc.buf.takeSmallBuffer(pktLen) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } + data := mc.wbuf.takeBuffer(pktLen) // Add the auth data [EOF] copy(data[4:], authData) - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** * Command Packets * ******************************************************************************/ -func (mc *mysqlConn) writeCommandPacket(command byte) error { +func (mc *mysqlConn) writeCommandPacket(ctx context.Context, command byte) error { // Reset Packet Sequence mc.sequence = 0 - data, err := mc.buf.takeSmallBuffer(4 + 1) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } - // Add command byte + data := mc.wbuf.takeBuffer(4 + 1) data[4] = command // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { +func (mc *mysqlConn) writeCommandPacketStr(ctx context.Context, command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 pktLen := 1 + len(arg) - data, err := mc.buf.takeBuffer(pktLen + 4) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } + + data := mc.wbuf.takeBuffer(4 + pktLen) // Add command byte data[4] = command @@ -472,21 +500,15 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { +func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } - // Add command byte + data := mc.wbuf.takeBuffer(4 + 1 + 4) data[4] = command // Add arg [32 bit] @@ -496,18 +518,19 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data[8] = byte(arg >> 24) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** * Result Packets * ******************************************************************************/ -func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readAuthResult(ctx context.Context) ([]byte, string, error) { + packet, err := mc.readPacket(ctx) if err != nil { return nil, "", err } + data := packet.data // packet indicator switch data[0] { @@ -539,11 +562,12 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { } // Returns error if Packet is not a 'Result OK'-Packet -func (mc *okHandler) readResultOK() error { - data, err := mc.conn().readPacket() +func (mc *okHandler) readResultOK(ctx context.Context) error { + packet, err := mc.conn().readPacket(ctx) if err != nil { return err } + data := packet.data if data[0] == iOK { return mc.handleOkPacket(data) @@ -553,12 +577,17 @@ func (mc *okHandler) readResultOK() error { // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset -func (mc *okHandler) readResultSetHeaderPacket() (int, error) { +func (mc *okHandler) readResultSetHeaderPacket(ctx context.Context) (int, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) - data, err := mc.conn().readPacket() + packet, err := mc.conn().readPacket(ctx) + if err != nil { + return 0, err + } + defer mc.conn().connector.putPacket(packet) + data := packet.data if err == nil { switch data[0] { @@ -569,7 +598,7 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, error) { return 0, mc.conn().handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, mc.handleInFileRequest(ctx, string(data[1:])) } // column count @@ -697,14 +726,15 @@ func (mc *okHandler) handleOkPacket(data []byte) error { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { +func (mc *mysqlConn) readColumns(ctx context.Context, count int) ([]mysqlField, error) { columns := make([]mysqlField, count) for i := 0; ; i++ { - data, err := mc.readPacket() + packet, err := mc.readPacket(ctx) if err != nil { return nil, err } + data := packet.data // EOF Packet if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { @@ -798,16 +828,20 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { + ctx := rows.ctx mc := rows.mc if rows.rs.done { return io.EOF } - data, err := mc.readPacket() + rows.mc.connector.putPacket(rows.pkt) + packet, err := mc.readPacket(ctx) + rows.pkt = packet if err != nil { return err } + data := packet.data // EOF Packet if data[0] == iEOF && len(data) == 5 { @@ -887,12 +921,13 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) readUntilEOF(ctx context.Context) error { for { - data, err := mc.readPacket() + packet, err := mc.readPacket(ctx) if err != nil { return err } + data := packet.data switch data[0] { case iERR: @@ -903,6 +938,7 @@ func (mc *mysqlConn) readUntilEOF() error { } return nil } + mc.connector.putPacket(packet) } } @@ -912,8 +948,13 @@ func (mc *mysqlConn) readUntilEOF() error { // Prepare Result Packets // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html -func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { - data, err := stmt.mc.readPacket() +func (stmt *mysqlStmt) readPrepareResultPacket(ctx context.Context) (uint16, error) { + packet, err := stmt.mc.readPacket(ctx) + if err != nil { + return 0, err + } + defer stmt.mc.connector.putPacket(packet) + data := packet.data if err == nil { // packet indicator [1 byte] if data[0] != iOK { @@ -939,7 +980,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html -func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { +func (stmt *mysqlStmt) writeCommandLongData(ctx context.Context, paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -976,7 +1017,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(data[:4+pktLen]) + err := stmt.mc.writePacket(ctx, data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -992,7 +1033,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html -func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { +func (stmt *mysqlStmt) writeExecutePacket(ctx context.Context, args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", @@ -1013,20 +1054,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Reset packet-sequence mc.sequence = 0 - var data []byte var err error - if len(args) == 0 { - data, err = mc.buf.takeBuffer(minPktLen) - } else { - data, err = mc.buf.takeCompleteBuffer() - // In this case the len(data) == cap(data) which is used to optimise the flow below. - } - if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } + data := mc.wbuf.takeBuffer(minPktLen) + data = data[:cap(data)] // command [1 byte] data[4] = comStmtExecute @@ -1165,7 +1196,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, v); err != nil { + if err := stmt.writeCommandLongData(ctx, i, v); err != nil { return err } } @@ -1187,7 +1218,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + if err := stmt.writeCommandLongData(ctx, i, []byte(v)); err != nil { return err } } @@ -1222,34 +1253,30 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - if err = mc.buf.store(data); err != nil { - mc.cfg.Logger.Print(err) - return errBadConnNoWrite - } } pos += len(paramValues) data = data[:pos] } - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // For each remaining resultset in the stream, discards its rows and updates // mc.affectedRows and mc.insertIds. -func (mc *okHandler) discardResults() error { +func (mc *okHandler) discardResults(ctx context.Context) error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket(ctx) if err != nil { return err } if resLen > 0 { // columns - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().readUntilEOF(ctx); err != nil { return err } // rows - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().readUntilEOF(ctx); err != nil { return err } } @@ -1259,10 +1286,15 @@ func (mc *okHandler) discardResults() error { // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html func (rows *binaryRows) readRow(dest []driver.Value) error { - data, err := rows.mc.readPacket() + ctx := rows.ctx + + rows.mc.connector.putPacket(rows.pkt) + packet, err := rows.mc.readPacket(ctx) + rows.pkt = packet if err != nil { return err } + data := packet.data // packet indicator [1 byte] if data[0] != iOK { @@ -1435,3 +1467,82 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { return nil } + +func (mc *mysqlConn) startGoroutines() { + mc.closech = make(chan struct{}) + mc.readRes = make(chan *packet) + mc.writeReq = make(chan []byte, 1) + mc.writeRes = make(chan writeResult) + + go mc.readLoop() + go mc.writeLoop() +} + +func (mc *mysqlConn) readLoop() { + for { + pkt := mc.connector.getPacket() + mc.muRead.Lock() + pkt.readFrom(&mc.rbuf) + mc.muRead.Unlock() + select { + case mc.readRes <- pkt: + case <-mc.closech: + return + } + } +} + +func (mc *mysqlConn) pauseReadLoop() error { + // abort current read operation. + if err := mc.netConn.SetReadDeadline(aLongTimeAgo); err != nil { + return err + } + + // wait for read loop to abort. + mc.muRead.Lock() + result := <-mc.readRes + if !errors.Is(result.err, os.ErrDeadlineExceeded) { + mc.muRead.Unlock() + return errors.New("mysql: failed to abort read loop") + } + + // reset read deadline. + if err := mc.netConn.SetReadDeadline(time.Time{}); err != nil { + mc.muRead.Unlock() + return err + } + return nil +} + +func (mc *mysqlConn) resumeReadLoop() { + mc.muRead.Unlock() +} + +func (mc *mysqlConn) writeLoop() { + for { + var data []byte + select { + case data = <-mc.writeReq: + case <-mc.closech: + return + } + + n, err := mc.writeSync(data) + + select { + case mc.writeRes <- writeResult{n, err}: + case <-mc.closech: + return + } + } +} + +func (mc *mysqlConn) writeSync(data []byte) (n int, err error) { + // Write packet + if mc.writeTimeout > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + return 0, err + } + } + return mc.netConn.Write(data) +} diff --git a/packets_test.go b/packets_test.go index e86ec5848..d557afedc 100644 --- a/packets_test.go +++ b/packets_test.go @@ -10,125 +10,51 @@ package mysql import ( "bytes" - "errors" + "context" + "io" "net" "testing" - "time" ) -var ( - errConnClosed = errors.New("connection is closed") - errConnTooManyReads = errors.New("too many reads") - errConnTooManyWrites = errors.New("too many writes") -) - -// struct to mock a net.Conn for testing purposes -type mockConn struct { - laddr net.Addr - raddr net.Addr - data []byte - written []byte - queuedReplies [][]byte - closed bool - read int - reads int - writes int - maxReads int - maxWrites int -} - -func (m *mockConn) Read(b []byte) (n int, err error) { - if m.closed { - return 0, errConnClosed - } - - m.reads++ - if m.maxReads > 0 && m.reads > m.maxReads { - return 0, errConnTooManyReads - } - - n = copy(b, m.data) - m.read += n - m.data = m.data[n:] - return -} -func (m *mockConn) Write(b []byte) (n int, err error) { - if m.closed { - return 0, errConnClosed - } - - m.writes++ - if m.maxWrites > 0 && m.writes > m.maxWrites { - return 0, errConnTooManyWrites - } - - n = len(b) - m.written = append(m.written, b...) - - if n > 0 && len(m.queuedReplies) > 0 { - m.data = m.queuedReplies[0] - m.queuedReplies = m.queuedReplies[1:] - } - return -} -func (m *mockConn) Close() error { - m.closed = true - return nil -} -func (m *mockConn) LocalAddr() net.Addr { - return m.laddr -} -func (m *mockConn) RemoteAddr() net.Addr { - return m.raddr -} -func (m *mockConn) SetDeadline(t time.Time) error { - return nil -} -func (m *mockConn) SetReadDeadline(t time.Time) error { - return nil -} -func (m *mockConn) SetWriteDeadline(t time.Time) error { - return nil -} - -// make sure mockConn implements the net.Conn interface -var _ net.Conn = new(mockConn) - -func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { - conn := new(mockConn) +func newRWMockConn(t *testing.T, sequence uint8) (net.Conn, *mysqlConn) { connector, err := newConnector(NewConfig()) if err != nil { panic(err) } + + client, server := net.Pipe() mc := &mysqlConn{ - buf: newBuffer(conn), cfg: connector.cfg, connector: connector, - netConn: conn, - closech: make(chan struct{}), + netConn: server, + rbuf: newReadBuffer(server), maxAllowedPacket: defaultMaxAllowedPacket, sequence: sequence, } - return conn, mc + mc.startGoroutines() + t.Cleanup(mc.cleanup) + return client, mc } func TestReadPacketSingleByte(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - } + conn, mc := newRWMockConn(t, 0) + + go func() { + io.Copy(io.Discard, conn) + }() + go func() { + conn.Write([]byte{0x01, 0x00, 0x00, 0x00, 0xff}) + }() - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} - conn.maxReads = 1 - packet, err := mc.readPacket() + packet, err := mc.readPacket(context.Background()) if err != nil { t.Fatal(err) } - if len(packet) != 1 { - t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet)) + if len(packet.data) != 1 { + t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet.data)) } - if packet[0] != 0xff { - t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0]) + if packet.data[0] != 0xff { + t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet.data[0]) } } @@ -149,12 +75,18 @@ func TestReadPacketWrongSequenceID(t *testing.T) { ExpectedErr: ErrPktSyncMul, }, } { - conn, mc := newRWMockConn(testCase.ClientSequenceID) - - conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} - _, err := mc.readPacket() + testCase := testCase + + conn, mc := newRWMockConn(t, testCase.ClientSequenceID) + go func() { + io.Copy(io.Discard, conn) + }() + go func() { + conn.Write([]byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff}) + }() + _, err := mc.readPacket(context.Background()) if err != testCase.ExpectedErr { - t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) + t.Errorf(`expected "%v", got "%v"`, testCase.ExpectedErr, err) } // connection should not be returned to the pool in this state @@ -165,171 +97,200 @@ func TestReadPacketWrongSequenceID(t *testing.T) { } func TestReadPacketSplit(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - } - - data := make([]byte, maxPacketSize*2+4*3) const pkt2ofs = maxPacketSize + 4 const pkt3ofs = 2 * (maxPacketSize + 4) - // case 1: payload has length maxPacketSize - data = data[:pkt2ofs+4] - - // 1st packet has maxPacketSize length and sequence id 0 - // ff ff ff 00 ... - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - - // mark the payload start and end of 1st packet so that we can check if the - // content was correctly appended - data[4] = 0x11 - data[maxPacketSize+3] = 0x22 - - // 2nd packet has payload length 0 and sequence id 1 - // 00 00 00 01 - data[pkt2ofs+3] = 0x01 - - conn.data = data - conn.maxReads = 3 - packet, err := mc.readPacket() - if err != nil { - t.Fatal(err) - } - if len(packet) != maxPacketSize { - t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet)) - } - if packet[0] != 0x11 { - t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) - } - if packet[maxPacketSize-1] != 0x22 { - t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1]) - } - - // case 2: payload has length which is a multiple of maxPacketSize - data = data[:cap(data)] - - // 2nd packet now has maxPacketSize length - data[pkt2ofs] = 0xff - data[pkt2ofs+1] = 0xff - data[pkt2ofs+2] = 0xff - - // mark the payload start and end of the 2nd packet - data[pkt2ofs+4] = 0x33 - data[pkt2ofs+maxPacketSize+3] = 0x44 - - // 3rd packet has payload length 0 and sequence id 2 - // 00 00 00 02 - data[pkt3ofs+3] = 0x02 - - conn.data = data - conn.reads = 0 - conn.maxReads = 5 - mc.sequence = 0 - packet, err = mc.readPacket() - if err != nil { - t.Fatal(err) - } - if len(packet) != 2*maxPacketSize { - t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet)) - } - if packet[0] != 0x11 { - t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) - } - if packet[2*maxPacketSize-1] != 0x44 { - t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1]) - } - - // case 3: payload has a length larger maxPacketSize, which is not an exact - // multiple of it - data = data[:pkt2ofs+4+42] - data[pkt2ofs] = 0x2a - data[pkt2ofs+1] = 0x00 - data[pkt2ofs+2] = 0x00 - data[pkt2ofs+4+41] = 0x44 - - conn.data = data - conn.reads = 0 - conn.maxReads = 4 - mc.sequence = 0 - packet, err = mc.readPacket() - if err != nil { - t.Fatal(err) - } - if len(packet) != maxPacketSize+42 { - t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet)) - } - if packet[0] != 0x11 { - t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) - } - if packet[maxPacketSize+41] != 0x44 { - t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41]) - } + t.Run("case 1: payload has length maxPacketSize", func(t *testing.T) { + conn, mc := newRWMockConn(t, 0) + data := make([]byte, pkt2ofs+4) + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet has payload length 0 and sequence id 1 + // 00 00 00 01 + data[pkt2ofs+3] = 0x01 + + go func() { + conn.Write(data) + }() + // TODO: check read operation count + + packet, err := mc.readPacket(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(packet.data) != maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet.data)) + } + if packet.data[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0]) + } + if packet.data[maxPacketSize-1] != 0x22 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet.data[maxPacketSize-1]) + } + }) + + t.Run("case 2: payload has length which is a multiple of maxPacketSize", func(t *testing.T) { + conn, mc := newRWMockConn(t, 0) + data := make([]byte, maxPacketSize*2+4*3) + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet now has maxPacketSize length + data[pkt2ofs] = 0xff + data[pkt2ofs+1] = 0xff + data[pkt2ofs+2] = 0xff + data[pkt2ofs+3] = 0x01 + + // mark the payload start and end of the 2nd packet + data[pkt2ofs+4] = 0x33 + data[pkt2ofs+maxPacketSize+3] = 0x44 + + // 3rd packet has payload length 0 and sequence id 2 + // 00 00 00 02 + data[pkt3ofs+3] = 0x02 + + go func() { + conn.Write(data) + }() + // TODO: check read operation count + + packet, err := mc.readPacket(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(packet.data) != 2*maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet.data)) + } + if packet.data[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0]) + } + if packet.data[2*maxPacketSize-1] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet.data[2*maxPacketSize-1]) + } + }) + + t.Run("case 3: payload has a length larger maxPacketSize, which is not an exact multiple of it", func(t *testing.T) { + conn, mc := newRWMockConn(t, 0) + data := make([]byte, pkt2ofs+4+42) + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet + data[pkt2ofs] = 0x2a + data[pkt2ofs+1] = 0x00 + data[pkt2ofs+2] = 0x00 + data[pkt2ofs+3] = 0x01 + data[pkt2ofs+4+41] = 0x44 + + go func() { + conn.Write(data) + }() + // TODO: check read operation count + + packet, err := mc.readPacket(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(packet.data) != maxPacketSize+42 { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet.data)) + } + if packet.data[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0]) + } + if packet.data[maxPacketSize+41] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet.data[maxPacketSize+41]) + } + }) } func TestReadPacketFail(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - closech: make(chan struct{}), - cfg: NewConfig(), - } - - // illegal empty (stand-alone) packet - conn.data = []byte{0x00, 0x00, 0x00, 0x00} - conn.maxReads = 1 - _, err := mc.readPacket() - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %v", err) - } - - // reset - conn.reads = 0 - mc.sequence = 0 - mc.buf = newBuffer(conn) + t.Run("illegal empty (stand-alone) packet", func(t *testing.T) { + conn, mc := newRWMockConn(t, 0) + go func() { + conn.Write([]byte{0x00, 0x00, 0x00, 0x00}) + }() + go func() { + io.Copy(io.Discard, conn) + }() + + _, err := mc.readPacket(context.Background()) + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + }) - // fail to read header - conn.closed = true - _, err = mc.readPacket() - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %v", err) - } + t.Run("fail to read header", func(t *testing.T) { + conn, mc := newRWMockConn(t, 0) + go func() { + conn.Close() + }() - // reset - conn.closed = false - conn.reads = 0 - mc.sequence = 0 - mc.buf = newBuffer(conn) - - // fail to read body - conn.maxReads = 1 - _, err = mc.readPacket() - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %v", err) - } + _, err := mc.readPacket(context.Background()) + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + }) + + t.Run("fail to read body", func(t *testing.T) { + conn, mc := newRWMockConn(t, 0) + go func() { + conn.Write([]byte{0x01, 0x00, 0x00, 0x00}) + conn.Close() + }() + go func() { + io.Copy(io.Discard, conn) + }() + + _, err := mc.readPacket(context.Background()) + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + }) } // https://github.com/go-sql-driver/mysql/pull/801 // not-NUL terminated plugin_name in init packet func TestRegression801(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - cfg: new(Config), - sequence: 42, - closech: make(chan struct{}), - } - - conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, - 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77, - 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, - 112, 97, 115, 115, 119, 111, 114, 100} - conn.maxReads = 1 - - authData, pluginName, err := mc.readHandshakePacket() + conn, mc := newRWMockConn(t, 42) + + go func() { + conn.Write([]byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, + 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77, + 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, + 112, 97, 115, 115, 119, 111, 114, 100}) + conn.Close() + }() + + authData, pluginName, err := mc.readHandshakePacket(context.Background()) if err != nil { t.Fatalf("got error: %v", err) } diff --git a/rows.go b/rows.go index 63d0ed2d5..49491ea34 100644 --- a/rows.go +++ b/rows.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "database/sql/driver" "io" "math" @@ -22,9 +23,10 @@ type resultSet struct { } type mysqlRows struct { - mc *mysqlConn - rs resultSet - finish func() + mc *mysqlConn + ctx context.Context + rs resultSet + pkt *packet // current read packet } type binaryRows struct { @@ -98,11 +100,7 @@ func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { } func (rows *mysqlRows) Close() (err error) { - if f := rows.finish; f != nil { - f() - rows.finish = nil - } - + ctx := rows.ctx mc := rows.mc if mc == nil { return nil @@ -111,20 +109,16 @@ func (rows *mysqlRows) Close() (err error) { return err } - // flip the buffer for this connection if we need to drain it. - // note that for a successful query (i.e. one where rows.next() - // has been called until it returns false), `rows.mc` will be nil - // by the time the user calls `(*Rows).Close`, so we won't reach this - // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 - mc.buf.flip() + rows.mc.connector.putPacket(rows.pkt) + rows.pkt = nil // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.readUntilEOF(ctx) } if err == nil { handleOk := mc.clearResult() - if err = handleOk.discardResults(); err != nil { + if err = handleOk.discardResults(ctx); err != nil { return err } } @@ -141,6 +135,8 @@ func (rows *mysqlRows) HasNextResultSet() (b bool) { } func (rows *mysqlRows) nextResultSet() (int, error) { + ctx := rows.ctx + if rows.mc == nil { return 0, io.EOF } @@ -150,7 +146,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.readUntilEOF(ctx); err != nil { return 0, err } rows.rs.done = true @@ -163,7 +159,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - return rows.mc.resultUnchanged().readResultSetHeaderPacket() + return rows.mc.resultUnchanged().readResultSetHeaderPacket(ctx) } func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { @@ -182,12 +178,13 @@ func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { } func (rows *binaryRows) NextResultSet() error { + ctx := rows.ctx resLen, err := rows.nextNotEmptyResultSet() if err != nil { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(ctx, resLen) return err } @@ -204,12 +201,13 @@ func (rows *binaryRows) Next(dest []driver.Value) error { } func (rows *textRows) NextResultSet() (err error) { + ctx := rows.ctx resLen, err := rows.nextNotEmptyResultSet() if err != nil { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(ctx, resLen) return err } diff --git a/statement.go b/statement.go index 31e7799c4..97253c071 100644 --- a/statement.go +++ b/statement.go @@ -9,6 +9,7 @@ package mysql import ( + "context" "database/sql/driver" "encoding/json" "fmt" @@ -23,6 +24,8 @@ type mysqlStmt struct { } func (stmt *mysqlStmt) Close() error { + ctx := context.TODO() + if stmt.mc == nil || stmt.mc.closed.Load() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. @@ -31,7 +34,7 @@ func (stmt *mysqlStmt) Close() error { return driver.ErrBadConn } - err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(ctx, comStmtClose, stmt.id) stmt.mc = nil return err } @@ -50,12 +53,14 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + ctx := context.Background() + if stmt.mc.closed.Load() { stmt.mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, stmt.mc.markBadConn(err) } @@ -64,24 +69,24 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket(ctx) if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.readUntilEOF(ctx); err != nil { return nil, err } // Rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return nil, err } } - if err := handleOk.discardResults(); err != nil { + if err := handleOk.discardResults(ctx); err != nil { return nil, err } @@ -90,16 +95,16 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - return stmt.query(args) + return stmt.query(context.Background(), args) } -func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { +func (stmt *mysqlStmt) query(ctx context.Context, args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { stmt.mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, stmt.mc.markBadConn(err) } @@ -108,7 +113,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result handleOk := stmt.mc.clearResult() - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket(ctx) if err != nil { return nil, err } @@ -117,7 +122,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) + rows.ctx = ctx + rows.rs.columns, err = mc.readColumns(ctx, resLen) } else { rows.rs.done = true diff --git a/transaction.go b/transaction.go index 4a4b61001..d1dc14d80 100644 --- a/transaction.go +++ b/transaction.go @@ -8,15 +8,18 @@ package mysql +import "context" + type mysqlTx struct { - mc *mysqlConn + ctx context.Context + mc *mysqlConn } func (tx *mysqlTx) Commit() (err error) { if tx.mc == nil || tx.mc.closed.Load() { return ErrInvalidConn } - err = tx.mc.exec("COMMIT") + err = tx.mc.exec(tx.ctx, "COMMIT") tx.mc = nil return } @@ -25,7 +28,7 @@ func (tx *mysqlTx) Rollback() (err error) { if tx.mc == nil || tx.mc.closed.Load() { return ErrInvalidConn } - err = tx.mc.exec("ROLLBACK") + err = tx.mc.exec(tx.ctx, "ROLLBACK") tx.mc = nil return } diff --git a/utils.go b/utils.go index a24197b93..f86924206 100644 --- a/utils.go +++ b/utils.go @@ -19,7 +19,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" ) @@ -770,47 +769,6 @@ func escapeStringQuotes(buf []byte, v string) []byte { return buf[:pos] } -/****************************************************************************** -* Sync utils * -******************************************************************************/ - -// noCopy may be embedded into structs which must not be copied -// after the first use. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} - -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} - -// Unlock is a no-op used by -copylocks checker from `go vet`. -// noCopy should implement sync.Locker from Go 1.11 -// https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc -// https://github.com/golang/go/issues/26165 -func (*noCopy) Unlock() {} - -// atomicError is a wrapper for atomically accessed error values -type atomicError struct { - _ noCopy - value atomic.Value -} - -// Set sets the error value regardless of the previous value. -// The value must not be nil -func (ae *atomicError) Set(value error) { - ae.value.Store(value) -} - -// Value returns the current error value -func (ae *atomicError) Value() error { - if v := ae.value.Load(); v != nil { - // this will panic if the value doesn't implement the error interface - return v.(error) - } - return nil -} - func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { dargs := make([]driver.Value, len(named)) for n, param := range named { diff --git a/utils_test.go b/utils_test.go index 4e5fc3cb7..4067d8927 100644 --- a/utils_test.go +++ b/utils_test.go @@ -173,28 +173,6 @@ func TestEscapeQuotes(t *testing.T) { expect("foo\"bar", "foo\"bar") // not affected } -func TestAtomicError(t *testing.T) { - var ae atomicError - if ae.Value() != nil { - t.Fatal("Expected value to be nil") - } - - ae.Set(ErrMalformPkt) - if v := ae.Value(); v != ErrMalformPkt { - if v == nil { - t.Fatal("Value is still nil") - } - t.Fatal("Error did not match") - } - ae.Set(ErrPktSync) - if ae.Value() == ErrMalformPkt { - t.Fatal("Error still matches old error") - } - if v := ae.Value(); v != ErrPktSync { - t.Fatal("Error did not match") - } -} - func TestIsolationLevelMapping(t *testing.T) { data := []struct { level driver.IsolationLevel