From 63f0c4b44fa06595ee4a0b8a73c07bdacb9f7049 Mon Sep 17 00:00:00 2001 From: Aaron Jheng Date: Mon, 11 Mar 2024 19:42:40 +0800 Subject: [PATCH 1/3] Specify a custom dial function per config --- connector.go | 24 +++++++++++++++++------- dsn.go | 37 +++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/connector.go b/connector.go index 62012dba..718613b8 100644 --- a/connector.go +++ b/connector.go @@ -87,20 +87,30 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server - dialsLock.RLock() - dial, ok := dials[mc.cfg.Net] - dialsLock.RUnlock() - if ok { + if c.cfg.DialFunc != nil { dctx := ctx if mc.cfg.Timeout > 0 { var cancel context.CancelFunc dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) defer cancel() } - mc.netConn, err = dial(dctx, mc.cfg.Addr) + mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr) } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + mc.netConn, err = dial(dctx, mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } } if err != nil { return nil, err diff --git a/dsn.go b/dsn.go index 3c7a6e21..b99cff94 100644 --- a/dsn.go +++ b/dsn.go @@ -37,24 +37,25 @@ var ( type Config struct { // non boolean fields - User string // Username - Passwd string // Password (requires User) - Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") - Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") - DBName string // Database name - Params map[string]string // Connection parameters - ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - charsets []string // Connection charset. When set, this will be set in SET NAMES query - Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout - Logger Logger // Logger + User string // Username + Passwd string // Password (requires User) + Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") + Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + charsets []string // Connection charset. When set, this will be set in SET NAMES query + Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // Specifies the dial function for creating connections // boolean fields From b52c4c55eab9044be1ab8f12edcbdf60e0b20935 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 8 Nov 2024 00:58:08 +0900 Subject: [PATCH 2/3] reformat dsn.go --- dsn.go | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/dsn.go b/dsn.go index b99cff94..f391a8fc 100644 --- a/dsn.go +++ b/dsn.go @@ -37,25 +37,26 @@ var ( type Config struct { // non boolean fields - User string // Username - Passwd string // Password (requires User) - Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") - Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") - DBName string // Database name - Params map[string]string // Connection parameters - ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - charsets []string // Connection charset. When set, this will be set in SET NAMES query - Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout - Logger Logger // Logger - DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // Specifies the dial function for creating connections + User string // Username + Passwd string // Password (requires User) + Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") + Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + charsets []string // Connection charset. When set, this will be set in SET NAMES query + Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger + // DialFunc specifies the dial function for creating connections + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // boolean fields From 2d6a857426313cb0e5216172aa8448c3a59dff6a Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 8 Nov 2024 01:08:43 +0900 Subject: [PATCH 3/3] simplify calling dialer --- connector.go | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/connector.go b/connector.go index 718613b8..769b3adc 100644 --- a/connector.go +++ b/connector.go @@ -87,29 +87,24 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + if c.cfg.DialFunc != nil { - dctx := ctx - if mc.cfg.Timeout > 0 { - var cancel context.CancelFunc - dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) - defer cancel() - } mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr) } else { dialsLock.RLock() dial, ok := dials[mc.cfg.Net] dialsLock.RUnlock() if ok { - dctx := ctx - if mc.cfg.Timeout > 0 { - var cancel context.CancelFunc - dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) - defer cancel() - } mc.netConn, err = dial(dctx, mc.cfg.Addr) } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + nd := net.Dialer{} + mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr) } } if err != nil {