From c1b600085c7bd1a627e97d56e40046778ae6c793 Mon Sep 17 00:00:00 2001 From: meiji163 Date: Thu, 24 Oct 2024 12:03:26 -0700 Subject: [PATCH] add arthurscreiber's review suggestions --- go/logic/applier.go | 35 ++++++++++------------------------- go/mysql/connection.go | 1 - go/mysql/connection_test.go | 4 ++-- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/go/logic/applier.go b/go/logic/applier.go index 4e0c58dd0..ecb9cc992 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -80,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier { func (this *Applier) InitDBConnections() (err error) { applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { + uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri) + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil { return err } singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) @@ -1210,7 +1211,7 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB // ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { var totalDelta int64 - ctx := context.TODO() + ctx := context.Background() err := func() error { conn, err := this.db.Conn(ctx) @@ -1236,31 +1237,23 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) } buildResults := make([]*dmlBuildResult, 0, len(dmlEvents)) + nArgs := 0 for _, dmlEvent := range dmlEvents { for _, buildResult := range this.buildDMLEventQuery(dmlEvent) { if buildResult.err != nil { return rollback(buildResult.err) } - + nArgs += len(buildResult.args) buildResults = append(buildResults, buildResult) } } execErr := conn.Raw(func(driverConn any) error { - ex, ok := driverConn.(driver.ExecerContext) - if !ok { - return fmt.Errorf("could not cast driverConn to ExecerContext") - } - - nvc, ok := driverConn.(driver.NamedValueChecker) - if !ok { - return fmt.Errorf("could not cast driverConn to NamedValueChecker") - } + ex := driverConn.(driver.ExecerContext) + nvc := driverConn.(driver.NamedValueChecker) - var multiArgs []driver.NamedValue + multiArgs := make([]driver.NamedValue, 0, nArgs) multiQueryBuilder := strings.Builder{} - var rowDeltas []int64 - for _, buildResult := range buildResults { for _, arg := range buildResult.args { nv := driver.NamedValue{Value: driver.Value(arg)} @@ -1270,29 +1263,21 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) multiQueryBuilder.WriteString(buildResult.query) multiQueryBuilder.WriteString(";\n") - - rowDeltas = append(rowDeltas, buildResult.rowsDelta) } - // this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs) res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) if err != nil { err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) - this.migrationContext.Log.Errorf("Error exec: %+v", err) return err } - mysqlRes, ok := res.(drivermysql.Result) - if !ok { - return fmt.Errorf("Could not cast %+v to mysql.Result", res) - } + mysqlRes := res.(drivermysql.Result) // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event for i, rowsAffected := range mysqlRes.AllRowsAffected() { - totalDelta += rowDeltas[i] * rowsAffected + totalDelta += buildResults[i].rowsDelta * rowsAffected } - return nil }) diff --git a/go/mysql/connection.go b/go/mysql/connection.go index 1766ee917..33bde2b62 100644 --- a/go/mysql/connection.go +++ b/go/mysql/connection.go @@ -132,7 +132,6 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string { connectionParams := []string{ "autocommit=true", "interpolateParams=true", - "multiStatements=true", fmt.Sprintf("charset=%s", this.Charset), fmt.Sprintf("tls=%s", tlsOption), fmt.Sprintf("transaction_isolation=%q", this.TransactionIsolation), diff --git a/go/mysql/connection_test.go b/go/mysql/connection_test.go index bcd8b3147..7859c9354 100644 --- a/go/mysql/connection_test.go +++ b/go/mysql/connection_test.go @@ -86,7 +86,7 @@ func TestGetDBUri(t *testing.T) { c.Charset = "utf8mb4,utf8,latin1" uri := c.GetDBUri("test") - require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&multiStatements=true&charset=utf8mb4,utf8,latin1&tls=false&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) + require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4,utf8,latin1&tls=false&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) } func TestGetDBUriWithTLSSetup(t *testing.T) { @@ -100,5 +100,5 @@ func TestGetDBUriWithTLSSetup(t *testing.T) { c.Charset = "utf8mb4_general_ci,utf8_general_ci,latin1" uri := c.GetDBUri("test") - require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&multiStatements=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) + require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri) }