Skip to content

Commit 6a2c0a1

Browse files
committed
Support for sending connection attributes
1 parent 2307b45 commit 6a2c0a1

File tree

7 files changed

+145
-1
lines changed

7 files changed

+145
-1
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Shuode Li <elemount at qq.com>
7272
Soroush Pour <me at soroushjp.com>
7373
Stan Putrya <root.vagner at gmail.com>
7474
Stanley Gunawan <gunawan.stanley at gmail.com>
75+
Vasily Fedoseyev <vasilyfedoseyev at gmail.com>
7576
Xiangyu Hu <xiangyu.hu at outlook.com>
7677
Xiaobing Jiang <s7v7nislands at gmail.com>
7778
Xiuming Chen <cc at cxm.cc>

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ SELECT u.id FROM users as u
204204

205205
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
206206

207+
##### `connectionAttributes`
208+
209+
```
210+
Type: map
211+
Valid Values: comma-separated list of attribute:value pairs
212+
Default: empty
213+
```
214+
215+
Allows setting of connection attributes, for example `connectionAttributes=program_name:YourProgramName` will show `YourProgramName` in `Program` field of connections list of Mysql Workbench.
216+
207217
##### `interpolateParams`
208218

209219
```

driver_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -2071,5 +2071,44 @@ func TestEmptyPassword(t *testing.T) {
20712071
if !strings.HasPrefix(err.Error(), "Error 1045") {
20722072
t.Fatal(err.Error())
20732073
}
2074+
}
2075+
}
2076+
2077+
func TestConnectionAttributes(t *testing.T) {
2078+
if !available {
2079+
t.Skipf("MySQL server not running on %s", netAddr)
2080+
}
2081+
2082+
db, err := sql.Open("mysql", dsn+"&connectionAttributes=program_name:GoTest,foo:bar")
2083+
if err != nil {
2084+
t.Fatalf("error connecting: %s", err.Error())
2085+
}
2086+
defer db.Close()
2087+
dbt := &DBTest{t, db}
2088+
2089+
rows, err := dbt.db.Query("SELECT program_name FROM sys.processlist WHERE db=?", dbname)
2090+
if err != nil {
2091+
dbt.Skip("server probably does not support program_name in sys.processlist")
2092+
}
2093+
2094+
if rows.Next() {
2095+
var str string
2096+
rows.Scan(&str)
2097+
if "GoTest" != str {
2098+
dbt.Errorf("GoTest != %s", str)
2099+
}
2100+
} else {
2101+
dbt.Error("no data")
2102+
}
2103+
2104+
rows = dbt.mustQuery("select attr_value from performance_schema.session_account_connect_attrs where processlist_id=CONNECTION_ID() and attr_name='foo'")
2105+
if rows.Next() {
2106+
var str string
2107+
rows.Scan(&str)
2108+
if "bar" != str {
2109+
dbt.Errorf("bar != %s", str)
2110+
}
2111+
} else {
2112+
dbt.Error("no data")
20742113
}
20752114
}

dsn.go

+43
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type Config struct {
3939
Addr string // Network address (requires Net)
4040
DBName string // Database name
4141
Params map[string]string // Connection parameters
42+
Attributes map[string]string // Connection attributes
4243
Collation string // Connection collation
4344
Loc *time.Location // Location for time.Time values
4445
MaxAllowedPacket int // Max packet size allowed
@@ -308,6 +309,30 @@ func (cfg *Config) FormatDSN() string {
308309

309310
}
310311

312+
if len(cfg.Attributes) > 0 {
313+
// connectionAttributes=program_name:Login Server,other_name:other
314+
if hasParam {
315+
buf.WriteString("&connectionAttributes=")
316+
} else {
317+
hasParam = true
318+
buf.WriteString("?connectionAttributes=")
319+
}
320+
321+
var attr_names []string
322+
for attr_name := range cfg.Attributes {
323+
attr_names = append(attr_names, attr_name)
324+
}
325+
sort.Strings(attr_names)
326+
for index, attr_name := range attr_names {
327+
if index > 0 {
328+
buf.WriteByte(',')
329+
}
330+
buf.WriteString(attr_name)
331+
buf.WriteByte(':')
332+
buf.WriteString(url.QueryEscape(cfg.Attributes[attr_name]))
333+
}
334+
}
335+
311336
// other params
312337
if cfg.Params != nil {
313338
var params []string
@@ -588,6 +613,24 @@ func parseDSNParams(cfg *Config, params string) (err error) {
588613
if err != nil {
589614
return
590615
}
616+
case "connectionAttributes":
617+
if cfg.Attributes == nil {
618+
cfg.Attributes = make(map[string]string)
619+
}
620+
621+
var attributes string
622+
if attributes, err = url.QueryUnescape(value); err != nil {
623+
return
624+
}
625+
626+
// program_name:Name,foo:bar
627+
for _, attr_str := range strings.Split(attributes, ",") {
628+
attr := strings.SplitN(attr_str, ":", 2)
629+
if len(attr) != 2 {
630+
continue
631+
}
632+
cfg.Attributes[attr[0]] = attr[1]
633+
}
591634
default:
592635
// lazy init
593636
if cfg.Params == nil {

dsn_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ var testDSNs = []struct {
7171
}, {
7272
"tcp(de:ad:be:ef::ca:fe)/dbname",
7373
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
74+
}, {
75+
"tcp(127.0.0.1)/dbname?connectionAttributes=program_name:SomeService",
76+
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Attributes: map[string]string{"program_name": "SomeService"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
7477
},
7578
}
7679

@@ -318,6 +321,20 @@ func TestParamsAreSorted(t *testing.T) {
318321
}
319322
}
320323

324+
func TestAttributesAreSorted(t *testing.T) {
325+
expected := "/dbname?connectionAttributes=p1:v1,p2:v2"
326+
cfg := NewConfig()
327+
cfg.DBName = "dbname"
328+
cfg.Attributes = map[string]string{
329+
"p2": "v2",
330+
"p1": "v1",
331+
}
332+
actual := cfg.FormatDSN()
333+
if actual != expected {
334+
t.Errorf("generic Config.Attributes were not sorted: want %#v, got %#v", expected, actual)
335+
}
336+
}
337+
321338
func BenchmarkParseDSN(b *testing.B) {
322339
b.ReportAllocs()
323340

packets.go

+29-1
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,15 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
202202
if len(data) > pos {
203203
// character set [1 byte]
204204
// status flags [2 bytes]
205+
pos += 1 + 2
206+
205207
// capability flags (upper 2 bytes) [2 bytes]
208+
mc.flags |= clientFlag(uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16)
209+
pos += 2
210+
206211
// length of auth-plugin-data [1 byte]
207212
// reserved (all [00]) [10 bytes]
208-
pos += 1 + 2 + 2 + 1 + 10
213+
pos += 1 + 10
209214

210215
// second part of the password cipher [mininum 13 bytes],
211216
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -284,6 +289,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
284289
pktLen++
285290
}
286291

292+
connectAttrsBuf := make([]byte, 0, 100)
293+
if mc.flags&clientConnectAttrs != 0 {
294+
clientFlags |= clientConnectAttrs
295+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_client_name"))
296+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("go-mysql-driver"))
297+
298+
for k, v := range mc.cfg.Attributes {
299+
if k == "_client_name" {
300+
// do not allow overwriting reserved values
301+
continue
302+
}
303+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(k))
304+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(v))
305+
}
306+
connectAttrsBuf = appendLengthEncodedString(make([]byte, 0, 100), connectAttrsBuf)
307+
pktLen += len(connectAttrsBuf)
308+
}
309+
287310
// To specify a db name
288311
if n := len(mc.cfg.DBName); n > 0 {
289312
clientFlags |= clientConnectWithDB
@@ -367,6 +390,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
367390

368391
pos += copy(data[pos:], plugin)
369392
data[pos] = 0x00
393+
pos++
394+
395+
if clientFlags&clientConnectAttrs != 0 {
396+
pos += copy(data[pos:], connectAttrsBuf)
397+
}
370398

371399
// Send Auth packet
372400
return mc.writePacket(data)

utils.go

+6
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,12 @@ func skipLengthEncodedString(b []byte) (int, error) {
464464
return n, io.EOF
465465
}
466466

467+
// encodes a bytes slice with prepended length-encoded size and appends it to the given bytes slice
468+
func appendLengthEncodedString(b []byte, str []byte) []byte {
469+
b = appendLengthEncodedInteger(b, uint64(len(str)))
470+
return append(b, str...)
471+
}
472+
467473
// returns the number read, whether the value is NULL and the number of bytes read
468474
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
469475
// See issue #349

0 commit comments

Comments
 (0)