diff --git a/benchmark_test.go b/benchmark_test.go index fb8a2f5f3..7ccb46fcc 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -216,9 +216,9 @@ func BenchmarkRoundtripBin(b *testing.B) { func BenchmarkInterpolation(b *testing.B) { mc := &mysqlConn{ - cfg: &config{ - interpolateParams: true, - loc: time.UTC, + cfg: &Config{ + InterpolateParams: true, + Loc: time.UTC, }, maxPacketAllowed: maxPacketSize, maxWriteSize: maxPacketSize - 1, diff --git a/connection.go b/connection.go index 455fcc1d9..c708796f8 100644 --- a/connection.go +++ b/connection.go @@ -9,9 +9,7 @@ package mysql import ( - "crypto/tls" "database/sql/driver" - "errors" "net" "strconv" "strings" @@ -23,7 +21,7 @@ type mysqlConn struct { netConn net.Conn affectedRows uint64 insertId uint64 - cfg *config + cfg *Config maxPacketAllowed int maxWriteSize int flags clientFlag @@ -33,28 +31,9 @@ type mysqlConn struct { strict bool } -type config struct { - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - tls *tls.Config - timeout time.Duration - collation uint8 - allowAllFiles bool - allowOldPasswords bool - allowCleartextPasswords bool - clientFoundRows bool - columnsWithAlias bool - interpolateParams bool -} - // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { - for param, val := range mc.cfg.params { + for param, val := range mc.cfg.Params { switch param { // Charset case "charset": @@ -70,27 +49,6 @@ func (mc *mysqlConn) handleParams() (err error) { return } - // time.Time parsing - case "parseTime": - var isBool bool - mc.parseTime, isBool = readBool(val) - if !isBool { - return errors.New("Invalid Bool value: " + val) - } - - // Strict mode - case "strict": - var isBool bool - mc.strict, isBool = readBool(val) - if !isBool { - return errors.New("Invalid Bool value: " + val) - } - - // Compression - case "compress": - err = errors.New("Compression not implemented yet") - return - // System Vars default: err = mc.exec("SET " + param + "=" + val + "") @@ -217,7 +175,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { - v := v.In(mc.cfg.loc) + v := v.In(mc.cfg.Loc) v = v.Add(time.Nanosecond * 500) // To round under microsecond year := v.Year() year100 := year / 100 @@ -298,7 +256,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, driver.ErrBadConn } if len(args) != 0 { - if !mc.cfg.interpolateParams { + if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement @@ -349,7 +307,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, driver.ErrBadConn } if len(args) != 0 { - if !mc.cfg.interpolateParams { + if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip @@ -395,6 +353,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if err == nil { rows := new(textRows) rows.mc = mc + rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns diff --git a/driver.go b/driver.go index 7502c57b4..1d7723b82 100644 --- a/driver.go +++ b/driver.go @@ -53,17 +53,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { maxPacketAllowed: maxPacketSize, maxWriteSize: maxPacketSize - 1, } - mc.cfg, err = parseDSN(dsn) + mc.cfg, err = ParseDSN(dsn) if err != nil { return nil, err } + mc.parseTime = mc.cfg.ParseTime + mc.strict = mc.cfg.Strict // Connect to Server - if dial, ok := dials[mc.cfg.net]; ok { - mc.netConn, err = dial(mc.cfg.addr) + if dial, ok := dials[mc.cfg.Net]; ok { + mc.netConn, err = dial(mc.cfg.Addr) } else { - nd := net.Dialer{Timeout: mc.cfg.timeout} - mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr) + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { return nil, err @@ -136,7 +138,7 @@ func handleAuthResult(mc *mysqlConn, cipher []byte) error { } // Retry auth if configured to do so. - if mc.cfg.allowOldPasswords && err == ErrOldPassword { + if mc.cfg.AllowOldPasswords && err == ErrOldPassword { // Retry with old authentication method. Note: there are edge cases // where this should work but doesn't; this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 @@ -144,7 +146,7 @@ func handleAuthResult(mc *mysqlConn, cipher []byte) error { return err } err = mc.readResultOK() - } else if mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword { + } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html diff --git a/driver_test.go b/driver_test.go index fbbe6d5ec..0e9571a59 100644 --- a/driver_test.go +++ b/driver_test.go @@ -91,7 +91,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { dsn2 := dsn + "&interpolateParams=true" var db2 *sql.DB - if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open("mysql", dsn2) if err != nil { t.Fatalf("Error connecting: %s", err.Error()) diff --git a/dsn.go b/dsn.go new file mode 100644 index 000000000..45b4899d9 --- /dev/null +++ b/dsn.go @@ -0,0 +1,298 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +var ( + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset") +) + +// Config is a configuration parsed from a DSN string +type Config struct { + User string // Username + Passwd string // Password + Net string // Network type + Addr string // Network address + DBName string // Database name + Params map[string]string // Connection parameters + Loc *time.Location // Location for time.Time values + TLS *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + Collation uint8 // Connection collation + + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowOldPasswords bool // Allows the old insecure password method + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + ParseTime bool // Parse time values to time.Time + Strict bool // Return warnings as errors +} + +// ParseDSN parses the DSN string to a Config +func ParseDSN(dsn string) (cfg *Config, err error) { + // New config with some default values + cfg = &Config{ + Loc: time.UTC, + Collation: defaultCollation, + } + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break + } + } + cfg.User = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.Addr = dsn[k+1 : i-1] + break + } + } + cfg.Net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.DBName = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return nil, errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return nil, errors.New("Default addr for network '" + cfg.Net + "' unknown") + } + + } + + return +} + +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *Config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + + // Disable INFILE whitelist / enable all files + case "allowAllFiles": + var isBool bool + cfg.AllowAllFiles, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Use cleartext authentication mode (MySQL 5.5.10+) + case "allowCleartextPasswords": + var isBool bool + cfg.AllowCleartextPasswords, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.AllowOldPasswords, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.ClientFoundRows, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Collation + case "collation": + collation, ok := collations[value] + if !ok { + // Note possibility for false negatives: + // could be triggered although the collation is valid if the + // collations map does not contain entries the server supports. + err = errors.New("unknown collation") + return + } + cfg.Collation = collation + break + + case "columnsWithAlias": + var isBool bool + cfg.ColumnsWithAlias, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Compression + case "compress": + return errors.New("Compression not implemented yet") + + // Enable client side placeholder substitution + case "interpolateParams": + var isBool bool + cfg.InterpolateParams, isBool = readBool(value) + if !isBool { + return fmt.Errorf("Invalid Bool value: %s", value) + } + + // Time Location + case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } + cfg.Loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // time.Time parsing + case "parseTime": + var isBool bool + cfg.ParseTime, isBool = readBool(value) + if !isBool { + return errors.New("Invalid Bool value: " + value) + } + + // Strict mode + case "strict": + var isBool bool + cfg.Strict, isBool = readBool(value) + if !isBool { + return errors.New("Invalid Bool value: " + value) + } + + // Dial Timeout + case "timeout": + cfg.Timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.TLS = &tls.Config{} + } + } else if value, err := url.QueryUnescape(value); err != nil { + return fmt.Errorf("Invalid value for tls config name: %v", err) + } else { + if strings.ToLower(value) == "skip-verify" { + cfg.TLS = &tls.Config{InsecureSkipVerify: true} + } else if tlsConfig, ok := tlsConfigRegister[value]; ok { + if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + tlsConfig.ServerName = host + } + } + + cfg.TLS = tlsConfig + } else { + return fmt.Errorf("Invalid value / unknown config name: %s", value) + } + } + + default: + // lazy init + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} diff --git a/dsn_test.go b/dsn_test.go new file mode 100644 index 000000000..4ac1f562c --- /dev/null +++ b/dsn_test.go @@ -0,0 +1,180 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/tls" + "fmt" + "net/url" + "testing" +) + +var testDSNs = []struct { + in string + out string +}{ + {"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false ParseTime:false Strict:false}"}, + {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS: Timeout:30s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, + {"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS: Timeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false ParseTime:false Strict:false}"}, +} + +func TestDSNParser(t *testing.T) { + var cfg *Config + var err error + var res string + + for i, tst := range testDSNs { + cfg, err = ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + } + + // pointer not static + cfg.TLS = nil + + res = fmt.Sprintf("%+v", cfg) + if res != tst.out { + t.Errorf("%d. ParseDSN(%q) => %q, want %q", i, tst.in, res, tst.out) + } + } +} + +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := ParseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := ParseDSN(tst) + if err == nil { + t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.TLS.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.TLS.ServerName != name { + t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } + + DeregisterTLSConfig("utils_test") +} + +func TestDSNWithCustomTLS_queryEscape(t *testing.T) { + const configKey = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) + name := "foohost" + tlsCfg := tls.Config{ServerName: name} + + RegisterTLSConfig(configKey, &tlsCfg) + + cfg, err := ParseDSN(dsn) + + if err != nil { + t.Error(err.Error()) + } else if cfg.TLS.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) + } +} + +func TestDSNUnsafeCollation(t *testing.T) { + _, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := ParseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} diff --git a/infile.go b/infile.go index 84c53a99c..9c898b705 100644 --- a/infile.go +++ b/infile.go @@ -124,7 +124,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { fileRegisterLock.RLock() fr := fileRegister[name] fileRegisterLock.RUnlock() - if mc.cfg.allowAllFiles || fr { + if mc.cfg.AllowAllFiles || fr { var file *os.File var fi os.FileInfo diff --git a/packets.go b/packets.go index 76cb7c84e..532c56c96 100644 --- a/packets.go +++ b/packets.go @@ -161,7 +161,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { if mc.flags&clientProtocol41 == 0 { return nil, ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { + if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { return nil, ErrNoTLS } pos += 2 @@ -221,22 +221,22 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientPluginAuth | mc.flags&clientLongFlag - if mc.cfg.clientFoundRows { + if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } // To enable TLS / SSL - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { clientFlags |= clientSSL } // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd)) + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 // To specify a db name - if n := len(mc.cfg.dbname); n > 0 { + if n := len(mc.cfg.DBName); n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } @@ -262,18 +262,18 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[11] = 0x00 // Charset [1 byte] - data[12] = mc.cfg.collation + data[12] = mc.cfg.Collation // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { // Send TLS / SSL request packet if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS - tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) if err := tlsConn.Handshake(); err != nil { return err } @@ -288,8 +288,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // User [null terminated string] - if len(mc.cfg.user) > 0 { - pos += copy(data[pos:], mc.cfg.user) + if len(mc.cfg.User) > 0 { + pos += copy(data[pos:], mc.cfg.User) } data[pos] = 0x00 pos++ @@ -299,8 +299,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pos += 1 + copy(data[pos+1:], scrambleBuff) // Databasename [null terminated string] - if len(mc.cfg.dbname) > 0 { - pos += copy(data[pos:], mc.cfg.dbname) + if len(mc.cfg.DBName) > 0 { + pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ } @@ -317,7 +317,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd)) + scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 @@ -339,7 +339,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.passwd) + 1 + pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) if data == nil { // can not take the buffer. Something must be wrong with the connection @@ -348,7 +348,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error { } // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.passwd) + copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 return mc.writePacket(data) @@ -575,7 +575,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { pos += n // Table [len coded string] - if mc.cfg.columnsWithAlias { + if mc.cfg.ColumnsWithAlias { tableName, _, n, err := readLengthEncodedString(data[pos:]) if err != nil { return nil, err @@ -674,7 +674,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( string(dest[i].([]byte)), - mc.cfg.loc, + mc.cfg.Loc, ) if err == nil { continue @@ -981,7 +981,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { val = []byte("0000-00-00") } else { - val = []byte(v.In(mc.cfg.loc).Format(timeFormat)) + val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) } paramValues = appendLengthEncodedInteger(paramValues, @@ -1144,7 +1144,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) case rows.mc.parseTime: - dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc) + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 if rows.columns[i].fieldType == fieldTypeDate { diff --git a/rows.go b/rows.go index ba606e146..5d21948ad 100644 --- a/rows.go +++ b/rows.go @@ -38,7 +38,7 @@ type emptyRows struct{} func (rows *mysqlRows) Columns() []string { columns := make([]string, len(rows.columns)) - if rows.mc.cfg.columnsWithAlias { + if rows.mc.cfg.ColumnsWithAlias { for i := range columns { if tableName := rows.columns[i].tableName; len(tableName) > 0 { columns[i] = tableName + "." + rows.columns[i].name diff --git a/utils.go b/utils.go index 4bfc331e3..e267cce4e 100644 --- a/utils.go +++ b/utils.go @@ -13,22 +13,14 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" - "errors" "fmt" "io" - "net" - "net/url" "strings" "time" ) var ( tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs - - errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") - errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") - errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") - errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset") ) func init() { @@ -72,235 +64,6 @@ func DeregisterTLSConfig(key string) { delete(tlsConfigRegister, key) } -// parseDSN parses the DSN string to a config -func parseDSN(dsn string) (cfg *config, err error) { - // New config with some default values - cfg = &config{ - loc: time.UTC, - collation: defaultCollation, - } - - // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] - // Find the last '/' (since the password or the net addr might contain a '/') - foundSlash := false - for i := len(dsn) - 1; i >= 0; i-- { - if dsn[i] == '/' { - foundSlash = true - var j, k int - - // left part is empty if i <= 0 - if i > 0 { - // [username[:password]@][protocol[(address)]] - // Find the last '@' in dsn[:i] - for j = i; j >= 0; j-- { - if dsn[j] == '@' { - // username[:password] - // Find the first ':' in dsn[:j] - for k = 0; k < j; k++ { - if dsn[k] == ':' { - cfg.passwd = dsn[k+1 : j] - break - } - } - cfg.user = dsn[:k] - - break - } - } - - // [protocol[(address)]] - // Find the first '(' in dsn[j+1:i] - for k = j + 1; k < i; k++ { - if dsn[k] == '(' { - // dsn[i-1] must be == ')' if an address is specified - if dsn[i-1] != ')' { - if strings.ContainsRune(dsn[k+1:i], ')') { - return nil, errInvalidDSNUnescaped - } - return nil, errInvalidDSNAddr - } - cfg.addr = dsn[k+1 : i-1] - break - } - } - cfg.net = dsn[j+1 : k] - } - - // dbname[?param1=value1&...¶mN=valueN] - // Find the first '?' in dsn[i+1:] - for j = i + 1; j < len(dsn); j++ { - if dsn[j] == '?' { - if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { - return - } - break - } - } - cfg.dbname = dsn[i+1 : j] - - break - } - } - - if !foundSlash && len(dsn) > 0 { - return nil, errInvalidDSNNoSlash - } - - if cfg.interpolateParams && unsafeCollations[cfg.collation] { - return nil, errInvalidDSNUnsafeCollation - } - - // Set default network if empty - if cfg.net == "" { - cfg.net = "tcp" - } - - // Set default address if empty - if cfg.addr == "" { - switch cfg.net { - case "tcp": - cfg.addr = "127.0.0.1:3306" - case "unix": - cfg.addr = "/tmp/mysql.sock" - default: - return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") - } - - } - - return -} - -// parseDSNParams parses the DSN "query string" -// Values must be url.QueryEscape'ed -func parseDSNParams(cfg *config, params string) (err error) { - for _, v := range strings.Split(params, "&") { - param := strings.SplitN(v, "=", 2) - if len(param) != 2 { - continue - } - - // cfg params - switch value := param[1]; param[0] { - - // Enable client side placeholder substitution - case "interpolateParams": - var isBool bool - cfg.interpolateParams, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Disable INFILE whitelist / enable all files - case "allowAllFiles": - var isBool bool - cfg.allowAllFiles, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Use cleartext authentication mode (MySQL 5.5.10+) - case "allowCleartextPasswords": - var isBool bool - cfg.allowCleartextPasswords, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Use old authentication mode (pre MySQL 4.1) - case "allowOldPasswords": - var isBool bool - cfg.allowOldPasswords, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Switch "rowsAffected" mode - case "clientFoundRows": - var isBool bool - cfg.clientFoundRows, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Collation - case "collation": - collation, ok := collations[value] - if !ok { - // Note possibility for false negatives: - // could be triggered although the collation is valid if the - // collations map does not contain entries the server supports. - err = errors.New("unknown collation") - return - } - cfg.collation = collation - break - - case "columnsWithAlias": - var isBool bool - cfg.columnsWithAlias, isBool = readBool(value) - if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) - } - - // Time Location - case "loc": - if value, err = url.QueryUnescape(value); err != nil { - return - } - cfg.loc, err = time.LoadLocation(value) - if err != nil { - return - } - - // Dial Timeout - case "timeout": - cfg.timeout, err = time.ParseDuration(value) - if err != nil { - return - } - - // TLS-Encryption - case "tls": - boolValue, isBool := readBool(value) - if isBool { - if boolValue { - cfg.tls = &tls.Config{} - } - } else if value, err := url.QueryUnescape(value); err != nil { - return fmt.Errorf("Invalid value for tls config name: %v", err) - } else { - if strings.ToLower(value) == "skip-verify" { - cfg.tls = &tls.Config{InsecureSkipVerify: true} - } else if tlsConfig, ok := tlsConfigRegister[value]; ok { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.addr) - if err == nil { - tlsConfig.ServerName = host - } - } - - cfg.tls = tlsConfig - } else { - return fmt.Errorf("Invalid value / unknown config name: %s", value) - } - } - - default: - // lazy init - if cfg.params == nil { - cfg.params = make(map[string]string) - } - - if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { - return - } - } - } - - return -} - // Returns the bool value of the input. // The 2nd return value indicates if the input was a valid bool value func readBool(input string) (value bool, valid bool) { diff --git a/utils_test.go b/utils_test.go index 8ff905069..0d6c6684f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -10,179 +10,12 @@ package mysql import ( "bytes" - "crypto/tls" "encoding/binary" "fmt" - "net/url" "testing" "time" ) -var testDSNs = []struct { - in string - out string - loc *time.Location -}{ - {"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC}, - {"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true allowCleartextPasswords:false clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local}, - {"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, - {"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls: timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC}, -} - -func TestDSNParser(t *testing.T) { - var cfg *config - var err error - var res string - - for i, tst := range testDSNs { - cfg, err = parseDSN(tst.in) - if err != nil { - t.Error(err.Error()) - } - - // pointer not static - cfg.tls = nil - - res = fmt.Sprintf("%+v", cfg) - if res != fmt.Sprintf(tst.out, tst.loc) { - t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc)) - } - } -} - -func TestDSNParserInvalid(t *testing.T) { - var invalidDSNs = []string{ - "@net(addr/", // no closing brace - "@tcp(/", // no closing brace - "tcp(/", // no closing brace - "(/", // no closing brace - "net(addr)//", // unescaped - "user:pass@tcp(1.2.3.4:3306)", // no trailing slash - //"/dbname?arg=/some/unescaped/path", - } - - for i, tst := range invalidDSNs { - if _, err := parseDSN(tst); err == nil { - t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) - } - } -} - -func TestDSNWithCustomTLS(t *testing.T) { - baseDSN := "user:password@tcp(localhost:5555)/dbname?tls=" - tlsCfg := tls.Config{} - - RegisterTLSConfig("utils_test", &tlsCfg) - - // Custom TLS is missing - tst := baseDSN + "invalid_tls" - cfg, err := parseDSN(tst) - if err == nil { - t.Errorf("Invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) - } - - tst = baseDSN + "utils_test" - - // Custom TLS with a server name - name := "foohost" - tlsCfg.ServerName = name - cfg, err = parseDSN(tst) - - if err != nil { - t.Error(err.Error()) - } else if cfg.tls.ServerName != name { - t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) - } - - // Custom TLS without a server name - name = "localhost" - tlsCfg.ServerName = "" - cfg, err = parseDSN(tst) - - if err != nil { - t.Error(err.Error()) - } else if cfg.tls.ServerName != name { - t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) - } - - DeregisterTLSConfig("utils_test") -} - -func TestDSNWithCustomTLS_queryEscape(t *testing.T) { - const configKey = "&%!:" - dsn := "user:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) - name := "foohost" - tlsCfg := tls.Config{ServerName: name} - - RegisterTLSConfig(configKey, &tlsCfg) - - cfg, err := parseDSN(dsn) - - if err != nil { - t.Error(err.Error()) - } else if cfg.tls.ServerName != name { - t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) - } -} - -func TestDSNUnsafeCollation(t *testing.T) { - _, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") - if err != errInvalidDSNUnsafeCollation { - t.Errorf("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err) - } - - _, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") - if err != nil { - t.Errorf("Expected %v, Got %v", nil, err) - } - - _, err = parseDSN("/dbname?collation=gbk_chinese_ci") - if err != nil { - t.Errorf("Expected %v, Got %v", nil, err) - } - - _, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true") - if err != nil { - t.Errorf("Expected %v, Got %v", nil, err) - } - - _, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") - if err != nil { - t.Errorf("Expected %v, Got %v", nil, err) - } - - _, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") - if err != nil { - t.Errorf("Expected %v, Got %v", nil, err) - } - - _, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") - if err != nil { - t.Errorf("Expected %v, Got %v", nil, err) - } -} - -func BenchmarkParseDSN(b *testing.B) { - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - for _, tst := range testDSNs { - if _, err := parseDSN(tst.in); err != nil { - b.Error(err.Error()) - } - } - } -} - func TestScanNullTime(t *testing.T) { var scanTests = []struct { in interface{}