Skip to content

Commit

Permalink
chore: add context.Context everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
joschi committed Oct 6, 2024
1 parent c378583 commit 6fe5f52
Show file tree
Hide file tree
Showing 80 changed files with 1,381 additions and 1,183 deletions.
31 changes: 16 additions & 15 deletions database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cassandra

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -52,7 +53,7 @@ type Cassandra struct {
config *Config
}

func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
func WithInstance(ctx context.Context, session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if len(config.KeyspaceName) == 0 {
Expand All @@ -76,14 +77,14 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
config: config,
}

if err := c.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(ctx); err != nil {
return nil, err
}

return c, nil
}

func (c *Cassandra) Open(url string) (database.Driver, error) {
func (c *Cassandra) Open(ctx context.Context, url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
Expand Down Expand Up @@ -185,34 +186,34 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
}
}

return WithInstance(session, &Config{
return WithInstance(ctx, session, &Config{
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
MigrationsTable: u.Query().Get("x-migrations-table"),
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
MultiStatementMaxSize: multiStatementMaxSize,
})
}

func (c *Cassandra) Close() error {
func (c *Cassandra) Close(ctx context.Context) error {
c.session.Close()
return nil
}

func (c *Cassandra) Lock() error {
func (c *Cassandra) Lock(ctx context.Context) error {
if !c.isLocked.CAS(false, true) {
return database.ErrLocked
}
return nil
}

func (c *Cassandra) Unlock() error {
func (c *Cassandra) Unlock(ctx context.Context) error {
if !c.isLocked.CAS(true, false) {
return database.ErrNotLocked
}
return nil
}

func (c *Cassandra) Run(migration io.Reader) error {
func (c *Cassandra) Run(ctx context.Context, migration io.Reader) error {
if c.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
Expand Down Expand Up @@ -243,7 +244,7 @@ func (c *Cassandra) Run(migration io.Reader) error {
return nil
}

func (c *Cassandra) SetVersion(version int, dirty bool) error {
func (c *Cassandra) SetVersion(ctx context.Context, version int, dirty bool) error {
// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
Expand Down Expand Up @@ -273,7 +274,7 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error {
}

// Return current keyspace version
func (c *Cassandra) Version() (version int, dirty bool, err error) {
func (c *Cassandra) Version(ctx context.Context) (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.session.Query(query).Scan(&version, &dirty)
switch {
Expand All @@ -291,7 +292,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) {
}
}

func (c *Cassandra) Drop() error {
func (c *Cassandra) Drop(ctx context.Context) error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
iter := c.session.Query(query).Iter()
Expand All @@ -309,13 +310,13 @@ func (c *Cassandra) Drop() error {
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Cassandra type.
func (c *Cassandra) ensureVersionTable() (err error) {
if err = c.Lock(); err != nil {
func (c *Cassandra) ensureVersionTable(ctx context.Context) (err error) {
if err = c.Lock(ctx); err != nil {
return err
}

defer func() {
if e := c.Unlock(); e != nil {
if e := c.Unlock(ctx); e != nil {
if err == nil {
err = e
} else {
Expand All @@ -328,7 +329,7 @@ func (c *Cassandra) ensureVersionTable() (err error) {
if err != nil {
return err
}
if _, _, err = c.Version(); err != nil {
if _, _, err = c.Version(ctx); err != nil {
return err
}
return nil
Expand Down
12 changes: 7 additions & 5 deletions database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ func Test(t *testing.T) {

func test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
d, err := p.Open(ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
if err := d.Close(ctx); err != nil {
t.Error(err)
}
}()
Expand All @@ -97,23 +98,24 @@ func test(t *testing.T) {

func testMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
d, err := p.Open(ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
if err := d.Close(ctx); err != nil {
t.Error(err)
}
}()

m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d)
m, err := migrate.NewWithDatabaseInstance(ctx, "file://./examples/migrations", "testks", d)
if err != nil {
t.Fatal(err)
}
Expand Down
33 changes: 17 additions & 16 deletions database/clickhouse/clickhouse.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clickhouse

import (
"context"
"database/sql"
"fmt"
"io"
Expand Down Expand Up @@ -40,7 +41,7 @@ func init() {
database.Register("clickhouse", &ClickHouse{})
}

func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
func WithInstance(ctx context.Context, conn *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
Expand All @@ -54,7 +55,7 @@ func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := ch.init(); err != nil {
if err := ch.init(ctx); err != nil {
return nil, err
}

Expand All @@ -67,7 +68,7 @@ type ClickHouse struct {
isLocked atomic.Bool
}

func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
func (ch *ClickHouse) Open(ctx context.Context, dsn string) (database.Driver, error) {
purl, err := url.Parse(dsn)
if err != nil {
return nil, err
Expand Down Expand Up @@ -104,14 +105,14 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
},
}

if err := ch.init(); err != nil {
if err := ch.init(ctx); err != nil {
return nil, err
}

return ch, nil
}

func (ch *ClickHouse) init() error {
func (ch *ClickHouse) init(ctx context.Context) error {
if len(ch.config.DatabaseName) == 0 {
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
return err
Expand All @@ -130,10 +131,10 @@ func (ch *ClickHouse) init() error {
ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
}

return ch.ensureVersionTable()
return ch.ensureVersionTable(ctx)
}

func (ch *ClickHouse) Run(r io.Reader) error {
func (ch *ClickHouse) Run(ctx context.Context, r io.Reader) error {
if ch.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
Expand Down Expand Up @@ -163,7 +164,7 @@ func (ch *ClickHouse) Run(r io.Reader) error {

return nil
}
func (ch *ClickHouse) Version() (int, bool, error) {
func (ch *ClickHouse) Version(ctx context.Context) (int, bool, error) {
var (
version int
dirty uint8
Expand All @@ -178,7 +179,7 @@ func (ch *ClickHouse) Version() (int, bool, error) {
return version, dirty == 1, nil
}

func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
func (ch *ClickHouse) SetVersion(ctx context.Context, version int, dirty bool) error {
var (
bool = func(v bool) uint8 {
if v {
Expand All @@ -203,13 +204,13 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the ClickHouse type.
func (ch *ClickHouse) ensureVersionTable() (err error) {
if err = ch.Lock(); err != nil {
func (ch *ClickHouse) ensureVersionTable(ctx context.Context) (err error) {
if err = ch.Lock(ctx); err != nil {
return err
}

defer func() {
if e := ch.Unlock(); e != nil {
if e := ch.Unlock(ctx); e != nil {
if err == nil {
err = e
} else {
Expand Down Expand Up @@ -258,7 +259,7 @@ func (ch *ClickHouse) ensureVersionTable() (err error) {
return nil
}

func (ch *ClickHouse) Drop() (err error) {
func (ch *ClickHouse) Drop(ctx context.Context) (err error) {
query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
tables, err := ch.conn.Query(query)

Expand Down Expand Up @@ -290,21 +291,21 @@ func (ch *ClickHouse) Drop() (err error) {
return nil
}

func (ch *ClickHouse) Lock() error {
func (ch *ClickHouse) Lock(ctx context.Context) error {
if !ch.isLocked.CAS(false, true) {
return database.ErrLocked
}

return nil
}
func (ch *ClickHouse) Unlock() error {
func (ch *ClickHouse) Unlock(ctx context.Context) error {
if !ch.isLocked.CAS(true, false) {
return database.ErrNotLocked
}

return nil
}
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
func (ch *ClickHouse) Close(ctx context.Context) error { return ch.conn.Close() }

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
Expand Down
Loading

0 comments on commit 6fe5f52

Please sign in to comment.