Skip to content

Support for sending connection attributes #737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ before_script:
- sudo service mysql restart
- .travis/wait_mysql.sh
- mysql -e 'create database gotest;'
- mysql -e 'show databases;'

matrix:
include:
Expand Down
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Shuode Li <elemount at qq.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Vasily Fedoseyev <vasilyfedoseyev at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ SELECT u.id FROM users as u

will return `u.id` instead of just `id` if `columnsWithAlias=true`.

##### `connectAttrs`

```
Type: map
Valid Values: comma-separated list of attribute:value pairs
Default: empty
```

Allows setting of connection attributes, for example `connectAttrs=program_name:YourProgramName` will show `YourProgramName` in `Program` field of connections list of Mysql Workbench, if your server supports it (requires `performance_schema` to be supported and enabled).

##### `interpolateParams`

```
Expand Down Expand Up @@ -487,4 +497,3 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE).

![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow")

30 changes: 19 additions & 11 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"testing"
Expand Down Expand Up @@ -363,13 +364,16 @@ func TestAuthFastCleartextPassword(t *testing.T) {
}

// check written auth response
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User)
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
expectedAuthResp := []byte("secret")
if !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response:\n%s\nExpected:\n%s\n",
hex.Dump(writtenAuthResp), hex.Dump(expectedAuthResp))
}
if conn.written[authRespEnd] != 0 {
t.Fatalf("Expected null-terminated")
}
conn.written = nil

Expand Down Expand Up @@ -683,14 +687,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
}

// check written auth response
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
authRespEnd := authRespStart + 1 + len(authResp) + 1
writtenAuthRespLen := conn.written[authRespStart]
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User)
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
expectedAuthResp := []byte("secret")
if !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response:\n%s\nExpected:\n%s\n",
hex.Dump(writtenAuthResp), hex.Dump(expectedAuthResp))
}
if conn.written[authRespEnd] != 0 {
t.Fatalf("Expected null-terminated")
}

conn.written = nil

// auth response (OK)
Expand Down
4 changes: 3 additions & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const (
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientSecureConn // reserved2 in 8.0
clientMultiStatements
clientMultiResults
clientPSMultiResults
Expand All @@ -56,6 +56,8 @@ const (
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
clientSslVerifyServerCert clientFlag = 1 << 30
clientRememberOptions clientFlag = 1 << 31
)

const (
Expand Down
50 changes: 50 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2077,6 +2077,56 @@ func TestEmptyPassword(t *testing.T) {
}
}

func TestConnectAttrs(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

db, err := sql.Open("mysql", dsn+"&connectAttrs=program_name:GoTest,foo:bar")
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
dbt := &DBTest{t, db}

rows := dbt.mustQuery("SHOW VARIABLES LIKE 'performance_schema'")
if rows.Next() {
var var_name, value string
rows.Scan(&var_name, &value)
if value != "ON" {
t.Skip("performance_schema is disabled")
}
} else {
t.Skip("no performance_schema variable in mysql")
}

rows, err = dbt.db.Query("select attr_value from performance_schema.session_connect_attrs where processlist_id=CONNECTION_ID() and attr_name='program_name'")
if err != nil {
dbt.Skipf("server probably does not support performance_schema.session_connect_attrs: %s", err)
}

if rows.Next() {
var str string
rows.Scan(&str)
if "GoTest" != str {
dbt.Errorf("GoTest != %s", str)
}
} else {
dbt.Error("no data for program_name")
}

rows = dbt.mustQuery("select attr_value from performance_schema.session_connect_attrs where processlist_id=CONNECTION_ID() and attr_name='foo'")
if rows.Next() {
var str string
rows.Scan(&str)
if "bar" != str {
dbt.Errorf("bar != %s", str)
}
} else {
dbt.Error("no data for custom attribute")
}
}

// static interface implementation checks of mysqlConn
var (
_ driver.ConnBeginTx = &mysqlConn{}
Expand Down
43 changes: 43 additions & 0 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Config struct {
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
ConnectAttrs map[string]string // Connection attributes
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
Expand Down Expand Up @@ -308,6 +309,30 @@ func (cfg *Config) FormatDSN() string {

}

if len(cfg.ConnectAttrs) > 0 {
// connectAttrs=program_name:Login Server,other_name:other
if hasParam {
buf.WriteString("&connectAttrs=")
} else {
hasParam = true
buf.WriteString("?connectAttrs=")
}

var attr_names []string
for attr_name := range cfg.ConnectAttrs {
attr_names = append(attr_names, attr_name)
}
sort.Strings(attr_names)
for index, attr_name := range attr_names {
if index > 0 {
buf.WriteByte(',')
}
buf.WriteString(attr_name)
buf.WriteByte(':')
buf.WriteString(url.QueryEscape(cfg.ConnectAttrs[attr_name]))
}
}

// other params
if cfg.Params != nil {
var params []string
Expand Down Expand Up @@ -588,6 +613,24 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return
}
case "connectAttrs":
if cfg.ConnectAttrs == nil {
cfg.ConnectAttrs = make(map[string]string)
}

var ConnectAttrs string
if ConnectAttrs, err = url.QueryUnescape(value); err != nil {
return
}

// program_name:Name,foo:bar
for _, attr_str := range strings.Split(ConnectAttrs, ",") {
attr := strings.SplitN(attr_str, ":", 2)
if len(attr) != 2 {
continue
}
cfg.ConnectAttrs[attr[0]] = attr[1]
}
default:
// lazy init
if cfg.Params == nil {
Expand Down
17 changes: 17 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ var testDSNs = []struct {
}, {
"tcp(de:ad:be:ef::ca:fe)/dbname",
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"tcp(127.0.0.1)/dbname?connectAttrs=program_name:SomeService",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", ConnectAttrs: map[string]string{"program_name": "SomeService"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
},
}

Expand Down Expand Up @@ -318,6 +321,20 @@ func TestParamsAreSorted(t *testing.T) {
}
}

func TestAttributesAreSorted(t *testing.T) {
expected := "/dbname?connectAttrs=p1:v1,p2:v2"
cfg := NewConfig()
cfg.DBName = "dbname"
cfg.ConnectAttrs = map[string]string{
"p2": "v2",
"p1": "v1",
}
actual := cfg.FormatDSN()
if actual != expected {
t.Errorf("generic Config.ConnectAttrs were not sorted: want %#v, got %#v", expected, actual)
}
}

func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()

Expand Down
66 changes: 54 additions & 12 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,15 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2

// capability flags (upper 2 bytes) [2 bytes]
mc.flags |= clientFlag(uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16)
pos += 2

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += 1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -246,10 +251,9 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {

// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, insecureAuth bool, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientTransactions |
clientLocalFiles |
Expand All @@ -270,22 +274,44 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
clientFlags |= clientMultiStatements
}

if !insecureAuth {
clientFlags |= clientSecureConn
}

// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
if len(authRespLEI) > 1 {
if len(authRespLEI) > 1 && clientFlags&clientSecureConn != 0 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
clientFlags |= clientPluginAuthLenEncClientData
}

pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
if addNUL {
if clientFlags&clientSecureConn == 0 || clientFlags&clientPluginAuthLenEncClientData == 0 {
pktLen++
}

connectAttrsBuf := make([]byte, 0, 100)
if mc.flags&clientConnectAttrs != 0 {
clientFlags |= clientConnectAttrs
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_client_name"))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("go-mysql-driver"))

for k, v := range mc.cfg.ConnectAttrs {
if k == "_client_name" {
// do not allow overwriting reserved values
continue
}
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(k))
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(v))
}
connectAttrsBuf = appendLengthEncodedString(make([]byte, 0, 100), connectAttrsBuf)
pktLen += len(connectAttrsBuf)
}

// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
if n := len(mc.cfg.DBName); mc.flags&clientConnectWithDB != 0 && n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
Expand Down Expand Up @@ -350,23 +376,39 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
data[pos] = 0x00
pos++

// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
// Auth Data [length encoded integer + data] if clientPluginAuthLenEncClientData
// clientSecureConn => 1 byte len + data
// else null-terminated
if clientFlags&clientPluginAuthLenEncClientData != 0 {
pos += copy(data[pos:], authRespLEI)
} else if clientFlags&clientSecureConn != 0 {
data[pos] = uint8(len(authResp))
pos++
}
pos += copy(data[pos:], authResp)
if addNUL {
if clientFlags&clientSecureConn == 0 && clientFlags&clientPluginAuthLenEncClientData == 0 {
data[pos] = 0x00
pos++
}

// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
if clientFlags&clientConnectWithDB != 0 {
pos += copy(data[pos:], mc.cfg.DBName)
data[pos] = 0x00
pos++
}

pos += copy(data[pos:], plugin)
data[pos] = 0x00
// auth plugin name [null terminated string]
if clientFlags&clientPluginAuth != 0 {
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
}

// connection attributes [lenenc-int total + lenenc-str key-value pairs]
if clientFlags&clientConnectAttrs != 0 {
pos += copy(data[pos:], connectAttrsBuf)
}

// Send Auth packet
return mc.writePacket(data)
Expand Down
6 changes: 6 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@ func skipLengthEncodedString(b []byte) (int, error) {
return n, io.EOF
}

// encodes a bytes slice with prepended length-encoded size and appends it to the given bytes slice
func appendLengthEncodedString(b []byte, str []byte) []byte {
b = appendLengthEncodedInteger(b, uint64(len(str)))
return append(b, str...)
}

// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
Expand Down