diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f224298..b36f5ba 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,7 +22,8 @@ jobs: sleep 5 go test -v -coverpkg=./... -race -timeout 3m -coverprofile=coverage.out.tmp ./... docker compose down -v - cat coverage.out.tmp | grep -v "main.go" > coverage.out + cat coverage.out.tmp | grep -v "main.go" > coverage.out.tmp2 + cat coverage.out.tmp2 | grep -v "version.go" > coverage.out - name: Run integration tests run: | ./scripts/test.sh diff --git a/connector/connector_test.go b/connector/connector_test.go index e423617..54f1d07 100644 --- a/connector/connector_test.go +++ b/connector/connector_test.go @@ -981,6 +981,16 @@ func createMockServer(t *testing.T, apiKey string, bearerToken string) *httptest } }) + mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet, http.MethodPost: + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + }) + return httptest.NewServer(mux) } @@ -1150,6 +1160,37 @@ func TestConnectorOAuth(t *testing.T) { }, }, }) + + failureBody := []byte(`{ + "collection": "findPetsOAuth", + "query": { + "fields": { + "__value": { + "type": "column", + "column": "__value" + } + } + }, + "arguments": { + "httpOptions": { + "type": "literal", + "value": { + "servers": ["1"] + } + } + }, + "collection_relationships": {} + }`) + + res, err = http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(failureBody)) + assert.NilError(t, err) + defer res.Body.Close() + + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + respBody, err := io.ReadAll(res.Body) + assert.NilError(t, err) + assert.Assert(t, strings.Contains(string(respBody), "oauth2: cannot fetch token: 400 Bad Request")) } type mockTLSServer struct { diff --git a/connector/internal/client.go b/connector/internal/client.go index a456cc7..9d7703e 100644 --- a/connector/internal/client.go +++ b/connector/internal/client.go @@ -282,7 +282,7 @@ func (client *HTTPClient) doRequest(ctx context.Context, request *RetryableReque ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", method, request.RawRequest.URL), trace.WithSpanKind(trace.SpanKindClient)) defer span.End() - urlAttr := cloneURL(&request.URL) + urlAttr := restUtils.CloneURL(&request.URL) password, hasPassword := urlAttr.User.Password() if urlAttr.User.String() != "" || hasPassword { maskedUser := restUtils.MaskString(urlAttr.User.Username()) diff --git a/connector/internal/security/auth.go b/connector/internal/security/auth.go index 55c185a..5bd7a13 100644 --- a/connector/internal/security/auth.go +++ b/connector/internal/security/auth.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/url" "github.com/hasura/ndc-http/ndc-http-schema/schema" ) @@ -19,7 +20,7 @@ type Credential interface { } // NewCredential creates a generic credential from the security scheme. -func NewCredential(ctx context.Context, httpClient *http.Client, security schema.SecurityScheme) (Credential, bool, error) { +func NewCredential(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, security schema.SecurityScheme) (Credential, bool, error) { if security.SecuritySchemer == nil { return nil, false, errors.New("empty security scheme") } @@ -44,7 +45,7 @@ func NewCredential(ctx context.Context, httpClient *http.Client, security schema headerForwardingRequired = true } - cred, err := NewOAuth2Client(ctx, httpClient, flowType, &flow) + cred, err := NewOAuth2Client(ctx, httpClient, baseServerURL, flowType, &flow) return cred, headerForwardingRequired || err != nil, err } diff --git a/connector/internal/security/oauth2.go b/connector/internal/security/oauth2.go index 6614fca..95d5023 100644 --- a/connector/internal/security/oauth2.go +++ b/connector/internal/security/oauth2.go @@ -5,8 +5,10 @@ import ( "fmt" "net/http" "net/url" + "path" "github.com/hasura/ndc-http/ndc-http-schema/schema" + "github.com/hasura/ndc-http/ndc-http-schema/utils" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" ) @@ -20,7 +22,7 @@ type OAuth2Client struct { var _ Credential = &OAuth2Client{} // NewOAuth2Client creates an OAuth2 client from the security scheme -func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType schema.OAuthFlowType, config *schema.OAuthFlow) (*OAuth2Client, error) { +func NewOAuth2Client(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, flowType schema.OAuthFlowType, config *schema.OAuthFlow) (*OAuth2Client, error) { if flowType != schema.ClientCredentialsFlow || config.TokenURL == nil || config.ClientID == nil || config.ClientSecret == nil { return &OAuth2Client{ client: httpClient, @@ -28,15 +30,30 @@ func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType sche }, nil } - tokenURL, err := config.TokenURL.Get() + rawTokenURL, err := config.TokenURL.Get() if err != nil { return nil, fmt.Errorf("tokenUrl: %w", err) } - if _, err := schema.ParseRelativeOrHttpURL(tokenURL); err != nil { + tokenURL, err := schema.ParseRelativeOrHttpURL(rawTokenURL) + if err != nil { return nil, fmt.Errorf("tokenUrl: %w", err) } + // if the token URL is a relative path it will be joined with the base server URL + if tokenURL.Host == "" { + tu := utils.CloneURL(baseServerURL) + tu.Path = path.Join(tu.Path, tokenURL.Path) + q := tu.Query() + for k, v := range tokenURL.Query() { + q[k] = v + } + tu.RawQuery = q.Encode() + tu.RawFragment = tokenURL.RawFragment + + tokenURL = tu + } + scopes := make([]string, 0, len(config.Scopes)) for scope := range config.Scopes { scopes = append(scopes, scope) @@ -68,7 +85,7 @@ func NewOAuth2Client(ctx context.Context, httpClient *http.Client, flowType sche ClientID: clientID, ClientSecret: clientSecret, Scopes: scopes, - TokenURL: tokenURL, + TokenURL: tokenURL.String(), EndpointParams: endpointParams, } diff --git a/connector/internal/upstream.go b/connector/internal/upstream.go index 792a465..68d6e6a 100644 --- a/connector/internal/upstream.go +++ b/connector/internal/upstream.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "net/http" + "net/url" "strconv" "github.com/hasura/ndc-http/connector/internal/argument" @@ -58,12 +59,12 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur } } + var defaultServerURL *url.URL settings := UpstreamSetting{ - servers: make(map[string]Server), - security: runtimeSchema.Settings.Security, - headers: um.getHeadersFromEnv(logger, namespace, runtimeSchema.Settings.Headers), - credentials: um.registerSecurityCredentials(ctx, httpClient, runtimeSchema.Settings.SecuritySchemes, logger.With(slog.String("namespace", namespace))), - httpClient: httpClient, + servers: make(map[string]Server), + security: runtimeSchema.Settings.Security, + headers: um.getHeadersFromEnv(logger, namespace, runtimeSchema.Settings.Headers), + httpClient: httpClient, } if len(runtimeSchema.Settings.ArgumentPresets) > 0 { @@ -89,6 +90,10 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur continue } + if defaultServerURL == nil { + defaultServerURL = serverURL + } + serverClient := httpClient if server.TLS != nil { tlsClient, err := security.NewHTTPClientTLS(um.defaultClient, server.TLS, logger) @@ -105,7 +110,7 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur URL: serverURL, Headers: um.getHeadersFromEnv(logger, namespace, server.Headers), Security: server.Security, - Credentials: um.registerSecurityCredentials(ctx, serverClient, server.SecuritySchemes, logger.With(slog.String("namespace", namespace), slog.String("server_id", serverID))), + Credentials: um.registerSecurityCredentials(ctx, serverClient, serverURL, server.SecuritySchemes, logger.With(slog.String("namespace", namespace), slog.String("server_id", serverID))), HTTPClient: serverClient, } @@ -120,6 +125,8 @@ func (um *UpstreamManager) Register(ctx context.Context, runtimeSchema *configur settings.servers[serverID] = newServer } + settings.credentials = um.registerSecurityCredentials(ctx, httpClient, defaultServerURL, runtimeSchema.Settings.SecuritySchemes, logger.With(slog.String("namespace", namespace))) + um.upstreams[namespace] = settings return nil @@ -292,11 +299,11 @@ func (um *UpstreamManager) getHeadersFromEnv(logger *slog.Logger, namespace stri return results } -func (um *UpstreamManager) registerSecurityCredentials(ctx context.Context, httpClient *http.Client, securitySchemes map[string]rest.SecurityScheme, logger *slog.Logger) map[string]security.Credential { +func (um *UpstreamManager) registerSecurityCredentials(ctx context.Context, httpClient *http.Client, baseServerURL *url.URL, securitySchemes map[string]rest.SecurityScheme, logger *slog.Logger) map[string]security.Credential { credentials := make(map[string]security.Credential) for key, ss := range securitySchemes { - cred, headerForwardRequired, err := security.NewCredential(ctx, httpClient, ss) + cred, headerForwardRequired, err := security.NewCredential(ctx, httpClient, baseServerURL, ss) if err != nil { // Relax the error to allow schema introspection without environment variables setting. // Moreover, because there are many security schemes the user may use one of them. diff --git a/connector/internal/utils.go b/connector/internal/utils.go index 9b25818..7bda0ea 100644 --- a/connector/internal/utils.go +++ b/connector/internal/utils.go @@ -2,7 +2,6 @@ package internal import ( "net/http" - "net/url" "strings" "github.com/hasura/ndc-http/ndc-http-schema/utils" @@ -52,19 +51,3 @@ func evalForwardedHeaders(req *RetryableRequest, headers map[string]string) erro return nil } - -func cloneURL(input *url.URL) *url.URL { - return &url.URL{ - Scheme: input.Scheme, - Opaque: input.Opaque, - User: input.User, - Host: input.Host, - Path: input.Path, - RawPath: input.RawPath, - OmitHost: input.OmitHost, - ForceQuery: input.ForceQuery, - RawQuery: input.RawQuery, - Fragment: input.Fragment, - RawFragment: input.RawFragment, - } -} diff --git a/connector/testdata/auth/schema.yaml b/connector/testdata/auth/schema.yaml index 2572590..91ee4a5 100644 --- a/connector/testdata/auth/schema.yaml +++ b/connector/testdata/auth/schema.yaml @@ -3,6 +3,22 @@ settings: servers: - url: env: PET_STORE_URL + - url: + env: PET_STORE_URL + securitySchemes: + petstore_auth: + type: oauth2 + flows: + clientCredentials: + tokenUrl: + value: /oauth2/token + clientId: + env: OAUTH2_CLIENT_ID + clientSecret: + env: OAUTH2_CLIENT_SECRET + scopes: + read:pets: read your pets + write:pets: modify pets in your account securitySchemes: api_key: type: apiKey diff --git a/ndc-http-schema/utils/http.go b/ndc-http-schema/utils/http.go index fef0b33..6b865ce 100644 --- a/ndc-http-schema/utils/http.go +++ b/ndc-http-schema/utils/http.go @@ -1,6 +1,7 @@ package utils import ( + "net/url" "strings" "github.com/hasura/ndc-http/ndc-http-schema/schema" @@ -30,3 +31,20 @@ func IsContentTypeBinary(contentType string) bool { func IsContentTypeMultipartForm(contentType string) bool { return strings.HasPrefix(contentType, "multipart/") } + +// CloneURL clones the input URL to a new instance. +func CloneURL(input *url.URL) *url.URL { + return &url.URL{ + Scheme: input.Scheme, + Opaque: input.Opaque, + User: input.User, + Host: input.Host, + Path: input.Path, + RawPath: input.RawPath, + OmitHost: input.OmitHost, + ForceQuery: input.ForceQuery, + RawQuery: input.RawQuery, + Fragment: input.Fragment, + RawFragment: input.RawFragment, + } +}