Skip to content

Commit

Permalink
Improve command handler and test failed commands
Browse files Browse the repository at this point in the history
  • Loading branch information
pjcdawkins committed Dec 28, 2024
1 parent b21f18b commit 4987a07
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
63 changes: 37 additions & 26 deletions go-tests/mockssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"testing"
"time"
Expand All @@ -21,39 +20,45 @@ func newServerConfig(t *testing.T) *ssh.ServerConfig {
return &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if cert, ok := key.(*ssh.Certificate); ok {
t.Log("SSH certificate received with key ID:", cert.KeyId)
t.Logf("SSH certificate received from %s with key ID %s", conn.RemoteAddr(), cert.KeyId)
return &ssh.Permissions{CriticalOptions: cert.CriticalOptions, Extensions: cert.Extensions}, nil
}
return nil, fmt.Errorf("not accepting public key type: %s", key.Type())
},
AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) {
if err == nil {
t.Logf("SSH authenticated with user %s and method %s", conn.User(), method)
t.Logf("SSH auth log: client %s, user %s, method %s",
conn.RemoteAddr(), conn.User(), method)
} else {
t.Logf("SSH auth log: client %s (%s), server %s, user %s, method %s, error: %s",
conn.RemoteAddr(), conn.ClientVersion(), conn.LocalAddr(), conn.User(), method, err)
}
},
}
}

type MockSSHServer struct {
t *testing.T
validCommands map[string]string
listener net.Listener
port int
hostKey ssh.Signer
t *testing.T
commandHandler CommandHandler
listener net.Listener
port int
hostKey ssh.Signer
}

type CommandHandler func(conn ssh.ConnMetadata, command string) (string, string, uint32)

// StartServer creates and starts a local SSH server.
// The server will automatically be stopped when the test completes.
func StartServer(t *testing.T, validCommands map[string]string) (*MockSSHServer, error) {
func StartServer(t *testing.T, handler CommandHandler) (*MockSSHServer, error) {
hk, err := ssh.ParsePrivateKey(hostKey)
if err != nil {
return nil, fmt.Errorf("failed to parse host key: %v", err)
}

s := &MockSSHServer{
t: t,
validCommands: validCommands,
hostKey: hk,
t: t,
commandHandler: handler,
hostKey: hk,
}
if err := s.start(); err != nil {
return nil, err
Expand Down Expand Up @@ -94,7 +99,7 @@ func (s *MockSSHServer) start() error {
s.port = listener.Addr().(*net.TCPAddr).Port
t.Logf("Test SSH server listening at %s", listener.Addr())

go func(l net.Listener, validCommands map[string]string) {
go func(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
Expand All @@ -111,7 +116,7 @@ func (s *MockSSHServer) start() error {
return
}

t.Logf("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
t.Logf("Handling SSH connection from %s", sshConn.RemoteAddr())

wg := sync.WaitGroup{}
wg.Add(1)
Expand All @@ -121,17 +126,17 @@ func (s *MockSSHServer) start() error {
}()
wg.Add(1)
go func() {
s.handleChannels(chans)
s.handleChannels(sshConn, chans)
wg.Done()
}()
wg.Wait()
}
}(s.listener, s.validCommands)
}(s.listener)

return nil
}

func (s *MockSSHServer) handleChannels(chans <-chan ssh.NewChannel) {
func (s *MockSSHServer) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChannel) {
t := s.t

for newChannel := range chans {
Expand Down Expand Up @@ -163,18 +168,24 @@ func (s *MockSSHServer) handleChannels(chans <-chan ssh.NewChannel) {
if err != nil {
t.Errorf("Failed to reply to command: %v", err)
}
cmd := strings.TrimLeft(string(req.Payload), "\x00\x03")
t.Logf("Command received: %s", cmd)
if output := s.validCommands[cmd]; output != "" {
_, err = channel.Write([]byte(output))
// Strip the first four bytes of the payload, the uint32 representing the string length.
// See https://datatracker.ietf.org/doc/html/rfc4251#section-5
cmd := req.Payload[4:]
t.Logf("Handling command: %s", cmd)
stdOut, stdErr, exitCode := s.commandHandler(conn, string(cmd))
if len(stdOut) > 0 {
_, err = channel.Write([]byte(stdOut))
if err != nil {
t.Errorf("Failed to write to stdout channel: %v", err)
}
}
if len(stdErr) > 0 {
_, err = channel.Stderr().Write([]byte(stdErr))
if err != nil {
t.Errorf("Failed to write to channel: %v", err)
t.Errorf("Failed to write to stderr channel: %v", err)
}
exitWithStatus <- 0
} else {
_, _ = channel.Stderr().Write([]byte(fmt.Sprintf("Invalid command: %s", cmd)))
exitWithStatus <- 1
}
exitWithStatus <- exitCode
return
default:
_ = req.Reply(false, nil)
Expand Down
24 changes: 21 additions & 3 deletions go-tests/ssh_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package tests

import (
"fmt"
"net/http/httptest"
"os/exec"
"strconv"
"testing"

"github.com/platformsh/cli/pkg/mockapi"
"github.com/platformsh/legacy-cli/tests/mockssh"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
)

func TestSSH(t *testing.T) {
Expand All @@ -15,8 +19,14 @@ func TestSSH(t *testing.T) {

myUserID := "my-user-id"

sshServer, err := mockssh.StartServer(t, map[string]string{
"pwd": "/mock/path",
sshServer, err := mockssh.StartServer(t, func(conn ssh.ConnMetadata, command string) (string, string, uint32) {
switch command {
case "pwd":
return "/mock/path", "", 0
case "fail-with-code-2":
return "", "Returning exit code 2\n", 2
}
return "", fmt.Sprintf("Unknown command: %s\n", command), 1
})
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -62,5 +72,13 @@ func TestSSH(t *testing.T) {
}

f.Run("cc")
assertTrimmed(t, "/mock/path", f.Run("ssh", "-p", projectID, "-e", ".", "pwd"))
assert.Equal(t, "/mock/path", f.Run("ssh", "-p", projectID, "-e", ".", "pwd"))

_, stdErr, _ := f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "2", "pwd")
assert.Contains(t, stdErr, "Available instances: 0, 1")

_, _, err = f.RunCombinedOutput("ssh", "-p", projectID, "-e", "main", "--instance", "1", "fail-with-code-2")
var exitErr *exec.ExitError
assert.ErrorAs(t, err, &exitErr)
assert.Equal(t, 2, exitErr.ExitCode())
}

0 comments on commit 4987a07

Please sign in to comment.