diff --git a/doc/command-line-flags.md b/doc/command-line-flags.md index ac491b268..a54cb62e0 100644 --- a/doc/command-line-flags.md +++ b/doc/command-line-flags.md @@ -64,6 +64,9 @@ It is not reliable to parse the `ALTER` statement to determine if it is instant ### binlogsyncer-max-reconnect-attempts `--binlogsyncer-max-reconnect-attempts=0`, the maximum number of attempts to re-establish a broken inspector connection for sync binlog. `0` or `negative number` means infinite retry, default `0` +### chunk-concurrent-size +`--chunk-concurrent-size=1`, The number of goroutines to execute chunks concurrently in each copy time slot, default `1`, allowed range `0`-`100`. + ### conf `--conf=/path/to/my.cnf`: file where credentials are specified. Should be in (or contain) the following format: diff --git a/go.mod b/go.mod index 9be1453fc..4938b0ba5 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( golang.org/x/net v0.17.0 golang.org/x/term v0.13.0 golang.org/x/text v0.13.0 + golang.org/x/sync v0.1.0 ) require ( diff --git a/go.sum b/go.sum index 1a540d1cc..d773989c8 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,8 @@ golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/go/base/context.go b/go/base/context.go index 300ec1201..a18194678 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -72,6 +72,14 @@ func NewThrottleCheckResult(throttle bool, reason string, reasonHint ThrottleRea } } +type IterationRangeValues struct { + Min *sql.ColumnValues + Max *sql.ColumnValues + Size int64 + IsIncludeMinValues bool + HasFurtherRange bool +} + // MigrationContext has the general, global state of migration. It is used by // all components throughout the migration process. type MigrationContext struct { @@ -119,6 +127,7 @@ type MigrationContext struct { HeartbeatIntervalMilliseconds int64 defaultNumRetries int64 ChunkSize int64 + ChunkConcurrentSize int64 niceRatio float64 MaxLagMillisecondsThrottleThreshold int64 throttleControlReplicaKeys *mysql.InstanceKeyMap @@ -210,25 +219,26 @@ type MigrationContext struct { InCutOverCriticalSectionFlag int64 PanicAbort chan error - OriginalTableColumnsOnApplier *sql.ColumnList - OriginalTableColumns *sql.ColumnList - OriginalTableVirtualColumns *sql.ColumnList - OriginalTableUniqueKeys [](*sql.UniqueKey) - OriginalTableAutoIncrement uint64 - GhostTableColumns *sql.ColumnList - GhostTableVirtualColumns *sql.ColumnList - GhostTableUniqueKeys [](*sql.UniqueKey) - UniqueKey *sql.UniqueKey - SharedColumns *sql.ColumnList - ColumnRenameMap map[string]string - DroppedColumnsMap map[string]bool - MappedSharedColumns *sql.ColumnList - MigrationRangeMinValues *sql.ColumnValues - MigrationRangeMaxValues *sql.ColumnValues - Iteration int64 - MigrationIterationRangeMinValues *sql.ColumnValues - MigrationIterationRangeMaxValues *sql.ColumnValues - ForceTmpTableName string + OriginalTableColumnsOnApplier *sql.ColumnList + OriginalTableColumns *sql.ColumnList + OriginalTableVirtualColumns *sql.ColumnList + OriginalTableUniqueKeys [](*sql.UniqueKey) + OriginalTableAutoIncrement uint64 + GhostTableColumns *sql.ColumnList + GhostTableVirtualColumns *sql.ColumnList + GhostTableUniqueKeys [](*sql.UniqueKey) + UniqueKey *sql.UniqueKey + SharedColumns *sql.ColumnList + ColumnRenameMap map[string]string + DroppedColumnsMap map[string]bool + MappedSharedColumns *sql.ColumnList + MigrationRangeMinValues *sql.ColumnValues + MigrationRangeMaxValues *sql.ColumnValues + Iteration int64 + MigrationIterationRangeMinValues *sql.ColumnValues + MigrationIterationRangeMaxValues *sql.ColumnValues + CalculateNextIterationRangeEndValuesLock *sync.Mutex + ForceTmpTableName string recentBinlogCoordinates mysql.BinlogCoordinates @@ -269,26 +279,27 @@ type ContextConfig struct { func NewMigrationContext() *MigrationContext { return &MigrationContext{ - Uuid: uuid.NewString(), - defaultNumRetries: 60, - ChunkSize: 1000, - InspectorConnectionConfig: mysql.NewConnectionConfig(), - ApplierConnectionConfig: mysql.NewConnectionConfig(), - MaxLagMillisecondsThrottleThreshold: 1500, - CutOverLockTimeoutSeconds: 3, - DMLBatchSize: 10, - etaNanoseonds: ETAUnknown, - maxLoad: NewLoadMap(), - criticalLoad: NewLoadMap(), - throttleMutex: &sync.Mutex{}, - throttleHTTPMutex: &sync.Mutex{}, - throttleControlReplicaKeys: mysql.NewInstanceKeyMap(), - configMutex: &sync.Mutex{}, - pointOfInterestTimeMutex: &sync.Mutex{}, - lastHeartbeatOnChangelogMutex: &sync.Mutex{}, - ColumnRenameMap: make(map[string]string), - PanicAbort: make(chan error), - Log: NewDefaultLogger(), + Uuid: uuid.NewString(), + defaultNumRetries: 60, + ChunkSize: 1000, + InspectorConnectionConfig: mysql.NewConnectionConfig(), + ApplierConnectionConfig: mysql.NewConnectionConfig(), + MaxLagMillisecondsThrottleThreshold: 1500, + CutOverLockTimeoutSeconds: 3, + DMLBatchSize: 10, + etaNanoseonds: ETAUnknown, + maxLoad: NewLoadMap(), + criticalLoad: NewLoadMap(), + throttleMutex: &sync.Mutex{}, + throttleHTTPMutex: &sync.Mutex{}, + throttleControlReplicaKeys: mysql.NewInstanceKeyMap(), + configMutex: &sync.Mutex{}, + pointOfInterestTimeMutex: &sync.Mutex{}, + lastHeartbeatOnChangelogMutex: &sync.Mutex{}, + CalculateNextIterationRangeEndValuesLock: &sync.Mutex{}, + ColumnRenameMap: make(map[string]string), + PanicAbort: make(chan error), + Log: NewDefaultLogger(), } } @@ -616,6 +627,16 @@ func (this *MigrationContext) SetChunkSize(chunkSize int64) { atomic.StoreInt64(&this.ChunkSize, chunkSize) } +func (this *MigrationContext) SetChunkConcurrentSize(chunkConcurrentSize int64) { + if chunkConcurrentSize < 1 { + chunkConcurrentSize = 1 + } + if chunkConcurrentSize > 100 { + chunkConcurrentSize = 100 + } + atomic.StoreInt64(&this.ChunkConcurrentSize, chunkConcurrentSize) +} + func (this *MigrationContext) SetDMLBatchSize(batchSize int64) { if batchSize < 1 { batchSize = 1 diff --git a/go/cmd/gh-ost/main.go b/go/cmd/gh-ost/main.go index 139703077..6dc7e27e0 100644 --- a/go/cmd/gh-ost/main.go +++ b/go/cmd/gh-ost/main.go @@ -104,6 +104,7 @@ func main() { flag.BoolVar(&migrationContext.CutOverExponentialBackoff, "cut-over-exponential-backoff", false, "Wait exponentially longer intervals between failed cut-over attempts. Wait intervals obey a maximum configurable with 'exponential-backoff-max-interval').") exponentialBackoffMaxInterval := flag.Int64("exponential-backoff-max-interval", 64, "Maximum number of seconds to wait between attempts when performing various operations with exponential backoff.") chunkSize := flag.Int64("chunk-size", 1000, "amount of rows to handle in each iteration (allowed range: 10-100,000)") + chunkConcurrentSize := flag.Int64("chunk-concurrent-size", 1, "The number of goroutines to execute chunks concurrently in each copy time slot, default 1 (allowed range: 1-100)") dmlBatchSize := flag.Int64("dml-batch-size", 10, "batch size for DML events to apply in a single transaction (range 1-100)") defaultRetries := flag.Int64("default-retries", 60, "Default number of retries for various operations before panicking") cutOverLockTimeoutSeconds := flag.Int64("cut-over-lock-timeout-seconds", 3, "Max number of seconds to hold locks on tables while attempting to cut-over (retry attempted when lock exceeds timeout)") @@ -294,6 +295,7 @@ func main() { migrationContext.SetHeartbeatIntervalMilliseconds(*heartbeatIntervalMillis) migrationContext.SetNiceRatio(*niceRatio) migrationContext.SetChunkSize(*chunkSize) + migrationContext.SetChunkConcurrentSize(*chunkConcurrentSize) migrationContext.SetDMLBatchSize(*dmlBatchSize) migrationContext.SetMaxLagMillisecondsThrottleThreshold(*maxLagMillis) migrationContext.SetThrottleQuery(*throttleQuery) diff --git a/go/logic/applier.go b/go/logic/applier.go index fa374a70f..4ae931c67 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -76,6 +76,10 @@ func (this *Applier) InitDBConnections() (err error) { if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { return err } + concurrentSize := atomic.LoadInt64(&this.migrationContext.ChunkConcurrentSize) + if concurrentSize > mysql.MaxDBPoolConnections { + this.db.SetMaxOpenConns(int(concurrentSize)) + } singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) if this.singletonDB, _, err = mysql.GetDB(this.migrationContext.Uuid, singletonApplierUri); err != nil { return err @@ -561,11 +565,18 @@ func (this *Applier) ReadMigrationRangeValues() error { // which will be used for copying the next chunk of rows. Ir returns "false" if there is // no further chunk to work through, i.e. we're past the last chunk and are done with // iterating the range (and this done with copying row chunks) -func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange bool, err error) { - this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationIterationRangeMaxValues - if this.migrationContext.MigrationIterationRangeMinValues == nil { - this.migrationContext.MigrationIterationRangeMinValues = this.migrationContext.MigrationRangeMinValues +func (this *Applier) CalculateNextIterationRangeEndValues() (values *base.IterationRangeValues, err error) { + this.migrationContext.CalculateNextIterationRangeEndValuesLock.Lock() + defer this.migrationContext.CalculateNextIterationRangeEndValuesLock.Unlock() + + iterationRangeValues := &base.IterationRangeValues{} + + iterationRangeValues.Min = this.migrationContext.MigrationIterationRangeMaxValues + if iterationRangeValues.Min == nil { + iterationRangeValues.Min = this.migrationContext.MigrationRangeMinValues + iterationRangeValues.IsIncludeMinValues = true } + for i := 0; i < 2; i++ { buildFunc := sql.BuildUniqueKeyRangeEndPreparedQueryViaOffset if i == 1 { @@ -575,46 +586,48 @@ func (this *Applier) CalculateNextIterationRangeEndValues() (hasFurtherRange boo this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, &this.migrationContext.UniqueKey.Columns, - this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), + iterationRangeValues.Min.AbstractValues(), this.migrationContext.MigrationRangeMaxValues.AbstractValues(), atomic.LoadInt64(&this.migrationContext.ChunkSize), - this.migrationContext.GetIteration() == 0, + iterationRangeValues.IsIncludeMinValues, fmt.Sprintf("iteration:%d", this.migrationContext.GetIteration()), ) if err != nil { - return hasFurtherRange, err + return iterationRangeValues, err } rows, err := this.db.Query(query, explodedArgs...) if err != nil { - return hasFurtherRange, err + return iterationRangeValues, err } defer rows.Close() iterationRangeMaxValues := sql.NewColumnValues(this.migrationContext.UniqueKey.Len()) for rows.Next() { if err = rows.Scan(iterationRangeMaxValues.ValuesPointers...); err != nil { - return hasFurtherRange, err + return iterationRangeValues, err } - hasFurtherRange = true + iterationRangeValues.HasFurtherRange = true } if err = rows.Err(); err != nil { - return hasFurtherRange, err + return iterationRangeValues, err } - if hasFurtherRange { - this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeMaxValues - return hasFurtherRange, nil + if iterationRangeValues.HasFurtherRange { + iterationRangeValues.Max = iterationRangeMaxValues + this.migrationContext.MigrationIterationRangeMinValues = iterationRangeValues.Min + this.migrationContext.MigrationIterationRangeMaxValues = iterationRangeValues.Max + return iterationRangeValues, nil } } this.migrationContext.Log.Debugf("Iteration complete: no further range to iterate") - return hasFurtherRange, nil + return iterationRangeValues, nil } // ApplyIterationInsertQuery issues a chunk-INSERT query on the ghost table. It is where // data actually gets copied from original table. -func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected int64, duration time.Duration, err error) { +func (this *Applier) ApplyIterationInsertQuery(iterationRangeValues *base.IterationRangeValues) (chunkSize int64, rowsAffected int64, duration time.Duration, err error) { startTime := time.Now() - chunkSize = atomic.LoadInt64(&this.migrationContext.ChunkSize) + chunkSize = iterationRangeValues.Size query, explodedArgs, err := sql.BuildRangeInsertPreparedQuery( this.migrationContext.DatabaseName, @@ -624,9 +637,9 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected this.migrationContext.MappedSharedColumns.Names(), this.migrationContext.UniqueKey.Name, &this.migrationContext.UniqueKey.Columns, - this.migrationContext.MigrationIterationRangeMinValues.AbstractValues(), - this.migrationContext.MigrationIterationRangeMaxValues.AbstractValues(), - this.migrationContext.GetIteration() == 0, + iterationRangeValues.Min.AbstractValues(), + iterationRangeValues.Max.AbstractValues(), + iterationRangeValues.IsIncludeMinValues, this.migrationContext.IsTransactionalTable(), ) if err != nil { @@ -663,8 +676,8 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected duration = time.Since(startTime) this.migrationContext.Log.Debugf( "Issued INSERT on range: [%s]..[%s]; iteration: %d; chunk-size: %d", - this.migrationContext.MigrationIterationRangeMinValues, - this.migrationContext.MigrationIterationRangeMaxValues, + iterationRangeValues.Min, + iterationRangeValues.Max, this.migrationContext.GetIteration(), chunkSize) return chunkSize, rowsAffected, duration, nil diff --git a/go/logic/migrator.go b/go/logic/migrator.go index fed7c944b..380282cdb 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -9,6 +9,7 @@ import ( "context" "errors" "fmt" + "golang.org/x/sync/errgroup" "io" "math" "os" @@ -1190,41 +1191,68 @@ func (this *Migrator) iterateChunks() error { // When hasFurtherRange is false, original table might be write locked and CalculateNextIterationRangeEndValues would hangs forever - hasFurtherRange := false - if err := this.retryOperation(func() (e error) { - hasFurtherRange, e = this.applier.CalculateNextIterationRangeEndValues() - return e - }); err != nil { - return terminateRowIteration(err) - } - if !hasFurtherRange { - atomic.StoreInt64(&hasNoFurtherRangeFlag, 1) - return terminateRowIteration(nil) + concurrentSize := atomic.LoadInt64(&this.migrationContext.ChunkConcurrentSize) + if concurrentSize == 0 { + concurrentSize = 1 } - // Copy task: - applyCopyRowsFunc := func() error { - if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { - // No need for more writes. - // This is the de-facto place where we avoid writing in the event of completed cut-over. - // There could _still_ be a race condition, but that's as close as we can get. - // What about the race condition? Well, there's actually no data integrity issue. - // when rowCopyCompleteFlag==1 that means **guaranteed** all necessary rows have been copied. - // But some are still then collected at the binary log, and these are the ones we're trying to - // not apply here. If the race condition wins over us, then we just attempt to apply onto the - // _ghost_ table, which no longer exists. So, bothering error messages and all, but no damage. + + g, _ := errgroup.WithContext(context.Background()) + g.SetLimit(int(concurrentSize)) + + for i := 0; i < int(concurrentSize); i++ { + g.Go(func() error { + var iterationRangeValues *base.IterationRangeValues + + if err := this.retryOperation(func() (e error) { + iterationRangeValues, e = this.applier.CalculateNextIterationRangeEndValues() + return e + }); err != nil { + return err + } + + if !iterationRangeValues.HasFurtherRange { + atomic.StoreInt64(&hasNoFurtherRangeFlag, 1) + return nil + } + + // Copy task: + applyCopyRowsFunc := func() error { + if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 { + // No need for more writes. + // This is the de-facto place where we avoid writing in the event of completed cut-over. + // There could _still_ be a race condition, but that's as close as we can get. + // What about the race condition? Well, there's actually no data integrity issue. + // when rowCopyCompleteFlag==1 that means **guaranteed** all necessary rows have been copied. + // But some are still then collected at the binary log, and these are the ones we're trying to + // not apply here. If the race condition wins over us, then we just attempt to apply onto the + // _ghost_ table, which no longer exists. So, bothering error messages and all, but no damage. + return nil + } + _, rowsAffected, _, err := this.applier.ApplyIterationInsertQuery(iterationRangeValues) + if err != nil { + return err // wrapping call will retry + } + atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) + atomic.AddInt64(&this.migrationContext.Iteration, 1) + return nil + } + + if err := this.retryOperation(applyCopyRowsFunc); err != nil { + return err + } + return nil - } - _, rowsAffected, _, err := this.applier.ApplyIterationInsertQuery() - if err != nil { - return err // wrapping call will retry - } - atomic.AddInt64(&this.migrationContext.TotalRowsCopied, rowsAffected) - atomic.AddInt64(&this.migrationContext.Iteration, 1) - return nil + }) } - if err := this.retryOperation(applyCopyRowsFunc); err != nil { + + if err := g.Wait(); err != nil { return terminateRowIteration(err) } + + if atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { + return terminateRowIteration(nil) + } + return nil } // Enqueue copy operation; to be executed by executeWriteFuncs() diff --git a/vendor/golang.org/x/sync/LICENSE b/vendor/golang.org/x/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/vendor/golang.org/x/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/sync/PATENTS b/vendor/golang.org/x/sync/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/vendor/golang.org/x/sync/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/sync/errgroup/errgroup.go b/vendor/golang.org/x/sync/errgroup/errgroup.go new file mode 100644 index 000000000..cbee7a4e2 --- /dev/null +++ b/vendor/golang.org/x/sync/errgroup/errgroup.go @@ -0,0 +1,132 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup provides synchronization, error propagation, and Context +// cancelation for groups of goroutines working on subtasks of a common task. +package errgroup + +import ( + "context" + "fmt" + "sync" +) + +type token struct{} + +// A Group is a collection of goroutines working on subtasks that are part of +// the same overall task. +// +// A zero Group is valid, has no limit on the number of active goroutines, +// and does not cancel on error. +type Group struct { + cancel func() + + wg sync.WaitGroup + + sem chan token + + errOnce sync.Once + err error +} + +func (g *Group) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + +// WithContext returns a new Group and an associated Context derived from ctx. +// +// The derived Context is canceled the first time a function passed to Go +// returns a non-nil error or the first time Wait returns, whichever occurs +// first. +func WithContext(ctx context.Context) (*Group, context.Context) { + ctx, cancel := context.WithCancel(ctx) + return &Group{cancel: cancel}, ctx +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +func (g *Group) Wait() error { + g.wg.Wait() + if g.cancel != nil { + g.cancel() + } + return g.err +} + +// Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. +// +// The first call to return a non-nil error cancels the group's context, if the +// group was created by calling WithContext. The error will be returned by Wait. +func (g *Group) Go(f func() error) { + if g.sem != nil { + g.sem <- token{} + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *Group) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *Group) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 19cadde30..09ef4f692 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -43,6 +43,9 @@ go.uber.org/atomic # golang.org/x/net v0.17.0 ## explicit; go 1.17 golang.org/x/net/context +# golang.org/x/sync v0.1.0 +## explicit +golang.org/x/sync/errgroup # golang.org/x/sys v0.13.0 ## explicit; go 1.17 golang.org/x/sys/plan9