Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve provider name handling #213

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package auth
import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -270,10 +272,42 @@ func (s *Service) addProviderByName(name string, p provider.Params) {
}

func (s *Service) addProvider(prov provider.Provider) {
if !s.isValidProviderName(prov.Name()) {
return
}
s.providers = append(s.providers, provider.NewService(prov))
s.authMiddleware.Providers = s.providers
}

func (s *Service) isValidProviderName(name string) bool {
if strings.TrimSpace(name) == "" {
s.logger.Logf("[ERROR] provider has been ignored because its name is empty")
return false
}

formatForbidden := func(name string) {
s.logger.Logf("[ERROR] provider has been ignored because its name contains forbidden characters: '%s'", name)
}

path, err := url.PathUnescape(name)
if err != nil || path != name {
formatForbidden(name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small thing: in debug mode, this function's redirect for logging will lose the source line information. I'd suggest creating an error message instead and logging it directly.

return false
}
if name != url.PathEscape(name) {
formatForbidden(name)
return false
}
// net/url package does not escape everything (https://github.com/golang/go/issues/5684)
// It is better to reject all reserved characters from https://datatracker.ietf.org/doc/html/rfc3986#section-2.2
if regexp.MustCompile(`[:/?#\[\]@!$&'\(\)*+,;=]`).MatchString(name) {
formatForbidden(name)
return false
}

return true
}

// AddProvider adds provider for given name
func (s *Service) AddProvider(name, cid, csecret string) {
p := provider.Params{
Expand Down
36 changes: 32 additions & 4 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,34 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

func TestIntegrationInvalidProviderNames(t *testing.T) {
invalidNames := []string{
"provider/with/slashes",
"provider with spaces",
" providerWithSpacesAround\t",
"providerWithReserved-$-Char",
"providerWithReserved-&-Char",
"providerWithReserved-+-Char",
"providerWithReserved-,-Char",
"providerWithReserved-:-Char",
"providerWithReserved-;-Char",
"providerWithReserved-=-Char",
"providerWithReserved-?-Char",
"providerWithReserved-@-Char",
"providerWith%2F-EscapedSequence",
"",
}
svc, teardown := prepService(t, func(svc *Service) {
for _, name := range invalidNames {
svc.AddCustomProvider(name, Client{"cid", "csecret"}, provider.CustomHandlerOpt{})
}
})
defer teardown()

require.Equal(t, 1, len(svc.Providers()))
require.Equal(t, "dev", svc.Providers()[0].Name())
}

func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
Expand Down Expand Up @@ -387,7 +415,7 @@ func TestDirectProvider(t *testing.T) {

func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProviderWithUserIDFunc("directCustom",
svc.AddDirectProviderWithUserIDFunc("direct_custom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
Expand All @@ -402,12 +430,12 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
jar, err := cookiejar.New(nil)
require.Nil(t, err)
client := &http.Client{Jar: jar, Timeout: 5 * time.Second}
resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad")
resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)

resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password")
resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
Expand All @@ -417,7 +445,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
t.Logf("resp %s", string(body))
t.Logf("headers: %+v", resp.Header)

assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)
assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)

require.Equal(t, 2, len(resp.Cookies()))
assert.Equal(t, "JWT", resp.Cookies()[0].Name)
Expand Down
23 changes: 17 additions & 6 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
}

// check if user provider is allowed
if !a.isProviderAllowed(claims.User.ID) {
if !a.isProviderAllowed(&claims) {
onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID))
a.JWTService.Reset(w)
return
Expand All @@ -153,13 +153,24 @@ func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler {
return f
}

// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890"
// this check is needed to reject users from providers what are used to be allowed but not anymore.
// isProviderAllowed checks if user provider is allowed.
// If provider name is explicitly set in the token claims, then that provider is checked.
//
// If user id looks like "provider_1234567890",
// then there is an attempt to extract provider name from that user ID.
// Note that such read can fail if user id has multiple "_" separator symbols.
//
// This check is needed to reject users from providers what are used to be allowed but not anymore.
// Such users made token before the provider was disabled and should not be allowed to login anymore.
func (a *Authenticator) isProviderAllowed(userID string) bool {
userProvider := strings.Split(userID, "_")[0]
func (a *Authenticator) isProviderAllowed(claims *token.Claims) bool {
// TODO: remove this read when old tokens expire and all new tokens have a provider name in them
userIDProvider := strings.Split(claims.User.ID, "_")[0]
for _, p := range a.Providers {
if p.Name() == userProvider {
name := p.Name()
if claims.AuthProvider != nil && claims.AuthProvider.Name == name {
return true
}
if name == userIDProvider {
return true
}
}
Expand Down
55 changes: 26 additions & 29 deletions middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4Mj

var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g"

var testJwtValidWithAuthProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiYXV0aF9wcm92aWRlciI6eyJuYW1lIjoicHJvdmlkZXIxIn19.iBKM9-lgejJNjcs-crj6gkEejnIJpavmaq8alenf0JA"

func TestAuthJWTCookie(t *testing.T) {
a := makeTestAuth(t)

Expand All @@ -51,56 +53,51 @@ func TestAuthJWTCookie(t *testing.T) {
client := &http.Client{Timeout: 5 * time.Second}
expiration := int(365 * 24 * time.Hour.Seconds()) //nolint

t.Run("valid token", func(t *testing.T) {
makeRequest := func(jwtCookie string, xsrfToken string) *http.Response {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
req.AddCookie(&http.Cookie{
Name: "JWT",
Value: jwtCookie,
HttpOnly: true,
Path: "/",
MaxAge: expiration,
Secure: false,
})
req.Header.Add("X-XSRF-TOKEN", xsrfToken)

resp, err := client.Do(req)
require.NoError(t, err)
return resp
}

t.Run("valid token", func(t *testing.T) {
resp := makeRequest(testJwtValid, "random id")
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

t.Run("valid token, wrong provider", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/",
MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
t.Run("valid token with auth_provider", func(t *testing.T) {
resp := makeRequest(testJwtValidWithAuthProvider, "random id")
assert.Equal(t, 201, resp.StatusCode, "valid token user")
})

resp, err := client.Do(req)
require.NoError(t, err)
t.Run("valid token, wrong provider", func(t *testing.T) {
resp := makeRequest(testJwtValidWrongProvider, "random id")
assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed")
})

t.Run("xsrf mismatch", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "wrong id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtValid, "wrong id")
assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch")
})

t.Run("token expired and refreshed", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtExpired, "random id")
assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed")
})

t.Run("no user info in the token", func(t *testing.T) {
req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false})
req.Header.Add("X-XSRF-TOKEN", "random id")
resp, err := client.Do(req)
require.NoError(t, err)
resp := makeRequest(testJwtNoUser, "random id")
assert.Equal(t, 401, resp.StatusCode, "no user info in the token")
})
}
Expand Down
6 changes: 6 additions & 0 deletions provider/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
ExpiresAt: time.Now().Add(30 * time.Minute).Unix(),
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
AuthProvider: &token.AuthProvider{
Name: ah.name,
},
}

if _, err = ah.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -376,6 +379,9 @@ func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) {
Audience: oauthClaims.Audience,
},
SessionOnly: false,
AuthProvider: &token.AuthProvider{
Name: ah.name,
},
}

if _, err = ah.JwtService.Set(w, claims); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions provider/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
Audience: creds.Audience,
},
SessionOnly: sessOnly,
AuthProvider: &token.AuthProvider{
Name: p.ProviderName,
},
}

if _, err = p.TokenService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/oauth1.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
ExpiresAt: time.Now().Add(30 * time.Minute).Unix(),
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
AuthProvider: &token.AuthProvider{
Name: h.name,
},
}

if _, err = h.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -146,6 +149,9 @@ func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) {
Audience: oauthClaims.Audience,
},
SessionOnly: oauthClaims.SessionOnly,
AuthProvider: &token.AuthProvider{
Name: h.name,
},
}

if _, err = h.JwtService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
NoAva: r.URL.Query().Get("noava") == "1",
AuthProvider: &token.AuthProvider{
Name: p.name,
},
}

if _, err := p.JwtService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -215,6 +218,9 @@ func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) {
},
SessionOnly: oauthClaims.SessionOnly,
NoAva: oauthClaims.NoAva,
AuthProvider: &token.AuthProvider{
Name: p.name,
},
}

if _, err = p.JwtService.Set(w, claims); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions provider/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request)
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
},
SessionOnly: false, // TODO review?
AuthProvider: &authtoken.AuthProvider{
Name: th.ProviderName,
},
}

if _, err := th.TokenService.Set(w, claims); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions provider/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
Audience: confClaims.Audience,
},
SessionOnly: sessOnly,
AuthProvider: &token.AuthProvider{
Name: e.ProviderName,
},
}

if _, err = e.TokenService.Set(w, claims); err != nil {
Expand Down Expand Up @@ -152,6 +155,9 @@ func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request)
NotBefore: time.Now().Add(-1 * time.Minute).Unix(),
Issuer: e.Issuer,
},
AuthProvider: &token.AuthProvider{
Name: e.ProviderName,
},
}

tkn, err := e.TokenService.Token(claims)
Expand Down
14 changes: 10 additions & 4 deletions token/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ type Service struct {
// Claims stores user info for token and state & from from login
type Claims struct {
jwt.StandardClaims
User *User `json:"user,omitempty"` // user info
SessionOnly bool `json:"sess_only,omitempty"`
Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake
NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon
User *User `json:"user,omitempty"` // user info
SessionOnly bool `json:"sess_only,omitempty"`
Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake
NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon
AuthProvider *AuthProvider `json:"auth_provider,omitempty"` // auth provider info
}

// Handshake used for oauth handshake
Expand All @@ -34,6 +35,11 @@ type Handshake struct {
ID string `json:"id,omitempty"`
}

// AuthProvider stores attributes of provider which has created a JWT token
type AuthProvider struct {
Name string `json:"name,omitempty"`
}

const (
// default names for cookies and headers
defaultJWTCookieName = "JWT"
Expand Down
Loading
Loading