diff --git a/viper.go b/viper.go index 20eb4da17..79c987a3c 100644 --- a/viper.go +++ b/viper.go @@ -28,6 +28,7 @@ import ( "os" "path/filepath" "reflect" + "slices" "strconv" "strings" "sync" @@ -689,6 +690,92 @@ func (v *Viper) searchMap(source map[string]any, path []string) any { return nil } +// searchMapWithAliases recursively searches for slice field in source map and +// replace them with the environment variable value if it exists. +// +// Returns replaced values, and a boolean if the value was found in +// environment varaible. +func (v *Viper) searchAndReplaceSliceValueWithEnv(source any, envKey string) (any, bool) { + + switch sourceValue := source.(type) { + case []any: + var newSliceValues []any + if len(sourceValue) <= 0 { + return newSliceValues, false + + } + + var exists []bool + for i := 0; ; i++ { + envKey := envKey + v.keyDelim + strconv.Itoa(i) + var value any + var existDefault = true + if len(sourceValue) < i+1 { + value = sourceValue[0] + existDefault = false + } else { + value = sourceValue[i] + } + switch existingValue := value.(type) { + case map[string]any: + newVal, found := v.searchAndReplaceSliceValueWithEnv(existingValue, envKey) + if !found && !existDefault { + return newSliceValues, slices.Contains(exists, true) + } + newSliceValues = append(newSliceValues, newVal) + exists = append(exists, found || existDefault) + + default: + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newSliceValues = append(newSliceValues, newVal) + exists = append(exists, true) + } else { + exists = append(exists, false || existDefault) + if existDefault { + newSliceValues = append(newSliceValues, existingValue) + } else { + return newSliceValues, slices.Contains(exists, true) + } + } + } + } + return newSliceValues, slices.Contains(exists, true) + + case map[string]any: + var newMapValues map[string]any = make(map[string]any) + var exists []bool + for key, mapValue := range sourceValue { + envKey := envKey + v.keyDelim + key + switch existingValue := mapValue.(type) { + case map[string]any: + newVal, found := v.searchAndReplaceSliceValueWithEnv(existingValue, envKey) + if !found { + return newMapValues, false + } + newMapValues[key] = newVal + exists = append(exists, found) + + default: + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newMapValues[key] = newVal + exists = append(exists, true) + } else { + exists = append(exists, false) + newMapValues[key] = existingValue + } + } + } + return newMapValues, slices.Contains(exists, true) + + default: + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + return newVal, true + } else { + return source, false + } + } +} + // searchIndexableWithPathPrefixes recursively searches for a value for path in source map/slice. // // While searchMap() considers each path element as a single map key or slice index, this @@ -906,6 +993,11 @@ func (v *Viper) Get(key string) any { return nil } + // Check for Env override again, to handle slices + if v.automaticEnvApplied { + val, _ = v.searchAndReplaceSliceValueWithEnv(val, lcaseKey) + } + if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val diff --git a/viper_test.go b/viper_test.go index 08533699c..814967a3e 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2606,6 +2606,87 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +var yamlSimpleSlice = []byte(` +name: Steve +port: 8080 +auth: + secret: 88888-88888 +modes: + - 1 + - 2 + - 3 +clients: + - name: foo + - name: bar +proxy: + clients: + - name: proxy_foo + - name: proxy_bar + - name: proxy_baz +`) + +func TestSliceIndexAutomaticEnv(t *testing.T) { + v.SetConfigType("yaml") + r := strings.NewReader(string(yamlSimpleSlice)) + + type ClientConfig struct { + Name string + } + + type AuthConfig struct { + Secret string + } + + type ProxyConfig struct { + Clients []ClientConfig + } + + type Configuration struct { + Port int + Name string + Auth AuthConfig + Modes []int + Clients []ClientConfig + Proxy ProxyConfig + } + + // Read yaml as default value + err := v.unmarshalReader(r, v.config) + require.NoError(t, err) + + assert.Equal(t, "Steve", v.GetString("name")) + assert.Equal(t, 8080, v.GetInt("port")) + assert.Equal(t, "88888-88888", v.GetString("auth.secret")) + assert.Equal(t, "foo", v.GetString("clients.0.name")) + assert.Equal(t, "bar", v.GetString("clients.1.name")) + assert.Equal(t, "proxy_foo", v.GetString("proxy.clients.0.name")) + assert.Equal(t, []int{1, 2, 3}, v.GetIntSlice("modes")) + + // Override with env variable + t.Setenv("NAME", "Steven") + t.Setenv("AUTH_SECRET", "99999-99999") + t.Setenv("MODES_2", "300") + t.Setenv("CLIENTS_1_NAME", "baz") + t.Setenv("PROXY_CLIENTS_0_NAME", "ProxyFoo") + t.Setenv("PROXY_CLIENTS_3_NAME", "ProxyNew") + + SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + AutomaticEnv() + + // Unmarshal into struct + var config Configuration + v.Unmarshal(&config) + + assert.Equal(t, "Steven", config.Name) + assert.Equal(t, 8080, config.Port) + assert.Equal(t, "99999-99999", config.Auth.Secret) + assert.Equal(t, []int{1, 2, 300}, config.Modes) + assert.Equal(t, "foo", config.Clients[0].Name) + assert.Equal(t, "baz", config.Clients[1].Name) + assert.Equal(t, "ProxyFoo", config.Proxy.Clients[0].Name) + assert.Equal(t, "ProxyNew", config.Proxy.Clients[3].Name) +} + func TestIsPathShadowedInFlatMap(t *testing.T) { v := New()