From cc7aa6859ac677bd420e7c91eed5d124a6a02cc2 Mon Sep 17 00:00:00 2001 From: Yar Kravtsov Date: Tue, 3 Dec 2024 09:37:06 +0200 Subject: [PATCH] refactor: Improve image synchronization and container management --- pkg/config/config.go | 25 ++--- pkg/deployment/deployment.go | 155 ++++++++++++++++----------- pkg/imagesync/imagesync.go | 27 ++--- pkg/imagesync/imagesync_test.go | 6 +- pkg/runner/remote/runner.go | 184 +++++++++++++++++++------------- 5 files changed, 233 insertions(+), 164 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index e2de31c..9fe67cc 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -41,18 +41,19 @@ type Server struct { } type Service struct { - Name string `yaml:"name" validate:"required"` - Image string `yaml:"image"` - Port int `yaml:"port" validate:"required,min=1,max=65535"` - Path string `yaml:"path"` - HealthCheck *HealthCheck `yaml:"health_check"` - Routes []Route `yaml:"routes" validate:"required,dive"` - Volumes []string `yaml:"volumes" validate:"dive,volume_reference"` - Command string `yaml:"command"` - Entrypoint []string `yaml:"entrypoint"` - Env []string `yaml:"env"` - Forwards []string `yaml:"forwards"` - Recreate bool `yaml:"recreate"` + Name string `yaml:"name" validate:"required"` + Image string `yaml:"image"` + ImageUpdated bool + Port int `yaml:"port" validate:"required,min=1,max=65535"` + Path string `yaml:"path"` + HealthCheck *HealthCheck `yaml:"health_check"` + Routes []Route `yaml:"routes" validate:"required,dive"` + Volumes []string `yaml:"volumes" validate:"dive,volume_reference"` + Command string `yaml:"command"` + Entrypoint []string `yaml:"entrypoint"` + Env []string `yaml:"env"` + Forwards []string `yaml:"forwards"` + Recreate bool `yaml:"recreate"` } type HealthCheck struct { diff --git a/pkg/deployment/deployment.go b/pkg/deployment/deployment.go index d02f6f2..e6b8a9b 100644 --- a/pkg/deployment/deployment.go +++ b/pkg/deployment/deployment.go @@ -23,12 +23,13 @@ const ( type Runner interface { CopyFile(ctx context.Context, from, to string) error - GetHost() string + Host() string RunCommand(ctx context.Context, command string, args ...string) (io.ReadCloser, error) } type ImageSyncer interface { - Sync(ctx context.Context, image string) error + Sync(ctx context.Context, image string) (bool, error) + CompareImages(ctx context.Context, image string) (bool, error) } type Deployment struct { @@ -42,7 +43,7 @@ func NewDeployment(runner Runner, syncer ImageSyncer, sm *console.SpinnerManager } func (d *Deployment) Deploy(ctx context.Context, project string, cfg *config.Config) error { - hostname := d.runner.GetHost() + hostname := d.runner.Host() // Create project network spinner := d.sm.AddSpinner("network", fmt.Sprintf("[%s] Creating network...", hostname)) @@ -76,7 +77,7 @@ func (d *Deployment) Deploy(ctx context.Context, project string, cfg *config.Con } func (d *Deployment) createVolumes(ctx context.Context, project string, volumes []string) error { - hostname := d.runner.GetHost() + hostname := d.runner.Host() for _, volume := range volumes { spinner := d.sm.AddSpinner("volume", fmt.Sprintf("[%s] Creating volume %s", hostname, volume)) @@ -93,7 +94,7 @@ func (d *Deployment) createVolumes(ctx context.Context, project string, volumes } func (d *Deployment) deployDependencies(ctx context.Context, project string, dependencies []config.Dependency) error { - hostname := d.runner.GetHost() + hostname := d.runner.Host() var wg sync.WaitGroup errChan := make(chan error, len(dependencies)) @@ -130,7 +131,7 @@ func (d *Deployment) deployDependencies(ctx context.Context, project string, dep } func (d *Deployment) deployServices(ctx context.Context, project string, services []config.Service) error { - hostname := d.runner.GetHost() + hostname := d.runner.Host() var wg sync.WaitGroup errChan := make(chan error, len(services)) @@ -167,7 +168,7 @@ func (d *Deployment) deployServices(ctx context.Context, project string, service } func (d *Deployment) startProxy(ctx context.Context, project string, cfg *config.Config) error { - hostname := d.runner.GetHost() + hostname := d.runner.Host() // Prepare project folder projectPath, err := d.prepareProjectFolder(project) @@ -237,10 +238,6 @@ func (d *Deployment) startProxy(ctx context.Context, project string, cfg *config } func (d *Deployment) startDependency(project string, dependency *config.Dependency) error { - if _, err := d.pullImage(dependency.Image); err != nil { - return fmt.Errorf("failed to pull image for %s: %v", dependency.Image, err) - } - service := &config.Service{ Name: dependency.Name, Image: dependency.Image, @@ -255,16 +252,6 @@ func (d *Deployment) startDependency(project string, dependency *config.Dependen } func (d *Deployment) installService(project string, service *config.Service) error { - if service.Image != "" { - if _, err := d.pullImage(service.Image); err != nil { - return fmt.Errorf("failed to pull image for %s: %v", service.Image, err) - } - } else { - if err := d.syncer.Sync(context.Background(), fmt.Sprintf("%s-%s", project, service.Name)); err != nil { - return fmt.Errorf("failed to sync service %s for %s: %v", service.Name, service.Image, err) - } - } - if err := d.createContainer(project, service, ""); err != nil { return fmt.Errorf("failed to start container for %s: %v", service.Image, err) } @@ -281,16 +268,6 @@ func (d *Deployment) installService(project string, service *config.Service) err func (d *Deployment) updateService(project string, service *config.Service) error { svcName := service.Name - if service.Image != "" { - if _, err := d.pullImage(service.Image); err != nil { - return fmt.Errorf("failed to pull new image for %s: %v", svcName, err) - } - } else { - if err := d.syncer.Sync(context.Background(), fmt.Sprintf("%s-%s", project, service.Name)); err != nil { - return fmt.Errorf("failed to sync service %s for %s: %v", service.Name, service.Image, err) - } - } - if service.Recreate { if err := d.recreateService(project, service); err != nil { return fmt.Errorf("failed to recreate service %s: %w", service.Name, err) @@ -369,7 +346,7 @@ type containerInfo struct { } func (d *Deployment) getContainerID(project, service string) (string, error) { - info, err := d.getContainerInfo(service, project) + info, err := d.getContainerInfo(project, service) if err != nil { return "", err } @@ -377,7 +354,7 @@ func (d *Deployment) getContainerID(project, service string) (string, error) { return info.ID, err } -func (d *Deployment) getContainerInfo(service, network string) (*containerInfo, error) { +func (d *Deployment) getContainerInfo(network, service string) (*containerInfo, error) { output, err := d.runCommand(context.Background(), "docker", "ps", "-aq", "--filter", fmt.Sprintf("network=%s", network)) if err != nil { return nil, fmt.Errorf("failed to get container IDs: %w", err) @@ -637,69 +614,119 @@ func (d *Deployment) prepareNginxConfig(cfg *config.Config, projectPath string) return configPath, d.runner.CopyFile(context.Background(), tmpFile.Name(), filepath.Join(configPath, "default.conf")) } -func (d *Deployment) serviceChanged(project string, service *config.Service) (bool, error) { - containerInfo, err := d.getContainerInfo(service.Name, project) +func (d *Deployment) deployService(project string, service *config.Service) error { + err := d.updateImage(project, service) if err != nil { - return false, fmt.Errorf("failed to get container info: %w", err) + return err } - hash, err := service.Hash() + containerStatus, err := d.getContainerStatus(project, service.Name) if err != nil { - return false, fmt.Errorf("failed to generate config hash: %w", err) + return err } - return containerInfo.Config.Labels["ftl.config-hash"] != hash, nil -} + if containerStatus == ContainerStatusNotFound { + if err := d.installService(project, service); err != nil { + return fmt.Errorf("failed to install service %s: %w", service.Name, err) + } -func (d *Deployment) deployService(project string, service *config.Service) error { - imageName := service.Image - if imageName == "" { - imageName = fmt.Sprintf("%s-%s", project, service.Name) + return nil } - hash, err := d.getImageHash(imageName) + containerShouldBeUpdated, err := d.containerShouldBeUpdated(project, service) if err != nil { - return fmt.Errorf("failed to pull image for %s: %w", service.Name, err) + return err } - containerInfo, err := d.getContainerInfo(service.Name, project) - if err != nil { - if err := d.installService(project, service); err != nil { - return fmt.Errorf("failed to install service %s: %w", service.Name, err) + if containerShouldBeUpdated { + if err := d.updateService(project, service); err != nil { + return fmt.Errorf("failed to update service %s due to image change: %w", service.Name, err) } return nil } - if hash != containerInfo.Image { - if err := d.updateService(project, service); err != nil { - return fmt.Errorf("failed to update service %s due to image change: %w", service.Name, err) + if containerStatus == ContainerStatusStopped { + if err := d.startContainer(service); err != nil { + return fmt.Errorf("failed to start container %s: %w", service.Name, err) } return nil } - changed, err := d.serviceChanged(project, service) - if err != nil { - return fmt.Errorf("failed to check if service %s has changed: %w", service.Name, err) - } + return nil +} - if changed { - if err := d.updateService(project, service); err != nil { - return fmt.Errorf("failed to update service %s due to config change: %w", service.Name, err) +type ContainerStatusType int + +const ( + ContainerStatusRunning ContainerStatusType = iota + ContainerStatusStopped + ContainerStatusNotFound + ContainerStatusError +) + +func (d *Deployment) getContainerStatus(project, service string) (ContainerStatusType, error) { + getContainerInfo, err := d.getContainerInfo(project, service) + if err != nil { + if strings.Contains(err.Error(), "no container found") { + return ContainerStatusNotFound, nil } - return nil + return ContainerStatusError, fmt.Errorf("failed to get container info: %w", err) } - if containerInfo.State.Status != "running" { - if err := d.startContainer(service); err != nil { - return fmt.Errorf("failed to start container %s: %w", service.Name, err) + if getContainerInfo.State.Status != "running" { + return ContainerStatusStopped, nil + } + + return ContainerStatusRunning, nil +} + +func (d *Deployment) updateImage(project string, service *config.Service) error { + if service.Image == "" { + updated, err := d.syncer.Sync(context.Background(), fmt.Sprintf("%s-%s", project, service.Name)) + if err != nil { + return err } + service.ImageUpdated = updated } + + _, err := d.pullImage(service.Image) + if err != nil { + return err + } + return nil } +func (d *Deployment) containerShouldBeUpdated(project string, service *config.Service) (bool, error) { + containerInfo, err := d.getContainerInfo(project, service.Name) + if err != nil { + return false, fmt.Errorf("failed to get container info: %w", err) + } + + imageHash, err := d.getImageHash(service.Image) + if err != nil { + return false, fmt.Errorf("failed to get image hash: %w", err) + } + + if service.Image == "" && service.ImageUpdated { + return true, nil + } + + if service.Image != "" && containerInfo.Image != imageHash { + return true, nil + } + + hash, err := service.Hash() + if err != nil { + return false, fmt.Errorf("failed to generate config hash: %w", err) + } + + return containerInfo.Config.Labels["ftl.config-hash"] != hash, nil +} + func (d *Deployment) networkExists(network string) (bool, error) { output, err := d.runCommand(context.Background(), "docker", "network", "ls", "--format", "{{.Name}}") if err != nil { diff --git a/pkg/imagesync/imagesync.go b/pkg/imagesync/imagesync.go index 84aee54..ff3be60 100644 --- a/pkg/imagesync/imagesync.go +++ b/pkg/imagesync/imagesync.go @@ -45,41 +45,41 @@ func NewImageSync(cfg Config, runner *remote.Runner) *ImageSync { } // Sync performs the Docker image synchronization process. -func (s *ImageSync) Sync(ctx context.Context, image string) error { - needsSync, err := s.compareImages(ctx, image) +func (s *ImageSync) Sync(ctx context.Context, image string) (bool, error) { + needsSync, err := s.CompareImages(ctx, image) if err != nil { - return fmt.Errorf("failed to compare images: %w", err) + return false, fmt.Errorf("failed to compare images: %w", err) } if !needsSync { - return nil // Images are identical + return false, nil // Images are identical } if err := s.prepareDirectories(ctx); err != nil { - return fmt.Errorf("failed to prepare directories: %w", err) + return false, fmt.Errorf("failed to prepare directories: %w", err) } if err := s.exportAndExtractImage(ctx, image); err != nil { - return fmt.Errorf("failed to export and extract image: %w", err) + return false, fmt.Errorf("failed to export and extract image: %w", err) } if err := s.transferMetadata(ctx, image); err != nil { - return fmt.Errorf("failed to transfer metadata: %w", err) + return false, fmt.Errorf("failed to transfer metadata: %w", err) } if err := s.syncBlobs(ctx, image); err != nil { - return fmt.Errorf("failed to sync blobs: %w", err) + return false, fmt.Errorf("failed to sync blobs: %w", err) } if err := s.loadRemoteImage(ctx, image); err != nil { - return fmt.Errorf("failed to load remote image: %w", err) + return false, fmt.Errorf("failed to load remote image: %w", err) } - return nil + return true, nil } -// compareImages checks if the image needs to be synced by comparing local and remote versions. -func (s *ImageSync) compareImages(ctx context.Context, image string) (bool, error) { +// CompareImages checks if the image needs to be synced by comparing local and remote versions. +func (s *ImageSync) CompareImages(ctx context.Context, image string) (bool, error) { localInspect, err := s.inspectLocalImage(image) if err != nil { return false, fmt.Errorf("failed to inspect local image: %w", err) @@ -91,7 +91,8 @@ func (s *ImageSync) compareImages(ctx context.Context, image string) (bool, erro } // Compare normalized JSON data - return !compareImageData(localInspect, remoteInspect), nil + imagesEqual := compareImageData(localInspect, remoteInspect) + return !imagesEqual, nil } // ImageData represents Docker image metadata. diff --git a/pkg/imagesync/imagesync_test.go b/pkg/imagesync/imagesync_test.go index e6ddc09..92cb52a 100644 --- a/pkg/imagesync/imagesync_test.go +++ b/pkg/imagesync/imagesync_test.go @@ -83,7 +83,7 @@ func TestImageSync(t *testing.T) { // Run sync t.Log("Running sync...") ctx := context.Background() - err = sync.Sync(ctx, testImage) + _, err = sync.Sync(ctx, testImage) require.NoError(t, err) // Verify image exists on remote @@ -98,12 +98,12 @@ func TestImageSync(t *testing.T) { // Test image comparison t.Log("Comparing images...") - needsSync, err := sync.compareImages(ctx, testImage) + needsSync, err := sync.CompareImages(ctx, testImage) require.NoError(t, err) require.False(t, needsSync, "Images should be identical after sync") // Test re-sync with no changes t.Log("Re-syncing...") - err = sync.Sync(ctx, testImage) + _, err = sync.Sync(ctx, testImage) require.NoError(t, err) } diff --git a/pkg/runner/remote/runner.go b/pkg/runner/remote/runner.go index fd6bbe7..e9eb703 100644 --- a/pkg/runner/remote/runner.go +++ b/pkg/runner/remote/runner.go @@ -12,136 +12,176 @@ import ( "golang.org/x/crypto/ssh" ) -// Runner provides methods to run commands and copy files on a remote host. +// ErrNoClient is returned when attempting operations on a closed Runner. +var ErrNoClient = errors.New("ssh client is nil") + +// Runner executes commands and transfers files on a remote host via SSH. +// Once closed, a Runner cannot be reused. type Runner struct { - sshClient *ssh.Client + client *ssh.Client // client is unexported as it's an implementation detail } -// NewRunner creates a new Runner instance. -func NewRunner(sshClient *ssh.Client) *Runner { - return &Runner{ - sshClient: sshClient, +// NewRunner creates a new Runner instance using the provided SSH client. +// It returns nil if the client is nil. +func NewRunner(client *ssh.Client) *Runner { + if client == nil { + return nil } + return &Runner{client: client} } -// Close closes the SSH client. -func (c *Runner) Close() error { - if c.sshClient == nil { +// Close releases all resources associated with the Runner. +// After Close, the Runner cannot be reused. +func (r *Runner) Close() error { + if r.client == nil { return nil } - - err := c.sshClient.Close() - c.sshClient = nil + err := r.client.Close() + r.client = nil return err } -// RunCommands runs multiple commands on the remote host. -func (c *Runner) RunCommands(ctx context.Context, commands []string) error { - for _, command := range commands { - rc, err := c.RunCommand(ctx, command) +// RunCommands executes multiple commands sequentially on the remote host. +// It stops at the first command that fails. +func (r *Runner) RunCommands(ctx context.Context, commands []string) error { + if r.client == nil { + return ErrNoClient + } + + for _, cmd := range commands { + output, err := r.RunCommand(ctx, cmd) if err != nil { - return fmt.Errorf("failed to run command '%s': %w", command, err) + return fmt.Errorf("executing command %q: %w", cmd, err) } - _, readErr := io.ReadAll(rc) + // Close the output reader in the same iteration to prevent resource leaks + _, err = io.Copy(io.Discard, output) + closeErr := output.Close() - if readErr != nil { - return fmt.Errorf("failed to read output of command '%s': %w", command, readErr) + if err != nil { + return fmt.Errorf("reading output of %q: %w", cmd, err) + } + if closeErr != nil { + return fmt.Errorf("closing output of %q: %w", cmd, closeErr) } } - return nil } -// RunCommand runs a command on the remote host. -func (c *Runner) RunCommand(ctx context.Context, command string, args ...string) (io.ReadCloser, error) { - session, err := c.sshClient.NewSession() +// RunCommand executes a single command with optional arguments on the remote host. +// The caller must close the returned ReadCloser when done. +func (r *Runner) RunCommand(ctx context.Context, command string, args ...string) (io.ReadCloser, error) { + if r.client == nil { + return nil, ErrNoClient + } + + session, err := r.client.NewSession() if err != nil { - return nil, fmt.Errorf("unable to create session: %v", err) + return nil, fmt.Errorf("creating session: %w", err) } - fullCommand := command + // Build the full command with properly escaped arguments + fullCmd := command if len(args) > 0 { escapedArgs := make([]string, len(args)) for i, arg := range args { - escapedArgs[i] = sshEscapeArg(arg) + escapedArgs[i] = escapeArg(arg) } - fullCommand += " " + strings.Join(escapedArgs, " ") + fullCmd += " " + strings.Join(escapedArgs, " ") } + // Set up command I/O stdout, err := session.StdoutPipe() if err != nil { session.Close() - return nil, fmt.Errorf("failed to get stdout pipe: %w", err) + return nil, fmt.Errorf("creating stdout pipe: %w", err) } stderr, err := session.StderrPipe() if err != nil { session.Close() - return nil, fmt.Errorf("failed to get stderr pipe: %w", err) + return nil, fmt.Errorf("creating stderr pipe: %w", err) } - if err := session.Start(fullCommand); err != nil { + if err := session.Start(fullCmd); err != nil { session.Close() - return nil, fmt.Errorf("failed to start command: %w", err) + return nil, fmt.Errorf("starting command: %w", err) } - outputReader := io.MultiReader(stdout, stderr) - - readCloser := &sessionReadCloser{ - Reader: outputReader, + return &commandOutput{ + reader: io.MultiReader(stdout, stderr), session: session, ctx: ctx, - } - - return readCloser, nil + }, nil } -func (c *Runner) GetHost() string { - fullAddr := c.sshClient.RemoteAddr().String() - addr := strings.Split(fullAddr, ":") - return addr[0] +// Host returns the hostname of the remote server. +func (r *Runner) Host() string { + if r.client == nil { + return "" + } + addr := r.client.RemoteAddr().String() + host, _, _ := strings.Cut(addr, ":") + return host } -// sshEscapeArg properly escapes a command-line argument for SSH -func sshEscapeArg(arg string) string { - return "'" + strings.Replace(arg, "'", "'\\''", -1) + "'" -} +// CopyFile copies a file from src on the local machine to dst on the remote host. +// The destination file will have permissions 0644. +func (r *Runner) CopyFile(ctx context.Context, src, dst string) error { + if r.client == nil { + return ErrNoClient + } -// sessionReadCloser wraps an io.Reader and closes the SSH session when closed -type sessionReadCloser struct { - io.Reader - session *ssh.Session - ctx context.Context -} + client, err := scp.NewClientBySSH(r.client) + if err != nil { + return fmt.Errorf("creating SCP client: %w", err) + } -// Close closes the SSH session and waits for the command to finish. -func (src *sessionReadCloser) Close() error { - if err := src.session.Signal(ssh.SIGTERM); err != nil { - src.session.Close() + f, err := os.Open(src) + if err != nil { + return fmt.Errorf("opening source file: %w", err) } + defer f.Close() - if err := src.session.Wait(); err != nil { - var exitMissingError *ssh.ExitMissingError - if !errors.As(err, &exitMissingError) { - return err - } + if err := client.CopyFile(ctx, f, dst, "0644"); err != nil { + return fmt.Errorf("copying file: %w", err) } + return nil +} - return src.session.Close() +// commandOutput combines the stdout and stderr of a remote command +// and handles proper cleanup of the underlying SSH session. +type commandOutput struct { + reader io.Reader + session *ssh.Session + ctx context.Context } -// CopyFile copies a file from the remote host to the local machine. -func (c *Runner) CopyFile(ctx context.Context, src, dst string) error { - client, err := scp.NewClientBySSH(c.sshClient) - if err != nil { - return fmt.Errorf("failed to create SCP client: %w", err) +func (c *commandOutput) Read(p []byte) (int, error) { + // Check context cancellation before reading + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + return c.reader.Read(p) } +} - file, err := os.Open(src) - if err != nil { - return fmt.Errorf("failed to open file: %w", err) +func (c *commandOutput) Close() error { + // Send SIGTERM first for graceful shutdown + _ = c.session.Signal(ssh.SIGTERM) + + var exitErr *ssh.ExitError + err := c.session.Wait() + if err != nil && !errors.As(err, &exitErr) { + c.session.Close() + return fmt.Errorf("waiting for command completion: %w", err) } - return client.CopyFile(ctx, file, dst, "0644") + return c.session.Close() +} + +// escapeArg escapes a command-line argument for safe use in SSH commands. +func escapeArg(arg string) string { + return "'" + strings.ReplaceAll(arg, "'", "'\\''") + "'" }