Skip to content

Commit

Permalink
feat: prepend base server URL to the relative OAuth token path (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac authored Dec 28, 2024
1 parent 4e7ff56 commit 99e7c51
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 33 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ jobs:
sleep 5
go test -v -coverpkg=./... -race -timeout 3m -coverprofile=coverage.out.tmp ./...
docker compose down -v
cat coverage.out.tmp | grep -v "main.go" > coverage.out
cat coverage.out.tmp | grep -v "main.go" > coverage.out.tmp2
cat coverage.out.tmp2 | grep -v "version.go" > coverage.out
- name: Run integration tests
run: |
./scripts/test.sh
Expand Down
41 changes: 41 additions & 0 deletions connector/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,16 @@ func createMockServer(t *testing.T, apiKey string, bearerToken string) *httptest
}
})

mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodPost:
w.WriteHeader(http.StatusBadRequest)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
})

return httptest.NewServer(mux)
}

Expand Down Expand Up @@ -1150,6 +1160,37 @@ func TestConnectorOAuth(t *testing.T) {
},
},
})

failureBody := []byte(`{
"collection": "findPetsOAuth",
"query": {
"fields": {
"__value": {
"type": "column",
"column": "__value"
}
}
},
"arguments": {
"httpOptions": {
"type": "literal",
"value": {
"servers": ["1"]
}
}
},
"collection_relationships": {}
}`)

res, err = http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(failureBody))
assert.NilError(t, err)
defer res.Body.Close()

assert.Equal(t, http.StatusInternalServerError, res.StatusCode)

respBody, err := io.ReadAll(res.Body)
assert.NilError(t, err)
assert.Assert(t, strings.Contains(string(respBody), "oauth2: cannot fetch token: 400 Bad Request"))
}

type mockTLSServer struct {
Expand Down
2 changes: 1 addition & 1 deletion connector/internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (client *HTTPClient) doRequest(ctx context.Context, request *RetryableReque
ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", method, request.RawRequest.URL), trace.WithSpanKind(trace.SpanKindClient))
defer span.End()

urlAttr := cloneURL(&request.URL)
urlAttr := restUtils.CloneURL(&request.URL)
password, hasPassword := urlAttr.User.Password()
if urlAttr.User.String() != "" || hasPassword {
maskedUser := restUtils.MaskString(urlAttr.User.Username())
Expand Down
5 changes: 3 additions & 2 deletions connector/internal/security/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"net/url"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
)
Expand All @@ -19,7 +20,7 @@ type Credential interface {
}

// NewCredential creates a generic credential from the security scheme.
func NewCredential(ctx context.Context, httpClient *http.Client, security schema.SecurityScheme) (Credential, bool, error) {
func NewCredential(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, security schema.SecurityScheme) (Credential, bool, error) {
if security.SecuritySchemer == nil {
return nil, false, errors.New("empty security scheme")
}
Expand All @@ -44,7 +45,7 @@ func NewCredential(ctx context.Context, httpClient *http.Client, security schema
headerForwardingRequired = true
}

cred, err := NewOAuth2Client(ctx, httpClient, flowType, &flow)
cred, err := NewOAuth2Client(ctx, httpClient, baseServerURL, flowType, &flow)

return cred, headerForwardingRequired || err != nil, err
}
Expand Down
25 changes: 21 additions & 4 deletions connector/internal/security/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"fmt"
"net/http"
"net/url"
"path"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
"github.com/hasura/ndc-http/ndc-http-schema/utils"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
Expand All @@ -20,23 +22,38 @@ type OAuth2Client struct {
var _ Credential = &OAuth2Client{}

// NewOAuth2Client creates an OAuth2 client from the security scheme
func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType schema.OAuthFlowType, config *schema.OAuthFlow) (*OAuth2Client, error) {
func NewOAuth2Client(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, flowType schema.OAuthFlowType, config *schema.OAuthFlow) (*OAuth2Client, error) {
if flowType != schema.ClientCredentialsFlow || config.TokenURL == nil || config.ClientID == nil || config.ClientSecret == nil {
return &OAuth2Client{
client: httpClient,
isEmpty: true,
}, nil
}

tokenURL, err := config.TokenURL.Get()
rawTokenURL, err := config.TokenURL.Get()
if err != nil {
return nil, fmt.Errorf("tokenUrl: %w", err)
}

if _, err := schema.ParseRelativeOrHttpURL(tokenURL); err != nil {
tokenURL, err := schema.ParseRelativeOrHttpURL(rawTokenURL)
if err != nil {
return nil, fmt.Errorf("tokenUrl: %w", err)
}

// if the token URL is a relative path it will be joined with the base server URL
if tokenURL.Host == "" {
tu := utils.CloneURL(baseServerURL)
tu.Path = path.Join(tu.Path, tokenURL.Path)
q := tu.Query()
for k, v := range tokenURL.Query() {
q[k] = v
}
tu.RawQuery = q.Encode()
tu.RawFragment = tokenURL.RawFragment

tokenURL = tu
}

scopes := make([]string, 0, len(config.Scopes))
for scope := range config.Scopes {
scopes = append(scopes, scope)
Expand Down Expand Up @@ -68,7 +85,7 @@ func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType sche
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
TokenURL: tokenURL,
TokenURL: tokenURL.String(),
EndpointParams: endpointParams,
}

Expand Down
23 changes: 15 additions & 8 deletions connector/internal/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"net/http"
"net/url"
"strconv"

"github.com/hasura/ndc-http/connector/internal/argument"
Expand Down Expand Up @@ -58,12 +59,12 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
}
}

var defaultServerURL *url.URL
settings := UpstreamSetting{
servers: make(map[string]Server),
security: runtimeSchema.Settings.Security,
headers: um.getHeadersFromEnv(logger, namespace, runtimeSchema.Settings.Headers),
credentials: um.registerSecurityCredentials(ctx, httpClient, runtimeSchema.Settings.SecuritySchemes, logger.With(slog.String("namespace", namespace))),
httpClient: httpClient,
servers: make(map[string]Server),
security: runtimeSchema.Settings.Security,
headers: um.getHeadersFromEnv(logger, namespace, runtimeSchema.Settings.Headers),
httpClient: httpClient,
}

if len(runtimeSchema.Settings.ArgumentPresets) > 0 {
Expand All @@ -89,6 +90,10 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
continue
}

if defaultServerURL == nil {
defaultServerURL = serverURL
}

serverClient := httpClient
if server.TLS != nil {
tlsClient, err := security.NewHTTPClientTLS(um.defaultClient, server.TLS, logger)
Expand All @@ -105,7 +110,7 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
URL: serverURL,
Headers: um.getHeadersFromEnv(logger, namespace, server.Headers),
Security: server.Security,
Credentials: um.registerSecurityCredentials(ctx, serverClient, server.SecuritySchemes, logger.With(slog.String("namespace", namespace), slog.String("server_id", serverID))),
Credentials: um.registerSecurityCredentials(ctx, serverClient, serverURL, server.SecuritySchemes, logger.With(slog.String("namespace", namespace), slog.String("server_id", serverID))),
HTTPClient: serverClient,
}

Expand All @@ -120,6 +125,8 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
settings.servers[serverID] = newServer
}

settings.credentials = um.registerSecurityCredentials(ctx, httpClient, defaultServerURL, runtimeSchema.Settings.SecuritySchemes, logger.With(slog.String("namespace", namespace)))

um.upstreams[namespace] = settings

return nil
Expand Down Expand Up @@ -292,11 +299,11 @@ func (um *UpstreamManager) getHeadersFromEnv(logger *slog.Logger, namespace stri
return results
}

func (um *UpstreamManager) registerSecurityCredentials(ctx context.Context, httpClient *http.Client, securitySchemes map[string]rest.SecurityScheme, logger *slog.Logger) map[string]security.Credential {
func (um *UpstreamManager) registerSecurityCredentials(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, securitySchemes map[string]rest.SecurityScheme, logger *slog.Logger) map[string]security.Credential {
credentials := make(map[string]security.Credential)

for key, ss := range securitySchemes {
cred, headerForwardRequired, err := security.NewCredential(ctx, httpClient, ss)
cred, headerForwardRequired, err := security.NewCredential(ctx, httpClient, baseServerURL, ss)
if err != nil {
// Relax the error to allow schema introspection without environment variables setting.
// Moreover, because there are many security schemes the user may use one of them.
Expand Down
17 changes: 0 additions & 17 deletions connector/internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package internal

import (
"net/http"
"net/url"
"strings"

"github.com/hasura/ndc-http/ndc-http-schema/utils"
Expand Down Expand Up @@ -52,19 +51,3 @@ func evalForwardedHeaders(req *RetryableRequest, headers map[string]string) erro

return nil
}

func cloneURL(input *url.URL) *url.URL {
return &url.URL{
Scheme: input.Scheme,
Opaque: input.Opaque,
User: input.User,
Host: input.Host,
Path: input.Path,
RawPath: input.RawPath,
OmitHost: input.OmitHost,
ForceQuery: input.ForceQuery,
RawQuery: input.RawQuery,
Fragment: input.Fragment,
RawFragment: input.RawFragment,
}
}
16 changes: 16 additions & 0 deletions connector/testdata/auth/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@ settings:
servers:
- url:
env: PET_STORE_URL
- url:
env: PET_STORE_URL
securitySchemes:
petstore_auth:
type: oauth2
flows:
clientCredentials:
tokenUrl:
value: /oauth2/token
clientId:
env: OAUTH2_CLIENT_ID
clientSecret:
env: OAUTH2_CLIENT_SECRET
scopes:
read:pets: read your pets
write:pets: modify pets in your account
securitySchemes:
api_key:
type: apiKey
Expand Down
18 changes: 18 additions & 0 deletions ndc-http-schema/utils/http.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"net/url"
"strings"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
Expand Down Expand Up @@ -30,3 +31,20 @@ func IsContentTypeBinary(contentType string) bool {
func IsContentTypeMultipartForm(contentType string) bool {
return strings.HasPrefix(contentType, "multipart/")
}

// CloneURL clones the input URL to a new instance.
func CloneURL(input *url.URL) *url.URL {
return &url.URL{
Scheme: input.Scheme,
Opaque: input.Opaque,
User: input.User,
Host: input.Host,
Path: input.Path,
RawPath: input.RawPath,
OmitHost: input.OmitHost,
ForceQuery: input.ForceQuery,
RawQuery: input.RawQuery,
Fragment: input.Fragment,
RawFragment: input.RawFragment,
}
}

0 comments on commit 99e7c51

Please sign in to comment.