@@ -85,6 +85,23 @@ type DBTest struct {
8585 db * sql.DB
8686}
8787
88+ type netErrorMock struct {
89+ temporary bool
90+ timeout bool
91+ }
92+
93+ func (e netErrorMock ) Temporary () bool {
94+ return e .temporary
95+ }
96+
97+ func (e netErrorMock ) Timeout () bool {
98+ return e .timeout
99+ }
100+
101+ func (e netErrorMock ) Error () string {
102+ return fmt .Sprintf ("mock net error. Temporary: %v, Timeout %v" , e .temporary , e .timeout )
103+ }
104+
88105func runTestsWithMultiStatement (t * testing.T , dsn string , tests ... func (dbt * DBTest )) {
89106 if ! available {
90107 t .Skipf ("MySQL server not running on %s" , netAddr )
@@ -1801,6 +1818,38 @@ func TestConcurrent(t *testing.T) {
18011818 })
18021819}
18031820
1821+ func testDialError (t * testing.T , dialErr error , expectErr error ) {
1822+ RegisterDial ("mydial" , func (addr string ) (net.Conn , error ) {
1823+ return nil , dialErr
1824+ })
1825+
1826+ db , err := sql .Open ("mysql" , fmt .Sprintf ("%s:%s@mydial(%s)/%s?timeout=30s" , user , pass , addr , dbname ))
1827+ if err != nil {
1828+ t .Fatalf ("error connecting: %s" , err .Error ())
1829+ }
1830+ defer db .Close ()
1831+
1832+ _ , err = db .Exec ("DO 1" )
1833+ if err != expectErr {
1834+ t .Fatalf ("was expecting %s. Got: %s" , dialErr , err )
1835+ }
1836+ }
1837+
1838+ func TestDialUnknownError (t * testing.T ) {
1839+ testErr := fmt .Errorf ("test" )
1840+ testDialError (t , testErr , testErr )
1841+ }
1842+
1843+ func TestDialNonRetryableNetErr (t * testing.T ) {
1844+ testErr := netErrorMock {}
1845+ testDialError (t , testErr , testErr )
1846+ }
1847+
1848+ func TestDialTemporaryNetErr (t * testing.T ) {
1849+ testErr := netErrorMock {temporary : true }
1850+ testDialError (t , testErr , driver .ErrBadConn )
1851+ }
1852+
18041853// Tests custom dial functions
18051854func TestCustomDial (t * testing.T ) {
18061855 if ! available {
0 commit comments