From c0546f741946826ac57ff1af48f29f4742f64b0a Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Wed, 3 Jan 2024 12:17:45 +0800 Subject: [PATCH 1/4] add unit tests to catch a bug/gap where slice values are not overriden by env variables --- viper.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++- viper_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/viper.go b/viper.go index 20eb4da17..b658ffda3 100644 --- a/viper.go +++ b/viper.go @@ -1128,7 +1128,83 @@ func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error { } // TODO: struct keys should be enough? - return decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) + err := decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) + + // Post processing for slice of maps + // if features.BindStruct { + err = unmarshalPostProcess(rawVal, opts...) + if err != nil { + return err + } + // } + + return err +} + +func unmarshalPostProcess(input any, opts ...DecoderConfigOption) error { + var structKeyMap map[string]any + + err := decode(input, defaultDecoderConfig(&structKeyMap, opts...)) + if err != nil { + return err + } + + v.postProcessingSliceFields(map[string]bool{}, structKeyMap, "") + return nil +} + +// TODO remove shadow +func (v *Viper) postProcessingSliceFields(shadow map[string]bool, m map[string]any, prefix string) map[string]bool { + if shadow != nil && prefix != "" && shadow[prefix] { + // prefix is shadowed => nothing more to flatten + return shadow + } + if shadow == nil { + shadow = make(map[string]bool) + } + + var m2 map[string]any + if prefix != "" { + prefix += v.keyDelim + } + for k, val := range m { + fullKey := prefix + k + valValue := reflect.ValueOf(val) + if valValue.Kind() == reflect.Slice { + for i := 0; i < valValue.Len(); i++ { + item := valValue.Index(i) + if item.Kind() != reflect.Struct || !item.CanSet() { + continue + } + itemType := item.Type() + for j := 0; j < item.NumField(); j++ { + field := itemType.Field(j) + // fmt.Printf("Field %d: Name=%s, Type=%v, Value=%v\n", j, field.Name, field.Type, item.Field(j).Interface()) + + sliceKey := fmt.Sprintf("%s%s%s%d%s%s", prefix, k, v.keyDelim, i, v.keyDelim, field.Name) + shadow[strings.ToLower(sliceKey)] = true + // fmt.Printf("%s is slice\n", sliceKey) + + if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok { + // fmt.Printf("Val is %v\n", val) + item.Field(j).SetString(val) + } + } + } + } + + switch val := val.(type) { + case map[string]any: + m2 = val + case map[any]any: + m2 = cast.ToStringMap(val) + default: + continue + } + // recursively merge to shadow map + shadow = v.postProcessingSliceFields(shadow, m2, fullKey) + } + return shadow } func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { diff --git a/viper_test.go b/viper_test.go index 0b1f40741..30460b047 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2606,6 +2606,77 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +var yamlSimpleSlice = []byte(` +name: Steve +port: 8080 +auth: + secret: 88888-88888 +clients: + - name: foo + - name: bar +proxy: + clients: + - name: proxy_foo + - name: proxy_bar + - name: proxy_baz +`) + +func TestSliceIndexAutomaticEnv(t *testing.T) { + v.SetConfigType("yaml") + r := strings.NewReader(string(yamlSimpleSlice)) + + type ClientConfig struct { + Name string + } + + type AuthConfig struct { + Secret string + } + + type ProxyConfig struct { + Clients []ClientConfig + } + + type Configuration struct { + Port int + Name string + Auth AuthConfig + Clients []ClientConfig + Proxy ProxyConfig + } + + // Read yaml as default value + err := v.unmarshalReader(r, v.config) + require.NoError(t, err) + + assert.Equal(t, "Steve", v.GetString("name")) + assert.Equal(t, 8080, v.GetInt("port")) + assert.Equal(t, "88888-88888", v.GetString("auth.secret")) + assert.Equal(t, "foo", v.GetString("clients.0.name")) + assert.Equal(t, "bar", v.GetString("clients.1.name")) + assert.Equal(t, "proxy_foo", v.GetString("proxy.clients.0.name")) + + // Override with env variable + t.Setenv("NAME", "Steven") + t.Setenv("AUTH_SECRET", "99999-99999") + t.Setenv("CLIENTS_1_NAME", "baz") + t.Setenv("PROXY_CLIENTS_0_NAME", "ProxyFoo") + + SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + AutomaticEnv() + + // Unmarshal into struct + var config Configuration + v.Unmarshal(&config) + + assert.Equal(t, "Steven", config.Name) + assert.Equal(t, 8080, config.Port) + assert.Equal(t, "99999-99999", config.Auth.Secret) + assert.Equal(t, "foo", config.Clients[0].Name) + assert.Equal(t, "baz", config.Clients[1].Name) + assert.Equal(t, "ProxyFoo", config.Proxy.Clients[0].Name) +} + func TestIsPathShadowedInFlatMap(t *testing.T) { v := New() From 44bde863fa406a5e4bcc003d02fd14a4ba79eb7e Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Wed, 3 Jan 2024 14:10:52 +0800 Subject: [PATCH 2/4] handle int slice as well --- viper.go | 47 +++++++++++++++++++++++------------------------ viper_test.go | 8 ++++++++ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/viper.go b/viper.go index b658ffda3..bb43c148f 100644 --- a/viper.go +++ b/viper.go @@ -1149,19 +1149,11 @@ func unmarshalPostProcess(input any, opts ...DecoderConfigOption) error { return err } - v.postProcessingSliceFields(map[string]bool{}, structKeyMap, "") + v.postProcessingSliceFields(structKeyMap, "") return nil } -// TODO remove shadow -func (v *Viper) postProcessingSliceFields(shadow map[string]bool, m map[string]any, prefix string) map[string]bool { - if shadow != nil && prefix != "" && shadow[prefix] { - // prefix is shadowed => nothing more to flatten - return shadow - } - if shadow == nil { - shadow = make(map[string]bool) - } +func (v *Viper) postProcessingSliceFields(m map[string]any, prefix string) { var m2 map[string]any if prefix != "" { @@ -1173,21 +1165,29 @@ func (v *Viper) postProcessingSliceFields(shadow map[string]bool, m map[string]a if valValue.Kind() == reflect.Slice { for i := 0; i < valValue.Len(); i++ { item := valValue.Index(i) - if item.Kind() != reflect.Struct || !item.CanSet() { + iStr := strconv.FormatInt(int64(i), 10) + + fmt.Printf("item %v\n", item) + if !item.CanSet() { continue } - itemType := item.Type() - for j := 0; j < item.NumField(); j++ { - field := itemType.Field(j) - // fmt.Printf("Field %d: Name=%s, Type=%v, Value=%v\n", j, field.Name, field.Type, item.Field(j).Interface()) - - sliceKey := fmt.Sprintf("%s%s%s%d%s%s", prefix, k, v.keyDelim, i, v.keyDelim, field.Name) - shadow[strings.ToLower(sliceKey)] = true - // fmt.Printf("%s is slice\n", sliceKey) - + if item.Kind() == reflect.Struct { + itemType := item.Type() + for j := 0; j < item.NumField(); j++ { + field := itemType.Field(j) + sliceKey := prefix + k + v.keyDelim + iStr + v.keyDelim + field.Name + // fmt.Printf("%s is slice\n", sliceKey) + + if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok { + // fmt.Printf("Val is %v\n", val) + item.Field(j).SetString(val) + } + } + } else { + sliceKey := prefix + k + v.keyDelim + iStr if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok { - // fmt.Printf("Val is %v\n", val) - item.Field(j).SetString(val) + intValue, _ := strconv.ParseInt(val, 10, 32) + item.SetInt(intValue) } } } @@ -1202,9 +1202,8 @@ func (v *Viper) postProcessingSliceFields(shadow map[string]bool, m map[string]a continue } // recursively merge to shadow map - shadow = v.postProcessingSliceFields(shadow, m2, fullKey) + v.postProcessingSliceFields(m2, fullKey) } - return shadow } func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { diff --git a/viper_test.go b/viper_test.go index 30460b047..0375300d2 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2611,6 +2611,10 @@ name: Steve port: 8080 auth: secret: 88888-88888 +modes: + - 1 + - 2 + - 3 clients: - name: foo - name: bar @@ -2641,6 +2645,7 @@ func TestSliceIndexAutomaticEnv(t *testing.T) { Port int Name string Auth AuthConfig + Modes []int Clients []ClientConfig Proxy ProxyConfig } @@ -2655,10 +2660,12 @@ func TestSliceIndexAutomaticEnv(t *testing.T) { assert.Equal(t, "foo", v.GetString("clients.0.name")) assert.Equal(t, "bar", v.GetString("clients.1.name")) assert.Equal(t, "proxy_foo", v.GetString("proxy.clients.0.name")) + assert.Equal(t, []int{1, 2, 3}, v.GetIntSlice("modes")) // Override with env variable t.Setenv("NAME", "Steven") t.Setenv("AUTH_SECRET", "99999-99999") + t.Setenv("MODES_2", "300") t.Setenv("CLIENTS_1_NAME", "baz") t.Setenv("PROXY_CLIENTS_0_NAME", "ProxyFoo") @@ -2672,6 +2679,7 @@ func TestSliceIndexAutomaticEnv(t *testing.T) { assert.Equal(t, "Steven", config.Name) assert.Equal(t, 8080, config.Port) assert.Equal(t, "99999-99999", config.Auth.Secret) + assert.Equal(t, []int{1, 2, 300}, config.Modes) assert.Equal(t, "foo", config.Clients[0].Name) assert.Equal(t, "baz", config.Clients[1].Name) assert.Equal(t, "ProxyFoo", config.Proxy.Clients[0].Name) From 51a5afb83c41f07613b0c3bf319ec064b9b8e68e Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Sat, 6 Jan 2024 16:20:45 +0800 Subject: [PATCH 3/4] new cleaner approach --- viper.go | 129 +++++++++++++++++++++++-------------------------------- 1 file changed, 53 insertions(+), 76 deletions(-) diff --git a/viper.go b/viper.go index bb43c148f..77f29aa87 100644 --- a/viper.go +++ b/viper.go @@ -689,6 +689,53 @@ func (v *Viper) searchMap(source map[string]any, path []string) any { return nil } +func (v *Viper) searchAndReplaceSliceValueWithEnv(source any, envKey string) any { + switch v1 := source.(type) { + case []any: + var newSlices []any + for i, value := range v1 { + envKey := envKey + v.keyDelim + strconv.Itoa(i) + switch v2 := value.(type) { + case map[string]any: + val := v.searchAndReplaceSliceValueWithEnv(v2, envKey) + newSlices = append(newSlices, val) + default: + if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newSlices = append(newSlices, val) + } else { + newSlices = append(newSlices, v2) + } + } + } + return newSlices + case map[string]any: + var newMapValue map[string]any = make(map[string]any) + + for k, v2 := range v1 { + envKey := envKey + v.keyDelim + k + switch v3 := v2.(type) { + case map[string]any: + val := v.searchAndReplaceSliceValueWithEnv(v3, envKey) + newMapValue[k] = val + + default: + if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newMapValue[k] = val + } else { + newMapValue[k] = v3 + } + } + } + return newMapValue + default: + if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + return val + } else { + return source + } + } +} + // searchIndexableWithPathPrefixes recursively searches for a value for path in source map/slice. // // While searchMap() considers each path element as a single map key or slice index, this @@ -906,6 +953,11 @@ func (v *Viper) Get(key string) any { return nil } + // Check for Env override again, to handle slices + if v.automaticEnvApplied { + val = v.searchAndReplaceSliceValueWithEnv(val, lcaseKey) + } + if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val @@ -1128,82 +1180,7 @@ func (v *Viper) Unmarshal(rawVal any, opts ...DecoderConfigOption) error { } // TODO: struct keys should be enough? - err := decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) - - // Post processing for slice of maps - // if features.BindStruct { - err = unmarshalPostProcess(rawVal, opts...) - if err != nil { - return err - } - // } - - return err -} - -func unmarshalPostProcess(input any, opts ...DecoderConfigOption) error { - var structKeyMap map[string]any - - err := decode(input, defaultDecoderConfig(&structKeyMap, opts...)) - if err != nil { - return err - } - - v.postProcessingSliceFields(structKeyMap, "") - return nil -} - -func (v *Viper) postProcessingSliceFields(m map[string]any, prefix string) { - - var m2 map[string]any - if prefix != "" { - prefix += v.keyDelim - } - for k, val := range m { - fullKey := prefix + k - valValue := reflect.ValueOf(val) - if valValue.Kind() == reflect.Slice { - for i := 0; i < valValue.Len(); i++ { - item := valValue.Index(i) - iStr := strconv.FormatInt(int64(i), 10) - - fmt.Printf("item %v\n", item) - if !item.CanSet() { - continue - } - if item.Kind() == reflect.Struct { - itemType := item.Type() - for j := 0; j < item.NumField(); j++ { - field := itemType.Field(j) - sliceKey := prefix + k + v.keyDelim + iStr + v.keyDelim + field.Name - // fmt.Printf("%s is slice\n", sliceKey) - - if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok { - // fmt.Printf("Val is %v\n", val) - item.Field(j).SetString(val) - } - } - } else { - sliceKey := prefix + k + v.keyDelim + iStr - if val, ok := v.getEnv(v.mergeWithEnvPrefix(sliceKey)); ok { - intValue, _ := strconv.ParseInt(val, 10, 32) - item.SetInt(intValue) - } - } - } - } - - switch val := val.(type) { - case map[string]any: - m2 = val - case map[any]any: - m2 = cast.ToStringMap(val) - default: - continue - } - // recursively merge to shadow map - v.postProcessingSliceFields(m2, fullKey) - } + return decode(v.getSettings(keys), defaultDecoderConfig(rawVal, opts...)) } func (v *Viper) decodeStructKeys(input any, opts ...DecoderConfigOption) ([]string, error) { From cb5aa46b9755207a84abc724e4120fea1bde491c Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Sat, 6 Jan 2024 20:41:43 +0800 Subject: [PATCH 4/4] clean up --- viper.go | 52 +++++++++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/viper.go b/viper.go index 77f29aa87..c223a90e9 100644 --- a/viper.go +++ b/viper.go @@ -689,47 +689,53 @@ func (v *Viper) searchMap(source map[string]any, path []string) any { return nil } +// searchMapWithAliases recursively searches for slice field in source map and +// replace them with the environment variable value if it exists. +// +// Returns replaced values. func (v *Viper) searchAndReplaceSliceValueWithEnv(source any, envKey string) any { - switch v1 := source.(type) { + switch sourceValue := source.(type) { case []any: - var newSlices []any - for i, value := range v1 { + var newSliceValues []any + for i, sliceValue := range sourceValue { envKey := envKey + v.keyDelim + strconv.Itoa(i) - switch v2 := value.(type) { + switch existingValue := sliceValue.(type) { case map[string]any: - val := v.searchAndReplaceSliceValueWithEnv(v2, envKey) - newSlices = append(newSlices, val) + newVal := v.searchAndReplaceSliceValueWithEnv(existingValue, envKey) + newSliceValues = append(newSliceValues, newVal) + default: - if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { - newSlices = append(newSlices, val) + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newSliceValues = append(newSliceValues, newVal) } else { - newSlices = append(newSlices, v2) + newSliceValues = append(newSliceValues, existingValue) } } } - return newSlices - case map[string]any: - var newMapValue map[string]any = make(map[string]any) + return newSliceValues - for k, v2 := range v1 { - envKey := envKey + v.keyDelim + k - switch v3 := v2.(type) { + case map[string]any: + var newMapValues map[string]any = make(map[string]any) + for key, mapValue := range sourceValue { + envKey := envKey + v.keyDelim + key + switch existingValue := mapValue.(type) { case map[string]any: - val := v.searchAndReplaceSliceValueWithEnv(v3, envKey) - newMapValue[k] = val + newVal := v.searchAndReplaceSliceValueWithEnv(existingValue, envKey) + newMapValues[key] = newVal default: - if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { - newMapValue[k] = val + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newMapValues[key] = newVal } else { - newMapValue[k] = v3 + newMapValues[key] = existingValue } } } - return newMapValue + return newMapValues + default: - if val, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { - return val + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + return newVal } else { return source }