Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add strict_domain_check config attribute to be used in GET /login #4

Open
wants to merge 1 commit into
base: cermati/v0.39.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion handlers/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,13 @@ func getValidRequestedURL(r *http.Request) (string, error) {

hostname := u.Hostname()
if cfg.GenOAuth.Provider != cfg.Providers.IndieAuth {
d := domains.Matches(hostname)
var d string
if cfg.Cfg.StrictDomainCheck {
d = domains.MatchesStrict(hostname)
} else {
d = domains.Matches(hostname)
}

if d == "" {
inCookieDomain := (hostname == cfg.Cfg.Cookie.Domain || strings.HasSuffix(hostname, "."+cfg.Cfg.Cookie.Domain))
if cfg.Cfg.Cookie.Domain == "" || !inCookieDomain {
Expand Down
31 changes: 16 additions & 15 deletions pkg/cfg/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,22 @@ import (
// though most of the time envconfig will use the struct key's name: VOUCH_PORT VOUCH_JWT_MAXAGE
// default values should be set in .defaults.yml
type Config struct {
LogLevel string `mapstructure:"logLevel"`
Listen string `mapstructure:"listen"`
Port int `mapstructure:"port"`
SocketMode int `mapstructure:"socket_mode"`
SocketGroup string `mapstructure:"socket_group"`
DocumentRoot string `mapstructure:"document_root" envconfig:"document_root"`
WriteTimeout int `mapstructure:"writeTimeout"`
ReadTimeout int `mapstructure:"readTimeout"`
IdleTimeout int `mapstructure:"idleTimeout"`
Domains []string `mapstructure:"domains"`
WhiteList []string `mapstructure:"whitelist"`
TeamWhiteList []string `mapstructure:"teamWhitelist"`
AllowAllUsers bool `mapstructure:"allowAllUsers"`
PublicAccess bool `mapstructure:"publicAccess"`
TLS struct {
LogLevel string `mapstructure:"logLevel"`
Listen string `mapstructure:"listen"`
Port int `mapstructure:"port"`
SocketMode int `mapstructure:"socket_mode"`
SocketGroup string `mapstructure:"socket_group"`
DocumentRoot string `mapstructure:"document_root" envconfig:"document_root"`
WriteTimeout int `mapstructure:"writeTimeout"`
ReadTimeout int `mapstructure:"readTimeout"`
IdleTimeout int `mapstructure:"idleTimeout"`
Domains []string `mapstructure:"domains"`
StrictDomainCheck bool `mapstructure:"strict_domain_check"`
WhiteList []string `mapstructure:"whitelist"`
TeamWhiteList []string `mapstructure:"teamWhitelist"`
AllowAllUsers bool `mapstructure:"allowAllUsers"`
PublicAccess bool `mapstructure:"publicAccess"`
TLS struct {
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
Profile string `mapstructure:"profile"`
Expand Down
19 changes: 19 additions & 0 deletions pkg/domains/domains.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ func Matches(s string) string {
return ""
}

// MatchesStrict returns one of the domains we're configured for with strict comparison (not matching *.domain.com)
func MatchesStrict(s string) string {
if strings.Contains(s, ":") {
// then we have a port and we just want to check the host
split := strings.Split(s, ":")
log.Debugf("removing port from %s to test domain %s", s, split[0])
s = split[0]
}

for i, v := range cfg.Cfg.Domains {
if s == v {
log.Debugf("domain %s matched array value at [%d]=%v", s, i, v)
return v
}
}
log.Warnf("domain %s not found in any domains %v", s, cfg.Cfg.Domains)
return ""
}

// IsUnderManagement check if an email is under vouch-managed domain
func IsUnderManagement(email string) bool {
split := strings.Split(email, "@")
Expand Down
11 changes: 10 additions & 1 deletion pkg/domains/domains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ OR CONDITIONS OF ANY KIND, either express or implied.
package domains

import (
"github.com/stretchr/testify/assert"
"testing"

"github.com/stretchr/testify/assert"
"github.com/vouch/vouch-proxy/pkg/cfg"
)

Expand Down Expand Up @@ -50,3 +50,12 @@ func TestMatches(t *testing.T) {
assert.Equal(t, "sub.test.mydomain.com", Matches("subsub.sub.test.mydomain.com"))
assert.Equal(t, "test.mydomain.com", Matches("other.test.mydomain.com"))
}

func TestMatchesStrict(t *testing.T) {
assert.Equal(t, "vouch.github.io", MatchesStrict("vouch.github.io"))
assert.Equal(t, "", MatchesStrict("sub.vouch.github.io"))
assert.Equal(t, "", MatchesStrict("a-different-vouch.github.io"))

assert.Equal(t, "", MatchesStrict("mydomain.com"))

}