Skip to content

Commit

Permalink
feat(pkg,agent): allow agent stop to listening
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto committed Sep 21, 2023
1 parent 367c240 commit fbb157c
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 36 deletions.
41 changes: 30 additions & 11 deletions agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,33 +119,38 @@ func main() {
}).Fatal("Failed to initialize agent")
}

listing := make(chan bool)
ctx := cmd.Context()

listening := make(chan bool)
go func() {
<-listing
<-listening
// NOTICE: We only start to ping the server when the agent is ready to accept connections.
// It will make the agent ping to server after the ticker time set on ping function, what is 10 minutes by
// default.

ping := make(chan agent.Ping)
go ag.Ping(nil, ping)

for range ping {
log.WithFields(log.Fields{
if err := ag.Ping(ctx, nil); err != nil {
log.WithError(err).WithFields(log.Fields{
"version": AgentVersion,
"mode": mode,
"tenant_id": cfg.TenantID,
"server_address": cfg.ServerAddress,
"timestamp": time.Now(),
}).Info("ping")
}).Fatal("Failed to ping server")
}

log.WithFields(log.Fields{
"version": AgentVersion,
"mode": mode,
"tenant_id": cfg.TenantID,
"server_address": cfg.ServerAddress,
}).Info("Stopped pinging server")
}()

log.WithFields(log.Fields{
"version": AgentVersion,
"mode": mode,
"tenant_id": cfg.TenantID,
"server_address": cfg.ServerAddress,
}).Info("listening for connections")
}).Info("Listening for connections")

// Disable check update in development mode
if AgentVersion != "latest" {
Expand Down Expand Up @@ -178,7 +183,21 @@ func main() {
}()
}

ag.Listen(listing) //nolint:errcheck
if err := ag.Listen(ctx, listening); err != nil {
log.WithError(err).WithFields(log.Fields{
"version": AgentVersion,
"mode": mode,
"tenant_id": cfg.TenantID,
"server_address": cfg.ServerAddress,
}).Fatal("Failed to listen for connections")
}

log.WithFields(log.Fields{
"version": AgentVersion,
"mode": mode,
"tenant_id": cfg.TenantID,
"server_address": cfg.ServerAddress,
}).Info("Stopped listening for connections")
},
}

Expand Down
95 changes: 70 additions & 25 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
package agent

import (
"context"
"crypto/rsa"
"io"
"net"
Expand All @@ -53,6 +54,7 @@ import (
"os"
"runtime"
"strings"
"sync"
"time"

"github.com/Masterminds/semver"
Expand All @@ -68,12 +70,6 @@ import (
log "github.com/sirupsen/logrus"
)

// Ping is a message sent by the agent to the server to keep the connection alive.
type Ping struct {
// Timestamp is the time the ping was sent or received.
Timestamp time.Time
}

// throw sends a value on a channel, but does not block the goroutine.
func throw[V any, T chan V](ch T, v V) {
ch <- v
Expand Down Expand Up @@ -140,6 +136,9 @@ type Agent struct {
serverAddress *url.URL
sessions []string
server *server.Server
tunnel *tunnel.Tunnel
mux sync.RWMutex
done bool
}

// NewAgent creates a new agent instance.
Expand Down Expand Up @@ -184,6 +183,7 @@ func NewAgentWithConfig(config *Config) (*Agent, error) {
config: config,
serverAddress: serverAddress,
cli: client.NewClient(client.WithURL(serverAddress)),
tunnel: tunnel.NewTunnel(),
}

return a, nil
Expand Down Expand Up @@ -218,6 +218,10 @@ func (a *Agent) Initialize() error {
return errors.Wrap(err, "failed to authorize device")
}

a.mux.Lock()
a.done = true
a.mux.Unlock()

return nil
}

Expand Down Expand Up @@ -312,17 +316,24 @@ func (a *Agent) NewReverseListener() (*revdial.Listener, error) {
return a.cli.NewReverseListener(a.authData.Token)
}

func (a *Agent) Close() error {
a.mux.Lock()
a.done = true
a.mux.Unlock()

return a.tunnel.Close()
}

// Listen creates a new SSH server, tunnel to ShellHub and listen for incoming connections.
//
// listening parameter is a channel that is notified when the agent is listing for connections. It can be used to
// start to ping the server, synchronizing device information or other tasks.
func (a *Agent) Listen(listining chan bool) error {
func (a *Agent) Listen(ctx context.Context, listining chan bool) error {
a.server = server.NewServer(a.cli, a.authData, a.config.PrivateKey, a.config.KeepAliveInterval, a.config.SingleUserPassword)

serv := a.server

tun := tunnel.NewTunnel()
tun.ConnHandler = func(c echo.Context) error {
a.tunnel.ConnHandler = func(c echo.Context) error {
hj, ok := c.Response().Writer.(http.Hijacker)
if !ok {
return c.String(http.StatusInternalServerError, "webserver doesn't support hijacking")
Expand All @@ -343,7 +354,7 @@ func (a *Agent) Listen(listining chan bool) error {
return nil
}

tun.HTTPHandler = func(c echo.Context) error {
a.tunnel.HTTPHandler = func(c echo.Context) error {
replyError := func(err error, msg string, code int) error {
log.WithError(err).WithFields(log.Fields{
"remote": c.Request().RemoteAddr,
Expand Down Expand Up @@ -388,7 +399,7 @@ func (a *Agent) Listen(listining chan bool) error {
return nil
}

tun.CloseHandler = func(c echo.Context) error {
a.tunnel.CloseHandler = func(c echo.Context) error {
id := c.Param("id")
serv.CloseSession(id)

Expand All @@ -397,7 +408,25 @@ func (a *Agent) Listen(listining chan bool) error {

serv.SetDeviceName(a.authData.Name)

// NOTICE(r): when context is canceled, the agent will close the tunnel and stop listening for connections.
go func() {
<-ctx.Done()
a.Close() //nolint:errcheck
}()

for {
a.mux.RLock()
if a.done {
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
}).Debug("stopped listening for connections")

return nil
}
a.mux.RUnlock()

listener, err := a.NewReverseListener()
if err != nil {
time.Sleep(time.Second * 10)
Expand All @@ -424,7 +453,7 @@ func (a *Agent) Listen(listining chan bool) error {
}).Info("Server connection established")

throw(listining, true)
if err := tun.Listen(listener); err != nil {
if err := a.tunnel.Listen(listener); err != nil {
continue
}
throw(listining, false)
Expand All @@ -436,24 +465,40 @@ func (a *Agent) Listen(listining chan bool) error {
// If the ticker is nil, it will be set to 10 minutes.
//
// ping parameter is a channel that is notified when the agent pings the server.
func (a *Agent) Ping(ticker *time.Ticker, ping chan Ping) {
func (a *Agent) Ping(ctx context.Context, ticker *time.Ticker) error {
if ticker == nil {
ticker = time.NewTicker(10 * time.Minute)
}

for range ticker.C {
sessions := make([]string, 0, len(a.server.Sessions))
for key := range a.server.Sessions {
sessions = append(sessions, key)
}

a.sessions = sessions

if err := a.authorize(); err != nil {
a.server.SetDeviceName(a.authData.Name)
for {
select {
case <-ctx.Done():
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
}).Debug("stopped pinging server due to context cancellation")

return nil
case <-ticker.C:
sessions := make([]string, 0, len(a.server.Sessions))
for key := range a.server.Sessions {
sessions = append(sessions, key)
}

a.sessions = sessions

if err := a.authorize(); err != nil {
a.server.SetDeviceName(a.authData.Name)
}

log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"timestamp": time.Now(),
}).Info("Ping")
}

throw(ping, Ping{Timestamp: time.Now()})
}
}

Expand Down
9 changes: 9 additions & 0 deletions pkg/agent/pkg/tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ func NewTunnel() *Tunnel {
func (t *Tunnel) Listen(l *revdial.Listener) error {
return t.srv.Serve(l)
}

// Close closes the tunnel.
func (t *Tunnel) Close() error {
if err := t.router.Close(); err != nil {
return err
}

return t.srv.Close()
}

0 comments on commit fbb157c

Please sign in to comment.