diff --git a/agent/go.mod b/agent/go.mod index 3a258d8d3c3..798dad57db1 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -48,7 +48,10 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pkg/sftp v1.13.5 // indirect github.com/sethvargo/go-envconfig v0.9.0 // indirect + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/vishvananda/netlink v1.2.1-beta.2 // indirect + github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel v1.26.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.26.0 // indirect diff --git a/agent/go.sum b/agent/go.sum index 341adc2c9ea..c5be11d9733 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -98,6 +98,8 @@ github.com/shellhub-io/ssh v0.0.0-20230224143412-edd48dfd6eea h1:7tEI9nukSYZViCj github.com/shellhub-io/ssh v0.0.0-20230224143412-edd48dfd6eea/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -118,6 +120,11 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= +github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= +github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f h1:p4VB7kIXpOQvVn1ZaTIVp+3vuYAXFe3OJEvjbUYJLaA= +github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= @@ -159,6 +166,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/agent/main.go b/agent/main.go index 657aff93733..177779cfc98 100644 --- a/agent/main.go +++ b/agent/main.go @@ -109,7 +109,7 @@ func main() { "tenant_id": cfg.TenantID, "server_address": cfg.ServerAddress, "preferred_hostname": cfg.PreferredHostname, - }).Info("Listening for connections") + }).Info("Listening for SSH connections") // Disable check update in development mode if AgentVersion != "latest" { @@ -163,15 +163,39 @@ func main() { }() } - if err := ag.ListenSSH(ctx); err != nil { - log.WithError(err).WithFields(log.Fields{ - "version": AgentVersion, - "mode": mode, - "tenant_id": cfg.TenantID, - "server_address": cfg.ServerAddress, - "preferred_hostname": cfg.PreferredHostname, - }).Fatal("Failed to listen for SSH connections") - } + go func() { + if err := ag.ListenSSH(ctx); err != nil { + log.WithError(err).WithFields(log.Fields{ + "version": AgentVersion, + "mode": mode, + "tenant_id": cfg.TenantID, + "server_address": cfg.ServerAddress, + "preferred_hostname": cfg.PreferredHostname, + }).Fatal("Failed to listen for SSH connections") + } + }() + + go func() { + if !cfg.VPN { + log.Info("VPN is disable") + + return + } + + log.Debug("VPN enabled") + + for { + log.Info("VPN connected") + + if err := ag.ConnectVPN(ctx); err != nil { + log.WithError(err).Error("Connect to VPN lost. Retrying in 10 seconds.") + } + + time.Sleep(10 * time.Second) + } + }() + + <-ctx.Done() log.WithFields(log.Fields{ "version": AgentVersion, @@ -179,7 +203,7 @@ func main() { "tenant_id": cfg.TenantID, "server_address": cfg.ServerAddress, "preferred_hostname": cfg.PreferredHostname, - }).Info("Stopped listening for connections") + }).Info("Agent Stopped") }, } diff --git a/api/services/errors.go b/api/services/errors.go index 8cef919b130..5d65eb7807d 100644 --- a/api/services/errors.go +++ b/api/services/errors.go @@ -130,6 +130,8 @@ var ( ErrAPIKeyDuplicated = errors.New("APIKey duplicated", ErrLayer, ErrCodeDuplicated) ErrAuthForbidden = errors.New("user is authenticated but cannot access this resource", ErrLayer, ErrCodeForbidden) ErrRoleInvalid = errors.New("role is invalid", ErrLayer, ErrCodeForbidden) + ErrNamespaceIPInvalid = errors.New("ip is invalid", ErrLayer, ErrCodeForbidden) + ErrNamespaceIPNotPrivate = errors.New("ip is not a private address", ErrLayer, ErrCodeForbidden) ) func NewErrRoleInvalid() error { @@ -471,3 +473,11 @@ func NewErrDeviceMaxDevicesReached(count int) error { func NewErrAuthForbidden() error { return NewErrForbidden(ErrAuthForbidden, nil) } + +func NewErrNamespaceIPInvalid() error { + return NewErrInvalid(ErrNamespaceIPInvalid, nil, nil) +} + +func NewErrNamespaceIPNotPrivate() error { + return NewErrInvalid(ErrNamespaceIPNotPrivate, nil, nil) +} diff --git a/api/services/namespace.go b/api/services/namespace.go index 8b1b62c810e..dfaaf3f8787 100644 --- a/api/services/namespace.go +++ b/api/services/namespace.go @@ -3,11 +3,13 @@ package services import ( "context" "errors" + "net" "strings" "github.com/shellhub-io/shellhub/api/store" "github.com/shellhub-io/shellhub/api/store/mongo" "github.com/shellhub-io/shellhub/pkg/api/authorizer" + "github.com/shellhub-io/shellhub/pkg/api/internalclient" "github.com/shellhub-io/shellhub/pkg/api/requests" "github.com/shellhub-io/shellhub/pkg/clock" "github.com/shellhub-io/shellhub/pkg/envs" @@ -204,6 +206,27 @@ func (s *service) EditNamespace(ctx context.Context, req *requests.NamespaceEdit ConnectionAnnouncement: req.Settings.ConnectionAnnouncement, } + if envs.IsEnterprise() { + changes.VPNEnable = req.VPN.Enable + + if req.VPN.Address != nil { + address := *req.VPN.Address + ip := net.IPv4(address[0], address[1], address[2], address[3]) + + if ip.IsLoopback() || ip.IsUnspecified() { + return nil, NewErrNamespaceIPInvalid() + } + + if !ip.IsPrivate() { + return nil, NewErrNamespaceIPNotPrivate() + } + + changes.VPNAddress = &address + } + + changes.VPNMask = req.VPN.Mask + } + if err := s.store.NamespaceEdit(ctx, req.Tenant, changes); err != nil { switch { case errors.Is(err, store.ErrNoDocuments): @@ -213,6 +236,14 @@ func (s *service) EditNamespace(ctx context.Context, req *requests.NamespaceEdit } } + if envs.IsEnterprise() { + cli := s.client.(internalclient.Client) + + if err := cli.VPNStopRouter(req.Tenant); err != nil { + return nil, err + } + } + return s.store.NamespaceGet(ctx, req.Tenant, true) } diff --git a/api/services/namespace_test.go b/api/services/namespace_test.go index e0380bf2318..c8d76fd1dc9 100644 --- a/api/services/namespace_test.go +++ b/api/services/namespace_test.go @@ -937,6 +937,7 @@ func TestEditNamespace(t *testing.T) { tenantID: "xxxxx", namespaceName: "newname", requiredMocks: func() { + envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("false").Once() mock.On("NamespaceEdit", ctx, "xxxxx", &models.NamespaceChanges{Name: "newname"}). Return(store.ErrNoDocuments). Once() @@ -951,6 +952,7 @@ func TestEditNamespace(t *testing.T) { tenantID: "xxxxx", namespaceName: "newname", requiredMocks: func() { + envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("false").Once() mock.On("NamespaceEdit", ctx, "xxxxx", &models.NamespaceChanges{Name: "newname"}). Return(errors.New("error")). Once() @@ -965,6 +967,7 @@ func TestEditNamespace(t *testing.T) { namespaceName: "newName", tenantID: "xxxxx", requiredMocks: func() { + envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("false").Once() mock.On("NamespaceEdit", ctx, "xxxxx", &models.NamespaceChanges{Name: "newname"}). Return(nil). Once() @@ -991,6 +994,7 @@ func TestEditNamespace(t *testing.T) { namespaceName: "newname", tenantID: "xxxxx", requiredMocks: func() { + envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("false").Once() mock.On("NamespaceEdit", ctx, "xxxxx", &models.NamespaceChanges{Name: "newname"}). Return(nil). Once() @@ -1000,6 +1004,7 @@ func TestEditNamespace(t *testing.T) { Name: "newname", } + envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("false").Once() mock.On("NamespaceGet", ctx, "xxxxx", true). Return(namespace, nil). Once() diff --git a/api/store/mongo/migrations/main.go b/api/store/mongo/migrations/main.go index a0d8487261e..64a3ed9d347 100644 --- a/api/store/mongo/migrations/main.go +++ b/api/store/mongo/migrations/main.go @@ -86,6 +86,7 @@ func GenerateMigrations() []migrate.Migration { migration74, migration75, migration76, + migration77, } } diff --git a/api/store/mongo/migrations/migration_77.go b/api/store/mongo/migrations/migration_77.go new file mode 100644 index 00000000000..4007b2e40fa --- /dev/null +++ b/api/store/mongo/migrations/migration_77.go @@ -0,0 +1,64 @@ +package migrations + +import ( + "context" + + "github.com/shellhub-io/shellhub/pkg/envs" + "github.com/sirupsen/logrus" + migrate "github.com/xakep666/mongo-migrate" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +var migration77 = migrate.Migration{ + Version: 77, + Description: "Adding VPN settings to namespace", + Up: migrate.MigrationFunc(func(ctx context.Context, db *mongo.Database) error { + logrus.WithFields(logrus.Fields{ + "component": "migration", + "version": 77, + "action": "Up", + }).Info("Applying migration") + + if envs.IsEnterprise() { + update := bson.M{ + "$set": bson.M{ + "vpn": bson.M{ + "enable": false, + "address": bson.A{10, 0, 0, 0}, + "mask": 16, + }, + }, + } + + _, err := db. + Collection("namespaces"). + UpdateMany(ctx, bson.M{}, update) + + return err + } + + return nil + }), + Down: migrate.MigrationFunc(func(ctx context.Context, db *mongo.Database) error { + logrus.WithFields(logrus.Fields{ + "component": "migration", + "version": 77, + "action": "Down", + }).Info("Reverting migration") + + if envs.IsEnterprise() { + update := bson.M{ + "$unset": bson.M{"vpn": ""}, + } + + _, err := db. + Collection("namespaces"). + UpdateMany(ctx, bson.M{}, update) + + return err + } + + return nil + }), +} diff --git a/api/store/mongo/migrations/migration_77_test.go b/api/store/mongo/migrations/migration_77_test.go new file mode 100644 index 00000000000..32cba1c37c4 --- /dev/null +++ b/api/store/mongo/migrations/migration_77_test.go @@ -0,0 +1,123 @@ +package migrations + +import ( + "context" + "errors" + "testing" + + "github.com/shellhub-io/shellhub/pkg/envs" + env_mocks "github.com/shellhub-io/shellhub/pkg/envs/mocks" + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/stretchr/testify/assert" + migrate "github.com/xakep666/mongo-migrate" + "go.mongodb.org/mongo-driver/bson" +) + +func TestMigration77(t *testing.T) { + ctx := context.Background() + + envMock = &env_mocks.Backend{} + envs.DefaultBackend = envMock + + cases := []struct { + description string + setup func() error + requireMocks func() + test func() error + }{ + { + description: "Success to apply up on migration 77", + setup: func() error { + _, err := c. + Database("test"). + Collection("namespaces"). + InsertOne(ctx, models.Namespace{ + TenantID: "00000000-0000-4000-0000-000000000000", + }) + + return err + }, + requireMocks: func() { + envMock.On("Get", "SHELLHUB_ENTERPRISE").Return("true").Once() + }, + test: func() error { + migrations := GenerateMigrations()[76:77] + migrates := migrate.NewMigrate(c.Database("test"), migrations...) + err := migrates.Up(context.Background(), migrate.AllAvailable) + if err != nil { + return err + } + + query := c. + Database("test"). + Collection("namespaces"). + FindOne(context.TODO(), bson.M{"tenant_id": "00000000-0000-4000-0000-000000000000"}) + + ns := new(models.Namespace) + if err := query.Decode(ns); err != nil { + return errors.New("unable to find the namespace") + } + + if ns.VPN == nil { + return errors.New("unable to apply the migration") + } + + return nil + }, + }, { + description: "Success to unapply the migration 77", + setup: func() error { + _, err := c. + Database("test"). + Collection("namespaces"). + InsertOne(ctx, models.Namespace{ + TenantID: "00000000-0000-4000-0000-000000000000", + Settings: &models.NamespaceSettings{}, + }) + + return err + }, + requireMocks: func() {}, + test: func() error { + migrations := GenerateMigrations()[76:77] + migrates := migrate.NewMigrate(c.Database("test"), migrations...) + err := migrates.Down(context.Background(), migrate.AllAvailable) + if err != nil { + return err + } + + query := c. + Database("test"). + Collection("namespaces"). + FindOne(context.TODO(), bson.M{"tenant_id": "00000000-0000-4000-0000-000000000000"}) + + ns := new(models.Namespace) + if err := query.Decode(ns); err != nil { + return errors.New("unable to find the namespace") + } + + if ns.VPN != nil { + return errors.New("unable to unapply the migration") + } + + return nil + }, + }, + } + + for _, test := range cases { + tc := test + t.Run(tc.description, func(t *testing.T) { + tc.requireMocks() + + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + assert.NoError(t, tc.setup()) + assert.NoError(t, tc.test()) + + envMock.AssertExpectations(t) + }) + } +} diff --git a/docker-compose.agent.yml b/docker-compose.agent.yml index c2b5311b791..d9f0fc681ea 100644 --- a/docker-compose.agent.yml +++ b/docker-compose.agent.yml @@ -19,6 +19,7 @@ services: - SHELLHUB_PRIVATE_KEY=/go/src/github.com/shellhub-io/shellhub/agent/shellhub.key - SHELLHUB_TENANT_ID=00000000-0000-4000-0000-000000000000 - SHELLHUB_VERSION=${SHELLHUB_VERSION} + - SHELLHUB_VPN=${SHELLHUB_VPN} - SHELLHUB_LOG_LEVEL=${SHELLHUB_LOG_LEVEL} - SHELLHUB_LOG_FORMAT=${SHELLHUB_LOG_FORMAT} volumes: diff --git a/go.mod b/go.mod index 161e49da273..770b862253f 100644 --- a/go.mod +++ b/go.mod @@ -27,8 +27,10 @@ require ( github.com/pkg/sftp v1.13.5 github.com/sethvargo/go-envconfig v0.9.0 github.com/sirupsen/logrus v1.9.3 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go/modules/redis v0.32.0 + github.com/vishvananda/netlink v1.2.1-beta.2 golang.org/x/crypto v0.22.0 golang.org/x/sys v0.19.0 ) @@ -96,6 +98,7 @@ require ( github.com/ulikunitz/xz v0.5.11 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect + github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect github.com/vmihailenco/go-tinylfu v0.2.2 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 8c4bfc08193..a8928db850d 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -50,6 +50,7 @@ import ( "github.com/shellhub-io/shellhub/pkg/agent/pkg/keygen" "github.com/shellhub-io/shellhub/pkg/agent/pkg/sysinfo" "github.com/shellhub-io/shellhub/pkg/agent/ssh" + "github.com/shellhub-io/shellhub/pkg/agent/vpn" "github.com/shellhub-io/shellhub/pkg/api/client" "github.com/shellhub-io/shellhub/pkg/envs" "github.com/shellhub-io/shellhub/pkg/models" @@ -114,6 +115,9 @@ type Config struct { // MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait // before attempting to reconnect to the ShellHub server. Default is 60 seconds. MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"` + + // Defines if the device will try to connect to the namespace's VPN. + VPN bool `env:"VPN,default=false"` } func LoadConfigFromEnv() (*Config, map[string]interface{}, error) { @@ -163,6 +167,7 @@ type Agent struct { serverInfo *models.Info cli client.Client ssh *ssh.SSH + vpn *vpn.VPN mode Mode } @@ -357,6 +362,12 @@ func (a *Agent) ListenSSH(ctx context.Context) error { return a.ssh.Listen(ctx) } +func (a *Agent) ConnectVPN(ctx context.Context) error { + a.vpn = vpn.NewVPN(a.cli, a.authData.Token) + + return a.vpn.Connect(ctx) +} + // CheckUpdate gets the ShellHub's server version. func (a *Agent) CheckUpdate() (*semver.Version, error) { info, err := a.cli.GetInfo(AgentVersion) diff --git a/pkg/agent/pkg/tunnel/tunnel.go b/pkg/agent/pkg/tunnel/tunnel.go index a17d0549f8a..dd745f8e23b 100644 --- a/pkg/agent/pkg/tunnel/tunnel.go +++ b/pkg/agent/pkg/tunnel/tunnel.go @@ -83,6 +83,38 @@ func NewTunnel() *Tunnel { return t } +const ContextKeyHTTPConn string = "http-conn" + +// NewCustomTunnel creates a new [Tunnel] with the route to the connect, in a POST, and close, in a DELETE, actions. +func NewCustomTunnel(connPath string, closePath string) *Tunnel { + router := echo.New() + + t := &Tunnel{ + router: router, + srv: &http.Server{ + Handler: router, + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, ContextKeyHTTPConn, c) //nolint:revive + }, + }, + ConnHandler: func(e echo.Context) error { + panic("connHandler can not be nil") + }, + CloseHandler: func(e echo.Context) error { + panic("closeHandler can not be nil") + }, + } + + router.POST(connPath, func(e echo.Context) error { + return t.ConnHandler(e) + }) + router.DELETE(closePath, func(e echo.Context) error { + return t.CloseHandler(e) + }) + + return t +} + // Listen to reverse listener. func (t *Tunnel) Listen(l *revdial.Listener) error { return t.srv.Serve(l) diff --git a/pkg/agent/vpn/handlers.go b/pkg/agent/vpn/handlers.go new file mode 100644 index 00000000000..03a4c183995 --- /dev/null +++ b/pkg/agent/vpn/handlers.go @@ -0,0 +1,49 @@ +package vpn + +import ( + "net" + + "github.com/labstack/echo/v4" + log "github.com/sirupsen/logrus" +) + +func handler(handler func(net.Conn, *Settings) error) func(c echo.Context) error { + return func(c echo.Context) error { + log.Debug("handler started") + defer log.Debug("handler done") + + conn, _, err := c.Response().Hijack() + if err != nil { + log.Error(err) + + return err + } + + defer conn.Close() + + settings, err := ParseSettings(c.Request().Body) + if err != nil { + log.WithError(err).Error("faild to parse the settings") + + return err + } + + // NOTE: the [handler] is called to handler the core logic of the VPN client, while this handler is used to extract + // the connection and the settings data. + if err := handler(conn, settings); err != nil { + log.WithError(err).Error("failed to handler the vpn connection between server and agent") + + return err + } + + return nil + } +} + +func closeHandler(callback func() error) func(c echo.Context) error { + return func(c echo.Context) error { + log.Trace("close handler called") + + return callback() + } +} diff --git a/pkg/agent/vpn/pkg/ifce/interface.go b/pkg/agent/vpn/pkg/ifce/interface.go new file mode 100644 index 00000000000..b0d344548e9 --- /dev/null +++ b/pkg/agent/vpn/pkg/ifce/interface.go @@ -0,0 +1,110 @@ +package ifce + +import ( + "errors" + "fmt" + + "github.com/songgao/water" + "github.com/vishvananda/netlink" +) + +// MaximumTransmissionUnit (MTU) is the size of the largest protocol data unit (PDU) that can be communicated in a +// single network layer transaction. +const MaximumTransmissionUnit = 1000 + +const InterfaceName = "shb" + +var ErrGenerateInterface = errors.New("failed to generate an interface") + +func generateInterfaceName() (string, error) { + const attempts = 1000 + + for i := 0; i < attempts; i++ { + name := fmt.Sprintf("%s%d", InterfaceName, i) + + if _, err := netlink.LinkByName(name); err == nil { + continue + } + + return name, nil + } + + return "", ErrGenerateInterface +} + +type Interface struct { + face *water.Interface + link netlink.Link +} + +var ( + ErrInterfaceCreate = errors.New("failed to create the interface") + ErrInterfaceUp = errors.New("failed to get up the interface") + ErrInterfaceConfiguration = errors.New("failed to configure the interface") + ErrInterfaceMTU = errors.New("failed to configure the MTU") +) + +func NewInterface(addrr string) (*Interface, error) { + name, err := generateInterfaceName() + if err != nil { + return nil, err + } + + iface, err := water.New(water.Config{ + DeviceType: water.TUN, + PlatformSpecificParams: water.PlatformSpecificParams{ + Name: name, + MultiQueue: true, + }, + }) + if err != nil { + return nil, errors.Join(ErrInterfaceCreate, err) + } + + addr, err := netlink.ParseAddr(addrr) + if err != nil { + return nil, errors.Join(ErrInterfaceConfiguration, err) + } + + link, err := netlink.LinkByName(iface.Name()) + if err != nil { + return nil, errors.Join(ErrInterfaceConfiguration, err) + } + + if err := netlink.AddrAdd(link, addr); err != nil { + return nil, errors.Join(ErrInterfaceConfiguration, err) + } + + if err := netlink.LinkSetMTU(link, MaximumTransmissionUnit); err != nil { + return nil, errors.Join(ErrInterfaceMTU, err) + } + + return &Interface{ + face: iface, + link: link, + }, nil +} + +func (i *Interface) Up() error { + if err := netlink.LinkSetUp(i.link); err != nil { + return errors.Join(ErrInterfaceUp, err) + } + + return nil +} + +func (i *Interface) Name() string { + return i.face.Name() +} + +func (i *Interface) Close() error { + return i.face.Close() +} + +func (i *Interface) Read(buffer []byte) (int, error) { + return i.face.Read(buffer) +} + +func (i *Interface) Write(buffer []byte) (int, error) { + return i.face.Write(buffer) +} diff --git a/pkg/agent/vpn/pkg/packets/packets.go b/pkg/agent/vpn/pkg/packets/packets.go new file mode 100644 index 00000000000..1e98d08a8e1 --- /dev/null +++ b/pkg/agent/vpn/pkg/packets/packets.go @@ -0,0 +1,228 @@ +package packets + +import ( + "encoding/binary" + "net" +) + +const ( + Multicast = "multicast" + Unicast = "unicast" +) + +// PacketMaxSize defines the max size of each IP packet received. This value is related to the interface MTU. +const PacketMaxSize = 1024 + +// IPv4 packet. +// +// https://en.wikipedia.org/wiki/IPv4#Header +type Packet struct { + Version uint8 // Version of the IP protocol (typically 4 for IPv4) + IHL uint8 // Internet Header Length (in 32-bit words) + TOS uint8 // Type of Service + TotalLength uint16 // Total length of the packet (header + data) + Identification uint16 // Identification field for packet fragmentation + Flags uint8 // Flags for fragmentation control + FragmentOffset uint16 // Fragment offset + TTL uint8 // Time To Live + Protocol uint8 // Protocol (e.g., TCP, UDP) + HeaderChecksum uint16 // Checksum of the header + Source [4]byte // Source IP address + Destination [4]byte // Destination IP address + Options []byte // Optional fields (if any) + Payload []byte // Payload (data) +} + +// Protocols stores a map of [byte] for string matching the protocol's name. +// +// https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers +var Protocols = map[byte]string{ + 0x00: "HOPOPT", + 0x01: "ICMP", + 0x02: "IGMP", + 0x03: "GGP", + 0x04: "IP-in-IP", + 0x05: "ST", + 0x06: "TCP", + 0x07: "CBT", + 0x08: "EGP", + 0x09: "IGP", + 0x0A: "BBN-RCC-MON", + 0x0B: "NVP-II", + 0x0C: "PUP", + 0x0D: "ARGUS", + 0x0E: "EMCON", + 0x0F: "XNET", + 0x10: "CHAOS", + 0x11: "UDP", + 0x12: "MUX", + 0x13: "DCN-MEAS", + 0x14: "HMP", + 0x15: "PRM", + 0x16: "XNS-IDP", + 0x17: "TRUNK-1", + 0x18: "TRUNK-2", + 0x19: "LEAF-1", + 0x1A: "LEAF-2", + 0x1B: "RDP", + 0x1C: "IRTP", + 0x1D: "ISO-TP4", + 0x1E: "NETBLT", + 0x1F: "MFE-NSP", + 0x20: "MERIT-INP", + 0x21: "DCCP", + 0x22: "3PC", + 0x23: "IDPR", + 0x24: "XTP", + 0x25: "DDP", + 0x26: "IDPR-CMTP", + 0x27: "TP++", + 0x28: "IL", + 0x29: "IPv6", + 0x2A: "SDRP", + 0x2B: "IPv6-Route", + 0x2C: "IPv6-Frag", + 0x2D: "IDRP", + 0x2E: "RSVP", + 0x2F: "GRE", + 0x30: "DSR", + 0x31: "BNA", + 0x32: "ESP", + 0x33: "AH", + 0x34: "I-NLSP", + 0x35: "SwIPe", + 0x36: "NARP", + 0x37: "MOBILE", + 0x38: "TLSP", + 0x39: "SKIP", + 0x3A: "IPv6-ICMP", + 0x3B: "IPv6-NoNxt", + 0x3C: "IPv6-Opts", + 0x3D: "Any host internal protocol", + 0x3E: "CFTP", + 0x3F: "Any local network", + 0x40: "SAT-EXPAK", + 0x41: "KRYPTOLAN", + 0x42: "RVD", + 0x43: "IPPC", + 0x44: "Any distributed file system", + 0x45: "SAT-MON", + 0x46: "VISA", + 0x47: "IPCU", + 0x48: "CPNX", + 0x49: "CPHB", + 0x4A: "WSN", + 0x4B: "PVP", + 0x4C: "BR-SAT-MON", + 0x4D: "SUN-ND", + 0x4E: "WB-MON", + 0x4F: "WB-EXPAK", + 0x50: "ISO-IP", + 0x51: "VMTP", + 0x52: "SECURE-VMTP", + 0x53: "VINES", + 0x54: "TTP/IPTM", + 0x55: "NSFNET-IGP", + 0x56: "DGP", + 0x57: "TCF", + 0x58: "EIGRP", + 0x59: "OSPF", + 0x5A: "Sprite-RPC", + 0x5B: "LARP", + 0x5C: "MTP", + 0x5D: "AX.25", + 0x5E: "OS", + 0x5F: "MICP", + 0x60: "SCC-SP", + 0x61: "ETHERIP", + 0x62: "ENCAP", + 0x63: "Any private encryption scheme", + 0x64: "GMTP", + 0x65: "IFMP", + 0x66: "PNNI", + 0x67: "PIM", + 0x68: "ARIS", + 0x69: "SCPS", + 0x6A: "QNX", + 0x6B: "A/N", + 0x6C: "IPComp", + 0x6D: "SNP", + 0x6E: "Compaq-Peer", + 0x6F: "IPX-in-IP", + 0x70: "VRRP", + 0x71: "PGM", + 0x72: "Any 0-hop protocol", + 0x73: "L2TP", + 0x74: "DDX", + 0x75: "IATP", + 0x76: "STP", + 0x77: "SRP", + 0x78: "UTI", + 0x79: "SMP", + 0x7A: "SM", + 0x7B: "PTP", + 0x7C: "IS-IS over IPv4", + 0x7D: "FIRE", + 0x7E: "CRTP", + 0x7F: "CRUDP", + 0x80: "SSCOPMCE", + 0x81: "IPLT", + 0x82: "SPS", + 0x83: "PIPE", + 0x84: "SCTP", + 0x85: "FC", + 0x86: "RSVP-E2E-IGNORE", + 0x87: "Mobility Header", + 0x88: "UDPLite", + 0x89: "MPLS-in-IP", + 0x8A: "manet", + 0x8B: "HIP", + 0x8C: "Shim6", + 0x8D: "WESP", + 0x8E: "ROHC", + 0x8F: "Ethernet", + 0x90: "AGGFRAG", + 0x91: "NSH", + 0x92: "Unassigned", + 0xFD: "Use for experimentation and testing", + 0xFE: "Use for experimentation and testing", + 0xFF: "Reserved", +} + +func Protocol(buffer []byte) string { + if len(buffer) < 9 { + return "Invalid" + } + + if protocol, ok := Protocols[buffer[9]]; ok { + return protocol + } + + return "Unassigned" +} + +func Source(buffer []byte) [4]byte { + return [4]byte{buffer[12], buffer[13], buffer[14], buffer[15]} +} + +func Destination(buffer []byte) [4]byte { + return [4]byte{buffer[16], buffer[17], buffer[18], buffer[19]} +} + +func Length(buffer []byte) int { + return int(binary.BigEndian.Uint16(buffer[2:4])) +} + +func TimeToLive(buffer []byte) int { + return int(binary.BigEndian.Uint16(buffer[8:9])) +} + +// TODO: try don't use the [net.IPv4] or use it everytime. +func IsMulticast(buffer []byte) bool { + return net.IPv4(buffer[16], buffer[17], buffer[18], buffer[19]).IsMulticast() +} + +// TODO: try don't use the [net.IPv4] or use it everytime. +func IsUnicast(buffer []byte) bool { + return net.IPv4(buffer[16], buffer[17], buffer[18], buffer[19]).IsGlobalUnicast() +} diff --git a/pkg/agent/vpn/settings.go b/pkg/agent/vpn/settings.go new file mode 100644 index 00000000000..6cdf91c2f8c --- /dev/null +++ b/pkg/agent/vpn/settings.go @@ -0,0 +1,35 @@ +package vpn + +import ( + "encoding/json" + "fmt" + "io" + "net" +) + +type Settings struct { + Address [4]byte `json:"address"` + Mask byte `json:"mask"` +} + +// ParseSettings read and parses the [Settings] structure from an [io.Reader]. +func ParseSettings(data io.Reader) (*Settings, error) { + body, err := io.ReadAll(data) + if err != nil { + return nil, err + } + + settings := Settings{} + if err = json.Unmarshal(body, &settings); err != nil { + return nil, err + } + + return &settings, nil +} + +// String converts a [Settings] to a string representation on the format $IP/$Mask. +func (s *Settings) String() string { + ip := net.IPv4(s.Address[0], s.Address[1], s.Address[2], s.Address[3]) + + return fmt.Sprintf("%s/%d", ip.String(), s.Mask) +} diff --git a/pkg/agent/vpn/vpn.go b/pkg/agent/vpn/vpn.go new file mode 100644 index 00000000000..203a1edf9b5 --- /dev/null +++ b/pkg/agent/vpn/vpn.go @@ -0,0 +1,215 @@ +package vpn + +import ( + "context" + "errors" + "io" + "net" + "sync" + + "github.com/shellhub-io/shellhub/pkg/agent/pkg/tunnel" + "github.com/shellhub-io/shellhub/pkg/agent/vpn/pkg/ifce" + "github.com/shellhub-io/shellhub/pkg/agent/vpn/pkg/packets" + "github.com/shellhub-io/shellhub/pkg/api/client" + log "github.com/sirupsen/logrus" +) + +type VPN struct { + // tunnel is the reverse WebSocket connection between Agent and ShellHub's server. + tunnel *tunnel.Tunnel + // httpc is the HTTP client for the ShellHub's server. + httpc client.Client + // token is the JWT token used to operate on ShellHub's server. + // TODO: insert the token into the HTTP client. + token string + // done is a channel used to indicate to the connection handler that the connection was closed. + done chan struct{} +} + +// ConnectEndpoint is used by ShellHub's server to start a new VPN connection with the Agent. +const ConnectEndpoint string = "/vpn/connect" + +// CloseEndpoint is used by ShellHub's server to close a VPN connection. +const CloseEndpoint string = "/vpn/close" + +// NewVPN creates a new instance of VPN client. +func NewVPN(cli client.Client, token string) *VPN { + return &VPN{ + tunnel: tunnel.NewCustomTunnel(ConnectEndpoint, CloseEndpoint), + httpc: cli, + token: token, + done: make(chan struct{}), + } +} + +const ( + MinPacketSize int = 4 + MaxPacketSize = ifce.MaximumTransmissionUnit +) + +// Handler handles the connection established between the ShellHub's server to Agent, starting the packet transmission. +func (s *VPN) Handler(conn net.Conn, settings *Settings) error { + log.Debug("vpn connection accepted") + defer log.Debug("vpn connection closed") + + log.WithFields(log.Fields{ + "addrss": settings.Address, + "mask": settings.Mask, + }).Debug("interface data") + + face, err := ifce.NewInterface(settings.String()) + if err != nil { + log.WithError(err).Error("failed to create or configure the interface") + + return err + } + + defer face.Close() + + log.WithFields(log.Fields{ + "interface": face.Name(), + }).Debug("interface create") + + if err := face.Up(); err != nil { + log.WithError(err).Error("failed to get up the interface") + + return err + } + + log.WithFields(log.Fields{ + "interface": face.Name(), + }).Debug("interface up") + + wg := new(sync.WaitGroup) + + // done closes the connection between the ShellHub's server and the network interface on the Agent. + done := sync.OnceFunc(func() { + log.Trace("conn and ifce connections closed") + + conn.Close() + face.Close() + + s.tunnel.Close() + }) + + go func() { + <-s.done + + log.Trace("message on done channel received") + + done() + }() + + wg.Add(1) + go func() { + defer wg.Done() + defer log.Trace("reading from interface done") + defer done() + + buffer := make([]byte, MaxPacketSize) + + for { + read, err := io.ReadAtLeast(conn, buffer, MinPacketSize) + if err != nil { + log.WithError(err).Debug("failed to read from connection to interface") + + return + } + + if read == 0 { + continue + } + + if read != packets.Length(buffer) { + rest, err := io.ReadAtLeast(conn, buffer[read:], packets.Length(buffer)-read) + if err != nil { + log.WithError(err).Debug("failed to read the rest of data") + + return + } + + read = read + rest + } + + if _, err := face.Write(buffer[:read]); err != nil { + log.WithError(err).Debug("failed to write to interface") + + return + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + defer log.Trace("reading from conn done") + defer done() + + buffer := make([]byte, MaxPacketSize) + + for { + read, err := io.ReadAtLeast(face, buffer, MinPacketSize) + if err != nil { + log.WithError(err).Debug("failed to read from interface to connection") + + return + } + + if read == 0 { + continue + } + + if _, err := conn.Write(buffer[:read]); err != nil { + log.WithError(err).Debug("failed to write to connection") + } + } + }() + + log.WithFields(log.Fields{ + "address": settings.String(), + "interface": face.Name(), + }).Info("VPN connection started") + + wg.Wait() + + return nil +} + +// Close closes the ShellHub Agent's listening, stoping it from receive new connection requests. +func (s *VPN) Close() error { + // NOTE: It sends a close message to the handler. + s.done <- struct{}{} + + return s.tunnel.Close() +} + +var ( + ErrConnectListen = errors.New("listen closed on vpn connection") + ErrConnectionReverse = errors.New("reverse connection lost") +) + +func (s *VPN) Connect(ctx context.Context) error { + s.tunnel.ConnHandler = handler(s.Handler) + s.tunnel.CloseHandler = closeHandler(s.Close) + + listener, err := s.httpc.NewReverseListener(ctx, s.token, "/vpn/connection") + if err != nil { + return errors.Join(ErrConnectionReverse, err) + } + + defer listener.Close() + + go func() { + <-ctx.Done() + + log.Trace("message on ctx channel received") + + s.Close() + }() + + if err := s.tunnel.Listen(listener); err != nil { + return errors.Join(ErrConnectListen, err) + } + + return nil +} diff --git a/pkg/api/internalclient/client.go b/pkg/api/internalclient/client.go index 2a89b94dc25..ea99b33099c 100644 --- a/pkg/api/internalclient/client.go +++ b/pkg/api/internalclient/client.go @@ -18,6 +18,7 @@ type Client interface { sessionAPI sshkeyAPI firewallAPI + vpnAPI } type client struct { diff --git a/pkg/api/internalclient/mocks/internalclient.go b/pkg/api/internalclient/mocks/internalclient.go index 0094b640f4c..c3727b33360 100644 --- a/pkg/api/internalclient/mocks/internalclient.go +++ b/pkg/api/internalclient/mocks/internalclient.go @@ -439,6 +439,34 @@ func (_m *Client) UpdateSession(uid string, model *models.SessionUpdate) error { return r0 } +// VPNDeleteNamespaceCache provides a mock function with given fields: tenant +func (_m *Client) VPNDeleteNamespaceCache(tenant string) error { + ret := _m.Called(tenant) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(tenant) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// VPNStopRouter provides a mock function with given fields: tenant +func (_m *Client) VPNStopRouter(tenant string) error { + ret := _m.Called(tenant) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(tenant) + } else { + r0 = ret.Error(0) + } + + return r0 +} + type mockConstructorTestingTNewClient interface { mock.TestingT Cleanup(func()) diff --git a/pkg/api/internalclient/vpn.go b/pkg/api/internalclient/vpn.go new file mode 100644 index 00000000000..5406e9cbacd --- /dev/null +++ b/pkg/api/internalclient/vpn.go @@ -0,0 +1,23 @@ +package internalclient + +import "net/http" + +type vpnAPI interface { + // VPNStopRouter sends a rquest to VPN service to stop the namespace router. + VPNStopRouter(tenant string) error +} + +func (c *client) VPNStopRouter(tenant string) error { + res, err := c.http. + R(). + Delete("http://vpn:8080/vpn/router/" + tenant) + if err != nil { + return err + } + + if res.StatusCode() != http.StatusOK { + return err + } + + return nil +} diff --git a/pkg/api/requests/namespace.go b/pkg/api/requests/namespace.go index 0b1fcd13f0a..7662b8863e8 100644 --- a/pkg/api/requests/namespace.go +++ b/pkg/api/requests/namespace.go @@ -51,6 +51,14 @@ type NamespaceEdit struct { SessionRecord *bool `json:"session_record" validate:"omitempty"` ConnectionAnnouncement *string `json:"connection_announcement" validate:"omitempty,min=0,max=4096"` } `json:"settings"` + VPN struct { + // Enable defines if the Virtual Private Network between devices are enabled. + Enable *bool `json:"enable"` + // Address defines the network address. + Address *[4]byte `json:"address"` + // Mask defines the mask of the network. + Mask *byte `json:"mask" validate:"omitempty,min=8,max=24"` + } `json:"vpn"` } type NamespaceAddMember struct { diff --git a/pkg/models/namespace.go b/pkg/models/namespace.go index 2fc3b50964c..04a73f75c1f 100644 --- a/pkg/models/namespace.go +++ b/pkg/models/namespace.go @@ -14,6 +14,16 @@ type Namespace struct { DevicesCount int `json:"devices_count" bson:"devices_count,omitempty"` CreatedAt time.Time `json:"created_at" bson:"created_at"` Billing *Billing `json:"billing" bson:"billing,omitempty"` + VPN *VPN `json:"vpn" bson:"vpn,omitempty"` +} + +type VPN struct { + // Enable defines if the Virtual Private Network between devices are enabled. + Enable bool `json:"enable" bson:"enable"` + // Address defines the network address. + Address [4]byte `json:"address" bson:"address"` + // Mask defines the mask of the network. + Mask byte `json:"mask" bson:"mask"` } // HasMaxDevices checks if the namespace has a maximum number of devices. @@ -53,9 +63,12 @@ type NamespaceSettings struct { } type NamespaceChanges struct { - Name string `bson:"name,omitempty"` - SessionRecord *bool `bson:"settings.session_record,omitempty"` - ConnectionAnnouncement *string `bson:"settings.connection_announcement,omitempty"` + Name string `bson:"name,omitempty"` + SessionRecord *bool `bson:"settings.session_record,omitempty"` + ConnectionAnnouncement *string `bson:"settings.connection_announcement,omitempty"` + VPNEnable *bool `bson:"vpn.enable,omitempty"` + VPNAddress *[4]byte `bson:"vpn.address,omitempty"` + VPNMask *byte `bson:"vpn.mask,omitempty"` } // default Announcement Message for the shellhub namespace