-
Notifications
You must be signed in to change notification settings - Fork 286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bugfix: error image when use null value as image query condition in insert on duplicate #704 #725
base: master
Are you sure you want to change the base?
Changes from all commits
abdd938
4debbde
42fb93a
56f2645
c52fd82
321906e
b60e020
5675b29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,68 +97,108 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildBeforeImageSQL(insertStmt *a | |
if err := checkDuplicateKeyUpdate(insertStmt, metaData); err != nil { | ||
return "", nil, err | ||
} | ||
var selectArgs []driver.Value | ||
u.BeforeImageSqlPrimaryKeys = make(map[string]bool, len(metaData.Indexs)) | ||
pkIndexMap := u.getPkIndex(insertStmt, metaData) | ||
var pkIndexArray []int | ||
for _, val := range pkIndexMap { | ||
tmpVal := val | ||
pkIndexArray = append(pkIndexArray, tmpVal) | ||
pkIndexArray = append(pkIndexArray, val) | ||
} | ||
insertRows, err := getInsertRows(insertStmt, pkIndexArray) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
insertNum := len(insertRows) | ||
paramMap, err := u.buildImageParameters(insertStmt, args, insertRows) | ||
if err != nil { | ||
return "", nil, err | ||
} | ||
|
||
sql := strings.Builder{} | ||
sql.WriteString("SELECT * FROM " + metaData.TableName + " ") | ||
if len(paramMap) == 0 || len(metaData.Indexs) == 0 { | ||
return "", nil, nil | ||
} | ||
hasPK := false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
goto 不太好理解? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
那这里还需要改为goto吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不要使用 goto 吧 |
||
for _, index := range metaData.Indexs { | ||
if strings.EqualFold("PRIMARY", index.Name) { | ||
allPKColumnsHaveValue := true | ||
for _, col := range index.Columns { | ||
if params, ok := paramMap[col.ColumnName]; !ok || len(params) == 0 || params[0] == nil { | ||
allPKColumnsHaveValue = false | ||
break | ||
} | ||
} | ||
hasPK = allPKColumnsHaveValue | ||
break | ||
} | ||
} | ||
if !hasPK { | ||
hasValidUniqueIndex := false | ||
for _, index := range metaData.Indexs { | ||
if !index.NonUnique && !strings.EqualFold("PRIMARY", index.Name) { | ||
if _, _, valid := validateIndexPrefix(index, paramMap, 0); valid { | ||
hasValidUniqueIndex = true | ||
break | ||
} | ||
} | ||
} | ||
if !hasValidUniqueIndex { | ||
return "", nil, nil | ||
} | ||
} | ||
var sql strings.Builder | ||
sql.WriteString("SELECT * FROM " + metaData.TableName + " ") | ||
var selectArgs []driver.Value | ||
isContainWhere := false | ||
for i := 0; i < insertNum; i++ { | ||
finalI := i | ||
paramAppenderTempList := make([]driver.Value, 0) | ||
hasConditions := false | ||
for i := 0; i < len(insertRows); i++ { | ||
var rowConditions []string | ||
var rowArgs []driver.Value | ||
usedParams := make(map[string]bool) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add cap |
||
|
||
// First try unique indexes | ||
for _, index := range metaData.Indexs { | ||
//unique index | ||
if index.NonUnique || isIndexValueNotNull(index, paramMap, finalI) == false { | ||
if index.NonUnique || strings.EqualFold("PRIMARY", index.Name) { | ||
continue | ||
} | ||
columnIsNull := true | ||
uniqueList := make([]string, 0) | ||
for _, columnMeta := range index.Columns { | ||
columnName := columnMeta.ColumnName | ||
imageParameters, ok := paramMap[columnName] | ||
if !ok && columnMeta.ColumnDef != nil { | ||
if strings.EqualFold("PRIMARY", index.Name) { | ||
u.BeforeImageSqlPrimaryKeys[columnName] = true | ||
} | ||
uniqueList = append(uniqueList, columnName+" = DEFAULT("+columnName+") ") | ||
columnIsNull = false | ||
continue | ||
} | ||
if strings.EqualFold("PRIMARY", index.Name) { | ||
u.BeforeImageSqlPrimaryKeys[columnName] = true | ||
if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid { | ||
rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")") | ||
rowArgs = append(rowArgs, args...) | ||
hasConditions = true | ||
for _, colMeta := range index.Columns { | ||
usedParams[colMeta.ColumnName] = true | ||
} | ||
columnIsNull = false | ||
uniqueList = append(uniqueList, columnName+" = ? ") | ||
paramAppenderTempList = append(paramAppenderTempList, imageParameters[finalI]) | ||
} | ||
} | ||
|
||
if !columnIsNull { | ||
if isContainWhere { | ||
sql.WriteString(" OR (" + strings.Join(uniqueList, " and ") + ") ") | ||
} else { | ||
sql.WriteString(" WHERE (" + strings.Join(uniqueList, " and ") + ") ") | ||
isContainWhere = true | ||
// Then try primary key | ||
for _, index := range metaData.Indexs { | ||
if !strings.EqualFold("PRIMARY", index.Name) { | ||
continue | ||
} | ||
if conditions, args, valid := validateIndexPrefix(index, paramMap, i); valid { | ||
rowConditions = append(rowConditions, "("+strings.Join(conditions, " and ")+")") | ||
rowArgs = append(rowArgs, args...) | ||
hasConditions = true | ||
for _, colMeta := range index.Columns { | ||
usedParams[colMeta.ColumnName] = true | ||
} | ||
} | ||
} | ||
selectArgs = append(selectArgs, paramAppenderTempList...) | ||
|
||
if len(rowConditions) > 0 { | ||
if !isContainWhere { | ||
sql.WriteString("WHERE ") | ||
isContainWhere = true | ||
} else { | ||
sql.WriteString(" OR ") | ||
} | ||
sql.WriteString(strings.Join(rowConditions, " OR ") + " ") | ||
selectArgs = append(selectArgs, rowArgs...) | ||
} | ||
} | ||
if !hasConditions { | ||
return "", nil, nil | ||
} | ||
log.Infof("build select sql by insert on duplicate sourceQuery, sql {}", sql.String()) | ||
return sql.String(), selectArgs, nil | ||
sqlStr := sql.String() | ||
log.Infof("build select sql by insert on duplicate sourceQuery, sql: %s", sqlStr) | ||
return sqlStr, selectArgs, nil | ||
} | ||
|
||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { | ||
|
@@ -168,14 +208,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e | |
log.Errorf("build prepare stmt: %+v", err) | ||
return nil, err | ||
} | ||
|
||
defer stmt.Close() | ||
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O | ||
metaData := execCtx.MetaDataMap[tableName] | ||
rows, err := stmt.Query(selectArgs) | ||
if err != nil { | ||
log.Errorf("stmt query: %+v", err) | ||
return nil, err | ||
} | ||
tableName := execCtx.ParseContext.InsertStmt.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O | ||
metaData := execCtx.MetaDataMap[tableName] | ||
defer rows.Close() | ||
image, err := u.buildRecordImages(rows, &metaData) | ||
if err != nil { | ||
return nil, err | ||
|
@@ -185,11 +225,13 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) AfterImage(ctx context.Context, e | |
|
||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Context, beforeImages []*types.RecordImage) (string, []driver.Value) { | ||
selectSQL, selectArgs := u.BeforeSelectSql, u.Args | ||
|
||
var beforeImage *types.RecordImage | ||
if len(beforeImages) > 0 { | ||
beforeImage = beforeImages[0] | ||
} | ||
if beforeImage == nil || len(beforeImage.Rows) == 0 { | ||
return selectSQL, selectArgs | ||
} | ||
primaryValueMap := make(map[string][]interface{}) | ||
for _, row := range beforeImage.Rows { | ||
for _, col := range row.Columns { | ||
|
@@ -198,25 +240,46 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildAfterImageSQL(ctx context.Co | |
} | ||
} | ||
} | ||
|
||
var afterImageSql strings.Builder | ||
var primaryValues []driver.Value | ||
afterImageSql.WriteString(selectSQL) | ||
for i := 0; i < len(beforeImage.Rows); i++ { | ||
wherePrimaryList := make([]string, 0) | ||
for name, value := range primaryValueMap { | ||
if !u.BeforeImageSqlPrimaryKeys[name] { | ||
wherePrimaryList = append(wherePrimaryList, name+" = ? ") | ||
primaryValues = append(primaryValues, value[i]) | ||
if len(primaryValueMap) == 0 || len(selectArgs) == len(beforeImage.Rows)*len(primaryValueMap) { | ||
return selectSQL, selectArgs | ||
} | ||
var primaryValues []driver.Value | ||
usedPrimaryKeys := make(map[string]bool) | ||
for name := range primaryValueMap { | ||
if !u.BeforeImageSqlPrimaryKeys[name] { | ||
usedPrimaryKeys[name] = true | ||
for i := 0; i < len(beforeImage.Rows); i++ { | ||
if value := primaryValueMap[name][i]; value != nil { | ||
if dv, ok := value.(driver.Value); ok { | ||
primaryValues = append(primaryValues, dv) | ||
} else { | ||
primaryValues = append(primaryValues, value) | ||
} | ||
} | ||
} | ||
} | ||
if len(wherePrimaryList) != 0 { | ||
afterImageSql.WriteString(" OR (" + strings.Join(wherePrimaryList, " and ") + ") ") | ||
} | ||
if len(primaryValues) > 0 { | ||
afterImageSql.WriteString(" OR (" + strings.Join(u.buildPrimaryKeyConditions(primaryValueMap, usedPrimaryKeys), " and ") + ") ") | ||
} | ||
finalArgs := make([]driver.Value, len(selectArgs)+len(primaryValues)) | ||
copy(finalArgs, selectArgs) | ||
copy(finalArgs[len(selectArgs):], primaryValues) | ||
sqlStr := afterImageSql.String() | ||
log.Infof("build after select sql by insert on duplicate sourceQuery, sql %s", sqlStr) | ||
return sqlStr, finalArgs | ||
} | ||
|
||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildPrimaryKeyConditions(primaryValueMap map[string][]interface{}, usedPrimaryKeys map[string]bool) []string { | ||
var conditions []string | ||
for name := range primaryValueMap { | ||
if !usedPrimaryKeys[name] { | ||
conditions = append(conditions, name+" = ? ") | ||
} | ||
} | ||
selectArgs = append(selectArgs, primaryValues...) | ||
log.Infof("build after select sql by insert on duplicate sourceQuery, sql {}", afterImageSql.String()) | ||
return afterImageSql.String(), selectArgs | ||
return conditions | ||
} | ||
|
||
func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) error { | ||
|
@@ -243,11 +306,10 @@ func checkDuplicateKeyUpdate(insert *ast.InsertStmt, metaData types.TableMeta) e | |
|
||
// build sql params | ||
func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast.InsertStmt, args []driver.Value, insertRows [][]interface{}) (map[string][]driver.Value, error) { | ||
var ( | ||
parameterMap = make(map[string][]driver.Value) | ||
) | ||
parameterMap := make(map[string][]driver.Value) | ||
insertColumns := getInsertColumns(insert) | ||
var placeHolderIndex = 0 | ||
placeHolderIndex := 0 | ||
|
||
for _, row := range insertRows { | ||
if len(row) != len(insertColumns) { | ||
log.Errorf("insert row's column size not equal to insert column size") | ||
|
@@ -256,13 +318,14 @@ func (u *MySQLInsertOnDuplicateUndoLogBuilder) buildImageParameters(insert *ast. | |
for i, col := range insertColumns { | ||
columnName := executor.DelEscape(col, types.DBTypeMySQL) | ||
val := row[i] | ||
rStr, ok := val.(string) | ||
if ok && strings.EqualFold(rStr, SqlPlaceholder) { | ||
objects := args[placeHolderIndex] | ||
parameterMap[columnName] = append(parameterMap[col], objects) | ||
if str, ok := val.(string); ok && strings.EqualFold(str, SqlPlaceholder) { | ||
if placeHolderIndex >= len(args) { | ||
return nil, fmt.Errorf("not enough parameters for placeholders") | ||
} | ||
parameterMap[columnName] = append(parameterMap[columnName], args[placeHolderIndex]) | ||
placeHolderIndex++ | ||
} else { | ||
parameterMap[columnName] = append(parameterMap[col], val) | ||
parameterMap[columnName] = append(parameterMap[columnName], val) | ||
} | ||
} | ||
} | ||
|
@@ -296,3 +359,28 @@ func isIndexValueNotNull(indexMeta types.IndexMeta, imageParameterMap map[string | |
} | ||
return true | ||
} | ||
|
||
func validateIndexPrefix(index types.IndexMeta, paramMap map[string][]driver.Value, rowIndex int) ([]string, []driver.Value, bool) { | ||
var indexConditions []string | ||
var indexArgs []driver.Value | ||
if len(index.Columns) > 1 { | ||
for _, colMeta := range index.Columns { | ||
params, ok := paramMap[colMeta.ColumnName] | ||
if !ok || len(params) <= rowIndex || params[rowIndex] == nil { | ||
return nil, nil, false | ||
} | ||
} | ||
} | ||
for _, colMeta := range index.Columns { | ||
columnName := colMeta.ColumnName | ||
params, ok := paramMap[columnName] | ||
if ok && len(params) > rowIndex && params[rowIndex] != nil { | ||
indexConditions = append(indexConditions, columnName+" = ? ") | ||
indexArgs = append(indexArgs, params[rowIndex]) | ||
} | ||
} | ||
if len(indexConditions) != len(index.Columns) { | ||
return nil, nil, false | ||
} | ||
return indexConditions, indexArgs, true | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use make and cap