From cafbccdf2c2c83788024674a26760dac972bd4c9 Mon Sep 17 00:00:00 2001 From: faiz-alhadiid Date: Fri, 8 Nov 2024 14:29:03 +0700 Subject: [PATCH] feat: add strict domain checking config attribute and use it in /login --- handlers/login.go | 8 +++++++- pkg/cfg/cfg.go | 31 ++++++++++++++++--------------- pkg/domains/domains.go | 19 +++++++++++++++++++ pkg/domains/domains_test.go | 11 ++++++++++- 4 files changed, 52 insertions(+), 17 deletions(-) diff --git a/handlers/login.go b/handlers/login.go index c16f501c..3d7b04d6 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -228,7 +228,13 @@ func getValidRequestedURL(r *http.Request) (string, error) { hostname := u.Hostname() if cfg.GenOAuth.Provider != cfg.Providers.IndieAuth { - d := domains.Matches(hostname) + var d string + if cfg.Cfg.StrictDomainCheck { + d = domains.MatchesStrict(hostname) + } else { + d = domains.Matches(hostname) + } + if d == "" { inCookieDomain := (hostname == cfg.Cfg.Cookie.Domain || strings.HasSuffix(hostname, "."+cfg.Cfg.Cookie.Domain)) if cfg.Cfg.Cookie.Domain == "" || !inCookieDomain { diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index 9aedba64..de355425 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -43,21 +43,22 @@ import ( // though most of the time envconfig will use the struct key's name: VOUCH_PORT VOUCH_JWT_MAXAGE // default values should be set in .defaults.yml type Config struct { - LogLevel string `mapstructure:"logLevel"` - Listen string `mapstructure:"listen"` - Port int `mapstructure:"port"` - SocketMode int `mapstructure:"socket_mode"` - SocketGroup string `mapstructure:"socket_group"` - DocumentRoot string `mapstructure:"document_root" envconfig:"document_root"` - WriteTimeout int `mapstructure:"writeTimeout"` - ReadTimeout int `mapstructure:"readTimeout"` - IdleTimeout int `mapstructure:"idleTimeout"` - Domains []string `mapstructure:"domains"` - WhiteList []string `mapstructure:"whitelist"` - TeamWhiteList []string `mapstructure:"teamWhitelist"` - AllowAllUsers bool `mapstructure:"allowAllUsers"` - PublicAccess bool `mapstructure:"publicAccess"` - TLS struct { + LogLevel string `mapstructure:"logLevel"` + Listen string `mapstructure:"listen"` + Port int `mapstructure:"port"` + SocketMode int `mapstructure:"socket_mode"` + SocketGroup string `mapstructure:"socket_group"` + DocumentRoot string `mapstructure:"document_root" envconfig:"document_root"` + WriteTimeout int `mapstructure:"writeTimeout"` + ReadTimeout int `mapstructure:"readTimeout"` + IdleTimeout int `mapstructure:"idleTimeout"` + Domains []string `mapstructure:"domains"` + StrictDomainCheck bool `mapstructure:"strict_domain_check"` + WhiteList []string `mapstructure:"whitelist"` + TeamWhiteList []string `mapstructure:"teamWhitelist"` + AllowAllUsers bool `mapstructure:"allowAllUsers"` + PublicAccess bool `mapstructure:"publicAccess"` + TLS struct { Cert string `mapstructure:"cert"` Key string `mapstructure:"key"` Profile string `mapstructure:"profile"` diff --git a/pkg/domains/domains.go b/pkg/domains/domains.go index 32fbfbd4..db785bd7 100644 --- a/pkg/domains/domains.go +++ b/pkg/domains/domains.go @@ -47,6 +47,25 @@ func Matches(s string) string { return "" } +// MatchesStrict returns one of the domains we're configured for with strict comparison (not matching *.domain.com) +func MatchesStrict(s string) string { + if strings.Contains(s, ":") { + // then we have a port and we just want to check the host + split := strings.Split(s, ":") + log.Debugf("removing port from %s to test domain %s", s, split[0]) + s = split[0] + } + + for i, v := range cfg.Cfg.Domains { + if s == v { + log.Debugf("domain %s matched array value at [%d]=%v", s, i, v) + return v + } + } + log.Warnf("domain %s not found in any domains %v", s, cfg.Cfg.Domains) + return "" +} + // IsUnderManagement check if an email is under vouch-managed domain func IsUnderManagement(email string) bool { split := strings.Split(email, "@") diff --git a/pkg/domains/domains_test.go b/pkg/domains/domains_test.go index c65f5fa8..a86a68f8 100644 --- a/pkg/domains/domains_test.go +++ b/pkg/domains/domains_test.go @@ -11,9 +11,9 @@ OR CONDITIONS OF ANY KIND, either express or implied. package domains import ( - "github.com/stretchr/testify/assert" "testing" + "github.com/stretchr/testify/assert" "github.com/vouch/vouch-proxy/pkg/cfg" ) @@ -50,3 +50,12 @@ func TestMatches(t *testing.T) { assert.Equal(t, "sub.test.mydomain.com", Matches("subsub.sub.test.mydomain.com")) assert.Equal(t, "test.mydomain.com", Matches("other.test.mydomain.com")) } + +func TestMatchesStrict(t *testing.T) { + assert.Equal(t, "vouch.github.io", MatchesStrict("vouch.github.io")) + assert.Equal(t, "", MatchesStrict("sub.vouch.github.io")) + assert.Equal(t, "", MatchesStrict("a-different-vouch.github.io")) + + assert.Equal(t, "", MatchesStrict("mydomain.com")) + +}