@@ -1043,6 +1043,13 @@ type Tx struct {
10431043 // or Rollback. once done, all operations fail with
10441044 // ErrTxDone.
10451045 done bool
1046+
1047+ // All Stmts prepared for this transaction. These will be closed after the
1048+ // transaction has been committed or rolled back.
1049+ stmts struct {
1050+ sync.Mutex
1051+ v []* Stmt
1052+ }
10461053}
10471054
10481055var ErrTxDone = errors .New ("sql: Transaction has already been committed or rolled back" )
@@ -1064,15 +1071,28 @@ func (tx *Tx) grabConn() (*driverConn, error) {
10641071 return tx .dc , nil
10651072}
10661073
1074+ // Closes all Stmts prepared for this transaction.
1075+ func (tx * Tx ) closePrepared () {
1076+ tx .stmts .Lock ()
1077+ for _ , stmt := range tx .stmts .v {
1078+ stmt .Close ()
1079+ }
1080+ tx .stmts .Unlock ()
1081+ }
1082+
10671083// Commit commits the transaction.
10681084func (tx * Tx ) Commit () error {
10691085 if tx .done {
10701086 return ErrTxDone
10711087 }
10721088 defer tx .close ()
10731089 tx .dc .Lock ()
1074- defer tx .dc .Unlock ()
1075- return tx .txi .Commit ()
1090+ err := tx .txi .Commit ()
1091+ tx .dc .Unlock ()
1092+ if err != driver .ErrBadConn {
1093+ tx .closePrepared ()
1094+ }
1095+ return err
10761096}
10771097
10781098// Rollback aborts the transaction.
@@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error {
10821102 }
10831103 defer tx .close ()
10841104 tx .dc .Lock ()
1085- defer tx .dc .Unlock ()
1086- return tx .txi .Rollback ()
1105+ err := tx .txi .Rollback ()
1106+ tx .dc .Unlock ()
1107+ if err != driver .ErrBadConn {
1108+ tx .closePrepared ()
1109+ }
1110+ return err
10871111}
10881112
10891113// Prepare creates a prepared statement for use within a transaction.
@@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
11271151 },
11281152 query : query ,
11291153 }
1154+ tx .stmts .Lock ()
1155+ tx .stmts .v = append (tx .stmts .v , stmt )
1156+ tx .stmts .Unlock ()
11301157 return stmt , nil
11311158}
11321159
@@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
11551182 dc .Lock ()
11561183 si , err := dc .ci .Prepare (stmt .query )
11571184 dc .Unlock ()
1158- return & Stmt {
1185+ txs := & Stmt {
11591186 db : tx .db ,
11601187 tx : tx ,
11611188 txsi : & driverStmt {
@@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
11651192 query : stmt .query ,
11661193 stickyErr : err ,
11671194 }
1195+ tx .stmts .Lock ()
1196+ tx .stmts .v = append (tx .stmts .v , txs )
1197+ tx .stmts .Unlock ()
1198+ return txs
11681199}
11691200
11701201// Exec executes a query that doesn't return rows.
0 commit comments