From 4660a11c07c36214ef84729c288ae4d673fe2c32 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Sun, 15 Dec 2024 15:28:36 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E6=9A=82=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/datasource/sql/exec/at/at_executor.go | 2 + .../sql/exec/at/update_join_executor.go | 280 ++++++++++++++++++ .../exec/at/update_join_executor_test.go.go | 100 +++++++ pkg/datasource/sql/types/sql.go | 1 + .../sql/undo/factor/undo_executor_factory.go | 2 +- 5 files changed, 384 insertions(+), 1 deletion(-) create mode 100644 pkg/datasource/sql/exec/at/update_join_executor.go create mode 100644 pkg/datasource/sql/exec/at/update_join_executor_test.go.go diff --git a/pkg/datasource/sql/exec/at/at_executor.go b/pkg/datasource/sql/exec/at/at_executor.go index e51db2284..09b284b34 100644 --- a/pkg/datasource/sql/exec/at/at_executor.go +++ b/pkg/datasource/sql/exec/at/at_executor.go @@ -68,6 +68,8 @@ func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec executor = NewInsertOnUpdateExecutor(queryParser, execCtx, e.hooks) case types.SQLTypeMulti: executor = NewMultiExecutor(queryParser, execCtx, e.hooks) + case types.SQLTypeUpdateJoin: + executor = NewUpdateJoinExecutor(queryParser, execCtx, e.hooks) default: executor = NewPlainExecutor(queryParser, execCtx) } diff --git a/pkg/datasource/sql/exec/at/update_join_executor.go b/pkg/datasource/sql/exec/at/update_join_executor.go new file mode 100644 index 000000000..971b8aa68 --- /dev/null +++ b/pkg/datasource/sql/exec/at/update_join_executor.go @@ -0,0 +1,280 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. + */ + +package at + +import ( + "context" + "database/sql/driver" + "fmt" + "seata.apache.org/seata-go/pkg/datasource/sql" + "seata.apache.org/seata-go/pkg/protocol/branch" + "strings" + + "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/format" + "github.com/arana-db/parser/model" + + "seata.apache.org/seata-go/pkg/datasource/sql/datasource" + "seata.apache.org/seata-go/pkg/datasource/sql/exec" + "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" + "seata.apache.org/seata-go/pkg/datasource/sql/util" + "seata.apache.org/seata-go/pkg/util/bytes" + "seata.apache.org/seata-go/pkg/util/log" +) + +const ( + multi_table_name_seperaror = "#" +) + +var ( + beforeImagesMap map[string]*types.RecordImage + afterImagesMap map[string]*types.RecordImage +) + +// updateJoinExecutor execute update SQL +type updateJoinExecutor struct { + updateExecutor + parserCtx *types.ParseContext + execContext *types.ExecContext + isLowerSupportGroupByPksVersion bool +} + +// NewUpdateJoinExecutor get executor +func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { + // todo 不确定,需要test一下 + val, _ := datasource.GetDataSourceManager(branch.BranchTypeAT).GetCachedResources().Load(execContent.TxCtx.ResourceID) + res := val.(*sql.DBResource) + minimumVersion, _ := util.ConvertDbVersion("5.7.5") + currentVersion, _ := util.ConvertDbVersion(res.GetDbVersion()) + return &updateJoinExecutor{ + parserCtx: parserCtx, + execContext: execContent, + updateExecutor: updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}, + isLowerSupportGroupByPksVersion: currentVersion < minimumVersion, + } +} + +// beforeImage build before image +func (u *updateJoinExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { + if !u.isAstStmtValid() { + return nil, nil + } + + // update join sql,like update t1 inner join t2 on t1.id = t2.id set t1.name = ?; tableItems = {"update t1 inner join t2","t1","t2"} + tableName, _ := u.parserCtx.GetTableName() + + tableItems := strings.Split(tableName, multi_table_name_seperaror) + + u.buildWhereConditionByPKs(u.execContext.MetaDataMap, u.execContext.DBType) + suffixCommonCondition, paramAppenderList := u.buildBeforeImageSQLCommonConditionSuffix() + for i := range tableItems { + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableNames[i]) + if err != nil { + return nil, err + } + + image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) + if err != nil { + return nil, err + } + } + + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues) + if err != nil { + return nil, err + } + + var rowsi driver.Rows + queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() + } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err + } + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") + } + + lockKey := u.buildLockKey(image, *metaData) + u.execContext.TxCtx.LockKeys[lockKey] = struct{}{} + image.SQLType = u.parserCtx.SQLType + + return image, nil +} + +// afterImage build after image +func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) { + if !u.isAstStmtValid() { + return nil, nil + } + if len(beforeImage.Rows) == 0 { + return &types.RecordImage{}, nil + } + + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + if err != nil { + return nil, err + } + selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) + + var rowsi driver.Rows + queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() + } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err + } + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") + } + + afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) + if err != nil { + return nil, err + } + afterImage.SQLType = u.parserCtx.SQLType + + return afterImage, nil +} + +func (u *updateJoinExecutor) isAstStmtValid() bool { + return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil +} + +// buildAfterImageSQL build the SQL to query after image data +func (u *updateJoinExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) { + if len(beforeImage.Rows) == 0 { + return "", nil + } + sb := strings.Builder{} + // todo: OnlyCareUpdateColumns should load from config first + var selectFields string + var separator = "," + if undo.UndoConfig.OnlyCareUpdateColumns { + for _, row := range beforeImage.Rows { + for _, column := range row.Columns { + selectFields += column.ColumnName + separator + } + } + selectFields = strings.TrimSuffix(selectFields, separator) + } else { + selectFields = "*" + } + sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ") + whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) + sb.WriteString(" " + whereSQL + " ") + return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) +} + +// buildAfterImageSQL build the SQL to query before image data +func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { + if !u.isAstStmtValid() { + log.Errorf("invalid update stmt") + return "", nil, fmt.Errorf("invalid update stmt") + } + + updateStmt := u.parserCtx.UpdateStmt + fields := make([]*ast.SelectField, 0, len(updateStmt.List)) + + if undo.UndoConfig.OnlyCareUpdateColumns { + for _, column := range updateStmt.List { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: column.Column, + }, + }) + } + + // select indexes columns + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + if err != nil { + return "", nil, err + } + for _, columnName := range metaData.GetPrimaryKeyOnlyName() { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: columnName, + L: columnName, + }, + }, + }, + }) + } + } else { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: "*", + L: "*", + }, + }, + }, + }) + } + + selStmt := ast.SelectStmt{ + SelectStmtOpts: &ast.SelectStmtOpts{}, + From: updateStmt.TableRefs, + Where: updateStmt.Where, + Fields: &ast.FieldList{Fields: fields}, + OrderBy: updateStmt.Order, + Limit: updateStmt.Limit, + TableHints: updateStmt.TableHints, + LockInfo: &ast.SelectLockInfo{ + LockType: ast.SelectLockForUpdate, + }, + } + + b := bytes.NewByteBuffer([]byte{}) + _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) + sql := string(b.Bytes()) + log.Infof("build select sql by update sourceQuery, sql {%s}", sql) + + return sql, u.buildSelectArgs(&selStmt, args), nil +} + +func (u *updateJoinExecutor) getDbVersion() string { +} diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go.go new file mode 100644 index 000000000..770f21688 --- /dev/null +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go.go @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package at + +import ( + "context" + "database/sql/driver" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/stretchr/testify/assert" + + "seata.apache.org/seata-go/pkg/datasource/sql/datasource" + "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql" + "seata.apache.org/seata-go/pkg/datasource/sql/exec" + "seata.apache.org/seata-go/pkg/datasource/sql/parser" + "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" + "seata.apache.org/seata-go/pkg/datasource/sql/util" + _ "seata.apache.org/seata-go/pkg/util/log" +) + +func TestBuildSelectSQLByUpdate(t *testing.T) { + undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) + datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil)) + stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", + func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { + return &types.TableMeta{ + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + }, nil + }) + defer stub.Reset() + + tests := []struct { + name string + sourceQuery string + sourceQueryArgs []driver.Value + expectQuery string + expectQueryArgs []driver.Value + }{ + { + sourceQuery: "update t_user set name = ?, age = ? where id = ?", + sourceQueryArgs: []driver.Value{"Jack", 1, 100}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", + expectQueryArgs: []driver.Value{100}, + }, + { + sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", + sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", + expectQueryArgs: []driver.Value{100, 18, 28}, + }, + { + sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", + sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", + expectQueryArgs: []driver.Value{100, 18, 28}, + }, + { + sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", + sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", + expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := parser.DoParser(tt.sourceQuery) + assert.Nil(t, err) + executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) + query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) + assert.Nil(t, err) + assert.Equal(t, tt.expectQuery, query) + assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) + }) + } +} diff --git a/pkg/datasource/sql/types/sql.go b/pkg/datasource/sql/types/sql.go index 58f4612f4..728ce6158 100644 --- a/pkg/datasource/sql/types/sql.go +++ b/pkg/datasource/sql/types/sql.go @@ -68,6 +68,7 @@ const ( SQLTypeSelectFoundRows SQLTypeInsertIgnore = iota + 57 SQLTypeInsertOnDuplicateUpdate + SQLTypeUpdateJoin // SQLTypeMulti and SQLTypeUnknown is different from seata-java SQLTypeMulti = iota + 999 SQLTypeUnknown diff --git a/pkg/datasource/sql/undo/factor/undo_executor_factory.go b/pkg/datasource/sql/undo/factor/undo_executor_factory.go index 2a943eff4..d8ae8ec5d 100644 --- a/pkg/datasource/sql/undo/factor/undo_executor_factory.go +++ b/pkg/datasource/sql/undo/factor/undo_executor_factory.go @@ -37,7 +37,7 @@ func GetUndoExecutor(dbType types.DBType, sqlUndoLog undo.SQLUndoLog) (res undo. res = undoExecutorHolder.GetInsertExecutor(sqlUndoLog) case types.SQLTypeDelete: res = undoExecutorHolder.GetDeleteExecutor(sqlUndoLog) - case types.SQLTypeUpdate: + case types.SQLTypeUpdate, types.SQLTypeUpdateJoin: res = undoExecutorHolder.GetUpdateExecutor(sqlUndoLog) default: return nil, fmt.Errorf("sql type: %d not support", sqlUndoLog.SQLType) From e430bc6dd11702dadf4e9fc31412850614504570 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 27 Dec 2024 21:26:26 +0800 Subject: [PATCH 2/8] duplicate image row for update join --- pkg/datasource/sql/conn.go | 4 + pkg/datasource/sql/conn_at.go | 2 + pkg/datasource/sql/exec/at/base_executor.go | 21 + pkg/datasource/sql/exec/at/update_executor.go | 139 +++---- .../sql/exec/at/update_executor_test.go | 178 +++++++-- .../sql/exec/at/update_join_executor.go | 371 ++++++++++-------- .../sql/exec/at/update_join_executor_test.go | 121 ++++++ .../exec/at/update_join_executor_test.go.go | 100 ----- pkg/datasource/sql/types/types.go | 1 + 9 files changed, 589 insertions(+), 348 deletions(-) create mode 100644 pkg/datasource/sql/exec/at/update_join_executor_test.go delete mode 100644 pkg/datasource/sql/exec/at/update_join_executor_test.go.go diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 7a2b0423d..f316292b4 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -244,6 +244,10 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e ) } +func (c *Conn) GetDbVersion() string { + return c.res.GetDbVersion() +} + func (c *Conn) GetAutoCommit() bool { return c.autoCommit } diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index f1d0f5ed6..abc31f267 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -63,6 +63,7 @@ func (c *ATConn) QueryContext(ctx context.Context, query string, args []driver.N NamedValues: args, Conn: c.targetConn, DBName: c.dbName, + DbVersion: c.GetDbVersion(), IsSupportsSavepoints: true, IsAutoCommit: c.GetAutoCommit(), } @@ -102,6 +103,7 @@ func (c *ATConn) ExecContext(ctx context.Context, query string, args []driver.Na NamedValues: args, Conn: c.targetConn, DBName: c.dbName, + DbVersion: c.GetDbVersion(), IsSupportsSavepoints: true, IsAutoCommit: c.GetAutoCommit(), } diff --git a/pkg/datasource/sql/exec/at/base_executor.go b/pkg/datasource/sql/exec/at/base_executor.go index 75f0cab56..884438317 100644 --- a/pkg/datasource/sql/exec/at/base_executor.go +++ b/pkg/datasource/sql/exec/at/base_executor.go @@ -22,6 +22,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "fmt" "strings" @@ -350,3 +351,23 @@ func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.Table return lockKeys.String() } + +func (u *updateExecutor) rowsPrepare(ctx context.Context, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) { + var queryer driver.Queryer + + queryerContext, ok := u.execContext.Conn.(driver.QueryerContext) + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + var err error + rows, err = util.CtxDriverQuery(ctx, queryerContext, queryer, selectSQL, selectArgs) + + if err != nil { + return nil, err + } + } else { + return nil, errors.New("target conn should been driver.QueryerContext or driver.Queryer") + } + return rows, nil +} diff --git a/pkg/datasource/sql/exec/at/update_executor.go b/pkg/datasource/sql/exec/at/update_executor.go index 0ac9b0498..3463ba895 100644 --- a/pkg/datasource/sql/exec/at/update_executor.go +++ b/pkg/datasource/sql/exec/at/update_executor.go @@ -20,6 +20,7 @@ package at import ( "context" "database/sql/driver" + "errors" "fmt" "strings" @@ -31,7 +32,6 @@ import ( "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/undo" - "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/util/bytes" "seata.apache.org/seata-go/pkg/util/log" ) @@ -90,37 +90,31 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, nil } - selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues) + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } - tableName, _ := u.parserCtx.GetTableName() - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, u.execContext.NamedValues) if err != nil { return nil, err } - - var rowsi driver.Rows - queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) - var queryer driver.Queryer - if !ok { - queryer, ok = u.execContext.Conn.(driver.Queryer) + if selectSQL == "" { + return nil, errors.New("build select sql by update sourceQuery fail") } - if ok { - rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - rowsi.Close() + + rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + if err := rowsi.Close(); err != nil { + log.Errorf("rows close fail, err:%v", err) + return } - }() - if err != nil { - log.Errorf("ctx driver query: %+v", err) - return nil, err } - } else { - log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") - return nil, fmt.Errorf("invalid conn") + }() + if err != nil { + return nil, err } image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) @@ -151,26 +145,17 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor } selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) - var rowsi driver.Rows - queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) - var queryer driver.Queryer - if !ok { - queryer, ok = u.execContext.Conn.(driver.Queryer) - } - if ok { - rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - rowsi.Close() + rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + if err := rowsi.Close(); err != nil { + log.Errorf("rows close fail, err:%v", err) + return } - }() - if err != nil { - log.Errorf("ctx driver query: %+v", err) - return nil, err } - } else { - log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") - return nil, fmt.Errorf("invalid conn") + }() + if err != nil { + return nil, err } afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) @@ -212,17 +197,54 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta } // buildAfterImageSQL build the SQL to query before image data -func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { +func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, args []driver.NamedValue) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { log.Errorf("invalid update stmt") return "", nil, fmt.Errorf("invalid update stmt") } + updateStmt := u.parserCtx.UpdateStmt + fields, err := u.buildSelectFields(ctx, tableMeta) + if err != nil { + return "", nil, err + } + if len(fields) == 0 { + return "", nil, err + } + + selStmt := ast.SelectStmt{ + SelectStmtOpts: &ast.SelectStmtOpts{}, + From: updateStmt.TableRefs, + Where: updateStmt.Where, + Fields: &ast.FieldList{Fields: fields}, + OrderBy: updateStmt.Order, + Limit: updateStmt.Limit, + TableHints: updateStmt.TableHints, + LockInfo: &ast.SelectLockInfo{ + LockType: ast.SelectLockForUpdate, + }, + } + + b := bytes.NewByteBuffer([]byte{}) + _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) + sql := string(b.Bytes()) + log.Infof("build select sql by update sourceQuery, sql {%s}", sql) + + return sql, u.buildSelectArgs(&selStmt, args), nil +} + +func (u *updateExecutor) buildSelectFields(ctx context.Context, tableMeta *types.TableMeta) ([]*ast.SelectField, error) { updateStmt := u.parserCtx.UpdateStmt fields := make([]*ast.SelectField, 0, len(updateStmt.List)) + lowerTableName := strings.ToLower(tableMeta.TableName) if undo.UndoConfig.OnlyCareUpdateColumns { for _, column := range updateStmt.List { + tableName := column.Column.Table.L + if tableName != "" && lowerTableName != tableName { + continue + } + fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{ Name: column.Column, @@ -230,16 +252,19 @@ func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver. }) } - // select indexes columns - tableName, _ := u.parserCtx.GetTableName() - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) - if err != nil { - return "", nil, err + if len(fields) == 0 { + return fields, nil } - for _, columnName := range metaData.GetPrimaryKeyOnlyName() { + + // select indexes columns + for _, columnName := range tableMeta.GetPrimaryKeyOnlyName() { fields = append(fields, &ast.SelectField{ Expr: &ast.ColumnNameExpr{ Name: &ast.ColumnName{ + Table: model.CIStr{ + O: tableMeta.TableName, + L: lowerTableName, + }, Name: model.CIStr{ O: columnName, L: columnName, @@ -261,23 +286,5 @@ func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver. }) } - selStmt := ast.SelectStmt{ - SelectStmtOpts: &ast.SelectStmtOpts{}, - From: updateStmt.TableRefs, - Where: updateStmt.Where, - Fields: &ast.FieldList{Fields: fields}, - OrderBy: updateStmt.Order, - Limit: updateStmt.Limit, - TableHints: updateStmt.TableHints, - LockInfo: &ast.SelectLockInfo{ - LockType: ast.SelectLockForUpdate, - }, - } - - b := bytes.NewByteBuffer([]byte{}) - _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) - sql := string(b.Bytes()) - log.Infof("build select sql by update sourceQuery, sql {%s}", sql) - - return sql, u.buildSelectArgs(&selStmt, args), nil + return fields, nil } diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go index 770f21688..c3efb1eb2 100644 --- a/pkg/datasource/sql/exec/at/update_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -20,14 +20,11 @@ package at import ( "context" "database/sql/driver" - "reflect" + "os" "testing" - "github.com/agiledragon/gomonkey/v2" "github.com/stretchr/testify/assert" - "seata.apache.org/seata-go/pkg/datasource/sql/datasource" - "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/parser" "seata.apache.org/seata-go/pkg/datasource/sql/types" @@ -36,23 +33,156 @@ import ( _ "seata.apache.org/seata-go/pkg/util/log" ) -func TestBuildSelectSQLByUpdate(t *testing.T) { - undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) - datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil)) - stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", - func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { - return &types.TableMeta{ - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, +var ( + MetaDataMap map[string]*types.TableMeta +) + +func initTest() { + MetaDataMap = map[string]*types.TableMeta{ + "t_user": { + TableName: "t_user", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, }, }, - }, nil - }) - defer stub.Reset() + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "name", "age"}, + }, + "t1": { + TableName: "t1", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "name", "age"}, + }, + "t2": { + TableName: "t2", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + "kk": { + ColumnDef: nil, + ColumnName: "kk", + }, + "addr": { + ColumnDef: nil, + ColumnName: "addr", + }, + }, + ColumnNames: []string{"id", "name", "age", "kk", "addr"}, + }, + "t3": { + TableName: "t3", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "age"}, + }, + "t4": { + TableName: "t4", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "age"}, + }, + } + + undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) +} + +func TestMain(m *testing.M) { + // 调用初始化函数 + initTest() + + // 启动测试 + os.Exit(m.Run()) +} + +func TestBuildSelectSQLByUpdate(t *testing.T) { tests := []struct { name string @@ -64,25 +194,25 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { { sourceQuery: "update t_user set name = ?, age = ? where id = ?", sourceQueryArgs: []driver.Value{"Jack", 1, 100}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE id=? FOR UPDATE", expectQueryArgs: []driver.Value{100}, }, { sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", expectQueryArgs: []driver.Value{100, 18, 28}, }, { sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", expectQueryArgs: []driver.Value{100, 18, 28}, }, { sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, }, } @@ -91,7 +221,7 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { c, err := parser.DoParser(tt.sourceQuery) assert.Nil(t, err) executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) - query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) + query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap["t_user"], util.ValueToNamedValue(tt.sourceQueryArgs)) assert.Nil(t, err) assert.Equal(t, tt.expectQuery, query) assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) diff --git a/pkg/datasource/sql/exec/at/update_join_executor.go b/pkg/datasource/sql/exec/at/update_join_executor.go index 971b8aa68..7658626a4 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor.go +++ b/pkg/datasource/sql/exec/at/update_join_executor.go @@ -20,9 +20,10 @@ package at import ( "context" "database/sql/driver" + "errors" "fmt" - "seata.apache.org/seata-go/pkg/datasource/sql" - "seata.apache.org/seata-go/pkg/protocol/branch" + "io" + "reflect" "strings" "github.com/arana-db/parser/ast" @@ -32,36 +33,24 @@ import ( "seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" - "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/util/bytes" "seata.apache.org/seata-go/pkg/util/log" ) -const ( - multi_table_name_seperaror = "#" -) - -var ( - beforeImagesMap map[string]*types.RecordImage - afterImagesMap map[string]*types.RecordImage -) - // updateJoinExecutor execute update SQL type updateJoinExecutor struct { updateExecutor parserCtx *types.ParseContext execContext *types.ExecContext isLowerSupportGroupByPksVersion bool + sqlMode string } // NewUpdateJoinExecutor get executor func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { - // todo 不确定,需要test一下 - val, _ := datasource.GetDataSourceManager(branch.BranchTypeAT).GetCachedResources().Load(execContent.TxCtx.ResourceID) - res := val.(*sql.DBResource) minimumVersion, _ := util.ConvertDbVersion("5.7.5") - currentVersion, _ := util.ConvertDbVersion(res.GetDbVersion()) + currentVersion, _ := util.ConvertDbVersion(execContent.DbVersion) return &updateJoinExecutor{ parserCtx: parserCtx, execContext: execContent, @@ -70,189 +59,154 @@ func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.Exe } } -// beforeImage build before image -func (u *updateJoinExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) { +// ExecContext exec SQL, and generate before image and after image +func (u *updateJoinExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { + u.beforeHooks(ctx, u.execContext) + defer func() { + u.afterHooks(ctx, u.execContext) + }() + + beforeImages, err := u.beforeImage(ctx) + if err != nil { + return nil, err + } + + res, err := f(ctx, u.execContext.Query, u.execContext.NamedValues) + if err != nil { + return nil, err + } + + afterImages, err := u.afterImage(ctx, beforeImages) + if err != nil { + return nil, err + } + + if len(afterImages) != len(beforeImages) { + return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") + } + + for i, afterImage := range afterImages { + beforeImage := afterImages[i] + if len(beforeImage.Rows) != len(afterImage.Rows) { + return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") + } + + u.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage) + u.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage) + } + + return res, nil +} + +func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil } - // update join sql,like update t1 inner join t2 on t1.id = t2.id set t1.name = ?; tableItems = {"update t1 inner join t2","t1","t2"} - tableName, _ := u.parserCtx.GetTableName() + var recordImages []*types.RecordImage - tableItems := strings.Split(tableName, multi_table_name_seperaror) + // Parsing multiple table name + updateStmt := u.parserCtx.UpdateStmt + tableNames := u.parseTableName(updateStmt.TableRefs.TableRefs) - u.buildWhereConditionByPKs(u.execContext.MetaDataMap, u.execContext.DBType) - suffixCommonCondition, paramAppenderList := u.buildBeforeImageSQLCommonConditionSuffix() - for i := range tableItems { - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableNames[i]) + for _, tbName := range tableNames { + metaData, err := datasource.GetTableCache(u.execContext.DBType).GetTableMeta(ctx, u.execContext.DBName, tbName) if err != nil { return nil, err } - - image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, u.execContext.NamedValues) if err != nil { return nil, err } - } - - selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues) - if err != nil { - return nil, err - } + if selectSQL == "" { + log.Debugf("Skip unused table [{%s}] when build select sql by update sourceQuery", tbName) + continue + } - var rowsi driver.Rows - queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) - var queryer driver.Queryer - if !ok { - queryer, ok = u.execContext.Conn.(driver.Queryer) - } - if ok { - rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - rowsi.Close() + var image *types.RecordImage + rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + if err == nil { + image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdateJoin) + } + if rowsi != nil { + if rowerr := rows.Close(); rowerr != nil { + log.Errorf("rows close fail, err:%v", rowerr) + return nil, rowerr } - }() + } if err != nil { - log.Errorf("ctx driver query: %+v", err) + // If one fail, all fails return nil, err } - } else { - log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") - return nil, fmt.Errorf("invalid conn") - } - lockKey := u.buildLockKey(image, *metaData) - u.execContext.TxCtx.LockKeys[lockKey] = struct{}{} - image.SQLType = u.parserCtx.SQLType + lockKey := u.buildLockKey(image, *metaData) + u.execContext.TxCtx.LockKeys[lockKey] = struct{}{} + image.SQLType = u.parserCtx.SQLType - return image, nil + recordImages = append(recordImages, image) + } + + return recordImages, nil } -// afterImage build after image -func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImage types.RecordImage) (*types.RecordImage, error) { +func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil } - if len(beforeImage.Rows) == 0 { - return &types.RecordImage{}, nil - } - tableName, _ := u.parserCtx.GetTableName() - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) - if err != nil { - return nil, err + if len(beforeImages) == 0 { + return nil, errors.New("empty beforeImages") } - selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) - var rowsi driver.Rows - queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) - var queryer driver.Queryer - if !ok { - queryer, ok = u.execContext.Conn.(driver.Queryer) - } - if ok { - rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - rowsi.Close() - } - }() + var recordImages []*types.RecordImage + for _, beforeImage := range beforeImages { + metaData, err := datasource.GetTableCache(u.execContext.DBType).GetTableMeta(ctx, u.execContext.DBName, beforeImage.TableName) if err != nil { - log.Errorf("ctx driver query: %+v", err) return nil, err } - } else { - log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") - return nil, fmt.Errorf("invalid conn") - } - - afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) - if err != nil { - return nil, err - } - afterImage.SQLType = u.parserCtx.SQLType - - return afterImage, nil -} -func (u *updateJoinExecutor) isAstStmtValid() bool { - return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil -} + selectSQL, selectArgs, err := u.buildAfterImageSQL(ctx, *beforeImage, metaData) + if err != nil { + return nil, err + } -// buildAfterImageSQL build the SQL to query after image data -func (u *updateJoinExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue) { - if len(beforeImage.Rows) == 0 { - return "", nil - } - sb := strings.Builder{} - // todo: OnlyCareUpdateColumns should load from config first - var selectFields string - var separator = "," - if undo.UndoConfig.OnlyCareUpdateColumns { - for _, row := range beforeImage.Rows { - for _, column := range row.Columns { - selectFields += column.ColumnName + separator + var image *types.RecordImage + rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + if err == nil { + image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdateJoin) + } + if rowsi != nil { + if rowerr := rowsi.Close(); rowerr != nil { + log.Errorf("rows close fail, err:%v", rowerr) + return nil, rowerr } } - selectFields = strings.TrimSuffix(selectFields, separator) - } else { - selectFields = "*" + if err != nil { + // If one fail, all fails + return nil, err + } + + image.SQLType = u.parserCtx.SQLType + recordImages = append(recordImages, image) } - sb.WriteString("SELECT " + selectFields + " FROM " + meta.TableName + " WHERE ") - whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) - sb.WriteString(" " + whereSQL + " ") - return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) + + return recordImages, nil } // buildAfterImageSQL build the SQL to query before image data -func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { +func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, args []driver.NamedValue) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { - log.Errorf("invalid update stmt") - return "", nil, fmt.Errorf("invalid update stmt") + log.Errorf("invalid update join stmt") + return "", nil, fmt.Errorf("invalid update join stmt") } updateStmt := u.parserCtx.UpdateStmt - fields := make([]*ast.SelectField, 0, len(updateStmt.List)) - - if undo.UndoConfig.OnlyCareUpdateColumns { - for _, column := range updateStmt.List { - fields = append(fields, &ast.SelectField{ - Expr: &ast.ColumnNameExpr{ - Name: column.Column, - }, - }) - } - - // select indexes columns - tableName, _ := u.parserCtx.GetTableName() - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) - if err != nil { - return "", nil, err - } - for _, columnName := range metaData.GetPrimaryKeyOnlyName() { - fields = append(fields, &ast.SelectField{ - Expr: &ast.ColumnNameExpr{ - Name: &ast.ColumnName{ - Name: model.CIStr{ - O: columnName, - L: columnName, - }, - }, - }, - }) - } - } else { - fields = append(fields, &ast.SelectField{ - Expr: &ast.ColumnNameExpr{ - Name: &ast.ColumnName{ - Name: model.CIStr{ - O: "*", - L: "*", - }, - }, - }, - }) + fields, err := u.buildSelectFields(ctx, tableMeta) + if err != nil { + return "", nil, err + } + if len(fields) == 0 { + return "", nil, err } selStmt := ast.SelectStmt{ @@ -263,6 +217,10 @@ func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, args []dri OrderBy: updateStmt.Order, Limit: updateStmt.Limit, TableHints: updateStmt.TableHints, + // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition + GroupBy: &ast.GroupByClause{ + Items: u.buildGroupByClause(ctx, tableMeta.TableName, tableMeta.GetPrimaryKeyOnlyName(), fields), + }, LockInfo: &ast.SelectLockInfo{ LockType: ast.SelectLockForUpdate, }, @@ -276,5 +234,102 @@ func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, args []dri return sql, u.buildSelectArgs(&selStmt, args), nil } -func (u *updateJoinExecutor) getDbVersion() string { +func (u *updateJoinExecutor) buildAfterImageSQL(ctx context.Context, beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue, error) { + selectSQL, selectArgs := u.updateExecutor.buildAfterImageSQL(beforeImage, meta) + + needUpdateColumns, err := u.buildSelectFields(ctx, meta) + if err != nil { + return "", nil, err + } + + // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition + groupByStr := strings.Builder{} + groupByItem := u.buildGroupByClause(ctx, meta.TableName, meta.GetPrimaryKeyOnlyName(), needUpdateColumns) + + groupByStr.WriteString(selectSQL) + groupByStr.WriteString(" GROUP BY ") + for index, item := range groupByItem { + if index != 0 { + groupByStr.WriteString(",") + } + groupByStr.WriteString(item.Expr.(*ast.ColumnNameExpr).Name.String()) + } + + groupByStr.WriteString(" ") + return groupByStr.String(), selectArgs, nil +} + +func (u *updateJoinExecutor) parseTableName(joinMate *ast.Join) []string { + var tableNames []string + if item, ok := joinMate.Left.(*ast.Join); ok { + tableNames = u.parseTableName(item) + } else { + leftName := joinMate.Left.(*ast.TableSource).Source.(*ast.TableName) + tableNames = append(tableNames, leftName.Name.O) + } + + rightName := joinMate.Right.(*ast.TableSource).Source.(*ast.TableName) + tableNames = append(tableNames, rightName.Name.O) + return tableNames +} + +// build group by condition which used for removing duplicate row in select join sql +func (u *updateJoinExecutor) buildGroupByClause(ctx context.Context, tableName string, pkColumns []string, allSelectColumns []*ast.SelectField) []*ast.ByItem { + var groupByPks = true + //only pks group by is valid when db version >= 5.7.5 + if u.isLowerSupportGroupByPksVersion { + if u.sqlMode == "" { + rowsi, err := u.rowsPrepare(ctx, "SELECT @@SQL_MODE", nil) + defer func() { + if rowsi != nil { + if rowerr := rowsi.Close(); rowerr != nil { + log.Errorf("rows close fail, err:%v", rowerr) + } + } + }() + if err != nil { + groupByPks = false + log.Warnf("determine group by pks or all columns error:%s", err) + } else { + // getString("@@SQL_MODE") + mode := make([]driver.Value, 1) + if err = rowsi.Next(mode); err != nil { + if err != io.EOF && len(mode) == 1 { + u.sqlMode = reflect.ValueOf(mode[0]).String() + } + } + } + } + + if strings.Contains(u.sqlMode, "ONLY_FULL_GROUP_BY") { + groupByPks = false + } + } + + groupByColumns := make([]*ast.ByItem, 0) + if groupByPks { + for _, column := range pkColumns { + groupByColumns = append(groupByColumns, &ast.ByItem{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Table: model.CIStr{ + O: tableName, + L: strings.ToLower(tableName), + }, + Name: model.CIStr{ + O: column, + L: strings.ToLower(column), + }, + }, + }, + }) + } + } else { + for _, column := range allSelectColumns { + groupByColumns = append(groupByColumns, &ast.ByItem{ + Expr: column.Expr, + }) + } + } + return groupByColumns } diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go new file mode 100644 index 000000000..80066660d --- /dev/null +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go @@ -0,0 +1,121 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. + */ + +package at + +import ( + "context" + "database/sql/driver" + "testing" + + "github.com/stretchr/testify/assert" + + "seata.apache.org/seata-go/pkg/datasource/sql/exec" + "seata.apache.org/seata-go/pkg/datasource/sql/parser" + "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/util" + _ "seata.apache.org/seata-go/pkg/util/log" +) + +func TestBuildSelectSQLByUpdateJoin(t *testing.T) { + tests := []struct { + name string + sourceQuery string + sourceQueryArgs []driver.Value + expectQuery map[string]string + expectQueryArgs []driver.Value + }{ + { + sourceQuery: "update t1 left join t2 on t1.id = t2.id inner join t3 on t3.id = t2.id right join t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30", + sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((t1 LEFT JOIN t2 ON t1.id=t2.id) JOIN t3 ON t3.id=t2.id) RIGHT JOIN t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE", + "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((t1 LEFT JOIN t2 ON t1.id=t2.id) JOIN t3 ON t3.id=t2.id) RIGHT JOIN t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1, 10}, + }, + { + sourceQuery: "update t1 left join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL'", + sourceQueryArgs: []driver.Value{}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id GROUP BY t1.name,t1.id FOR UPDATE", + "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{}, + }, + { + sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", + sourceQueryArgs: []driver.Value{1}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", + "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1}, + }, + { + sourceQuery: "update t1 right join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", + sourceQueryArgs: []driver.Value{1}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 RIGHT JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", + "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 RIGHT JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1}, + }, + { + sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?", + sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1, "Jack", 18, 28}, + }, + { + sourceQuery: "update t1 left join t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)", + sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1, 18, 28}, + }, + { + sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?", + sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2}, + expectQuery: map[string]string{ + "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", + "t2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", + }, + expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := parser.DoParser(tt.sourceQuery) + assert.Nil(t, err) + executor := NewUpdateJoinExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) + tableNames := executor.(*updateJoinExecutor).parseTableName(c.UpdateStmt.TableRefs.TableRefs) + for _, tbName := range tableNames { + query, args, err := executor.(*updateJoinExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap[tbName], util.ValueToNamedValue(tt.sourceQueryArgs)) + assert.Nil(t, err) + if query == "" { + continue + } + assert.Equal(t, tt.expectQuery[tbName], query) + assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) + } + }) + } +} diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go.go deleted file mode 100644 index 770f21688..000000000 --- a/pkg/datasource/sql/exec/at/update_join_executor_test.go.go +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package at - -import ( - "context" - "database/sql/driver" - "reflect" - "testing" - - "github.com/agiledragon/gomonkey/v2" - "github.com/stretchr/testify/assert" - - "seata.apache.org/seata-go/pkg/datasource/sql/datasource" - "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql" - "seata.apache.org/seata-go/pkg/datasource/sql/exec" - "seata.apache.org/seata-go/pkg/datasource/sql/parser" - "seata.apache.org/seata-go/pkg/datasource/sql/types" - "seata.apache.org/seata-go/pkg/datasource/sql/undo" - "seata.apache.org/seata-go/pkg/datasource/sql/util" - _ "seata.apache.org/seata-go/pkg/util/log" -) - -func TestBuildSelectSQLByUpdate(t *testing.T) { - undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) - datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil)) - stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", - func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { - return &types.TableMeta{ - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - }, nil - }) - defer stub.Reset() - - tests := []struct { - name string - sourceQuery string - sourceQueryArgs []driver.Value - expectQuery string - expectQueryArgs []driver.Value - }{ - { - sourceQuery: "update t_user set name = ?, age = ? where id = ?", - sourceQueryArgs: []driver.Value{"Jack", 1, 100}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", - expectQueryArgs: []driver.Value{100}, - }, - { - sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", - sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", - expectQueryArgs: []driver.Value{100, 18, 28}, - }, - { - sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", - sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", - expectQueryArgs: []driver.Value{100, 18, 28}, - }, - { - sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", - sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, - expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", - expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c, err := parser.DoParser(tt.sourceQuery) - assert.Nil(t, err) - executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) - query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) - assert.Nil(t, err) - assert.Equal(t, tt.expectQuery, query) - assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) - }) - } -} diff --git a/pkg/datasource/sql/types/types.go b/pkg/datasource/sql/types/types.go index 36b44a2e8..aa911285a 100644 --- a/pkg/datasource/sql/types/types.go +++ b/pkg/datasource/sql/types/types.go @@ -159,6 +159,7 @@ type ExecContext struct { Conn driver.Conn DBName string DBType DBType + DbVersion string // todo set values for these 4 param IsAutoCommit bool IsSupportsSavepoints bool From 63678ce34ecbb1f6ba8186283e8d2da8acda5f80 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 27 Dec 2024 22:28:30 +0800 Subject: [PATCH 3/8] update join condition placeholder param error --- pkg/datasource/sql/exec/at/base_executor.go | 12 ++++++++++++ .../sql/exec/at/update_join_executor_test.go | 10 +++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pkg/datasource/sql/exec/at/base_executor.go b/pkg/datasource/sql/exec/at/base_executor.go index 884438317..7583f28cf 100644 --- a/pkg/datasource/sql/exec/at/base_executor.go +++ b/pkg/datasource/sql/exec/at/base_executor.go @@ -98,7 +98,13 @@ func (b *baseExecutor) buildSelectArgs(stmt *ast.SelectStmt, args []driver.Named selectArgs = make([]driver.NamedValue, 0) ) + b.traversalArgs(stmt.From.TableRefs, &selectArgsIndexs) b.traversalArgs(stmt.Where, &selectArgsIndexs) + if stmt.GroupBy != nil { + for _, item := range stmt.GroupBy.Items { + b.traversalArgs(item, &selectArgsIndexs) + } + } if stmt.OrderBy != nil { for _, item := range stmt.OrderBy.Items { b.traversalArgs(item, &selectArgsIndexs) @@ -143,6 +149,12 @@ func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { b.traversalArgs(exprs[i], argsIndex) } break + case *ast.Join: + exprs := node.(*ast.Join) + b.traversalArgs(exprs.Left, argsIndex) + b.traversalArgs(exprs.Right, argsIndex) + b.traversalArgs(exprs.On.Expr, argsIndex) + break case *test_driver.ParamMarkerExpr: *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) break diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go index 80066660d..810e76970 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go @@ -49,13 +49,13 @@ func TestBuildSelectSQLByUpdateJoin(t *testing.T) { expectQueryArgs: []driver.Value{1, 10}, }, { - sourceQuery: "update t1 left join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL'", - sourceQueryArgs: []driver.Value{}, + sourceQuery: "update t1 left join t2 on t1.id = t2.id and t1.age=? set t1.name = 'WILL',t2.name = ?", + sourceQueryArgs: []driver.Value{18, "Jack"}, expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id GROUP BY t1.name,t1.id FOR UPDATE", - "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id GROUP BY t2.name,t2.id FOR UPDATE", + "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id AND t1.age=? GROUP BY t1.name,t1.id FOR UPDATE", + "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id AND t1.age=? GROUP BY t2.name,t2.id FOR UPDATE", }, - expectQueryArgs: []driver.Value{}, + expectQueryArgs: []driver.Value{18}, }, { sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", From 85de2385b646c9b4c6bf55e6e68a22269bc9dc78 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 10 Jan 2025 17:00:41 +0800 Subject: [PATCH 4/8] update join bugfix --- pkg/datasource/sql/exec/at/at_executor.go | 2 - pkg/datasource/sql/exec/at/base_executor.go | 76 ++++++++++- pkg/datasource/sql/exec/at/update_executor.go | 72 ++-------- .../sql/exec/at/update_executor_test.go | 18 +-- .../sql/exec/at/update_join_executor.go | 123 ++++++++++-------- .../sql/exec/at/update_join_executor_test.go | 114 ++++++++-------- pkg/datasource/sql/types/sql.go | 1 - .../sql/undo/factor/undo_executor_factory.go | 2 +- 8 files changed, 213 insertions(+), 195 deletions(-) diff --git a/pkg/datasource/sql/exec/at/at_executor.go b/pkg/datasource/sql/exec/at/at_executor.go index 09b284b34..e51db2284 100644 --- a/pkg/datasource/sql/exec/at/at_executor.go +++ b/pkg/datasource/sql/exec/at/at_executor.go @@ -68,8 +68,6 @@ func (e *ATExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Exec executor = NewInsertOnUpdateExecutor(queryParser, execCtx, e.hooks) case types.SQLTypeMulti: executor = NewMultiExecutor(queryParser, execCtx, e.hooks) - case types.SQLTypeUpdateJoin: - executor = NewUpdateJoinExecutor(queryParser, execCtx, e.hooks) default: executor = NewPlainExecutor(queryParser, execCtx) } diff --git a/pkg/datasource/sql/exec/at/base_executor.go b/pkg/datasource/sql/exec/at/base_executor.go index 7583f28cf..5e2ac16f0 100644 --- a/pkg/datasource/sql/exec/at/base_executor.go +++ b/pkg/datasource/sql/exec/at/base_executor.go @@ -22,16 +22,18 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "fmt" "strings" "github.com/arana-db/parser/ast" + "github.com/arana-db/parser/model" "github.com/arana-db/parser/test_driver" gxsort "github.com/dubbogo/gost/sort" + "github.com/pkg/errors" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" "seata.apache.org/seata-go/pkg/util/reflectx" ) @@ -152,8 +154,12 @@ func (b *baseExecutor) traversalArgs(node ast.Node, argsIndex *[]int32) { case *ast.Join: exprs := node.(*ast.Join) b.traversalArgs(exprs.Left, argsIndex) - b.traversalArgs(exprs.Right, argsIndex) - b.traversalArgs(exprs.On.Expr, argsIndex) + if exprs.Right != nil { + b.traversalArgs(exprs.Right, argsIndex) + } + if exprs.On != nil { + b.traversalArgs(exprs.On.Expr, argsIndex) + } break case *test_driver.ParamMarkerExpr: *argsIndex = append(*argsIndex, int32(node.(*test_driver.ParamMarkerExpr).Order)) @@ -200,6 +206,64 @@ func (b *baseExecutor) buildRecordImages(rowsi driver.Rows, tableMetaData *types return &types.RecordImage{TableName: tableMetaData.TableName, Rows: rowImages, SQLType: sqlType}, nil } +func (u *baseExecutor) buildSelectFields(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, inUseFields []*ast.Assignment) ([]*ast.SelectField, error) { + fields := make([]*ast.SelectField, 0, len(inUseFields)) + + tableName := tableAliases + if tableAliases == "" { + tableName = tableMeta.TableName + } + if undo.UndoConfig.OnlyCareUpdateColumns { + for _, column := range inUseFields { + tn := column.Column.Table.O + if tn != "" && tn != tableName { + continue + } + + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: column.Column, + }, + }) + } + + if len(fields) == 0 { + return fields, nil + } + + // select indexes columns + for _, columnName := range tableMeta.GetPrimaryKeyOnlyName() { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Table: model.CIStr{ + O: tableName, + L: tableName, + }, + Name: model.CIStr{ + O: columnName, + L: columnName, + }, + }, + }, + }) + } + } else { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: "*", + L: "*", + }, + }, + }, + }) + } + + return fields, nil +} + func getSqlNullValue(value interface{}) interface{} { if value == nil { return nil @@ -364,12 +428,12 @@ func (b *baseExecutor) buildLockKey(records *types.RecordImage, meta types.Table return lockKeys.String() } -func (u *updateExecutor) rowsPrepare(ctx context.Context, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) { +func (b *baseExecutor) rowsPrepare(ctx context.Context, conn driver.Conn, selectSQL string, selectArgs []driver.NamedValue) (driver.Rows, error) { var queryer driver.Queryer - queryerContext, ok := u.execContext.Conn.(driver.QueryerContext) + queryerContext, ok := conn.(driver.QueryerContext) if !ok { - queryer, ok = u.execContext.Conn.(driver.Queryer) + queryer, ok = conn.(driver.Queryer) } if ok { var err error diff --git a/pkg/datasource/sql/exec/at/update_executor.go b/pkg/datasource/sql/exec/at/update_executor.go index 3463ba895..906fb15a1 100644 --- a/pkg/datasource/sql/exec/at/update_executor.go +++ b/pkg/datasource/sql/exec/at/update_executor.go @@ -26,7 +26,6 @@ import ( "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" - "github.com/arana-db/parser/model" "seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/exec" @@ -49,6 +48,10 @@ type updateExecutor struct { // NewUpdateExecutor get update executor func NewUpdateExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor { + // Because update join cannot be clearly identified when SQL cannot be parsed + if parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil { + return NewUpdateJoinExecutor(parserCtx, execContent, hooks) + } return &updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}} } @@ -96,7 +99,8 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, err } - selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, u.execContext.NamedValues) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, "", u.execContext.NamedValues) + if err != nil { return nil, err } @@ -104,7 +108,7 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, errors.New("build select sql by update sourceQuery fail") } - rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) defer func() { if rowsi != nil { if err := rowsi.Close(); err != nil { @@ -145,7 +149,7 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor } selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) - rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) defer func() { if rowsi != nil { if err := rowsi.Close(); err != nil { @@ -197,14 +201,14 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta } // buildAfterImageSQL build the SQL to query before image data -func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, args []driver.NamedValue) (string, []driver.NamedValue, error) { +func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { log.Errorf("invalid update stmt") return "", nil, fmt.Errorf("invalid update stmt") } updateStmt := u.parserCtx.UpdateStmt - fields, err := u.buildSelectFields(ctx, tableMeta) + fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List) if err != nil { return "", nil, err } @@ -232,59 +236,3 @@ func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *typ return sql, u.buildSelectArgs(&selStmt, args), nil } - -func (u *updateExecutor) buildSelectFields(ctx context.Context, tableMeta *types.TableMeta) ([]*ast.SelectField, error) { - updateStmt := u.parserCtx.UpdateStmt - fields := make([]*ast.SelectField, 0, len(updateStmt.List)) - - lowerTableName := strings.ToLower(tableMeta.TableName) - if undo.UndoConfig.OnlyCareUpdateColumns { - for _, column := range updateStmt.List { - tableName := column.Column.Table.L - if tableName != "" && lowerTableName != tableName { - continue - } - - fields = append(fields, &ast.SelectField{ - Expr: &ast.ColumnNameExpr{ - Name: column.Column, - }, - }) - } - - if len(fields) == 0 { - return fields, nil - } - - // select indexes columns - for _, columnName := range tableMeta.GetPrimaryKeyOnlyName() { - fields = append(fields, &ast.SelectField{ - Expr: &ast.ColumnNameExpr{ - Name: &ast.ColumnName{ - Table: model.CIStr{ - O: tableMeta.TableName, - L: lowerTableName, - }, - Name: model.CIStr{ - O: columnName, - L: columnName, - }, - }, - }, - }) - } - } else { - fields = append(fields, &ast.SelectField{ - Expr: &ast.ColumnNameExpr{ - Name: &ast.ColumnName{ - Name: model.CIStr{ - O: "*", - L: "*", - }, - }, - }, - }) - } - - return fields, nil -} diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go index c3efb1eb2..103358d1c 100644 --- a/pkg/datasource/sql/exec/at/update_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -65,8 +65,8 @@ func initTest() { }, ColumnNames: []string{"id", "name", "age"}, }, - "t1": { - TableName: "t1", + "table1": { + TableName: "table1", Indexs: map[string]types.IndexMeta{ "id": { IType: types.IndexTypePrimaryKey, @@ -91,8 +91,8 @@ func initTest() { }, ColumnNames: []string{"id", "name", "age"}, }, - "t2": { - TableName: "t2", + "table2": { + TableName: "table2", Indexs: map[string]types.IndexMeta{ "id": { IType: types.IndexTypePrimaryKey, @@ -125,8 +125,8 @@ func initTest() { }, ColumnNames: []string{"id", "name", "age", "kk", "addr"}, }, - "t3": { - TableName: "t3", + "table3": { + TableName: "table3", Indexs: map[string]types.IndexMeta{ "id": { IType: types.IndexTypePrimaryKey, @@ -147,8 +147,8 @@ func initTest() { }, ColumnNames: []string{"id", "age"}, }, - "t4": { - TableName: "t4", + "table4": { + TableName: "table4", Indexs: map[string]types.IndexMeta{ "id": { IType: types.IndexTypePrimaryKey, @@ -221,7 +221,7 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { c, err := parser.DoParser(tt.sourceQuery) assert.Nil(t, err) executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) - query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap["t_user"], util.ValueToNamedValue(tt.sourceQueryArgs)) + query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap["t_user"], "", util.ValueToNamedValue(tt.sourceQueryArgs)) assert.Nil(t, err) assert.Equal(t, tt.expectQuery, query) assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) diff --git a/pkg/datasource/sql/exec/at/update_join_executor.go b/pkg/datasource/sql/exec/at/update_join_executor.go index 7658626a4..55261064f 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor.go +++ b/pkg/datasource/sql/exec/at/update_join_executor.go @@ -21,7 +21,6 @@ import ( "context" "database/sql/driver" "errors" - "fmt" "io" "reflect" "strings" @@ -40,11 +39,12 @@ import ( // updateJoinExecutor execute update SQL type updateJoinExecutor struct { - updateExecutor + baseExecutor parserCtx *types.ParseContext execContext *types.ExecContext isLowerSupportGroupByPksVersion bool sqlMode string + tableAliasesMap map[string]string } // NewUpdateJoinExecutor get executor @@ -54,8 +54,9 @@ func NewUpdateJoinExecutor(parserCtx *types.ParseContext, execContent *types.Exe return &updateJoinExecutor{ parserCtx: parserCtx, execContext: execContent, - updateExecutor: updateExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}, + baseExecutor: baseExecutor{hooks: hooks}, isLowerSupportGroupByPksVersion: currentVersion < minimumVersion, + tableAliasesMap: make(map[string]string, 0), } } @@ -66,6 +67,10 @@ func (u *updateJoinExecutor) ExecContext(ctx context.Context, f exec.CallbackWit u.afterHooks(ctx, u.execContext) }() + if u.isAstStmtValid() { + u.tableAliasesMap = u.parseTableName(u.parserCtx.UpdateStmt.TableRefs.TableRefs) + } + beforeImages, err := u.beforeImage(ctx) if err != nil { return nil, err @@ -85,19 +90,16 @@ func (u *updateJoinExecutor) ExecContext(ctx context.Context, f exec.CallbackWit return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") } - for i, afterImage := range afterImages { - beforeImage := afterImages[i] - if len(beforeImage.Rows) != len(afterImage.Rows) { - return nil, errors.New("Before image size is not equaled to after image size, probably because you updated the primary keys.") - } - - u.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage) - u.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage) - } + u.execContext.TxCtx.RoundImages.AppendBeofreImages(beforeImages) + u.execContext.TxCtx.RoundImages.AppendAfterImages(afterImages) return res, nil } +func (u *updateJoinExecutor) isAstStmtValid() bool { + return u.parserCtx != nil && u.parserCtx.UpdateStmt != nil && u.parserCtx.UpdateStmt.TableRefs.TableRefs.Right != nil +} + func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordImage, error) { if !u.isAstStmtValid() { return nil, nil @@ -105,16 +107,12 @@ func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordIm var recordImages []*types.RecordImage - // Parsing multiple table name - updateStmt := u.parserCtx.UpdateStmt - tableNames := u.parseTableName(updateStmt.TableRefs.TableRefs) - - for _, tbName := range tableNames { - metaData, err := datasource.GetTableCache(u.execContext.DBType).GetTableMeta(ctx, u.execContext.DBName, tbName) + for tbName, tableAliases := range u.tableAliasesMap { + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tbName) if err != nil { return nil, err } - selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, u.execContext.NamedValues) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, tableAliases, u.execContext.NamedValues) if err != nil { return nil, err } @@ -124,9 +122,9 @@ func (u *updateJoinExecutor) beforeImage(ctx context.Context) ([]*types.RecordIm } var image *types.RecordImage - rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) if err == nil { - image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdateJoin) + image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) } if rowsi != nil { if rowerr := rows.Close(); rowerr != nil { @@ -160,20 +158,20 @@ func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImages []*typ var recordImages []*types.RecordImage for _, beforeImage := range beforeImages { - metaData, err := datasource.GetTableCache(u.execContext.DBType).GetTableMeta(ctx, u.execContext.DBName, beforeImage.TableName) + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, beforeImage.TableName) if err != nil { return nil, err } - selectSQL, selectArgs, err := u.buildAfterImageSQL(ctx, *beforeImage, metaData) + selectSQL, selectArgs, err := u.buildAfterImageSQL(ctx, *beforeImage, metaData, u.tableAliasesMap[beforeImage.TableName]) if err != nil { return nil, err } var image *types.RecordImage - rowsi, err := u.rowsPrepare(ctx, selectSQL, selectArgs) + rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) if err == nil { - image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdateJoin) + image, err = u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) } if rowsi != nil { if rowerr := rowsi.Close(); rowerr != nil { @@ -194,14 +192,9 @@ func (u *updateJoinExecutor) afterImage(ctx context.Context, beforeImages []*typ } // buildAfterImageSQL build the SQL to query before image data -func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, args []driver.NamedValue) (string, []driver.NamedValue, error) { - if !u.isAstStmtValid() { - log.Errorf("invalid update join stmt") - return "", nil, fmt.Errorf("invalid update join stmt") - } - +func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) { updateStmt := u.parserCtx.UpdateStmt - fields, err := u.buildSelectFields(ctx, tableMeta) + fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List) if err != nil { return "", nil, err } @@ -219,7 +212,7 @@ func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta TableHints: updateStmt.TableHints, // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition GroupBy: &ast.GroupByClause{ - Items: u.buildGroupByClause(ctx, tableMeta.TableName, tableMeta.GetPrimaryKeyOnlyName(), fields), + Items: u.buildGroupByClause(ctx, tableMeta.TableName, tableAliases, tableMeta.GetPrimaryKeyOnlyName(), fields), }, LockInfo: &ast.SelectLockInfo{ LockType: ast.SelectLockForUpdate, @@ -234,52 +227,68 @@ func (u *updateJoinExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta return sql, u.buildSelectArgs(&selStmt, args), nil } -func (u *updateJoinExecutor) buildAfterImageSQL(ctx context.Context, beforeImage types.RecordImage, meta *types.TableMeta) (string, []driver.NamedValue, error) { - selectSQL, selectArgs := u.updateExecutor.buildAfterImageSQL(beforeImage, meta) +func (u *updateJoinExecutor) buildAfterImageSQL(ctx context.Context, beforeImage types.RecordImage, meta *types.TableMeta, tableAliases string) (string, []driver.NamedValue, error) { + if len(beforeImage.Rows) == 0 { + return "", nil, nil + } - needUpdateColumns, err := u.buildSelectFields(ctx, meta) + fields, err := u.buildSelectFields(ctx, meta, tableAliases, u.parserCtx.UpdateStmt.List) if err != nil { return "", nil, err } + if len(fields) == 0 { + return "", nil, err + } - // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition - groupByStr := strings.Builder{} - groupByItem := u.buildGroupByClause(ctx, meta.TableName, meta.GetPrimaryKeyOnlyName(), needUpdateColumns) - - groupByStr.WriteString(selectSQL) - groupByStr.WriteString(" GROUP BY ") - for index, item := range groupByItem { - if index != 0 { - groupByStr.WriteString(",") - } - groupByStr.WriteString(item.Expr.(*ast.ColumnNameExpr).Name.String()) + updateStmt := u.parserCtx.UpdateStmt + selStmt := ast.SelectStmt{ + SelectStmtOpts: &ast.SelectStmtOpts{}, + From: updateStmt.TableRefs, + Where: updateStmt.Where, + Fields: &ast.FieldList{Fields: fields}, + OrderBy: updateStmt.Order, + Limit: updateStmt.Limit, + TableHints: updateStmt.TableHints, + // maybe duplicate row for select join sql.remove duplicate row by 'group by' condition + GroupBy: &ast.GroupByClause{ + Items: u.buildGroupByClause(ctx, meta.TableName, tableAliases, meta.GetPrimaryKeyOnlyName(), fields), + }, } - groupByStr.WriteString(" ") - return groupByStr.String(), selectArgs, nil + b := bytes.NewByteBuffer([]byte{}) + _ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b)) + sql := string(b.Bytes()) + log.Infof("build select sql by update sourceQuery, sql {%s}", sql) + + return sql, u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()), nil } -func (u *updateJoinExecutor) parseTableName(joinMate *ast.Join) []string { - var tableNames []string +func (u *updateJoinExecutor) parseTableName(joinMate *ast.Join) map[string]string { + tableNames := make(map[string]string, 0) if item, ok := joinMate.Left.(*ast.Join); ok { tableNames = u.parseTableName(item) } else { - leftName := joinMate.Left.(*ast.TableSource).Source.(*ast.TableName) - tableNames = append(tableNames, leftName.Name.O) + leftTableSource := joinMate.Left.(*ast.TableSource) + leftName := leftTableSource.Source.(*ast.TableName) + tableNames[leftName.Name.O] = leftTableSource.AsName.O } - rightName := joinMate.Right.(*ast.TableSource).Source.(*ast.TableName) - tableNames = append(tableNames, rightName.Name.O) + rightTableSource := joinMate.Right.(*ast.TableSource) + rightName := rightTableSource.Source.(*ast.TableName) + tableNames[rightName.Name.O] = rightTableSource.AsName.O return tableNames } // build group by condition which used for removing duplicate row in select join sql -func (u *updateJoinExecutor) buildGroupByClause(ctx context.Context, tableName string, pkColumns []string, allSelectColumns []*ast.SelectField) []*ast.ByItem { +func (u *updateJoinExecutor) buildGroupByClause(ctx context.Context, tableName string, tableAliases string, pkColumns []string, allSelectColumns []*ast.SelectField) []*ast.ByItem { var groupByPks = true + if tableAliases != "" { + tableName = tableAliases + } //only pks group by is valid when db version >= 5.7.5 if u.isLowerSupportGroupByPksVersion { if u.sqlMode == "" { - rowsi, err := u.rowsPrepare(ctx, "SELECT @@SQL_MODE", nil) + rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, "SELECT @@SQL_MODE", nil) defer func() { if rowsi != nil { if rowerr := rowsi.Close(); rowerr != nil { diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go index 810e76970..367e196d6 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go @@ -40,66 +40,66 @@ func TestBuildSelectSQLByUpdateJoin(t *testing.T) { expectQueryArgs []driver.Value }{ { - sourceQuery: "update t1 left join t2 on t1.id = t2.id inner join t3 on t3.id = t2.id right join t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30", - sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10}, - expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((t1 LEFT JOIN t2 ON t1.id=t2.id) JOIN t3 ON t3.id=t2.id) RIGHT JOIN t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE", - "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((t1 LEFT JOIN t2 ON t1.id=t2.id) JOIN t3 ON t3.id=t2.id) RIGHT JOIN t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE", - }, - expectQueryArgs: []driver.Value{1, 10}, - }, - { - sourceQuery: "update t1 left join t2 on t1.id = t2.id and t1.age=? set t1.name = 'WILL',t2.name = ?", + sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id and t1.age=? set t1.name = 'WILL',t2.name = ?", sourceQueryArgs: []driver.Value{18, "Jack"}, expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id AND t1.age=? GROUP BY t1.name,t1.id FOR UPDATE", - "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id AND t1.age=? GROUP BY t2.name,t2.id FOR UPDATE", + "table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id AND t1.age=? GROUP BY t1.name,t1.id FOR UPDATE", + "table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id AND t1.age=? GROUP BY t2.name,t2.id FOR UPDATE", }, expectQueryArgs: []driver.Value{18}, }, - { - sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", - sourceQueryArgs: []driver.Value{1}, - expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", - "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", - }, - expectQueryArgs: []driver.Value{1}, - }, - { - sourceQuery: "update t1 right join t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", - sourceQueryArgs: []driver.Value{1}, - expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 RIGHT JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", - "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM t1 RIGHT JOIN t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", - }, - expectQueryArgs: []driver.Value{1}, - }, - { - sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?", - sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28}, - expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE", - }, - expectQueryArgs: []driver.Value{1, "Jack", 18, 28}, - }, - { - sourceQuery: "update t1 left join t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)", - sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28}, - expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM t1 LEFT JOIN t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE", - }, - expectQueryArgs: []driver.Value{1, 18, 28}, - }, - { - sourceQuery: "update t1 inner join t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?", - sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2}, - expectQuery: map[string]string{ - "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", - "t2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM t1 JOIN t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", - }, - expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2}, - }, + //{ + // sourceQuery: "update table1 AS t1 inner join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", + // sourceQueryArgs: []driver.Value{1}, + // expectQuery: map[string]string{ + // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", + // "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", + // }, + // expectQueryArgs: []driver.Value{1}, + //}, + //{ + // sourceQuery: "update table1 AS t1 right join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", + // sourceQueryArgs: []driver.Value{1}, + // expectQuery: map[string]string{ + // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", + // "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", + // }, + // expectQueryArgs: []driver.Value{1}, + //}, + //{ + // sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?", + // sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28}, + // expectQuery: map[string]string{ + // "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE", + // }, + // expectQueryArgs: []driver.Value{1, "Jack", 18, 28}, + //}, + //{ + // sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)", + // sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28}, + // expectQuery: map[string]string{ + // "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE", + // }, + // expectQueryArgs: []driver.Value{1, 18, 28}, + //}, + //{ + // sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?", + // sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2}, + // expectQuery: map[string]string{ + // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", + // "t2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", + // }, + // expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2}, + //}, + //{ + // sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id inner join t3 on t3.id = t2.id right join t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30", + // sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10}, + // expectQuery: map[string]string{ + // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE", + // "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE", + // }, + // expectQueryArgs: []driver.Value{1, 10}, + //}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -107,8 +107,8 @@ func TestBuildSelectSQLByUpdateJoin(t *testing.T) { assert.Nil(t, err) executor := NewUpdateJoinExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) tableNames := executor.(*updateJoinExecutor).parseTableName(c.UpdateStmt.TableRefs.TableRefs) - for _, tbName := range tableNames { - query, args, err := executor.(*updateJoinExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap[tbName], util.ValueToNamedValue(tt.sourceQueryArgs)) + for tbName, tableAliases := range tableNames { + query, args, err := executor.(*updateJoinExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap[tbName], tableAliases, util.ValueToNamedValue(tt.sourceQueryArgs)) assert.Nil(t, err) if query == "" { continue diff --git a/pkg/datasource/sql/types/sql.go b/pkg/datasource/sql/types/sql.go index 728ce6158..58f4612f4 100644 --- a/pkg/datasource/sql/types/sql.go +++ b/pkg/datasource/sql/types/sql.go @@ -68,7 +68,6 @@ const ( SQLTypeSelectFoundRows SQLTypeInsertIgnore = iota + 57 SQLTypeInsertOnDuplicateUpdate - SQLTypeUpdateJoin // SQLTypeMulti and SQLTypeUnknown is different from seata-java SQLTypeMulti = iota + 999 SQLTypeUnknown diff --git a/pkg/datasource/sql/undo/factor/undo_executor_factory.go b/pkg/datasource/sql/undo/factor/undo_executor_factory.go index d8ae8ec5d..2a943eff4 100644 --- a/pkg/datasource/sql/undo/factor/undo_executor_factory.go +++ b/pkg/datasource/sql/undo/factor/undo_executor_factory.go @@ -37,7 +37,7 @@ func GetUndoExecutor(dbType types.DBType, sqlUndoLog undo.SQLUndoLog) (res undo. res = undoExecutorHolder.GetInsertExecutor(sqlUndoLog) case types.SQLTypeDelete: res = undoExecutorHolder.GetDeleteExecutor(sqlUndoLog) - case types.SQLTypeUpdate, types.SQLTypeUpdateJoin: + case types.SQLTypeUpdate: res = undoExecutorHolder.GetUpdateExecutor(sqlUndoLog) default: return nil, fmt.Errorf("sql type: %d not support", sqlUndoLog.SQLType) From cacae212987dd557a0cf354eb503a590eb15ced1 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 10 Jan 2025 17:08:55 +0800 Subject: [PATCH 5/8] Open test annotations --- .../sql/exec/at/update_join_executor_test.go | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go index 367e196d6..0f396ddb1 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go @@ -48,58 +48,58 @@ func TestBuildSelectSQLByUpdateJoin(t *testing.T) { }, expectQueryArgs: []driver.Value{18}, }, - //{ - // sourceQuery: "update table1 AS t1 inner join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", - // sourceQueryArgs: []driver.Value{1}, - // expectQuery: map[string]string{ - // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", - // "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", - // }, - // expectQueryArgs: []driver.Value{1}, - //}, - //{ - // sourceQuery: "update table1 AS t1 right join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", - // sourceQueryArgs: []driver.Value{1}, - // expectQuery: map[string]string{ - // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", - // "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", - // }, - // expectQueryArgs: []driver.Value{1}, - //}, - //{ - // sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?", - // sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28}, - // expectQuery: map[string]string{ - // "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE", - // }, - // expectQueryArgs: []driver.Value{1, "Jack", 18, 28}, - //}, - //{ - // sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)", - // sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28}, - // expectQuery: map[string]string{ - // "t1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE", - // }, - // expectQueryArgs: []driver.Value{1, 18, 28}, - //}, - //{ - // sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?", - // sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2}, - // expectQuery: map[string]string{ - // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", - // "t2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", - // }, - // expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2}, - //}, - //{ - // sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id inner join t3 on t3.id = t2.id right join t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30", - // sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10}, - // expectQuery: map[string]string{ - // "t1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE", - // "t2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE", - // }, - // expectQueryArgs: []driver.Value{1, 10}, - //}, + { + sourceQuery: "update table1 AS t1 inner join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", + sourceQueryArgs: []driver.Value{1}, + expectQuery: map[string]string{ + "table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", + "table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1}, + }, + { + sourceQuery: "update table1 AS t1 right join table2 AS t2 on t1.id = t2.id set t1.name = 'WILL',t2.name = 'WILL' where t1.id=?", + sourceQueryArgs: []driver.Value{1}, + expectQuery: map[string]string{ + "table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t1.name,t1.id FOR UPDATE", + "table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM table1 AS t1 RIGHT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1}, + }, + { + sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id = ? and t1.name = ? and t2.age between ? and ?", + sourceQueryArgs: []driver.Value{"newJack", 38, 1, "Jack", 18, 28}, + expectQuery: map[string]string{ + "table1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t1.name=? AND t2.age BETWEEN ? AND ? GROUP BY t1.name,t1.age,t1.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1, "Jack", 18, 28}, + }, + { + sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id set t1.name = ?, t1.age = ? where t1.id=? and t2.id is null and t1.age IN (?,?)", + sourceQueryArgs: []driver.Value{"newJack", 38, 1, 18, 28}, + expectQuery: map[string]string{ + "table1": "SELECT SQL_NO_CACHE t1.name,t1.age,t1.id FROM table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id WHERE t1.id=? AND t2.id IS NULL AND t1.age IN (?,?) GROUP BY t1.name,t1.age,t1.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1, 18, 28}, + }, + { + sourceQuery: "update table1 t1 inner join table2 t2 on t1.id = t2.id set t1.name = ?, t2.age = ? where t2.kk between ? and ? and t2.addr in(?,?) and t2.age > ? order by t1.name desc limit ?", + sourceQueryArgs: []driver.Value{"Jack", 18, 10, 20, "Beijing", "Guangzhou", 18, 2}, + expectQuery: map[string]string{ + "table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t1.name,t1.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", + "table2": "SELECT SQL_NO_CACHE t2.age,t2.id FROM table1 AS t1 JOIN table2 AS t2 ON t1.id=t2.id WHERE t2.kk BETWEEN ? AND ? AND t2.addr IN (?,?) AND t2.age>? GROUP BY t2.age,t2.id ORDER BY t1.name DESC LIMIT ? FOR UPDATE", + }, + expectQueryArgs: []driver.Value{10, 20, "Beijing", "Guangzhou", 18, 2}, + }, + { + sourceQuery: "update table1 t1 left join table2 t2 on t1.id = t2.id inner join table3 t3 on t3.id = t2.id right join table4 t4 on t4.id = t2.id set t1.name = ?,t2.name = ? where t1.id=? and t3.age=? and t4.age>30", + sourceQueryArgs: []driver.Value{"Jack", "WILL", 1, 10}, + expectQuery: map[string]string{ + "table1": "SELECT SQL_NO_CACHE t1.name,t1.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t1.name,t1.id FOR UPDATE", + "table2": "SELECT SQL_NO_CACHE t2.name,t2.id FROM ((table1 AS t1 LEFT JOIN table2 AS t2 ON t1.id=t2.id) JOIN table3 AS t3 ON t3.id=t2.id) RIGHT JOIN table4 AS t4 ON t4.id=t2.id WHERE t1.id=? AND t3.age=? AND t4.age>30 GROUP BY t2.name,t2.id FOR UPDATE", + }, + expectQueryArgs: []driver.Value{1, 10}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 0774046a7f3ba4d33df4ebb36ae5f7b511d41217 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 10 Jan 2025 19:22:34 +0800 Subject: [PATCH 6/8] recover update executor --- pkg/datasource/sql/exec/at/update_executor.go | 115 +++++++++++----- .../sql/exec/at/update_executor_test.go | 126 +----------------- .../sql/exec/at/update_join_executor_test.go | 110 +++++++++++++++ 3 files changed, 195 insertions(+), 156 deletions(-) diff --git a/pkg/datasource/sql/exec/at/update_executor.go b/pkg/datasource/sql/exec/at/update_executor.go index 906fb15a1..0f14e97b1 100644 --- a/pkg/datasource/sql/exec/at/update_executor.go +++ b/pkg/datasource/sql/exec/at/update_executor.go @@ -20,8 +20,9 @@ package at import ( "context" "database/sql/driver" - "errors" "fmt" + "github.com/arana-db/parser/model" + "seata.apache.org/seata-go/pkg/datasource/sql/util" "strings" "github.com/arana-db/parser/ast" @@ -93,32 +94,37 @@ func (u *updateExecutor) beforeImage(ctx context.Context) (*types.RecordImage, e return nil, nil } - tableName, _ := u.parserCtx.GetTableName() - metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, u.execContext.NamedValues) if err != nil { return nil, err } - selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, metaData, "", u.execContext.NamedValues) - + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) if err != nil { return nil, err } - if selectSQL == "" { - return nil, errors.New("build select sql by update sourceQuery fail") - } - rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - if err := rowsi.Close(); err != nil { - log.Errorf("rows close fail, err:%v", err) - return + var rowsi driver.Rows + queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err } - }() - if err != nil { - return nil, err + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") } image, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) @@ -149,17 +155,26 @@ func (u *updateExecutor) afterImage(ctx context.Context, beforeImage types.Recor } selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) - rowsi, err := u.rowsPrepare(ctx, u.execContext.Conn, selectSQL, selectArgs) - defer func() { - if rowsi != nil { - if err := rowsi.Close(); err != nil { - log.Errorf("rows close fail, err:%v", err) - return + var rowsi driver.Rows + queryerCtx, ok := u.execContext.Conn.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = u.execContext.Conn.(driver.Queryer) + } + if ok { + rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs) + defer func() { + if rowsi != nil { + rowsi.Close() } + }() + if err != nil { + log.Errorf("ctx driver query: %+v", err) + return nil, err } - }() - if err != nil { - return nil, err + } else { + log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") + return nil, fmt.Errorf("invalid conn") } afterImage, err := u.buildRecordImages(rowsi, metaData, types.SQLTypeUpdate) @@ -201,19 +216,53 @@ func (u *updateExecutor) buildAfterImageSQL(beforeImage types.RecordImage, meta } // buildAfterImageSQL build the SQL to query before image data -func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, tableMeta *types.TableMeta, tableAliases string, args []driver.NamedValue) (string, []driver.NamedValue, error) { +func (u *updateExecutor) buildBeforeImageSQL(ctx context.Context, args []driver.NamedValue) (string, []driver.NamedValue, error) { if !u.isAstStmtValid() { log.Errorf("invalid update stmt") return "", nil, fmt.Errorf("invalid update stmt") } updateStmt := u.parserCtx.UpdateStmt - fields, err := u.buildSelectFields(ctx, tableMeta, tableAliases, updateStmt.List) - if err != nil { - return "", nil, err - } - if len(fields) == 0 { - return "", nil, err + fields := make([]*ast.SelectField, 0, len(updateStmt.List)) + + if undo.UndoConfig.OnlyCareUpdateColumns { + for _, column := range updateStmt.List { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: column.Column, + }, + }) + } + + // select indexes columns + tableName, _ := u.parserCtx.GetTableName() + metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, u.execContext.DBName, tableName) + if err != nil { + return "", nil, err + } + for _, columnName := range metaData.GetPrimaryKeyOnlyName() { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: columnName, + L: columnName, + }, + }, + }, + }) + } + } else { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: "*", + L: "*", + }, + }, + }, + }) } selStmt := ast.SelectStmt{ diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go index 103358d1c..34173aafb 100644 --- a/pkg/datasource/sql/exec/at/update_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -20,7 +20,7 @@ package at import ( "context" "database/sql/driver" - "os" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "testing" "github.com/stretchr/testify/assert" @@ -28,17 +28,12 @@ import ( "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/parser" "seata.apache.org/seata-go/pkg/datasource/sql/types" - "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" _ "seata.apache.org/seata-go/pkg/util/log" ) -var ( - MetaDataMap map[string]*types.TableMeta -) - -func initTest() { - MetaDataMap = map[string]*types.TableMeta{ +func TestBuildSelectSQLByUpdate(t *testing.T) { + MetaDataMap := map[string]*types.TableMeta{ "t_user": { TableName: "t_user", Indexs: map[string]types.IndexMeta{ @@ -65,124 +60,9 @@ func initTest() { }, ColumnNames: []string{"id", "name", "age"}, }, - "table1": { - TableName: "table1", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "name": { - ColumnDef: nil, - ColumnName: "name", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "name", "age"}, - }, - "table2": { - TableName: "table2", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "name": { - ColumnDef: nil, - ColumnName: "name", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - "kk": { - ColumnDef: nil, - ColumnName: "kk", - }, - "addr": { - ColumnDef: nil, - ColumnName: "addr", - }, - }, - ColumnNames: []string{"id", "name", "age", "kk", "addr"}, - }, - "table3": { - TableName: "table3", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "age"}, - }, - "table4": { - TableName: "table4", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, - }, - }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "age"}, - }, } undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) -} - -func TestMain(m *testing.M) { - // 调用初始化函数 - initTest() - - // 启动测试 - os.Exit(m.Run()) -} - -func TestBuildSelectSQLByUpdate(t *testing.T) { tests := []struct { name string diff --git a/pkg/datasource/sql/exec/at/update_join_executor_test.go b/pkg/datasource/sql/exec/at/update_join_executor_test.go index 0f396ddb1..0ef30da3f 100644 --- a/pkg/datasource/sql/exec/at/update_join_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_join_executor_test.go @@ -20,6 +20,7 @@ package at import ( "context" "database/sql/driver" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "testing" "github.com/stretchr/testify/assert" @@ -32,6 +33,115 @@ import ( ) func TestBuildSelectSQLByUpdateJoin(t *testing.T) { + MetaDataMap := map[string]*types.TableMeta{ + "table1": { + TableName: "table1", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "name", "age"}, + }, + "table2": { + TableName: "table2", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "name": { + ColumnDef: nil, + ColumnName: "name", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + "kk": { + ColumnDef: nil, + ColumnName: "kk", + }, + "addr": { + ColumnDef: nil, + ColumnName: "addr", + }, + }, + ColumnNames: []string{"id", "name", "age", "kk", "addr"}, + }, + "table3": { + TableName: "table3", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "age"}, + }, + "table4": { + TableName: "table4", + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, + }, + }, + Columns: map[string]types.ColumnMeta{ + "id": { + ColumnDef: nil, + ColumnName: "id", + }, + "age": { + ColumnDef: nil, + ColumnName: "age", + }, + }, + ColumnNames: []string{"id", "age"}, + }, + } + + undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) + tests := []struct { name string sourceQuery string From a5a1955debffad0fa487514446b441af66104072 Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 10 Jan 2025 19:35:28 +0800 Subject: [PATCH 7/8] recover update test --- .../sql/exec/at/update_executor_test.go | 58 ++++++++----------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go index 34173aafb..55b2e1395 100644 --- a/pkg/datasource/sql/exec/at/update_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -20,49 +20,39 @@ package at import ( "context" "database/sql/driver" - "seata.apache.org/seata-go/pkg/datasource/sql/undo" + "github.com/agiledragon/gomonkey/v2" + "reflect" "testing" "github.com/stretchr/testify/assert" + "seata.apache.org/seata-go/pkg/datasource/sql/datasource" + "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/parser" "seata.apache.org/seata-go/pkg/datasource/sql/types" + "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/util" _ "seata.apache.org/seata-go/pkg/util/log" ) func TestBuildSelectSQLByUpdate(t *testing.T) { - MetaDataMap := map[string]*types.TableMeta{ - "t_user": { - TableName: "t_user", - Indexs: map[string]types.IndexMeta{ - "id": { - IType: types.IndexTypePrimaryKey, - Columns: []types.ColumnMeta{ - {ColumnName: "id"}, + undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) + datasource.RegisterTableCache(types.DBTypeMySQL, mysql.NewTableMetaInstance(nil, nil)) + stub := gomonkey.ApplyMethod(reflect.TypeOf(datasource.GetTableCache(types.DBTypeMySQL)), "GetTableMeta", + func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string) (*types.TableMeta, error) { + return &types.TableMeta{ + Indexs: map[string]types.IndexMeta{ + "id": { + IType: types.IndexTypePrimaryKey, + Columns: []types.ColumnMeta{ + {ColumnName: "id"}, + }, }, }, - }, - Columns: map[string]types.ColumnMeta{ - "id": { - ColumnDef: nil, - ColumnName: "id", - }, - "name": { - ColumnDef: nil, - ColumnName: "name", - }, - "age": { - ColumnDef: nil, - ColumnName: "age", - }, - }, - ColumnNames: []string{"id", "name", "age"}, - }, - } - - undo.InitUndoConfig(undo.Config{OnlyCareUpdateColumns: true}) + }, nil + }) + defer stub.Reset() tests := []struct { name string @@ -74,25 +64,25 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { { sourceQuery: "update t_user set name = ?, age = ? where id = ?", sourceQueryArgs: []driver.Value{"Jack", 1, 100}, - expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE id=? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", expectQueryArgs: []driver.Value{100}, }, { sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", expectQueryArgs: []driver.Value{100, 18, 28}, }, { sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", expectQueryArgs: []driver.Value{100, 18, 28}, }, { sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, - expectQuery: "SELECT SQL_NO_CACHE name,age,t_user.id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, }, } @@ -101,7 +91,7 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { c, err := parser.DoParser(tt.sourceQuery) assert.Nil(t, err) executor := NewUpdateExecutor(c, &types.ExecContext{Values: tt.sourceQueryArgs, NamedValues: util.ValueToNamedValue(tt.sourceQueryArgs)}, []exec.SQLHook{}) - query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), MetaDataMap["t_user"], "", util.ValueToNamedValue(tt.sourceQueryArgs)) + query, args, err := executor.(*updateExecutor).buildBeforeImageSQL(context.Background(), util.ValueToNamedValue(tt.sourceQueryArgs)) assert.Nil(t, err) assert.Equal(t, tt.expectQuery, query) assert.Equal(t, tt.expectQueryArgs, util.NamedValueToValue(args)) From adabe14c7b7357e3602f065a98b4cf8b31968e6f Mon Sep 17 00:00:00 2001 From: lxfeng1997 <824141436@qq.com> Date: Fri, 10 Jan 2025 19:36:50 +0800 Subject: [PATCH 8/8] recover update test --- pkg/datasource/sql/exec/at/update_executor_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/datasource/sql/exec/at/update_executor_test.go b/pkg/datasource/sql/exec/at/update_executor_test.go index 55b2e1395..a6ffc9bef 100644 --- a/pkg/datasource/sql/exec/at/update_executor_test.go +++ b/pkg/datasource/sql/exec/at/update_executor_test.go @@ -20,10 +20,10 @@ package at import ( "context" "database/sql/driver" - "github.com/agiledragon/gomonkey/v2" "reflect" "testing" + "github.com/agiledragon/gomonkey/v2" "github.com/stretchr/testify/assert" "seata.apache.org/seata-go/pkg/datasource/sql/datasource"