diff --git a/database/pgx/README.md b/database/pgx/README.md index bad669315..bec7c5c75 100644 --- a/database/pgx/README.md +++ b/database/pgx/README.md @@ -11,6 +11,8 @@ This package is for [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). A back | `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds | | `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) | | `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) | +| `x-lock-strategy` | `LockStrategy` | Strategy used for locking during migration (default: advisory) | +| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock (default: schema_lock) | | `dbname` | `DatabaseName` | The name of the database to connect to | | `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. | | `user` | | The user to sign in as | diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index deaca94ea..7e42d29c9 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -23,6 +23,12 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" _ "github.com/jackc/pgx/v4/stdlib" + "github.com/lib/pq" +) + +const ( + LockStrategyAdvisory = "advisory" + LockStrategyTable = "table" ) func init() { @@ -36,6 +42,8 @@ var ( DefaultMigrationsTable = "schema_migrations" DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB + DefaultLockTable = "schema_lock" + DefaultLockStrategy = LockStrategyAdvisory ) var ( @@ -49,6 +57,8 @@ type Config struct { MigrationsTable string DatabaseName string SchemaName string + LockTable string + LockStrategy string migrationsSchemaName string migrationsTableName string StatementTimeout time.Duration @@ -108,6 +118,14 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.MigrationsTable = DefaultMigrationsTable } + if len(config.LockTable) == 0 { + config.LockTable = DefaultLockTable + } + + if len(config.LockStrategy) == 0 { + config.LockStrategy = DefaultLockStrategy + } + config.migrationsSchemaName = config.SchemaName config.migrationsTableName = config.MigrationsTable if config.MigrationsTableQuoted { @@ -133,6 +151,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } + if err := px.ensureLockTable(); err != nil { + return nil, err + } + if err := px.ensureVersionTable(); err != nil { return nil, err } @@ -196,6 +218,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } } + lockStrategy := purl.Query().Get("x-lock-strategy") + lockTable := purl.Query().Get("x-lock-table") + px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, @@ -203,6 +228,8 @@ func (p *Postgres) Open(url string) (database.Driver, error) { StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, MultiStatementEnabled: multiStatementEnabled, MultiStatementMaxSize: multiStatementMaxSize, + LockStrategy: lockStrategy, + LockTable: lockTable, }) if err != nil { @@ -221,36 +248,116 @@ func (p *Postgres) Close() error { return nil } -// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err - } - - // This will wait indefinitely until the lock can be acquired. - query := `SELECT pg_advisory_lock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + switch p.config.LockStrategy { + case LockStrategyAdvisory: + return p.applyAdvisoryLock() + case LockStrategyTable: + return p.applyTableLock() + default: + return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy) } - return nil }) } func (p *Postgres) Unlock() error { return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error { - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) - if err != nil { - return err + switch p.config.LockStrategy { + case LockStrategyAdvisory: + return p.releaseAdvisoryLock() + case LockStrategyTable: + return p.releaseTableLock() + default: + return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy) } + }) +} - query := `SELECT pg_advisory_unlock($1)` - if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} +// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS +func (p *Postgres) applyAdvisoryLock() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } + + // This will wait indefinitely until the lock can be acquired. + query := `SELECT pg_advisory_lock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} + } + return nil +} + +func (p *Postgres) applyTableLock() error { + tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return &database.Error{OrigErr: err, Err: "transaction start failed"} + } + defer func() { + errRollback := tx.Rollback() + if errRollback != nil { + err = multierror.Append(err, errRollback) } - return nil - }) + }() + + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) + if err != nil { + return err + } + + query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" + rows, err := tx.Query(query, aid) + if err != nil { + return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)} + } + + defer func() { + if errClose := rows.Close(); errClose != nil { + err = multierror.Append(err, errClose) + } + }() + + // If row exists at all, lock is present + locked := rows.Next() + if locked { + return database.ErrLocked + } + + query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)" + if _, err := tx.Exec(query, aid); err != nil { + return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)} + } + + return tx.Commit() +} + +func (p *Postgres) releaseAdvisoryLock() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName) + if err != nil { + return err + } + + query := `SELECT pg_advisory_unlock($1)` + if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + return nil +} + +func (p *Postgres) releaseTableLock() error { + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) + if err != nil { + return err + } + + query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1" + if _, err := p.db.Exec(query, aid); err != nil { + return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)} + } + + return nil } func (p *Postgres) Run(migration io.Reader) error { @@ -414,6 +521,12 @@ func (p *Postgres) Drop() (err error) { if err := tables.Scan(&tableName); err != nil { return err } + + // do not drop lock table + if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable { + continue + } + if len(tableName) > 0 { tableNames = append(tableNames, tableName) } @@ -478,6 +591,28 @@ func (p *Postgres) ensureVersionTable() (err error) { return nil } +func (p *Postgres) ensureLockTable() error { + if p.config.LockStrategy != LockStrategyTable { + return nil + } + + var count int + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + if count == 1 { + return nil + } + + query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)` + if _, err := p.db.Exec(query); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + + return nil +} + // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { end := strings.IndexRune(name, 0) diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index 5d7a5238e..53e8e1d86 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -134,6 +134,32 @@ func TestMigrate(t *testing.T) { }) } +func TestMigrateLockTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table") + p := &Postgres{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d) + if err != nil { + t.Fatal(err) + } + dt.TestMigrate(t, m) + }) +} + func TestMultipleStatements(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort()