Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/plan/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func (p *DeletePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
sql := sb.String()
log.Debugf("delete, db name: %s, sql: %s", p.Database, sql)

pp := parser.New()
stmtNode, err := pp.ParseOneStmt(sql, "", "")
_parser := parser.New()
stmtNode, err := _parser.ParseOneStmt(sql, "", "")
if err != nil {
return nil, 0, errors.WithStack(err)
}
Expand Down
70 changes: 63 additions & 7 deletions pkg/plan/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ package plan

import (
"context"
"fmt"
"strings"

"github.com/pkg/errors"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/dt"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/misc"
"github.com/cectc/dbpack/pkg/mysql"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/visitor"
"github.com/cectc/dbpack/third_party/parser"
"github.com/cectc/dbpack/third_party/parser/ast"
"github.com/cectc/dbpack/third_party/parser/format"
)
Expand Down Expand Up @@ -53,17 +58,27 @@ func (p *UpdatePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
}
for _, table := range p.Tables {
sb.Reset()
if err = p.generate(&sb, table); err != nil {
if err = p.generate(&sb, table, hints...); err != nil {
return nil, 0, errors.Wrap(err, "failed to generate sql")
}
sql := sb.String()
log.Debugf("update, db name: %s, sql: %s", p.Database, sql)

_parser := parser.New()
stmtNode, err := _parser.ParseOneStmt(sql, "", "")
if err != nil {
return nil, 0, errors.WithStack(err)
}
stmtNode.Accept(&visitor.ParamVisitor{})

commandType := proto.CommandType(ctx)
switch commandType {
case constant.ComQuery:
ctx := proto.WithQueryStmt(ctx, stmtNode)
result, warns, err = tx.Query(ctx, sql)
case constant.ComStmtExecute:
stmt := generateStatement(sql, stmtNode, p.Args)
ctx := proto.WithPrepareStmt(ctx, stmt)
result, warns, err = tx.ExecuteSql(ctx, sql, p.Args...)
default:
continue
Expand All @@ -87,10 +102,24 @@ func (p *UpdatePlan) Execute(ctx context.Context, hints ...*ast.TableOptimizerHi
return mysqlResult, warnings, nil
}

func (p *UpdatePlan) generate(sb *strings.Builder, table string) error {
func (p *UpdatePlan) generate(sb *strings.Builder, table string, hints ...*ast.TableOptimizerHint) error {
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, sb)
ctx.WriteKeyWord("UPDATE ")
// todo add xid hint for distributed transaction

if len(hints) != 0 {
ctx.WritePlain("/*+ ")
for i, tableHint := range hints {
if i != 0 {
ctx.WritePlain(" ")
}
if err := tableHint.Restore(ctx); err != nil {
return errors.Wrapf(err, "An error occurred while restoring UpdateStmt.TableHints[%d], HintName: %s",
i, tableHint.HintName.String())
}
}
ctx.WritePlain("*/ ")
}

ctx.WritePlain(table)
ctx.WriteKeyWord(" SET ")
for i, assignment := range p.Stmt.List {
Expand Down Expand Up @@ -137,18 +166,45 @@ type MultiUpdatePlan struct {
Plans []*UpdatePlan
}

func (p *MultiUpdatePlan) Execute(ctx context.Context, _ ...*ast.TableOptimizerHint) (proto.Result, uint16, error) {
func (p *MultiUpdatePlan) Execute(ctx context.Context, _ ...*ast.TableOptimizerHint) (result proto.Result, warns uint16, err error) {
var (
affectedRows uint64
warnings uint16
affected uint64
hints []*ast.TableOptimizerHint
)
// todo distributed transaction
if has, _ := misc.HasXIDHint(p.Stmt.TableHints); !has {
tableName := p.Stmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.String()
transactionManager := dt.GetDistributedTransactionManager()
timeoutVariable := proto.Variable(ctx, constant.TransactionTimeout)
timeout, ok := timeoutVariable.(int32)
if !ok {
return nil, 0, errors.New("transaction timeout must be of type int32")
}
var xid string
xid, err = transactionManager.Begin(ctx, fmt.Sprintf("UPDATE_%s", tableName), timeout)
if err != nil {
return nil, 0, err
}
hints = append(hints, misc.NewXIDHint(xid))
defer func() {
if err != nil {
if _, rollbackErr := transactionManager.Rollback(ctx, xid); rollbackErr != nil {
log.Error(err)
}
} else {
if _, commitErr := transactionManager.Commit(ctx, xid); commitErr != nil {
log.Error(err)
}
}
}()
}
for _, pl := range p.Plans {
result, warns, err := pl.Execute(ctx)
result, warns, err = pl.Execute(ctx, hints...)
if err != nil {
return nil, 0, err
}
affected, err := result.RowsAffected()
affected, err = result.RowsAffected()
if err != nil {
return nil, 0, errors.WithStack(err)
}
Expand Down
2 changes: 1 addition & 1 deletion test/shd/sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (suite *_ShardingSuite) TestDeleteDrugResource() {
suite.Assert().Nil(err)
affectedRows, err := result.RowsAffected()
suite.Assert().Nil(err)
suite.Assert().Equal(int64(1), affectedRows)
suite.Assert().Equal(int64(11), affectedRows)
time.Sleep(10 * time.Second)
}

Expand Down