Skip to content

Commit c7bbe20

Browse files
author
Fabian Holler
committed
allow retry on some netErrors
Since the commit "Don't return ErrBadConn on a network error" net.OpError do not return driver.ErrBadConn anymore. This caused that in some situations the sql package does not retry an operation when it should. E.g. when the postgresql server is restarted, a broken pipe error might happen for the query is done after the server finished the startup (lib#870). With this commit driver.ErrBadConn is returned for netErrors when it's ensured that the server did not already executed the operation. This is the case when e.g. a netError occur for the call that tries to send the message to initiate the query to the server.
1 parent 2ff3cb3 commit c7bbe20

File tree

2 files changed

+81
-34
lines changed

2 files changed

+81
-34
lines changed

conn.go

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,9 @@ func (cn *conn) gname() string {
600600
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
601601
b := cn.writeBuf('Q')
602602
b.string(q)
603-
cn.send(b)
603+
if err := cn.send(b); err != nil {
604+
return res, commandTag, cn.errHandleFatalNetErrors(err)
605+
}
604606

605607
for {
606608
t, r := cn.recv1()
@@ -632,7 +634,10 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
632634

633635
b := cn.writeBuf('Q')
634636
b.string(q)
635-
cn.send(b)
637+
638+
if err := cn.send(b); err != nil {
639+
return nil, cn.errHandleFatalNetErrors(err)
640+
}
636641

637642
for {
638643
t, r := cn.recv1()
@@ -752,7 +757,7 @@ func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format,
752757
}
753758
}
754759

755-
func (cn *conn) prepareTo(q, stmtName string) *stmt {
760+
func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) {
756761
st := &stmt{cn: cn, name: stmtName}
757762

758763
b := cn.writeBuf('P')
@@ -765,13 +770,16 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
765770
b.string(st.name)
766771

767772
b.next('S')
768-
cn.send(b)
773+
if err := cn.send(b); err != nil {
774+
return nil, cn.errHandleFatalNetErrors(err)
775+
}
769776

770777
cn.readParseResponse()
771778
st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
772779
st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
773780
cn.readReadyForQuery()
774-
return st
781+
782+
return st, nil
775783
}
776784

777785
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
@@ -787,7 +795,7 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
787795
}
788796
return s, err
789797
}
790-
return cn.prepareTo(q, cn.gname()), nil
798+
return cn.prepareTo(q, cn.gname())
791799
}
792800

793801
func (cn *conn) Close() (err error) {
@@ -829,7 +837,9 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
829837
}
830838

831839
if cn.binaryParameters {
832-
cn.sendBinaryModeQuery(query, args)
840+
if err := cn.sendBinaryModeQuery(query, args); err != nil {
841+
return nil, cn.errHandleFatalNetErrors(err)
842+
}
833843

834844
cn.readParseResponse()
835845
cn.readBindResponse()
@@ -838,7 +848,11 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
838848
cn.postExecuteWorkaround()
839849
return rows, nil
840850
}
841-
st := cn.prepareTo(query, "")
851+
st, err := cn.prepareTo(query, "")
852+
if err != nil {
853+
return nil, err
854+
}
855+
842856
st.exec(args)
843857
return &rows{
844858
cn: cn,
@@ -862,7 +876,9 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
862876
}
863877

864878
if cn.binaryParameters {
865-
cn.sendBinaryModeQuery(query, args)
879+
if err := cn.sendBinaryModeQuery(query, args); err != nil {
880+
return res, cn.errHandleFatalNetErrors(err)
881+
}
866882

867883
cn.readParseResponse()
868884
cn.readBindResponse()
@@ -874,16 +890,25 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
874890
// Use the unnamed statement to defer planning until bind
875891
// time, or else value-based selectivity estimates cannot be
876892
// used.
877-
st := cn.prepareTo(query, "")
893+
st, err := cn.prepareTo(query, "")
894+
if err != nil {
895+
return res, nil
896+
}
897+
878898
r, err := st.Exec(args)
879899
if err != nil {
880900
panic(err)
881901
}
882902
return r, err
883903
}
884904

885-
func (cn *conn) send(m *writeBuf) {
905+
func (cn *conn) send(m *writeBuf) error {
886906
_, err := cn.c.Write(m.wrap())
907+
return err
908+
}
909+
910+
func (cn *conn) mustSend(m *writeBuf) {
911+
err := cn.send(m)
887912
if err != nil {
888913
panic(err)
889914
}
@@ -1109,7 +1134,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11091134
case 3:
11101135
w := cn.writeBuf('p')
11111136
w.string(o["password"])
1112-
cn.send(w)
1137+
cn.mustSend(w)
11131138

11141139
t, r := cn.recv()
11151140
if t != 'R' {
@@ -1123,7 +1148,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11231148
s := string(r.next(4))
11241149
w := cn.writeBuf('p')
11251150
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1126-
cn.send(w)
1151+
cn.mustSend(w)
11271152

11281153
t, r := cn.recv()
11291154
if t != 'R' {
@@ -1145,7 +1170,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11451170
w.string("SCRAM-SHA-256")
11461171
w.int32(len(scOut))
11471172
w.bytes(scOut)
1148-
cn.send(w)
1173+
cn.mustSend(w)
11491174

11501175
t, r := cn.recv()
11511176
if t != 'R' {
@@ -1165,7 +1190,7 @@ func (cn *conn) auth(r *readBuf, o values) {
11651190
scOut = sc.Out()
11661191
w = cn.writeBuf('p')
11671192
w.bytes(scOut)
1168-
cn.send(w)
1193+
cn.mustSend(w)
11691194

11701195
t, r = cn.recv()
11711196
if t != 'R' {
@@ -1219,9 +1244,11 @@ func (st *stmt) Close() (err error) {
12191244
w := st.cn.writeBuf('C')
12201245
w.byte('S')
12211246
w.string(st.name)
1222-
st.cn.send(w)
1247+
if err := st.cn.send(w); err != nil {
1248+
return st.cn.errHandleFatalNetErrors(err)
1249+
}
12231250

1224-
st.cn.send(st.cn.writeBuf('S'))
1251+
st.cn.mustSend(st.cn.writeBuf('S'))
12251252

12261253
t, _ := st.cn.recv1()
12271254
if t != '3' {
@@ -1299,7 +1326,7 @@ func (st *stmt) exec(v []driver.Value) {
12991326
w.int32(0)
13001327

13011328
w.next('S')
1302-
cn.send(w)
1329+
cn.mustSend(w)
13031330

13041331
cn.readBindResponse()
13051332
cn.postExecuteWorkaround()
@@ -1577,7 +1604,7 @@ func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
15771604
}
15781605
}
15791606

1580-
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
1607+
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) error {
15811608
if len(args) >= 65536 {
15821609
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
15831610
}
@@ -1601,7 +1628,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
16011628
b.int32(0)
16021629

16031630
b.next('S')
1604-
cn.send(b)
1631+
return cn.send(b)
16051632
}
16061633

16071634
func (cn *conn) processParameterStatus(r *readBuf) {

error.go

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -480,36 +480,56 @@ func errRecoverNoErrBadConn(err *error) {
480480

481481
func (c *conn) errRecover(err *error) {
482482
e := recover()
483-
switch v := e.(type) {
484-
case nil:
485-
// Do nothing
483+
if e == nil {
484+
return
485+
}
486+
487+
if val, ok := e.(error); ok {
488+
*err = c.errHandleFatalNetErrors(val)
489+
return
490+
}
491+
492+
c.bad = true
493+
panic(fmt.Sprintf("unknown error: %#v", e))
494+
}
495+
496+
func (c *conn) errHandleFatalNetErrors(e error) error {
497+
return c.errHandle(e, false)
498+
}
499+
500+
func (c *conn) errHandle(e error, netErrorIsRetryable bool) error {
501+
var result error
502+
503+
switch v := (e).(type) {
486504
case runtime.Error:
487505
c.bad = true
488506
panic(v)
489507
case *Error:
490508
if v.Fatal() {
491-
*err = driver.ErrBadConn
509+
result = driver.ErrBadConn
492510
} else {
493-
*err = v
511+
result = v
494512
}
495513
case *net.OpError:
496514
c.bad = true
497-
*err = v
498-
case error:
499-
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
500-
*err = driver.ErrBadConn
515+
if netErrorIsRetryable {
516+
result = driver.ErrBadConn
501517
} else {
502-
*err = v
518+
result = v
503519
}
504-
505520
default:
506-
c.bad = true
507-
panic(fmt.Sprintf("unknown error: %#v", e))
521+
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
522+
result = driver.ErrBadConn
523+
} else {
524+
result = v
525+
}
508526
}
509527

510528
// Any time we return ErrBadConn, we need to remember it since *Tx doesn't
511529
// mark the connection bad in database/sql.
512-
if *err == driver.ErrBadConn {
530+
if result == driver.ErrBadConn {
513531
c.bad = true
514532
}
533+
534+
return result
515535
}

0 commit comments

Comments
 (0)