Skip to content

Commit

Permalink
support updating primary/unique key columns in new DML
Browse files Browse the repository at this point in the history
  • Loading branch information
aunjgr committed Jan 10, 2025
1 parent 1c3a6e1 commit cbe5c18
Show file tree
Hide file tree
Showing 32 changed files with 2,641 additions and 1,768 deletions.
1,573 changes: 916 additions & 657 deletions pkg/pb/pipeline/pipeline.pb.go

Large diffs are not rendered by default.

1,832 changes: 991 additions & 841 deletions pkg/pb/plan/plan.pb.go

Large diffs are not rendered by default.

137 changes: 100 additions & 37 deletions pkg/sql/colexec/dedupjoin/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (dedupJoin *DedupJoin) build(analyzer process.Analyzer, proc *process.Proce
if dedupJoin.OnDuplicateAction != plan.Node_UPDATE {
ctr.matched.InitWithSize(ctr.batchRowCount)
} else {
ctr.matched.InitWithSize(int64(ctr.mp.GetGroupCount()) + 1)
ctr.matched.InitWithSize(int64(ctr.mp.GetGroupCount()))
}
}
return
Expand Down Expand Up @@ -196,10 +196,27 @@ func (ctr *container) finalize(ap *DedupJoin, proc *process.Process) error {
}
}

if ap.OnDuplicateAction != plan.Node_UPDATE {
if ap.OnDuplicateAction != plan.Node_UPDATE || ctr.mp.HashOnUnique() {
if ctr.matched.Count() == 0 {
ap.ctr.buf = ctr.batches
ctr.batches = nil
//ap.ctr.buf = ctr.batches
ap.ctr.buf = make([]*batch.Batch, len(ctr.batches))
for i := range ap.ctr.buf {
ap.ctr.buf[i] = batch.NewWithSize(len(ap.Result))
batSize := ctr.batches[i].Vecs[0].Length()
for j, rp := range ap.Result {
if rp.Rel == 1 {
ap.ctr.buf[i].SetVector(int32(j), ctr.batches[i].Vecs[rp.Pos])
ctr.batches[i].Vecs[rp.Pos] = nil
} else {
ap.ctr.buf[i].Vecs[j] = vector.NewVec(ap.LeftTypes[rp.Pos])
if err := vector.AppendMultiFixed(ap.ctr.buf[i].Vecs[j], 0, true, batSize, proc.Mp()); err != nil {
return err
}
}
}

ap.ctr.buf[i].SetRowCount(batSize)
}

return nil
}
Expand Down Expand Up @@ -295,7 +312,7 @@ func (ctr *container) finalize(ap *DedupJoin, proc *process.Process) error {
ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())

bitmapLen := uint64(ctr.matched.Len())
for i := uint64(1); i < bitmapLen; i++ {
for i := uint64(0); i < bitmapLen; i++ {
if ctr.matched.Contains(i) {
continue
}
Expand All @@ -311,41 +328,59 @@ func (ctr *container) finalize(ap *DedupJoin, proc *process.Process) error {
}
}

sels = ctr.mp.GetSels(i)
sels = ctr.mp.GetSels(i + 1)
idx1, idx2 := sels[0]/colexec.DefaultBatchSize, sels[0]%colexec.DefaultBatchSize
err := colexec.SetJoinBatchValues(ctr.joinBat1, ctr.batches[idx1], int64(idx2), 1, ctr.cfs1)
if err != nil {
return err
}

for _, sel := range sels[1:] {
idx1, idx2 = sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
err = colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2), 1, ctr.cfs2)
if len(sels) == 1 {
for j, rp := range ap.Result {
if rp.Rel == 1 {
if err := ap.ctr.buf[batIdx].Vecs[j].UnionOne(ctr.batches[idx1].Vecs[rp.Pos], int64(idx2), proc.Mp()); err != nil {
return err
}
} else {
if err := ap.ctr.buf[batIdx].Vecs[j].UnionNull(proc.Mp()); err != nil {
return err
}
}
}
} else {
err := colexec.SetJoinBatchValues(ctr.joinBat1, ctr.batches[idx1], int64(idx2), 1, ctr.cfs1)
if err != nil {
return err
}

vecs := make([]*vector.Vector, len(ctr.exprExecs))
for j, exprExec := range ctr.exprExecs {
vecs[j], err = exprExec.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2}, nil)
if ctr.joinBat2 == nil {
ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
}

for _, sel := range sels[1:] {
idx1, idx2 = sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
err = colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2), 1, ctr.cfs2)
if err != nil {
return err
}
}

for j, pos := range ap.UpdateColIdxList {
ctr.joinBat1.Vecs[pos] = vecs[j]
}
}
vecs := make([]*vector.Vector, len(ctr.exprExecs))
for j, exprExec := range ctr.exprExecs {
vecs[j], err = exprExec.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2}, nil)
if err != nil {
return err
}
}

for j, rp := range ap.Result {
if rp.Rel == 1 {
if err := ap.ctr.buf[batIdx].Vecs[j].UnionOne(ctr.joinBat1.Vecs[rp.Pos], 0, proc.Mp()); err != nil {
return err
for j, pos := range ap.UpdateColIdxList {
ctr.joinBat1.Vecs[pos] = vecs[j]
}
} else {
if err := ap.ctr.buf[batIdx].Vecs[j].UnionNull(proc.Mp()); err != nil {
return err
}

for j, rp := range ap.Result {
if rp.Rel == 1 {
if err := ap.ctr.buf[batIdx].Vecs[j].UnionOne(ctr.joinBat1.Vecs[rp.Pos], 0, proc.Mp()); err != nil {
return err
}
} else {
if err := ap.ctr.buf[batIdx].Vecs[j].UnionNull(proc.Mp()); err != nil {
return err
}
}
}
}
Expand All @@ -369,11 +404,14 @@ func (ctr *container) probe(bat *batch.Batch, ap *DedupJoin, proc *process.Proce
if err != nil {
return err
}
if ctr.joinBat1 == nil {
ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
}
if ctr.joinBat2 == nil && ctr.batchRowCount > 0 {
ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())

if ap.OnDuplicateAction == plan.Node_UPDATE {
if ctr.joinBat1 == nil {
ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp())
}
if ctr.joinBat2 == nil && ctr.batchRowCount > 0 {
ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp())
}
}

rowCntInc := 0
Expand All @@ -393,6 +431,10 @@ func (ctr *container) probe(bat *batch.Batch, ap *DedupJoin, proc *process.Proce

switch ap.OnDuplicateAction {
case plan.Node_FAIL:
if ctr.mp.IsDeleted(vals[k] - 1) {
continue
}

// do nothing for txn.mode = Optimistic
if !isPessimistic {
continue
Expand Down Expand Up @@ -430,8 +472,8 @@ func (ctr *container) probe(bat *batch.Batch, ap *DedupJoin, proc *process.Proce
return err
}

sels := ctr.mp.GetSels(vals[k])
for _, sel := range sels {
if ctr.mp.HashOnUnique() {
sel := vals[k] - 1
idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
err = colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2), 1, ctr.cfs2)
if err != nil {
Expand All @@ -449,6 +491,27 @@ func (ctr *container) probe(bat *batch.Batch, ap *DedupJoin, proc *process.Proce
for j, pos := range ap.UpdateColIdxList {
ctr.joinBat1.Vecs[pos] = vecs[j]
}
} else {
sels := ctr.mp.GetSels(vals[k])
for _, sel := range sels {
idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize
err = colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2), 1, ctr.cfs2)
if err != nil {
return err
}

vecs := make([]*vector.Vector, len(ctr.exprExecs))
for j, exprExec := range ctr.exprExecs {
vecs[j], err = exprExec.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2}, nil)
if err != nil {
return err
}
}

for j, pos := range ap.UpdateColIdxList {
ctr.joinBat1.Vecs[pos] = vecs[j]
}
}
}

for j, rp := range ap.Result {
Expand All @@ -471,7 +534,7 @@ func (ctr *container) probe(bat *batch.Batch, ap *DedupJoin, proc *process.Proce
}
}

ctr.matched.Add(vals[k])
ctr.matched.Add(vals[k] - 1)
rowCntInc++
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/colexec/dedupjoin/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func TestDedupJoin(t *testing.T) {
err := tc.arg.Prepare(tc.proc)
require.NoError(t, err)
tc.barg.IsDedup = true
tc.barg.DelColIdx = -1
err = tc.barg.Prepare(tc.proc)
require.NoError(t, err)

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/colexec/dedupjoin/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ type DedupJoin struct {
OnDuplicateAction plan.Node_OnDuplicateAction
DedupColName string
DedupColTypes []plan.Type
DelColIdx int32
UpdateColIdxList []int32
UpdateColExprList []*plan.Expr

Expand Down
19 changes: 10 additions & 9 deletions pkg/sql/colexec/hashbuild/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ func (hashBuild *HashBuild) Prepare(proc *process.Process) (err error) {
hashBuild.OpAnalyzer.Reset()
}

if hashBuild.NeedHashMap {
hashBuild.ctr.hashmapBuilder.IsDedup = hashBuild.IsDedup
hashBuild.ctr.hashmapBuilder.OnDuplicateAction = hashBuild.OnDuplicateAction
hashBuild.ctr.hashmapBuilder.DedupColName = hashBuild.DedupColName
hashBuild.ctr.hashmapBuilder.DedupColTypes = hashBuild.DedupColTypes
return hashBuild.ctr.hashmapBuilder.Prepare(hashBuild.Conditions, proc)
if !hashBuild.NeedHashMap {
return nil
}
return nil

hashBuild.ctr.hashmapBuilder.IsDedup = hashBuild.IsDedup
hashBuild.ctr.hashmapBuilder.OnDuplicateAction = hashBuild.OnDuplicateAction
hashBuild.ctr.hashmapBuilder.DedupColName = hashBuild.DedupColName
hashBuild.ctr.hashmapBuilder.DedupColTypes = hashBuild.DedupColTypes

return hashBuild.ctr.hashmapBuilder.Prepare(hashBuild.Conditions, hashBuild.DelColIdx, proc)
}

func (hashBuild *HashBuild) Call(proc *process.Process) (vm.CallResult, error) {
Expand All @@ -75,9 +77,8 @@ func (hashBuild *HashBuild) Call(proc *process.Process) (vm.CallResult, error) {
case SendJoinMap:
var jm *message.JoinMap
if ctr.hashmapBuilder.InputBatchRowCount > 0 {
jm = message.NewJoinMap(ctr.hashmapBuilder.MultiSels, ctr.hashmapBuilder.IntHashMap, ctr.hashmapBuilder.StrHashMap, ctr.hashmapBuilder.Batches.Buf, proc.Mp())
jm = message.NewJoinMap(ctr.hashmapBuilder.MultiSels, ctr.hashmapBuilder.IntHashMap, ctr.hashmapBuilder.StrHashMap, ctr.hashmapBuilder.DelRows, ctr.hashmapBuilder.Batches.Buf, proc.Mp())
jm.SetPushedRuntimeFilterIn(ctr.runtimeFilterIn)
//jm.SetIgnoreRows(ctr.hashmapBuilder.IgnoreRows)
if ap.NeedBatches {
jm.SetRowCount(int64(ctr.hashmapBuilder.InputBatchRowCount))
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/colexec/hashbuild/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type HashBuild struct {
RuntimeFilterSpec *plan.RuntimeFilterSpec

IsDedup bool
DelColIdx int32
OnDuplicateAction plan.Node_OnDuplicateAction
DedupColName string
DedupColTypes []plan.Type
Expand Down
Loading

0 comments on commit cbe5c18

Please sign in to comment.