diff --git a/oidc/jwks.go b/oidc/jwks.go index 524f4e67..0c268595 100644 --- a/oidc/jwks.go +++ b/oidc/jwks.go @@ -7,7 +7,7 @@ import ( "crypto/rsa" "errors" "fmt" - "io/ioutil" + "io" "net/http" "sync" "time" @@ -230,13 +230,13 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) { } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxRespBodySize)) if err != nil { return nil, fmt.Errorf("unable to read response body: %v", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body) + return nil, fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body[:getMaxLogSizeForBody(body)]) } var keySet jose.JSONWebKeySet diff --git a/oidc/oidc.go b/oidc/oidc.go index d86be693..38183740 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -10,7 +10,7 @@ import ( "errors" "fmt" "hash" - "io/ioutil" + "io" "mime" "net/http" "strings" @@ -35,6 +35,10 @@ const ( ScopeOfflineAccess = "offline_access" ) +const ( + maxRespBodySize = 262144 +) + var ( errNoAtHash = errors.New("id token did not have an access token hash") errInvalidAtHash = errors.New("access token hash does not match value in ID token") @@ -210,17 +214,13 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) { } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxRespBodySize)) if err != nil { return nil, fmt.Errorf("unable to read response body: %v", err) } if resp.StatusCode != http.StatusOK { - maxBodySize := len(body) - if maxBodySize > 2048 { - maxBodySize = 2048 - } - return nil, fmt.Errorf("%s: %s", resp.Status, body[:maxBodySize]) + return nil, fmt.Errorf("%s: %s", resp.Status, body[:getMaxLogSizeForBody(body)]) } var p providerJSON @@ -335,12 +335,12 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) return nil, err } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxRespBodySize)) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%s: %s", resp.Status, body) + return nil, fmt.Errorf("%s: %s", resp.Status, body[:getMaxLogSizeForBody(body)]) } ct := resp.Header.Get("Content-Type") @@ -548,3 +548,10 @@ func unmarshalResp(r *http.Response, body []byte, v interface{}) error { } return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err) } + +func getMaxLogSizeForBody(body []byte) int { + if len(body) > 2048 { + return 2048 + } + return len(body) +} diff --git a/oidc/verify.go b/oidc/verify.go index 3e5ffbc7..94d5cea7 100644 --- a/oidc/verify.go +++ b/oidc/verify.go @@ -7,7 +7,7 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "net/http" "strings" "time" @@ -182,7 +182,7 @@ func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxRespBodySize)) if err != nil { return nil, fmt.Errorf("unable to read response body: %v", err) }