From 2a3318c8b6fb2dd74aa1b979bb28407a80037d18 Mon Sep 17 00:00:00 2001 From: meiji163 Date: Wed, 23 Oct 2024 18:55:55 -0700 Subject: [PATCH] conn.Raw not working --- go/logic/applier.go | 66 ++++++++++++++++++++++++++++++--------------- 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index a07faf20f..5b5dc7654 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -14,10 +14,13 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/openark/golib/log" + "context" + "database/sql/driver" + + "github.com/github/gh-ost/go/mysql" + drivermysql "github.com/go-sql-driver/mysql" "github.com/openark/golib/sqlutils" ) @@ -1207,13 +1210,19 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB // ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { var totalDelta int64 + ctx := context.TODO() err := func() error { - tx, err := this.db.Begin() + conn, err := this.db.Conn(ctx) if err != nil { return err } + defer conn.Close() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } rollback := func(err error) error { tx.Rollback() return err @@ -1225,34 +1234,49 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) if _, err := tx.Exec(sessionQuery); err != nil { return rollback(err) } - multiArgs := []interface{}{} + rowDeltas := make([]int64, 0, len(dmlEvents)) + multiArgs := []driver.NamedValue{} var multiQueryBuilder strings.Builder for _, dmlEvent := range dmlEvents { for _, buildResult := range this.buildDMLEventQuery(dmlEvent) { if buildResult.err != nil { - return buildResult.err + return rollback(buildResult.err) } - multiArgs = append(multiArgs, buildResult.args...) + for _, arg := range buildResult.args { + multiArgs = append(multiArgs, driver.NamedValue{Value: driver.Value(arg)}) + } + rowDeltas = append(rowDeltas, buildResult.rowsDelta) multiQueryBuilder.WriteString(buildResult.query) multiQueryBuilder.WriteString(";\n") } } - // TODO: get rows affected from each query in multi statement - log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err) - _, err = tx.Exec(multiQueryBuilder.String(), multiArgs...) - if err != nil { - err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) - return rollback(err) - } - // rowsAffected, err := result.RowsAffected() - // if err != nil { - // log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err) - // rowsAffected = 1 - // } - // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). - // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event - // totalDelta += buildResult.rowsDelta * rowsAffected + //this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs) + execErr := conn.Raw(func(driverConn any) error { + ex, ok := driverConn.(driver.ExecerContext) + if !ok { + return fmt.Errorf("could not cast driverConn to ExecerContext") + } + res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) + if err != nil { + err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) + this.migrationContext.Log.Errorf("Error exec: %+v", err) + return err + } + mysqlRes, ok := res.(drivermysql.Result) + if !ok { + return fmt.Errorf("Could not cast %+v to mysql.Result", res) + } + // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). + // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event + for i, rowsAffected := range mysqlRes.AllRowsAffected() { + totalDelta += rowDeltas[i] * rowsAffected + } + return nil + }) + if execErr != nil { + return rollback(execErr) + } if err := tx.Commit(); err != nil { return err }