Skip to content

Commit 175b972

Browse files
johtowheatman
authored andcommitted
database/sql: Close per-tx prepared statements when the associated tx ends
LGTM=bradfitz R=golang-codereviews, bradfitz, mattn.jp CC=golang-codereviews https://golang.org/cl/131650043
1 parent 6e369a1 commit 175b972

File tree

2 files changed

+67
-5
lines changed

2 files changed

+67
-5
lines changed

src/database/sql/sql.go

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,13 @@ type Tx struct {
10431043
// or Rollback. once done, all operations fail with
10441044
// ErrTxDone.
10451045
done bool
1046+
1047+
// All Stmts prepared for this transaction. These will be closed after the
1048+
// transaction has been committed or rolled back.
1049+
stmts struct {
1050+
sync.Mutex
1051+
v []*Stmt
1052+
}
10461053
}
10471054

10481055
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
@@ -1064,15 +1071,28 @@ func (tx *Tx) grabConn() (*driverConn, error) {
10641071
return tx.dc, nil
10651072
}
10661073

1074+
// Closes all Stmts prepared for this transaction.
1075+
func (tx *Tx) closePrepared() {
1076+
tx.stmts.Lock()
1077+
for _, stmt := range tx.stmts.v {
1078+
stmt.Close()
1079+
}
1080+
tx.stmts.Unlock()
1081+
}
1082+
10671083
// Commit commits the transaction.
10681084
func (tx *Tx) Commit() error {
10691085
if tx.done {
10701086
return ErrTxDone
10711087
}
10721088
defer tx.close()
10731089
tx.dc.Lock()
1074-
defer tx.dc.Unlock()
1075-
return tx.txi.Commit()
1090+
err := tx.txi.Commit()
1091+
tx.dc.Unlock()
1092+
if err != driver.ErrBadConn {
1093+
tx.closePrepared()
1094+
}
1095+
return err
10761096
}
10771097

10781098
// Rollback aborts the transaction.
@@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error {
10821102
}
10831103
defer tx.close()
10841104
tx.dc.Lock()
1085-
defer tx.dc.Unlock()
1086-
return tx.txi.Rollback()
1105+
err := tx.txi.Rollback()
1106+
tx.dc.Unlock()
1107+
if err != driver.ErrBadConn {
1108+
tx.closePrepared()
1109+
}
1110+
return err
10871111
}
10881112

10891113
// Prepare creates a prepared statement for use within a transaction.
@@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
11271151
},
11281152
query: query,
11291153
}
1154+
tx.stmts.Lock()
1155+
tx.stmts.v = append(tx.stmts.v, stmt)
1156+
tx.stmts.Unlock()
11301157
return stmt, nil
11311158
}
11321159

@@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
11551182
dc.Lock()
11561183
si, err := dc.ci.Prepare(stmt.query)
11571184
dc.Unlock()
1158-
return &Stmt{
1185+
txs := &Stmt{
11591186
db: tx.db,
11601187
tx: tx,
11611188
txsi: &driverStmt{
@@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
11651192
query: stmt.query,
11661193
stickyErr: err,
11671194
}
1195+
tx.stmts.Lock()
1196+
tx.stmts.v = append(tx.stmts.v, txs)
1197+
tx.stmts.Unlock()
1198+
return txs
11681199
}
11691200

11701201
// Exec executes a query that doesn't return rows.

src/database/sql/sql_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,33 @@ func TestExec(t *testing.T) {
441441
}
442442
}
443443

444+
func TestTxPrepare(t *testing.T) {
445+
db := newTestDB(t, "")
446+
defer closeDB(t, db)
447+
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
448+
tx, err := db.Begin()
449+
if err != nil {
450+
t.Fatalf("Begin = %v", err)
451+
}
452+
stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
453+
if err != nil {
454+
t.Fatalf("Stmt, err = %v, %v", stmt, err)
455+
}
456+
defer stmt.Close()
457+
_, err = stmt.Exec("Bobby", 7)
458+
if err != nil {
459+
t.Fatalf("Exec = %v", err)
460+
}
461+
err = tx.Commit()
462+
if err != nil {
463+
t.Fatalf("Commit = %v", err)
464+
}
465+
// Commit() should have closed the statement
466+
if !stmt.closed {
467+
t.Fatal("Stmt not closed after Commit")
468+
}
469+
}
470+
444471
func TestTxStmt(t *testing.T) {
445472
db := newTestDB(t, "")
446473
defer closeDB(t, db)
@@ -464,6 +491,10 @@ func TestTxStmt(t *testing.T) {
464491
if err != nil {
465492
t.Fatalf("Commit = %v", err)
466493
}
494+
// Commit() should have closed the statement
495+
if !txs.closed {
496+
t.Fatal("Stmt not closed after Commit")
497+
}
467498
}
468499

469500
// Issue: http://golang.org/issue/2784

0 commit comments

Comments
 (0)