diff --git a/viper.go b/viper.go index 2ff730f31..9a3e0de03 100644 --- a/viper.go +++ b/viper.go @@ -21,6 +21,7 @@ package viper import ( "bytes" + "context" "encoding/csv" "errors" "fmt" @@ -215,7 +216,8 @@ type Viper struct { aliases map[string]string typeByDefValue bool - onConfigChange func(fsnotify.Event) + onConfigChange func(fsnotify.Event) + stopWatchingFunc func() logger Logger @@ -456,6 +458,10 @@ func (v *Viper) WatchConfig() { configDir, _ := filepath.Split(configFile) realConfigFile, _ := filepath.EvalSymlinks(filename) + // init the stopWatchingFunc + watchingCtx, cancel := context.WithCancel(context.Background()) + v.stopWatchingFunc = cancel + eventsWG := sync.WaitGroup{} eventsWG.Add(1) go func() { @@ -492,6 +498,9 @@ func (v *Viper) WatchConfig() { } eventsWG.Done() return + case <-watchingCtx.Done(): // StopWatching function called + eventsWG.Done() + return } } }() @@ -502,6 +511,16 @@ func (v *Viper) WatchConfig() { initWG.Wait() // make sure that the go routine above fully ended before returning } +// StopWatching stop watching a config file for changes. +func StopWatching() { v.StopWatching() } + +// StopWatching stop watching a config file for changes. +func (v *Viper) StopWatching() { + if v.stopWatchingFunc != nil { + v.stopWatchingFunc() + } +} + // SetConfigFile explicitly defines the path, name and extension of the config file. // Viper will use this and not check any of the config paths. func SetConfigFile(in string) { v.SetConfigFile(in) } diff --git a/viper_test.go b/viper_test.go index e0bfc57bd..06a16718e 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2545,6 +2545,46 @@ func TestWatchFile(t *testing.T) { }) } +func TestStopWatching(t *testing.T) { + t.Run( + "file content changed after stop watching", func(t *testing.T) { + // given a `config.yaml` file being watched + v, configFile, cleanup := newViperWithConfigFile(t) + defer cleanup() + _, err := os.Stat(configFile) + require.NoError(t, err) + t.Logf("test config file: %s\n", configFile) + + v.WatchConfig() + v.StopWatching() + + // overwriting the file after StopWatching called + err = ioutil.WriteFile(configFile, []byte("foo: baz\n"), 0o640) + time.Sleep(time.Second) // wait for file changed event + // then the config value should not be changed + require.Nil(t, err) + assert.Equal(t, "bar", v.Get("foo")) + + // watch again + wg := sync.WaitGroup{} + wg.Add(1) + var wgDoneOnce sync.Once // OnConfigChange is called twice on Windows + v.OnConfigChange( + func(in fsnotify.Event) { + t.Logf("config file changed again") + wgDoneOnce.Do(func() { wg.Done() }) + }, + ) + v.WatchConfig() + // overwriting the file after StopWatching and Watch again + err = ioutil.WriteFile(configFile, []byte("foo: qux\n"), 0o640) + wg.Wait() + require.Nil(t, err) + assert.Equal(t, "qux", v.Get("foo")) + }, + ) +} + func TestUnmarshal_DotSeparatorBackwardCompatibility(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags.String("foo.bar", "cobra_flag", "")