From 814769373f6ea86c027e6551122880282d8e8128 Mon Sep 17 00:00:00 2001 From: Alex Kucksdorf Date: Wed, 26 Oct 2022 12:14:04 +0200 Subject: [PATCH 1/2] [sqlserver] Ensure version table in provided schema Closes #839 --- database/sqlserver/sqlserver.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 90e3926df..c1220b555 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -365,10 +365,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { query := `IF NOT EXISTS (SELECT * FROM sysobjects - WHERE id = object_id(N'[dbo].[` + ss.config.MigrationsTable + `]') + WHERE id = object_id(N'[` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `]') AND OBJECTPROPERTY(id, N'IsUserTable') = 1 ) - CREATE TABLE ` + ss.config.MigrationsTable + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` + CREATE TABLE [` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `] ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} From 0350a00606ffb68b0036adc93561fa3dc2b67115 Mon Sep 17 00:00:00 2001 From: Alex Kucksdorf Date: Wed, 26 Oct 2022 12:51:46 +0200 Subject: [PATCH 2/2] [sqlserver] Always access version table with explicit schema --- database/sqlserver/sqlserver.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index c1220b555..3cfa48bf9 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -263,7 +263,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } - query := `TRUNCATE TABLE "` + ss.config.MigrationsTable + `"` + query := `TRUNCATE TABLE ` + ss.getMigrationTable() if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -279,7 +279,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { if dirty { dirtyBit = 1 } - query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)` + query = `INSERT INTO ` + ss.getMigrationTable() + ` (version, dirty) VALUES (@p1, @p2)` if _, err := tx.Exec(query, version, dirtyBit); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -297,7 +297,7 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { // Version of the current database state func (ss *SQLServer) Version() (version int, dirty bool, err error) { - query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"` + query := `SELECT TOP 1 version, dirty FROM ` + ss.getMigrationTable() err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { case err == sql.ErrNoRows: @@ -365,10 +365,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { query := `IF NOT EXISTS (SELECT * FROM sysobjects - WHERE id = object_id(N'[` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `]') + WHERE id = object_id(N'` + ss.getMigrationTable() + `') AND OBJECTPROPERTY(id, N'IsUserTable') = 1 ) - CREATE TABLE [` + ss.config.SchemaName + `].[` + ss.config.MigrationsTable + `] ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` + CREATE TABLE ` + ss.getMigrationTable() + ` ( version BIGINT PRIMARY KEY NOT NULL, dirty BIT NOT NULL );` if _, err = ss.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -377,6 +377,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { return nil } +func (ss *SQLServer) getMigrationTable() string { + return fmt.Sprintf("[%s].[%s]", ss.config.SchemaName, ss.config.MigrationsTable) +} + func getMSITokenProvider(resource string) (func() (string, error), error) { msi, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, nil) if err != nil {