diff --git a/config/config.go b/config/config.go index 2a2ece3..90d98a7 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,8 @@ package config import ( "fmt" + "os" + "reflect" "strings" "github.com/underdog-tech/vulnbot/logger" @@ -28,38 +30,55 @@ type Config struct { Team []TeamConfig } +func fileExists(fname string) bool { + if _, err := os.Stat(fname); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + func GetUserConfig(configFile string) (Config, error) { log := logger.Get() userCfg := Config{} - // Set up env var overrides - replacer := strings.NewReplacer("-", "_") - viper.SetEnvKeyReplacer(replacer) - viper.SetEnvPrefix("vulnbot") - viper.AutomaticEnv() + // Use reflection to register all config fields in Viper to set up defaults + cfgFields := reflect.ValueOf(userCfg) + cfgType := cfgFields.Type() + + for i := 0; i < cfgFields.NumField(); i++ { + viper.SetDefault(cfgType.Field(i).Name, cfgFields.Field(i).Interface()) + } // Load the main config file + if !fileExists(configFile) { + log.Fatal().Str("config", configFile).Msg("Config file not found.") + } viper.SetConfigFile(configFile) if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Fatal().Str("config", configFile).Err(err).Msg("Config file not found.") - } else { - log.Fatal().Err(err).Msg("Error reading config file.") - } + log.Fatal().Err(err).Msg("Error reading config file.") } - viper.Unmarshal(&userCfg) // (Optionally) Load a .env file - viper.SetConfigFile("./.env") - viper.SetConfigType("env") - if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { - log.Warn().Msg("No .env file found; not loaded.") - } else { + if fileExists("./.env") { + viper.SetConfigFile("./.env") + viper.SetConfigType("env") + if err := viper.ReadInConfig(); err != nil { log.Error().Err(err).Msg("Error loading .env file.") } + } else { + log.Warn().Msg("No .env file found; not loaded.") } + + // Set up env var overrides + replacer := strings.NewReplacer("-", "_") + viper.SetEnvKeyReplacer(replacer) + viper.SetEnvPrefix("vulnbot") + viper.AutomaticEnv() + + // Finally, copy all loaded values into the config object viper.Unmarshal(&userCfg) return userCfg, nil diff --git a/config/config_test.go b/config/config_test.go index 2512f8e..5a86a12 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -69,38 +69,29 @@ func getCurrentDir() (string, error) { return cwd, nil } -func TestLoadConfig(t *testing.T) { - expectedSlackChannel := "testing_slack_channel" +func TestGetUserConfigFromFile(t *testing.T) { currentDir, err := getCurrentDir() - if err != nil { - assert.Error(t, err) - } + assert.Nil(t, err) testDataPath := filepath.Join(currentDir, "/testdata/test_config.toml") - - var cfg config.Config - _ = config.LoadConfig(config.ViperParams{ - Output: &cfg, - ConfigPath: &testDataPath, - }) - - assert.IsType(t, config.Config{}, cfg) - assert.Equal(t, expectedSlackChannel, cfg.Default_slack_channel) + cfg, err := config.GetUserConfig(testDataPath) + assert.Nil(t, err) + assert.Equal(t, "testing_slack_channel", cfg.Default_slack_channel) + assert.Equal(t, []config.EcosystemConfig{{Label: "Go", Slack_emoji: ":golang:"}}, cfg.Ecosystem) } -func TestLoadEnv(t *testing.T) { - expectedSlackAuthToken := "testing" - currentDir, err := getCurrentDir() - if err != nil { - assert.Error(t, err) - } - testDataPath := filepath.Join(currentDir, "/testdata/config.env") +func TestGetUserConfigFromEnv(t *testing.T) { + t.Setenv("VULNBOT_DISABLE_SLACK", "1") + t.Setenv("VULNBOT_GITHUB_ORG", "hitchhikers") + // This should override the config file + t.Setenv("VULNBOT_DEFAULT_SLACK_CHANNEL", "other_slack_channel") - var env config.Env - _ = config.LoadEnv(config.ViperParams{ - EnvFileName: &testDataPath, - Output: &env, - }) + currentDir, err := getCurrentDir() + assert.Nil(t, err) + testDataPath := filepath.Join(currentDir, "/testdata/test_config.toml") + cfg, err := config.GetUserConfig(testDataPath) + assert.Nil(t, err) - assert.IsType(t, config.Env{}, env) - assert.Equal(t, expectedSlackAuthToken, env.SlackAuthToken) + assert.True(t, cfg.Disable_slack) + assert.Equal(t, "hitchhikers", cfg.Github_org) + assert.Equal(t, "other_slack_channel", cfg.Default_slack_channel) }