Skip to content

Commit d3e4fe6

Browse files
methanebheni
andauthored
Use PathEscape for dbname in DSN. (#1432)
Support for slashes in database names via url escape codes. On the other hand, '%' in DSN is now treated as percent-encoding. Co-authored-by: Brian Hendriks <[email protected]>
1 parent 924f833 commit d3e4fe6

File tree

4 files changed

+51
-31
lines changed

4 files changed

+51
-31
lines changed

Diff for: AUTHORS

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ Xuehong Chan <chanxuehong at gmail.com>
110110
Zhenye Xie <xiezhenye at gmail.com>
111111
Zhixin Wen <john.wenzhixin at gmail.com>
112112
Ziheng Lyu <zihenglv at gmail.com>
113+
Brian Hendriks <brian at dolthub.com>
113114

114115
# Organizations
115116

@@ -127,3 +128,4 @@ Percona LLC
127128
Pivotal Inc.
128129
Stripe Inc.
129130
Zendesk Inc.
131+
Dolthub Inc.

Diff for: README.md

+6
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ This has the same effect as an empty DSN string:
114114
115115
```
116116

117+
`dbname` is escaped by [PathEscape()]()https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes:
118+
119+
```
120+
/dbname%2Fwithslash
121+
```
122+
117123
Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.
118124

119125
#### Password

Diff for: dsn.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func (cfg *Config) FormatDSN() string {
203203

204204
// /dbname
205205
buf.WriteByte('/')
206-
buf.WriteString(cfg.DBName)
206+
buf.WriteString(url.PathEscape(cfg.DBName))
207207

208208
// [?param1=value1&...&paramN=valueN]
209209
hasParam := false
@@ -365,7 +365,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
365365
break
366366
}
367367
}
368-
cfg.DBName = dsn[i+1 : j]
368+
369+
dbname := dsn[i+1 : j]
370+
if cfg.DBName, err = url.PathUnescape(dbname); err != nil {
371+
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
372+
}
369373

370374
break
371375
}

Diff for: dsn_test.go

+37-29
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ var testDSNs = []struct {
5050
}, {
5151
"/dbname",
5252
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
53+
}, {
54+
"/dbname%2Fwithslash",
55+
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
5356
}, {
5457
"@/",
5558
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
@@ -76,17 +79,20 @@ var testDSNs = []struct {
7679

7780
func TestDSNParser(t *testing.T) {
7881
for i, tst := range testDSNs {
79-
cfg, err := ParseDSN(tst.in)
80-
if err != nil {
81-
t.Error(err.Error())
82-
}
82+
t.Run(tst.in, func(t *testing.T) {
83+
cfg, err := ParseDSN(tst.in)
84+
if err != nil {
85+
t.Error(err.Error())
86+
return
87+
}
8388

84-
// pointer not static
85-
cfg.TLS = nil
89+
// pointer not static
90+
cfg.TLS = nil
8691

87-
if !reflect.DeepEqual(cfg, tst.out) {
88-
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
89-
}
92+
if !reflect.DeepEqual(cfg, tst.out) {
93+
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
94+
}
95+
})
9096
}
9197
}
9298

@@ -113,27 +119,29 @@ func TestDSNParserInvalid(t *testing.T) {
113119

114120
func TestDSNReformat(t *testing.T) {
115121
for i, tst := range testDSNs {
116-
dsn1 := tst.in
117-
cfg1, err := ParseDSN(dsn1)
118-
if err != nil {
119-
t.Error(err.Error())
120-
continue
121-
}
122-
cfg1.TLS = nil // pointer not static
123-
res1 := fmt.Sprintf("%+v", cfg1)
124-
125-
dsn2 := cfg1.FormatDSN()
126-
cfg2, err := ParseDSN(dsn2)
127-
if err != nil {
128-
t.Error(err.Error())
129-
continue
130-
}
131-
cfg2.TLS = nil // pointer not static
132-
res2 := fmt.Sprintf("%+v", cfg2)
122+
t.Run(tst.in, func(t *testing.T) {
123+
dsn1 := tst.in
124+
cfg1, err := ParseDSN(dsn1)
125+
if err != nil {
126+
t.Error(err.Error())
127+
return
128+
}
129+
cfg1.TLS = nil // pointer not static
130+
res1 := fmt.Sprintf("%+v", cfg1)
133131

134-
if res1 != res2 {
135-
t.Errorf("%d. %q does not match %q", i, res2, res1)
136-
}
132+
dsn2 := cfg1.FormatDSN()
133+
cfg2, err := ParseDSN(dsn2)
134+
if err != nil {
135+
t.Error(err.Error())
136+
return
137+
}
138+
cfg2.TLS = nil // pointer not static
139+
res2 := fmt.Sprintf("%+v", cfg2)
140+
141+
if res1 != res2 {
142+
t.Errorf("%d. %q does not match %q", i, res2, res1)
143+
}
144+
})
137145
}
138146
}
139147

0 commit comments

Comments
 (0)