Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ func (s transactionStatus) String() string {
panic("not reached")
}

type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}

type defaultDialer struct{}

func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { return net.Dial(ntw, addr) }
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(ntw, addr, timeout)
}

type conn struct {
c net.Conn
buf *bufio.Reader
Expand All @@ -89,6 +101,10 @@ func (c *conn) writeBuf(b byte) *writeBuf {
}

func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
}

func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
defer func() {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn
Expand Down Expand Up @@ -174,7 +190,7 @@ func Open(name string) (_ driver.Conn, err error) {
}
}

c, err := dial(o)
c, err := dial(d, o)
if err != nil {
return nil, err
}
Expand All @@ -188,7 +204,7 @@ func Open(name string) (_ driver.Conn, err error) {
return cn, err
}

func dial(o values) (net.Conn, error) {
func dial(d Dialer, o values) (net.Conn, error) {
ntw, addr := network(o)

timeout := o.Get("connect_timeout")
Expand All @@ -207,14 +223,14 @@ func dial(o values) (net.Conn, error) {
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
conn, err := net.DialTimeout(ntw, addr, duration)
conn, err := d.DialTimeout(ntw, addr, duration)
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
return net.Dial(ntw, addr)
return d.Dial(ntw, addr)
}

func network(o values) (string, string) {
Expand Down