diff --git a/utils.go b/utils.go index cf57677..2aa233f 100644 --- a/utils.go +++ b/utils.go @@ -2,38 +2,60 @@ package zfs import ( "bytes" + "context" "errors" "fmt" "io" + "os" "os/exec" "regexp" "runtime" "strconv" "strings" + "time" "github.com/google/uuid" ) +const ( + cmdtimeoutEnv = "COMMAND_TIMEOUT" +) + type command struct { Command string Stdin io.Reader Stdout io.Writer + timeout *time.Duration +} + +func getCommandTimeout() *time.Duration { + value := os.Getenv(cmdtimeoutEnv) + if timeout, err := time.ParseDuration(value); value != "" && err == nil { + return &timeout + } + return nil } func (c *command) Run(arg ...string) ([][]string, error) { cmd := exec.Command(c.Command, arg...) + if c.timeout != nil { + ctx, cancel := context.WithTimeout(context.TODO(), *c.timeout) + defer cancel() + cmd = exec.CommandContext(ctx, c.Command, arg...) + } - var stdout, stderr bytes.Buffer + if c.Stdin != nil { + cmd.Stdin = c.Stdin + } + var stdout bytes.Buffer if c.Stdout == nil { cmd.Stdout = &stdout } else { cmd.Stdout = c.Stdout } - if c.Stdin != nil { - cmd.Stdin = c.Stdin - } + var stderr bytes.Buffer cmd.Stderr = &stderr id := uuid.New().String() diff --git a/zfs.go b/zfs.go index 1166bdc..3132be3 100644 --- a/zfs.go +++ b/zfs.go @@ -114,7 +114,7 @@ func zfs(arg ...string) error { // zfs is a helper function to wrap typical calls to zfs. func zfsOutput(arg ...string) ([][]string, error) { - c := command{Command: "zfs"} + c := command{Command: "zfs", timeout: getCommandTimeout()} return c.Run(arg...) } diff --git a/zpool.go b/zpool.go index a0bd647..e4e5392 100644 --- a/zpool.go +++ b/zpool.go @@ -36,7 +36,7 @@ func zpool(arg ...string) error { // zpool is a helper function to wrap typical calls to zpool. func zpoolOutput(arg ...string) ([][]string, error) { - c := command{Command: "zpool"} + c := command{Command: "zpool", timeout: getCommandTimeout()} return c.Run(arg...) }