diff --git a/bench_test.go b/bench_test.go index dd0e48f0..0fac403e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -13,8 +13,6 @@ import ( ) func BenchmarkLDBQueryBaseline(b *testing.B) { - b.StopTimer() - ctx := context.TODO() tmpDir, err := ioutil.TempDir("", "") @@ -51,8 +49,8 @@ func BenchmarkLDBQueryBaseline(b *testing.B) { b.Fatalf("Unexpected error preparing query: %v", err) } - b.StartTimer() - + b.ReportAllocs() + b.ResetTimer() for i := 0; i < b.N; i++ { rows, err := prepQ.QueryContext(ctx, "foo") if err != nil { @@ -123,6 +121,8 @@ func BenchmarkGetRowByKey(b *testing.B) { r: r, } + b.ReportAllocs() + b.ResetTimer() for i := 0; i < b.N; i++ { var row benchKVRow found, err := benchSetup.r.GetRowByKey(benchSetup.ctx, &row, "foo", "bar", "foo") diff --git a/docker-compose.yml b/docker-compose.yml index e4a35e85..da71e6e5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,4 +12,3 @@ services: MYSQL_DATABASE: ctldb MYSQL_USER: ctldb MYSQL_PASSWORD: ctldbpw - mem_limit: 536870912 diff --git a/go.mod b/go.mod index 0e1cad6a..6ad8c0b0 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/pkg/errors v0.9.1 github.com/segmentio/cli v0.5.1 github.com/segmentio/conf v1.1.0 - github.com/segmentio/errors-go v1.0.0 github.com/segmentio/events/v2 v2.3.2 github.com/segmentio/go-sqlite3 v1.12.0 github.com/segmentio/stats/v4 v4.6.2 diff --git a/go.sum b/go.sum index 64a25a05..8fe5cbf1 100644 --- a/go.sum +++ b/go.sum @@ -105,8 +105,6 @@ github.com/segmentio/cli v0.5.1 h1:Xhtnmp0LrF+JHQTTV4Q58S79gG8JKXO4MMniyqc+XZs= github.com/segmentio/cli v0.5.1/go.mod h1:qz2M+DqXgYnjKLTrcI80MoGQsI6xT0wXCozfBAtF/iI= github.com/segmentio/conf v1.1.0 h1:3d8AaXnQNLCze/UpZ31pwDpDj+tmb2FIwroOtqCYNBY= github.com/segmentio/conf v1.1.0/go.mod h1:Y3B9O/PqqWqjyxyWWseyj/quPEtMu1zDp/kVbSWWaB0= -github.com/segmentio/errors-go v1.0.0 h1:B4mbo4hP3+XffV1GhwyAcHlvWoZtYdTyc3BOVPxspTQ= -github.com/segmentio/errors-go v1.0.0/go.mod h1:RDVEREUrpa4/jM8rt5KsQpu+JoXPi6i07vG7m4tX0MY= github.com/segmentio/events/v2 v2.3.2 h1:J73yVqYtnLWZD3Oqef82fYPZhfpRfQGiOvBes+OohoY= github.com/segmentio/events/v2 v2.3.2/go.mod h1:9HY7dFOCKoPQx3hUBXYim6I4hqaZWtSGWJ4IYAMxtkM= github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e h1:uO75wNGioszjmIzcY/tvdDYKRLVvzggtAmmJkn9j4GQ= diff --git a/ldb_reader.go b/ldb_reader.go index 1aea253c..bc468d27 100644 --- a/ldb_reader.go +++ b/ldb_reader.go @@ -3,6 +3,7 @@ package ctlstore import ( "context" "database/sql" + "errors" "fmt" "os" "path/filepath" @@ -11,7 +12,6 @@ import ( "sync" "time" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" @@ -64,14 +64,14 @@ func newVersionedLDBReader(dirPath string) (*LDBReader, error) { // To initialize this reader, we must first load an LDB: last, err := lookupLastLDBSync(dirPath) if err != nil { - return nil, errors.Wrap(err, "checking last ldb sync") + return nil, fmt.Errorf("checking last ldb sync: %w", err) } if last == 0 { return nil, fmt.Errorf("no LDB in path (%s)", dirPath) } err = reader.switchLDB(dirPath, last) if err != nil { - return nil, errors.Wrap(err, "switching ldbs") + return nil, fmt.Errorf("switching ldbs: %w", err) } // Then we can defer to the watcher goroutine to swap this @@ -133,7 +133,7 @@ func (reader *LDBReader) GetLedgerLatency(ctx context.Context) (time.Duration, e case err == sql.ErrNoRows: return 0, ErrNoLedgerUpdates case err != nil: - return 0, errors.Wrap(err, "get ledger latency") + return 0, fmt.Errorf("get ledger latency: %w", err) default: return time.Now().Sub(timestamp), nil } @@ -281,7 +281,7 @@ func (reader *LDBReader) GetRowByKey( if err != nil { // See NOTE above about why this cache is getting cleared reader.invalidatePKCache(ldbTable) // assumes RLock is held - err = errors.Wrap(err, "query target row error") + err = fmt.Errorf("query target row error: %w", err) return } defer rows.Close() @@ -306,7 +306,7 @@ func (reader *LDBReader) GetRowByKey( err = scanFunc(rows) if err != nil { - err = errors.Wrap(err, "target row scan error") + err = fmt.Errorf("target row scan error: %w", err) } else { err = rows.Err() } @@ -430,7 +430,7 @@ func (reader *LDBReader) getPrimaryKey(ctx context.Context, ldbTable string) (sc const qs = "SELECT name,type FROM pragma_table_info(?) WHERE pk > 0 ORDER BY pk ASC" rows, err := reader.Db.QueryContext(ctx, qs, ldbTable) if err != nil { - return schema.PrimaryKeyZero, errors.Wrap(err, "query pragma_table_info error") + return schema.PrimaryKeyZero, fmt.Errorf("query pragma_table_info error: %w", err) } defer rows.Close() @@ -441,14 +441,14 @@ func (reader *LDBReader) getPrimaryKey(ctx context.Context, ldbTable string) (sc var ftString string err = rows.Scan(&name, &ftString) if err != nil { - return schema.PrimaryKeyZero, errors.WithStack(err) + return schema.PrimaryKeyZero, fmt.Errorf("scan: %w", err) } rawFieldNames = append(rawFieldNames, name) rawFieldTypes = append(rawFieldTypes, ftString) } err = rows.Err() if err != nil { - return schema.PrimaryKeyZero, errors.WithStack(err) + return schema.PrimaryKeyZero, fmt.Errorf("rows err: %w", err) } pk, err := schema.NewPKFromRawNamesAndTypes(rawFieldNames, rawFieldTypes) @@ -622,14 +622,14 @@ func (reader *LDBReader) switchLDB(dirPath string, timestamp int64) error { db, err := newLDB(fullPath) if err != nil { - return errors.Wrap(err, "new ldb") + return fmt.Errorf("new ldb: %w", err) } reader.mu.Lock() defer reader.mu.Unlock() if err = reader.closeDB(); err != nil { - return errors.Wrap(err, "closing db") + return fmt.Errorf("closing db: %w", err) } reader.Db = db @@ -668,7 +668,7 @@ func lookupLastLDBSync(dirPath string) (int64, error) { // dirPath + ["", ldb.DefaultLDBFilename] localPath, err := filepath.Rel(dirPath, filePath) if err != nil { - return errors.Wrapf(err, "base path (%s)", filePath) + return fmt.Errorf("base path (%s): %w", filePath, err) } fields := strings.Split(localPath, "/") @@ -691,7 +691,7 @@ func lookupLastLDBSync(dirPath string) (int64, error) { return nil }) if err != nil { - return 0, errors.Wrap(err, "filepath walk") + return 0, fmt.Errorf("filepath walk: %w", err) } return lastSync, nil diff --git a/ldb_reader_test.go b/ldb_reader_test.go index 612774b4..47d037e1 100644 --- a/ldb_reader_test.go +++ b/ldb_reader_test.go @@ -3,6 +3,7 @@ package ctlstore import ( "context" "database/sql" + "errors" "fmt" "io/ioutil" "os" @@ -11,7 +12,7 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/pkg/errors" + "github.com/segmentio/ctlstore/pkg/ldb" "github.com/segmentio/ctlstore/pkg/ldbwriter" "github.com/segmentio/ctlstore/pkg/schema" diff --git a/pkg/changelog/changelog_writer.go b/pkg/changelog/changelog_writer.go index e3c2fba4..af181884 100644 --- a/pkg/changelog/changelog_writer.go +++ b/pkg/changelog/changelog_writer.go @@ -2,8 +2,8 @@ package changelog import ( "encoding/json" + "fmt" - "github.com/pkg/errors" "github.com/segmentio/events/v2" ) @@ -42,7 +42,7 @@ func (w *ChangelogWriter) WriteChange(e ChangelogEntry) error { bytes, err := json.Marshal(structure) if err != nil { - return errors.Wrap(err, "error marshalling json") + return fmt.Errorf("error marshalling json: %w", err) } events.Debug("changelogWriter.WriteChange: %{family}s.%{table}s => %{key}v", diff --git a/pkg/cmd/ctlstore-cli/cmd/read_keys.go b/pkg/cmd/ctlstore-cli/cmd/read_keys.go index d72b4823..f795785f 100644 --- a/pkg/cmd/ctlstore-cli/cmd/read_keys.go +++ b/pkg/cmd/ctlstore-cli/cmd/read_keys.go @@ -10,7 +10,6 @@ import ( "text/tabwriter" "time" - "github.com/pkg/errors" "github.com/segmentio/cli" "github.com/segmentio/ctlstore" ) @@ -48,7 +47,7 @@ var cliReadKeys = &cli.CommandFunc{ } reader, err := ctlstore.ReaderForPath(ldbPath) if err != nil { - return errors.Wrap(err, "ldb reader for path") + return fmt.Errorf("ldb reader for path: %w", err) } defer reader.Close() resMap := make(map[string]interface{}) @@ -108,7 +107,7 @@ func parseKey(key string) (interface{}, error) { } hex, err := hex.DecodeString(parts[1]) if err != nil { - return nil, errors.Errorf("could not parse '%s' as hex", parts[1]) + return nil, fmt.Errorf("could not parse '%s' as hex", parts[1]) } return hex, nil } diff --git a/pkg/cmd/ctlstore-cli/main.go b/pkg/cmd/ctlstore-cli/main.go index 5a1c9b61..3c744935 100644 --- a/pkg/cmd/ctlstore-cli/main.go +++ b/pkg/cmd/ctlstore-cli/main.go @@ -1,6 +1,8 @@ package main -import "github.com/segmentio/ctlstore/pkg/cmd/ctlstore-cli/cmd" +import ( + "github.com/segmentio/ctlstore/pkg/cmd/ctlstore-cli/cmd" +) func main() { cmd.Execute() diff --git a/pkg/cmd/ctlstore-mutator/main.go b/pkg/cmd/ctlstore-mutator/main.go index 252a523b..d663a219 100644 --- a/pkg/cmd/ctlstore-mutator/main.go +++ b/pkg/cmd/ctlstore-mutator/main.go @@ -16,7 +16,6 @@ import ( "github.com/segmentio/conf" "github.com/segmentio/ctlstore/pkg/utils" - "github.com/segmentio/errors-go" ) type config struct { @@ -99,7 +98,7 @@ func main() { } b, err := json.Marshal(payload) if err != nil { - return errors.Wrap(err, "marshaling payload") + return fmt.Errorf("marshaling payload: %w", err) } req, err := http.NewRequest("POST", executiveURL+"/families/"+cfg.FamilyName+"/mutations", bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") @@ -107,7 +106,7 @@ func main() { req.Header.Set("ctlstore-secret", cfg.WriterSecret) resp, err := client.Do(req) if err != nil { - return errors.Wrap(err, "making mutation request") + return fmt.Errorf("making mutation request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -145,12 +144,12 @@ func setup(cfg config, url string) error { req, err := http.NewRequest("POST", url+"/families/"+cfg.FamilyName, nil) if err != nil { - return errors.Wrap(err, "create family request") + return fmt.Errorf("create family request: %w", err) } req.Header.Set("Content-Type", "application/json") res, err = client.Do(req) if err != nil { - return errors.Wrap(err, "making faily request") + return fmt.Errorf("making faily request: %w", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusConflict { @@ -175,12 +174,12 @@ func setup(cfg config, url string) error { } req, err = http.NewRequest("POST", url+"/families/"+cfg.FamilyName+"/tables/"+cfg.TableName, utils.NewJsonReader(tableDef)) if err != nil { - return errors.Wrap(err, "create family request") + return fmt.Errorf("create family request: %w", err) } req.Header.Set("Content-Type", "application/json") res, err = client.Do(req) if err != nil { - return errors.Wrap(err, "making table request") + return fmt.Errorf("making table request: %w", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusConflict { diff --git a/pkg/cmd/ctlstore/main.go b/pkg/cmd/ctlstore/main.go index 8a09a0a2..51c3d7ea 100644 --- a/pkg/cmd/ctlstore/main.go +++ b/pkg/cmd/ctlstore/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "net/http" "os" @@ -10,7 +11,6 @@ import ( "time" "github.com/segmentio/conf" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" _ "github.com/segmentio/events/v2/sigevents" "github.com/segmentio/stats/v4" @@ -291,12 +291,12 @@ func supervisor(ctx context.Context, args []string) { }) defer teardown() if err := utils.EnsureDirForFile(cliCfg.ReflectorConfig.LDBPath); err != nil { - return errors.Wrap(err, "ensure ldb dir") + return fmt.Errorf("ensure ldb dir: %w", err) } reflector, err := newReflector(cliCfg.ReflectorConfig, true) if err != nil { - return errors.Wrap(err, "build supervisor reflector") + return fmt.Errorf("build supervisor reflector: %w", err) } supervisor, err := supervisorpkg.SupervisorFromConfig(supervisorpkg.SupervisorConfig{ @@ -306,7 +306,7 @@ func supervisor(ctx context.Context, args []string) { Reflector: reflector, // compose the reflector, since it will start with the supervisor }) if err != nil { - return errors.Wrap(err, "start supervisor") + return fmt.Errorf("start supervisor: %w", err) } defer supervisor.Close() supervisor.Start(ctx) @@ -403,7 +403,7 @@ func executive(ctx context.Context, args []string) { defer executive.Close() if err := executive.Start(ctx, cliCfg.Bind); err != nil { - if errors.Cause(err) != context.Canceled { + if errors.Is(err, context.Canceled) { errs.IncrDefault(stats.T("op", "service shutdown")) } events.Log("executive quit: %v", err) diff --git a/pkg/ctldb/ctldb.go b/pkg/ctldb/ctldb.go index 8e1cbe56..c5d98491 100644 --- a/pkg/ctldb/ctldb.go +++ b/pkg/ctldb/ctldb.go @@ -59,6 +59,33 @@ CREATE TABLE locks ( INSERT INTO locks VALUES('ledger', 0); +` + LimiterDBSchemaUp, + "sqlite": ` + +CREATE TABLE families ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR(191) NOT NULL UNIQUE +); + +CREATE TABLE mutators ( + writer VARCHAR(191) NOT NULL PRIMARY KEY, + secret VARCHAR(255), + cookie BLOB(1024) NOT NULL, + clock INTEGER NOT NULL DEFAULT 0 +); + +CREATE TABLE ctlstore_dml_ledger ( + seq INTEGER PRIMARY KEY AUTOINCREMENT, + leader_ts DATETIME DEFAULT CURRENT_TIMESTAMP, + statement TEXT NOT NULL +); + +CREATE TABLE locks ( + id VARCHAR(191) NOT NULL PRIMARY KEY, + clock INTEGER NOT NULL DEFAULT 0 +); + +INSERT INTO locks VALUES('ledger', 0); ` + LimiterDBSchemaUp, "sqlite3": ` diff --git a/pkg/ctldb/ctldb_test_helpers.go b/pkg/ctldb/ctldb_test_helpers.go index 1c33f1a1..73db84f5 100644 --- a/pkg/ctldb/ctldb_test_helpers.go +++ b/pkg/ctldb/ctldb_test_helpers.go @@ -1,6 +1,8 @@ package ctldb -import "testing" +import ( + "testing" +) // This configuration comes from the docker-compose.yml file const testCtlDBRawDSN = "ctldb:ctldbpw@tcp(localhost:3306)/ctldb" diff --git a/pkg/ctldb/dsn_parameters.go b/pkg/ctldb/dsn_parameters.go index dd43b0e5..31707078 100644 --- a/pkg/ctldb/dsn_parameters.go +++ b/pkg/ctldb/dsn_parameters.go @@ -1,6 +1,8 @@ package ctldb -import "net/url" +import ( + "net/url" +) func SetCtldbDSNParameters(dsn string) (string, error) { var err error diff --git a/pkg/errs/errs.go b/pkg/errs/errs.go index 859b2610..ab253958 100644 --- a/pkg/errs/errs.go +++ b/pkg/errs/errs.go @@ -2,9 +2,10 @@ package errs import ( "context" + "errors" "fmt" + "net/http" - "github.com/segmentio/errors-go" "github.com/segmentio/stats/v4" ) @@ -12,14 +13,30 @@ const ( defaultErrName = "errors" ) -const ( - // these error types are handy when using errors-go - ErrTypeTemporary = "Temporary" - ErrTypePermanent = "Permanent" -) +type ErrTypeTemporary struct{ Err error } + +func (e ErrTypeTemporary) Error() string { + return e.Err.Error() +} + +func (e ErrTypeTemporary) Is(target error) bool { + _, ok := target.(ErrTypeTemporary) + return ok +} + +type ErrTypePermanent struct{ Err error } + +func (e ErrTypePermanent) Error() string { + return e.Err.Error() +} + +func (e ErrTypePermanent) Is(target error) bool { + _, ok := target.(ErrTypePermanent) + return ok +} func IsCanceled(err error) bool { - return err != nil && errors.Cause(err) == context.Canceled + return errors.Is(err, context.Canceled) } // IncrDefault increments the default error metric @@ -42,26 +59,52 @@ func Incr(name string, tags ...stats.Tag) { stats.Incr(defaultErrName, newTags...) } +func statusCode(err error) int { + var coder StatusCoder + if errors.As(err, &coder) { + return coder.StatusCode() + } + return http.StatusInternalServerError +} + // These are here because there's a need for a set of errors that have roughly // REST/HTTP compatibility, but aren't directly coupled to that interface. Lower // layers of the system can generate these errors while still making sense in // any context. + +type StatusCoder interface { + StatusCode() int +} + type baseError struct { + StatusCoder Err string } +func (b baseError) StatusCode() int { + return http.StatusInternalServerError +} + type ConflictError baseError func (e ConflictError) Error() string { return e.Err } +func (e ConflictError) StatusCode() int { + return http.StatusConflict +} + type BadRequestError baseError func (e BadRequestError) Error() string { return e.Err } +func (e BadRequestError) StatusCode() int { + return http.StatusBadRequest +} + func BadRequest(format string, args ...interface{}) error { return &BadRequestError{ Err: fmt.Sprintf(format, args...), @@ -74,6 +117,10 @@ func (e NotFoundError) Error() string { return e.Err } +func (e NotFoundError) StatusCode() int { + return http.StatusNotFound +} + func NotFound(format string, args ...interface{}) error { return &NotFoundError{ Err: fmt.Sprintf(format, args...), @@ -92,8 +139,16 @@ func (e RateLimitExceededErr) Error() string { return e.Err } +func (e RateLimitExceededErr) StatusCode() int { + return http.StatusTooManyRequests +} + type InsufficientStorageErr baseError func (e InsufficientStorageErr) Error() string { return e.Err } + +func (e InsufficientStorageErr) StatusCode() int { + return http.StatusInsufficientStorage +} diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go new file mode 100644 index 00000000..873ab5c0 --- /dev/null +++ b/pkg/errs/errs_test.go @@ -0,0 +1,56 @@ +package errs + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStatusCodes(t *testing.T) { + tests := []struct { + name string + err error + code int + }{ + { + name: "generic", + err: errors.New("foo"), + code: http.StatusInternalServerError, + }, + { + name: "conflict", + err: ConflictError{}, + code: http.StatusConflict, + }, + { + name: "bad request", + err: BadRequestError{}, + code: http.StatusBadRequest, + }, + { + name: "not found", + err: NotFound("foo"), + code: http.StatusNotFound, + }, + { + name: "rate limited", + err: RateLimitExceededErr{}, + code: http.StatusTooManyRequests, + }, + { + name: "insufficient storage", + err: InsufficientStorageErr{}, + code: http.StatusInsufficientStorage, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := fmt.Errorf("wrapped: %w", test.err) + code := statusCode(err) + require.Equal(t, test.code, code) + }) + } +} diff --git a/pkg/event/changelog.go b/pkg/event/changelog.go index 09cf8925..3e843ca1 100644 --- a/pkg/event/changelog.go +++ b/pkg/event/changelog.go @@ -5,14 +5,16 @@ import ( "bytes" "context" "encoding/json" + "errors" + "fmt" "io" + "io/fs" "os" "path/filepath" "time" "github.com/fsnotify/fsnotify" "github.com/segmentio/ctlstore/pkg/errs" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" ) @@ -63,7 +65,7 @@ func (c *fileChangelog) next(ctx context.Context) (Event, error) { func (c *fileChangelog) start(ctx context.Context) error { watcher, err := fsnotify.NewWatcher() if err != nil { - return errors.Wrap(err, "create fsnotify watcher") + return fmt.Errorf("create fsnotify watcher: %w", err) } go func() { select { @@ -76,7 +78,7 @@ func (c *fileChangelog) start(ctx context.Context) error { paths := []string{c.path, filepath.Dir(c.path)} for _, w := range paths { if err := watcher.Add(w); err != nil { - return errors.Wrapf(err, "could not watch '%s'", w) + return fmt.Errorf("could not watch '%s': %w", w, err) } } fsNotifyCh := make(chan fsnotify.Event) @@ -136,7 +138,7 @@ func (c *fileChangelog) read(ctx context.Context, fsNotifyCh chan fsnotify.Event // first, open the changelog f, err := os.Open(c.path) if err != nil { - return errors.Wrap(err, "open changelog") + return fmt.Errorf("open changelog: %w", err) } defer func() { if err := f.Close(); err != nil { @@ -155,7 +157,7 @@ func (c *fileChangelog) read(ctx context.Context, fsNotifyCh chan fsnotify.Event } var entry entry if err := json.Unmarshal(b, &entry); err != nil { - c.send(ctx, eventErr{err: errors.Wrapf(err, "parse entry '%s'", b)}) + c.send(ctx, eventErr{err: fmt.Errorf("parse entry '%s': %w", b, err)}) errs.Incr("changelog-errors", stats.T("op", "parse json")) return } @@ -196,7 +198,7 @@ func (c *fileChangelog) read(ctx context.Context, fsNotifyCh chan fsnotify.Event for { err = readEvents() if err != io.EOF { - return errors.Wrap(err, "read bytes") + return fmt.Errorf("read bytes: %w", err) } events.Debug("EOF. Waiting for more content...") select { @@ -207,7 +209,7 @@ func (c *fileChangelog) read(ctx context.Context, fsNotifyCh chan fsnotify.Event if err := readEvents(); err != io.EOF { events.Log("could not consume rest of file: %{error}s", err) } - return errors.Wrap(err, "watcher error") + return fmt.Errorf("watcher error: %w", err) case event := <-fsNotifyCh: switch event.Op { case fsnotify.Write: @@ -217,7 +219,7 @@ func (c *fileChangelog) read(ctx context.Context, fsNotifyCh chan fsnotify.Event events.Debug("New changelog created. Consuming the rest of current one...") err := readEvents() if err != io.EOF { - return errors.Wrap(err, "consume rest of changelog") + return fmt.Errorf("consume rest of changelog: %w", err) } events.Debug("Restarting reader") return nil @@ -233,7 +235,7 @@ func (c *fileChangelog) read(ctx context.Context, fsNotifyCh chan fsnotify.Event // loop again case errs.IsCanceled(err): return - case os.IsNotExist(errors.Cause(err)): + case errors.Is(err, fs.ErrNotExist): events.Log("Changelog file does not exist, rechecking...") select { case <-fsNotifyCh: @@ -271,12 +273,12 @@ func (c *fileChangelog) validate() error { switch { case err == nil: case os.IsNotExist(err): - return errors.Wrap(err, "changelog does not exist") + return fmt.Errorf("changelog does not exist: %w", err) default: - return errors.Wrap(err, "stat changelog") + return fmt.Errorf("stat changelog: %w", err) } default: - return errors.Wrap(err, "stat changelog") + return fmt.Errorf("stat changelog: %w", err) } return nil } diff --git a/pkg/event/fake_changelog.go b/pkg/event/fake_changelog.go index f51da8eb..74b8edb9 100644 --- a/pkg/event/fake_changelog.go +++ b/pkg/event/fake_changelog.go @@ -2,8 +2,7 @@ package event import ( "context" - - "github.com/pkg/errors" + "errors" ) type ( diff --git a/pkg/event/fake_log_writer.go b/pkg/event/fake_log_writer.go index 0f39505b..e0b17d67 100644 --- a/pkg/event/fake_log_writer.go +++ b/pkg/event/fake_log_writer.go @@ -4,11 +4,11 @@ import ( "bufio" "context" "encoding/json" + "fmt" "os" "sync/atomic" "time" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" ) @@ -27,7 +27,7 @@ type fakeLogWriter struct { func (w *fakeLogWriter) writeN(ctx context.Context, n int) error { f, err := os.Create(w.path) if err != nil { - return errors.Wrap(err, "create file") + return fmt.Errorf("create file: %w", err) } defer func() { events.Debug("Done writing %{num}d events", n) @@ -56,10 +56,10 @@ func (w *fakeLogWriter) writeN(ctx context.Context, n int) error { atomic.AddInt64(&w.seq, 1) err := json.NewEncoder(bw).Encode(entry) if err != nil { - return errors.Wrap(err, "write event") + return fmt.Errorf("write event: %w", err) } if err := bw.Flush(); err != nil { - return errors.Wrap(err, "flush") + return fmt.Errorf("flush: %w", err) } time.Sleep(w.delay) @@ -67,7 +67,7 @@ func (w *fakeLogWriter) writeN(ctx context.Context, n int) error { if w.rotateAfterBytes > 0 { info, err := os.Stat(w.path) if err != nil { - return errors.Wrap(err, "stat path") + return fmt.Errorf("stat path: %w", err) } // fmt.Println(info.Size(), w.rotateAfterBytes) if info.Size() > int64(w.rotateAfterBytes) { @@ -83,14 +83,14 @@ func (w *fakeLogWriter) writeN(ctx context.Context, n int) error { events.Debug("Rotating log file..") these = 0 if err := f.Close(); err != nil { - return errors.Wrap(err, "close during rotation") + return fmt.Errorf("close during rotation: %w", err) } if err := os.Remove(f.Name()); err != nil { - return errors.Wrap(err, "remove file") + return fmt.Errorf("remove file: %w", err) } f, err = os.Create(w.path) if err != nil { - return errors.Wrap(err, "rotate into new file") + return fmt.Errorf("rotate into new file: %w", err) } bw = bufio.NewWriter(f) atomic.AddInt64(&w.rotations, 1) diff --git a/pkg/event/iterator.go b/pkg/event/iterator.go index d6536ee6..0bcd3fa7 100644 --- a/pkg/event/iterator.go +++ b/pkg/event/iterator.go @@ -2,8 +2,8 @@ package event import ( "context" - - "github.com/segmentio/errors-go" + "errors" + "fmt" ) type ( @@ -44,13 +44,13 @@ func NewIterator(ctx context.Context, changelogPath string, opts ...IteratorOpt) if iter.changelog == nil { cl := newFileChangelog(changelogPath) if err := cl.validate(); err != nil { - return nil, errors.Wrap(err, "validate changelog") + return nil, fmt.Errorf("validate changelog: %w", err) } iter.changelog = cl } ctx, iter.cancelFunc = context.WithCancel(ctx) if err := iter.changelog.start(ctx); err != nil { - return nil, errors.Wrap(err, "start changelog") + return nil, fmt.Errorf("start changelog: %w", err) } return iter, nil } diff --git a/pkg/event/iterator_integration_test.go b/pkg/event/iterator_integration_test.go index 35f00ff4..69147a9d 100644 --- a/pkg/event/iterator_integration_test.go +++ b/pkg/event/iterator_integration_test.go @@ -1,3 +1,6 @@ +//go:build integration +// +build integration + package event import ( diff --git a/pkg/event/iterator_test.go b/pkg/event/iterator_test.go index 3c664cc4..8e66cc5a 100644 --- a/pkg/event/iterator_test.go +++ b/pkg/event/iterator_test.go @@ -2,10 +2,10 @@ package event import ( "context" + "errors" "testing" "time" - "github.com/pkg/errors" "github.com/stretchr/testify/require" ) diff --git a/pkg/executive/db_executive.go b/pkg/executive/db_executive.go index fc2bbf4f..2103a8df 100644 --- a/pkg/executive/db_executive.go +++ b/pkg/executive/db_executive.go @@ -3,12 +3,13 @@ package executive import ( "context" "database/sql" + "errors" "fmt" "strings" "time" "github.com/go-sql-driver/mysql" - "github.com/pkg/errors" + "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/ldb" "github.com/segmentio/ctlstore/pkg/limits" @@ -34,19 +35,19 @@ var ErrTableDoesNotExist = errors.New("table does not exist") func (e *dbExecutive) FamilySchemas(family string) ([]schema.Table, error) { familyName, err := schema.NewFamilyName(family) if err != nil { - return nil, errors.Wrap(err, "family name") + return nil, fmt.Errorf("family name: %w", err) } dbInfo := getDBInfo(e.DB) tables, err := dbInfo.GetAllTables(context.TODO()) if err != nil { - return nil, errors.Wrap(err, "get table names") + return nil, fmt.Errorf("get table names: %w", err) } var res []schema.Table for _, table := range tables { if table.Family == familyName.String() { ts, err := e.TableSchema(familyName.Name, table.Table) if err != nil { - return nil, errors.Wrap(err, "get table schema") + return nil, fmt.Errorf("get table schema: %w", err) } res = append(res, *ts) } @@ -57,24 +58,24 @@ func (e *dbExecutive) FamilySchemas(family string) ([]schema.Table, error) { func (e *dbExecutive) TableSchema(family, table string) (*schema.Table, error) { familyName, err := schema.NewFamilyName(family) if err != nil { - return nil, errors.Wrap(err, "family name") + return nil, fmt.Errorf("family name: %w", err) } if normalized := familyName.String(); normalized != family { - return nil, errors.Wrapf(err, "passed in family name does not match normalized family name: %q", normalized) + return nil, fmt.Errorf("passed in family name does not match normalized family name: %q: %w", normalized, err) } tableName, err := schema.NewTableName(table) if err != nil { - return nil, errors.Wrap(err, "table name") + return nil, fmt.Errorf("table name: %w", err) } if normalized := tableName.String(); normalized != table { - return nil, errors.Wrapf(err, "passed in table name does not match normalized table name: %q", normalized) + return nil, fmt.Errorf("passed in table name does not match normalized table name: %q: %w", normalized, err) } tbl, ok, err := e.fetchMetaTableByName(familyName, tableName) if err != nil { - return nil, errors.Wrap(err, "fetch meta table") + return nil, fmt.Errorf("fetch meta table: %w", err) } if !ok { - return nil, errors.Wrapf(ErrTableDoesNotExist, "%s___%s", family, table) + return nil, fmt.Errorf("%s___%s: %w", family, table, ErrTableDoesNotExist) } res := &schema.Table{ Family: familyName.Name, @@ -89,7 +90,7 @@ func (e *dbExecutive) TableSchema(family, table string) (*schema.Table, error) { case schema.FTBinary: case schema.FTByteString: default: - return nil, errors.Errorf("unsupported field type: %q", field.FieldType) + return nil, fmt.Errorf("unsupported field type: %q", field.FieldType) } res.Fields = append(res.Fields, []string{ field.Name.Name, field.FieldType.String(), @@ -166,7 +167,7 @@ func (e *dbExecutive) CreateTable(familyName string, tableName string, fieldName err = tbl.Validate() if err != nil { - return &errs.BadRequestError{err.Error()} + return &errs.BadRequestError{Err: err.Error()} } ddl, err := tbl.AsCreateTableDDL() @@ -195,7 +196,7 @@ func (e *dbExecutive) CreateTable(familyName string, tableName string, fieldName err = e.takeLedgerLock(ctx, tx) if err != nil { - return errors.Wrap(err, "take ledger lock") + return fmt.Errorf("take ledger lock: %w", err) } dlw := dmlLedgerWriter{ @@ -206,7 +207,7 @@ func (e *dbExecutive) CreateTable(familyName string, tableName string, fieldName seq, err := dlw.Add(ctx, logDDL) if err != nil { - return errors.Wrap(err, "apply dml") + return fmt.Errorf("apply dml: %w", err) } _, err = e.applyDDL(ctx, tx, ddl) @@ -215,7 +216,7 @@ func (e *dbExecutive) CreateTable(familyName string, tableName string, fieldName strings.Contains(err.Error(), "already exists") { return &errs.ConflictError{Err: "Table already exists"} } - return errors.Wrap(err, "apply ddl") + return fmt.Errorf("apply ddl: %w", err) } err = tx.Commit() @@ -232,11 +233,11 @@ func (e *dbExecutive) CreateTables(tables []schema.Table) error { for _, table := range tables { fieldNames, fieldTypes, err := schema.UnzipFieldsParam(table.Fields) if err != nil { - return errors.Wrap(err, fmt.Sprintf("unzipping fields param for family %q table %q", table.Family, table.Name)) + return fmt.Errorf("unzipping fields param for family %q table %q: %w", table.Family, table.Name, err) } err = e.CreateTable(table.Family, table.Name, fieldNames, fieldTypes, table.KeyFields) if err != nil { - return errors.Wrap(err, fmt.Sprintf("creating table for family %q table %q", table.Family, table.Name)) + return fmt.Errorf("creating table for family %q table %q: %w", table.Family, table.Name, err) } } return nil @@ -258,7 +259,7 @@ func (e *dbExecutive) applyDDL(ctx context.Context, tx *sql.Tx, ddl string) (sql // sqlite supports transactional ddl return tx.ExecContext(ctx, ddl) default: - return nil, errors.Errorf("Unknown driver: %T", e.DB.Driver()) + return nil, fmt.Errorf("Unknown driver: %T", e.DB.Driver()) } } @@ -315,7 +316,7 @@ func (e *dbExecutive) AddFields(familyName string, tableName string, fieldNames err = e.takeLedgerLock(ctx, tx) if err != nil { - return errors.Wrap(err, "take ledger lock") + return fmt.Errorf("take ledger lock: %w", err) } // We first write the column modification to the DML ledger within the transaction. @@ -326,7 +327,7 @@ func (e *dbExecutive) AddFields(familyName string, tableName string, fieldNames defer dlw.Close() seq, err := dlw.Add(ctx, logDDL) if err != nil { - return errors.Wrap(err, "add dml") + return fmt.Errorf("add dml: %w", err) } // Next, apply the DDL to the ctldb. If the DDL fails, return the err, which will @@ -337,13 +338,13 @@ func (e *dbExecutive) AddFields(familyName string, tableName string, fieldNames strings.Contains(err.Error(), "duplicate column name") { // sqlite return &errs.ConflictError{Err: "Column already exists"} } - return errors.Wrap(err, "apply ddl") + return fmt.Errorf("apply ddl: %w", err) } // if the DDL succeeds, commit the transaction err = tx.Commit() if err != nil { - return errors.Wrap(err, "commit tx") + return fmt.Errorf("commit tx: %w", err) } events.Log("Successfully created new field `%{fieldName}s %{fieldType}v` on table %{tableName}s at seq %{seq}v", fieldName, fieldType, tableName, seq) return nil @@ -404,7 +405,7 @@ func (e *dbExecutive) GetWriterCookie(writerName string, writerSecret string) ([ func (e *dbExecutive) takeLedgerLock(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, "UPDATE locks SET clock = clock + 1 WHERE id = 'ledger'") if err != nil { - return errors.Wrap(err, "update locks") + return fmt.Errorf("update locks: %w", err) } return nil } @@ -476,12 +477,12 @@ func (e *dbExecutive) Mutate( tblNames := reqset.TableNames() tbls, err := e.fetchMetaTablesByName(famName, tblNames) if err != nil { - return errors.Wrap(err, "fetch meta tables error") + return fmt.Errorf("fetch meta tables error: %w", err) } for _, tblName := range tblNames { if _, ok := tbls[tblName]; !ok { - return errors.Errorf("Table not found: %s", tblName) + return fmt.Errorf("Table not found: %s", tblName) } } @@ -491,7 +492,7 @@ func (e *dbExecutive) Mutate( // dope, y'all. tx, err := e.DB.BeginTx(ctx, nil) if err != nil { - return errors.Wrap(err, "begin tx error") + return fmt.Errorf("begin tx error: %w", err) } defer tx.Rollback() @@ -512,7 +513,7 @@ func (e *dbExecutive) Mutate( // See the method documentation for more information. err = e.takeLedgerLock(ctx, tx) if err != nil { - return errors.Wrap(err, "taking ledger lock") + return fmt.Errorf("taking ledger lock: %w", err) } // Check Cookie @@ -544,7 +545,7 @@ func (e *dbExecutive) Mutate( if len(reqset.Requests) > 1 { _, err := dlw.BeginTx(ctx) if err != nil { - return errors.Wrap(err, "logging tx begin failed") + return fmt.Errorf("logging tx begin failed: %w", err) } } @@ -589,26 +590,26 @@ func (e *dbExecutive) Mutate( _, err = tx.ExecContext(ctx, dmlSQL) if err != nil { events.Log("dml exec error, Request: %{req}+v SQL: %{sql}s", req, dmlSQL) - return errors.Wrap(err, "dml exec error") + return fmt.Errorf("dml exec error: %w", err) } // Now record it in the log table lastSeq, err = dlw.Add(ctx, dmlSQL) if err != nil { - return errors.Wrap(err, "log write error") + return fmt.Errorf("log write error: %w", err) } } if len(reqset.Requests) > 1 { lastSeq, err = dlw.CommitTx(ctx) if err != nil { - return errors.Wrap(err, "logging tx commit failed") + return fmt.Errorf("logging tx commit failed: %w", err) } } err = tx.Commit() if err != nil { - return errors.Wrap(err, "commit failed") + return fmt.Errorf("commit failed: %w", err) } events.Debug( @@ -735,9 +736,6 @@ func (e *dbExecutive) fetchFamilyByName(famName schema.FamilyName) (fam dbFamily } else if err == nil { // No errors, no ErrNoRows, that means found! ok = true - } else { - // An error! - err = errors.WithStack(err) } return @@ -750,10 +748,10 @@ func (e *dbExecutive) RegisterWriter(writerName string, secret string) error { } if len(secret) < limits.LimitWriterSecretMinLength { - return errors.Errorf("Secret should be at least %d characters", limits.LimitWriterSecretMinLength) + return fmt.Errorf("Secret should be at least %d characters", limits.LimitWriterSecretMinLength) } if len(secret) > limits.LimitWriterSecretMaxLength { - return errors.Errorf("Secret can be at most %d characters", limits.LimitWriterSecretMaxLength) + return fmt.Errorf("Secret can be at most %d characters", limits.LimitWriterSecretMaxLength) } ms := mutatorStore{ @@ -859,13 +857,13 @@ func (e *dbExecutive) ReadTableSizeLimits() (res limits.TableSizeLimits, err err "FROM max_table_sizes "+ "ORDER BY family_name, table_name") if err != nil { - return res, errors.Wrap(err, "select table sizes") + return res, fmt.Errorf("select table sizes: %w", err) } defer rows.Close() for rows.Next() { var tsl limits.TableSizeLimit if err := rows.Scan(&tsl.Family, &tsl.Table, &tsl.WarnSize, &tsl.MaxSize); err != nil { - return res, errors.Wrap(err, "scan table sizes") + return res, fmt.Errorf("scan table sizes: %w", err) } res.Tables = append(res.Tables, tsl) } @@ -877,14 +875,14 @@ func (e *dbExecutive) UpdateTableSizeLimit(limit limits.TableSizeLimit) error { defer cancel() tx, err := e.DB.BeginTx(ctx, nil) if err != nil { - return errors.Wrap(err, "start tx") + return fmt.Errorf("start tx: %w", err) } defer tx.Rollback() // check first to see if the table exists ft := schema.FamilyTable{Family: limit.Family, Table: limit.Table} _, err = tx.ExecContext(ctx, "select * from "+ft.String()+" limit 1") if err != nil { - return errors.Errorf("table '%s' not found", ft) + return fmt.Errorf("table '%s' not found", ft) } // then do the upsert res, err := tx.ExecContext(ctx, "replace into max_table_sizes "+ @@ -892,17 +890,17 @@ func (e *dbExecutive) UpdateTableSizeLimit(limit limits.TableSizeLimit) error { "values (?, ?, ?, ?)", limit.Family, limit.Table, limit.WarnSize, limit.MaxSize) if err != nil { - return errors.Wrap(err, "replace into max_table_sizes") + return fmt.Errorf("replace into max_table_sizes: %w", err) } ra, err := res.RowsAffected() if err != nil { - return errors.Wrap(err, "max_table_sizes rows affected") + return fmt.Errorf("max_table_sizes rows affected: %w", err) } if ra <= 0 { return errors.New("unexpected failure -- no rows updated") } if err := tx.Commit(); err != nil { - return errors.Wrap(err, "commit tx") + return fmt.Errorf("commit tx: %w", err) } return nil } @@ -914,14 +912,14 @@ func (e *dbExecutive) DeleteTableSizeLimit(ft schema.FamilyTable) error { resp, err := e.DB.ExecContext(ctx, "delete from max_table_sizes where family_name=? and table_name=?", ft.Family, ft.Table) if err != nil { - return errors.Wrap(err, "delete from max_table_sizes") + return fmt.Errorf("delete from max_table_sizes: %w", err) } rows, err := resp.RowsAffected() if err != nil { - return errors.Wrap(err, "rows affected") + return fmt.Errorf("rows affected: %w", err) } if rows < 1 { - return errors.Errorf("could not find table limit for %s", ft) + return fmt.Errorf("could not find table limit for %s", ft) } return nil } @@ -935,14 +933,14 @@ func (e *dbExecutive) ReadWriterRateLimits() (res limits.WriterRateLimits, err e "FROM max_writer_rates "+ "ORDER BY writer_name") if err != nil { - return res, errors.Wrap(err, "select writer rates") + return res, fmt.Errorf("select writer rates: %w", err) } defer rows.Close() for rows.Next() { var wrl limits.WriterRateLimit wrl.RateLimit.Period = time.Minute if err := rows.Scan(&wrl.Writer, &wrl.RateLimit.Amount); err != nil { - return res, errors.Wrap(err, "scan writer rates") + return res, fmt.Errorf("scan writer rates: %w", err) } res.Writers = append(res.Writers, wrl) } @@ -954,7 +952,7 @@ func (e *dbExecutive) UpdateWriterRateLimit(limit limits.WriterRateLimit) error defer cancel() tx, err := e.DB.BeginTx(ctx, nil) if err != nil { - return errors.Wrap(err, "start tx") + return fmt.Errorf("start tx: %w", err) } defer tx.Rollback() @@ -963,34 +961,34 @@ func (e *dbExecutive) UpdateWriterRateLimit(limit limits.WriterRateLimit) error // computers are fast. writer, err := schema.NewWriterName(limit.Writer) if err != nil { - return errors.Wrap(err, "validate writer") + return fmt.Errorf("validate writer: %w", err) } exists, err := ms.Exists(writer) if err != nil { - return errors.Wrap(err, "check writer exists") + return fmt.Errorf("check writer exists: %w", err) } if !exists { - return errors.Errorf("no writer with the name '%s' exists", limit.Writer) + return fmt.Errorf("no writer with the name '%s' exists", limit.Writer) } adjustedAmount, err := limit.RateLimit.AdjustAmount(time.Minute) if err != nil { - return errors.Wrap(err, "check limit") + return fmt.Errorf("check limit: %w", err) } res, err := tx.ExecContext(ctx, "replace into max_writer_rates "+ "(writer_name, max_rows_per_minute) "+ "values (?, ?)", limit.Writer, adjustedAmount) if err != nil { - return errors.Wrap(err, "replace into max_writer_rates") + return fmt.Errorf("replace into max_writer_rates: %w", err) } ra, err := res.RowsAffected() if err != nil { - return errors.Wrap(err, "max_writer_rates rows affected") + return fmt.Errorf("max_writer_rates rows affected: %w", err) } if ra <= 0 { return errors.New("unexpected failure -- no rows updated") } if err := tx.Commit(); err != nil { - return errors.Wrap(err, "commit tx") + return fmt.Errorf("commit tx: %w", err) } return nil } @@ -1000,14 +998,14 @@ func (e *dbExecutive) DeleteWriterRateLimit(writerName string) error { defer cancel() res, err := e.DB.ExecContext(ctx, "delete from max_writer_rates where writer_name=?", writerName) if err != nil { - return errors.Wrap(err, "delete from max_writer_rates") + return fmt.Errorf("delete from max_writer_rates: %w", err) } ra, err := res.RowsAffected() if err != nil { - return errors.Wrap(err, "get rows affected from max_writer_rates") + return fmt.Errorf("get rows affected from max_writer_rates: %w", err) } if ra <= 0 { - return errors.Errorf("no writer limit for the writer '%s' was found", writerName) + return fmt.Errorf("no writer limit for the writer '%s' was found", writerName) } return nil } @@ -1048,13 +1046,13 @@ func (e *dbExecutive) DropTable(table schema.FamilyTable) error { tx, err := e.DB.BeginTx(ctx, nil) if err != nil { - return errors.Wrap(err, "error beginning transaction") + return fmt.Errorf("error beginning transaction: %w", err) } defer tx.Rollback() err = e.takeLedgerLock(ctx, tx) if err != nil { - return errors.Wrap(err, "take ledger lock") + return fmt.Errorf("take ledger lock: %w", err) } dlw := dmlLedgerWriter{ @@ -1065,17 +1063,17 @@ func (e *dbExecutive) DropTable(table schema.FamilyTable) error { _, err = tx.ExecContext(ctx, ddl) if err != nil { - return errors.Wrap(err, "error running drop command") + return fmt.Errorf("error running drop command: %w", err) } seq, err := dlw.Add(ctx, logDDL) if err != nil { - return errors.Wrap(err, "error inserting drop command into ledger") + return fmt.Errorf("error inserting drop command into ledger: %w", err) } err = tx.Commit() if err != nil { - return errors.Wrap(err, "error committing transaction") + return fmt.Errorf("error committing transaction: %w", err) } events.Log("Successfully dropped `%{tableName}s` at seq %{seq}v", table.String(), seq) @@ -1120,13 +1118,13 @@ func (e *dbExecutive) ClearTable(table schema.FamilyTable) error { tx, err := e.DB.BeginTx(ctx, nil) if err != nil { - return errors.Wrap(err, "error beginning transaction") + return fmt.Errorf("error beginning transaction: %w", err) } defer tx.Rollback() err = e.takeLedgerLock(ctx, tx) if err != nil { - return errors.Wrap(err, "take ledger lock") + return fmt.Errorf("take ledger lock: %w", err) } dlw := dmlLedgerWriter{ @@ -1137,17 +1135,17 @@ func (e *dbExecutive) ClearTable(table schema.FamilyTable) error { _, err = tx.ExecContext(ctx, ddl) if err != nil { - return errors.Wrap(err, "error running delete command") + return fmt.Errorf("error running delete command: %w", err) } seq, err := dlw.Add(ctx, logDDL) if err != nil { - return errors.Wrap(err, "error inserting delete command into ledger") + return fmt.Errorf("error inserting delete command into ledger: %w", err) } err = tx.Commit() if err != nil { - return errors.Wrap(err, "error committing transaction") + return fmt.Errorf("error committing transaction: %w", err) } events.Log("Successfully deleted all rows from `%{tableName}s` at seq %{seq}v", table.String(), seq) @@ -1162,13 +1160,13 @@ func (e *dbExecutive) ReadFamilyTableNames(family schema.FamilyName) (tables []s events.Debug("reading family table names where f=%s", family) rows, err := e.DB.QueryContext(ctx, fmt.Sprintf(`select table_name from information_schema.tables where table_name like '%s___%%'`, family.String())) if err != nil { - return nil, errors.Wrap(err, "error reading family table names") + return nil, fmt.Errorf("error reading family table names: %w", err) } defer rows.Close() for rows.Next() { var fullTableName string if err := rows.Scan(&fullTableName); err != nil { - return nil, errors.Wrap(err, "error reading family table names") + return nil, fmt.Errorf("error reading family table names: %w", err) } prefix := family.String() + "___" table := strings.TrimPrefix(fullTableName, prefix) @@ -1189,11 +1187,11 @@ func (e *dbExecutive) ReadFamilyTableNames(family schema.FamilyName) (tables []s func sanitizeFamilyAndTableNames(family string, table string) (string, string, error) { sanFamily, err := schema.NewFamilyName(family) if err != nil { - return "", "", errors.Wrap(err, "sanitize family") + return "", "", fmt.Errorf("sanitize family: %w", err) } sanTable, err := schema.NewTableName(table) if err != nil { - return "", "", errors.Wrap(err, "sanitize table") + return "", "", fmt.Errorf("sanitize table: %w", err) } return sanFamily.Name, sanTable.Name, nil } diff --git a/pkg/executive/db_executive_test.go b/pkg/executive/db_executive_test.go index 0616267e..2f5df816 100644 --- a/pkg/executive/db_executive_test.go +++ b/pkg/executive/db_executive_test.go @@ -1,18 +1,9 @@ -/* - * - * IMPORTANT: All of the tests for dbExecutive are called from the - * TestAllDBExecutive() function, which runs tests thru both the - * SQLite and the MySQL code paths. Use lowercase t in your test - * function name and add it to the map in TestAllDBExecutive to - * get it to run thru both. - * - */ - package executive import ( "context" "database/sql" + "errors" "fmt" "path/filepath" "sort" @@ -22,7 +13,7 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/pkg/errors" + "github.com/segmentio/ctlstore/pkg/ctldb" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/limits" @@ -34,8 +25,6 @@ import ( "github.com/stretchr/testify/require" ) -type dbExecTestFn func(*testing.T, string) - const ( // Schema executed to initialize the test database. testCtlDBSchemaUpForMySQL = ` @@ -113,41 +102,6 @@ var ( testDefaultWriterLimit = limits.RateLimit{Amount: 1000, Period: time.Minute} ) -func TestAllDBExecutive(t *testing.T) { - dbTypes := []string{"mysql", "sqlite3"} - testFns := map[string]dbExecTestFn{ - "testDBExecutiveCreateFamily": testDBExecutiveCreateFamily, - "testDBExecutiveCreateTable": testDBExecutiveCreateTable, - "testDBExecutiveCreateTables": testDBExecutiveCreateTables, - "testDBExecutiveCreateTableLocksLedger": testDBExecutiveCreateTableLocksLedger, - "testDBExecutiveAddFields": testDBExecutiveAddFields, - "testDBExecutiveAddFieldsLocksLedger": testDBExecutiveAddFieldsLocksLedger, - "testDBExecutiveFetchFamilyByName": testDBExecutiveFetchFamilyByName, - "testDBExecutiveMutate": testDBExecutiveMutate, - "testDBExecutiveGetWriterCookie": testDBExecutiveGetWriterCookie, - "testDBExecutiveSetWriterCookie": testDBExecutiveSetWriterCookie, - "testFetchMetaTableByName": testFetchMetaTableByName, - "testDBExecutiveRegisterWriter": testDBExecutiveRegisterWriter, - "testDBExecutiveReadRow": testDBExecutiveReadRow, - "testDBLimiter": testDBLimiter, - "testDBExecutiveWriterRates": testDBExecutiveWriterRates, - "testDBExecutiveTableLimits": testDBExecutiveTableLimits, - "testDBExecutiveClearTable": testDBExecutiveClearTable, - "testDBExecutiveDropTable": testDBExecutiveDropTable, - "testDBExecutiveReadFamilyTableNames": testDBExecutiveReadFamilyTableNames, - "testDBExecutiveTableSchema": testDBExecutiveTableSchema, - "testDBExecutiveFamilySchemas": testDBExecutiveFamilySchemas, - } - - for _, dbType := range dbTypes { - for testName, testFn := range testFns { - t.Run(testName+"_"+dbType, func(t *testing.T) { - testFn(t, dbType) - }) - } - } -} - func newCtlDBTestConnection(t *testing.T, dbType string) (*sql.DB, func()) { var ( teardowns utils.Teardowns @@ -195,7 +149,9 @@ func newCtlDBTestConnection(t *testing.T, dbType string) (*sql.DB, func()) { schemaUp = testCtlDBSchemaUpForSQLite3 tmpDir, td := tests.WithTmpDir(t) teardowns.Add(td) - db, err = sql.Open("sqlite3", filepath.Join(tmpDir, "ctldb.db")) + ldbpath := filepath.Join(tmpDir, "ctldb.db") + t.Logf("LDB path: %s", ldbpath) + db, err = sql.Open("sqlite3", ldbpath) default: t.Fatalf("unknown dbtype %q", dbType) } @@ -253,49 +209,64 @@ func newDbExecTestUtil(t *testing.T, dbType string) *dbExecTestUtil { } } -func testDBExecutiveCreateFamily(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - err := u.e.CreateFamily("family2") - if err != nil { - t.Errorf("Unexpected error calling CreateFamily: %+v", err) +func withDBTypes(t *testing.T, fn func(dbType string)) { + dbTypes := []string{} + dbTypes = append(dbTypes, "mysql") + dbTypes = append(dbTypes, "sqlite3") + for _, dbType := range dbTypes { + t.Run(dbType, func(t *testing.T) { + fn(dbType) + }) } +} - row := u.db.QueryRow("SELECT COUNT(*) FROM families WHERE name = 'family2'") - var cnt sql.NullInt64 - err = row.Scan(&cnt) - if err != nil { - t.Fatalf("Unexpected error scanning result: %v", err) - } +func TestDBExecutiveCreateFamily(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + err := u.e.CreateFamily("family2") + if err != nil { + t.Errorf("Unexpected error calling CreateFamily: %+v", err) + } - if want, got := 1, cnt; !got.Valid || int(got.Int64) != want { - t.Errorf("Expected %v rows, got %v", want, got) - } + row := u.db.QueryRow("SELECT COUNT(*) FROM families WHERE name = 'family2'") + var cnt sql.NullInt64 + err = row.Scan(&cnt) + if err != nil { + t.Fatalf("Unexpected error scanning result: %v", err) + } - err = u.e.CreateFamily("family2") - if err != nil && err.Error() != "Family already exists" { - t.Errorf("Unexpected error %v", err) - } + if want, got := 1, cnt; !got.Valid || int(got.Int64) != want { + t.Errorf("Expected %v rows, got %v", want, got) + } + + err = u.e.CreateFamily("family2") + if err != nil && err.Error() != "Family already exists" { + t.Errorf("Unexpected error %v", err) + } + }) } -func testDBExecutiveRegisterWriter(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - ms := mutatorStore{DB: u.db, Ctx: u.ctx, TableName: mutatorsTableName} +func TestDBExecutiveRegisterWriter(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + ms := mutatorStore{DB: u.db, Ctx: u.ctx, TableName: mutatorsTableName} - // first ensure that register writer succeeds - err := u.e.RegisterWriter("writerTest", "secret1") - require.NoError(t, err) + // first ensure that register writer succeeds + err := u.e.RegisterWriter("writerTest", "secret1") + require.NoError(t, err) - _, found, err := ms.Get(schema.WriterName{Name: "writerTest"}, "secret1") - require.NoError(t, err) - require.True(t, found) + _, found, err := ms.Get(schema.WriterName{Name: "writerTest"}, "secret1") + require.NoError(t, err) + require.True(t, found) - // try to register again with the same credentials - err = u.e.RegisterWriter("writerTest", "secret1") - require.NoError(t, err) + // try to register again with the same credentials + err = u.e.RegisterWriter("writerTest", "secret1") + require.NoError(t, err) - // register the same writer but with a different credential - err = u.e.RegisterWriter("writerTest", "some new secret") - require.Equal(t, err, ErrWriterAlreadyExists) + // register the same writer but with a different credential + err = u.e.RegisterWriter("writerTest", "some new secret") + require.Equal(t, err, ErrWriterAlreadyExists) + }) } func queryDMLTable(t *testing.T, db *sql.DB, limit int) []string { @@ -316,53 +287,109 @@ func queryDMLTable(t *testing.T, db *sql.DB, limit int) []string { return statements } -func testDBExecutiveFamilySchemas(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() - err := u.e.CreateFamily("schematest2") - require.NoError(t, err) - err = u.e.CreateTable("schematest2", - "table1", - []string{"field1", "field2", "field3", "field4", "field5", "field6"}, - []schema.FieldType{ - schema.FTString, - schema.FTInteger, - schema.FTDecimal, - schema.FTText, - schema.FTBinary, - schema.FTByteString, - }, - []string{"field1", "field2", "field6"}, - ) - require.NoError(t, err) - err = u.e.CreateTable("schematest2", - "table2", - []string{"field1", "field2"}, - []schema.FieldType{ - schema.FTInteger, - schema.FTBinary, - }, - []string{"field1"}, - ) - require.NoError(t, err) +func TestDBExecutiveFamilySchemas(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() + err := u.e.CreateFamily("schematest2") + require.NoError(t, err) + err = u.e.CreateTable("schematest2", + "table1", + []string{"field1", "field2", "field3", "field4", "field5", "field6"}, + []schema.FieldType{ + schema.FTString, + schema.FTInteger, + schema.FTDecimal, + schema.FTText, + schema.FTBinary, + schema.FTByteString, + }, + []string{"field1", "field2", "field6"}, + ) + require.NoError(t, err) + err = u.e.CreateTable("schematest2", + "table2", + []string{"field1", "field2"}, + []schema.FieldType{ + schema.FTInteger, + schema.FTBinary, + }, + []string{"field1"}, + ) + require.NoError(t, err) - err = u.e.CreateFamily("schematest_other") - require.NoError(t, err) - err = u.e.CreateTable("schematest_other", - "table3", - []string{"field1"}, - []schema.FieldType{ - schema.FTInteger, - }, - []string{"field1"}, - ) - require.NoError(t, err) + err = u.e.CreateFamily("schematest_other") + require.NoError(t, err) + err = u.e.CreateTable("schematest_other", + "table3", + []string{"field1"}, + []schema.FieldType{ + schema.FTInteger, + }, + []string{"field1"}, + ) + require.NoError(t, err) - schemas, err := u.e.FamilySchemas("schematest2") - require.NoError(t, err) - expected := []schema.Table{ - { - Family: "schematest2", + schemas, err := u.e.FamilySchemas("schematest2") + require.NoError(t, err) + expected := []schema.Table{ + { + Family: "schematest2", + Name: "table1", + Fields: [][]string{ + {"field1", "string"}, + {"field2", "integer"}, + {"field3", "decimal"}, + {"field4", "text"}, + {"field5", "binary"}, + {"field6", "bytestring"}, + }, + KeyFields: []string{ + "field1", + "field2", + "field6", + }, + }, + { + Family: "schematest2", + Name: "table2", + Fields: [][]string{ + {"field1", "integer"}, + {"field2", "binary"}, + }, + KeyFields: []string{ + "field1", + }, + }, + } + require.EqualValues(t, expected, schemas) + }) +} + +func TestDBExecutiveTableSchema(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() + err := u.e.CreateFamily("schematest1") + require.NoError(t, err) + err = u.e.CreateTable("schematest1", + "table1", + []string{"field1", "field2", "field3", "field4", "field5", "field6"}, + []schema.FieldType{ + schema.FTString, + schema.FTInteger, + schema.FTDecimal, + schema.FTText, + schema.FTBinary, + schema.FTByteString, + }, + []string{"field1", "field2", "field6"}, + ) + require.NoError(t, err) + tableSchema, err := u.e.TableSchema("schematest1", "table1") + require.NoError(t, err) + expected := &schema.Table{ + Family: "schematest1", Name: "table1", Fields: [][]string{ {"field1", "string"}, @@ -377,1491 +404,1528 @@ func testDBExecutiveFamilySchemas(t *testing.T, dbType string) { "field2", "field6", }, - }, - { - Family: "schematest2", - Name: "table2", - Fields: [][]string{ - {"field1", "integer"}, - {"field2", "binary"}, - }, - KeyFields: []string{ - "field1", - }, - }, - } - require.EqualValues(t, expected, schemas) + } + require.EqualValues(t, expected, tableSchema) + }) } -func testDBExecutiveTableSchema(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() - err := u.e.CreateFamily("schematest1") - require.NoError(t, err) - err = u.e.CreateTable("schematest1", - "table1", - []string{"field1", "field2", "field3", "field4", "field5", "field6"}, - []schema.FieldType{ - schema.FTString, - schema.FTInteger, - schema.FTDecimal, - schema.FTText, - schema.FTBinary, - schema.FTByteString, - }, - []string{"field1", "field2", "field6"}, - ) - require.NoError(t, err) - tableSchema, err := u.e.TableSchema("schematest1", "table1") - require.NoError(t, err) - expected := &schema.Table{ - Family: "schematest1", - Name: "table1", - Fields: [][]string{ - {"field1", "string"}, - {"field2", "integer"}, - {"field3", "decimal"}, - {"field4", "text"}, - {"field5", "binary"}, - {"field6", "bytestring"}, - }, - KeyFields: []string{ - "field1", - "field2", - "field6", - }, - } - require.EqualValues(t, expected, tableSchema) +func TestSimpleLockLedger(t *testing.T) { + withDBTypes(t, func(dbType string) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + u := newDbExecTestUtil(t, dbType) + defer u.Close() + err := u.e.CreateTable("family1", + "table2", + []string{"field1"}, + []schema.FieldType{schema.FTString}, + []string{"field1"}, + ) + require.NoError(t, err) + + const numGoroutines = 2 + errs := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + errs <- func() error { + tx, err := u.e.DB.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + err = u.e.takeLedgerLock(ctx, tx) + if err != nil { + return err + } + err = tx.Commit() + return err + }() + }() + } + for i := 0; i < numGoroutines; i++ { + err := <-errs + require.NoError(t, err) + } + + }) } // multiple goroutine will attempt to add a number of fields to the same // table concurrently. this test verifies that the ledger sequences do not // skip from the perspective of a reader repeatedly querying the dml ledger // table. -func testDBExecutiveAddFieldsLocksLedger(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() - - err := u.e.CreateTable("family1", - "table2", - []string{"field1"}, - []schema.FieldType{schema.FTString}, - []string{"field1"}, - ) - require.NoError(t, err) - - var numGoroutines = 10 - const numFields = 5 - errs := make(chan error, numGoroutines+1) - for i := 0; i < numGoroutines; i++ { - prefix := fmt.Sprintf("prefix_%d", i) - go func(prefix string) { +func TestDBExecutiveAddFieldsLocksLedger(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() + err := u.e.CreateTable("family1", + "table2", + []string{"field1"}, + []schema.FieldType{schema.FTString}, + []string{"field1"}, + ) + require.NoError(t, err) + var numGoroutines = 10 + const numFields = 5 + errs := make(chan error, numGoroutines+1) + for i := 0; i < numGoroutines; i++ { + prefix := fmt.Sprintf("prefix_%d", i) + go func(prefix string) { + err := func() error { + var fieldNames []string + var fieldTypes []schema.FieldType + for i := 0; i < numFields; i++ { + fieldNames = append(fieldNames, fmt.Sprintf("%s_field_%d", prefix, i)) + fieldTypes = append(fieldTypes, schema.FTText) + } + return u.e.AddFields("family1", "table2", fieldNames, fieldTypes) + }() + errs <- err + }(prefix) + } + // fetch dml repeatedly, detecting gaps in the ledger. + go func() { err := func() error { - var fieldNames []string - var fieldTypes []schema.FieldType - for i := 0; i < numFields; i++ { - fieldNames = append(fieldNames, fmt.Sprintf("%s_field_%d", prefix, i)) - fieldTypes = append(fieldTypes, schema.FTText) - } - return u.e.AddFields("family1", "table2", fieldNames, fieldTypes) - }() - errs <- err - }(prefix) - } - // fetch dml repeatedly, detecting gaps in the ledger. - go func() { - err := func() error { - lastSeq := int64(-1) - for { - if lastSeq == int64(numGoroutines*numFields+1) { - // yay we're done - return nil - } - sql := "SELECT seq FROM ctlstore_dml_ledger WHERE seq > ? ORDER BY seq LIMIT 10" - rows, err := u.db.QueryContext(u.ctx, sql, lastSeq) - if err != nil { - return errors.Wrap(err, "fetch") - } - for rows.Next() { - var seq int64 - err := rows.Scan(&seq) + lastSeq := int64(-1) + for { + if lastSeq == int64(numGoroutines*numFields+1) { + // yay we're done + return nil + } + sql := "SELECT seq FROM ctlstore_dml_ledger WHERE seq > ? ORDER BY seq LIMIT 10" + rows, err := u.db.QueryContext(u.ctx, sql, lastSeq) if err != nil { - return errors.Wrap(err, "scan") + return fmt.Errorf("fetch: %w", err) } - if lastSeq == -1 { - if seq != 1 { - return fmt.Errorf("first sequence was %d", seq) + for rows.Next() { + var seq int64 + err := rows.Scan(&seq) + if err != nil { + return fmt.Errorf("scan: %w", err) } - } else { - if seq != lastSeq+1 { - return fmt.Errorf("detected gap seq=%d lastSeq=%d", seq, lastSeq) + if lastSeq == -1 { + if seq != 1 { + return fmt.Errorf("first sequence was %d", seq) + } + } else { + if seq != lastSeq+1 { + return fmt.Errorf("detected gap seq=%d lastSeq=%d", seq, lastSeq) + } } + lastSeq = seq } - lastSeq = seq - } - err = rows.Err() - if err != nil { - return err + err = rows.Err() + if err != nil { + return err + } + time.Sleep(10 * time.Millisecond) } - time.Sleep(10 * time.Millisecond) + }() + if err != nil { + err = fmt.Errorf("reader: %w", err) } + errs <- err }() - errs <- errors.Wrap(err, "reader") - }() - for i := 0; i < numGoroutines+1; i++ { - err := <-errs - require.NoError(t, err) - } + for i := 0; i < numGoroutines+1; i++ { + err := <-errs + require.NoError(t, err) + } + }) } -func testDBExecutiveAddFields(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveAddFields(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() + + addFields := func() error { + return u.e.AddFields("family1", + "table2", + []string{"field7", "field8", "field9", "field10", "field11", "field12"}, + []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, + ) + } + + // first verify that we cannot add to the table if it does not already exist. + err := addFields() + require.Error(t, err) + // also verify that no DML exists + dmls := queryDMLTable(t, u.db, -1) + require.Empty(t, dmls) - addFields := func() error { - return u.e.AddFields("family1", + err = u.e.CreateTable("family1", "table2", - []string{"field7", "field8", "field9", "field10", "field11", "field12"}, + []string{"field1", "field2", "field3", "field4", "field5", "field6"}, []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, + []string{"field1", "field2", "field3"}, ) - } - - // first verify that we cannot add to the table if it does not already exist. - err := addFields() - require.Error(t, err) - // also verify that no DML exists - dmls := queryDMLTable(t, u.db, -1) - require.Empty(t, dmls) - - err = u.e.CreateTable("family1", - "table2", - []string{"field1", "field2", "field3", "field4", "field5", "field6"}, - []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, - []string{"field1", "field2", "field3"}, - ) - if err != nil { - t.Fatalf("Unexpected error calling CreateTable: %+v", err) - } + if err != nil { + t.Fatalf("Unexpected error calling CreateTable: %+v", err) + } - err = addFields() - if err != nil { - t.Fatalf("Unexpected error calling UpdateTable: %+v", err) - } + err = addFields() + if err != nil { + t.Fatalf("Unexpected error calling UpdateTable: %+v", err) + } - // ensure that the table was modified correctly in the ctldb + // ensure that the table was modified correctly in the ctldb - res, err := u.db.Exec(`INSERT into family1___table2 + res, err := u.db.Exec(`INSERT into family1___table2 (field1,field2,field3,field4,field5,field6,field7,field8,field9,field10,field11,field12) VALUES ('1',2,'3',4.1,'5',x'6a','7',8,'9',10.1,'11',x'12') `) - if err != nil { - t.Fatal(err) - } - rows, err := res.RowsAffected() - if err != nil { - t.Fatal(err) - } - if rows != int64(1) { - t.Fatal(rows) - } + if err != nil { + t.Fatal(err) + } + rows, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } + if rows != int64(1) { + t.Fatal(rows) + } - // ensure that the DML was added to the ledger - statements := queryDMLTable(t, u.db, 6) - require.EqualValues(t, []string{ - "ALTER TABLE family1___table2 ADD COLUMN \"field12\" BLOB", - "ALTER TABLE family1___table2 ADD COLUMN \"field11\" TEXT", - "ALTER TABLE family1___table2 ADD COLUMN \"field10\" REAL", - "ALTER TABLE family1___table2 ADD COLUMN \"field9\" BLOB(255)", - "ALTER TABLE family1___table2 ADD COLUMN \"field8\" INTEGER", - "ALTER TABLE family1___table2 ADD COLUMN \"field7\" VARCHAR(191)", - }, statements) - - err = u.e.AddFields("family1", - "table2", - []string{"field7", "field8", "field9", "field10", "field11", "field12"}, - []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, - ) - if err == nil || !strings.Contains(err.Error(), "Column already exists") { - t.Fatalf("Unexpected error calling UpdateTable: %+v", err) - } + // ensure that the DML was added to the ledger + statements := queryDMLTable(t, u.db, 6) + require.EqualValues(t, []string{ + "ALTER TABLE family1___table2 ADD COLUMN \"field12\" BLOB", + "ALTER TABLE family1___table2 ADD COLUMN \"field11\" TEXT", + "ALTER TABLE family1___table2 ADD COLUMN \"field10\" REAL", + "ALTER TABLE family1___table2 ADD COLUMN \"field9\" BLOB(255)", + "ALTER TABLE family1___table2 ADD COLUMN \"field8\" INTEGER", + "ALTER TABLE family1___table2 ADD COLUMN \"field7\" VARCHAR(191)", + }, statements) + + err = u.e.AddFields("family1", + "table2", + []string{"field7", "field8", "field9", "field10", "field11", "field12"}, + []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, + ) + if err == nil || !strings.Contains(err.Error(), "Column already exists") { + t.Fatalf("Unexpected error calling UpdateTable: %+v", err) + } + }) } // multiple goroutine will attempt to create a number of tables in the same // DB concurrently. this test verifies that the ledger sequences do not // skip from the perspective of a reader repeatedly querying the dml ledger // table. -func testDBExecutiveCreateTableLocksLedger(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() - - err := u.e.CreateTable("family1", - "table2", - []string{"field1"}, - []schema.FieldType{schema.FTString}, - []string{"field1"}, - ) - require.NoError(t, err) +func TestDBExecutiveCreateTableLocksLedger(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - var numGoroutines = 10 - const numTables = 5 - errs := make(chan error, numGoroutines+1) - for i := 0; i < numGoroutines; i++ { - prefix := fmt.Sprintf("prefix_%d", i) - go func(prefix string) { - err := func() error { - for i := 0; i < numTables; i++ { - err := u.e.CreateTable("family1", - fmt.Sprintf("%s_table_%d", prefix, i), - []string{"field1"}, - []schema.FieldType{schema.FTString}, - []string{"field1"}, - ) - if err != nil { - return err + err := u.e.CreateTable("family1", + "table2", + []string{"field1"}, + []schema.FieldType{schema.FTString}, + []string{"field1"}, + ) + require.NoError(t, err) + + var numGoroutines = 10 + const numTables = 5 + errs := make(chan error, numGoroutines+1) + for i := 0; i < numGoroutines; i++ { + prefix := fmt.Sprintf("prefix_%d", i) + go func(prefix string) { + err := func() error { + for i := 0; i < numTables; i++ { + err := u.e.CreateTable("family1", + fmt.Sprintf("%s_table_%d", prefix, i), + []string{"field1"}, + []schema.FieldType{schema.FTString}, + []string{"field1"}, + ) + if err != nil { + return err + } } - } - return nil - }() - errs <- err - }(prefix) - } - // fetch dml repeatedly, detecting gaps in the ledger. - go func() { - err := func() error { - lastSeq := int64(-1) - for { - if lastSeq == int64(numGoroutines*numTables+1) { - // yay we're done return nil - } - sql := "SELECT seq FROM ctlstore_dml_ledger WHERE seq > ? ORDER BY seq LIMIT 10" - rows, err := u.db.QueryContext(u.ctx, sql, lastSeq) - if err != nil { - return errors.Wrap(err, "fetch") - } - for rows.Next() { - var seq int64 - err := rows.Scan(&seq) + }() + errs <- err + }(prefix) + } + // fetch dml repeatedly, detecting gaps in the ledger. + go func() { + err := func() error { + lastSeq := int64(-1) + for { + if lastSeq == int64(numGoroutines*numTables+1) { + // yay we're done + return nil + } + sql := "SELECT seq FROM ctlstore_dml_ledger WHERE seq > ? ORDER BY seq LIMIT 10" + rows, err := u.db.QueryContext(u.ctx, sql, lastSeq) if err != nil { - return errors.Wrap(err, "scan") + return fmt.Errorf("fetch: %w", err) } - if lastSeq == -1 { - if seq != 1 { - return fmt.Errorf("first sequence was %d", seq) + for rows.Next() { + var seq int64 + err := rows.Scan(&seq) + if err != nil { + return fmt.Errorf("scan: %w", err) } - } else { - if seq != lastSeq+1 { - return fmt.Errorf("detected gap seq=%d lastSeq=%d", seq, lastSeq) + if lastSeq == -1 { + if seq != 1 { + return fmt.Errorf("first sequence was %d", seq) + } + } else { + if seq != lastSeq+1 { + return fmt.Errorf("detected gap seq=%d lastSeq=%d", seq, lastSeq) + } } + lastSeq = seq + t.Logf("seq: %d", lastSeq) } - lastSeq = seq - t.Logf("seq: %d", lastSeq) - } - err = rows.Err() - if err != nil { - return err + err = rows.Err() + if err != nil { + return err + } + time.Sleep(10 * time.Millisecond) } - time.Sleep(10 * time.Millisecond) + }() + if err != nil { + err = fmt.Errorf("reader: %w", err) } + errs <- err }() - errs <- errors.Wrap(err, "reader") - }() - for i := 0; i < numGoroutines+1; i++ { - err := <-errs - require.NoError(t, err) - } + for i := 0; i < numGoroutines+1; i++ { + err := <-errs + require.NoError(t, err) + } + }) } -func testDBExecutiveCreateTable(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveCreateTable(t *testing.T) { + withDBTypes(t, func(dbType string) { - createTable := func() error { - return u.e.CreateTable("family1", - "table2", - []string{"field1", "field2", "field3", "field4", "field5", "field6"}, - []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, - []string{"field1", "field2", "field3"}, - ) - } - err := createTable() - require.NoError(t, err) - dmls := queryDMLTable(t, u.db, -1) - require.Len(t, dmls, 1) // one DML should exist to create the table + u := newDbExecTestUtil(t, dbType) + defer u.Close() + + createTable := func() error { + return u.e.CreateTable("family1", + "table2", + []string{"field1", "field2", "field3", "field4", "field5", "field6"}, + []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTByteString, schema.FTDecimal, schema.FTText, schema.FTBinary}, + []string{"field1", "field2", "field3"}, + ) + } + err := createTable() + require.NoError(t, err) + dmls := queryDMLTable(t, u.db, -1) + require.Len(t, dmls, 1) // one DML should exist to create the table - // try to create the table again, verify it fails, and verify that the ledger is correct - err = createTable() - require.Error(t, err) - dmls = queryDMLTable(t, u.db, -1) - require.Len(t, dmls, 1) // there should still only be one DML + // try to create the table again, verify it fails, and verify that the ledger is correct + err = createTable() + require.Error(t, err) + dmls = queryDMLTable(t, u.db, -1) + require.Len(t, dmls, 1) // there should still only be one DML - // Just check that an empty table exists at all, because the field - // creation logic gets checked by sqlgen unit tests - row := u.db.QueryRow("SELECT COUNT(*) FROM family1___table2") + // Just check that an empty table exists at all, because the field + // creation logic gets checked by sqlgen unit tests + row := u.db.QueryRow("SELECT COUNT(*) FROM family1___table2") - var cnt sql.NullInt64 - err = row.Scan(&cnt) - if err != nil { - t.Fatalf("Unexpected error scanning result: %+v", err) - } + var cnt sql.NullInt64 + err = row.Scan(&cnt) + if err != nil { + t.Fatalf("Unexpected error scanning result: %+v", err) + } - if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { + t.Errorf("Expected %+v, got %+v", want, got) + } - logRow := u.db.QueryRow("SELECT statement FROM " + dmlLedgerTableName) - var rowStatement string - err = logRow.Scan(&rowStatement) - if err != nil { - t.Fatalf("Unexpected error: %+v", err) - } + logRow := u.db.QueryRow("SELECT statement FROM " + dmlLedgerTableName) + var rowStatement string + err = logRow.Scan(&rowStatement) + if err != nil { + t.Fatalf("Unexpected error: %+v", err) + } - indexOfCreate := strings.Index(rowStatement, "CREATE TABLE family1___table2") - if want, got := 0, indexOfCreate; want != got { - t.Errorf("Expected %+v, got %+v", want, got) - } + indexOfCreate := strings.Index(rowStatement, "CREATE TABLE family1___table2") + if want, got := 0, indexOfCreate; want != got { + t.Errorf("Expected %+v, got %+v", want, got) + } - err = u.e.CreateTable("family1", - "table2", - []string{"field1", "field2", "field3"}, - []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTDecimal}, - []string{"field1"}, - ) - if err == nil || err.Error() != "Table already exists" { - t.Errorf("Unexpected error calling CreateTable: %+v", err) - } + err = u.e.CreateTable("family1", + "table2", + []string{"field1", "field2", "field3"}, + []schema.FieldType{schema.FTString, schema.FTInteger, schema.FTDecimal}, + []string{"field1"}, + ) + if err == nil || err.Error() != "Table already exists" { + t.Errorf("Unexpected error calling CreateTable: %+v", err) + } - err = u.e.CreateTable("family1", - "table3", - []string{"field1", "field2"}, - []schema.FieldType{schema.FTString, schema.FTInteger}, - []string{"field3"}) - if err == nil || err.Error() != "Primary key field 'field3' not specified as a field" { - t.Errorf("Unexpected error calling CreateTable: %+v", err) - } + err = u.e.CreateTable("family1", + "table3", + []string{"field1", "field2"}, + []schema.FieldType{schema.FTString, schema.FTInteger}, + []string{"field3"}) + if err == nil || err.Error() != "Primary key field 'field3' not specified as a field" { + t.Errorf("Unexpected error calling CreateTable: %+v", err) + } - err = u.e.CreateTable("family1", - "table4", - []string{"field1", "field2"}, - []schema.FieldType{schema.FTString, schema.FTDecimal}, - []string{"field2"}) - if err == nil || err.Error() != "Fields of type 'decimal' cannot be a key field" { - t.Errorf("Unexpected error calling CreateTable: %+v", err) - } + err = u.e.CreateTable("family1", + "table4", + []string{"field1", "field2"}, + []schema.FieldType{schema.FTString, schema.FTDecimal}, + []string{"field2"}) + if err == nil || err.Error() != "Fields of type 'decimal' cannot be a key field" { + t.Errorf("Unexpected error calling CreateTable: %+v", err) + } - err = u.e.CreateTable("family1", - "table4", - []string{"field1", "field2"}, - []schema.FieldType{schema.FTString, schema.FTDecimal}, - []string{}) - if err == nil || err.Error() != "table must have at least one key field" { - t.Errorf("Unexpected error calling CreateTable: %+v", err) - } + err = u.e.CreateTable("family1", + "table4", + []string{"field1", "field2"}, + []schema.FieldType{schema.FTString, schema.FTDecimal}, + []string{}) + if err == nil || err.Error() != "table must have at least one key field" { + t.Errorf("Unexpected error calling CreateTable: %+v", err) + } + }) } -func testDBExecutiveCreateTables(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveCreateTables(t *testing.T) { + withDBTypes(t, func(dbType string) { - err := u.e.CreateFamily("foofamily") - if err != nil { - t.Errorf("Unexpected error calling CreateFamily: %+v", err) - } + u := newDbExecTestUtil(t, dbType) + defer u.Close() + + err := u.e.CreateFamily("foofamily") + if err != nil { + t.Errorf("Unexpected error calling CreateFamily: %+v", err) + } - createTables := func() error { - return u.e.CreateTables( - []schema.Table{ - { - Family: "foofamily", - Name: "bartable", - Fields: [][]string{ - {"field1", "string"}, + createTables := func() error { + return u.e.CreateTables( + []schema.Table{ + { + Family: "foofamily", + Name: "bartable", + Fields: [][]string{ + {"field1", "string"}, + }, + KeyFields: []string{"field1"}, }, - KeyFields: []string{"field1"}, - }, - { - Family: "foofamily", - Name: "bartable2", - Fields: [][]string{ - {"field1", "string"}, - {"field2", "integer"}, + { + Family: "foofamily", + Name: "bartable2", + Fields: [][]string{ + {"field1", "string"}, + {"field2", "integer"}, + }, + KeyFields: []string{"field1"}, }, - KeyFields: []string{"field1"}, }, - }, - ) - } + ) + } - err = createTables() - require.NoError(t, err) - dmls := queryDMLTable(t, u.db, -1) - require.Len(t, dmls, 2) // 2 DMLs should exist to create the 2 tables + err = createTables() + require.NoError(t, err) + dmls := queryDMLTable(t, u.db, -1) + require.Len(t, dmls, 2) // 2 DMLs should exist to create the 2 tables - // try to create the tables again, verify it fails, and verify that the ledger is correct - err = createTables() - require.Error(t, err) - dmls = queryDMLTable(t, u.db, -1) - require.Len(t, dmls, 2) // there should still be two DMLs + // try to create the tables again, verify it fails, and verify that the ledger is correct + err = createTables() + require.Error(t, err) + dmls = queryDMLTable(t, u.db, -1) + require.Len(t, dmls, 2) // there should still be two DMLs - // Just check that empty tables exist at all, because the field - // creation logic gets checked by sqlgen unit tests + // Just check that empty tables exist at all, because the field + // creation logic gets checked by sqlgen unit tests - row := u.db.QueryRow("SELECT COUNT(*) FROM foofamily___bartable") - var cnt sql.NullInt64 - err = row.Scan(&cnt) - if err != nil { - t.Fatalf("Unexpected error scanning result: %+v", err) - } - if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + row := u.db.QueryRow("SELECT COUNT(*) FROM foofamily___bartable") + var cnt sql.NullInt64 + err = row.Scan(&cnt) + if err != nil { + t.Fatalf("Unexpected error scanning result: %+v", err) + } + if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { + t.Errorf("Expected %+v, got %+v", want, got) + } - // Next table - row = u.db.QueryRow("SELECT COUNT(*) FROM foofamily___bartable2") - err = row.Scan(&cnt) - if err != nil { - t.Fatalf("Unexpected error scanning result: %+v", err) - } - if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + // Next table + row = u.db.QueryRow("SELECT COUNT(*) FROM foofamily___bartable2") + err = row.Scan(&cnt) + if err != nil { + t.Fatalf("Unexpected error scanning result: %+v", err) + } + if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { + t.Errorf("Expected %+v, got %+v", want, got) + } - rows, err := u.db.Query("SELECT statement FROM " + dmlLedgerTableName) - if err != nil { - t.Fatalf("Unexpected error: %+v", err) - } - defer rows.Close() - i := 0 - for rows.Next() { - var rowStatement string - if err := rows.Scan(&rowStatement); err != nil { + rows, err := u.db.Query("SELECT statement FROM " + dmlLedgerTableName) + if err != nil { t.Fatalf("Unexpected error: %+v", err) } - tableNames := []string{"bartable", "bartable2"} - indexOfCreate := strings.Index(rowStatement, "CREATE TABLE foofamily___"+tableNames[i]) - if want, got := 0, indexOfCreate; want != got { - t.Errorf("Expected %+v, got %+v", want, got) + defer rows.Close() + i := 0 + for rows.Next() { + var rowStatement string + if err := rows.Scan(&rowStatement); err != nil { + t.Fatalf("Unexpected error: %+v", err) + } + tableNames := []string{"bartable", "bartable2"} + indexOfCreate := strings.Index(rowStatement, "CREATE TABLE foofamily___"+tableNames[i]) + if want, got := 0, indexOfCreate; want != got { + t.Errorf("Expected %+v, got %+v", want, got) + } + i++ } - i++ - } + }) } -func testDBExecutiveTableLimits(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveTableLimits(t *testing.T) { + withDBTypes(t, func(dbType string) { - ctx, cancel := context.WithCancel(u.ctx) - defer cancel() + u := newDbExecTestUtil(t, dbType) + defer u.Close() - // assert that there are not table limits - tsLimits, err := u.e.ReadTableSizeLimits() - require.NoError(t, err) - require.EqualValues(t, limits.TableSizeLimits{Global: testDefaultTableLimit}, tsLimits) - - tableLimit1 := limits.TableSizeLimit{ - Family: "foo", - Table: "bar", - SizeLimits: limits.SizeLimits{ - MaxSize: 100, - WarnSize: 5, - }, - } - tableLimit2 := limits.TableSizeLimit{ - Family: "foo2", - Table: "baz", - SizeLimits: limits.SizeLimits{ - MaxSize: 1100, - WarnSize: 15, - }, - } + ctx, cancel := context.WithCancel(u.ctx) + defer cancel() - // ensure that you can't set table size limits for tables that do not exist - err = u.e.UpdateTableSizeLimit(tableLimit1) - require.EqualError(t, errors.Cause(err), "table 'foo___bar' not found") - - // createTable creates a table in the ctldb with a generic schema - createTable := func(family, name string) { - require.NoError(t, u.e.CreateFamily(family)) - fieldNames := []string{"name", "data"} - fieldTypes := []schema.FieldType{schema.FTString, schema.FTBinary} - keyFields := []string{"name"} - err = u.e.CreateTable(family, name, fieldNames, fieldTypes, keyFields) + // assert that there are not table limits + tsLimits, err := u.e.ReadTableSizeLimits() require.NoError(t, err) - } + require.EqualValues(t, limits.TableSizeLimits{Global: testDefaultTableLimit}, tsLimits) + + tableLimit1 := limits.TableSizeLimit{ + Family: "foo", + Table: "bar", + SizeLimits: limits.SizeLimits{ + MaxSize: 100, + WarnSize: 5, + }, + } + tableLimit2 := limits.TableSizeLimit{ + Family: "foo2", + Table: "baz", + SizeLimits: limits.SizeLimits{ + MaxSize: 1100, + WarnSize: 15, + }, + } - // create the table - createTable("foo", "bar") + // ensure that you can't set table size limits for tables that do not exist + err = u.e.UpdateTableSizeLimit(tableLimit1) + require.EqualError(t, err, "table 'foo___bar' not found") + + // createTable creates a table in the ctldb with a generic schema + createTable := func(family, name string) { + require.NoError(t, u.e.CreateFamily(family)) + fieldNames := []string{"name", "data"} + fieldTypes := []schema.FieldType{schema.FTString, schema.FTBinary} + keyFields := []string{"name"} + err = u.e.CreateTable(family, name, fieldNames, fieldTypes, keyFields) + require.NoError(t, err) + } - // then the mutation to set a table size limit should work - err = u.e.UpdateTableSizeLimit(tableLimit1) - require.NoError(t, err) + // create the table + createTable("foo", "bar") - // verify that the mutation exists - tsLimits, err = u.e.ReadTableSizeLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) - require.EqualValues(t, []limits.TableSizeLimit{tableLimit1}, tsLimits.Tables) - - // verify that the mutation exists in the table that the limiter expects - var warnSize, maxSize int64 - row := u.e.DB.QueryRowContext(ctx, "select warn_size_bytes, max_size_bytes "+ - "from max_table_sizes "+ - "where family_name=? and table_name=?", tableLimit1.Family, tableLimit1.Table) - err = row.Scan(&warnSize, &maxSize) - require.NoError(t, err) - require.EqualValues(t, tableLimit1.WarnSize, warnSize) - require.EqualValues(t, tableLimit1.MaxSize, maxSize) + // then the mutation to set a table size limit should work + err = u.e.UpdateTableSizeLimit(tableLimit1) + require.NoError(t, err) - // create another table limit, but we will also create a new table first - createTable(tableLimit2.Family, tableLimit2.Table) - err = u.e.UpdateTableSizeLimit(tableLimit2) - require.NoError(t, err) + // verify that the mutation exists + tsLimits, err = u.e.ReadTableSizeLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) + require.EqualValues(t, []limits.TableSizeLimit{tableLimit1}, tsLimits.Tables) + + // verify that the mutation exists in the table that the limiter expects + var warnSize, maxSize int64 + row := u.e.DB.QueryRowContext(ctx, "select warn_size_bytes, max_size_bytes "+ + "from max_table_sizes "+ + "where family_name=? and table_name=?", tableLimit1.Family, tableLimit1.Table) + err = row.Scan(&warnSize, &maxSize) + require.NoError(t, err) + require.EqualValues(t, tableLimit1.WarnSize, warnSize) + require.EqualValues(t, tableLimit1.MaxSize, maxSize) - // verify that it shows up in the table limit query - tsLimits, err = u.e.ReadTableSizeLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) - require.EqualValues(t, []limits.TableSizeLimit{tableLimit1, tableLimit2}, tsLimits.Tables) + // create another table limit, but we will also create a new table first + createTable(tableLimit2.Family, tableLimit2.Table) + err = u.e.UpdateTableSizeLimit(tableLimit2) + require.NoError(t, err) - // delete the first table limit - err = u.e.DeleteTableSizeLimit(schema.FamilyTable{Family: tableLimit1.Family, Table: tableLimit1.Table}) - require.NoError(t, err) + // verify that it shows up in the table limit query + tsLimits, err = u.e.ReadTableSizeLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) + require.EqualValues(t, []limits.TableSizeLimit{tableLimit1, tableLimit2}, tsLimits.Tables) - // verify it no longer exists - tsLimits, err = u.e.ReadTableSizeLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) - require.EqualValues(t, []limits.TableSizeLimit{tableLimit2}, tsLimits.Tables) + // delete the first table limit + err = u.e.DeleteTableSizeLimit(schema.FamilyTable{Family: tableLimit1.Family, Table: tableLimit1.Table}) + require.NoError(t, err) - // update the second table limit to a different value - tableLimit2.MaxSize = 5000000 - require.NoError(t, u.e.UpdateTableSizeLimit(tableLimit2)) + // verify it no longer exists + tsLimits, err = u.e.ReadTableSizeLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) + require.EqualValues(t, []limits.TableSizeLimit{tableLimit2}, tsLimits.Tables) - // verify that the value was updated - tsLimits, err = u.e.ReadTableSizeLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) - require.EqualValues(t, []limits.TableSizeLimit{tableLimit2}, tsLimits.Tables) + // update the second table limit to a different value + tableLimit2.MaxSize = 5000000 + require.NoError(t, u.e.UpdateTableSizeLimit(tableLimit2)) + + // verify that the value was updated + tsLimits, err = u.e.ReadTableSizeLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultTableLimit, tsLimits.Global) + require.EqualValues(t, []limits.TableSizeLimit{tableLimit2}, tsLimits.Tables) + }) } -func testDBExecutiveWriterRates(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveWriterRates(t *testing.T) { + withDBTypes(t, func(dbType string) { - ctx, cancel := context.WithCancel(u.ctx) - defer cancel() + u := newDbExecTestUtil(t, dbType) + defer u.Close() - // assert that there are no writer limits - wrLimits, err := u.e.ReadWriterRateLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) - require.Len(t, wrLimits.Writers, 0) + ctx, cancel := context.WithCancel(u.ctx) + defer cancel() - const ( - writer1 = "my-writer-1" - writer2 = "my-writer-2" - ) - writerLimit1 := limits.WriterRateLimit{ - Writer: writer1, - RateLimit: limits.RateLimit{ - Amount: 2, - Period: time.Second, - }, - } - writerLimit2 := limits.WriterRateLimit{ - Writer: writer2, - RateLimit: limits.RateLimit{ - Amount: 120, - Period: time.Minute, - }, - } - // the limiter converts all rates to the configured period (1m). note - // that this is not needed for writerLimit2 because it's already based - // on the configured period. - expectedWriterLimit1 := limits.WriterRateLimit{ - Writer: writer1, - RateLimit: limits.RateLimit{ - Amount: 120, - Period: time.Minute, - }, - } + // assert that there are no writer limits + wrLimits, err := u.e.ReadWriterRateLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) + require.Len(t, wrLimits.Writers, 0) - // verify the writer must first exist - err = u.e.UpdateWriterRateLimit(writerLimit1) - require.EqualError(t, err, "no writer with the name '"+writer1+"' exists") + const ( + writer1 = "my-writer-1" + writer2 = "my-writer-2" + ) + writerLimit1 := limits.WriterRateLimit{ + Writer: writer1, + RateLimit: limits.RateLimit{ + Amount: 2, + Period: time.Second, + }, + } + writerLimit2 := limits.WriterRateLimit{ + Writer: writer2, + RateLimit: limits.RateLimit{ + Amount: 120, + Period: time.Minute, + }, + } + // the limiter converts all rates to the configured period (1m). note + // that this is not needed for writerLimit2 because it's already based + // on the configured period. + expectedWriterLimit1 := limits.WriterRateLimit{ + Writer: writer1, + RateLimit: limits.RateLimit{ + Amount: 120, + Period: time.Minute, + }, + } - require.NoError(t, u.e.RegisterWriter(writer1, "my-writer-secret")) - err = u.e.UpdateWriterRateLimit(writerLimit1) - require.NoError(t, err) + // verify the writer must first exist + err = u.e.UpdateWriterRateLimit(writerLimit1) + require.EqualError(t, err, "no writer with the name '"+writer1+"' exists") - // verify that the writer appears now in a read request - wrLimits, err = u.e.ReadWriterRateLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) - require.EqualValues(t, []limits.WriterRateLimit{expectedWriterLimit1}, wrLimits.Writers) - - // verify that the limit exists in the table that the limiter reads as well - row := u.db.QueryRowContext(ctx, "select max_rows_per_minute "+ - "from max_writer_rates "+ - "where writer_name=?", writer1) - var value int64 - require.NoError(t, row.Scan(&value)) - require.EqualValues(t, 120, value) - - // create another writer limit - require.NoError(t, u.e.RegisterWriter(writer2, "my-writer-secret-2")) - err = u.e.UpdateWriterRateLimit(writerLimit2) - require.NoError(t, err) + require.NoError(t, u.e.RegisterWriter(writer1, "my-writer-secret")) + err = u.e.UpdateWriterRateLimit(writerLimit1) + require.NoError(t, err) - // verify it shows up in the rate limit read query - wrLimits, err = u.e.ReadWriterRateLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) - require.EqualValues(t, []limits.WriterRateLimit{expectedWriterLimit1, writerLimit2}, wrLimits.Writers) + // verify that the writer appears now in a read request + wrLimits, err = u.e.ReadWriterRateLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) + require.EqualValues(t, []limits.WriterRateLimit{expectedWriterLimit1}, wrLimits.Writers) + + // verify that the limit exists in the table that the limiter reads as well + row := u.db.QueryRowContext(ctx, "select max_rows_per_minute "+ + "from max_writer_rates "+ + "where writer_name=?", writer1) + var value int64 + require.NoError(t, row.Scan(&value)) + require.EqualValues(t, 120, value) + + // create another writer limit + require.NoError(t, u.e.RegisterWriter(writer2, "my-writer-secret-2")) + err = u.e.UpdateWriterRateLimit(writerLimit2) + require.NoError(t, err) - // delete the first writer limit - require.NoError(t, u.e.DeleteWriterRateLimit(writer1)) + // verify it shows up in the rate limit read query + wrLimits, err = u.e.ReadWriterRateLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) + require.EqualValues(t, []limits.WriterRateLimit{expectedWriterLimit1, writerLimit2}, wrLimits.Writers) - // verify it no longer exists - wrLimits, err = u.e.ReadWriterRateLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) - require.EqualValues(t, []limits.WriterRateLimit{writerLimit2}, wrLimits.Writers) + // delete the first writer limit + require.NoError(t, u.e.DeleteWriterRateLimit(writer1)) - // update the second writer limit to a different value - writerLimit2.RateLimit.Amount = 300 - require.NoError(t, u.e.UpdateWriterRateLimit(writerLimit2)) + // verify it no longer exists + wrLimits, err = u.e.ReadWriterRateLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) + require.EqualValues(t, []limits.WriterRateLimit{writerLimit2}, wrLimits.Writers) - // verify that the value was updated - wrLimits, err = u.e.ReadWriterRateLimits() - require.NoError(t, err) - require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) - require.EqualValues(t, []limits.WriterRateLimit{writerLimit2}, wrLimits.Writers) + // update the second writer limit to a different value + writerLimit2.RateLimit.Amount = 300 + require.NoError(t, u.e.UpdateWriterRateLimit(writerLimit2)) + + // verify that the value was updated + wrLimits, err = u.e.ReadWriterRateLimits() + require.NoError(t, err) + require.EqualValues(t, testDefaultWriterLimit, wrLimits.Global) + require.EqualValues(t, []limits.WriterRateLimit{writerLimit2}, wrLimits.Writers) + }) } -func testDBExecutiveFetchFamilyByName(t *testing.T, dbType string) { - // Table testing this is so overkill, I get it. I just can't write - // software without intermediate unit tests. I'm too stupid. - suite := []struct { - desc string - familyName string - wantFam dbFamily - wantOk bool - wantErr string - }{ - {"Found case", "family1", dbFamily{1, "family1"}, true, ""}, - {"Not found", "family2", dbFamily{}, false, ""}, - {"Error", "family1", dbFamily{}, false, "sql: database is closed"}, - } +func TestDBExecutiveFetchFamilyByName(t *testing.T) { + withDBTypes(t, func(dbType string) { + + // Table testing this is so overkill, I get it. I just can't write + // software without intermediate unit tests. I'm too stupid. + suite := []struct { + desc string + familyName string + wantFam dbFamily + wantOk bool + wantErr string + }{ + {"Found case", "family1", dbFamily{1, "family1"}, true, ""}, + {"Not found", "family2", dbFamily{}, false, ""}, + {"Error", "family1", dbFamily{}, false, "sql: database is closed"}, + } - for i, testCase := range suite { - testName := fmt.Sprintf("[%d] %s", i, testCase.desc) - t.Run(testName, func(t *testing.T) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() + for i, testCase := range suite { + testName := fmt.Sprintf("[%d] %s", i, testCase.desc) + t.Run(testName, func(t *testing.T) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - if strings.Contains(strings.ToLower(testCase.desc), "error") { - // I hate this so much - u.db.Close() - } + if strings.Contains(strings.ToLower(testCase.desc), "error") { + // I hate this so much + u.db.Close() + } - famName, err := schema.NewFamilyName(testCase.familyName) - if err != nil { - t.Fatalf("Family name %s invalid: %+v", testCase.familyName, err) - } - fam, ok, err := u.e.fetchFamilyByName(famName) + famName, err := schema.NewFamilyName(testCase.familyName) + if err != nil { + t.Fatalf("Family name %s invalid: %+v", testCase.familyName, err) + } + fam, ok, err := u.e.fetchFamilyByName(famName) - // Supreme Go-l0rd bmizerany told me to use cmp - if diff := cmp.Diff(testCase.wantFam, fam); diff != "" { - t.Errorf("returned dbFamily differs\n%s", diff) - } - if diff := cmp.Diff(testCase.wantOk, ok); diff != "" { - t.Errorf("returned ok differs\n%s", diff) - } + // Supreme Go-l0rd bmizerany told me to use cmp + if diff := cmp.Diff(testCase.wantFam, fam); diff != "" { + t.Errorf("returned dbFamily differs\n%s", diff) + } + if diff := cmp.Diff(testCase.wantOk, ok); diff != "" { + t.Errorf("returned ok differs\n%s", diff) + } - // error I'm looking for isn't exported, damnit - if want, got := testCase.wantErr, err; true { - if got == nil { - if want != "" { - t.Errorf("Expected no error returned, got %+v", got) - } - } else { - if want != got.Error() { - t.Errorf("Expected: %+v, got %+v\n", want, got) + // error I'm looking for isn't exported, damnit + if want, got := testCase.wantErr, err; true { + if got == nil { + if want != "" { + t.Errorf("Expected no error returned, got %+v", got) + } + } else { + if want != got.Error() { + t.Errorf("Expected: %+v, got %+v\n", want, got) + } } } - } - }) - } + }) + } + }) } -func testFetchMetaTableByName(t *testing.T, dbType string) { - suite := []struct { - desc string - familyName string - tableName string - wantFields []schema.NamedFieldType - wantPK []string - wantOk bool - wantErr error - }{ - {"Found case", - "family1", - "table1", - []schema.NamedFieldType{ - {Name: schema.FieldName{Name: "field1"}, FieldType: schema.FTInteger}, - {Name: schema.FieldName{Name: "field2"}, FieldType: schema.FTString}, - {Name: schema.FieldName{Name: "field3"}, FieldType: schema.FTDecimal}, +func TestFetchMetaTableByName(t *testing.T) { + withDBTypes(t, func(dbType string) { + + suite := []struct { + desc string + familyName string + tableName string + wantFields []schema.NamedFieldType + wantPK []string + wantOk bool + wantErr error + }{ + {"Found case", + "family1", + "table1", + []schema.NamedFieldType{ + {Name: schema.FieldName{Name: "field1"}, FieldType: schema.FTInteger}, + {Name: schema.FieldName{Name: "field2"}, FieldType: schema.FTString}, + {Name: schema.FieldName{Name: "field3"}, FieldType: schema.FTDecimal}, + }, + []string{}, + true, + nil, }, - []string{}, - true, - nil, - }, - } + } - for i, testCase := range suite { - testName := fmt.Sprintf("[%d] %s", i, testCase.desc) - t.Run(testName, func(t *testing.T) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() + for i, testCase := range suite { + testName := fmt.Sprintf("[%d] %s", i, testCase.desc) + t.Run(testName, func(t *testing.T) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - famName, err := schema.NewFamilyName(testCase.familyName) - if err != nil { - t.Fatalf("Invalid family name %s, error: %+v", famName, err) - } - tblName, err := schema.NewTableName(testCase.tableName) - if err != nil { - t.Fatalf("Invalid table name %s, error: %+v", tblName, err) - } + famName, err := schema.NewFamilyName(testCase.familyName) + if err != nil { + t.Fatalf("Invalid family name %s, error: %+v", famName, err) + } + tblName, err := schema.NewTableName(testCase.tableName) + if err != nil { + t.Fatalf("Invalid table name %s, error: %+v", tblName, err) + } - tbl, gotOk, gotErr := u.e.fetchMetaTableByName(famName, tblName) + tbl, gotOk, gotErr := u.e.fetchMetaTableByName(famName, tblName) - if got, want := tbl.FamilyName.String(), testCase.familyName; got != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + if got, want := tbl.FamilyName.String(), testCase.familyName; got != want { + t.Errorf("Expected %+v, got %+v", want, got) + } - if got, want := tbl.TableName.String(), testCase.tableName; got != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + if got, want := tbl.TableName.String(), testCase.tableName; got != want { + t.Errorf("Expected %+v, got %+v", want, got) + } - if diff := cmp.Diff(testCase.wantFields, tbl.Fields); diff != "" { - t.Errorf("returned dbFamily differs\n%s", diff) - } + if diff := cmp.Diff(testCase.wantFields, tbl.Fields); diff != "" { + t.Errorf("returned dbFamily differs\n%s", diff) + } - if got, want := gotOk, testCase.wantOk; got != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + if got, want := gotOk, testCase.wantOk; got != want { + t.Errorf("Expected %+v, got %+v", want, got) + } - if got, want := gotErr, testCase.wantErr; got != want { - t.Errorf("Expected %+v, got %+v", want, got) - } - }) - } + if got, want := gotErr, testCase.wantErr; got != want { + t.Errorf("Expected %+v, got %+v", want, got) + } + }) + } + }) } -func testDBExecutiveMutate(t *testing.T, dbType string) { - suite := []struct { - desc string - writerName string - cookie []byte - checkCookie []byte - reqs []ExecutiveMutationRequest - expectErr error - expectRows map[string][]map[string]interface{} - expectDML []string - skipDBTypes []string - }{ - { - desc: "MySQL String Column With Null Value", - skipDBTypes: []string{"sqlite3"}, // sqlite3 cannot retrieve this data without truncating so we skip it as a backend - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "a\u0000b", - "field3": 42, +func TestDBExecutiveMutate(t *testing.T) { + withDBTypes(t, func(dbType string) { + + suite := []struct { + desc string + writerName string + cookie []byte + checkCookie []byte + reqs []ExecutiveMutationRequest + expectErr error + expectRows map[string][]map[string]interface{} + expectDML []string + skipDBTypes []string + }{ + { + desc: "MySQL String Column With Null Value", + skipDBTypes: []string{"sqlite3"}, // sqlite3 cannot retrieve this data without truncating so we skip it as a backend + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "a\u0000b", + "field3": 42, + }, }, }, - }, - expectRows: map[string][]map[string]interface{}{ - "table1": { - {"field1": 1}, + expectRows: map[string][]map[string]interface{}{ + "table1": { + {"field1": 1}, + }, + }, + expectDML: []string{ + `REPLACE INTO family1___table1 ("field1","field2","field3") ` + + `VALUES(1,x'610062',42)`, }, }, - expectDML: []string{ - `REPLACE INTO family1___table1 ("field1","field2","field3") ` + - `VALUES(1,x'610062',42)`, - }, - }, - { - desc: "Binary Column Null Value", - reqs: []ExecutiveMutationRequest{ - { - TableName: "binary_table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": nil, + { + desc: "Binary Column Null Value", + reqs: []ExecutiveMutationRequest{ + { + TableName: "binary_table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": nil, + }, }, }, - }, - expectRows: map[string][]map[string]interface{}{ - "binary_table1": { - {"field1": 1}, + expectRows: map[string][]map[string]interface{}{ + "binary_table1": { + {"field1": 1}, + }, + }, + expectDML: []string{ + `REPLACE INTO family1___binary_table1 ("field1","field2") ` + + `VALUES(1,NULL)`, }, }, - expectDML: []string{ - `REPLACE INTO family1___binary_table1 ("field1","field2") ` + - `VALUES(1,NULL)`, - }, - }, - { - desc: "Empty Table Insert", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "bar", - "field3": 10.0, + { + desc: "Empty Table Insert", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "bar", + "field3": 10.0, + }, }, }, - }, - expectRows: map[string][]map[string]interface{}{ - "table1": { - {"field1": 1, "field2": "bar", "field3": 10.0}, + expectRows: map[string][]map[string]interface{}{ + "table1": { + {"field1": 1, "field2": "bar", "field3": 10.0}, + }, + }, + expectDML: []string{ + `REPLACE INTO family1___table1 ("field1","field2","field3") ` + + `VALUES(1,'bar',10)`, }, }, - expectDML: []string{ - `REPLACE INTO family1___table1 ("field1","field2","field3") ` + - `VALUES(1,'bar',10)`, - }, - }, - { - desc: "Unicode Insert", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "𨅝", - "field3": 10.0, + { + desc: "Unicode Insert", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "𨅝", + "field3": 10.0, + }, }, }, - }, - expectRows: map[string][]map[string]interface{}{ - "table1": { - {"field1": 1, "field2": "𨅝", "field3": 10.0}, + expectRows: map[string][]map[string]interface{}{ + "table1": { + {"field1": 1, "field2": "𨅝", "field3": 10.0}, + }, + }, + expectDML: []string{ + `REPLACE INTO family1___table1 ("field1","field2","field3") ` + + `VALUES(1,'𨅝',10)`, }, }, - expectDML: []string{ - `REPLACE INTO family1___table1 ("field1","field2","field3") ` + - `VALUES(1,'𨅝',10)`, - }, - }, - { - desc: "Escaped String Insert", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": `\\d`, - "field3": 10.0, + { + desc: "Escaped String Insert", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": `\\d`, + "field3": 10.0, + }, }, }, - }, - expectRows: map[string][]map[string]interface{}{ - "table1": { - {"field1": 1, "field2": `\\d`, "field3": 10.0}, + expectRows: map[string][]map[string]interface{}{ + "table1": { + {"field1": 1, "field2": `\\d`, "field3": 10.0}, + }, + }, + expectDML: []string{ + `REPLACE INTO family1___table1 ("field1","field2","field3") ` + + `VALUES(1,'\\d',10)`, }, }, - expectDML: []string{ - `REPLACE INTO family1___table1 ("field1","field2","field3") ` + - `VALUES(1,'\\d',10)`, - }, - }, - { - desc: "Multi Insert", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "bar", - "field3": 10.0, + { + desc: "Multi Insert", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "bar", + "field3": 10.0, + }, + }, + { + TableName: "table100", + Delete: false, + Values: map[string]interface{}{ + "field1": 2, + "field2": "baz", + "field3": 20.0, + }, }, }, - { - TableName: "table100", - Delete: false, - Values: map[string]interface{}{ - "field1": 2, - "field2": "baz", - "field3": 20.0, + expectRows: map[string][]map[string]interface{}{ + "table1": { + {"field1": 1, "field2": "bar", "field3": 10.0}, }, + "table100": { + {"field1": 2, "field2": "baz", "field3": 20.0}, + }, + }, + expectDML: []string{ + schema.DMLTxBeginKey, + `REPLACE INTO family1___table1 ("field1","field2","field3") ` + + `VALUES(1,'bar',10)`, + `REPLACE INTO family1___table100 ("field1","field2","field3") ` + + `VALUES(2,'baz',20)`, + schema.DMLTxEndKey, }, }, - expectRows: map[string][]map[string]interface{}{ - "table1": { - {"field1": 1, "field2": "bar", "field3": 10.0}, + { + desc: "Delete", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table10", + Delete: true, + Values: map[string]interface{}{ + "field1": 1, + }, + }, + { + TableName: "table11", + Delete: true, + Values: map[string]interface{}{ + "field1": 1, + }, + }, }, - "table100": { - {"field1": 2, "field2": "baz", "field3": 20.0}, + expectRows: map[string][]map[string]interface{}{ + "table10": {}, + "table11": {}, + }, + expectDML: []string{ + schema.DMLTxBeginKey, + `DELETE FROM family1___table10 WHERE "field1" = 1`, + `DELETE FROM family1___table11 WHERE "field1" = 1`, + schema.DMLTxEndKey, }, }, - expectDML: []string{ - schema.DMLTxBeginKey, - `REPLACE INTO family1___table1 ("field1","field2","field3") ` + - `VALUES(1,'bar',10)`, - `REPLACE INTO family1___table100 ("field1","field2","field3") ` + - `VALUES(2,'baz',20)`, - schema.DMLTxEndKey, - }, - }, - { - desc: "Delete", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table10", - Delete: true, - Values: map[string]interface{}{ - "field1": 1, + { + desc: "Null Value Insert", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": nil, + "field3": 10.0, + }, }, }, - { - TableName: "table11", - Delete: true, - Values: map[string]interface{}{ - "field1": 1, - }, + // expectRows doesn't work here for some reason, but the + // statement is fine. + expectDML: []string{ + `REPLACE INTO family1___table1 ("field1","field2","field3") ` + + `VALUES(1,NULL,10)`, }, }, - expectRows: map[string][]map[string]interface{}{ - "table10": {}, - "table11": {}, - }, - expectDML: []string{ - schema.DMLTxBeginKey, - `DELETE FROM family1___table10 WHERE "field1" = 1`, - `DELETE FROM family1___table11 WHERE "field1" = 1`, - schema.DMLTxEndKey, - }, - }, - { - desc: "Null Value Insert", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": nil, - "field3": 10.0, + { + desc: "Replace row in table", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table10", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "bar", + "field3": 0.0, + }, }, }, - }, - // expectRows doesn't work here for some reason, but the - // statement is fine. - expectDML: []string{ - `REPLACE INTO family1___table1 ("field1","field2","field3") ` + - `VALUES(1,NULL,10)`, - }, - }, - { - desc: "Replace row in table", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table10", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "bar", - "field3": 0.0, + expectRows: map[string][]map[string]interface{}{ + "table10": { + {"field1": 1, "field2": "bar", "field3": 0.0}, }, }, - }, - expectRows: map[string][]map[string]interface{}{ - "table10": { - {"field1": 1, "field2": "bar", "field3": 0.0}, + expectDML: []string{ + `REPLACE INTO family1___table10 ("field1","field2","field3") ` + + `VALUES(1,'bar',0)`, }, }, - expectDML: []string{ - `REPLACE INTO family1___table10 ("field1","field2","field3") ` + - `VALUES(1,'bar',0)`, - }, - }, - { - desc: "Error on missing fields", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "bar", + { + desc: "Error on missing fields", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "bar", + }, }, }, + expectErr: errors.New("Missing field field3"), + expectRows: map[string][]map[string]interface{}{ + "table1": {}, + }, }, - expectErr: errors.New("Missing field field3"), - expectRows: map[string][]map[string]interface{}{ - "table1": {}, - }, - }, - { - desc: "Check cookie correct", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "bar", - "field3": 10.0, + { + desc: "Check cookie correct", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "bar", + "field3": 10.0, + }, }, }, + cookie: []byte{2}, + checkCookie: []byte{1}, }, - cookie: []byte{2}, - checkCookie: []byte{1}, - }, - { - desc: "No check cookie succeeds", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": "bar", - "field3": 10.0, + { + desc: "No check cookie succeeds", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": "bar", + "field3": 10.0, + }, }, }, + cookie: []byte{2}, }, - cookie: []byte{2}, - }, - { - desc: "Max DML Size Exceeded", - reqs: []ExecutiveMutationRequest{ - { - TableName: "table1", - Delete: false, - Values: map[string]interface{}{ - "field1": 1, - "field2": strings.Repeat("b", 769*units.KILOBYTE), - "field3": 10.0, + { + desc: "Max DML Size Exceeded", + reqs: []ExecutiveMutationRequest{ + { + TableName: "table1", + Delete: false, + Values: map[string]interface{}{ + "field1": 1, + "field2": strings.Repeat("b", 769*units.KILOBYTE), + "field3": 10.0, + }, }, }, + expectErr: &errs.BadRequestError{Err: "Request generated too large of a DML statement"}, }, - expectErr: &errs.BadRequestError{Err: "Request generated too large of a DML statement"}, - }, - } - - for caseIdx, testCase := range suite { - var skipTest bool - for _, skipDBType := range testCase.skipDBTypes { - if skipDBType == dbType { - skipTest = true - } } - if skipTest { - continue - } - - testName := fmt.Sprintf("[%d] %s", caseIdx, testCase.desc) - t.Run(testName, func(t *testing.T) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() - writerName := testCase.writerName - if writerName == "" { - writerName = "writer1" + for caseIdx, testCase := range suite { + var skipTest bool + for _, skipDBType := range testCase.skipDBTypes { + if skipDBType == dbType { + skipTest = true + } } - - cookie := testCase.cookie - if cookie == nil { - cookie = []byte{2} + if skipTest { + continue } - err := u.e.Mutate(writerName, "", "family1", cookie, testCase.checkCookie, testCase.reqs) + testName := fmt.Sprintf("[%d] %s", caseIdx, testCase.desc) + t.Run(testName, func(t *testing.T) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - if err != nil { - if testCase.expectErr != nil { - if want, got := testCase.expectErr, err; want != got && want.Error() != got.Error() { - t.Errorf("Expected error %+v, got %+v", want, got) - } - } else { - t.Errorf("Unexpected error: %+v", err) + writerName := testCase.writerName + if writerName == "" { + writerName = "writer1" } - } else { - if testCase.expectErr != nil { - t.Errorf("Expected error: %+v, got nil", testCase.expectErr) + + cookie := testCase.cookie + if cookie == nil { + cookie = []byte{2} } - } - if testCase.expectDML != nil { - rows, err := u.db.Query( - "SELECT statement FROM " + - dmlLedgerTableName + - " ORDER BY seq ASC") + err := u.e.Mutate(writerName, "", "family1", cookie, testCase.checkCookie, testCase.reqs) if err != nil { - t.Errorf("Unexpected error: %+v", err) - } else { - defer rows.Close() - i := 0 - for rows.Next() { - var rowStatement string - err = rows.Scan(&rowStatement) - if err != nil { - t.Fatalf("Unexpected error scanning: %+v", err) + if testCase.expectErr != nil { + if want, got := testCase.expectErr, err; want != got && want.Error() != got.Error() { + t.Errorf("Expected error %+v, got %+v", want, got) } - if i < len(testCase.expectDML) { - if want, got := testCase.expectDML[i], rowStatement; want != got { - t.Errorf("Expected %+v, got %+v", want, got) - } - } else { - t.Errorf("Extra statement: %v", rowStatement) - } - i++ + } else { + t.Errorf("Unexpected error: %+v", err) } - } - } - - if testCase.expectRows != nil { - for name, rows := range testCase.expectRows { - famName, _ := schema.NewFamilyName("family1") - tblName, err := schema.NewTableName(name) - if err != nil { - t.Fatalf("Invalid table name %s, error: %+v", name, err) + } else { + if testCase.expectErr != nil { + t.Errorf("Expected error: %+v, got nil", testCase.expectErr) } + } - sqlTableName := schema.LDBTableName(famName, tblName) + if testCase.expectDML != nil { + rows, err := u.db.Query( + "SELECT statement FROM " + + dmlLedgerTableName + + " ORDER BY seq ASC") - var rowCnt int - cntRow := u.db.QueryRow("SELECT COUNT(*) FROM " + sqlTableName) - err = cntRow.Scan(&rowCnt) if err != nil { - t.Fatalf("Unexpected error encountered: %+v", err) - } - - if want, got := len(rows), rowCnt; want != got { - t.Errorf("Expected %s to have %d rows, got %d", sqlTableName, want, got) + t.Errorf("Unexpected error: %+v", err) + } else { + defer rows.Close() + i := 0 + for rows.Next() { + var rowStatement string + err = rows.Scan(&rowStatement) + if err != nil { + t.Fatalf("Unexpected error scanning: %+v", err) + } + if i < len(testCase.expectDML) { + if want, got := testCase.expectDML[i], rowStatement; want != got { + t.Errorf("Expected %+v, got %+v", want, got) + } + } else { + t.Errorf("Extra statement: %v", rowStatement) + } + i++ + } } + } - for _, row := range rows { - clauses := []string{} - valz := []interface{}{} - for colName, colVal := range row { - clauses = append(clauses, colName+"=?") - valz = append(valz, colVal) + if testCase.expectRows != nil { + for name, rows := range testCase.expectRows { + famName, _ := schema.NewFamilyName("family1") + tblName, err := schema.NewTableName(name) + if err != nil { + t.Fatalf("Invalid table name %s, error: %+v", name, err) } - whereClause := " WHERE " + strings.Join(clauses, " AND ") - var cnt int - qs := "SELECT COUNT(*) FROM " + sqlTableName + whereClause - t.Logf("Running %s", qs) - resRow := u.db.QueryRow(qs, valz...) - err = resRow.Scan(&cnt) + sqlTableName := schema.LDBTableName(famName, tblName) + + var rowCnt int + cntRow := u.db.QueryRow("SELECT COUNT(*) FROM " + sqlTableName) + err = cntRow.Scan(&rowCnt) if err != nil { t.Fatalf("Unexpected error encountered: %+v", err) } - if cnt != 1 { - t.Errorf("Expected to find 1 row of %+v, got: %+v", row, cnt) + if want, got := len(rows), rowCnt; want != got { + t.Errorf("Expected %s to have %d rows, got %d", sqlTableName, want, got) + } + + for _, row := range rows { + clauses := []string{} + valz := []interface{}{} + for colName, colVal := range row { + clauses = append(clauses, colName+"=?") + valz = append(valz, colVal) + } + whereClause := " WHERE " + strings.Join(clauses, " AND ") + + var cnt int + qs := "SELECT COUNT(*) FROM " + sqlTableName + whereClause + t.Logf("Running %s", qs) + resRow := u.db.QueryRow(qs, valz...) + err = resRow.Scan(&cnt) + if err != nil { + t.Fatalf("Unexpected error encountered: %+v", err) + } + + if cnt != 1 { + t.Errorf("Expected to find 1 row of %+v, got: %+v", row, cnt) + } } } } - } - }) - } + }) + } + }) } -func testDBExecutiveGetWriterCookie(t *testing.T, dbType string) { - suite := []struct { - desc string - writerName string - writerSecret string - expectCookie []byte - expectErr string - }{ - { - desc: "Empty writer returns error", - writerName: "writer-doesnt-exist", - expectErr: "Writer not found", - }, - { - desc: "Bad writer secret returns error", - writerName: "writer1", - writerSecret: "invalid", - expectErr: "Writer not found", - }, - { - desc: "Existing writer", - writerName: "writer1", - expectCookie: []byte{1}, - }, - } +func TestDBExecutiveGetWriterCookie(t *testing.T) { + withDBTypes(t, func(dbType string) { + + suite := []struct { + desc string + writerName string + writerSecret string + expectCookie []byte + expectErr string + }{ + { + desc: "Empty writer returns error", + writerName: "writer-doesnt-exist", + expectErr: "Writer not found", + }, + { + desc: "Bad writer secret returns error", + writerName: "writer1", + writerSecret: "invalid", + expectErr: "Writer not found", + }, + { + desc: "Existing writer", + writerName: "writer1", + expectCookie: []byte{1}, + }, + } - for _, testCase := range suite { - t.Run(testCase.desc, func(t *testing.T) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() + for _, testCase := range suite { + t.Run(testCase.desc, func(t *testing.T) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - gotCookie, gotErr := u.e.GetWriterCookie(testCase.writerName, testCase.writerSecret) + gotCookie, gotErr := u.e.GetWriterCookie(testCase.writerName, testCase.writerSecret) - if diff := cmp.Diff(testCase.expectCookie, gotCookie); diff != "" { - t.Errorf("Cookie differs\n%s", diff) - } - if want, got := testCase.expectErr, gotErr; (got == nil && want != "") || (got != nil && want != got.Error()) { - t.Errorf("Expected: %v, got %v", want, got) - } - }) - } + if diff := cmp.Diff(testCase.expectCookie, gotCookie); diff != "" { + t.Errorf("Cookie differs\n%s", diff) + } + if want, got := testCase.expectErr, gotErr; (got == nil && want != "") || (got != nil && want != got.Error()) { + t.Errorf("Expected: %v, got %v", want, got) + } + }) + } + }) } -func testDBExecutiveSetWriterCookie(t *testing.T, dbType string) { - suite := []struct { - desc string - writerName string - writerSecret string - cookie []byte - expectErr string - expectCookie []byte - }{ - { - desc: "Empty writer returns error", - writerName: "writer-doesnt-exist", - expectErr: "Writer not found", - }, - { - desc: "Bad writer secret returns error", - writerName: "writer1", - writerSecret: "invalid", - expectErr: "Writer not found", - }, - { - desc: "Existing writer", - writerName: "writer1", - cookie: []byte{1}, - expectCookie: []byte{1}, - }, - } - - for _, testCase := range suite { - t.Run(testCase.desc, func(t *testing.T) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveSetWriterCookie(t *testing.T) { + withDBTypes(t, func(dbType string) { + + suite := []struct { + desc string + writerName string + writerSecret string + cookie []byte + expectErr string + expectCookie []byte + }{ + { + desc: "Empty writer returns error", + writerName: "writer-doesnt-exist", + expectErr: "Writer not found", + }, + { + desc: "Bad writer secret returns error", + writerName: "writer1", + writerSecret: "invalid", + expectErr: "Writer not found", + }, + { + desc: "Existing writer", + writerName: "writer1", + cookie: []byte{1}, + expectCookie: []byte{1}, + }, + } - gotErr := u.e.SetWriterCookie(testCase.writerName, testCase.writerSecret, testCase.cookie) + for _, testCase := range suite { + t.Run(testCase.desc, func(t *testing.T) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - if want, got := testCase.expectErr, gotErr; (got == nil && want != "") || (got != nil && want != got.Error()) { - t.Errorf("Expected: %v, got %v", want, got) - } + gotErr := u.e.SetWriterCookie(testCase.writerName, testCase.writerSecret, testCase.cookie) - if testCase.expectCookie != nil { - gotCookie, err := u.e.GetWriterCookie(testCase.writerName, testCase.writerSecret) - if err != nil { - t.Fatalf("Unexpected error: %+v", err) + if want, got := testCase.expectErr, gotErr; (got == nil && want != "") || (got != nil && want != got.Error()) { + t.Errorf("Expected: %v, got %v", want, got) } - if diff := cmp.Diff(testCase.expectCookie, gotCookie); diff != "" { - t.Errorf("Cookie differs:\n+%v", diff) + + if testCase.expectCookie != nil { + gotCookie, err := u.e.GetWriterCookie(testCase.writerName, testCase.writerSecret) + if err != nil { + t.Fatalf("Unexpected error: %+v", err) + } + if diff := cmp.Diff(testCase.expectCookie, gotCookie); diff != "" { + t.Errorf("Cookie differs:\n+%v", diff) + } } - } - }) - } + }) + } + }) } -func testDBExecutiveReadRow(t *testing.T, dbType string) { - suite := []struct { - desc string - familyName string - tableName string - where map[string]interface{} - expectOut map[string]interface{} - expectErr string - }{ - { - desc: "Table not found", - familyName: "nonExistantFamily", - tableName: "nonExistantTable", - where: nil, - expectOut: nil, - expectErr: "Table not found", - }, - { - desc: "Row not found", - familyName: "family1", - tableName: "table10", - where: map[string]interface{}{"field1": 1234}, - expectOut: map[string]interface{}{}, - expectErr: "", - }, - { - desc: "Row found", - familyName: "family1", - tableName: "table10", - where: map[string]interface{}{"field1": 1}, - expectOut: map[string]interface{}{ - "field1": int64(1), - "field2": "foo", - "field3": float64(1.2), +func TestDBExecutiveReadRow(t *testing.T) { + withDBTypes(t, func(dbType string) { + + suite := []struct { + desc string + familyName string + tableName string + where map[string]interface{} + expectOut map[string]interface{} + expectErr string + }{ + { + desc: "Table not found", + familyName: "nonExistantFamily", + tableName: "nonExistantTable", + where: nil, + expectOut: nil, + expectErr: "Table not found", }, - expectErr: "", - }, - } + { + desc: "Row not found", + familyName: "family1", + tableName: "table10", + where: map[string]interface{}{"field1": 1234}, + expectOut: map[string]interface{}{}, + expectErr: "", + }, + { + desc: "Row found", + familyName: "family1", + tableName: "table10", + where: map[string]interface{}{"field1": 1}, + expectOut: map[string]interface{}{ + "field1": int64(1), + "field2": "foo", + "field3": float64(1.2), + }, + expectErr: "", + }, + } - for _, testCase := range suite { - t.Run(testCase.desc, func(t *testing.T) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() + for _, testCase := range suite { + t.Run(testCase.desc, func(t *testing.T) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - gotOut, gotErr := u.e.ReadRow(testCase.familyName, testCase.tableName, testCase.where) + gotOut, gotErr := u.e.ReadRow(testCase.familyName, testCase.tableName, testCase.where) - if diff := cmp.Diff(testCase.expectOut, gotOut); diff != "" { - t.Errorf("Out differs\n%s", diff) - } - if want, got := testCase.expectErr, gotErr; (got == nil && want != "") || (got != nil && want != got.Error()) { - t.Errorf("Expected: %v, got %v", want, got) - } - }) - } + if diff := cmp.Diff(testCase.expectOut, gotOut); diff != "" { + t.Errorf("Out differs\n%s", diff) + } + if want, got := testCase.expectErr, gotErr; (got == nil && want != "") || (got != nil && want != got.Error()) { + t.Errorf("Expected: %v, got %v", want, got) + } + }) + } + }) } -func testDBExecutiveDropTable(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveDropTable(t *testing.T) { + withDBTypes(t, func(dbType string) { - err := u.e.CreateTable("family1", - "delete_test", - []string{"field1"}, - []schema.FieldType{schema.FTString}, - []string{"field1"}, - ) - require.NoError(t, err) + u := newDbExecTestUtil(t, dbType) + defer u.Close() - // verify the table exists and has no rows - row := u.db.QueryRow("SELECT COUNT(*) FROM family1___delete_test") - var cnt sql.NullInt64 - err = row.Scan(&cnt) - require.NoError(t, err) - require.EqualValues(t, 0, cnt.Int64) + err := u.e.CreateTable("family1", + "delete_test", + []string{"field1"}, + []schema.FieldType{schema.FTString}, + []string{"field1"}, + ) + require.NoError(t, err) - err = u.e.DropTable(schema.FamilyTable{Family: "family1", Table: "delete_test"}) - require.NoError(t, err) + // verify the table exists and has no rows + row := u.db.QueryRow("SELECT COUNT(*) FROM family1___delete_test") + var cnt sql.NullInt64 + err = row.Scan(&cnt) + require.NoError(t, err) + require.EqualValues(t, 0, cnt.Int64) - // assert that we can't query the table anymore - row = u.db.QueryRow("SELECT COUNT(*) FROM family1___delete_test") - err = row.Scan(&cnt) - switch dbType { - case "sqlite3": - require.EqualError(t, err, "no such table: family1___delete_test") - case "mysql": - require.EqualError(t, err, "Error 1146: Table 'ctldb.family1___delete_test' doesn't exist") - default: - require.Fail(t, "unknown db type: "+dbType) - } + err = u.e.DropTable(schema.FamilyTable{Family: "family1", Table: "delete_test"}) + require.NoError(t, err) - // double check the dml - row = u.db.QueryRow("select statement from ctlstore_dml_ledger order by seq desc limit 1") - var statement string - err = row.Scan(&statement) - require.NoError(t, err) - require.EqualValues(t, "DROP TABLE IF EXISTS family1___delete_test", statement) + // assert that we can't query the table anymore + row = u.db.QueryRow("SELECT COUNT(*) FROM family1___delete_test") + err = row.Scan(&cnt) + switch dbType { + case "sqlite3": + require.EqualError(t, err, "no such table: family1___delete_test") + case "mysql": + require.EqualError(t, err, "Error 1146: Table 'ctldb.family1___delete_test' doesn't exist") + default: + require.Fail(t, "unknown db type: "+dbType) + } + + // double check the dml + row = u.db.QueryRow("select statement from ctlstore_dml_ledger order by seq desc limit 1") + var statement string + err = row.Scan(&statement) + require.NoError(t, err) + require.EqualValues(t, "DROP TABLE IF EXISTS family1___delete_test", statement) + }) } -func testDBExecutiveClearTable(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveClearTable(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + defer u.Close() - err := u.e.CreateTable("family1", - "table5", - []string{"field1", "field2"}, - []schema.FieldType{schema.FTString, schema.FTInteger}, - []string{"field1", "field2"}, - ) - if err != nil { - t.Fatalf("Unexpected error calling CreateTable: %+v", err) - } + err := u.e.CreateTable("family1", + "table5", + []string{"field1", "field2"}, + []schema.FieldType{schema.FTString, schema.FTInteger}, + []string{"field1", "field2"}, + ) + if err != nil { + t.Fatalf("Unexpected error calling CreateTable: %+v", err) + } - _, err = u.db.Exec(`INSERT into family1___table5 + _, err = u.db.Exec(`INSERT into family1___table5 (field1,field2) VALUES ('1',2) `) - if err != nil { - t.Fatal(err) - } + if err != nil { + t.Fatal(err) + } - err = u.e.ClearTable(schema.FamilyTable{Family: "family1", Table: "table5"}) - if err != nil { - t.Fatalf("Unexpected error calling ClearTable: %+v", err) - } + err = u.e.ClearTable(schema.FamilyTable{Family: "family1", Table: "table5"}) + if err != nil { + t.Fatalf("Unexpected error calling ClearTable: %+v", err) + } - row := u.db.QueryRow("SELECT COUNT(*) FROM family1___table5") - var cnt sql.NullInt64 - err = row.Scan(&cnt) - if err != nil { - t.Fatalf("Unexpected error scanning result: %+v", err) - } + row := u.db.QueryRow("SELECT COUNT(*) FROM family1___table5") + var cnt sql.NullInt64 + err = row.Scan(&cnt) + if err != nil { + t.Fatalf("Unexpected error scanning result: %+v", err) + } - if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { - t.Errorf("Expected %+v, got %+v", want, got) - } + if want, got := 0, cnt; !got.Valid || int(got.Int64) != want { + t.Errorf("Expected %+v, got %+v", want, got) + } + }) } -func testDBExecutiveReadFamilyTableNames(t *testing.T, dbType string) { - if dbType != "mysql" { - t.Skip("skipping test when db is not mysql") - } - - u := newDbExecTestUtil(t, dbType) - defer u.Close() +func TestDBExecutiveReadFamilyTableNames(t *testing.T) { + withDBTypes(t, func(dbType string) { + if dbType != "mysql" { + t.Skip("skipping test when db is not mysql") + } - tables, err := u.e.ReadFamilyTableNames(schema.FamilyName{Name: "family1"}) - if err != nil { - t.Fatalf("Unexpected error calling Reading Family Table Names: %+v", err) - } + u := newDbExecTestUtil(t, dbType) + defer u.Close() - if want, got := 5, len(tables); got != want { - t.Errorf("Expected %+v tables, got %+v", want, got) - } + tables, err := u.e.ReadFamilyTableNames(schema.FamilyName{Name: "family1"}) + if err != nil { + t.Fatalf("Unexpected error calling Reading Family Table Names: %+v", err) + } - sort.Slice(tables, func(i, j int) bool { - return tables[i].Table < tables[j].Table - }) + if want, got := 5, len(tables); got != want { + t.Errorf("Expected %+v tables, got %+v", want, got) + } - expected := []schema.FamilyTable{ - { - Family: "family1", - Table: "binary_table1", - }, - { - Family: "family1", - Table: "table1", - }, - { - Family: "family1", - Table: "table10", - }, - { - Family: "family1", - Table: "table100", - }, - { - Family: "family1", - Table: "table11", - }, - } + sort.Slice(tables, func(i, j int) bool { + return tables[i].Table < tables[j].Table + }) - for i, table := range tables { - if table.Family != "family1" { - t.Errorf("Expected family1, got %+v", table.Family) + expected := []schema.FamilyTable{ + { + Family: "family1", + Table: "binary_table1", + }, + { + Family: "family1", + Table: "table1", + }, + { + Family: "family1", + Table: "table10", + }, + { + Family: "family1", + Table: "table100", + }, + { + Family: "family1", + Table: "table11", + }, } - if table.Table != expected[i].Table { - t.Errorf("Invalid table name. Expected %s, got %s,", expected[i].Table, table.Table) + + for i, table := range tables { + if table.Family != "family1" { + t.Errorf("Expected family1, got %+v", table.Family) + } + if table.Table != expected[i].Table { + t.Errorf("Invalid table name. Expected %s, got %s,", expected[i].Table, table.Table) + } } - } + }) } diff --git a/pkg/executive/db_limiter.go b/pkg/executive/db_limiter.go index 4fcb1cff..bbc7c0b3 100644 --- a/pkg/executive/db_limiter.go +++ b/pkg/executive/db_limiter.go @@ -3,10 +3,11 @@ package executive import ( "context" "database/sql" + "errors" + "fmt" "sync" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/limits" "github.com/segmentio/ctlstore/pkg/schema" @@ -55,11 +56,11 @@ func newDBLimiter(db *sql.DB, dbType string, defaultTableLimit limits.SizeLimits // request that includes some tables that are not over their limits. func (l *dbLimiter) allowed(ctx context.Context, tx *sql.Tx, lr limiterRequest) (bool, error) { if err := l.checkTableSizes(ctx, lr); err != nil { - return false, errors.Wrap(err, "check table sizes") + return false, fmt.Errorf("check table sizes: %w", err) } allowed, err := l.checkWriterRates(ctx, tx, lr) if err != nil { - return false, errors.Wrap(err, "check writer rates") + return false, fmt.Errorf("check writer rates: %w", err) } return allowed, nil } @@ -80,7 +81,7 @@ func (l *dbLimiter) checkWriterRates(ctx context.Context, tx *sql.Tx, lr limiter var amount int64 err := row.Scan(&amount) if err != nil && err != sql.ErrNoRows { - return false, errors.Wrap(err, "select from writer_usage") + return false, fmt.Errorf("select from writer_usage: %w", err) } amount += int64(numMutations) if err == sql.ErrNoRows { @@ -88,11 +89,11 @@ func (l *dbLimiter) checkWriterRates(ctx context.Context, tx *sql.Tx, lr limiter res, err := tx.ExecContext(ctx, "INSERT INTO writer_usage (bucket,writer_name,amount) VALUES (?,?,?)", bucket, lr.writerName, amount) if err != nil { - return false, errors.Wrap(err, "insert into writer_usage") + return false, fmt.Errorf("insert into writer_usage: %w", err) } rowsAffected, err := res.RowsAffected() if err != nil { - return false, errors.Wrap(err, "affected rows from insert into writer_usage") + return false, fmt.Errorf("affected rows from insert into writer_usage: %w", err) } if rowsAffected == 0 { return false, errors.New("insert into writer_usage failed (no rows updated)") @@ -102,11 +103,11 @@ func (l *dbLimiter) checkWriterRates(ctx context.Context, tx *sql.Tx, lr limiter res, err := tx.ExecContext(ctx, "UPDATE writer_usage SET amount=? where bucket=? and writer_name=?", amount, bucket, lr.writerName) if err != nil { - return false, errors.Wrap(err, "update writer_usage") + return false, fmt.Errorf("update writer_usage: %w", err) } rowsAffected, err := res.RowsAffected() if err != nil { - return false, errors.Wrap(err, "affected rows from update writer_usage") + return false, fmt.Errorf("affected rows from update writer_usage: %w", err) } if rowsAffected == 0 { return false, errors.New("updating writer_usage failed (no rows updated)") @@ -138,7 +139,7 @@ func (l *dbLimiter) checkTableSizes(ctx context.Context, lr limiterRequest) erro func (l *dbLimiter) start(ctx context.Context) error { events.Log("Starting the db limiter") if err := l.tableSizer.start(ctx); err != nil { - return errors.Wrap(err, "could not start sizer") + return fmt.Errorf("could not start sizer: %w", err) } instrumentUpdateErr := func(err error) { errs.IncrDefault(stats.Tag{Name: "op", Value: "update-limits"}) @@ -146,7 +147,7 @@ func (l *dbLimiter) start(ctx context.Context) error { // we always require an initial update of limit config from the db if err := l.refreshWriterLimits(ctx); err != nil { instrumentUpdateErr(err) - return errors.Wrap(err, "refresh writer limits") + return fmt.Errorf("refresh writer limits: %w", err) } // after we've done one refreshWriterLimits successfully, we'll do the rest async go utils.CtxLoop(ctx, defaultRefreshPeriod, func() { @@ -171,11 +172,11 @@ func (l *dbLimiter) deleteOldUsageData(ctx context.Context) error { res, err := l.db.ExecContext(ctx, "delete from writer_usage where bucket < ?", deleteEpoch) if err != nil { - return errors.Wrap(err, "could not delete from writer_usage table") + return fmt.Errorf("could not delete from writer_usage table: %w", err) } rows, err := res.RowsAffected() if err != nil { - return errors.Wrap(err, "could not get rows affected after deleting from writer_usage") + return fmt.Errorf("could not get rows affected after deleting from writer_usage: %w", err) } if rows > 0 { events.Log("deleted %{rows}d rows from the writer_usage table", rows) @@ -189,7 +190,7 @@ func (l *dbLimiter) deleteOldUsageData(ctx context.Context) error { func (l *dbLimiter) refreshWriterLimits(ctx context.Context) error { rows, err := l.db.QueryContext(ctx, "select writer_name, max_rows_per_minute FROM max_writer_rates") if err != nil { - return errors.Wrap(err, "could not query max_writer_rates") + return fmt.Errorf("could not query max_writer_rates: %w", err) } defer rows.Close() writerLimits := make(map[string]int64) @@ -197,20 +198,20 @@ func (l *dbLimiter) refreshWriterLimits(ctx context.Context) error { var writerName string var maxRowsPerMinute int64 if err = rows.Scan(&writerName, &maxRowsPerMinute); err != nil { - return errors.Wrap(err, "could not scan max_writer_rates") + return fmt.Errorf("could not scan max_writer_rates: %w", err) } // we need to convert the max rows per minute to the rate for the period which we're checking rateLimit := limits.RateLimit{Amount: maxRowsPerMinute, Period: time.Minute} adjustedRate, err := rateLimit.AdjustAmount(l.defaultWriterLimit.Period) if err != nil { - return errors.Wrap(err, "adjust found rate limit") + return fmt.Errorf("adjust found rate limit: %w", err) } events.Debug("adjusted %v limit from %v/%v to %v/%v", writerName, maxRowsPerMinute, time.Minute, adjustedRate, l.defaultWriterLimit.Period) writerLimits[writerName] = adjustedRate } if err := rows.Err(); err != nil { - return errors.Wrap(err, "rows err after scanning") + return fmt.Errorf("rows err after scanning: %w", err) } // update the shared data while locked l.mut.Lock() diff --git a/pkg/executive/db_limiter_test.go b/pkg/executive/db_limiter_test.go index 2af65075..068b1956 100644 --- a/pkg/executive/db_limiter_test.go +++ b/pkg/executive/db_limiter_test.go @@ -21,133 +21,133 @@ import ( "github.com/stretchr/testify/require" ) -// testDBLimiter is run from TestAllDBExecutive -func testDBLimiter(t *testing.T, dbType string) { - u := newDbExecTestUtil(t, dbType) - ctldb := u.db - - ctx, cancel := context.WithCancel(u.ctx) - defer cancel() - - const ( - familyName = "db_limiter_family" - tableName = "db_limiter_table" - bucketInterval = 5 * time.Second - writerLimit = 5 // per bucket interval - writerName = "db-limiter-writer-name" - writerSecret = "db-limiter-writer-secret" - ) - - // we control the time using a fakeTime with an epoch of 1000s - fakeTime := newFakeTime(1000) - defaultTableLimit := limits.SizeLimits{MaxSize: 30 * units.KILOBYTE, WarnSize: 20 * units.KILOBYTE} - limiter := newDBLimiter(ctldb, dbType, defaultTableLimit, bucketInterval, writerLimit) - limiter.timeFunc = fakeTime.get - require.NoError(t, limiter.tableSizer.refresh(ctx)) - require.NoError(t, u.e.CreateFamily(familyName)) - executive := &executiveService{ctldb: u.db, ctx: ctx, limiter: limiter, serveTimeout: 10 * time.Second} - - fieldNames := []string{"name", "data"} - fieldTypes := []schema.FieldType{schema.FTString, schema.FTBinary} - keyFields := []string{"name"} - require.NoError(t, u.e.CreateTable(familyName, tableName, fieldNames, fieldTypes, keyFields)) - require.NoError(t, u.e.RegisterWriter(writerName, writerSecret)) - - payloadFunc := newMutationPayload(t, tableName, "test", int(10*units.KILOBYTE)) - - // makeMutation is a func that performs a mutation supplied by the payloadFunc. it fails the test if it - // fails, so no need to return a value. - makeMutation := func(expectedCode int) { - req := httptest.NewRequest("POST", "/families/"+familyName+"/mutations", payloadFunc()) - req.Header.Set("ctlstore-writer", writerName) - req.Header.Set("ctlstore-secret", writerSecret) - w := httptest.NewRecorder() - executive.ServeHTTP(w, req) - resp := w.Result() - defer resp.Body.Close() - if expectedCode != resp.StatusCode { - b, _ := ioutil.ReadAll(resp.Body) - require.Failf(t, "request failed", "Expected %d, got %d: %s", expectedCode, resp.StatusCode, b) +func TestDBLimiter(t *testing.T) { + withDBTypes(t, func(dbType string) { + u := newDbExecTestUtil(t, dbType) + ctldb := u.db + + ctx, cancel := context.WithCancel(u.ctx) + defer cancel() + + const ( + familyName = "db_limiter_family" + tableName = "db_limiter_table" + bucketInterval = 5 * time.Second + writerLimit = 5 // per bucket interval + writerName = "db-limiter-writer-name" + writerSecret = "db-limiter-writer-secret" + ) + + // we control the time using a fakeTime with an epoch of 1000s + fakeTime := newFakeTime(1000) + defaultTableLimit := limits.SizeLimits{MaxSize: 30 * units.KILOBYTE, WarnSize: 20 * units.KILOBYTE} + limiter := newDBLimiter(ctldb, dbType, defaultTableLimit, bucketInterval, writerLimit) + limiter.timeFunc = fakeTime.get + require.NoError(t, limiter.tableSizer.refresh(ctx)) + require.NoError(t, u.e.CreateFamily(familyName)) + executive := &executiveService{ctldb: u.db, ctx: ctx, limiter: limiter, serveTimeout: 10 * time.Second} + + fieldNames := []string{"name", "data"} + fieldTypes := []schema.FieldType{schema.FTString, schema.FTBinary} + keyFields := []string{"name"} + require.NoError(t, u.e.CreateTable(familyName, tableName, fieldNames, fieldTypes, keyFields)) + require.NoError(t, u.e.RegisterWriter(writerName, writerSecret)) + + payloadFunc := newMutationPayload(t, tableName, "test", int(10*units.KILOBYTE)) + + // makeMutation is a func that performs a mutation supplied by the payloadFunc. it fails the test if it + // fails, so no need to return a value. + makeMutation := func(expectedCode int) { + req := httptest.NewRequest("POST", "/families/"+familyName+"/mutations", payloadFunc()) + req.Header.Set("ctlstore-writer", writerName) + req.Header.Set("ctlstore-secret", writerSecret) + w := httptest.NewRecorder() + executive.ServeHTTP(w, req) + resp := w.Result() + defer resp.Body.Close() + if expectedCode != resp.StatusCode { + b, _ := ioutil.ReadAll(resp.Body) + require.Failf(t, "request failed", "Expected %d, got %d: %s", expectedCode, resp.StatusCode, b) + } } - } - - // do the first mutation to create the table - makeMutation(http.StatusOK) - // let the table sizer find the table so it stops logging about it - require.NoError(t, limiter.tableSizer.refresh(ctx)) - // at this point, we're at epoch 1000 and we've written 1 row. our limit is 5 per period, so let's - // fill out the rest of our quota - for i := 0; i < 4; i++ { + // do the first mutation to create the table makeMutation(http.StatusOK) - } + // let the table sizer find the table so it stops logging about it + require.NoError(t, limiter.tableSizer.refresh(ctx)) - // if we do another mutation it should fail - makeMutation(http.StatusTooManyRequests) + // at this point, we're at epoch 1000 and we've written 1 row. our limit is 5 per period, so let's + // fill out the rest of our quota + for i := 0; i < 4; i++ { + makeMutation(http.StatusOK) + } - // shift the epoch up by $period - fakeTime.add(int64(bucketInterval / time.Second)) + // if we do another mutation it should fail + makeMutation(http.StatusTooManyRequests) - // should then be able to store 5 more but no more than that - for i := 0; i < 5; i++ { - makeMutation(http.StatusOK) - } - makeMutation(http.StatusTooManyRequests) + // shift the epoch up by $period + fakeTime.add(int64(bucketInterval / time.Second)) - // leaving the epoch where it is, let's add a per-writer override - _, err := u.db.ExecContext(ctx, "insert into max_writer_rates (writer_name, max_rows_per_minute) values(?,?)", - writerName, 120) // gets converted from 120/min -> 10/5s - require.NoError(t, err) - require.NoError(t, limiter.refreshWriterLimits(ctx)) + // should then be able to store 5 more but no more than that + for i := 0; i < 5; i++ { + makeMutation(http.StatusOK) + } + makeMutation(http.StatusTooManyRequests) - // we should be able to make five writes now before it fails - for i := 0; i < 5; i++ { - makeMutation(http.StatusOK) - } - makeMutation(http.StatusTooManyRequests) + // leaving the epoch where it is, let's add a per-writer override + _, err := u.db.ExecContext(ctx, "insert into max_writer_rates (writer_name, max_rows_per_minute) values(?,?)", + writerName, 120) // gets converted from 120/min -> 10/5s + require.NoError(t, err) + require.NoError(t, limiter.refreshWriterLimits(ctx)) - // now with the override still in place, bump the epoch and observe that writes go through - // since we're in a new bucket - fakeTime.add(int64(bucketInterval / time.Second)) - makeMutation(http.StatusOK) + // we should be able to make five writes now before it fails + for i := 0; i < 5; i++ { + makeMutation(http.StatusOK) + } + makeMutation(http.StatusTooManyRequests) - // finally, refresh the table sizer. this will pick up the new limits and should deny the - // request (if on mysql) - require.NoError(t, limiter.tableSizer.refresh(ctx)) - if dbType == "sqlite3" { + // now with the override still in place, bump the epoch and observe that writes go through + // since we're in a new bucket + fakeTime.add(int64(bucketInterval / time.Second)) makeMutation(http.StatusOK) - } else { - makeMutation(http.StatusInsufficientStorage) - } - countRows := func() int64 { - row := u.db.QueryRowContext(ctx, "select count(*) from writer_usage") - var count int64 - err = row.Scan(&count) - require.NoError(t, err) - return count - } + // finally, refresh the table sizer. this will pick up the new limits and should deny the + // request (if on mysql) + require.NoError(t, limiter.tableSizer.refresh(ctx)) + if dbType == "sqlite3" { + makeMutation(http.StatusOK) + } else { + makeMutation(http.StatusInsufficientStorage) + } - // verify we have rows in the writer_usage table - count := countRows() - require.True(t, count > 0, "unexpected count: %d", count) + countRows := func() int64 { + row := u.db.QueryRowContext(ctx, "select count(*) from writer_usage") + var count int64 + err = row.Scan(&count) + require.NoError(t, err) + return count + } - // verify that the cleaner leaves our rows alone since we are still at the same epoch - err = limiter.deleteOldUsageData(ctx) - require.NoError(t, err) - require.EqualValues(t, count, countRows()) + // verify we have rows in the writer_usage table + count := countRows() + require.True(t, count > 0, "unexpected count: %d", count) - // set the time to now so that we can trigger the cleaner to clean up our old rows - fakeTime.set(time.Now().Unix()) + // verify that the cleaner leaves our rows alone since we are still at the same epoch + err = limiter.deleteOldUsageData(ctx) + require.NoError(t, err) + require.EqualValues(t, count, countRows()) - // perform a cleanup - err = limiter.deleteOldUsageData(ctx) - require.NoError(t, err) + // set the time to now so that we can trigger the cleaner to clean up our old rows + fakeTime.set(time.Now().Unix()) - // verify no more - count = countRows() - require.EqualValues(t, 0, count) + // perform a cleanup + err = limiter.deleteOldUsageData(ctx) + require.NoError(t, err) + // verify no more + count = countRows() + require.EqualValues(t, 0, count) + }) } // newMutationPayload is a helper that produces a func that produces a reader that supplies a diff --git a/pkg/executive/executive.go b/pkg/executive/executive.go index db6826fd..0689ac6e 100644 --- a/pkg/executive/executive.go +++ b/pkg/executive/executive.go @@ -1,7 +1,8 @@ package executive import ( - "github.com/pkg/errors" + "fmt" + "github.com/segmentio/ctlstore/pkg/limits" "github.com/segmentio/ctlstore/pkg/schema" ) @@ -85,7 +86,7 @@ func (r *mutationRequest) valuesByOrder(fieldOrder []schema.FieldName) ([]interf if v, ok := r.Values[fn]; ok { values = append(values, v) } else { - return nil, errors.Errorf("Missing field %s", fn) + return nil, fmt.Errorf("Missing field %s", fn) } } return values, nil diff --git a/pkg/executive/executive_endpoint.go b/pkg/executive/executive_endpoint.go index f0090fe0..50c70c11 100644 --- a/pkg/executive/executive_endpoint.go +++ b/pkg/executive/executive_endpoint.go @@ -2,6 +2,7 @@ package executive import ( "encoding/json" + "errors" "io/ioutil" "net/http" "strconv" @@ -13,7 +14,7 @@ import ( "github.com/segmentio/stats/v4" "github.com/gorilla/mux" - "github.com/pkg/errors" + "github.com/segmentio/events/v2" ) @@ -186,7 +187,7 @@ func (ee *ExecutiveEndpoint) handleTableSchemaRoute(w http.ResponseWriter, r *ht switch { case err == nil: // do nothing, no error - case errors.Cause(err) == ErrTableDoesNotExist: + case errors.Is(err, ErrTableDoesNotExist): http.Error(w, err.Error(), http.StatusNotFound) return default: @@ -573,29 +574,16 @@ func writeErrorResponse(e error, w http.ResponseWriter) { status := http.StatusInternalServerError resBody := e.Error() - cause := errors.Cause(e) - // first check for generic error values - switch cause { - case ErrWriterAlreadyExists: + switch { + case errors.Is(e, ErrWriterAlreadyExists): status = http.StatusConflict default: - // if no generic error values matched, check the error types as well - switch cause.(type) { - case *errs.ConflictError: - status = http.StatusConflict - case *errs.BadRequestError: - status = http.StatusBadRequest - case *errs.NotFoundError: - status = http.StatusNotFound - case *errs.RateLimitExceededErr: - status = http.StatusTooManyRequests - case *errs.InsufficientStorageErr: - status = http.StatusInsufficientStorage - default: - status = http.StatusInternalServerError + var coder errs.StatusCoder + if errors.As(e, &coder) { + status = coder.StatusCode() } - } + w.WriteHeader(status) _, _ = w.Write([]byte(resBody)) diff --git a/pkg/executive/executive_endpoint_test.go b/pkg/executive/executive_endpoint_test.go index fb566964..67ce3abc 100644 --- a/pkg/executive/executive_endpoint_test.go +++ b/pkg/executive/executive_endpoint_test.go @@ -3,14 +3,14 @@ package executive_test import ( "bytes" "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" "reflect" "testing" "time" - "github.com/pkg/errors" - "github.com/google/go-cmp/cmp" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/executive" @@ -923,7 +923,7 @@ func TestExecEndpointHandler(_t *testing.T) { Method: http.MethodGet, ExpectedStatusCode: http.StatusNotFound, PreFunc: func(t *testing.T, atom *testExecEndpointHandlerAtom) { - atom.ei.TableSchemaReturns(nil, errors.Wrap(executive.ErrTableDoesNotExist, "boom")) + atom.ei.TableSchemaReturns(nil, fmt.Errorf("boom: %w", executive.ErrTableDoesNotExist)) }, PostFunc: func(t *testing.T, atom *testExecEndpointHandlerAtom) { require.EqualValues(t, 1, atom.ei.TableSchemaCallCount()) diff --git a/pkg/executive/executive_service.go b/pkg/executive/executive_service.go index e54c4b9f..1ecb98d7 100644 --- a/pkg/executive/executive_service.go +++ b/pkg/executive/executive_service.go @@ -11,7 +11,6 @@ import ( "syscall" "time" - "github.com/pkg/errors" ctldbpkg "github.com/segmentio/ctlstore/pkg/ctldb" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/limits" @@ -90,7 +89,7 @@ func (s *executiveService) Start(ctx context.Context, bind string) error { // tell the limiter to start picking up db changes if err := s.limiter.start(ctx); err != nil { - return errors.Wrap(err, "could not start limiter") + return fmt.Errorf("could not start limiter: %w", err) } // perform instrumentation in the background diff --git a/pkg/executive/mutators_store.go b/pkg/executive/mutators_store.go index 51d022dd..e467d435 100644 --- a/pkg/executive/mutators_store.go +++ b/pkg/executive/mutators_store.go @@ -7,9 +7,9 @@ import ( "database/sql" "encoding/base64" "encoding/hex" + "errors" "fmt" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/limits" "github.com/segmentio/ctlstore/pkg/schema" "github.com/segmentio/ctlstore/pkg/sqlgen" @@ -51,7 +51,7 @@ func (ms *mutatorStore) Register(writerName schema.WriterName, writerSecret stri case err == sql.ErrNoRows: // this is OK, it just means that it wasn't found case err != nil: - return errors.Wrap(err, "select from mutators") + return fmt.Errorf("select from mutators: %w", err) case count == 1: // writer already exists with this secret return nil @@ -70,7 +70,7 @@ func (ms *mutatorStore) Exists(writerName schema.WriterName) (bool, error) { row := ms.DB.QueryRowContext(ms.Ctx, qs, writerName.Name) var count int64 if err := row.Scan(&count); err != nil { - return false, errors.Wrap(err, "scan writer count") + return false, fmt.Errorf("scan writer count: %w", err) } return count > 0, nil } diff --git a/pkg/executive/status_writer.go b/pkg/executive/status_writer.go index a366d90f..d95cf136 100644 --- a/pkg/executive/status_writer.go +++ b/pkg/executive/status_writer.go @@ -1,6 +1,8 @@ package executive -import "net/http" +import ( + "net/http" +) type statusWriter struct { writer http.ResponseWriter diff --git a/pkg/executive/table_sizer.go b/pkg/executive/table_sizer.go index a787ee91..95e9e064 100644 --- a/pkg/executive/table_sizer.go +++ b/pkg/executive/table_sizer.go @@ -7,7 +7,6 @@ import ( "sync" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/limits" "github.com/segmentio/ctlstore/pkg/schema" @@ -126,11 +125,11 @@ func (s *tableSizer) refresh(ctx context.Context) error { defer cancel() sizes, err := s.getSizes(ctx) if err != nil { - return errors.Wrap(err, "get table sizes") + return fmt.Errorf("get table sizes: %w", err) } configuredLimits, err := s.getLimits(ctx) if err != nil { - return errors.Wrap(err, "get configured table limits") + return fmt.Errorf("get configured table limits: %w", err) } s.mut.Lock() defer s.mut.Unlock() @@ -168,7 +167,7 @@ func (s *tableSizer) getSizes(ctx context.Context) (map[schema.FamilyTable]int64 } dbSchema, err := s.getSchema(ctx) if err != nil { - return nil, errors.Wrap(err, "get schema") + return nil, fmt.Errorf("get schema: %w", err) } query := "SELECT table_name, (data_length + index_length) FROM information_schema.tables WHERE table_schema=?" rows, err := s.ctldb.QueryContext(ctx, query, dbSchema) diff --git a/pkg/globalstats/stats.go b/pkg/globalstats/stats.go index 0be6f196..6aa0fa11 100644 --- a/pkg/globalstats/stats.go +++ b/pkg/globalstats/stats.go @@ -4,13 +4,13 @@ package globalstats import ( "context" + "errors" "os" "path/filepath" "sync/atomic" "time" "github.com/segmentio/ctlstore/pkg/version" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" ) diff --git a/pkg/globalstats/stats_test.go b/pkg/globalstats/stats_test.go index bdc2b3cd..7537b156 100644 --- a/pkg/globalstats/stats_test.go +++ b/pkg/globalstats/stats_test.go @@ -26,7 +26,7 @@ func newFakeHandler() *fakeHandler { } } -func (h *fakeHandler) HandleMeasures(t time.Time, measures ...stats.Measure) { +func (h *fakeHandler) HandleMeasures(_ time.Time, measures ...stats.Measure) { h.mut.Lock() defer h.mut.Unlock() diff --git a/pkg/heartbeat/heartbeat.go b/pkg/heartbeat/heartbeat.go index fc5c929e..1d6dcd51 100644 --- a/pkg/heartbeat/heartbeat.go +++ b/pkg/heartbeat/heartbeat.go @@ -2,12 +2,12 @@ package heartbeat import ( "context" + "fmt" "io/ioutil" "net/http" "strings" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/utils" "github.com/segmentio/events/v2" @@ -50,7 +50,7 @@ func HeartbeatFromConfig(config HeartbeatConfig) (*Heartbeat, error) { writerSecret: config.WriterSecret, } if err := heartbeat.init(); err != nil { - return nil, errors.Wrap(err, "init heartbeat") + return nil, fmt.Errorf("init heartbeat: %w", err) } return heartbeat, nil } @@ -87,19 +87,19 @@ func (h *Heartbeat) pulse(ctx context.Context) { }) req, err := http.NewRequest(http.MethodPost, h.executive+"/families/"+h.family+"/mutations", body) if err != nil { - return errors.Wrap(err, "build mutation request") + return fmt.Errorf("build mutation request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("ctlstore-writer", h.writerName) req.Header.Set("ctlstore-secret", h.writerSecret) resp, err := client.Do(req) if err != nil { - return errors.Wrap(err, "make mutation request") + return fmt.Errorf("make mutation request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { b, _ := ioutil.ReadAll(resp.Body) - return errors.Errorf("could not make mutation request: %d: %s", resp.StatusCode, b) + return fmt.Errorf("could not make mutation request: %d: %s", resp.StatusCode, b) } events.Log("Heartbeat: %v", heartbeat) return nil @@ -117,28 +117,28 @@ func (h *Heartbeat) init() error { body := strings.NewReader(h.writerSecret) res, err := http.Post(h.executive+"/writers/"+h.writerName, "text/plain", body) if err != nil { - return errors.Wrap(err, "register writer") + return fmt.Errorf("register writer: %w", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK { b, _ := ioutil.ReadAll(res.Body) - return errors.Errorf("could not register writer: %d: %s", res.StatusCode, b) + return fmt.Errorf("could not register writer: %d: %s", res.StatusCode, b) } // setup the family ------------ req, err := http.NewRequest(http.MethodPost, h.executive+"/families/"+h.family, nil) if err != nil { - return errors.Wrap(err, "create family request") + return fmt.Errorf("create family request: %w", err) } res, err = client.Do(req) if err != nil { - return errors.Wrap(err, "make family request") + return fmt.Errorf("make family request: %w", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusConflict { b, _ := ioutil.ReadAll(res.Body) - return errors.Errorf("could not make family request: %v: %s", res.StatusCode, b) + return fmt.Errorf("could not make family request: %v: %s", res.StatusCode, b) } // setup the table ------------- @@ -152,17 +152,17 @@ func (h *Heartbeat) init() error { } req, err = http.NewRequest(http.MethodPost, h.executive+"/families/"+h.family+"/tables/"+h.table, utils.NewJsonReader(tableDef)) if err != nil { - return errors.Wrap(err, "create table request") + return fmt.Errorf("create table request: %w", err) } req.Header.Set("Content-Type", "application/json") res, err = client.Do(req) if err != nil { - return errors.Wrap(err, "make table request") + return fmt.Errorf("make table request: %w", err) } defer res.Body.Close() if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusConflict { b, _ := ioutil.ReadAll(res.Body) - return errors.Errorf("could not make table request: %v: %s", res.StatusCode, b) + return fmt.Errorf("could not make table request: %v: %s", res.StatusCode, b) } return nil diff --git a/pkg/ldbwriter/ldb_writer.go b/pkg/ldbwriter/ldb_writer.go index 44824603..07482467 100644 --- a/pkg/ldbwriter/ldb_writer.go +++ b/pkg/ldbwriter/ldb_writer.go @@ -3,6 +3,7 @@ package ldbwriter import ( "context" "database/sql" + "errors" "fmt" "github.com/pkg/errors" @@ -54,7 +55,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem tx, err = writer.Db.Begin() if err != nil { errs.Incr("sql_ldb_writer.begin_tx.error") - return errors.Wrap(err, "open tx error") + return fmt.Errorf("open tx error: %w", err) } } else { // Applying a ledger transaction, so bring it into scope @@ -85,7 +86,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem if err != nil { tx.Rollback() errs.Incr("sql_ldb_writer.upsert_last_update.error") - return errors.Wrap(err, "update last_update") + return fmt.Errorf("update last_update: %w", err) } // Update the sequence tracker row. This SQL will insert the row @@ -104,7 +105,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem if err != nil { tx.Rollback() errs.Incr("sql_ldb_writer.upsert_seq.error") - return errors.Wrap(err, "update seq tracker error") + return fmt.Errorf("update seq tracker error: %w", err) } // Check for replayed statements @@ -112,7 +113,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem if err != nil { tx.Rollback() errs.Incr("sql_ldb_writer.upsert_seq.rows_affected_error") - return errors.Wrap(err, "update seq tracker rows affected error") + return fmt.Errorf("update seq tracker rows affected error: %w", err) } if rowsAffected == 0 { tx.Rollback() @@ -146,7 +147,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem events.Log("Failed to commit Tx at seq %{seq}s: %{error}+v", statement.Sequence, err) - return errors.Wrap(err, "commit multi-statement dml tx error") + return fmt.Errorf("commit multi-statement dml tx error: %w", err) } stats.Incr("sql_ldb_writer.ledgerTx.commit.success") events.Debug("Committed TX at %{sequence}v", statement.Sequence) @@ -159,7 +160,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem if err != nil { tx.Rollback() errs.Incr("sql_ldb_writer.exec.error") - return errors.Wrap(err, "exec dml statement error") + return fmt.Errorf("exec dml statement error: %w", err) } stats.Incr("sql_ldb_writer.exec.success") @@ -176,7 +177,7 @@ func (writer *SqlLdbWriter) ApplyDMLStatement(_ context.Context, statement schem tx.Rollback() errs.Incr("sql_ldb_writer.single.commit.error") errs.Incr("sql_ldb_writer.commit.error") - return errors.Wrap(err, "commit one-statement dml tx error") + return fmt.Errorf("commit one-statement dml tx error: %w", err) } } diff --git a/pkg/ledger/ecs_client.go b/pkg/ledger/ecs_client.go index 1533a894..253e4ad4 100644 --- a/pkg/ledger/ecs_client.go +++ b/pkg/ledger/ecs_client.go @@ -1,6 +1,8 @@ package ledger -import "github.com/aws/aws-sdk-go/service/ecs/ecsiface" +import ( + "github.com/aws/aws-sdk-go/service/ecs/ecsiface" +) //counterfeiter:generate -o fakes/ecs_client.go . ECSClient type ECSClient interface { diff --git a/pkg/ledger/ecs_metadata.go b/pkg/ledger/ecs_metadata.go index cc3a8ee5..03634eed 100644 --- a/pkg/ledger/ecs_metadata.go +++ b/pkg/ledger/ecs_metadata.go @@ -1,9 +1,8 @@ package ledger import ( + "fmt" "strings" - - "github.com/pkg/errors" ) type EcsMetadata struct { @@ -17,7 +16,7 @@ type EcsMetadata struct { func (m EcsMetadata) accountID() (string, error) { parts := strings.Split(m.ContainerInstanceArn, ":") if len(parts) != 6 { - return "", errors.Errorf("invalid container instance arn: '%s'", m.ContainerInstanceArn) + return "", fmt.Errorf("invalid container instance arn: '%s'", m.ContainerInstanceArn) } return parts[4], nil } diff --git a/pkg/ledger/errors_is_test.go b/pkg/ledger/errors_is_test.go new file mode 100644 index 00000000..9387b677 --- /dev/null +++ b/pkg/ledger/errors_is_test.go @@ -0,0 +1,15 @@ +package ledger + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTemporaryAs(t *testing.T) { + var err error = temporaryError{errors.New("boom")} + err = fmt.Errorf("wrapped: %w", err) + require.True(t, errors.Is(err, temporaryError{})) +} diff --git a/pkg/ledger/ledger_monitor.go b/pkg/ledger/ledger_monitor.go index 4354609d..11acfdb0 100644 --- a/pkg/ledger/ledger_monitor.go +++ b/pkg/ledger/ledger_monitor.go @@ -3,6 +3,7 @@ package ledger import ( "context" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" @@ -12,7 +13,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecs" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" @@ -77,7 +77,7 @@ func (m *Monitor) Start(ctx context.Context) { err := func() error { latency, err := m.latencyFunc(ctx) if err != nil { - return errors.Wrap(err, "get ledger latency") + return fmt.Errorf("get ledger latency: %w", err) } // always instrument ledger latency even if ECS behavior is disabled. stats.Set("reflector-ledger-latency", latency) @@ -86,13 +86,13 @@ func (m *Monitor) Start(ctx context.Context) { case latency <= m.cfg.MaxHealthyLatency && (health == nil || *health != true): // set a healthy attribute if err := m.setHealthAttribute(ctx, m.cfg.HealthyAttributeValue); err != nil { - return errors.Wrap(err, "set healthy") + return fmt.Errorf("set healthy: %w", err) } health = pointer.ToBool(true) case latency > m.cfg.MaxHealthyLatency && (health == nil || *health != false): // set an unhealthy attribute if err := m.setHealthAttribute(ctx, m.cfg.UnhealthyAttributeValue); err != nil { - return errors.Wrap(err, "set unhealthy") + return fmt.Errorf("set unhealthy: %w", err) } health = pointer.ToBool(false) } @@ -110,8 +110,8 @@ func (m *Monitor) Start(ctx context.Context) { switch { case err == nil: case errs.IsCanceled(err): - // context is done, just let it fall through - case errors.Is("temporary", err) && temporaryErrorLimit > 0: + // context is done, just let it fall through + case errors.Is(err, temporaryError{}): // don't increment error metric for a temporary error temporaryErrorLimit-- events.Log("Temporary monitor ledger latency error: %s", err) @@ -128,11 +128,11 @@ func (m *Monitor) setHealthAttribute(ctx context.Context, attrValue string) erro events.Log("Setting ECS instance attribute: %s=%s", m.cfg.AttributeName, attrValue) ecsMeta, err := m.getECSMetadata(ctx) if err != nil { - return errors.Wrap(err, "get ecs metadata") + return fmt.Errorf("get ecs metadata: %w", err) } clusterARN, err := m.buildClusterARN(ecsMeta) if err != nil { - return errors.Wrap(err, "build cluster ARN") + return fmt.Errorf("build cluster ARN: %w", err) } events.Log("Putting attribute name=%{attName}v value=%{attValue}v targetID=%{targetID}v targetType=%{targetType}v", m.cfg.AttributeName, attrValue, ecsMeta.ContainerInstanceArn, ecsContainerInstanceTargetType) @@ -149,7 +149,7 @@ func (m *Monitor) setHealthAttribute(ctx context.Context, attrValue string) erro Cluster: aws.String(clusterARN), }) if err != nil { - return errors.Wrap(err, "put attributes") + return fmt.Errorf("put attributes: %w", err) } return nil } @@ -171,7 +171,7 @@ func (m *Monitor) buildClusterARN(meta EcsMetadata) (arn string, err error) { } accountID, err := meta.accountID() if err != nil { - return errors.Wrap(err, "get account id") + return fmt.Errorf("get account id: %w", err) } cluster := meta.Cluster arn = fmt.Sprintf("arn:aws:ecs:%s:%s:cluster/%s", region, accountID, cluster) @@ -189,17 +189,24 @@ func (m *Monitor) getECSMetadata(ctx context.Context) (meta EcsMetadata, err err if err != nil { // signal that this is a temporary error and we can retry a number of times before // we start reporting errors. - return errors.WithTypes(errors.Wrap(err, "get ecs metadata"), "temporary") + return temporaryError{fmt.Errorf("get ecs metadata: %w", err)} } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { b, _ := ioutil.ReadAll(resp.Body) - return errors.Errorf("could not get ecs metadata: [%d]: %s", resp.StatusCode, b) + return fmt.Errorf("could not get ecs metadata: [%d]: %s", resp.StatusCode, b) } if err = json.NewDecoder(resp.Body).Decode(&meta); err != nil { - return errors.Wrap(err, "read metadata") + return fmt.Errorf("read metadata: %w", err) } return nil }() return meta, err } + +type temporaryError struct{ error } + +func (te temporaryError) Is(err error) bool { + _, ok := err.(temporaryError) + return ok +} diff --git a/pkg/ledger/ledger_monitor_test.go b/pkg/ledger/ledger_monitor_test.go index 0eb861d9..7a6e3f8b 100644 --- a/pkg/ledger/ledger_monitor_test.go +++ b/pkg/ledger/ledger_monitor_test.go @@ -2,13 +2,14 @@ package ledger_test import ( "context" + "errors" "sync" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ecs" - "github.com/pkg/errors" + "github.com/segmentio/ctlstore/pkg/ledger" "github.com/segmentio/ctlstore/pkg/ledger/fakes" _ "github.com/segmentio/events/v2/log" diff --git a/pkg/ledger/opts.go b/pkg/ledger/opts.go index 19467120..aa618d2f 100644 --- a/pkg/ledger/opts.go +++ b/pkg/ledger/opts.go @@ -1,6 +1,8 @@ package ledger -import "time" +import ( + "time" +) func WithCheckCallback(fn func()) MonitorOpt { return func(m *Monitor) { diff --git a/pkg/limits/limits.go b/pkg/limits/limits.go index ef192167..577e6336 100644 --- a/pkg/limits/limits.go +++ b/pkg/limits/limits.go @@ -2,11 +2,11 @@ package limits import ( "encoding/json" + "errors" "fmt" "math" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/units" ) @@ -70,7 +70,7 @@ func (l *RateLimit) UnmarshalJSON(b []byte) error { case float64: l.Amount = int64(amount) default: - return errors.Errorf("invalid amount: '%v'", amount) + return fmt.Errorf("invalid amount: '%v'", amount) } } if period, ok := val["period"]; ok { @@ -80,11 +80,11 @@ func (l *RateLimit) UnmarshalJSON(b []byte) error { case string: parsed, err := time.ParseDuration(period) if err != nil { - return errors.Errorf("invalid period: '%v'", period) + return fmt.Errorf("invalid period: '%v'", period) } l.Period = parsed default: - return errors.Errorf("invalid period: '%v'", period) + return fmt.Errorf("invalid period: '%v'", period) } } return nil diff --git a/pkg/limits/limits_test.go b/pkg/limits/limits_test.go index 743b898d..110a4130 100644 --- a/pkg/limits/limits_test.go +++ b/pkg/limits/limits_test.go @@ -2,10 +2,10 @@ package limits import ( "encoding/json" + "errors" "testing" "time" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pkg/logwriter/sized_log_writer.go b/pkg/logwriter/sized_log_writer.go index ee83e01f..538a8c8a 100644 --- a/pkg/logwriter/sized_log_writer.go +++ b/pkg/logwriter/sized_log_writer.go @@ -1,10 +1,9 @@ package logwriter import ( + "errors" "os" "strings" - - "github.com/pkg/errors" ) const sizedLogWriterDefaultMode os.FileMode = 0644 diff --git a/pkg/mysql/mysql_info.go b/pkg/mysql/mysql_info.go index df3b6664..39e2b850 100644 --- a/pkg/mysql/mysql_info.go +++ b/pkg/mysql/mysql_info.go @@ -3,8 +3,8 @@ package mysql import ( "context" "database/sql" + "fmt" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/schema" "github.com/segmentio/ctlstore/pkg/sqlgen" ) @@ -17,13 +17,13 @@ func (m *MySQLDBInfo) GetAllTables(ctx context.Context) ([]schema.FamilyTable, e var res []schema.FamilyTable rows, err := m.Db.QueryContext(ctx, "select distinct table_name from information_schema.tables order by table_name") if err != nil { - return nil, errors.Wrap(err, "query table names") + return nil, fmt.Errorf("query table names: %w", err) } for rows.Next() { var fullName string err = rows.Scan(&fullName) if err != nil { - return nil, errors.Wrap(err, "scan table name") + return nil, fmt.Errorf("scan table name: %w", err) } if ft, ok := schema.ParseFamilyTable(fullName); ok { res = append(res, ft) diff --git a/pkg/reflector/bootstrap_test.go b/pkg/reflector/bootstrap_test.go index 1fb1d376..8ee8efd3 100644 --- a/pkg/reflector/bootstrap_test.go +++ b/pkg/reflector/bootstrap_test.go @@ -1,6 +1,7 @@ package reflector import ( + "errors" "io" "io/ioutil" "strings" @@ -8,7 +9,6 @@ import ( "time" "github.com/segmentio/ctlstore/pkg/errs" - "github.com/segmentio/errors-go" "github.com/stretchr/testify/require" ) @@ -62,7 +62,7 @@ func TestBoostrapLDB(t *testing.T) { name: "temporary failure", dl: &fakeDownloadTo{ res: []readerErr{ - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, {r: strings.NewReader(ldbContent)}, }, }, @@ -72,10 +72,10 @@ func TestBoostrapLDB(t *testing.T) { name: "max temporary failures", dl: &fakeDownloadTo{ res: []readerErr{ - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, {r: strings.NewReader(ldbContent)}, }, }, @@ -85,11 +85,11 @@ func TestBoostrapLDB(t *testing.T) { name: "too many temporary failures", // max retries edge case dl: &fakeDownloadTo{ res: []readerErr{ - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, {r: strings.NewReader(ldbContent)}, }, }, @@ -100,7 +100,7 @@ func TestBoostrapLDB(t *testing.T) { name: "permanent failure", dl: &fakeDownloadTo{ res: []readerErr{ - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypePermanent)}, + {e: errs.ErrTypePermanent{errors.New("failure")}}, }, }, fc: "", @@ -109,10 +109,10 @@ func TestBoostrapLDB(t *testing.T) { name: "permanent failure with retries before it", dl: &fakeDownloadTo{ res: []readerErr{ - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypeTemporary)}, - {e: errors.WithTypes(errors.New("failure"), errs.ErrTypePermanent)}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, + {e: errs.ErrTypeTemporary{errors.New("failure")}}, }, }, fc: "", @@ -138,17 +138,3 @@ func TestBoostrapLDB(t *testing.T) { }) } } - -// verifies the wrapping behavior of errors-go -func TestErrorsGoTypes(t *testing.T) { - // verify when the outer error is typed - err := errors.New("root cause") - err = errors.WithTypes(err, "Temporary") - require.True(t, errors.Is("Temporary", err)) - - // verify when the inner error is typed - err = errors.New("root cause") - err = errors.WithTypes(err, "Temporary") - err = errors.Wrap(err, "wrapped") - require.True(t, errors.Is("Temporary", err)) -} diff --git a/pkg/reflector/dml_source.go b/pkg/reflector/dml_source.go index fc3879ec..2eda714a 100644 --- a/pkg/reflector/dml_source.go +++ b/pkg/reflector/dml_source.go @@ -3,10 +3,10 @@ package reflector import ( "context" "database/sql" + "errors" "fmt" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/schema" "github.com/segmentio/ctlstore/pkg/sqlgen" "github.com/segmentio/stats/v4" @@ -54,7 +54,7 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState rows, err := source.db.QueryContext(ctx, qs, source.lastSequence) if err != nil { - return statement, errors.Wrap(err, "select row") + return statement, fmt.Errorf("select row: %w", err) } // CR: reconsider naked returns here @@ -78,7 +78,7 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState err = rows.Scan(&row.seq, &row.leaderTs, &row.statement) if err != nil { - return statement, errors.Wrap(err, "scan row") + return statement, fmt.Errorf("scan row: %w", err) } if schema.DMLSequence(row.seq) > source.lastSequence+1 { @@ -87,7 +87,7 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState timestamp, err := time.Parse(dmlLedgerTimestampFormat, row.leaderTs) if err != nil { - return statement, errors.Wrapf(err, "could not parse time '%s'", row.leaderTs) + return statement, fmt.Errorf("could not parse time '%s': %w", row.leaderTs, err) } dmlst := schema.DMLStatement{ @@ -106,7 +106,7 @@ func (source *sqlDmlSource) Next(ctx context.Context) (statement schema.DMLState err = rows.Err() if err != nil { - return statement, errors.Wrap(err, "rows err") + return statement, fmt.Errorf("rows err: %w", err) } } diff --git a/pkg/reflector/dml_source_test.go b/pkg/reflector/dml_source_test.go index ccf90249..29ebecde 100644 --- a/pkg/reflector/dml_source_test.go +++ b/pkg/reflector/dml_source_test.go @@ -3,11 +3,12 @@ package reflector import ( "context" "database/sql" + "errors" "fmt" + "strings" "testing" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/limits" "github.com/segmentio/ctlstore/pkg/sqlgen" "github.com/stretchr/testify/require" @@ -94,16 +95,16 @@ func TestSqlDmlSource(t *testing.T) { foundError := false for i := 0; i < 2; i++ { _, err = src.Next(ctx) - cause := errors.Cause(err) + switch { - case cause == nil: - case cause == context.Canceled: + case err == nil: + case errors.Is(err, context.Canceled): foundError = true break // the db driver will at some point return an error with // the value "interrupted" instead of returning // context.Canceled(). Sigh. - case cause.Error() == "interrupted": + case strings.Contains(err.Error(), "interrupted"): foundError = true break } diff --git a/pkg/reflector/download.go b/pkg/reflector/download.go index 7e7f958f..acb17a6a 100644 --- a/pkg/reflector/download.go +++ b/pkg/reflector/download.go @@ -3,6 +3,7 @@ package reflector import ( "bytes" "compress/gzip" + "fmt" "io" "net/http" "strings" @@ -12,7 +13,6 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" @@ -49,11 +49,11 @@ func (d *S3Downloader) DownloadTo(w io.Writer) (n int64, err error) { case awserr.RequestFailure: if d.StartOverOnNotFound && err.StatusCode() == http.StatusNotFound { // don't bother retrying. we'll start with a fresh ldb. - return -1, errors.WithTypes(errors.Wrap(err, "get s3 data"), errs.ErrTypePermanent) + return -1, errs.ErrTypePermanent{fmt.Errorf("get s3 data: %w", err)} } } // retry - return -1, errors.WithTypes(errors.Wrap(err, "get s3 data"), errs.ErrTypeTemporary) + return -1, errs.ErrTypeTemporary{fmt.Errorf("get s3 data: %w", err)} } defer obj.Body.Close() compressedSize := obj.ContentLength @@ -61,12 +61,12 @@ func (d *S3Downloader) DownloadTo(w io.Writer) (n int64, err error) { if strings.HasSuffix(d.Key, ".gz") { reader, err = gzip.NewReader(reader) if err != nil { - return n, errors.Wrap(err, "create gzip reader") + return n, fmt.Errorf("create gzip reader: %w", err) } } n, err = io.Copy(w, reader) if err != nil { - return n, errors.Wrap(err, "copy from s3 to writer") + return n, fmt.Errorf("copy from s3 to writer: %w", err) } if compressedSize != nil { events.Log("LDB inflated %d -> %d bytes", *compressedSize, n) diff --git a/pkg/reflector/download_test.go b/pkg/reflector/download_test.go index 808e211b..c8654e16 100644 --- a/pkg/reflector/download_test.go +++ b/pkg/reflector/download_test.go @@ -3,6 +3,7 @@ package reflector_test import ( "bytes" "compress/gzip" + "errors" "io" "io/ioutil" "math/rand" @@ -12,9 +13,9 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/s3" + "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/fakes" "github.com/segmentio/ctlstore/pkg/reflector" - "github.com/segmentio/errors-go" "github.com/stretchr/testify/require" ) @@ -30,7 +31,7 @@ func TestS3DownloadErrors(t *testing.T) { s3Client func() reflector.S3Client n int64 err error - errTypes []string + errTypes []error }{ { name: "success", @@ -52,7 +53,7 @@ func TestS3DownloadErrors(t *testing.T) { return f }, err: errors.New("get s3 data: failure"), - errTypes: []string{"Temporary"}, // generic failures get retried + errTypes: []error{errs.ErrTypeTemporary{}}, // generic failures get retried n: -1, }, { @@ -65,7 +66,7 @@ func TestS3DownloadErrors(t *testing.T) { return f }, err: errors.New("failure"), - errTypes: []string{"Permanent"}, + errTypes: []error{errs.ErrTypePermanent{}}, n: -1, }, { @@ -77,7 +78,7 @@ func TestS3DownloadErrors(t *testing.T) { return f }, err: errors.New("failure"), - errTypes: []string{"Temporary"}, + errTypes: []error{errs.ErrTypeTemporary{}}, n: -1, }, { @@ -89,7 +90,7 @@ func TestS3DownloadErrors(t *testing.T) { return f }, err: errors.New("failure"), - errTypes: []string{"Temporary"}, + errTypes: []error{errs.ErrTypeTemporary{}}, n: -1, }, } { @@ -104,11 +105,8 @@ func TestS3DownloadErrors(t *testing.T) { require.NoError(t, err) } else { require.Contains(t, err.Error(), test.err.Error()) - require.EqualValues(t, len(test.errTypes), len(errors.Types(err)), - "got types: %v", errors.Types(err)) - for _, typ := range test.errTypes { - require.True(t, errors.Is(typ, err), - "error did not have the error type '%s', got types: %v", typ, errors.Types(err)) + for _, target := range test.errTypes { + require.True(t, errors.Is(err, target)) } } require.EqualValues(t, test.n, n) diff --git a/pkg/reflector/reflector.go b/pkg/reflector/reflector.go index 324c2d25..5346a593 100644 --- a/pkg/reflector/reflector.go +++ b/pkg/reflector/reflector.go @@ -24,7 +24,6 @@ import ( "github.com/segmentio/ctlstore/pkg/ledger" "github.com/segmentio/ctlstore/pkg/logwriter" "github.com/segmentio/ctlstore/pkg/sqlite" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" _ "github.com/segmentio/events/v2/log" // lets events actually log @@ -179,7 +178,7 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { var maxKnownSeq sql.NullInt64 err = row.Scan(&maxKnownSeq) if err != nil { - return nil, errors.Wrap(err, "find max seq from ledger") + return nil, fmt.Errorf("find max seq from ledger: %w", err) } events.Log("Max known ledger sequence: %{seq}d", maxKnownSeq) @@ -267,7 +266,7 @@ func ReflectorFromConfig(config ReflectorConfig) (*Reflector, error) { ledgerLatencyFunc := ctlstore.NewLDBReaderFromDB(ldbDB).GetLedgerLatency ledgerMon, err := ledger.NewLedgerMonitor(config.LedgerHealth, ledgerLatencyFunc) if err != nil { - return nil, errors.Wrap(err, "build ledger latency monitor") + return nil, fmt.Errorf("build ledger latency monitor: %w", err) } var walMon starter @@ -343,21 +342,19 @@ func (r *Reflector) Start(ctx context.Context) error { err := func() error { shovel, err := r.shovel() if err != nil { - return errors.Wrap(err, "build shovel") + return fmt.Errorf("build shovel: %w", err) } defer shovel.Close() events.Log("Shoveling...") stats.Incr("reflector.shovel_start") err = shovel.Start(ctx) - return errors.Wrap(err, "shovel") + return fmt.Errorf("shovel: %w", err) }() switch { case errs.IsCanceled(err): // this is normal - case events.IsTermination(errors.Cause(err)): // this is normal - events.Log("Reflector received termination signal") case err != nil: switch { - case errors.Is("SkippedSequence", err): + case errors.Is(err, errSkippedSequence): // this is instrumented elsewhere and is not an error that we need // to handle normally, so we will skip instrumenting this as a // shovel_error for now. @@ -446,7 +443,7 @@ func bootstrapLDB(cfg ldbBootstrapConfig) error { } dler = &memoryDownloader{Content: decoded} default: - return errors.Errorf("unsupported scheme '%s' for bootstrap URL '%s'", scheme, cfg.url) + return fmt.Errorf("unsupported scheme '%s' for bootstrap URL '%s'", scheme, cfg.url) } // Download to a temp file first to prevent leaving a zero-byte file @@ -483,7 +480,7 @@ func bootstrapLDB(cfg ldbBootstrapConfig) error { } events.Log("Bootstrap: Downloaded %{bytes}d bytes", bytes) return nil - case errors.Is(errs.ErrTypeTemporary, err): + case errors.Is(err, errs.ErrTypeTemporary{}): incrError("temporary") events.Log("Temporary error trying to download snapshot: %{error}s", err) delay := cfg.retryDelay @@ -492,7 +489,7 @@ func bootstrapLDB(cfg ldbBootstrapConfig) error { } events.Log("Retrying in %{delay}s", delay) time.Sleep(delay) - case errors.Is(errs.ErrTypePermanent, err): + case errors.Is(err, errs.ErrTypePermanent{}): incrError("permanent") events.Log("Could not download snapshot: %{error}s", err) events.Log("Starting with a new LDB") @@ -502,7 +499,7 @@ func bootstrapLDB(cfg ldbBootstrapConfig) error { return err } } - return errors.Errorf("download of ldb snapshot failed after max attempts reached: %s", err) + return fmt.Errorf("download of ldb snapshot failed after max attempts reached: %s", err) } type noopStarter struct { diff --git a/pkg/reflector/reflector_ctl.go b/pkg/reflector/reflector_ctl.go index bfb35817..918b1d56 100644 --- a/pkg/reflector/reflector_ctl.go +++ b/pkg/reflector/reflector_ctl.go @@ -2,10 +2,10 @@ package reflector import ( "context" + "fmt" "sync" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/utils" "github.com/segmentio/events/v2" @@ -168,7 +168,7 @@ func (r *ReflectorCtl) lifecycle(appCtx context.Context) { running = false case <-time.After(reflectorCtlTimeout): errs.Incr("reflector-ctl-timeouts", stats.Tag{Name: "op", Value: "stop-reflector"}) - err := errors.Errorf("could not stop reflector after %s", reflectorCtlTimeout) + err := fmt.Errorf("could not stop reflector after %s", reflectorCtlTimeout) msg.sendErr(appCtx, err) continue } @@ -195,6 +195,6 @@ func (m *reflectorCtlMsg) sendErr(ctx context.Context, err error) { // should never happen but we don't want to block indefinitely // if someone did not create a result chan errs.Incr("reflector-ctl-timeouts", stats.Tag{Name: "op", Value: "send-err"}) - panic(errors.Errorf("could not send err on ctl msg after %s", reflectorCtlTimeout)) + panic(fmt.Errorf("could not send err on ctl msg after %s", reflectorCtlTimeout)) } } diff --git a/pkg/reflector/reflector_test.go b/pkg/reflector/reflector_test.go index 1e77086e..f97adebf 100644 --- a/pkg/reflector/reflector_test.go +++ b/pkg/reflector/reflector_test.go @@ -17,7 +17,6 @@ import ( "github.com/segmentio/ctlstore/pkg/ldb" "github.com/segmentio/ctlstore/pkg/ldbwriter" "github.com/segmentio/ctlstore/pkg/ledger" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/stretchr/testify/require" ) diff --git a/pkg/reflector/shovel.go b/pkg/reflector/shovel.go index 11370635..d95b9e7b 100644 --- a/pkg/reflector/shovel.go +++ b/pkg/reflector/shovel.go @@ -2,13 +2,14 @@ package reflector import ( "context" + "errors" + "fmt" "io" "time" "github.com/segmentio/ctlstore/pkg/errs" "github.com/segmentio/ctlstore/pkg/ldbwriter" "github.com/segmentio/ctlstore/pkg/schema" - "github.com/segmentio/errors-go" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" ) @@ -100,9 +101,7 @@ func (s *shovel) Start(ctx context.Context) error { if s.abortOnSeqSkip { // Mitigation for a bug that we haven't found yet stats.Incr("shovel.skipped_sequence_abort") - err = errors.New("shovel skipped sequence") - err = errors.WithTypes(err, "SkippedSequence") - return err + return errSkippedSequence } } } @@ -111,7 +110,7 @@ func (s *shovel) Start(ctx context.Context) error { err = s.writer.ApplyDMLStatement(ctx, st) if err != nil { errs.Incr("shovel.apply_statement.error") - return errors.Wrapf(err, "ledger seq: %d", st.Sequence) + return fmt.Errorf("ledger seq: %d: %w", st.Sequence, err) } lastSeq = st.Sequence @@ -128,6 +127,8 @@ func (s *shovel) Start(ctx context.Context) error { } } +var errSkippedSequence = errors.New("shovel skipped sequence") + func (s *shovel) Close() error { for _, closer := range s.closers { err := closer.Close() diff --git a/pkg/schema/db_column_meta.go b/pkg/schema/db_column_meta.go index 70026329..b05bd926 100644 --- a/pkg/schema/db_column_meta.go +++ b/pkg/schema/db_column_meta.go @@ -1,6 +1,8 @@ package schema -import "database/sql" +import ( + "database/sql" +) type DBColumnMeta struct { Name string diff --git a/pkg/schema/field_type.go b/pkg/schema/field_type.go index 6c596aca..666b6ca5 100644 --- a/pkg/schema/field_type.go +++ b/pkg/schema/field_type.go @@ -1,6 +1,8 @@ package schema -import "strings" +import ( + "strings" +) type FieldType int diff --git a/pkg/schema/primary_key.go b/pkg/schema/primary_key.go index e6529cca..6961d53b 100644 --- a/pkg/schema/primary_key.go +++ b/pkg/schema/primary_key.go @@ -1,6 +1,8 @@ package schema -import errors "github.com/segmentio/errors-go" +import ( + "fmt" +) var PrimaryKeyZero = PrimaryKey{} @@ -50,7 +52,7 @@ func NewPKFromRawNamesAndTypes(names []string, types []string) (PrimaryKey, erro } ft, ok := SqlTypeToFieldType(types[i]) if !ok { - return PrimaryKeyZero, errors.Errorf("no field type found for '%s'", types[i]) + return PrimaryKeyZero, fmt.Errorf("no field type found for '%s'", types[i]) } fns[i] = fn fts[i] = ft diff --git a/pkg/schema/writer_name.go b/pkg/schema/writer_name.go index e9ec1d9b..1682445d 100644 --- a/pkg/schema/writer_name.go +++ b/pkg/schema/writer_name.go @@ -1,6 +1,8 @@ package schema -import "fmt" +import ( + "fmt" +) // use newWriterName to construct a writerName type WriterName struct { diff --git a/pkg/sidecar/sidecar.go b/pkg/sidecar/sidecar.go index b0ba08a3..8d639209 100644 --- a/pkg/sidecar/sidecar.go +++ b/pkg/sidecar/sidecar.go @@ -3,6 +3,8 @@ package sidecar import ( "context" "encoding/json" + "errors" + "fmt" "log" "net/http" "os" @@ -10,7 +12,6 @@ import ( "github.com/gorilla/mux" "github.com/segmentio/ctlstore" - "github.com/segmentio/errors-go" "github.com/segmentio/stats/v4" "github.com/segmentio/stats/v4/httpstats" ) @@ -44,6 +45,8 @@ type ( } ) +var errLimitExceeded = errors.New("limit exceeded") + func (k Key) ToValue() interface{} { switch { case k.Binary != nil: @@ -73,7 +76,7 @@ func New(config Config) (*Sidecar, error) { err := fn(w, r) switch { case err == nil: - case errors.Is("limit-exceeded", err): + case errors.Is(err, errLimitExceeded): w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) default: http.Error(w, err.Error(), http.StatusInternalServerError) @@ -105,7 +108,7 @@ func (s *Sidecar) Start(ctx context.Context) error { } defer srv.Close() err := srv.ListenAndServe() - return errors.Wrap(err, "listen and serve") + return fmt.Errorf("listen and serve: %w", err) } func (s *Sidecar) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -123,7 +126,7 @@ func (s *Sidecar) statsHandler(delegate http.Handler) http.Handler { func (s *Sidecar) getLedgerLatency(w http.ResponseWriter, r *http.Request) error { duration, err := s.reader.GetLedgerLatency(r.Context()) if err != nil { - return errors.Wrap(err, "get ledger latency") + return fmt.Errorf("get ledger latency: %w", err) } res := map[string]interface{}{ "value": duration.Seconds(), @@ -134,7 +137,7 @@ func (s *Sidecar) getLedgerLatency(w http.ResponseWriter, r *http.Request) error func (s *Sidecar) healthcheck(w http.ResponseWriter, r *http.Request) error { _, err := s.reader.GetLedgerLatency(r.Context()) - return errors.Wrap(err, "healthcheck") + return fmt.Errorf("healthcheck: %w", err) } func (s *Sidecar) ping(w http.ResponseWriter, r *http.Request) error { @@ -150,7 +153,7 @@ func (s *Sidecar) getRowsByKeyPrefix(w http.ResponseWriter, r *http.Request) err var rr ReadRequest err := json.NewDecoder(r.Body).Decode(&rr) if err != nil { - return errors.Wrap(err, "decode body") + return fmt.Errorf("decode body: %w", err) } res := make([]interface{}, 0) rows, err := s.reader.GetRowsByKeyPrefix(r.Context(), family, table, keysToInterface(rr.Key)...) @@ -162,12 +165,12 @@ func (s *Sidecar) getRowsByKeyPrefix(w http.ResponseWriter, r *http.Request) err out := make(map[string]interface{}) err = rows.Scan(out) if err != nil { - return errors.Wrap(err, "scan") + return fmt.Errorf("scan: %w", err) } res = append(res, out) if s.maxRows > 0 && len(res) > s.maxRows { - err = errors.Errorf("max row count (%d) exceeded", s.maxRows) - err = errors.WithTypes(err, "limit-exceeded") + err = fmt.Errorf("max row count (%d) exceeded", s.maxRows) + err = fmt.Errorf("%w: %v", errLimitExceeded, err) return err } } @@ -188,7 +191,7 @@ func (s *Sidecar) getRowByKey(w http.ResponseWriter, r *http.Request) error { var rr ReadRequest err := json.NewDecoder(r.Body).Decode(&rr) if err != nil { - return errors.Wrap(err, "decode body") + return fmt.Errorf("decode body: %w", err) } out := make(map[string]interface{}) diff --git a/pkg/sidecar/sidecar_test.go b/pkg/sidecar/sidecar_test.go index 94df2e53..a2ce4089 100644 --- a/pkg/sidecar/sidecar_test.go +++ b/pkg/sidecar/sidecar_test.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/base64" "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" "path/filepath" @@ -13,6 +15,13 @@ import ( "github.com/stretchr/testify/require" ) +func TestErrLimitExceeded(t *testing.T) { + cause := errors.New("boom") + err := fmt.Errorf("%w: %v", errLimitExceeded, cause) + require.True(t, errors.Is(err, errLimitExceeded)) + require.Equal(t, "limit exceeded: boom", err.Error()) +} + func TestLedgerLatency(t *testing.T) { tu, teardown := ctlstore.NewLDBTestUtil(t) defer teardown() diff --git a/pkg/sqlgen/sqlgen.go b/pkg/sqlgen/sqlgen.go index 1b0f8d3a..9b4bcede 100644 --- a/pkg/sqlgen/sqlgen.go +++ b/pkg/sqlgen/sqlgen.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "encoding/base64" "encoding/hex" + "errors" "fmt" "reflect" "regexp" @@ -14,7 +15,6 @@ import ( "sync" "github.com/segmentio/ctlstore/pkg/schema" - "github.com/segmentio/errors-go" ) type MetaTable struct { @@ -127,7 +127,7 @@ func maybeDecodeBase64(val interface{}, should bool) (interface{}, error) { decoded, err := base64.StdEncoding.DecodeString(valstr) if err != nil { - return nil, errors.Wrap(err, "maybeDecodeBase64") + return nil, fmt.Errorf("maybeDecodeBase64: %w", err) } return decoded, nil @@ -250,7 +250,7 @@ func (t *MetaTable) DeleteDML(values []interface{}) (string, error) { ft, found := t.fieldTypeByName(fn) if !found { - return "", errors.Errorf("DeleteDML couldn't find fieldName %s", fn.String()) + return "", fmt.Errorf("DeleteDML couldn't find fieldName %s", fn.String()) } val, err := maybeDecodeBase64(values[i], isBase64EncodedFieldType(ft)) if err != nil { diff --git a/pkg/sqlite/sql_change_buffer.go b/pkg/sqlite/sql_change_buffer.go index 0b6df82e..6eef4306 100644 --- a/pkg/sqlite/sql_change_buffer.go +++ b/pkg/sqlite/sql_change_buffer.go @@ -1,6 +1,8 @@ package sqlite -import "sync" +import ( + "sync" +) // SQLChangeBuffer accumulates sqliteWatchChanges and allows them to be popped // off later when writing the changelog. diff --git a/pkg/sqlite/sqlite_info.go b/pkg/sqlite/sqlite_info.go index 56922536..76361d5b 100644 --- a/pkg/sqlite/sqlite_info.go +++ b/pkg/sqlite/sqlite_info.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/schema" "github.com/segmentio/ctlstore/pkg/sqlgen" ) @@ -18,13 +17,13 @@ func (m *SqliteDBInfo) GetAllTables(ctx context.Context) ([]schema.FamilyTable, var res []schema.FamilyTable rows, err := m.Db.QueryContext(ctx, "select distinct name from sqlite_master where type='table' order by name") if err != nil { - return nil, errors.Wrap(err, "query table names") + return nil, fmt.Errorf("query table names: %w", err) } for rows.Next() { var fullName string err = rows.Scan(&fullName) if err != nil { - return nil, errors.Wrap(err, "scan table name") + return nil, fmt.Errorf("scan table name: %w", err) } if ft, ok := schema.ParseFamilyTable(fullName); ok { res = append(res, ft) diff --git a/pkg/sqlite/sqlite_watch.go b/pkg/sqlite/sqlite_watch.go index 9d351ce7..1517011e 100644 --- a/pkg/sqlite/sqlite_watch.go +++ b/pkg/sqlite/sqlite_watch.go @@ -3,8 +3,9 @@ package sqlite import ( "context" "database/sql" + "errors" + "fmt" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/scanfunc" "github.com/segmentio/ctlstore/pkg/schema" "github.com/segmentio/go-sqlite3" @@ -107,7 +108,7 @@ func (c *SQLiteWatchChange) ExtractKeys(db *sql.DB) ([][]interface{}, error) { }, } if err := ph.Scan(row[colInfo.Index]); err != nil { - return nil, errors.Wrap(err, "scan key value column") + return nil, fmt.Errorf("scan key value column: %w", err) } key = append(key, pkAndMeta{ Name: colInfo.ColumnName, diff --git a/pkg/supervisor/gzip_pipe.go b/pkg/supervisor/gzip_pipe.go index 72030b8f..87711af4 100644 --- a/pkg/supervisor/gzip_pipe.go +++ b/pkg/supervisor/gzip_pipe.go @@ -2,10 +2,10 @@ package supervisor import ( "compress/gzip" + "errors" + "fmt" "io" "sync" - - "github.com/pkg/errors" ) type gzipCompressionReader struct { @@ -39,10 +39,10 @@ func (r *gzipCompressionReader) Read(p []byte) (n int, err error) { pw.CloseWithError(func() error { _, err := io.Copy(gw, r.reader) if err != nil { - return errors.Wrap(err, "copy to gzip writer") + return fmt.Errorf("copy to gzip writer: %w", err) } if err = gw.Close(); err != nil { - return errors.Wrap(err, "close gzip writer") + return fmt.Errorf("close gzip writer: %w", err) } return nil }()) diff --git a/pkg/supervisor/gzip_pipe_test.go b/pkg/supervisor/gzip_pipe_test.go index f1cf8121..27b20598 100644 --- a/pkg/supervisor/gzip_pipe_test.go +++ b/pkg/supervisor/gzip_pipe_test.go @@ -3,12 +3,13 @@ package supervisor import ( "bytes" "compress/gzip" + "errors" + "fmt" "io" "io/ioutil" "strings" "testing" - "github.com/pkg/errors" "github.com/stretchr/testify/require" ) @@ -92,10 +93,10 @@ func TestIOPipes(t *testing.T) { pw.CloseWithError(func() error { _, err := io.Copy(gw, bytes.NewReader(data)) if err != nil { - return errors.Wrap(err, "copy to gw") + return fmt.Errorf("copy to gw: %w", err) } if err = gw.Close(); err != nil { - return errors.Wrap(err, "close gzip writer") + return fmt.Errorf("close gzip writer: %w", err) } return nil }()) diff --git a/pkg/supervisor/supervisor.go b/pkg/supervisor/supervisor.go index 05f1e013..fd9ef098 100644 --- a/pkg/supervisor/supervisor.go +++ b/pkg/supervisor/supervisor.go @@ -3,12 +3,13 @@ package supervisor import ( "context" "database/sql" + "errors" + "fmt" "io" "os" "strings" "time" - "github.com/pkg/errors" "github.com/segmentio/ctlstore/pkg/reflector" "github.com/segmentio/events/v2" "github.com/segmentio/stats/v4" @@ -45,7 +46,7 @@ func SupervisorFromConfig(config SupervisorConfig) (Supervisor, error) { for _, url := range urls { snapshot, err := archivedSnapshotFromURL(url) if err != nil { - return nil, errors.Wrapf(err, "configure snapshot for '%s'", url) + return nil, fmt.Errorf("configure snapshot for '%s': %w", url, err) } snapshots = append(snapshots, snapshot) } @@ -63,18 +64,18 @@ func (s *supervisor) snapshot(ctx context.Context) error { s.reflectorCtl.Stop(ctx) defer s.reflectorCtl.Start(ctx) if err := s.checkpointLDB(); err != nil { - return errors.Wrap(err, "checkpoint ldb") + return fmt.Errorf("checkpoint ldb: %w", err) } info, err := os.Stat(s.LDBPath) if err != nil { - return errors.Wrap(err, "stat ldb path") + return fmt.Errorf("stat ldb path: %w", err) } stats.Set("ldb-size-bytes", info.Size()) errs := make(chan error, len(s.Snapshots)) for _, snapshot := range s.Snapshots { go func(snapshot archivedSnapshot) { err := snapshot.Upload(ctx, s.LDBPath) - errs <- errors.Wrapf(err, "upload snapshot") + errs <- fmt.Errorf("upload snapshot: %w", err) }(snapshot) } for range s.Snapshots { @@ -89,31 +90,31 @@ func (s *supervisor) checkpointLDB() error { ctx := context.Background() // we do not want to interrupt this operation srcDb, err := sql.Open("sqlite3", s.LDBPath+"?_journal_mode=wal") if err != nil { - return errors.Wrap(err, "opening source db") + return fmt.Errorf("opening source db: %w", err) } defer srcDb.Close() conn, err := srcDb.Conn(ctx) if err != nil { - return errors.Wrap(err, "src db connection") + return fmt.Errorf("src db connection: %w", err) } defer conn.Close() _, err = conn.ExecContext(ctx, "PRAGMA wal_checkpoint(PASSIVE);") if err != nil { - return errors.Wrap(err, "checkpointing database") + return fmt.Errorf("checkpointing database: %w", err) } _, err = conn.ExecContext(ctx, "VACUUM") if err != nil { - return errors.Wrap(err, "vacuuming database") + return fmt.Errorf("vacuuming database: %w", err) } // This will prevent any writes while the copy is taking place _, err = conn.ExecContext(ctx, "BEGIN IMMEDIATE TRANSACTION;") if err != nil { - return errors.Wrap(err, "locking database") + return fmt.Errorf("locking database: %w", err) } events.Log("Acquired write lock on %{srcDb}s", s.LDBPath) _, err = conn.ExecContext(ctx, "COMMIT;") if err != nil { - return errors.Wrap(err, "commit") + return fmt.Errorf("commit: %w", err) } events.Log("Released write lock on %{srcDb}s", s.LDBPath) events.Log("Checkpointed WAL on %{srcDb}s", s.LDBPath) @@ -158,7 +159,7 @@ func (s *supervisor) Start(ctx context.Context) { return } err := s.snapshot(ctx) - if err != nil && errors.Cause(err) != context.Canceled { + if errors.Is(err, context.Canceled) { s.incrementSnapshotErrorMetric(1) events.Log("Error taking snapshot: %{error}+v", err) // Use a shorter sleep duration for faster retries diff --git a/pkg/unsafe/unsafe_test.go b/pkg/unsafe/unsafe_test.go index fc2d8ac7..a81bd436 100644 --- a/pkg/unsafe/unsafe_test.go +++ b/pkg/unsafe/unsafe_test.go @@ -1,3 +1,4 @@ +//go:build !race // +build !race package unsafe diff --git a/pkg/utils/atomic_bool.go b/pkg/utils/atomic_bool.go index 2c78873f..73ac290b 100644 --- a/pkg/utils/atomic_bool.go +++ b/pkg/utils/atomic_bool.go @@ -1,6 +1,8 @@ package utils -import "sync/atomic" +import ( + "sync/atomic" +) type AtomicBool int32 diff --git a/pkg/utils/ensure_dir.go b/pkg/utils/ensure_dir.go index 4789e4bf..ecce96ac 100644 --- a/pkg/utils/ensure_dir.go +++ b/pkg/utils/ensure_dir.go @@ -1,10 +1,9 @@ package utils import ( + "fmt" "os" "path/filepath" - - "github.com/pkg/errors" ) // EnsureDirForFile ensures that the specified file's parent directory @@ -17,8 +16,8 @@ func EnsureDirForFile(file string) error { return nil case os.IsNotExist(err): err = os.Mkdir(dir, 0700) - return errors.Wrapf(err, "mkdir %s", dir) + return fmt.Errorf("mkdir %s: %w", dir, err) default: - return errors.Wrapf(err, "stat %s", dir) + return fmt.Errorf("stat %s: %w", dir, err) } } diff --git a/pkg/utils/interface_slice.go b/pkg/utils/interface_slice.go index 8a8c4752..5294f04a 100644 --- a/pkg/utils/interface_slice.go +++ b/pkg/utils/interface_slice.go @@ -1,6 +1,8 @@ package utils -import "reflect" +import ( + "reflect" +) // Converts args to an interface slice using the following rules: // diff --git a/pkg/version/version_go1_12.go b/pkg/version/version_go1_12.go index 329d9156..3058439e 100644 --- a/pkg/version/version_go1_12.go +++ b/pkg/version/version_go1_12.go @@ -1,3 +1,4 @@ +//go:build go1.12 // +build go1.12 package version diff --git a/tools.go b/tools.go index 689ef896..97576b08 100644 --- a/tools.go +++ b/tools.go @@ -1,3 +1,4 @@ +//go:build tools // +build tools package ctlstore diff --git a/version.go b/version.go index 4036ef46..1b88fa57 100644 --- a/version.go +++ b/version.go @@ -1,6 +1,8 @@ package ctlstore -import "github.com/segmentio/ctlstore/pkg/version" +import ( + "github.com/segmentio/ctlstore/pkg/version" +) // Version is the current ctlstore client library version. var Version string