diff --git a/.travis.yml b/.travis.yml index 47dd289a0..7a4113dbc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,7 @@ before_script: - sudo service mysql restart - .travis/wait_mysql.sh - mysql -e 'create database gotest;' + - mysql -e 'show databases;' matrix: include: diff --git a/AUTHORS b/AUTHORS index 73ff68fbc..7124f8bd4 100644 --- a/AUTHORS +++ b/AUTHORS @@ -72,6 +72,7 @@ Shuode Li Soroush Pour Stan Putrya Stanley Gunawan +Vasily Fedoseyev Xiangyu Hu Xiaobing Jiang Xiuming Chen diff --git a/README.md b/README.md index 2e9b07eeb..9c712f4fe 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `connectAttrs` + +``` +Type: map +Valid Values: comma-separated list of attribute:value pairs +Default: empty +``` + +Allows setting of connection attributes, for example `connectAttrs=program_name:YourProgramName` will show `YourProgramName` in `Program` field of connections list of Mysql Workbench, if your server supports it (requires `performance_schema` to be supported and enabled). + ##### `interpolateParams` ``` @@ -487,4 +497,3 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). ![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") - diff --git a/auth_test.go b/auth_test.go index 407363be4..040a016f7 100644 --- a/auth_test.go +++ b/auth_test.go @@ -13,6 +13,7 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" + "encoding/hex" "encoding/pem" "fmt" "testing" @@ -363,13 +364,16 @@ func TestAuthFastCleartextPassword(t *testing.T) { } // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) authRespEnd := authRespStart + 1 + len(authResp) - writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{115, 101, 99, 114, 101, 116} - if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + expectedAuthResp := []byte("secret") + if !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response:\n%s\nExpected:\n%s\n", + hex.Dump(writtenAuthResp), hex.Dump(expectedAuthResp)) + } + if conn.written[authRespEnd] != 0 { + t.Fatalf("Expected null-terminated") } conn.written = nil @@ -683,14 +687,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { } // check written auth response - authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) + 1 - writtenAuthRespLen := conn.written[authRespStart] + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + authRespEnd := authRespStart + 1 + len(authResp) writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} - if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { - t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + expectedAuthResp := []byte("secret") + if !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response:\n%s\nExpected:\n%s\n", + hex.Dump(writtenAuthResp), hex.Dump(expectedAuthResp)) } + if conn.written[authRespEnd] != 0 { + t.Fatalf("Expected null-terminated") + } + conn.written = nil // auth response (OK) diff --git a/const.go b/const.go index b1e6b85ef..6118a5e07 100644 --- a/const.go +++ b/const.go @@ -46,7 +46,7 @@ const ( clientIgnoreSIGPIPE clientTransactions clientReserved - clientSecureConn + clientSecureConn // reserved2 in 8.0 clientMultiStatements clientMultiResults clientPSMultiResults @@ -56,6 +56,8 @@ const ( clientCanHandleExpiredPasswords clientSessionTrack clientDeprecateEOF + clientSslVerifyServerCert clientFlag = 1 << 30 + clientRememberOptions clientFlag = 1 << 31 ) const ( diff --git a/driver_test.go b/driver_test.go index f2bf344e5..e45441efa 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2077,6 +2077,56 @@ func TestEmptyPassword(t *testing.T) { } } +func TestConnectAttrs(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn+"&connectAttrs=program_name:GoTest,foo:bar") + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + dbt := &DBTest{t, db} + + rows := dbt.mustQuery("SHOW VARIABLES LIKE 'performance_schema'") + if rows.Next() { + var var_name, value string + rows.Scan(&var_name, &value) + if value != "ON" { + t.Skip("performance_schema is disabled") + } + } else { + t.Skip("no performance_schema variable in mysql") + } + + rows, err = dbt.db.Query("select attr_value from performance_schema.session_connect_attrs where processlist_id=CONNECTION_ID() and attr_name='program_name'") + if err != nil { + dbt.Skipf("server probably does not support performance_schema.session_connect_attrs: %s", err) + } + + if rows.Next() { + var str string + rows.Scan(&str) + if "GoTest" != str { + dbt.Errorf("GoTest != %s", str) + } + } else { + dbt.Error("no data for program_name") + } + + rows = dbt.mustQuery("select attr_value from performance_schema.session_connect_attrs where processlist_id=CONNECTION_ID() and attr_name='foo'") + if rows.Next() { + var str string + rows.Scan(&str) + if "bar" != str { + dbt.Errorf("bar != %s", str) + } + } else { + dbt.Error("no data for custom attribute") + } +} + // static interface implementation checks of mysqlConn var ( _ driver.ConnBeginTx = &mysqlConn{} diff --git a/dsn.go b/dsn.go index be014babe..d0098ff71 100644 --- a/dsn.go +++ b/dsn.go @@ -39,6 +39,7 @@ type Config struct { Addr string // Network address (requires Net) DBName string // Database name Params map[string]string // Connection parameters + ConnectAttrs map[string]string // Connection attributes Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed @@ -308,6 +309,30 @@ func (cfg *Config) FormatDSN() string { } + if len(cfg.ConnectAttrs) > 0 { + // connectAttrs=program_name:Login Server,other_name:other + if hasParam { + buf.WriteString("&connectAttrs=") + } else { + hasParam = true + buf.WriteString("?connectAttrs=") + } + + var attr_names []string + for attr_name := range cfg.ConnectAttrs { + attr_names = append(attr_names, attr_name) + } + sort.Strings(attr_names) + for index, attr_name := range attr_names { + if index > 0 { + buf.WriteByte(',') + } + buf.WriteString(attr_name) + buf.WriteByte(':') + buf.WriteString(url.QueryEscape(cfg.ConnectAttrs[attr_name])) + } + } + // other params if cfg.Params != nil { var params []string @@ -588,6 +613,24 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + case "connectAttrs": + if cfg.ConnectAttrs == nil { + cfg.ConnectAttrs = make(map[string]string) + } + + var ConnectAttrs string + if ConnectAttrs, err = url.QueryUnescape(value); err != nil { + return + } + + // program_name:Name,foo:bar + for _, attr_str := range strings.Split(ConnectAttrs, ",") { + attr := strings.SplitN(attr_str, ":", 2) + if len(attr) != 2 { + continue + } + cfg.ConnectAttrs[attr[0]] = attr[1] + } default: // lazy init if cfg.Params == nil { diff --git a/dsn_test.go b/dsn_test.go index 1cd095496..dae4c24e0 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -71,6 +71,9 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "tcp(127.0.0.1)/dbname?connectAttrs=program_name:SomeService", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", ConnectAttrs: map[string]string{"program_name": "SomeService"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, }, } @@ -318,6 +321,20 @@ func TestParamsAreSorted(t *testing.T) { } } +func TestAttributesAreSorted(t *testing.T) { + expected := "/dbname?connectAttrs=p1:v1,p2:v2" + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.ConnectAttrs = map[string]string{ + "p2": "v2", + "p1": "v1", + } + actual := cfg.FormatDSN() + if actual != expected { + t.Errorf("generic Config.ConnectAttrs were not sorted: want %#v, got %#v", expected, actual) + } +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs() diff --git a/packets.go b/packets.go index f99934e73..13f85fe89 100644 --- a/packets.go +++ b/packets.go @@ -202,10 +202,15 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 + // capability flags (upper 2 bytes) [2 bytes] + mc.flags |= clientFlag(uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16) + pos += 2 + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -246,10 +251,9 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth bool, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | - clientSecureConn | clientLongPassword | clientTransactions | clientLocalFiles | @@ -270,22 +274,44 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, clientFlags |= clientMultiStatements } + if !insecureAuth { + clientFlags |= clientSecureConn + } + // encode length of the auth plugin data var authRespLEIBuf [9]byte authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) - if len(authRespLEI) > 1 { + if len(authRespLEI) > 1 && clientFlags&clientSecureConn != 0 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer clientFlags |= clientPluginAuthLenEncClientData } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - if addNUL { + if clientFlags&clientSecureConn == 0 || clientFlags&clientPluginAuthLenEncClientData == 0 { pktLen++ } + connectAttrsBuf := make([]byte, 0, 100) + if mc.flags&clientConnectAttrs != 0 { + clientFlags |= clientConnectAttrs + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_client_name")) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("go-mysql-driver")) + + for k, v := range mc.cfg.ConnectAttrs { + if k == "_client_name" { + // do not allow overwriting reserved values + continue + } + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(k)) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(v)) + } + connectAttrsBuf = appendLengthEncodedString(make([]byte, 0, 100), connectAttrsBuf) + pktLen += len(connectAttrsBuf) + } + // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { + if n := len(mc.cfg.DBName); mc.flags&clientConnectWithDB != 0 && n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } @@ -350,23 +376,39 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, data[pos] = 0x00 pos++ - // Auth Data [length encoded integer] - pos += copy(data[pos:], authRespLEI) + // Auth Data [length encoded integer + data] if clientPluginAuthLenEncClientData + // clientSecureConn => 1 byte len + data + // else null-terminated + if clientFlags&clientPluginAuthLenEncClientData != 0 { + pos += copy(data[pos:], authRespLEI) + } else if clientFlags&clientSecureConn != 0 { + data[pos] = uint8(len(authResp)) + pos++ + } pos += copy(data[pos:], authResp) - if addNUL { + if clientFlags&clientSecureConn == 0 && clientFlags&clientPluginAuthLenEncClientData == 0 { data[pos] = 0x00 pos++ } // Databasename [null terminated string] - if len(mc.cfg.DBName) > 0 { + if clientFlags&clientConnectWithDB != 0 { pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ } - pos += copy(data[pos:], plugin) - data[pos] = 0x00 + // auth plugin name [null terminated string] + if clientFlags&clientPluginAuth != 0 { + pos += copy(data[pos:], plugin) + data[pos] = 0x00 + pos++ + } + + // connection attributes [lenenc-int total + lenenc-str key-value pairs] + if clientFlags&clientConnectAttrs != 0 { + pos += copy(data[pos:], connectAttrsBuf) + } // Send Auth packet return mc.writePacket(data) diff --git a/utils.go b/utils.go index cb3650bb9..bb1be6f1a 100644 --- a/utils.go +++ b/utils.go @@ -466,6 +466,12 @@ func skipLengthEncodedString(b []byte) (int, error) { return n, io.EOF } +// encodes a bytes slice with prepended length-encoded size and appends it to the given bytes slice +func appendLengthEncodedString(b []byte, str []byte) []byte { + b = appendLengthEncodedInteger(b, uint64(len(str))) + return append(b, str...) +} + // returns the number read, whether the value is NULL and the number of bytes read func readLengthEncodedInteger(b []byte) (uint64, bool, int) { // See issue #349