Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: prepend base server URL to the relative OAuth token path #52

Merged
merged 3 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ jobs:
sleep 5
go test -v -coverpkg=./... -race -timeout 3m -coverprofile=coverage.out.tmp ./...
docker compose down -v
cat coverage.out.tmp | grep -v "main.go" > coverage.out
cat coverage.out.tmp | grep -v "main.go" > coverage.out.tmp2
cat coverage.out.tmp2 | grep -v "version.go" > coverage.out
- name: Run integration tests
run: |
./scripts/test.sh
Expand Down
41 changes: 41 additions & 0 deletions connector/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,16 @@ func createMockServer(t *testing.T, apiKey string, bearerToken string) *httptest
}
})

mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodPost:
w.WriteHeader(http.StatusBadRequest)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
})

return httptest.NewServer(mux)
}

Expand Down Expand Up @@ -1150,6 +1160,37 @@ func TestConnectorOAuth(t *testing.T) {
},
},
})

failureBody := []byte(`{
"collection": "findPetsOAuth",
"query": {
"fields": {
"__value": {
"type": "column",
"column": "__value"
}
}
},
"arguments": {
"httpOptions": {
"type": "literal",
"value": {
"servers": ["1"]
}
}
},
"collection_relationships": {}
}`)

res, err = http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(failureBody))
assert.NilError(t, err)
defer res.Body.Close()

assert.Equal(t, http.StatusInternalServerError, res.StatusCode)

respBody, err := io.ReadAll(res.Body)
assert.NilError(t, err)
assert.Assert(t, strings.Contains(string(respBody), "oauth2: cannot fetch token: 400 Bad Request"))
}

type mockTLSServer struct {
Expand Down
2 changes: 1 addition & 1 deletion connector/internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (client *HTTPClient) doRequest(ctx context.Context, request *RetryableReque
ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", method, request.RawRequest.URL), trace.WithSpanKind(trace.SpanKindClient))
defer span.End()

urlAttr := cloneURL(&request.URL)
urlAttr := restUtils.CloneURL(&request.URL)
password, hasPassword := urlAttr.User.Password()
if urlAttr.User.String() != "" || hasPassword {
maskedUser := restUtils.MaskString(urlAttr.User.Username())
Expand Down
5 changes: 3 additions & 2 deletions connector/internal/security/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
"net/url"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
)
Expand All @@ -19,7 +20,7 @@ type Credential interface {
}

// NewCredential creates a generic credential from the security scheme.
func NewCredential(ctx context.Context, httpClient *http.Client, security schema.SecurityScheme) (Credential, bool, error) {
func NewCredential(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, security schema.SecurityScheme) (Credential, bool, error) {
if security.SecuritySchemer == nil {
return nil, false, errors.New("empty security scheme")
}
Expand All @@ -44,7 +45,7 @@ func NewCredential(ctx context.Context, httpClient *http.Client, security schema
headerForwardingRequired = true
}

cred, err := NewOAuth2Client(ctx, httpClient, flowType, &flow)
cred, err := NewOAuth2Client(ctx, httpClient, baseServerURL, flowType, &flow)

return cred, headerForwardingRequired || err != nil, err
}
Expand Down
25 changes: 21 additions & 4 deletions connector/internal/security/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"fmt"
"net/http"
"net/url"
"path"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
"github.com/hasura/ndc-http/ndc-http-schema/utils"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
Expand All @@ -20,23 +22,38 @@ type OAuth2Client struct {
var _ Credential = &OAuth2Client{}

// NewOAuth2Client creates an OAuth2 client from the security scheme
func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType schema.OAuthFlowType, config *schema.OAuthFlow) (*OAuth2Client, error) {
func NewOAuth2Client(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, flowType schema.OAuthFlowType, config *schema.OAuthFlow) (*OAuth2Client, error) {
if flowType != schema.ClientCredentialsFlow || config.TokenURL == nil || config.ClientID == nil || config.ClientSecret == nil {
return &OAuth2Client{
client: httpClient,
isEmpty: true,
}, nil
}

tokenURL, err := config.TokenURL.Get()
rawTokenURL, err := config.TokenURL.Get()
if err != nil {
return nil, fmt.Errorf("tokenUrl: %w", err)
}

if _, err := schema.ParseRelativeOrHttpURL(tokenURL); err != nil {
tokenURL, err := schema.ParseRelativeOrHttpURL(rawTokenURL)
if err != nil {
return nil, fmt.Errorf("tokenUrl: %w", err)
}

// if the token URL is a relative path it will be joined with the base server URL
if tokenURL.Host == "" {
tu := utils.CloneURL(baseServerURL)
tu.Path = path.Join(tu.Path, tokenURL.Path)
q := tu.Query()
for k, v := range tokenURL.Query() {
q[k] = v
}
tu.RawQuery = q.Encode()
tu.RawFragment = tokenURL.RawFragment

tokenURL = tu
}

scopes := make([]string, 0, len(config.Scopes))
for scope := range config.Scopes {
scopes = append(scopes, scope)
Expand Down Expand Up @@ -68,7 +85,7 @@ func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType sche
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
TokenURL: tokenURL,
TokenURL: tokenURL.String(),
EndpointParams: endpointParams,
}

Expand Down
23 changes: 15 additions & 8 deletions connector/internal/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"net/http"
"net/url"
"strconv"

"github.com/hasura/ndc-http/connector/internal/argument"
Expand Down Expand Up @@ -58,12 +59,12 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
}
}

var defaultServerURL *url.URL
settings := UpstreamSetting{
servers: make(map[string]Server),
security: runtimeSchema.Settings.Security,
headers: um.getHeadersFromEnv(logger, namespace, runtimeSchema.Settings.Headers),
credentials: um.registerSecurityCredentials(ctx, httpClient, runtimeSchema.Settings.SecuritySchemes, logger.With(slog.String("namespace", namespace))),
httpClient: httpClient,
servers: make(map[string]Server),
security: runtimeSchema.Settings.Security,
headers: um.getHeadersFromEnv(logger, namespace, runtimeSchema.Settings.Headers),
httpClient: httpClient,
}

if len(runtimeSchema.Settings.ArgumentPresets) > 0 {
Expand All @@ -89,6 +90,10 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
continue
}

if defaultServerURL == nil {
defaultServerURL = serverURL
}

serverClient := httpClient
if server.TLS != nil {
tlsClient, err := security.NewHTTPClientTLS(um.defaultClient, server.TLS, logger)
Expand All @@ -105,7 +110,7 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
URL: serverURL,
Headers: um.getHeadersFromEnv(logger, namespace, server.Headers),
Security: server.Security,
Credentials: um.registerSecurityCredentials(ctx, serverClient, server.SecuritySchemes, logger.With(slog.String("namespace", namespace), slog.String("server_id", serverID))),
Credentials: um.registerSecurityCredentials(ctx, serverClient, serverURL, server.SecuritySchemes, logger.With(slog.String("namespace", namespace), slog.String("server_id", serverID))),
HTTPClient: serverClient,
}

Expand All @@ -120,6 +125,8 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur
settings.servers[serverID] = newServer
}

settings.credentials = um.registerSecurityCredentials(ctx, httpClient, defaultServerURL, runtimeSchema.Settings.SecuritySchemes, logger.With(slog.String("namespace", namespace)))

um.upstreams[namespace] = settings

return nil
Expand Down Expand Up @@ -292,11 +299,11 @@ func (um *UpstreamManager) getHeadersFromEnv(logger *slog.Logger, namespace stri
return results
}

func (um *UpstreamManager) registerSecurityCredentials(ctx context.Context, httpClient *http.Client, securitySchemes map[string]rest.SecurityScheme, logger *slog.Logger) map[string]security.Credential {
func (um *UpstreamManager) registerSecurityCredentials(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, securitySchemes map[string]rest.SecurityScheme, logger *slog.Logger) map[string]security.Credential {
credentials := make(map[string]security.Credential)

for key, ss := range securitySchemes {
cred, headerForwardRequired, err := security.NewCredential(ctx, httpClient, ss)
cred, headerForwardRequired, err := security.NewCredential(ctx, httpClient, baseServerURL, ss)
if err != nil {
// Relax the error to allow schema introspection without environment variables setting.
// Moreover, because there are many security schemes the user may use one of them.
Expand Down
17 changes: 0 additions & 17 deletions connector/internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package internal

import (
"net/http"
"net/url"
"strings"

"github.com/hasura/ndc-http/ndc-http-schema/utils"
Expand Down Expand Up @@ -52,19 +51,3 @@ func evalForwardedHeaders(req *RetryableRequest, headers map[string]string) erro

return nil
}

func cloneURL(input *url.URL) *url.URL {
return &url.URL{
Scheme: input.Scheme,
Opaque: input.Opaque,
User: input.User,
Host: input.Host,
Path: input.Path,
RawPath: input.RawPath,
OmitHost: input.OmitHost,
ForceQuery: input.ForceQuery,
RawQuery: input.RawQuery,
Fragment: input.Fragment,
RawFragment: input.RawFragment,
}
}
16 changes: 16 additions & 0 deletions connector/testdata/auth/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@ settings:
servers:
- url:
env: PET_STORE_URL
- url:
env: PET_STORE_URL
securitySchemes:
petstore_auth:
type: oauth2
flows:
clientCredentials:
tokenUrl:
value: /oauth2/token
clientId:
env: OAUTH2_CLIENT_ID
clientSecret:
env: OAUTH2_CLIENT_SECRET
scopes:
read:pets: read your pets
write:pets: modify pets in your account
securitySchemes:
api_key:
type: apiKey
Expand Down
18 changes: 18 additions & 0 deletions ndc-http-schema/utils/http.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"net/url"
"strings"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
Expand Down Expand Up @@ -30,3 +31,20 @@ func IsContentTypeBinary(contentType string) bool {
func IsContentTypeMultipartForm(contentType string) bool {
return strings.HasPrefix(contentType, "multipart/")
}

// CloneURL clones the input URL to a new instance.
func CloneURL(input *url.URL) *url.URL {
return &url.URL{
Scheme: input.Scheme,
Opaque: input.Opaque,
User: input.User,
Host: input.Host,
Path: input.Path,
RawPath: input.RawPath,
OmitHost: input.OmitHost,
ForceQuery: input.ForceQuery,
RawQuery: input.RawQuery,
Fragment: input.Fragment,
RawFragment: input.RawFragment,
}
}
Loading