Skip to content

Commit

Permalink
feature(api): add migration to insert vpn default value
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto committed Aug 30, 2024
1 parent b363f31 commit 6edbde8
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 0 deletions.
1 change: 1 addition & 0 deletions api/store/mongo/migrations/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func GenerateMigrations() []migrate.Migration {
migration73,
migration74,
migration75,
migration76,
}
}

Expand Down
64 changes: 64 additions & 0 deletions api/store/mongo/migrations/migration_76.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package migrations

import (
"context"

"github.com/shellhub-io/shellhub/pkg/envs"
"github.com/sirupsen/logrus"
migrate "github.com/xakep666/mongo-migrate"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)

var migration76 = migrate.Migration{
Version: 76,
Description: "Adding VPN settings to namespace",
Up: migrate.MigrationFunc(func(ctx context.Context, db *mongo.Database) error {
logrus.WithFields(logrus.Fields{
"component": "migration",
"version": 76,
"action": "Up",
}).Info("Applying migration")

if envs.IsEnterprise() {
update := bson.M{
"$set": bson.M{
"vpn": bson.M{
"enable": false,
"address": bson.A{10, 0, 0, 0},
"mask": 24,
},
},
}

_, err := db.
Collection("namespaces").
UpdateMany(ctx, bson.M{}, update)

return err
}

return nil
}),
Down: migrate.MigrationFunc(func(ctx context.Context, db *mongo.Database) error {
logrus.WithFields(logrus.Fields{
"component": "migration",
"version": 76,
"action": "Down",
}).Info("Reverting migration")

if envs.IsEnterprise() {
update := bson.M{
"$unset": bson.M{"vpn": ""},
}

_, err := db.
Collection("namespaces").
UpdateMany(ctx, bson.M{}, update)

return err
}

return nil
}),
}
123 changes: 123 additions & 0 deletions api/store/mongo/migrations/migration_76_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package migrations

import (
"context"
"errors"
"testing"

"github.com/shellhub-io/shellhub/pkg/envs"
env_mocks "github.com/shellhub-io/shellhub/pkg/envs/mocks"
"github.com/shellhub-io/shellhub/pkg/models"
"github.com/stretchr/testify/assert"
migrate "github.com/xakep666/mongo-migrate"
"go.mongodb.org/mongo-driver/bson"
)

func TestMigration76(t *testing.T) {
ctx := context.Background()

envMock = &env_mocks.Backend{}
envs.DefaultBackend = envMock

cases := []struct {
description string
setup func() error
requireMocks func()
test func() error
}{
{
description: "Success to apply up on migration 76",
setup: func() error {
_, err := c.
Database("test").
Collection("namespaces").
InsertOne(ctx, models.Namespace{
TenantID: "00000000-0000-4000-0000-000000000000",
})

return err
},
requireMocks: func() {
envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("true").Once()
},
test: func() error {
migrations := GenerateMigrations()[75:76]
migrates := migrate.NewMigrate(c.Database("test"), migrations...)
err := migrates.Up(context.Background(), migrate.AllAvailable)
if err != nil {
return err
}

query := c.
Database("test").
Collection("namespaces").
FindOne(context.TODO(), bson.M{"tenant_id": "00000000-0000-4000-0000-000000000000"})

ns := new(models.Namespace)
if err := query.Decode(ns); err != nil {
return errors.New("unable to find the namespace")
}

if ns.VPN == nil {
return errors.New("unable to apply the migration")
}

return nil
},
}, {
description: "Success to unapply the migration 76",
setup: func() error {
_, err := c.
Database("test").
Collection("namespaces").
InsertOne(ctx, models.Namespace{
TenantID: "00000000-0000-4000-0000-000000000000",
Settings: &models.NamespaceSettings{},
})

return err
},
requireMocks: func() {},
test: func() error {
migrations := GenerateMigrations()[75:76]
migrates := migrate.NewMigrate(c.Database("test"), migrations...)
err := migrates.Down(context.Background(), migrate.AllAvailable)
if err != nil {
return err
}

query := c.
Database("test").
Collection("namespaces").
FindOne(context.TODO(), bson.M{"tenant_id": "00000000-0000-4000-0000-000000000000"})

ns := new(models.Namespace)
if err := query.Decode(ns); err != nil {
return errors.New("unable to find the namespace")
}

if ns.VPN != nil {
return errors.New("unable to unapply the migration")
}

return nil
},
},
}

for _, test := range cases {
tc := test
t.Run(tc.description, func(t *testing.T) {
tc.requireMocks()

t.Cleanup(func() {
assert.NoError(t, srv.Reset())
})

assert.NoError(t, tc.setup())
assert.NoError(t, tc.test())

envMock.AssertExpectations(t)
})
}
}

0 comments on commit 6edbde8

Please sign in to comment.