diff --git a/br/pkg/lightning/backend/local/duplicate.go b/br/pkg/lightning/backend/local/duplicate.go index 6b0c9f9d66978..daeacdcf19408 100644 --- a/br/pkg/lightning/backend/local/duplicate.go +++ b/br/pkg/lightning/backend/local/duplicate.go @@ -29,11 +29,15 @@ import ( "github.com/pingcap/kvproto/pkg/errorpb" "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/kvrpcpb" +<<<<<<< HEAD:br/pkg/lightning/backend/local/duplicate.go pkgkv "github.com/pingcap/tidb/br/pkg/kv" "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/errormanager" "github.com/pingcap/tidb/br/pkg/lightning/log" +======= + berrors "github.com/pingcap/tidb/br/pkg/errors" +>>>>>>> 0805e850d41 (br: handle region leader miss (#52822)):pkg/lightning/backend/local/duplicate.go "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/restore" "github.com/pingcap/tidb/br/pkg/utils" @@ -335,7 +339,8 @@ func getDupDetectClient( ) (import_sstpb.ImportSST_DuplicateDetectClient, error) { leader := region.Leader if leader == nil { - leader = region.Region.GetPeers()[0] + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", region.Region.Id) } importClient, err := importClientFactory.Create(ctx, leader.GetStoreId()) if err != nil { diff --git a/br/pkg/restore/import.go b/br/pkg/restore/import.go index 556596bd4fdd0..c6f0aa358f75e 100644 --- a/br/pkg/restore/import.go +++ b/br/pkg/restore/import.go @@ -605,7 +605,8 @@ func (importer *FileImporter) ingestSSTs( ) (*import_sstpb.IngestResponse, error) { leader := regionInfo.Leader if leader == nil { - leader = regionInfo.Region.GetPeers()[0] + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", regionInfo.Region.Id) } reqCtx := &kvrpcpb.Context{ RegionId: regionInfo.Region.GetId(), diff --git a/br/pkg/restore/import_retry_test.go b/br/pkg/restore/import_retry_test.go new file mode 100644 index 0000000000000..8e2a386b0e5f5 --- /dev/null +++ b/br/pkg/restore/import_retry_test.go @@ -0,0 +1,638 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +package restore + +import ( + "context" + "encoding/hex" + "fmt" + "os" + "strconv" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/store/pdtypes" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func assertDecode(t *testing.T, key []byte) []byte { + if len(key) == 0 { + return []byte{} + } + _, decoded, err := codec.DecodeBytes(key, nil) + require.NoError(t, err) + return decoded +} + +func assertRegions(t *testing.T, regions []*split.RegionInfo, keys ...string) { + require.Equal(t, len(regions)+1, len(keys), "%+v\nvs\n%+v", regions, keys) + last := keys[0] + for i, r := range regions { + start := assertDecode(t, r.Region.StartKey) + end := assertDecode(t, r.Region.EndKey) + + require.Equal(t, start, []byte(last), "not match for region: %+v", *r) + last = keys[i+1] + require.Equal(t, end, []byte(last), "not match for region: %+v", *r) + } +} + +// region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) +func initTestClient(isRawKv bool) *TestClient { + peers := make([]*metapb.Peer, 1) + peers[0] = &metapb.Peer{ + Id: 1, + StoreId: 1, + } + keys := [6]string{"", "aay", "bba", "bbh", "cca", ""} + regions := make(map[uint64]*split.RegionInfo) + for i := uint64(1); i < 6; i++ { + startKey := []byte(keys[i-1]) + if len(startKey) != 0 { + startKey = codec.EncodeBytesExt([]byte{}, startKey, isRawKv) + } + endKey := []byte(keys[i]) + if len(endKey) != 0 { + endKey = codec.EncodeBytesExt([]byte{}, endKey, isRawKv) + } + regions[i] = &split.RegionInfo{ + Leader: &metapb.Peer{ + Id: i, + StoreId: 1, + }, + Region: &metapb.Region{ + Id: i, + Peers: peers, + StartKey: startKey, + EndKey: endKey, + }, + } + } + stores := make(map[uint64]*metapb.Store) + stores[1] = &metapb.Store{ + Id: 1, + } + return NewTestClient(stores, regions, 6) +} + +func TestScanSuccess(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(1, 0, 0) + ctx := context.Background() + + // make exclusive to inclusive. + ctl := OverRegionsInRange([]byte("aa"), []byte("aay"), cli, &rs) + collectedRegions := []*split.RegionInfo{} + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + collectedRegions = append(collectedRegions, r) + return RPCResultOK() + }) + assertRegions(t, collectedRegions, "", "aay", "bba") + + ctl = OverRegionsInRange([]byte("aaz"), []byte("bb"), cli, &rs) + collectedRegions = []*split.RegionInfo{} + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + collectedRegions = append(collectedRegions, r) + return RPCResultOK() + }) + assertRegions(t, collectedRegions, "aay", "bba", "bbh", "cca") + + ctl = OverRegionsInRange([]byte("aa"), []byte("cc"), cli, &rs) + collectedRegions = []*split.RegionInfo{} + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + collectedRegions = append(collectedRegions, r) + return RPCResultOK() + }) + assertRegions(t, collectedRegions, "", "aay", "bba", "bbh", "cca", "") + + ctl = OverRegionsInRange([]byte("aa"), []byte(""), cli, &rs) + collectedRegions = []*split.RegionInfo{} + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + collectedRegions = append(collectedRegions, r) + return RPCResultOK() + }) + assertRegions(t, collectedRegions, "", "aay", "bba", "bbh", "cca", "") +} + +func TestNotLeader(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(1, 0, 0) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + notLeader := errorpb.Error{ + NotLeader: &errorpb.NotLeader{ + Leader: &metapb.Peer{ + Id: 42, + }, + }, + } + // record the regions we didn't touch. + meetRegions := []*split.RegionInfo{} + // record all regions we meet with id == 2. + idEqualsTo2Regions := []*split.RegionInfo{} + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + if r.Region.Id == 2 { + idEqualsTo2Regions = append(idEqualsTo2Regions, r) + } + if r.Region.Id == 2 && (r.Leader == nil || r.Leader.Id != 42) { + return RPCResult{ + StoreError: ¬Leader, + } + } + meetRegions = append(meetRegions, r) + return RPCResultOK() + }) + + require.NoError(t, err) + require.Len(t, idEqualsTo2Regions, 2) + if idEqualsTo2Regions[1].Leader != nil { + require.NotEqual(t, 42, idEqualsTo2Regions[0].Leader.Id) + } + require.EqualValues(t, 42, idEqualsTo2Regions[1].Leader.Id) + assertRegions(t, meetRegions, "", "aay", "bba", "bbh", "cca", "") +} + +func TestServerIsBusy(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, 0, 0) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + serverIsBusy := errorpb.Error{ + Message: "server is busy", + ServerIsBusy: &errorpb.ServerIsBusy{ + Reason: "memory is out", + }, + } + // record the regions we didn't touch. + meetRegions := []*split.RegionInfo{} + // record all regions we meet with id == 2. + idEqualsTo2Regions := []*split.RegionInfo{} + theFirstRun := true + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + if theFirstRun && r.Region.Id == 2 { + idEqualsTo2Regions = append(idEqualsTo2Regions, r) + theFirstRun = false + return RPCResult{ + StoreError: &serverIsBusy, + } + } + meetRegions = append(meetRegions, r) + return RPCResultOK() + }) + + require.NoError(t, err) + assertRegions(t, idEqualsTo2Regions, "aay", "bba") + assertRegions(t, meetRegions, "", "aay", "bba", "bbh", "cca", "") + require.Equal(t, rs.Attempt(), 1) +} + +func TestServerIsBusyWithMemoryIsLimited(t *testing.T) { + _ = failpoint.Enable("github.com/pingcap/tidb/br/pkg/restore/hint-memory-is-limited", "return(true)") + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/restore/hint-memory-is-limited") + }() + + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, 0, 0) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + serverIsBusy := errorpb.Error{ + Message: "memory is limited", + ServerIsBusy: &errorpb.ServerIsBusy{ + Reason: "", + }, + } + // record the regions we didn't touch. + meetRegions := []*split.RegionInfo{} + // record all regions we meet with id == 2. + idEqualsTo2Regions := []*split.RegionInfo{} + theFirstRun := true + err := ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + if theFirstRun && r.Region.Id == 2 { + idEqualsTo2Regions = append(idEqualsTo2Regions, r) + theFirstRun = false + return RPCResult{ + StoreError: &serverIsBusy, + } + } + meetRegions = append(meetRegions, r) + return RPCResultOK() + }) + + require.NoError(t, err) + assertRegions(t, idEqualsTo2Regions, "aay", "bba") + assertRegions(t, meetRegions, "", "aay", "bba", "bbh", "cca", "") + require.Equal(t, rs.Attempt(), 2) +} + +func printRegion(name string, infos []*split.RegionInfo) { + fmt.Printf(">>>>> %s <<<<<\n", name) + for _, info := range infos { + fmt.Printf("[%04d] %s ~ %s\n", info.Region.Id, hex.EncodeToString(info.Region.StartKey), hex.EncodeToString(info.Region.EndKey)) + } + fmt.Printf("<<<<< %s >>>>>\n", name) +} + +func printPDRegion(name string, infos []*pdtypes.Region) { + fmt.Printf(">>>>> %s <<<<<\n", name) + for _, info := range infos { + fmt.Printf("[%04d] %s ~ %s\n", info.Meta.Id, hex.EncodeToString(info.Meta.StartKey), hex.EncodeToString(info.Meta.EndKey)) + } + fmt.Printf("<<<<< %s >>>>>\n", name) +} + +func TestEpochNotMatch(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, 0, 0) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + printPDRegion("cli", cli.regionsInfo.Regions) + regions, err := split.PaginateScanRegion(ctx, cli, []byte("aaz"), []byte("bbb"), 2) + require.NoError(t, err) + require.Len(t, regions, 2) + left, right := regions[0], regions[1] + info := split.RegionInfo{ + Region: &metapb.Region{ + StartKey: left.Region.StartKey, + EndKey: right.Region.EndKey, + Id: 42, + Peers: []*metapb.Peer{ + {Id: 43}, + }, + }, + Leader: &metapb.Peer{Id: 43, StoreId: 1}, + } + newRegion := pdtypes.NewRegionInfo(info.Region, info.Leader) + mergeRegion := func() { + cli.regionsInfo.SetRegion(newRegion) + cli.regions[42] = &info + } + epochNotMatch := &import_sstpb.Error{ + Message: "Epoch not match", + StoreError: &errorpb.Error{ + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: []*metapb.Region{info.Region}, + }, + }} + firstRunRegions := []*split.RegionInfo{} + secondRunRegions := []*split.RegionInfo{} + isSecondRun := false + err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + if !isSecondRun && r.Region.Id == left.Region.Id { + mergeRegion() + isSecondRun = true + return RPCResultFromPBError(epochNotMatch) + } + if isSecondRun { + secondRunRegions = append(secondRunRegions, r) + } else { + firstRunRegions = append(firstRunRegions, r) + } + return RPCResultOK() + }) + printRegion("first", firstRunRegions) + printRegion("second", secondRunRegions) + printPDRegion("cli", cli.regionsInfo.Regions) + assertRegions(t, firstRunRegions, "", "aay") + assertRegions(t, secondRunRegions, "", "aay", "bbh", "cca", "") + require.NoError(t, err) +} + +func TestRegionSplit(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, 0, 0) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + printPDRegion("cli", cli.regionsInfo.Regions) + regions, err := split.PaginateScanRegion(ctx, cli, []byte("aaz"), []byte("aazz"), 1) + require.NoError(t, err) + require.Len(t, regions, 1) + target := regions[0] + + newRegions := []*split.RegionInfo{ + { + Region: &metapb.Region{ + Id: 42, + StartKey: target.Region.StartKey, + EndKey: codec.EncodeBytes(nil, []byte("aayy")), + }, + Leader: &metapb.Peer{ + Id: 43, + StoreId: 1, + }, + }, + { + Region: &metapb.Region{ + Id: 44, + StartKey: codec.EncodeBytes(nil, []byte("aayy")), + EndKey: target.Region.EndKey, + }, + Leader: &metapb.Peer{ + Id: 45, + StoreId: 1, + }, + }, + } + splitRegion := func() { + for _, r := range newRegions { + newRegion := pdtypes.NewRegionInfo(r.Region, r.Leader) + cli.regionsInfo.SetRegion(newRegion) + cli.regions[r.Region.Id] = r + } + } + epochNotMatch := &import_sstpb.Error{ + Message: "Epoch not match", + StoreError: &errorpb.Error{ + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: []*metapb.Region{ + newRegions[0].Region, + newRegions[1].Region, + }, + }, + }} + firstRunRegions := []*split.RegionInfo{} + secondRunRegions := []*split.RegionInfo{} + isSecondRun := false + err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + if !isSecondRun && r.Region.Id == target.Region.Id { + splitRegion() + isSecondRun = true + return RPCResultFromPBError(epochNotMatch) + } + if isSecondRun { + secondRunRegions = append(secondRunRegions, r) + } else { + firstRunRegions = append(firstRunRegions, r) + } + return RPCResultOK() + }) + printRegion("first", firstRunRegions) + printRegion("second", secondRunRegions) + printPDRegion("cli", cli.regionsInfo.Regions) + assertRegions(t, firstRunRegions, "", "aay") + assertRegions(t, secondRunRegions, "", "aay", "aayy", "bba", "bbh", "cca", "") + require.NoError(t, err) +} + +func TestRetryBackoff(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, time.Millisecond, 10*time.Millisecond) + ctl := OverRegionsInRange([]byte(""), []byte(""), cli, &rs) + ctx := context.Background() + + printPDRegion("cli", cli.regionsInfo.Regions) + regions, err := split.PaginateScanRegion(ctx, cli, []byte("aaz"), []byte("bbb"), 2) + require.NoError(t, err) + require.Len(t, regions, 2) + left := regions[0] + + epochNotLeader := &import_sstpb.Error{ + Message: "leader not found", + StoreError: &errorpb.Error{ + NotLeader: &errorpb.NotLeader{ + RegionId: 2, + }, + }} + isSecondRun := false + err = ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + if !isSecondRun && r.Region.Id == left.Region.Id { + isSecondRun = true + return RPCResultFromPBError(epochNotLeader) + } + return RPCResultOK() + }) + printPDRegion("cli", cli.regionsInfo.Regions) + require.Equal(t, 1, rs.Attempt()) + // we retried leader not found error. so the next backoff should be 2 * initical backoff. + require.Equal(t, 2*time.Millisecond, rs.ExponentialBackoff()) + require.NoError(t, err) +} + +func TestWrappedError(t *testing.T) { + result := RPCResultFromError(errors.Trace(status.Error(codes.Unavailable, "the server is slacking. ><=ยท>"))) + require.Equal(t, result.StrategyForRetry(), StrategyFromThisRegion) + result = RPCResultFromError(errors.Trace(status.Error(codes.Unknown, "the server said something hard to understand"))) + require.Equal(t, result.StrategyForRetry(), StrategyGiveUp) +} + +func envInt(name string, def int) int { + lit := os.Getenv(name) + r, err := strconv.Atoi(lit) + if err != nil { + return def + } + return r +} + +func TestPaginateScanLeader(t *testing.T) { + // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) + cli := initTestClient(false) + rs := utils.InitialRetryState(2, time.Millisecond, 10*time.Millisecond) + ctl := OverRegionsInRange([]byte("aa"), []byte("aaz"), cli, &rs) + ctx := context.Background() + + cli.InjectErr = true + cli.InjectTimes = int32(envInt("PAGINATE_SCAN_LEADER_FAILURE_COUNT", 2)) + collectedRegions := []*split.RegionInfo{} + ctl.Run(ctx, func(ctx context.Context, r *split.RegionInfo) RPCResult { + collectedRegions = append(collectedRegions, r) + return RPCResultOK() + }) + assertRegions(t, collectedRegions, "", "aay", "bba") +} + +func TestImportKVFiles(t *testing.T) { + var ( + importer = FileImporter{} + ctx = context.Background() + shiftStartTS uint64 = 100 + startTS uint64 = 200 + restoreTS uint64 = 300 + ) + + err := importer.ImportKVFiles( + ctx, + []*LogDataFileInfo{ + { + DataFileInfo: &backuppb.DataFileInfo{ + Path: "log3", + }, + }, + { + DataFileInfo: &backuppb.DataFileInfo{ + Path: "log1", + }, + }, + }, + nil, + shiftStartTS, + startTS, + restoreTS, + false, + ) + require.True(t, berrors.ErrInvalidArgument.Equal(err)) +} + +func TestFilterFilesByRegion(t *testing.T) { + files := []*LogDataFileInfo{ + { + DataFileInfo: &backuppb.DataFileInfo{ + Path: "log3", + }, + }, + { + DataFileInfo: &backuppb.DataFileInfo{ + Path: "log1", + }, + }, + } + ranges := []kv.KeyRange{ + { + StartKey: []byte("1111"), + EndKey: []byte("2222"), + }, { + StartKey: []byte("3333"), + EndKey: []byte("4444"), + }, + } + + testCases := []struct { + r split.RegionInfo + subfiles []*LogDataFileInfo + err error + }{ + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("0000"), + EndKey: []byte("1110"), + }, + }, + subfiles: []*LogDataFileInfo{}, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("0000"), + EndKey: []byte("1111"), + }, + }, + subfiles: []*LogDataFileInfo{ + files[0], + }, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("0000"), + EndKey: []byte("2222"), + }, + }, + subfiles: []*LogDataFileInfo{ + files[0], + }, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("2222"), + EndKey: []byte("3332"), + }, + }, + subfiles: []*LogDataFileInfo{ + files[0], + }, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("2223"), + EndKey: []byte("3332"), + }, + }, + subfiles: []*LogDataFileInfo{}, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("3332"), + EndKey: []byte("3333"), + }, + }, + subfiles: []*LogDataFileInfo{ + files[1], + }, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("4444"), + EndKey: []byte("5555"), + }, + }, + subfiles: []*LogDataFileInfo{ + files[1], + }, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("4444"), + EndKey: nil, + }, + }, + subfiles: []*LogDataFileInfo{ + files[1], + }, + err: nil, + }, + { + r: split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte("0000"), + EndKey: nil, + }, + }, + subfiles: files, + err: nil, + }, + } + + for _, c := range testCases { + subfile, err := FilterFilesByRegion(files, ranges, &c.r) + require.Equal(t, err, c.err) + require.Equal(t, subfile, c.subfiles) + } +} diff --git a/br/pkg/restore/split/mock_pd_client.go b/br/pkg/restore/split/mock_pd_client.go new file mode 100644 index 0000000000000..4bd709260e90a --- /dev/null +++ b/br/pkg/restore/split/mock_pd_client.go @@ -0,0 +1,195 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package split + +import ( + "context" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/tidb/pkg/store/pdtypes" + "github.com/pingcap/tidb/pkg/util/codec" + pd "github.com/tikv/pd/client" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// MockPDClientForSplit is a mock PD client for testing split and scatter. +type MockPDClientForSplit struct { + pd.Client + + mu sync.Mutex + + Regions *pdtypes.RegionTree + lastRegionID uint64 + scanRegions struct { + errors []error + beforeHook func() + } + splitRegions struct { + count int + hijacked func() (bool, *kvrpcpb.SplitRegionResponse, error) + } + scatterRegion struct { + eachRegionFailBefore int + count map[uint64]int + } + scatterRegions struct { + notImplemented bool + regionCount int + } + getOperator struct { + responses map[uint64][]*pdpb.GetOperatorResponse + } +} + +// NewMockPDClientForSplit creates a new MockPDClientForSplit. +func NewMockPDClientForSplit() *MockPDClientForSplit { + ret := &MockPDClientForSplit{} + ret.Regions = &pdtypes.RegionTree{} + ret.scatterRegion.count = make(map[uint64]int) + return ret +} + +func newRegionNotFullyReplicatedErr(regionID uint64) error { + return status.Errorf(codes.Unknown, "region %d is not fully replicated", regionID) +} + +func (c *MockPDClientForSplit) SetRegions(boundaries [][]byte) []*metapb.Region { + c.mu.Lock() + defer c.mu.Unlock() + + return c.setRegions(boundaries) +} + +func (c *MockPDClientForSplit) setRegions(boundaries [][]byte) []*metapb.Region { + ret := make([]*metapb.Region, 0, len(boundaries)-1) + for i := 1; i < len(boundaries); i++ { + c.lastRegionID++ + r := &metapb.Region{ + Id: c.lastRegionID, + StartKey: boundaries[i-1], + EndKey: boundaries[i], + } + p := &metapb.Peer{ + Id: c.lastRegionID, + StoreId: 1, + } + c.Regions.SetRegion(&pdtypes.Region{ + Meta: r, + Leader: p, + }) + ret = append(ret, r) + } + return ret +} + +func (c *MockPDClientForSplit) ScanRegions( + _ context.Context, + key, endKey []byte, + limit int, + _ ...pd.GetRegionOption, +) ([]*pd.Region, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if len(c.scanRegions.errors) > 0 { + err := c.scanRegions.errors[0] + c.scanRegions.errors = c.scanRegions.errors[1:] + return nil, err + } + + if c.scanRegions.beforeHook != nil { + c.scanRegions.beforeHook() + } + + regions := c.Regions.ScanRange(key, endKey, limit) + ret := make([]*pd.Region, 0, len(regions)) + for _, r := range regions { + ret = append(ret, &pd.Region{ + Meta: r.Meta, + Leader: r.Leader, + }) + } + return ret, nil +} + +func (c *MockPDClientForSplit) GetRegionByID(_ context.Context, regionID uint64, _ ...pd.GetRegionOption) (*pd.Region, error) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, r := range c.Regions.Regions { + if r.Meta.Id == regionID { + return &pd.Region{ + Meta: r.Meta, + Leader: r.Leader, + }, nil + } + } + return nil, errors.New("region not found") +} + +func (c *MockPDClientForSplit) SplitRegion( + region *RegionInfo, + keys [][]byte, + isRawKV bool, +) (bool, *kvrpcpb.SplitRegionResponse, error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.splitRegions.count++ + if c.splitRegions.hijacked != nil { + return c.splitRegions.hijacked() + } + + if !isRawKV { + for i := range keys { + keys[i] = codec.EncodeBytes(nil, keys[i]) + } + } + + newRegionBoundaries := make([][]byte, 0, len(keys)+2) + newRegionBoundaries = append(newRegionBoundaries, region.Region.StartKey) + newRegionBoundaries = append(newRegionBoundaries, keys...) + newRegionBoundaries = append(newRegionBoundaries, region.Region.EndKey) + newRegions := c.setRegions(newRegionBoundaries) + newRegions[0].Id = region.Region.Id + return false, &kvrpcpb.SplitRegionResponse{Regions: newRegions}, nil +} + +func (c *MockPDClientForSplit) ScatterRegion(_ context.Context, regionID uint64) error { + c.mu.Lock() + defer c.mu.Unlock() + + c.scatterRegion.count[regionID]++ + if c.scatterRegion.count[regionID] > c.scatterRegion.eachRegionFailBefore { + return nil + } + return newRegionNotFullyReplicatedErr(regionID) +} + +func (c *MockPDClientForSplit) ScatterRegions(_ context.Context, regionIDs []uint64, _ ...pd.RegionsOption) (*pdpb.ScatterRegionResponse, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.scatterRegions.notImplemented { + return nil, status.Error(codes.Unimplemented, "Ah, yep") + } + c.scatterRegions.regionCount += len(regionIDs) + return &pdpb.ScatterRegionResponse{}, nil +} + +func (c *MockPDClientForSplit) GetOperator(_ context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.getOperator.responses == nil { + return &pdpb.GetOperatorResponse{Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_SUCCESS}, nil + } + ret := c.getOperator.responses[regionID][0] + c.getOperator.responses[regionID] = c.getOperator.responses[regionID][1:] + return ret, nil +} diff --git a/br/pkg/restore/split/split.go b/br/pkg/restore/split/split.go new file mode 100644 index 0000000000000..c69e5959f9812 --- /dev/null +++ b/br/pkg/restore/split/split.go @@ -0,0 +1,350 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package split + +import ( + "bytes" + "context" + "encoding/hex" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/redact" + "go.uber.org/zap" +) + +var ( + WaitRegionOnlineAttemptTimes = config.DefaultRegionCheckBackoffLimit + SplitRetryTimes = 150 +) + +// Constants for split retry machinery. +const ( + SplitRetryInterval = 50 * time.Millisecond + SplitMaxRetryInterval = 4 * time.Second + + // it takes 30 minutes to scatter regions when each TiKV has 400k regions + ScatterWaitUpperInterval = 30 * time.Minute + + ScanRegionPaginationLimit = 128 +) + +func checkRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { + // current pd can't guarantee the consistency of returned regions + if len(regions) == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan region return empty result, startKey: %s, endKey: %s", + redact.Key(startKey), redact.Key(endKey)) + } + + if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "first region %d's startKey(%s) > startKey(%s), region epoch: %s", + regions[0].Region.Id, + redact.Key(regions[0].Region.StartKey), redact.Key(startKey), + regions[0].Region.RegionEpoch.String()) + } else if len(regions[len(regions)-1].Region.EndKey) != 0 && + bytes.Compare(regions[len(regions)-1].Region.EndKey, endKey) < 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "last region %d's endKey(%s) < endKey(%s), region epoch: %s", + regions[len(regions)-1].Region.Id, + redact.Key(regions[len(regions)-1].Region.EndKey), redact.Key(endKey), + regions[len(regions)-1].Region.RegionEpoch.String()) + } + + cur := regions[0] + if cur.Leader == nil { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader is nil", cur.Region.Id) + } + if cur.Leader.StoreId == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader's store id is 0", cur.Region.Id) + } + for _, r := range regions[1:] { + if r.Leader == nil { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader is nil", r.Region.Id) + } + if r.Leader.StoreId == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader's store id is 0", r.Region.Id) + } + if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's endKey not equal to next region %d's startKey, endKey: %s, startKey: %s, region epoch: %s %s", + cur.Region.Id, r.Region.Id, + redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey), + cur.Region.RegionEpoch.String(), r.Region.RegionEpoch.String()) + } + cur = r + } + + return nil +} + +// PaginateScanRegion scan regions with a limit pagination and return all regions +// at once. The returned regions are continuous and cover the key range. If not, +// or meet errors, it will retry internally. +func PaginateScanRegion( + ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, +) ([]*RegionInfo, error) { + if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { + return nil, errors.Annotatef(berrors.ErrInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + + var ( + lastRegions []*RegionInfo + err error + backoffer = NewWaitRegionOnlineBackoffer() + ) + _ = utils.WithRetry(ctx, func() error { + regions := make([]*RegionInfo, 0, 16) + scanStartKey := startKey + for { + var batch []*RegionInfo + batch, err = client.ScanRegions(ctx, scanStartKey, endKey, limit) + if err != nil { + err = errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan regions from start-key:%s, err: %s", + redact.Key(scanStartKey), err.Error()) + return err + } + regions = append(regions, batch...) + if len(batch) < limit { + // No more region + break + } + scanStartKey = batch[len(batch)-1].Region.GetEndKey() + if len(scanStartKey) == 0 || + (len(endKey) > 0 && bytes.Compare(scanStartKey, endKey) >= 0) { + // All key space have scanned + break + } + } + // if the number of regions changed, we can infer TiKV side really + // made some progress so don't increase the retry times. + if len(regions) != len(lastRegions) { + backoffer.Stat.ReduceRetry() + } + lastRegions = regions + + if err = checkRegionConsistency(startKey, endKey, regions); err != nil { + log.Warn("failed to scan region, retrying", + logutil.ShortError(err), + zap.Int("regionLength", len(regions))) + return err + } + return nil + }, backoffer) + + return lastRegions, err +} + +// CheckPartRegionConsistency only checks the continuity of regions and the first region consistency. +func CheckPartRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { + // current pd can't guarantee the consistency of returned regions + if len(regions) == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "scan region return empty result, startKey: %s, endKey: %s", + redact.Key(startKey), redact.Key(endKey)) + } + + if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "first region's startKey > startKey, startKey: %s, regionStartKey: %s", + redact.Key(startKey), redact.Key(regions[0].Region.StartKey)) + } + + cur := regions[0] + for _, r := range regions[1:] { + if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region endKey not equal to next region startKey, endKey: %s, startKey: %s", + redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey)) + } + cur = r + } + + return nil +} + +func ScanRegionsWithRetry( + ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, +) ([]*RegionInfo, error) { + if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { + return nil, errors.Annotatef(berrors.ErrInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + + var regions []*RegionInfo + var err error + // we don't need to return multierr. since there only 3 times retry. + // in most case 3 times retry have the same error. so we just return the last error. + // actually we'd better remove all multierr in br/lightning. + // because it's not easy to check multierr equals normal error. + // see https://github.com/pingcap/tidb/issues/33419. + _ = utils.WithRetry(ctx, func() error { + regions, err = client.ScanRegions(ctx, startKey, endKey, limit) + if err != nil { + err = errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan regions from start-key:%s, err: %s", + redact.Key(startKey), err.Error()) + return err + } + + if err = CheckPartRegionConsistency(startKey, endKey, regions); err != nil { + log.Warn("failed to scan region, retrying", logutil.ShortError(err)) + return err + } + + return nil + }, NewWaitRegionOnlineBackoffer()) + + return regions, err +} + +type WaitRegionOnlineBackoffer struct { + Stat utils.RetryState +} + +// NewWaitRegionOnlineBackoffer create a backoff to wait region online. +func NewWaitRegionOnlineBackoffer() *WaitRegionOnlineBackoffer { + return &WaitRegionOnlineBackoffer{ + Stat: utils.InitialRetryState( + WaitRegionOnlineAttemptTimes, + time.Millisecond*10, + time.Second*2, + ), + } +} + +// NextBackoff returns a duration to wait before retrying again +func (b *WaitRegionOnlineBackoffer) NextBackoff(err error) time.Duration { + // TODO(lance6716): why we only backoff when the error is ErrPDBatchScanRegion? + if berrors.ErrPDBatchScanRegion.Equal(err) { + // it needs more time to wait splitting the regions that contains data in PITR. + // 2s * 150 + delayTime := b.Stat.ExponentialBackoff() + failpoint.Inject("hint-scan-region-backoff", func(val failpoint.Value) { + if val.(bool) { + delayTime = time.Microsecond + } + }) + return delayTime + } + b.Stat.GiveUp() + return 0 +} + +// Attempt returns the remain attempt times +func (b *WaitRegionOnlineBackoffer) Attempt() int { + return b.Stat.Attempt() +} + +// BackoffMayNotCountBackoffer is a backoffer but it may not increase the retry +// counter. It should be used with ErrBackoff or ErrBackoffAndDontCount. +type BackoffMayNotCountBackoffer struct { + state utils.RetryState +} + +var ( + ErrBackoff = errors.New("found backoff error") + ErrBackoffAndDontCount = errors.New("found backoff error but don't count") +) + +// NewBackoffMayNotCountBackoffer creates a new backoffer that may backoff or retry. +// +// TODO: currently it has the same usage as NewWaitRegionOnlineBackoffer so we +// don't expose its inner settings. +func NewBackoffMayNotCountBackoffer() *BackoffMayNotCountBackoffer { + return &BackoffMayNotCountBackoffer{ + state: utils.InitialRetryState( + WaitRegionOnlineAttemptTimes, + time.Millisecond*10, + time.Second*2, + ), + } +} + +// NextBackoff implements utils.Backoffer. For BackoffMayNotCountBackoffer, only +// ErrBackoff and ErrBackoffAndDontCount is meaningful. +func (b *BackoffMayNotCountBackoffer) NextBackoff(err error) time.Duration { + if errors.ErrorEqual(err, ErrBackoff) { + return b.state.ExponentialBackoff() + } + if errors.ErrorEqual(err, ErrBackoffAndDontCount) { + delay := b.state.ExponentialBackoff() + b.state.ReduceRetry() + return delay + } + b.state.GiveUp() + return 0 +} + +// Attempt implements utils.Backoffer. +func (b *BackoffMayNotCountBackoffer) Attempt() int { + return b.state.Attempt() +} + +// getSplitKeysOfRegions checks every input key is necessary to split region on +// it. Returns a map from region to split keys belongs to it. +// +// The key will be skipped if it's the region boundary. +// +// prerequisite: +// - sortedKeys are sorted in ascending order. +// - sortedRegions are continuous and sorted in ascending order by start key. +// - sortedRegions can cover all keys in sortedKeys. +// PaginateScanRegion should satisfy the above prerequisites. +func getSplitKeysOfRegions( + sortedKeys [][]byte, + sortedRegions []*RegionInfo, + isRawKV bool, +) map[*RegionInfo][][]byte { + splitKeyMap := make(map[*RegionInfo][][]byte, len(sortedRegions)) + curKeyIndex := 0 + splitKey := codec.EncodeBytesExt(nil, sortedKeys[curKeyIndex], isRawKV) + + for _, region := range sortedRegions { + for { + if len(sortedKeys[curKeyIndex]) == 0 { + // should not happen? + goto nextKey + } + // If splitKey is the boundary of the region, don't need to split on it. + if bytes.Equal(splitKey, region.Region.GetStartKey()) { + goto nextKey + } + // If splitKey is not in this region, we should move to the next region. + if !region.ContainsInterior(splitKey) { + break + } + + splitKeyMap[region] = append(splitKeyMap[region], sortedKeys[curKeyIndex]) + + nextKey: + curKeyIndex++ + if curKeyIndex >= len(sortedKeys) { + return splitKeyMap + } + splitKey = codec.EncodeBytesExt(nil, sortedKeys[curKeyIndex], isRawKV) + } + } + lastKey := sortedKeys[len(sortedKeys)-1] + endOfLastRegion := sortedRegions[len(sortedRegions)-1].Region.GetEndKey() + if !bytes.Equal(lastKey, endOfLastRegion) { + log.Error("in getSplitKeysOfRegions, regions don't cover all keys", + zap.String("firstKey", hex.EncodeToString(sortedKeys[0])), + zap.String("lastKey", hex.EncodeToString(lastKey)), + zap.String("firstRegionStartKey", hex.EncodeToString(sortedRegions[0].Region.GetStartKey())), + zap.String("lastRegionEndKey", hex.EncodeToString(endOfLastRegion)), + ) + } + return splitKeyMap +} diff --git a/br/pkg/restore/split/split_test.go b/br/pkg/restore/split/split_test.go new file mode 100644 index 0000000000000..9ca523fe214f4 --- /dev/null +++ b/br/pkg/restore/split/split_test.go @@ -0,0 +1,691 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. +package split + +import ( + "bytes" + "context" + goerrors "errors" + "slices" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/store/pdtypes" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestScanRegionBackOfferWithSuccess(t *testing.T) { + var counter int + bo := NewWaitRegionOnlineBackoffer() + + err := utils.WithRetry(context.Background(), func() error { + defer func() { + counter++ + }() + + if counter == 3 { + return nil + } + return berrors.ErrPDBatchScanRegion + }, bo) + require.NoError(t, err) + require.Equal(t, counter, 4) +} + +func TestScanRegionBackOfferWithFail(t *testing.T) { + _ = failpoint.Enable("github.com/pingcap/tidb/br/pkg/restore/split/hint-scan-region-backoff", "return(true)") + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/restore/split/hint-scan-region-backoff") + }() + + var counter int + bo := NewWaitRegionOnlineBackoffer() + + err := utils.WithRetry(context.Background(), func() error { + defer func() { + counter++ + }() + return berrors.ErrPDBatchScanRegion + }, bo) + require.Error(t, err) + require.Equal(t, counter, WaitRegionOnlineAttemptTimes) +} + +func TestScanRegionBackOfferWithStopRetry(t *testing.T) { + _ = failpoint.Enable("github.com/pingcap/tidb/br/pkg/restore/split/hint-scan-region-backoff", "return(true)") + defer func() { + _ = failpoint.Disable("github.com/pingcap/tidb/br/pkg/restore/split/hint-scan-region-backoff") + }() + + var counter int + bo := NewWaitRegionOnlineBackoffer() + + err := utils.WithRetry(context.Background(), func() error { + defer func() { + counter++ + }() + + if counter < 5 { + return berrors.ErrPDBatchScanRegion + } + return berrors.ErrKVUnknown + }, bo) + require.Error(t, err) + require.Equal(t, counter, 6) +} + +type recordCntBackoffer struct { + already int +} + +func (b *recordCntBackoffer) NextBackoff(error) time.Duration { + b.already++ + return 0 +} + +func (b *recordCntBackoffer) Attempt() int { + return 100 +} + +func TestScatterSequentiallyRetryCnt(t *testing.T) { + mockClient := NewMockPDClientForSplit() + mockClient.scatterRegion.eachRegionFailBefore = 7 + client := pdClient{ + needScatterVal: true, + client: mockClient, + } + client.needScatterInit.Do(func() {}) + + ctx := context.Background() + regions := []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 1, + }, + }, + { + Region: &metapb.Region{ + Id: 2, + }, + }, + } + backoffer := &recordCntBackoffer{} + client.scatterRegionsSequentially( + ctx, + regions, + backoffer, + ) + require.Equal(t, 7, backoffer.already) +} + +func TestScatterBackwardCompatibility(t *testing.T) { + mockClient := NewMockPDClientForSplit() + mockClient.scatterRegions.notImplemented = true + client := pdClient{ + needScatterVal: true, + client: mockClient, + } + client.needScatterInit.Do(func() {}) + + ctx := context.Background() + regions := []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 1, + }, + }, + { + Region: &metapb.Region{ + Id: 2, + }, + }, + } + err := client.scatterRegions(ctx, regions) + require.NoError(t, err) + require.Equal(t, map[uint64]int{1: 1, 2: 1}, client.client.(*MockPDClientForSplit).scatterRegion.count) +} + +func TestWaitForScatterRegions(t *testing.T) { + mockPDCli := NewMockPDClientForSplit() + mockPDCli.scatterRegions.notImplemented = true + client := pdClient{ + needScatterVal: true, + client: mockPDCli, + } + client.needScatterInit.Do(func() {}) + regionCnt := 6 + checkGetOperatorRespsDrained := func() { + for i := 1; i <= regionCnt; i++ { + require.Len(t, mockPDCli.getOperator.responses[uint64(i)], 0) + } + } + checkNoRetry := func() { + for i := 1; i <= regionCnt; i++ { + require.Equal(t, 0, mockPDCli.scatterRegion.count[uint64(i)]) + } + } + + ctx := context.Background() + regions := make([]*RegionInfo, 0, regionCnt) + for i := 1; i <= regionCnt; i++ { + regions = append(regions, &RegionInfo{ + Region: &metapb.Region{ + Id: uint64(i), + }, + }) + } + + mockPDCli.getOperator.responses = make(map[uint64][]*pdpb.GetOperatorResponse) + mockPDCli.getOperator.responses[1] = []*pdpb.GetOperatorResponse{ + {Header: &pdpb.ResponseHeader{Error: &pdpb.Error{Type: pdpb.ErrorType_REGION_NOT_FOUND}}}, + } + mockPDCli.getOperator.responses[2] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("not-scatter-region")}, + } + mockPDCli.getOperator.responses[3] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_SUCCESS}, + } + mockPDCli.getOperator.responses[4] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_RUNNING}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_TIMEOUT}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_SUCCESS}, + } + mockPDCli.getOperator.responses[5] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_CANCEL}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_CANCEL}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_CANCEL}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_RUNNING}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_RUNNING}, + {Desc: []byte("not-scatter-region")}, + } + // should trigger a retry + mockPDCli.getOperator.responses[6] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_REPLACE}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_SUCCESS}, + } + + left, err := client.WaitRegionsScattered(ctx, regions) + require.NoError(t, err) + require.Equal(t, 0, left) + for i := 1; i <= 3; i++ { + require.Equal(t, 0, mockPDCli.scatterRegion.count[uint64(i)]) + } + // OperatorStatus_TIMEOUT should trigger rescatter once + require.Equal(t, 1, mockPDCli.scatterRegion.count[uint64(4)]) + // 3 * OperatorStatus_CANCEL should trigger 3 * rescatter + require.Equal(t, 3, mockPDCli.scatterRegion.count[uint64(5)]) + // OperatorStatus_REPLACE should trigger rescatter once + require.Equal(t, 1, mockPDCli.scatterRegion.count[uint64(6)]) + checkGetOperatorRespsDrained() + + // test non-retryable error + + mockPDCli.scatterRegion.count = make(map[uint64]int) + mockPDCli.getOperator.responses = make(map[uint64][]*pdpb.GetOperatorResponse) + mockPDCli.getOperator.responses[1] = []*pdpb.GetOperatorResponse{ + {Header: &pdpb.ResponseHeader{Error: &pdpb.Error{Type: pdpb.ErrorType_REGION_NOT_FOUND}}}, + } + mockPDCli.getOperator.responses[2] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("not-scatter-region")}, + } + // mimic non-retryable error + mockPDCli.getOperator.responses[3] = []*pdpb.GetOperatorResponse{ + {Header: &pdpb.ResponseHeader{Error: &pdpb.Error{Type: pdpb.ErrorType_DATA_COMPACTED}}}, + } + left, err = client.WaitRegionsScattered(ctx, regions) + require.ErrorContains(t, err, "get operator error: DATA_COMPACTED") + require.Equal(t, 4, left) // region 3,4,5,6 is not scattered + checkGetOperatorRespsDrained() + checkNoRetry() + + // test backoff is timed-out + + backup := WaitRegionOnlineAttemptTimes + WaitRegionOnlineAttemptTimes = 2 + t.Cleanup(func() { + WaitRegionOnlineAttemptTimes = backup + }) + + mockPDCli.scatterRegion.count = make(map[uint64]int) + mockPDCli.getOperator.responses = make(map[uint64][]*pdpb.GetOperatorResponse) + mockPDCli.getOperator.responses[1] = []*pdpb.GetOperatorResponse{ + {Header: &pdpb.ResponseHeader{Error: &pdpb.Error{Type: pdpb.ErrorType_REGION_NOT_FOUND}}}, + } + mockPDCli.getOperator.responses[2] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("not-scatter-region")}, + } + mockPDCli.getOperator.responses[3] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_SUCCESS}, + } + mockPDCli.getOperator.responses[4] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_RUNNING}, + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_RUNNING}, // first retry + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_RUNNING}, // second retry + } + mockPDCli.getOperator.responses[5] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("not-scatter-region")}, + } + mockPDCli.getOperator.responses[6] = []*pdpb.GetOperatorResponse{ + {Desc: []byte("scatter-region"), Status: pdpb.OperatorStatus_SUCCESS}, + } + left, err = client.WaitRegionsScattered(ctx, regions) + require.ErrorContains(t, err, "the first unfinished region: id:4") + require.Equal(t, 1, left) + checkGetOperatorRespsDrained() + checkNoRetry() +} + +func TestBackoffMayNotCountBackoffer(t *testing.T) { + b := NewBackoffMayNotCountBackoffer() + initVal := b.Attempt() + + b.NextBackoff(ErrBackoffAndDontCount) + require.Equal(t, initVal, b.Attempt()) + // test Annotate, which is the real usage in caller + b.NextBackoff(errors.Annotate(ErrBackoffAndDontCount, "caller message")) + require.Equal(t, initVal, b.Attempt()) + + b.NextBackoff(ErrBackoff) + require.Equal(t, initVal-1, b.Attempt()) + + b.NextBackoff(goerrors.New("test")) + require.Equal(t, 0, b.Attempt()) +} + +func TestSplitCtxCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + mockCli := NewMockPDClientForSplit() + mockCli.splitRegions.hijacked = func() (bool, *kvrpcpb.SplitRegionResponse, error) { + cancel() + resp := &kvrpcpb.SplitRegionResponse{ + Regions: []*metapb.Region{ + {Id: 1}, + {Id: 2}, + }, + } + return false, resp, nil + } + client := pdClient{ + client: mockCli, + } + + _, err := client.SplitWaitAndScatter(ctx, &RegionInfo{}, [][]byte{{1}}) + require.ErrorIs(t, err, context.Canceled) +} + +func TestGetSplitKeyPerRegion(t *testing.T) { + // test case moved from BR + sortedKeys := [][]byte{ + []byte("b"), + []byte("d"), + []byte("g"), + []byte("j"), + []byte("l"), + } + sortedRegions := []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 1, + StartKey: []byte("a"), + EndKey: []byte("g"), + }, + }, + { + Region: &metapb.Region{ + Id: 2, + StartKey: []byte("g"), + EndKey: []byte("k"), + }, + }, + { + Region: &metapb.Region{ + Id: 3, + StartKey: []byte("k"), + EndKey: []byte("m"), + }, + }, + } + result := getSplitKeysOfRegions(sortedKeys, sortedRegions, false) + require.Equal(t, 3, len(result)) + require.Equal(t, [][]byte{[]byte("b"), []byte("d")}, result[sortedRegions[0]]) + require.Equal(t, [][]byte{[]byte("g"), []byte("j")}, result[sortedRegions[1]]) + require.Equal(t, [][]byte{[]byte("l")}, result[sortedRegions[2]]) + + // test case moved from lightning + tableID := int64(1) + keys := []int64{1, 10, 100, 1000, 10000, -1} + sortedRegions = make([]*RegionInfo, 0, len(keys)) + start := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(0)) + regionStart := codec.EncodeBytes([]byte{}, start) + for i, end := range keys { + var regionEndKey []byte + if end >= 0 { + endKey := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(end)) + regionEndKey = codec.EncodeBytes([]byte{}, endKey) + } + region := &RegionInfo{ + Region: &metapb.Region{ + Id: uint64(i), + StartKey: regionStart, + EndKey: regionEndKey, + }, + } + sortedRegions = append(sortedRegions, region) + regionStart = regionEndKey + } + + checkKeys := map[int64]int{ + 0: -1, + 5: 1, + 6: 1, + 7: 1, + 50: 2, + 60: 2, + 70: 2, + 100: -1, + 50000: 5, + } + expected := map[uint64][][]byte{} + sortedKeys = make([][]byte, 0, len(checkKeys)) + + for hdl, idx := range checkKeys { + key := tablecodec.EncodeRowKeyWithHandle(tableID, kv.IntHandle(hdl)) + sortedKeys = append(sortedKeys, key) + if idx < 0 { + continue + } + expected[uint64(idx)] = append(expected[uint64(idx)], key) + } + + slices.SortFunc(sortedKeys, bytes.Compare) + for i := range expected { + slices.SortFunc(expected[i], bytes.Compare) + } + + got := getSplitKeysOfRegions(sortedKeys, sortedRegions, false) + require.Equal(t, len(expected), len(got)) + for region, gotKeys := range got { + require.Equal(t, expected[region.Region.GetId()], gotKeys) + } +} + +func checkRegionsBoundaries(t *testing.T, regions []*RegionInfo, expected [][]byte) { + require.Len( + t, regions, len(expected)-1, + "first region start key: %v, last region end key: %v, first expected key: %v, last expected key: %v", + regions[0].Region.StartKey, regions[len(regions)-1].Region.EndKey, + expected[0], expected[len(expected)-1], + ) + for i := 1; i < len(expected); i++ { + require.Equal(t, expected[i-1], regions[i-1].Region.StartKey) + require.Equal(t, expected[i], regions[i-1].Region.EndKey) + } +} + +func TestPaginateScanRegion(t *testing.T) { + ctx := context.Background() + mockPDClient := NewMockPDClientForSplit() + mockClient := &pdClient{ + client: mockPDClient, + } + + backup := WaitRegionOnlineAttemptTimes + WaitRegionOnlineAttemptTimes = 3 + t.Cleanup(func() { + WaitRegionOnlineAttemptTimes = backup + }) + + // no region + _, err := PaginateScanRegion(ctx, mockClient, []byte{}, []byte{}, 3) + require.Error(t, err) + require.True(t, berrors.ErrPDBatchScanRegion.Equal(err)) + require.ErrorContains(t, err, "scan region return empty result") + + // retry on error + mockPDClient.scanRegions.errors = []error{ + status.Error(codes.Unavailable, "not leader"), + } + mockPDClient.SetRegions([][]byte{{}, {}}) + got, err := PaginateScanRegion(ctx, mockClient, []byte{}, []byte{}, 3) + require.NoError(t, err) + checkRegionsBoundaries(t, got, [][]byte{{}, {}}) + + // test paginate + boundaries := [][]byte{{}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {}} + mockPDClient.SetRegions(boundaries) + got, err = PaginateScanRegion(ctx, mockClient, []byte{}, []byte{}, 3) + require.NoError(t, err) + checkRegionsBoundaries(t, got, boundaries) + got, err = PaginateScanRegion(ctx, mockClient, []byte{1}, []byte{}, 3) + require.NoError(t, err) + checkRegionsBoundaries(t, got, boundaries[1:]) + got, err = PaginateScanRegion(ctx, mockClient, []byte{}, []byte{2}, 8) + require.NoError(t, err) + checkRegionsBoundaries(t, got, boundaries[:3]) // [, 1), [1, 2) + got, err = PaginateScanRegion(ctx, mockClient, []byte{4}, []byte{5}, 1) + require.NoError(t, err) + checkRegionsBoundaries(t, got, [][]byte{{4}, {5}}) + + // test start == end + _, err = PaginateScanRegion(ctx, mockClient, []byte{4}, []byte{4}, 1) + require.ErrorContains(t, err, "scan region return empty result") + + // test start > end + _, err = PaginateScanRegion(ctx, mockClient, []byte{5}, []byte{4}, 5) + require.True(t, berrors.ErrInvalidRange.Equal(err)) + require.ErrorContains(t, err, "startKey > endKey") + + // test retry exhausted + mockPDClient.scanRegions.errors = []error{ + status.Error(codes.Unavailable, "not leader"), + status.Error(codes.Unavailable, "not leader"), + status.Error(codes.Unavailable, "not leader"), + } + _, err = PaginateScanRegion(ctx, mockClient, []byte{4}, []byte{5}, 1) + require.ErrorContains(t, err, "not leader") + + // test region not continuous + mockPDClient.Regions = &pdtypes.RegionTree{} + mockPDClient.Regions.SetRegion(&pdtypes.Region{ + Meta: &metapb.Region{ + Id: 1, + StartKey: []byte{1}, + EndKey: []byte{2}, + }, + Leader: &metapb.Peer{ + Id: 1, + StoreId: 1, + }, + }) + mockPDClient.Regions.SetRegion(&pdtypes.Region{ + Meta: &metapb.Region{ + Id: 4, + StartKey: []byte{4}, + EndKey: []byte{5}, + }, + Leader: &metapb.Peer{ + Id: 4, + StoreId: 1, + }, + }) + + _, err = PaginateScanRegion(ctx, mockClient, []byte{1}, []byte{5}, 3) + require.True(t, berrors.ErrPDBatchScanRegion.Equal(err)) + require.ErrorContains(t, err, "region 1's endKey not equal to next region 4's startKey") + + // test region becomes continuous slowly + toAdd := []*pdtypes.Region{ + { + Meta: &metapb.Region{ + Id: 2, + StartKey: []byte{2}, + EndKey: []byte{3}, + }, + Leader: &metapb.Peer{ + Id: 2, + StoreId: 1, + }, + }, + { + Meta: &metapb.Region{ + Id: 3, + StartKey: []byte{3}, + EndKey: []byte{4}, + }, + Leader: &metapb.Peer{ + Id: 3, + StoreId: 1, + }, + }, + } + mockPDClient.scanRegions.beforeHook = func() { + mockPDClient.Regions.SetRegion(toAdd[0]) + toAdd = toAdd[1:] + } + got, err = PaginateScanRegion(ctx, mockClient, []byte{1}, []byte{5}, 100) + require.NoError(t, err) + checkRegionsBoundaries(t, got, [][]byte{{1}, {2}, {3}, {4}, {5}}) +} + +func TestRegionConsistency(t *testing.T) { + cases := []struct { + startKey []byte + endKey []byte + err string + regions []*RegionInfo + }{ + { + codec.EncodeBytes([]byte{}, []byte("a")), + codec.EncodeBytes([]byte{}, []byte("a")), + "scan region return empty result, startKey: (.*?), endKey: (.*?)", + []*RegionInfo{}, + }, + { + codec.EncodeBytes([]byte{}, []byte("a")), + codec.EncodeBytes([]byte{}, []byte("a")), + "first region 1's startKey(.*?) > startKey(.*?)", + []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 1, + StartKey: codec.EncodeBytes([]byte{}, []byte("b")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + }, + }, + }, + }, + { + codec.EncodeBytes([]byte{}, []byte("b")), + codec.EncodeBytes([]byte{}, []byte("e")), + "last region 100's endKey(.*?) < endKey(.*?)", + []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 100, + StartKey: codec.EncodeBytes([]byte{}, []byte("b")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + }, + }, + }, + }, + { + codec.EncodeBytes([]byte{}, []byte("c")), + codec.EncodeBytes([]byte{}, []byte("e")), + "region 6's endKey not equal to next region 8's startKey(.*?)", + []*RegionInfo{ + { + Leader: &metapb.Peer{ + Id: 6, + StoreId: 1, + }, + Region: &metapb.Region{ + Id: 6, + StartKey: codec.EncodeBytes([]byte{}, []byte("b")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + RegionEpoch: nil, + }, + }, + { + Leader: &metapb.Peer{ + Id: 8, + StoreId: 1, + }, + Region: &metapb.Region{ + Id: 8, + StartKey: codec.EncodeBytes([]byte{}, []byte("e")), + EndKey: codec.EncodeBytes([]byte{}, []byte("f")), + }, + }, + }, + }, + { + codec.EncodeBytes([]byte{}, []byte("c")), + codec.EncodeBytes([]byte{}, []byte("e")), + "region 6's leader is nil(.*?)", + []*RegionInfo{ + { + Region: &metapb.Region{ + Id: 6, + StartKey: codec.EncodeBytes([]byte{}, []byte("c")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + RegionEpoch: nil, + }, + }, + { + Region: &metapb.Region{ + Id: 8, + StartKey: codec.EncodeBytes([]byte{}, []byte("d")), + EndKey: codec.EncodeBytes([]byte{}, []byte("e")), + }, + }, + }, + }, + { + codec.EncodeBytes([]byte{}, []byte("c")), + codec.EncodeBytes([]byte{}, []byte("e")), + "region 6's leader's store id is 0(.*?)", + []*RegionInfo{ + { + Leader: &metapb.Peer{ + Id: 6, + StoreId: 0, + }, + Region: &metapb.Region{ + Id: 6, + StartKey: codec.EncodeBytes([]byte{}, []byte("c")), + EndKey: codec.EncodeBytes([]byte{}, []byte("d")), + RegionEpoch: nil, + }, + }, + { + Leader: &metapb.Peer{ + Id: 6, + StoreId: 0, + }, + Region: &metapb.Region{ + Id: 8, + StartKey: codec.EncodeBytes([]byte{}, []byte("d")), + EndKey: codec.EncodeBytes([]byte{}, []byte("e")), + }, + }, + }, + }, + } + for _, ca := range cases { + err := checkRegionConsistency(ca.startKey, ca.endKey, ca.regions) + require.Error(t, err) + require.Regexp(t, ca.err, err.Error()) + } +} diff --git a/pkg/lightning/backend/local/BUILD.bazel b/pkg/lightning/backend/local/BUILD.bazel new file mode 100644 index 0000000000000..c297e333d2d7d --- /dev/null +++ b/pkg/lightning/backend/local/BUILD.bazel @@ -0,0 +1,188 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "local", + srcs = [ + "checksum.go", + "compress.go", + "disk_quota.go", + "duplicate.go", + "engine.go", + "engine_mgr.go", + "iterator.go", + "local.go", + "local_freebsd.go", + "local_unix.go", + "local_unix_generic.go", + "local_windows.go", + "localhelper.go", + "region_job.go", + "tikv_mode.go", + ], + importpath = "github.com/pingcap/tidb/pkg/lightning/backend/local", + visibility = ["//visibility:public"], + deps = [ + "//br/pkg/checksum", + "//br/pkg/errors", + "//br/pkg/logutil", + "//br/pkg/membuf", + "//br/pkg/pdutil", + "//br/pkg/restore/split", + "//br/pkg/storage", + "//br/pkg/version", + "//pkg/distsql", + "//pkg/infoschema", + "//pkg/kv", + "//pkg/lightning/backend", + "//pkg/lightning/backend/encode", + "//pkg/lightning/backend/external", + "//pkg/lightning/backend/kv", + "//pkg/lightning/checkpoints", + "//pkg/lightning/common", + "//pkg/lightning/config", + "//pkg/lightning/errormanager", + "//pkg/lightning/log", + "//pkg/lightning/manual", + "//pkg/lightning/metric", + "//pkg/lightning/mydump", + "//pkg/lightning/tikv", + "//pkg/lightning/verification", + "//pkg/metrics", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/parser/terror", + "//pkg/sessionctx/variable", + "//pkg/table", + "//pkg/table/tables", + "//pkg/tablecodec", + "//pkg/util", + "//pkg/util/codec", + "//pkg/util/compress", + "//pkg/util/engine", + "//pkg/util/hack", + "//pkg/util/logutil", + "//pkg/util/mathutil", + "//pkg/util/ranger", + "@com_github_cockroachdb_pebble//:pebble", + "@com_github_cockroachdb_pebble//objstorage/objstorageprovider", + "@com_github_cockroachdb_pebble//sstable", + "@com_github_cockroachdb_pebble//vfs", + "@com_github_coreos_go_semver//semver", + "@com_github_docker_go_units//:go-units", + "@com_github_google_btree//:btree", + "@com_github_google_uuid//:uuid", + "@com_github_klauspost_compress//gzip", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/errorpb", + "@com_github_pingcap_kvproto//pkg/import_sstpb", + "@com_github_pingcap_kvproto//pkg/kvrpcpb", + "@com_github_pingcap_kvproto//pkg/metapb", + "@com_github_pingcap_tipb//go-tipb", + "@com_github_tikv_client_go_v2//kv", + "@com_github_tikv_client_go_v2//oracle", + "@com_github_tikv_client_go_v2//tikv", + "@com_github_tikv_client_go_v2//util", + "@com_github_tikv_pd_client//:client", + "@com_github_tikv_pd_client//http", + "@com_github_tikv_pd_client//retry", + "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//backoff", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//credentials", + "@org_golang_google_grpc//credentials/insecure", + "@org_golang_google_grpc//keepalive", + "@org_golang_google_grpc//status", + "@org_golang_x_sync//errgroup", + "@org_golang_x_time//rate", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_multierr//:multierr", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "local_test", + timeout = "short", + srcs = [ + "checksum_test.go", + "compress_test.go", + "disk_quota_test.go", + "duplicate_test.go", + "engine_mgr_test.go", + "engine_test.go", + "iterator_test.go", + "local_check_test.go", + "local_test.go", + "localhelper_test.go", + "main_test.go", + "region_job_test.go", + ], + embed = [":local"], + flaky = True, + race = "on", + shard_count = 50, + deps = [ + "//br/pkg/membuf", + "//br/pkg/mock/mocklocal", + "//br/pkg/restore/split", + "//br/pkg/storage", + "//br/pkg/utils", + "//pkg/ddl", + "//pkg/errno", + "//pkg/keyspace", + "//pkg/kv", + "//pkg/lightning/backend", + "//pkg/lightning/backend/encode", + "//pkg/lightning/backend/external", + "//pkg/lightning/backend/kv", + "//pkg/lightning/checkpoints", + "//pkg/lightning/common", + "//pkg/lightning/config", + "//pkg/lightning/log", + "//pkg/lightning/mydump", + "//pkg/parser", + "//pkg/parser/ast", + "//pkg/parser/model", + "//pkg/parser/mysql", + "//pkg/sessionctx/stmtctx", + "//pkg/store/pdtypes", + "//pkg/table", + "//pkg/table/tables", + "//pkg/tablecodec", + "//pkg/testkit/testsetup", + "//pkg/types", + "//pkg/util", + "//pkg/util/codec", + "//pkg/util/engine", + "//pkg/util/hack", + "//pkg/util/mock", + "@com_github_cockroachdb_pebble//:pebble", + "@com_github_cockroachdb_pebble//objstorage/objstorageprovider", + "@com_github_cockroachdb_pebble//sstable", + "@com_github_cockroachdb_pebble//vfs", + "@com_github_coreos_go_semver//semver", + "@com_github_data_dog_go_sqlmock//:go-sqlmock", + "@com_github_docker_go_units//:go-units", + "@com_github_go_sql_driver_mysql//:mysql", + "@com_github_google_uuid//:uuid", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_failpoint//:failpoint", + "@com_github_pingcap_kvproto//pkg/errorpb", + "@com_github_pingcap_kvproto//pkg/import_sstpb", + "@com_github_pingcap_kvproto//pkg/metapb", + "@com_github_pingcap_tipb//go-tipb", + "@com_github_stretchr_testify//require", + "@com_github_tikv_client_go_v2//oracle", + "@com_github_tikv_pd_client//:client", + "@com_github_tikv_pd_client//errs", + "@com_github_tikv_pd_client//http", + "@org_golang_google_grpc//:grpc", + "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//encoding", + "@org_golang_google_grpc//status", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_goleak//:goleak", + "@org_uber_go_mock//gomock", + ], +) diff --git a/pkg/lightning/backend/local/region_job.go b/pkg/lightning/backend/local/region_job.go new file mode 100644 index 0000000000000..7bc812e4b9bb6 --- /dev/null +++ b/pkg/lightning/backend/local/region_job.go @@ -0,0 +1,907 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "container/heap" + "context" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + sst "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +type jobStageTp string + +/* + + + v + +------+------+ + +->+regionScanned+<------+ + | +------+------+ | + | | | + | | | + | v | + | +--+--+ +-----+----+ + | |wrote+---->+needRescan| + | +--+--+ +-----+----+ + | | ^ + | | | + | v | + | +---+----+ | + +-----+ingested+---------+ + +---+----+ + | + v + +above diagram shows the state transition of a region job, here are some special +cases: + - regionScanned can directly jump to ingested if the keyRange has no data + - regionScanned can only transit to wrote. TODO: check if it should be transited + to needRescan + - if a job only partially writes the data, after it becomes ingested, it will + update its keyRange and transits to regionScanned to continue the remaining + data + - needRescan may output multiple regionScanned jobs when the old region is split +*/ +const ( + regionScanned jobStageTp = "regionScanned" + wrote jobStageTp = "wrote" + ingested jobStageTp = "ingested" + needRescan jobStageTp = "needRescan" + + // suppose each KV is about 32 bytes, 16 * units.KiB / 32 = 512 + defaultKVBatchCount = 512 +) + +func (j jobStageTp) String() string { + return string(j) +} + +// regionJob is dedicated to import the data in [keyRange.start, keyRange.end) +// to a region. The keyRange may be changed when processing because of writing +// partial data to TiKV or region split. +type regionJob struct { + keyRange common.Range + // TODO: check the keyRange so that it's always included in region + region *split.RegionInfo + // stage should be updated only by convertStageTo + stage jobStageTp + // writeResult is available only in wrote and ingested stage + writeResult *tikvWriteResult + + ingestData common.IngestData + regionSplitSize int64 + regionSplitKeys int64 + metrics *metric.Common + + retryCount int + waitUntil time.Time + lastRetryableErr error + + // injected is used in test to set the behaviour + injected []injectedBehaviour +} + +type tikvWriteResult struct { + sstMeta []*sst.SSTMeta + count int64 + totalBytes int64 + remainingStartKey []byte +} + +type injectedBehaviour struct { + write injectedWriteBehaviour + ingest injectedIngestBehaviour +} + +type injectedWriteBehaviour struct { + result *tikvWriteResult + err error +} + +type injectedIngestBehaviour struct { + nextStage jobStageTp + err error +} + +func (j *regionJob) convertStageTo(stage jobStageTp) { + j.stage = stage + switch stage { + case regionScanned: + j.writeResult = nil + case ingested: + // when writing is skipped because key range is empty + if j.writeResult == nil { + return + } + + j.ingestData.Finish(j.writeResult.totalBytes, j.writeResult.count) + if j.metrics != nil { + j.metrics.BytesCounter.WithLabelValues(metric.StateImported). + Add(float64(j.writeResult.totalBytes)) + } + case needRescan: + j.region = nil + } +} + +// ref means that the ingestData of job will be accessed soon. +func (j *regionJob) ref(wg *sync.WaitGroup) { + if wg != nil { + wg.Add(1) + } + if j.ingestData != nil { + j.ingestData.IncRef() + } +} + +// done promises that the ingestData of job will not be accessed. Same amount of +// done should be called to release the ingestData. +func (j *regionJob) done(wg *sync.WaitGroup) { + if j.ingestData != nil { + j.ingestData.DecRef() + } + if wg != nil { + wg.Done() + } +} + +// writeToTiKV writes the data to TiKV and mark this job as wrote stage. +// if any write logic has error, writeToTiKV will set job to a proper stage and return nil. +// if any underlying logic has error, writeToTiKV will return an error. +// we don't need to do cleanup for the pairs written to tikv if encounters an error, +// tikv will take the responsibility to do so. +// TODO: let client-go provide a high-level write interface. +func (local *Backend) writeToTiKV(ctx context.Context, j *regionJob) error { + err := local.doWrite(ctx, j) + if err == nil { + return nil + } + if !common.IsRetryableError(err) { + return err + } + // currently only one case will restart write + if strings.Contains(err.Error(), "RequestTooNew") { + j.convertStageTo(regionScanned) + return err + } + j.convertStageTo(needRescan) + return err +} + +func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { + if j.stage != regionScanned { + return nil + } + + failpoint.Inject("fakeRegionJobs", func() { + front := j.injected[0] + j.injected = j.injected[1:] + j.writeResult = front.write.result + err := front.write.err + if err == nil { + j.convertStageTo(wrote) + } + failpoint.Return(err) + }) + + var cancel context.CancelFunc + ctx, cancel = context.WithTimeoutCause(ctx, 15*time.Minute, common.ErrWriteTooSlow) + defer cancel() + + apiVersion := local.tikvCodec.GetAPIVersion() + clientFactory := local.importClientFactory + kvBatchSize := local.KVWriteBatchSize + bufferPool := local.engineMgr.getBufferPool() + writeLimiter := local.writeLimiter + + begin := time.Now() + region := j.region.Region + + firstKey, lastKey, err := j.ingestData.GetFirstAndLastKey(j.keyRange.Start, j.keyRange.End) + if err != nil { + return errors.Trace(err) + } + if firstKey == nil { + j.convertStageTo(ingested) + log.FromContext(ctx).Debug("keys within region is empty, skip doIngest", + logutil.Key("start", j.keyRange.Start), + logutil.Key("regionStart", region.StartKey), + logutil.Key("end", j.keyRange.End), + logutil.Key("regionEnd", region.EndKey)) + return nil + } + + firstKey = codec.EncodeBytes([]byte{}, firstKey) + lastKey = codec.EncodeBytes([]byte{}, lastKey) + + u := uuid.New() + meta := &sst.SSTMeta{ + Uuid: u[:], + RegionId: region.GetId(), + RegionEpoch: region.GetRegionEpoch(), + Range: &sst.Range{ + Start: firstKey, + End: lastKey, + }, + ApiVersion: apiVersion, + } + + failpoint.Inject("changeEpochVersion", func(val failpoint.Value) { + cloned := *meta.RegionEpoch + meta.RegionEpoch = &cloned + i := val.(int) + if i >= 0 { + meta.RegionEpoch.Version += uint64(i) + } else { + meta.RegionEpoch.ConfVer -= uint64(-i) + } + }) + + annotateErr := func(in error, peer *metapb.Peer, msg string) error { + // annotate the error with peer/store/region info to help debug. + return errors.Annotatef( + in, + "peer %d, store %d, region %d, epoch %s, %s", + peer.Id, peer.StoreId, region.Id, region.RegionEpoch.String(), + msg, + ) + } + + leaderID := j.region.Leader.GetId() + clients := make([]sst.ImportSST_WriteClient, 0, len(region.GetPeers())) + allPeers := make([]*metapb.Peer, 0, len(region.GetPeers())) + req := &sst.WriteRequest{ + Chunk: &sst.WriteRequest_Meta{ + Meta: meta, + }, + Context: &kvrpcpb.Context{ + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: local.ResourceGroupName, + }, + RequestSource: util.BuildRequestSource(true, kv.InternalTxnLightning, local.TaskType), + }, + } + for _, peer := range region.GetPeers() { + cli, err := clientFactory.Create(ctx, peer.StoreId) + if err != nil { + return annotateErr(err, peer, "when create client") + } + + wstream, err := cli.Write(ctx) + if err != nil { + return annotateErr(err, peer, "when open write stream") + } + + failpoint.Inject("mockWritePeerErr", func() { + err = errors.Errorf("mock write peer error") + failpoint.Return(annotateErr(err, peer, "when open write stream")) + }) + + // Bind uuid for this write request + if err = wstream.Send(req); err != nil { + return annotateErr(err, peer, "when send meta") + } + clients = append(clients, wstream) + allPeers = append(allPeers, peer) + } + dataCommitTS := j.ingestData.GetTS() + req.Chunk = &sst.WriteRequest_Batch{ + Batch: &sst.WriteBatch{ + CommitTs: dataCommitTS, + }, + } + + pairs := make([]*sst.Pair, 0, defaultKVBatchCount) + count := 0 + size := int64(0) + totalSize := int64(0) + totalCount := int64(0) + // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split + // because the range-properties is not 100% accurate + regionMaxSize := j.regionSplitSize + if j.regionSplitSize <= int64(config.SplitRegionSize) { + regionMaxSize = j.regionSplitSize * 4 / 3 + } + + flushKVs := func() error { + req.Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] + preparedMsg := &grpc.PreparedMsg{} + // by reading the source code, Encode need to find codec and compression from the stream + // because all stream has the same codec and compression, we can use any one of them + if err := preparedMsg.Encode(clients[0], req); err != nil { + return err + } + + for i := range clients { + if err := writeLimiter.WaitN(ctx, allPeers[i].StoreId, int(size)); err != nil { + return errors.Trace(err) + } + if err := clients[i].SendMsg(preparedMsg); err != nil { + if err == io.EOF { + // if it's EOF, need RecvMsg to get the error + dummy := &sst.WriteResponse{} + err = clients[i].RecvMsg(dummy) + } + return annotateErr(err, allPeers[i], "when send data") + } + } + failpoint.Inject("afterFlushKVs", func() { + log.FromContext(ctx).Info(fmt.Sprintf("afterFlushKVs count=%d,size=%d", count, size)) + }) + return nil + } + + iter := j.ingestData.NewIter(ctx, j.keyRange.Start, j.keyRange.End, bufferPool) + //nolint: errcheck + defer iter.Close() + + var remainingStartKey []byte + for iter.First(); iter.Valid(); iter.Next() { + k, v := iter.Key(), iter.Value() + kvSize := int64(len(k) + len(v)) + // here we reuse the `*sst.Pair`s to optimize object allocation + if count < len(pairs) { + pairs[count].Key = k + pairs[count].Value = v + } else { + pair := &sst.Pair{ + Key: k, + Value: v, + } + pairs = append(pairs, pair) + } + count++ + totalCount++ + size += kvSize + totalSize += kvSize + + if size >= kvBatchSize { + if err := flushKVs(); err != nil { + return errors.Trace(err) + } + count = 0 + size = 0 + iter.ReleaseBuf() + } + if totalSize >= regionMaxSize || totalCount >= j.regionSplitKeys { + // we will shrink the key range of this job to real written range + if iter.Next() { + remainingStartKey = append([]byte{}, iter.Key()...) + log.FromContext(ctx).Info("write to tikv partial finish", + zap.Int64("count", totalCount), + zap.Int64("size", totalSize), + logutil.Key("startKey", j.keyRange.Start), + logutil.Key("endKey", j.keyRange.End), + logutil.Key("remainStart", remainingStartKey), + logutil.Region(region), + logutil.Leader(j.region.Leader), + zap.Uint64("commitTS", dataCommitTS)) + } + break + } + } + + if iter.Error() != nil { + return errors.Trace(iter.Error()) + } + + if count > 0 { + if err := flushKVs(); err != nil { + return errors.Trace(err) + } + count = 0 + size = 0 + iter.ReleaseBuf() + } + + var leaderPeerMetas []*sst.SSTMeta + for i, wStream := range clients { + resp, closeErr := wStream.CloseAndRecv() + if closeErr != nil { + return annotateErr(closeErr, allPeers[i], "when close write stream") + } + if resp.Error != nil { + return annotateErr(errors.New("resp error: "+resp.Error.Message), allPeers[i], "when close write stream") + } + if leaderID == region.Peers[i].GetId() { + leaderPeerMetas = resp.Metas + log.FromContext(ctx).Debug("get metas after write kv stream to tikv", zap.Reflect("metas", leaderPeerMetas)) + } + } + + failpoint.Inject("NoLeader", func() { + log.FromContext(ctx).Warn("enter failpoint NoLeader") + leaderPeerMetas = nil + }) + + // if there is not leader currently, we don't forward the stage to wrote and let caller + // handle the retry. + if len(leaderPeerMetas) == 0 { + log.FromContext(ctx).Warn("write to tikv no leader", + logutil.Region(region), logutil.Leader(j.region.Leader), + zap.Uint64("leader_id", leaderID), logutil.SSTMeta(meta), + zap.Int64("kv_pairs", totalCount), zap.Int64("total_bytes", totalSize)) + return common.ErrNoLeader.GenWithStackByArgs(region.Id, leaderID) + } + + takeTime := time.Since(begin) + log.FromContext(ctx).Debug("write to kv", zap.Reflect("region", j.region), zap.Uint64("leader", leaderID), + zap.Reflect("meta", meta), zap.Reflect("return metas", leaderPeerMetas), + zap.Int64("kv_pairs", totalCount), zap.Int64("total_bytes", totalSize), + zap.Stringer("takeTime", takeTime)) + if m, ok := metric.FromContext(ctx); ok { + m.SSTSecondsHistogram.WithLabelValues(metric.SSTProcessWrite).Observe(takeTime.Seconds()) + } + + j.writeResult = &tikvWriteResult{ + sstMeta: leaderPeerMetas, + count: totalCount, + totalBytes: totalSize, + remainingStartKey: remainingStartKey, + } + j.convertStageTo(wrote) + return nil +} + +// ingest tries to finish the regionJob. +// if any ingest logic has error, ingest may retry sometimes to resolve it and finally +// set job to a proper stage with nil error returned. +// if any underlying logic has error, ingest will return an error to let caller +// handle it. +func (local *Backend) ingest(ctx context.Context, j *regionJob) (err error) { + if j.stage != wrote { + return nil + } + + failpoint.Inject("fakeRegionJobs", func() { + front := j.injected[0] + j.injected = j.injected[1:] + j.convertStageTo(front.ingest.nextStage) + failpoint.Return(front.ingest.err) + }) + + if len(j.writeResult.sstMeta) == 0 { + j.convertStageTo(ingested) + return nil + } + + if m, ok := metric.FromContext(ctx); ok { + begin := time.Now() + defer func() { + if err == nil { + m.SSTSecondsHistogram.WithLabelValues(metric.SSTProcessIngest).Observe(time.Since(begin).Seconds()) + } + }() + } + + for retry := 0; retry < maxRetryTimes; retry++ { + resp, err := local.doIngest(ctx, j) + if err == nil && resp.GetError() == nil { + j.convertStageTo(ingested) + return nil + } + if err != nil { + if common.IsContextCanceledError(err) { + return err + } + log.FromContext(ctx).Warn("meet underlying error, will retry ingest", + log.ShortError(err), logutil.SSTMetas(j.writeResult.sstMeta), + logutil.Region(j.region.Region), logutil.Leader(j.region.Leader)) + continue + } + canContinue, err := j.convertStageOnIngestError(resp) + if common.IsContextCanceledError(err) { + return err + } + if !canContinue { + log.FromContext(ctx).Warn("meet error and handle the job later", + zap.Stringer("job stage", j.stage), + logutil.ShortError(j.lastRetryableErr), + j.region.ToZapFields(), + logutil.Key("start", j.keyRange.Start), + logutil.Key("end", j.keyRange.End)) + return nil + } + log.FromContext(ctx).Warn("meet error and will doIngest region again", + logutil.ShortError(j.lastRetryableErr), + j.region.ToZapFields(), + logutil.Key("start", j.keyRange.Start), + logutil.Key("end", j.keyRange.End)) + } + return nil +} + +func (local *Backend) checkWriteStall( + ctx context.Context, + region *split.RegionInfo, +) (bool, *sst.IngestResponse, error) { + clientFactory := local.importClientFactory + for _, peer := range region.Region.GetPeers() { + cli, err := clientFactory.Create(ctx, peer.StoreId) + if err != nil { + return false, nil, errors.Trace(err) + } + // currently we use empty MultiIngestRequest to check if TiKV is busy. + // If in future the rate limit feature contains more metrics we can switch to use it. + resp, err := cli.MultiIngest(ctx, &sst.MultiIngestRequest{}) + if err != nil { + return false, nil, errors.Trace(err) + } + if resp.Error != nil && resp.Error.ServerIsBusy != nil { + return true, resp, nil + } + } + return false, nil, nil +} + +// doIngest send ingest commands to TiKV based on regionJob.writeResult.sstMeta. +// When meet error, it will remove finished sstMetas before return. +func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestResponse, error) { + clientFactory := local.importClientFactory + supportMultiIngest := local.supportMultiIngest + shouldCheckWriteStall := local.ShouldCheckWriteStall + if shouldCheckWriteStall { + writeStall, resp, err := local.checkWriteStall(ctx, j.region) + if err != nil { + return nil, errors.Trace(err) + } + if writeStall { + return resp, nil + } + } + + batch := 1 + if supportMultiIngest { + batch = len(j.writeResult.sstMeta) + } + + var resp *sst.IngestResponse + for start := 0; start < len(j.writeResult.sstMeta); start += batch { + end := min(start+batch, len(j.writeResult.sstMeta)) + ingestMetas := j.writeResult.sstMeta[start:end] + + log.FromContext(ctx).Debug("ingest meta", zap.Reflect("meta", ingestMetas)) + + failpoint.Inject("FailIngestMeta", func(val failpoint.Value) { + // only inject the error once + var resp *sst.IngestResponse + + switch val.(string) { + case "notleader": + resp = &sst.IngestResponse{ + Error: &errorpb.Error{ + NotLeader: &errorpb.NotLeader{ + RegionId: j.region.Region.Id, + Leader: j.region.Leader, + }, + }, + } + case "epochnotmatch": + resp = &sst.IngestResponse{ + Error: &errorpb.Error{ + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: []*metapb.Region{j.region.Region}, + }, + }, + } + } + failpoint.Return(resp, nil) + }) + + leader := j.region.Leader + if leader == nil { + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", j.region.Region.Id) + } + + cli, err := clientFactory.Create(ctx, leader.StoreId) + if err != nil { + return nil, errors.Trace(err) + } + reqCtx := &kvrpcpb.Context{ + RegionId: j.region.Region.GetId(), + RegionEpoch: j.region.Region.GetRegionEpoch(), + Peer: leader, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: local.ResourceGroupName, + }, + RequestSource: util.BuildRequestSource(true, kv.InternalTxnLightning, local.TaskType), + } + + if supportMultiIngest { + req := &sst.MultiIngestRequest{ + Context: reqCtx, + Ssts: ingestMetas, + } + resp, err = cli.MultiIngest(ctx, req) + } else { + req := &sst.IngestRequest{ + Context: reqCtx, + Sst: ingestMetas[0], + } + resp, err = cli.Ingest(ctx, req) + } + if resp.GetError() != nil || err != nil { + // remove finished sstMetas + j.writeResult.sstMeta = j.writeResult.sstMeta[start:] + return resp, errors.Trace(err) + } + } + return resp, nil +} + +// convertStageOnIngestError will try to fix the error contained in ingest response. +// Return (_, error) when another error occurred. +// Return (true, nil) when the job can retry ingesting immediately. +// Return (false, nil) when the job should be put back to queue. +func (j *regionJob) convertStageOnIngestError( + resp *sst.IngestResponse, +) (bool, error) { + if resp.GetError() == nil { + return true, nil + } + + var newRegion *split.RegionInfo + switch errPb := resp.GetError(); { + case errPb.NotLeader != nil: + j.lastRetryableErr = common.ErrKVNotLeader.GenWithStack(errPb.GetMessage()) + + // meet a problem that the region leader+peer are all updated but the return + // error is only "NotLeader", we should update the whole region info. + j.convertStageTo(needRescan) + return false, nil + case errPb.EpochNotMatch != nil: + j.lastRetryableErr = common.ErrKVEpochNotMatch.GenWithStack(errPb.GetMessage()) + + if currentRegions := errPb.GetEpochNotMatch().GetCurrentRegions(); currentRegions != nil { + var currentRegion *metapb.Region + for _, r := range currentRegions { + if insideRegion(r, j.writeResult.sstMeta) { + currentRegion = r + break + } + } + if currentRegion != nil { + var newLeader *metapb.Peer + for _, p := range currentRegion.Peers { + if p.GetStoreId() == j.region.Leader.GetStoreId() { + newLeader = p + break + } + } + if newLeader != nil { + newRegion = &split.RegionInfo{ + Leader: newLeader, + Region: currentRegion, + } + } + } + } + if newRegion != nil { + j.region = newRegion + j.convertStageTo(regionScanned) + return false, nil + } + j.convertStageTo(needRescan) + return false, nil + case strings.Contains(errPb.Message, "raft: proposal dropped"): + j.lastRetryableErr = common.ErrKVRaftProposalDropped.GenWithStack(errPb.GetMessage()) + + j.convertStageTo(needRescan) + return false, nil + case errPb.ServerIsBusy != nil: + j.lastRetryableErr = common.ErrKVServerIsBusy.GenWithStack(errPb.GetMessage()) + + return false, nil + case errPb.RegionNotFound != nil: + j.lastRetryableErr = common.ErrKVRegionNotFound.GenWithStack(errPb.GetMessage()) + + j.convertStageTo(needRescan) + return false, nil + case errPb.ReadIndexNotReady != nil: + j.lastRetryableErr = common.ErrKVReadIndexNotReady.GenWithStack(errPb.GetMessage()) + + // this error happens when this region is splitting, the error might be: + // read index not ready, reason can not read index due to split, region 64037 + // we have paused schedule, but it's temporary, + // if next request takes a long time, there's chance schedule is enabled again + // or on key range border, another engine sharing this region tries to split this + // region may cause this error too. + j.convertStageTo(needRescan) + return false, nil + case errPb.DiskFull != nil: + j.lastRetryableErr = common.ErrKVIngestFailed.GenWithStack(errPb.GetMessage()) + + return false, errors.Errorf("non-retryable error: %s", resp.GetError().GetMessage()) + } + // all others doIngest error, such as stale command, etc. we'll retry it again from writeAndIngestByRange + j.lastRetryableErr = common.ErrKVIngestFailed.GenWithStack(resp.GetError().GetMessage()) + j.convertStageTo(regionScanned) + return false, nil +} + +type regionJobRetryHeap []*regionJob + +var _ heap.Interface = (*regionJobRetryHeap)(nil) + +func (h *regionJobRetryHeap) Len() int { + return len(*h) +} + +func (h *regionJobRetryHeap) Less(i, j int) bool { + v := *h + return v[i].waitUntil.Before(v[j].waitUntil) +} + +func (h *regionJobRetryHeap) Swap(i, j int) { + v := *h + v[i], v[j] = v[j], v[i] +} + +func (h *regionJobRetryHeap) Push(x any) { + *h = append(*h, x.(*regionJob)) +} + +func (h *regionJobRetryHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// regionJobRetryer is a concurrent-safe queue holding jobs that need to put +// back later, and put back when the regionJob.waitUntil is reached. It maintains +// a heap of jobs internally based on the regionJob.waitUntil field. +type regionJobRetryer struct { + // lock acquiring order: protectedClosed > protectedQueue > protectedToPutBack + protectedClosed struct { + mu sync.Mutex + closed bool + } + protectedQueue struct { + mu sync.Mutex + q regionJobRetryHeap + } + protectedToPutBack struct { + mu sync.Mutex + toPutBack *regionJob + } + putBackCh chan<- *regionJob + reload chan struct{} + jobWg *sync.WaitGroup +} + +// startRegionJobRetryer starts a new regionJobRetryer and it will run in +// background to put the job back to `putBackCh` when job's waitUntil is reached. +// Cancel the `ctx` will stop retryer and `jobWg.Done` will be trigger for jobs +// that are not put back yet. +func startRegionJobRetryer( + ctx context.Context, + putBackCh chan<- *regionJob, + jobWg *sync.WaitGroup, +) *regionJobRetryer { + ret := ®ionJobRetryer{ + putBackCh: putBackCh, + reload: make(chan struct{}, 1), + jobWg: jobWg, + } + ret.protectedQueue.q = make(regionJobRetryHeap, 0, 16) + go ret.run(ctx) + return ret +} + +// run is only internally used, caller should not use it. +func (q *regionJobRetryer) run(ctx context.Context) { + defer q.close() + + for { + var front *regionJob + q.protectedQueue.mu.Lock() + if len(q.protectedQueue.q) > 0 { + front = q.protectedQueue.q[0] + } + q.protectedQueue.mu.Unlock() + + switch { + case front != nil: + select { + case <-ctx.Done(): + return + case <-q.reload: + case <-time.After(time.Until(front.waitUntil)): + q.protectedQueue.mu.Lock() + q.protectedToPutBack.mu.Lock() + q.protectedToPutBack.toPutBack = heap.Pop(&q.protectedQueue.q).(*regionJob) + // release the lock of queue to avoid blocking regionJobRetryer.push + q.protectedQueue.mu.Unlock() + + // hold the lock of toPutBack to make sending to putBackCh and + // resetting toPutBack atomic w.r.t. regionJobRetryer.close + select { + case <-ctx.Done(): + q.protectedToPutBack.mu.Unlock() + return + case q.putBackCh <- q.protectedToPutBack.toPutBack: + q.protectedToPutBack.toPutBack = nil + q.protectedToPutBack.mu.Unlock() + } + } + default: + // len(q.q) == 0 + select { + case <-ctx.Done(): + return + case <-q.reload: + } + } + } +} + +// close is only internally used, caller should not use it. +func (q *regionJobRetryer) close() { + q.protectedClosed.mu.Lock() + defer q.protectedClosed.mu.Unlock() + q.protectedClosed.closed = true + + if q.protectedToPutBack.toPutBack != nil { + q.protectedToPutBack.toPutBack.done(q.jobWg) + } + for _, job := range q.protectedQueue.q { + job.done(q.jobWg) + } +} + +// push should not be blocked for long time in any cases. +func (q *regionJobRetryer) push(job *regionJob) bool { + q.protectedClosed.mu.Lock() + defer q.protectedClosed.mu.Unlock() + if q.protectedClosed.closed { + return false + } + + q.protectedQueue.mu.Lock() + heap.Push(&q.protectedQueue.q, job) + q.protectedQueue.mu.Unlock() + + select { + case q.reload <- struct{}{}: + default: + } + return true +}