Skip to content

Commit

Permalink
feat(api,pkg)!: new jwt decoding implementation
Browse files Browse the repository at this point in the history
We’re refactoring the `AuthRequest` handler. The handler will no longer
use the `AuthMiddleware`, and JWT decoding is now the responsibility of
the `jwttoken` package, which already handles token encoding. The main
reason for this change is the deprecation of
`github.com/labstack/echo/v4/middleware.JWTWithConfig`.

The `UserAuthClaims` and `DeviceAuthClaims` have also been moved to the
`authorizer` package. The `jwttoken` package now uses these types to
encode and decode tokens. Consequently, the `EncodeUserClaims` and
`EncodeDeviceClaims` functions replace the old generic `Encode`
function, which is now used only internally. The `ClaimsFromBearerToken`
function is now responsible for decoding a token into either user or
device claims.

The handler’s documentation has been updated to describe its use as a
"proxy-level" middleware, including a sequential diagram. All related
tests have also been improved.
  • Loading branch information
heiytor authored and gustavosbarreto committed Jul 19, 2024
1 parent 55161a8 commit 839be80
Show file tree
Hide file tree
Showing 20 changed files with 361 additions and 262 deletions.
2 changes: 0 additions & 2 deletions agent/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ require (
github.com/go-playground/validator/v10 v10.11.2 // indirect
github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/kr/fs v0.1.0 // indirect
Expand Down
4 changes: 0 additions & 4 deletions agent/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPr
github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is=
Expand Down
3 changes: 1 addition & 2 deletions api/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ go 1.21
require (
github.com/cnf/structhash v0.0.0-20201127153200-e1b16c1ebc08
github.com/getsentry/sentry-go v0.28.1
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/hibiken/asynq v0.24.1
github.com/labstack/echo/v4 v4.12.0
github.com/labstack/gommon v0.4.2
github.com/mitchellh/mapstructure v1.5.0
github.com/pkg/errors v0.9.1
github.com/shellhub-io/mongotest v0.0.0-20230928124937-e33b07010742
github.com/shellhub-io/shellhub v0.13.4
Expand Down Expand Up @@ -55,6 +53,7 @@ require (
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.6.0 // indirect
Expand Down
2 changes: 0 additions & 2 deletions api/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh
github.com/mattn/goveralls v0.0.9/go.mod h1:FRbM1PS8oVsOe9JtdzAAXM+DsvDMMHcM1C7drGJD8HY=
github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Clwo=
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
Expand Down
149 changes: 43 additions & 106 deletions api/routes/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@ import (
"net/http"
"strconv"

jwt "github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mitchellh/mapstructure"
"github.com/shellhub-io/shellhub/api/pkg/gateway"
errs "github.com/shellhub-io/shellhub/api/routes/errors"
svc "github.com/shellhub-io/shellhub/api/services"
client "github.com/shellhub-io/shellhub/pkg/api/internalclient"
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
"github.com/shellhub-io/shellhub/pkg/api/jwttoken"
"github.com/shellhub-io/shellhub/pkg/api/requests"
"github.com/shellhub-io/shellhub/pkg/models"
)

const (
Expand All @@ -29,18 +25,29 @@ const (
AuthMFAURL = "/auth/mfa"
)

const (
// AuthRequestUserToken is the type of the token used to authenticate a user.
AuthRequestUserToken = "user"
// AuthRequestDeviceToken is the type of the token used to authenticate a device.
AuthRequestDeviceToken = "device"
)

// AuthRequest checks the user and device authentication token.
// AuthRequest is a proxy-level authentication middleware. It decodes a specified
// authentication hash (e.g. JWT tokens and API keys), sets the credentials in
// headers, and redirects to the original endpoint.
//
// This route is a special route and it is called every time a user tries to access a route which requires
// authentication. It gets the JWT token sent, unwraps it and sets the information, like tenant, user, etc., as headers
// of the response to be got in the subsequent through the [gateway.Context].
// The following sequential diagram represents the authentication pipeline:
//
// +------+ +----------------+ +----------+
// | User | | /internal/auth | | /api/... |
// +------+ +----------------+ +----------+
// | | |
// | Send Request | |
// |------------->| |
// | | Extract and decode hash |
// | | Set auth headers |
// | |------------------------>|
// | | | Execute the target endpoint
// | |
// | Send response back to the user |
// |<---------------------------------------|
//
// If the authentication fails for any reason, it must return the failed status
// without redirecting the request. A token can be use to authenticate either a
// device or a user.
func (h *Handler) AuthRequest(c gateway.Context) error {
if key := c.Request().Header.Get("X-API-Key"); key != "" {
apiKey, err := h.service.AuthAPIKey(c.Ctx(), key)
Expand All @@ -55,85 +62,37 @@ func (h *Handler) AuthRequest(c gateway.Context) error {
return c.NoContent(http.StatusOK)
}

token, ok := c.Get(middleware.DefaultJWTConfig.ContextKey).(*jwt.Token)
if !ok {
return svc.ErrTypeAssertion
}

rawClaims, ok := token.Claims.(*jwt.MapClaims)
if !ok {
return svc.ErrTypeAssertion
}

// setHeader sets a reader to the HTTP response to be read in the subsequent request.
setHeader := func(response gateway.Context, key string, value string) {
response.Response().Header().Set(key, value)
}

// decodeMap parses the JWT claims into a struct.
decodeMap := func(input *jwt.MapClaims, output any) error {
config := &mapstructure.DecoderConfig{
TagName: "json",
Metadata: nil,
Result: output,
}

decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return err
}

return decoder.Decode(input)
bearerToken := c.Request().Header.Get("Authorization")
claims, err := jwttoken.ClaimsFromBearerToken(h.service.PublicKey(), bearerToken)
if err != nil {
return c.NoContent(http.StatusUnauthorized)
}

switch (*rawClaims)["claims"] {
case AuthRequestUserToken:
claims := new(models.UserAuthClaims)
if err := decodeMap(rawClaims, claims); err != nil {
return err
}

// The TenantID is optional as the user may not be part of any namespace.
if claims.Tenant != "" {
// The rawClaims contain only the tenant ID of the namespace and not the user's role. This is because the role is a
// dynamic attribute, and a JWT token must be stateless (the role can change, but the token cannot). For this reason,
// we need to retrieve the role every time this middleware is invoked (generally from the cache; see the [method]
// signature for more info).
if err := h.service.FillClaimsRole(c.Ctx(), claims); err != nil {
switch claims := claims.(type) {
case *authorizer.DeviceClaims:
c.Response().Header().Set("X-Device-UID", claims.UID)
c.Response().Header().Set("X-Tenant-ID", claims.TenantID)
case *authorizer.UserClaims:
// As the role is a dynamic attribute, and a JWT token must be stateless, we need to retrieve the role
// every time this middleware is invoked (generally from the cache).
if claims.TenantID != "" {
role, err := h.service.GetUserRole(c.Ctx(), claims.TenantID, claims.ID)
if err != nil {
return err
}
}

args := c.QueryParam("args")
if args != "skip" && claims.Tenant != "" {
// This forces any no cached token to be invalid, even if it not not expired.
if ok, err := h.service.AuthIsCacheToken(c.Ctx(), claims.Tenant, claims.ID); err != nil || !ok {
return svc.NewErrAuthUnathorized(err)
}
claims.Role = authorizer.RoleFromString(role)
}

c.Response().Header().Set("X-ID", claims.ID)
c.Response().Header().Set("X-Username", claims.Username)
c.Response().Header().Set("X-Tenant-ID", claims.Tenant)
c.Response().Header().Set("X-Tenant-ID", claims.TenantID)
c.Response().Header().Set("X-Role", claims.Role.String())

return c.NoContent(http.StatusOK)
case AuthRequestDeviceToken:
var claims models.DeviceAuthClaims

if err := decodeMap(rawClaims, &claims); err != nil {
return err
}

// Extract device UID from JWT and set it into the header.
setHeader(c, client.DeviceUIDHeader, claims.UID)
setHeader(c, "X-Tenant-ID", claims.Tenant)

return c.NoContent(http.StatusOK)
default:

return svc.NewErrAuthUnathorized(nil)
return c.NoContent(http.StatusUnauthorized)
}

return c.NoContent(http.StatusOK)
}

func (h *Handler) AuthDevice(c gateway.Context) error {
Expand Down Expand Up @@ -231,25 +190,3 @@ func (h *Handler) AuthPublicKey(c gateway.Context) error {

return c.JSON(http.StatusOK, res)
}

func AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ctx, ok := c.Get("ctx").(*gateway.Context)
if !ok {
return svc.ErrTypeAssertion
}

apiKey := c.Request().Header.Get("X-API-KEY")
if apiKey == "" {
jwt := middleware.JWTWithConfig(middleware.JWTConfig{ //nolint:staticcheck
Claims: &jwt.MapClaims{},
SigningKey: ctx.Service().(svc.Service).PublicKey(),
SigningMethod: "RS256",
})

return jwt(next)(c)
}

return next(c)
}
}
117 changes: 83 additions & 34 deletions api/routes/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"
svc "github.com/shellhub-io/shellhub/api/services"
"github.com/shellhub-io/shellhub/api/services/mocks"
"github.com/shellhub-io/shellhub/pkg/api/authorizer"
"github.com/shellhub-io/shellhub/pkg/api/jwttoken"
"github.com/shellhub-io/shellhub/pkg/api/requests"
"github.com/shellhub-io/shellhub/pkg/clock"
"github.com/shellhub-io/shellhub/pkg/models"
"github.com/stretchr/testify/assert"
gomock "github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -538,54 +536,105 @@ func TestAuthPublicKey(t *testing.T) {
}
}

// TODO: refactor this
func TestAuthRequest(t *testing.T) {
mock := new(mocks.Service)
func TestHandler_AuthRequest_with_authorization_header(t *testing.T) {
type Expected struct {
status int
headers map[string]string
}

svcMock := new(mocks.Service)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)

token := jwt.NewWithClaims(jwt.SigningMethodRS256, models.UserAuthClaims{
Username: "username",
Tenant: "tenant",
Role: authorizer.RoleInvalid,
ID: "id",
AuthClaims: models.AuthClaims{
Claims: "user",
},
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(clock.Now().Add(time.Hour * 72)),
},
})
require.NoError(t, err)

type Expected struct {
expectedStatus int
}
cases := []struct {
title string
description string
token func() (string, error)
requiredMocks func()
expected Expected
}{}
}{
{
description: "failed when token is invalid",
token: func() (string, error) {
return "", nil
},
requiredMocks: func() {
svcMock.On("PublicKey").Return(&privateKey.PublicKey).Once()
},
expected: Expected{
status: 401,
headers: map[string]string{},
},
},
{
description: "succeeds to authenticate a user",
token: func() (string, error) {
claims := authorizer.UserClaims{
ID: "000000000000000000000000",
TenantID: "00000000-0000-4000-0000-000000000000",
Role: authorizer.RoleOwner,
Username: "john_doe",
}

return jwttoken.EncodeUserClaims(claims, privateKey)
},
requiredMocks: func() {
svcMock.On("PublicKey").Return(&privateKey.PublicKey).Once()
svcMock.On("GetUserRole", gomock.Anything, "00000000-0000-4000-0000-000000000000", "000000000000000000000000").Return("owner", nil).Once()
},
expected: Expected{
status: 200,
headers: map[string]string{
"X-ID": "000000000000000000000000",
"X-Tenant-ID": "00000000-0000-4000-0000-000000000000",
"X-Role": authorizer.RoleOwner.String(),
"X-Username": "john_doe",
},
},
},
{
description: "succeeds to authenticate a device",
token: func() (string, error) {
claims := authorizer.DeviceClaims{
UID: "0000000000000000000000000000000000000000000000000000000000000000",
TenantID: "00000000-0000-4000-0000-000000000000",
}

return jwttoken.EncodeDeviceClaims(claims, privateKey)
},
requiredMocks: func() {
svcMock.On("PublicKey").Return(&privateKey.PublicKey).Once()
},
expected: Expected{
status: 200,
headers: map[string]string{
"X-Device-UID": "0000000000000000000000000000000000000000000000000000000000000000",
"X-Tenant-ID": "00000000-0000-4000-0000-000000000000",
},
},
},
}

for _, tc := range cases {
t.Run(tc.title, func(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
tc.requiredMocks()

req := httptest.NewRequest(http.MethodGet, "/internal/auth", nil)
req.Header.Set("Content-Type", "application/json")

tokenStr, err := token.SignedString(privateKey)
assert.NoError(t, err)

req.Header.Add("Authorization", "Bearer "+tokenStr)
token, err := tc.token()
require.NoError(t, err)

req.Header.Set("X-Role", authorizer.RoleOwner.String())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", token)

rec := httptest.NewRecorder()

e := NewRouter(mock)
e := NewRouter(svcMock)
e.ServeHTTP(rec, req)

assert.Equal(t, tc.expected.expectedStatus, rec.Result().StatusCode)
require.Equal(t, tc.expected.status, rec.Result().StatusCode)
for k, v := range tc.expected.headers {
require.Equal(t, rec.Result().Header.Get(k), v)
}
})
}
}
Loading

0 comments on commit 839be80

Please sign in to comment.