diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3814f7ee..fc12031d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: "1.20.14" + go-version: "1.21.8" - name: Build run: go build -v ./... diff --git a/README.md b/README.md index 3d18ebfa..2e790225 100644 --- a/README.md +++ b/README.md @@ -225,20 +225,21 @@ minor changes (e.g., updating docs) directly on the `main` branch. - [ ] update [UPSTREAM](UPSTREAM), commit the change, and then run the `./tools/merge.bash` script to merge from upstream; -- [ ] make sure you synch [./internal/safefilepath](./internal/safefilepath) with the -`./src/internal/safefilepath` of the Go release you're merging from; - - [ ] solve the very-likely merge conflicts and ensure [the original spirit of the patches](#patches) still hold; +- [ ] make sure you synch [./internal/safefilepath](./internal/safefilepath) with the +`./src/internal/safefilepath` of the Go release you're merging from; + - [ ] make sure the codebase does not assume `*tls.Conn` *anywhere* (`git grep -n '\*tls\.Conn'`) and otherwise replace `*tls.Conn` with `TLSConn`; - [ ] make sure the codebase does not call `tls.Client` *anywhere* except for `tlsconn.go` (`git grep -n 'tls\.Client'`) and otherwise replace `tls.Client` with `TLSClientFactory`; -- [ ] diff with upstream (`diff --color=never -ru .../golang/go/src/net/http .`) and -make sure what you see makes sense in terms of the original patches; +- [ ] diff with upstream (`./tools/compare.bash`) and make sure what you see +makes sense in terms of the original patches, save the diff, and include it into +the PR to document the actual changes between us and upstream. - [ ] ensure `go build -v ./...` still works; diff --git a/UPSTREAM b/UPSTREAM index e36050d0..3e9ef2da 100644 --- a/UPSTREAM +++ b/UPSTREAM @@ -1 +1 @@ -go1.20.14 +go1.21.8 diff --git a/cgi/child_test.go b/cgi/child_test.go index aa9b0010..b949bfec 100644 --- a/cgi/child_test.go +++ b/cgi/child_test.go @@ -12,8 +12,8 @@ import ( "strings" "testing" - "github.com/ooni/oohttp" - "github.com/ooni/oohttp/httptest" + http "github.com/ooni/oohttp" + httptest "github.com/ooni/oohttp/httptest" ) func TestRequest(t *testing.T) { diff --git a/cgi/host.go b/cgi/host.go index bf61f10f..6014e914 100644 --- a/cgi/host.go +++ b/cgi/host.go @@ -39,7 +39,7 @@ var osDefaultInheritEnv = func() []string { switch runtime.GOOS { case "darwin", "ios": return []string{"DYLD_LIBRARY_PATH"} - case "linux", "freebsd", "netbsd", "openbsd": + case "android", "linux", "freebsd", "netbsd", "openbsd": return []string{"LD_LIBRARY_PATH"} case "hpux": return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"} diff --git a/client.go b/client.go index ed536969..e0ad7553 100644 --- a/client.go +++ b/client.go @@ -145,7 +145,8 @@ type RoundTripper interface { // refererForURL returns a referer without any authentication info or // an empty string if lastReq scheme is https and newReq scheme is http. -func refererForURL(lastReq, newReq *url.URL) string { +// If the referer was explicitly set, then it will continue to be used. +func refererForURL(lastReq, newReq *url.URL, explicitRef string) string { // https://tools.ietf.org/html/rfc7231#section-5.5.2 // "Clients SHOULD NOT include a Referer header field in a // (non-secure) HTTP request if the referring page was @@ -153,6 +154,10 @@ func refererForURL(lastReq, newReq *url.URL) string { if lastReq.Scheme == "https" && newReq.Scheme == "http" { return "" } + if explicitRef != "" { + return explicitRef + } + referer := lastReq.String() if lastReq.User != nil { // This is not very efficient, but is the best we can @@ -200,6 +205,9 @@ func (c *Client) transport() RoundTripper { return DefaultTransport } +// ErrSchemeMismatch is returned when a server returns an HTTP response to an HTTPS client. +var ErrSchemeMismatch = errors.New("http: server gave HTTP response to HTTPS client") + // send issues an HTTP request. // Caller should close resp.Body when done reading from it. func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { @@ -261,7 +269,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d // response looks like HTTP and give a more helpful error. // See golang.org/issue/11111. if string(tlsErr.RecordHeader[:]) == "HTTP/" { - err = errors.New("http: server gave HTTP response to HTTPS client") + err = ErrSchemeMismatch } } return nil, didTimeout, err @@ -677,7 +685,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // Add the Referer header from the most recent // request URL to the new one, if it's not https->http: - if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { + if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL, req.Header.Get("Referer")); ref != "" { req.Header.Set("Referer", ref) } err = c.checkRedirect(req, reqs) @@ -991,8 +999,8 @@ func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { // directly, we don't know their scope, so we assume // it's for *.domain.com. - ihost := canonicalAddr(initial) - dhost := canonicalAddr(dest) + ihost := idnaASCIIFromURL(initial) + dhost := idnaASCIIFromURL(dest) return isDomainOrSubdomain(dhost, ihost) } // All other headers are copied: @@ -1007,6 +1015,12 @@ func isDomainOrSubdomain(sub, parent string) bool { if sub == parent { return true } + // If sub contains a :, it's probably an IPv6 address (and is definitely not a hostname). + // Don't check the suffix in this case, to avoid matching the contents of a IPv6 zone. + // For example, "::1%.www.example.com" is not a subdomain of "www.example.com". + if strings.ContainsAny(sub, ":%") { + return false + } // If sub is "foo.example.com" and parent is "example.com", // that means sub must end in "."+parent. // Do it without allocating. diff --git a/client_test.go b/client_test.go index c95fa156..c5e5e5f0 100644 --- a/client_test.go +++ b/client_test.go @@ -1208,7 +1208,7 @@ func testClientTimeout(t *testing.T, mode testMode) { })) // Try to trigger a timeout after reading part of the response body. - // The initial timeout is emprically usually long enough on a decently fast + // The initial timeout is empirically usually long enough on a decently fast // machine, but if we undershoot we'll retry with exponentially longer // timeouts until the test either passes or times out completely. // This keeps the test reasonably fast in the typical case but allows it to @@ -1412,24 +1412,32 @@ func (f eofReaderFunc) Read(p []byte) (n int, err error) { func TestReferer(t *testing.T) { tests := []struct { - lastReq, newReq string // from -> to URLs - want string + lastReq, newReq, explicitRef string // from -> to URLs, explicitly set Referer value + want string }{ // don't send user: - {"http://gopher@test.com", "http://link.com", "http://test.com"}, - {"https://gopher@test.com", "https://link.com", "https://test.com"}, + {lastReq: "http://gopher@test.com", newReq: "http://link.com", want: "http://test.com"}, + {lastReq: "https://gopher@test.com", newReq: "https://link.com", want: "https://test.com"}, // don't send a user and password: - {"http://gopher:go@test.com", "http://link.com", "http://test.com"}, - {"https://gopher:go@test.com", "https://link.com", "https://test.com"}, + {lastReq: "http://gopher:go@test.com", newReq: "http://link.com", want: "http://test.com"}, + {lastReq: "https://gopher:go@test.com", newReq: "https://link.com", want: "https://test.com"}, // nothing to do: - {"http://test.com", "http://link.com", "http://test.com"}, - {"https://test.com", "https://link.com", "https://test.com"}, + {lastReq: "http://test.com", newReq: "http://link.com", want: "http://test.com"}, + {lastReq: "https://test.com", newReq: "https://link.com", want: "https://test.com"}, // https to http doesn't send a referer: - {"https://test.com", "http://link.com", ""}, - {"https://gopher:go@test.com", "http://link.com", ""}, + {lastReq: "https://test.com", newReq: "http://link.com", want: ""}, + {lastReq: "https://gopher:go@test.com", newReq: "http://link.com", want: ""}, + + // https to http should remove an existing referer: + {lastReq: "https://test.com", newReq: "http://link.com", explicitRef: "https://foo.com", want: ""}, + {lastReq: "https://gopher:go@test.com", newReq: "http://link.com", explicitRef: "https://foo.com", want: ""}, + + // don't override an existing referer: + {lastReq: "https://test.com", newReq: "https://link.com", explicitRef: "https://foo.com", want: "https://foo.com"}, + {lastReq: "https://gopher:go@test.com", newReq: "https://link.com", explicitRef: "https://foo.com", want: "https://foo.com"}, } for _, tt := range tests { l, err := url.Parse(tt.lastReq) @@ -1440,7 +1448,7 @@ func TestReferer(t *testing.T) { if err != nil { t.Fatal(err) } - r := ExportRefererForURL(l, n) + r := ExportRefererForURL(l, n, tt.explicitRef) if r != tt.want { t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want) } @@ -1471,6 +1479,9 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { } // Issue 4800: copy (some) headers when Client follows a redirect. +// Issue 35104: Since both URLs have the same host (localhost) +// but different ports, sensitive headers like Cookie and Authorization +// are preserved. func TestClientCopyHeadersOnRedirect(t *testing.T) { run(t, testClientCopyHeadersOnRedirect) } func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) { const ( @@ -1484,6 +1495,8 @@ func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) { "X-Foo": []string{xfoo}, "Referer": []string{ts2URL}, "Accept-Encoding": []string{"gzip"}, + "Cookie": []string{"foo=bar"}, + "Authorization": []string{"secretpassword"}, } if !reflect.DeepEqual(r.Header, want) { t.Errorf("Request.Header = %#v; want %#v", r.Header, want) @@ -1502,9 +1515,11 @@ func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) { c := ts1.Client() c.CheckRedirect = func(r *Request, via []*Request) error { want := Header{ - "User-Agent": []string{ua}, - "X-Foo": []string{xfoo}, - "Referer": []string{ts2URL}, + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + "Cookie": []string{"foo=bar"}, + "Authorization": []string{"secretpassword"}, } if !reflect.DeepEqual(r.Header, want) { t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) @@ -1708,18 +1723,31 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) { {"cookie", "http://foo.com/", "http://bar.com/", false}, {"cookie2", "http://foo.com/", "http://bar.com/", false}, {"authorization", "http://foo.com/", "http://bar.com/", false}, + {"authorization", "http://foo.com/", "https://foo.com/", true}, + {"authorization", "http://foo.com:1234/", "http://foo.com:4321/", true}, {"www-authenticate", "http://foo.com/", "http://bar.com/", false}, + {"authorization", "http://foo.com/", "http://[::1%25.foo.com]/", false}, // But subdomains should work: {"www-authenticate", "http://foo.com/", "http://foo.com/", true}, {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true}, {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false}, - {"www-authenticate", "http://foo.com/", "https://foo.com/", false}, + {"www-authenticate", "http://foo.com/", "https://foo.com/", true}, {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true}, {"www-authenticate", "http://foo.com:80/", "http://sub.foo.com/", true}, {"www-authenticate", "http://foo.com:443/", "https://foo.com/", true}, {"www-authenticate", "http://foo.com:443/", "https://sub.foo.com/", true}, - {"www-authenticate", "http://foo.com:1234/", "http://foo.com/", false}, + {"www-authenticate", "http://foo.com:1234/", "http://foo.com/", true}, + + {"authorization", "http://foo.com/", "http://foo.com/", true}, + {"authorization", "http://foo.com/", "http://sub.foo.com/", true}, + {"authorization", "http://foo.com/", "http://notfoo.com/", false}, + {"authorization", "http://foo.com/", "https://foo.com/", true}, + {"authorization", "http://foo.com:80/", "http://foo.com/", true}, + {"authorization", "http://foo.com:80/", "http://sub.foo.com/", true}, + {"authorization", "http://foo.com:443/", "https://foo.com/", true}, + {"authorization", "http://foo.com:443/", "https://sub.foo.com/", true}, + {"authorization", "http://foo.com:1234/", "http://foo.com/", true}, } for i, tt := range tests { u0, err := url.Parse(tt.initialURL) diff --git a/clientserver_test.go b/clientserver_test.go index 3ddd7fb9..b162c323 100644 --- a/clientserver_test.go +++ b/clientserver_test.go @@ -1240,9 +1240,9 @@ func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { func TestInterruptWithPanic(t *testing.T) { run(t, func(t *testing.T, mode testMode) { t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) - t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) + t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) }) t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) - }) + }, testNotParallel) } func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { const msg = "hello" @@ -1284,24 +1284,28 @@ func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { } wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler - if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error { + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { gotLog := logOutput() if !wantStackLogged { if gotLog == "" { - return nil + return true } - return fmt.Errorf("want no log output; got: %s", gotLog) + t.Fatalf("want no log output; got: %s", gotLog) } if gotLog == "" { - return fmt.Errorf("wanted a stack trace logged; got nothing") + if d > 0 { + t.Logf("wanted a stack trace logged; got nothing after %v", d) + } + return false } if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { - return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog) + if d > 0 { + t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog) + } + return false } - return nil - }); err != nil { - t.Fatal(err) - } + return true + }) } type lockedBytesBuffer struct { diff --git a/cookiejar/jar.go b/cookiejar/jar.go index 320d7e2e..db5575ff 100644 --- a/cookiejar/jar.go +++ b/cookiejar/jar.go @@ -363,10 +363,17 @@ func jarKey(host string, psl PublicSuffixList) string { // isIP reports whether host is an IP address. func isIP(host string) bool { + if strings.ContainsAny(host, ":%") { + // Probable IPv6 address. + // Hostnames can't contain : or %, so this is definitely not a valid host. + // Treating it as an IP is the more conservative option, and avoids the risk + // of interpeting ::1%.www.example.com as a subtomain of www.example.com. + return true + } return net.ParseIP(host) != nil } -// defaultPath returns the directory part of an URL's path according to +// defaultPath returns the directory part of a URL's path according to // RFC 6265 section 5.1.4. func defaultPath(path string) string { if len(path) == 0 || path[0] != '/' { @@ -380,7 +387,7 @@ func defaultPath(path string) string { return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". } -// newEntry creates an entry from a http.Cookie c. now is the current time and +// newEntry creates an entry from an http.Cookie c. now is the current time and // is compared to c.Expires to determine deletion of c. defPath and host are the // default-path and the canonical host name of the URL c was received from. // @@ -466,14 +473,14 @@ func (j *Jar) domainAndType(host, domain string) (string, bool, error) { // dot in the domain-attribute before processing the cookie. // // Most browsers don't do that for IP addresses, only curl - // version 7.54) and IE (version 11) do not reject a + // (version 7.54) and IE (version 11) do not reject a // Set-Cookie: a=1; domain=.127.0.0.1 // This leading dot is optional and serves only as hint for // humans to indicate that a cookie with "domain=.bbc.co.uk" // would be sent to every subdomain of bbc.co.uk. // It just doesn't make sense on IP addresses. // The other processing and validation steps in RFC 6265 just - // collaps to: + // collapse to: if host != domain { return "", false, errIllegalDomain } diff --git a/cookiejar/jar_test.go b/cookiejar/jar_test.go index b26cb84e..22b7759e 100644 --- a/cookiejar/jar_test.go +++ b/cookiejar/jar_test.go @@ -253,6 +253,7 @@ var isIPTests = map[string]bool{ "127.0.0.1": true, "1.2.3.4": true, "2001:4860:0:2001::68": true, + "::1%zone": true, "example.com": false, "1.1.1.300": false, "www.foo.bar.net": false, @@ -350,7 +351,7 @@ func expiresIn(delta int) string { return "expires=" + t.Format(time.RFC1123) } -// mustParseURL parses s to an URL and panics on error. +// mustParseURL parses s to a URL and panics on error. func mustParseURL(s string) *url.URL { u, err := url.Parse(s) if err != nil || u.Scheme == "" || u.Host == "" { @@ -630,6 +631,15 @@ var basicsTests = [...]jarTest{ {"http://www.host.test:1234/", "a=1"}, }, }, + { + "IPv6 zone is not treated as a host.", + "https://example.com/", + []string{"a=1"}, + "a=1", + []query{ + {"https://[::1%25.example.com]:80/", ""}, + }, + }, } func TestBasics(t *testing.T) { @@ -671,7 +681,7 @@ var updateAndDeleteTests = [...]jarTest{ }, }, { - "Clear Secure flag from a http.", + "Clear Secure flag from an http.", "http://www.host.test/", []string{ "b=xx", diff --git a/doc.go b/doc.go index 50b6b841..45a8a77b 100644 --- a/doc.go +++ b/doc.go @@ -19,7 +19,7 @@ Get, Head, Post, and PostForm make HTTP (or HTTPS) requests: resp, err := http.PostForm("http://example.com/form", url.Values{"key": {"Value"}, "id": {"123"}}) -The client must close the response body when finished with it: +The caller must close the response body when finished with it: resp, err := http.Get("http://example.com/") if err != nil { @@ -29,6 +29,8 @@ The client must close the response body when finished with it: body, err := io.ReadAll(resp.Body) // ... +# Clients and Transports + For control over HTTP client headers, redirect policy, and other settings, create a Client: @@ -59,6 +61,8 @@ compression, and other settings, create a Transport: Clients and Transports are safe for concurrent use by multiple goroutines and for efficiency should only be created once and re-used. +# Servers + ListenAndServe starts an HTTP server with a given address and handler. The handler is usually nil, which means to use DefaultServeMux. Handle and HandleFunc add handlers to DefaultServeMux: @@ -83,11 +87,13 @@ custom Server: } log.Fatal(s.ListenAndServe()) +# HTTP/2 + Starting with Go 1.6, the http package has transparent support for the HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2 can do so by setting Transport.TLSNextProto (for clients) or Server.TLSNextProto (for servers) to a non-nil, empty -map. Alternatively, the following GODEBUG environment variables are +map. Alternatively, the following GODEBUG settings are currently supported: GODEBUG=http2client=0 # disable HTTP/2 client support @@ -95,9 +101,7 @@ currently supported: GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs GODEBUG=http2debug=2 # ... even more verbose, with frame dumps -The GODEBUG variables are not covered by Go's API compatibility -promise. Please report any issues before disabling HTTP/2 -support: https://golang.org/s/http2bug +Please report any issues before disabling HTTP/2 support: https://golang.org/s/http2bug The http package's Transport and Server both automatically enable HTTP/2 support for simple configurations. To enable HTTP/2 for more diff --git a/example/go.mod b/example/go.mod index a7a11b9b..9b8f5b77 100644 --- a/example/go.mod +++ b/example/go.mod @@ -1,6 +1,8 @@ module github.com/ooni/oohttp/example -go 1.20 +go 1.21 + +toolchain go1.21.8 require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 diff --git a/example/go.sum b/example/go.sum index 47022550..a72e40b7 100644 --- a/example/go.sum +++ b/example/go.sum @@ -19,7 +19,9 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -51,11 +53,14 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian/v3 v3.3.2 h1:IqNFLAmvJOgVlpdEBiQbDc2EwKW77amAycfTuWKdfvw= github.com/google/martian/v3 v3.3.2/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= github.com/ooni/oohttp v0.6.8-0.20240322100813-2a1cdc95a941 h1:es2QWDycIahrNuEkXClWSNIs9ISSdNGVNGYSkMJv5Tc= github.com/ooni/oohttp v0.6.8-0.20240322100813-2a1cdc95a941/go.mod h1:Vipww76rE6i/Lyd+M8gec/ixPrsyPti1J8xTyqzFIHA= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -108,6 +113,7 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= @@ -138,5 +144,6 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/export_test.go b/export_test.go index fb5ab939..5d198f3f 100644 --- a/export_test.go +++ b/export_test.go @@ -36,7 +36,7 @@ var ( Export_is408Message = is408Message ) -const MaxWriteWaitBeforeConnReuse = maxWriteWaitBeforeConnReuse +var MaxWriteWaitBeforeConnReuse = &maxWriteWaitBeforeConnReuse func init() { // We only want to pay for this cost during testing. @@ -142,9 +142,11 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string { pool.mu.Lock() defer pool.mu.Unlock() - for k, cc := range pool.conns { - for range cc { - ret = append(ret, k) + for k, ccs := range pool.conns { + for _, cc := range ccs { + if cc.idleState().canTakeNewRequest { + ret = append(ret, k) + } } } diff --git a/fcgi/fcgi.go b/fcgi/fcgi.go index fb822f8a..56f7d407 100644 --- a/fcgi/fcgi.go +++ b/fcgi/fcgi.go @@ -99,8 +99,10 @@ func (h *header) init(recType recType, reqId uint16, contentLength int) { // conn sends records over rwc type conn struct { - mutex sync.Mutex - rwc io.ReadWriteCloser + mutex sync.Mutex + rwc io.ReadWriteCloser + closeErr error + closed bool // to avoid allocations buf bytes.Buffer @@ -111,10 +113,15 @@ func newConn(rwc io.ReadWriteCloser) *conn { return &conn{rwc: rwc} } +// Close closes the conn if it is not already closed. func (c *conn) Close() error { c.mutex.Lock() defer c.mutex.Unlock() - return c.rwc.Close() + if !c.closed { + c.closeErr = c.rwc.Close() + c.closed = true + } + return c.closeErr } type record struct { diff --git a/fcgi/fcgi_test.go b/fcgi/fcgi_test.go index 20b825f1..3140f9e4 100644 --- a/fcgi/fcgi_test.go +++ b/fcgi/fcgi_test.go @@ -242,7 +242,7 @@ func TestChildServeCleansUp(t *testing.T) { input := make([]byte, len(tt.input)) copy(input, tt.input) rc := nopWriteCloser{bytes.NewReader(input)} - done := make(chan bool) + done := make(chan struct{}) c := newChild(rc, http.HandlerFunc(func( w http.ResponseWriter, r *http.Request, @@ -253,9 +253,9 @@ func TestChildServeCleansUp(t *testing.T) { t.Errorf("Expected %#v, got %#v", tt.err, err) } // not reached if body of request isn't closed - done <- true + close(done) })) - go c.serve() + c.serve() // wait for body of request to be closed or all goroutines to block <-done } @@ -332,7 +332,7 @@ func TestChildServeReadsEnvVars(t *testing.T) { input := make([]byte, len(tt.input)) copy(input, tt.input) rc := nopWriteCloser{bytes.NewReader(input)} - done := make(chan bool) + done := make(chan struct{}) c := newChild(rc, http.HandlerFunc(func( w http.ResponseWriter, r *http.Request, @@ -344,9 +344,9 @@ func TestChildServeReadsEnvVars(t *testing.T) { } else if env[tt.envVar] != tt.expectedVal { t.Errorf("Expected %s, got %s", tt.expectedVal, env[tt.envVar]) } - done <- true + close(done) })) - go c.serve() + c.serve() <-done } } @@ -382,7 +382,7 @@ func TestResponseWriterSniffsContentType(t *testing.T) { input := make([]byte, len(streamFullRequestStdin)) copy(input, streamFullRequestStdin) rc := nopWriteCloser{bytes.NewReader(input)} - done := make(chan bool) + done := make(chan struct{}) var resp *response c := newChild(rc, http.HandlerFunc(func( w http.ResponseWriter, @@ -390,10 +390,9 @@ func TestResponseWriterSniffsContentType(t *testing.T) { ) { io.WriteString(w, tt.body) resp = w.(*response) - done <- true + close(done) })) - defer c.cleanUp() - go c.serve() + c.serve() <-done if got := resp.Header().Get("Content-Type"); got != tt.wantCT { t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT) @@ -402,25 +401,27 @@ func TestResponseWriterSniffsContentType(t *testing.T) { } } -type signalingNopCloser struct { - io.Reader +type signalingNopWriteCloser struct { + io.ReadCloser closed chan bool } -func (*signalingNopCloser) Write(buf []byte) (int, error) { +func (*signalingNopWriteCloser) Write(buf []byte) (int, error) { return len(buf), nil } -func (rc *signalingNopCloser) Close() error { +func (rc *signalingNopWriteCloser) Close() error { close(rc.closed) - return nil + return rc.ReadCloser.Close() } // Test whether server properly closes connection when processing slow // requests func TestSlowRequest(t *testing.T) { pr, pw := io.Pipe() - go func(w io.Writer) { + + writerDone := make(chan struct{}) + go func() { for _, buf := range [][]byte{ streamBeginTypeStdin, makeRecord(typeStdin, 1, nil), @@ -428,9 +429,14 @@ func TestSlowRequest(t *testing.T) { pw.Write(buf) time.Sleep(100 * time.Millisecond) } - }(pw) - - rc := &signalingNopCloser{pr, make(chan bool)} + close(writerDone) + }() + defer func() { + <-writerDone + pw.Close() + }() + + rc := &signalingNopWriteCloser{pr, make(chan bool)} handlerDone := make(chan bool) c := newChild(rc, http.HandlerFunc(func( @@ -440,16 +446,9 @@ func TestSlowRequest(t *testing.T) { w.WriteHeader(200) close(handlerDone) })) - go c.serve() - defer c.cleanUp() - - timeout := time.After(2 * time.Second) + c.serve() <-handlerDone - select { - case <-rc.closed: - t.Log("FastCGI child closed connection") - case <-timeout: - t.Error("FastCGI child did not close socket after handling request") - } + <-rc.closed + t.Log("FastCGI child closed connection") } diff --git a/fs_test.go b/fs_test.go index 4736a7d2..a3553559 100644 --- a/fs_test.go +++ b/fs_test.go @@ -88,15 +88,39 @@ func testServeFile(t *testing.T, mode testMode) { if req.URL, err = url.Parse(ts.URL); err != nil { t.Fatal("ParseURL:", err) } - req.Method = "GET" - // straight GET - _, body := getBody(t, "straight get", req, c) - if !bytes.Equal(body, file) { - t.Fatalf("body mismatch: got %q, want %q", body, file) + // Get contents via various methods. + // + // See https://go.dev/issue/59471 for a proposal to limit the set of methods handled. + // For now, test the historical behavior. + for _, method := range []string{ + MethodGet, + MethodPost, + MethodPut, + MethodPatch, + MethodDelete, + MethodOptions, + MethodTrace, + } { + req.Method = method + _, body := getBody(t, method, req, c) + if !bytes.Equal(body, file) { + t.Fatalf("body mismatch for %v request: got %q, want %q", method, body, file) + } + } + + // HEAD request. + req.Method = MethodHead + resp, body := getBody(t, "HEAD", req, c) + if len(body) != 0 { + t.Fatalf("body mismatch for HEAD request: got %q, want empty", body) + } + if got, want := resp.Header.Get("Content-Length"), fmt.Sprint(len(file)); got != want { + t.Fatalf("Content-Length mismatch for HEAD request: got %v, want %v", got, want) } // Range tests + req.Method = MethodGet Cases: for _, rt := range ServeFileRangeTests { if rt.r != "" { @@ -745,6 +769,10 @@ func (f *fakeFileInfo) Mode() fs.FileMode { return 0644 } +func (f *fakeFileInfo) String() string { + return fs.FormatFileInfo(f) +} + type fakeFile struct { io.ReadSeeker fi *fakeFileInfo @@ -1483,3 +1511,52 @@ func testServeFileRejectsInvalidSuffixLengths(t *testing.T, mode testMode) { }) } } + +func TestFileServerMethods(t *testing.T) { + run(t, testFileServerMethods) +} +func testFileServerMethods(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts + + file, err := os.ReadFile(testFile) + if err != nil { + t.Fatal("reading file:", err) + } + + // Get contents via various methods. + // + // See https://go.dev/issue/59471 for a proposal to limit the set of methods handled. + // For now, test the historical behavior. + for _, method := range []string{ + MethodGet, + MethodHead, + MethodPost, + MethodPut, + MethodPatch, + MethodDelete, + MethodOptions, + MethodTrace, + } { + req, _ := NewRequest(method, ts.URL+"/file", nil) + t.Log(req.URL) + res, err := ts.Client().Do(req) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + wantBody := file + if method == MethodHead { + wantBody = nil + } + if !bytes.Equal(body, wantBody) { + t.Fatalf("%v: got body %q, want %q", method, body, wantBody) + } + if got, want := res.Header.Get("Content-Length"), fmt.Sprint(len(file)); got != want { + t.Fatalf("%v: got Content-Length %q, want %q", method, got, want) + } + } +} diff --git a/go.mod b/go.mod index 44a11f69..13987b60 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ooni/oohttp -go 1.20 +go 1.21 require golang.org/x/net v0.22.0 diff --git a/h2_bundle.go b/h2_bundle.go index 1d9720f2..2946bc92 100644 --- a/h2_bundle.go +++ b/h2_bundle.go @@ -1303,23 +1303,91 @@ var ( http2errPseudoAfterRegular = errors.New("pseudo header field after regular") ) -// flow is the flow control window's size. -type http2flow struct { +// inflowMinRefresh is the minimum number of bytes we'll send for a +// flow control window update. +const http2inflowMinRefresh = 4 << 10 + +// inflow accounts for an inbound flow control window. +// It tracks both the latest window sent to the peer (used for enforcement) +// and the accumulated unsent window. +type http2inflow struct { + avail int32 + unsent int32 +} + +// init sets the initial window. +func (f *http2inflow) init(n int32) { + f.avail = n +} + +// add adds n bytes to the window, with a maximum window size of max, +// indicating that the peer can now send us more data. +// For example, the user read from a {Request,Response} body and consumed +// some of the buffered data, so the peer can now send more. +// It returns the number of bytes to send in a WINDOW_UPDATE frame to the peer. +// Window updates are accumulated and sent when the unsent capacity +// is at least inflowMinRefresh or will at least double the peer's available window. +func (f *http2inflow) add(n int) (connAdd int32) { + if n < 0 { + panic("negative update") + } + unsent := int64(f.unsent) + int64(n) + // "A sender MUST NOT allow a flow-control window to exceed 2^31-1 octets." + // RFC 7540 Section 6.9.1. + const maxWindow = 1<<31 - 1 + if unsent+int64(f.avail) > maxWindow { + panic("flow control update exceeds maximum window size") + } + f.unsent = int32(unsent) + if f.unsent < http2inflowMinRefresh && f.unsent < f.avail { + // If there aren't at least inflowMinRefresh bytes of window to send, + // and this update won't at least double the window, buffer the update for later. + return 0 + } + f.avail += f.unsent + f.unsent = 0 + return int32(unsent) +} + +// take attempts to take n bytes from the peer's flow control window. +// It reports whether the window has available capacity. +func (f *http2inflow) take(n uint32) bool { + if n > uint32(f.avail) { + return false + } + f.avail -= int32(n) + return true +} + +// takeInflows attempts to take n bytes from two inflows, +// typically connection-level and stream-level flows. +// It reports whether both windows have available capacity. +func http2takeInflows(f1, f2 *http2inflow, n uint32) bool { + if n > uint32(f1.avail) || n > uint32(f2.avail) { + return false + } + f1.avail -= int32(n) + f2.avail -= int32(n) + return true +} + +// outflow is the outbound flow control window's size. +type http2outflow struct { _ http2incomparable // n is the number of DATA bytes we're allowed to send. - // A flow is kept both on a conn and a per-stream. + // An outflow is kept both on a conn and a per-stream. n int32 - // conn points to the shared connection-level flow that is - // shared by all streams on that conn. It is nil for the flow + // conn points to the shared connection-level outflow that is + // shared by all streams on that conn. It is nil for the outflow // that's on the conn directly. - conn *http2flow + conn *http2outflow } -func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf } +func (f *http2outflow) setConnFlow(cf *http2outflow) { f.conn = cf } -func (f *http2flow) available() int32 { +func (f *http2outflow) available() int32 { n := f.n if f.conn != nil && f.conn.n < n { n = f.conn.n @@ -1327,7 +1395,7 @@ func (f *http2flow) available() int32 { return n } -func (f *http2flow) take(n int32) { +func (f *http2outflow) take(n int32) { if n > f.available() { panic("internal error: took too much") } @@ -1339,7 +1407,7 @@ func (f *http2flow) take(n int32) { // add adds n bytes (positive or negative) to the flow control window. // It returns false if the sum would exceed 2^31-1. -func (f *http2flow) add(n int32) bool { +func (f *http2outflow) add(n int32) bool { sum := f.n + n if (sum > n) == (f.n > 0) { f.n = sum @@ -1995,6 +2063,15 @@ func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) er // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if err := f.startWriteDataPadded(streamID, endStream, data, pad); err != nil { + return err + } + return f.endWrite() +} + +// startWriteDataPadded is WriteDataPadded, but only writes the frame to the Framer's internal buffer. +// The caller should call endWrite to flush the frame to the underlying writer. +func (f *http2Framer) startWriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !http2validStreamID(streamID) && !f.AllowIllegalWrites { return http2errStreamID } @@ -2024,7 +2101,7 @@ func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad } f.wbuf = append(f.wbuf, data...) f.wbuf = append(f.wbuf, pad...) - return f.endWrite() + return nil } // A SettingsFrame conveys configuration parameters that affect how @@ -3700,13 +3777,9 @@ func (p *http2pipe) Write(d []byte) (n int, err error) { p.c.L = &p.mu } defer p.c.Signal() - if p.err != nil { + if p.err != nil || p.breakErr != nil { return 0, http2errClosedPipeWrite } - if p.breakErr != nil { - p.unread += len(d) - return len(d), nil // discard when there is no reader - } return p.b.Write(d) } @@ -4180,14 +4253,14 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { if s.NewWriteScheduler != nil { sc.writeSched = s.NewWriteScheduler() } else { - sc.writeSched = http2NewPriorityWriteScheduler(nil) + sc.writeSched = http2newRoundRobinWriteScheduler() } // These start at the RFC-specified defaults. If there is a higher // configured value for inflow, that will be updated when we send a // WINDOW_UPDATE shortly after sending SETTINGS. sc.flow.add(http2initialWindowSize) - sc.inflow.add(http2initialWindowSize) + sc.inflow.init(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize()) @@ -4302,8 +4375,8 @@ type http2serverConn struct { wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes bodyReadCh chan http2bodyReadMsg // from handlers -> serve serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control + flow http2outflow // conn-wide (not stream-specific) outbound flow control + inflow http2inflow // conn-wide inbound flow control tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string writeSched http2WriteScheduler @@ -4382,10 +4455,10 @@ type http2stream struct { cancelCtx func() // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow http2flow // limits writing from Handler to client - inflow http2flow // what the client is allowed to POST/etc to us + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow http2outflow // limits writing from Handler to client + inflow http2inflow // what the client is allowed to POST/etc to us state http2streamState resetQueued bool // RST_STREAM queued for write; set by sc.resetStream gotTrailerHeader bool // HEADER frame for trailers was seen @@ -4587,8 +4660,13 @@ type http2frameWriteResult struct { // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { - err := wr.write.writeFrame(sc) +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest, wd *http2writeData) { + var err error + if wd == nil { + err = wr.write.writeFrame(sc) + } else { + err = sc.framer.endWrite() + } sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} } @@ -5000,9 +5078,16 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { sc.writingFrameAsync = false err := wr.write.writeFrame(sc) sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) + } else if wd, ok := wr.write.(*http2writeData); ok { + // Encode the frame in the serve goroutine, to ensure we don't have + // any lingering asynchronous references to data passed to Write. + // See https://go.dev/issue/58446. + sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil) + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr, wd) } else { sc.writingFrameAsync = true - go sc.writeFrameAsync(wr) + go sc.writeFrameAsync(wr, nil) } } @@ -5252,7 +5337,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) { if f, ok := f.(*http2DataFrame); ok { - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", http2streamError(f.Header().StreamID, http2ErrCodeFlowControl)) } sc.sendWindowUpdate(nil, int(f.Length)) // conn-level @@ -5524,14 +5609,9 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // But still enforce their connection-level flow control, // and return any flow control bytes since we're not going // to consume them. - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) } - // Deduct the flow control from inflow, since we're - // going to immediately add it back in - // sendWindowUpdate, which also schedules sending the - // frames. - sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) // conn-level if st != nil && st.resetQueued { @@ -5546,10 +5626,9 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // Sender sending more than they'd declared? if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { - if sc.inflow.available() < int32(f.Length) { + if !sc.inflow.take(f.Length) { return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) } - sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) // conn-level st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) @@ -5560,29 +5639,33 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { } if f.Length > 0 { // Check whether the client has flow control quota. - if st.inflow.available() < int32(f.Length) { + if !http2takeInflows(&sc.inflow, &st.inflow, f.Length) { return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) } - st.inflow.take(int32(f.Length)) if len(data) > 0 { + st.bodyBytes += int64(len(data)) wrote, err := st.body.Write(data) if err != nil { + // The handler has closed the request body. + // Return the connection-level flow control for the discarded data, + // but not the stream-level flow control. sc.sendWindowUpdate(nil, int(f.Length)-wrote) - return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) + return nil } if wrote != len(data) { panic("internal error: bad Writer") } - st.bodyBytes += int64(len(data)) } // Return any padded flow control now, since we won't // refund it later on body reads. - if pad := int32(f.Length) - int32(len(data)); pad > 0 { - sc.sendWindowUpdate32(nil, pad) - sc.sendWindowUpdate32(st, pad) - } + // Call sendWindowUpdate even if there is no padding, + // to return buffered flow control credit if the sent + // window has shrunk. + pad := int32(f.Length) - int32(len(data)) + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) } if f.StreamEnded() { st.endStream() @@ -5857,8 +5940,7 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState st.cw.Init() st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) - st.inflow.conn = &sc.inflow // link to conn-level counter - st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + st.inflow.init(sc.srv.initialStreamRecvWindowSize()) if sc.hs.WriteTimeout != 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } @@ -5950,7 +6032,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re tlsState = sc.tlsState } - needsContinue := rp.header.Get("Expect") == "100-continue" + needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue") if needsContinue { rp.header.Del("Expect") } @@ -6194,47 +6276,28 @@ func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { } // st may be nil for conn-level -func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { - sc.serveG.check() - // "The legal range for the increment to the flow control - // window is 1 to 2^31-1 (2,147,483,647) octets." - // A Go Read call on 64-bit machines could in theory read - // a larger Read than this. Very unlikely, but we handle it here - // rather than elsewhere for now. - const maxUint31 = 1<<31 - 1 - for n > maxUint31 { - sc.sendWindowUpdate32(st, maxUint31) - n -= maxUint31 - } - sc.sendWindowUpdate32(st, int32(n)) +func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { + sc.sendWindowUpdate(st, int(n)) } // st may be nil for conn-level -func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { +func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) { sc.serveG.check() - if n == 0 { - return - } - if n < 0 { - panic("negative update") - } var streamID uint32 - if st != nil { + var send int32 + if st == nil { + send = sc.inflow.add(n) + } else { streamID = st.id + send = st.inflow.add(n) + } + if send == 0 { + return } sc.writeFrame(http2FrameWriteRequest{ - write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, + write: http2writeWindowUpdate{streamID: streamID, n: uint32(send)}, stream: st, }) - var ok bool - if st == nil { - ok = sc.inflow.add(n) - } else { - ok = st.inflow.add(n) - } - if !ok { - panic("internal error; sent too many window updates without decrements?") - } } // requestBody is the Handler's Request.Body type. @@ -6245,7 +6308,7 @@ type http2requestBody struct { conn *http2serverConn closeOnce sync.Once // for use by Close only sawEOF bool // for use by Read only - pipe *http2pipe // non-nil if we have a HTTP entity message body + pipe *http2pipe // non-nil if we have an HTTP entity message body needsContinue bool // need to send a 100-continue } @@ -6385,7 +6448,8 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { clen = "" } } - if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { + _, hasContentLength := rws.snapHeader["Content-Length"] + if !hasContentLength && clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) { clen = strconv.Itoa(len(p)) } _, hasContentType := rws.snapHeader["Content-Type"] @@ -6590,7 +6654,7 @@ func (w *http2responseWriter) FlushError() error { err = rws.bw.Flush() } else { // The bufio.Writer won't call chunkWriter.Write - // (writeChunk with zero bytes, so we have to do it + // (writeChunk with zero bytes), so we have to do it // ourselves to force the HTTP response header and/or // final DATA frame (with END_STREAM) to be sent. _, err = http2chunkWriter{rws}.Write(nil) @@ -7067,10 +7131,6 @@ const ( // we buffer per stream. http2transportDefaultStreamFlow = 4 << 20 - // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send - // a stream-level WINDOW_UPDATE for at a time. - http2transportDefaultStreamMinRefresh = 4 << 10 - http2defaultUserAgent = "Go-http-client/2.0" // initialMaxConcurrentStreams is a connections maxConcurrentStreams until @@ -7328,11 +7388,11 @@ type http2ClientConn struct { idleTimeout time.Duration // or 0 for never idleTimer *time.Timer - mu sync.Mutex // guards following - cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow http2flow // our conn-level flow control quota (cs.flow is per stream) - inflow http2flow // peer's conn-level flow control - doNotReuse bool // whether conn is marked to not be reused for any future requests + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2outflow // our conn-level flow control quota (cs.outflow is per stream) + inflow http2inflow // peer's conn-level flow control + doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool seenSettings bool // true if we've seen a settings frame, false otherwise @@ -7396,10 +7456,10 @@ type http2clientStream struct { respHeaderRecv chan struct{} // closed when headers are received res *Response // set if respHeaderRecv is closed - flow http2flow // guarded by cc.mu - inflow http2flow // guarded by cc.mu - bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read - readErr error // sticky read error; owned by transportResponseBody.Read + flow http2outflow // guarded by cc.mu + inflow http2inflow // guarded by cc.mu + bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read + readErr error // sticky read error; owned by transportResponseBody.Read reqBody io.ReadCloser reqBodyContentLength int64 // -1 means unknown @@ -7543,11 +7603,14 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { func http2authorityAddr(scheme string, authority string) (addr string) { host, port, err := net.SplitHostPort(authority) if err != nil { // authority didn't have a port + host = authority + port = "" + } + if port == "" { // authority's port was empty port = "443" if scheme == "http" { port = "80" } - host = authority } if a, err := idna.ToASCII(host); err == nil { host = a @@ -7585,10 +7648,11 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res http2traceGotConn(req, cc, reused) res, err := cc.RoundTrip(req) if err != nil && retry <= 6 { + roundTripErr := err if req, err = http2shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { - t.vlogf("RoundTrip retrying after failure: %v", err) + t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue } backoff := float64(uint(1) << (uint(retry) - 1)) @@ -7597,7 +7661,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res timer := http2backoffNewTimer(d) select { case <-timer.C: - t.vlogf("RoundTrip retrying after failure: %v", err) + t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): timer.Stop() @@ -7832,7 +7896,7 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client cc.bw.Write(http2clientPreface) cc.fr.WriteSettings(initialSettings...) cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) - cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) + cc.inflow.init(http2transportDefaultConnFlow + http2initialWindowSize) cc.bw.Flush() if cc.werr != nil { cc.Close() @@ -8290,6 +8354,29 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { return res, nil } + cancelRequest := func(cs *http2clientStream, err error) error { + cs.cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cs.cc.mu.Unlock() + // Wait for the request body to be closed. + // + // If nothing closed the body before now, abortStreamLocked + // will have started a goroutine to close it. + // + // Closing the body before returning avoids a race condition + // with net/http checking its readTrackingBody to see if the + // body was read from or closed. See golang/go#60041. + // + // The body is closed in a separate goroutine without the + // connection mutex held, but dropping the mutex before waiting + // will keep us from holding it indefinitely if the body + // close is slow for some reason. + if bodyClosed != nil { + <-bodyClosed + } + return err + } + for { select { case <-cs.respHeaderRecv: @@ -8309,10 +8396,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { case <-ctx.Done(): err := ctx.Err() cs.abortStream(err) - return nil, err + return nil, cancelRequest(cs, err) case <-cs.reqCancel: cs.abortStream(http2errRequestCanceled) - return nil, http2errRequestCanceled + return nil, cancelRequest(cs, http2errRequestCanceled) } } } @@ -8594,7 +8681,7 @@ func (cs *http2clientStream) cleanupWriteRequest(err error) { close(cs.donec) } -// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. +// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { for { @@ -8869,6 +8956,9 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail if err != nil { return nil, err } + if !httpguts.ValidHostHeader(host) { + return nil, errors.New("http2: invalid Host header") + } var path string if req.Method != "CONNECT" { @@ -8905,7 +8995,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail // 8.1.2.3 Request Pseudo-Header Fields // The :path pseudo-header field includes the path and query parts of the // target URI (the path-absolute production and optionally a '?' character - // followed by the query production (see Sections 3.3 and 3.4 of + // followed by the query production, see Sections 3.3 and 3.4 of // [RFC3986]). f(":authority", host) m := req.Method @@ -9094,8 +9184,7 @@ type http2resAndError struct { func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) - cs.inflow.add(http2transportDefaultStreamFlow) - cs.inflow.setConnFlow(&cc.inflow) + cs.inflow.init(http2transportDefaultStreamFlow) cs.ID = cc.nextStreamID cc.nextStreamID += 2 cc.streams[cs.ID] = cs @@ -9554,21 +9643,10 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { } cc.mu.Lock() - var connAdd, streamAdd int32 - // Check the conn-level first, before the stream-level. - if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { - connAdd = http2transportDefaultConnFlow - v - cc.inflow.add(connAdd) - } + connAdd := cc.inflow.add(n) + var streamAdd int32 if err == nil { // No need to refresh if the stream is over or failed. - // Consider any buffered body data (read from the conn but not - // consumed by the client) when computing flow control for this - // stream. - v := int(cs.inflow.available()) + cs.bufPipe.Len() - if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { - streamAdd = int32(http2transportDefaultStreamFlow - v) - cs.inflow.add(streamAdd) - } + streamAdd = cs.inflow.add(n) } cc.mu.Unlock() @@ -9592,29 +9670,27 @@ func (b http2transportResponseBody) Close() error { cs := b.cs cc := cs.cc + cs.bufPipe.BreakWithError(http2errClosedResponseBody) + cs.abortStream(http2errClosedResponseBody) + unread := cs.bufPipe.Len() if unread > 0 { cc.mu.Lock() // Return connection-level flow control. - if unread > 0 { - cc.inflow.add(int32(unread)) - } + connAdd := cc.inflow.add(unread) cc.mu.Unlock() // TODO(dneil): Acquiring this mutex can block indefinitely. // Move flow control return to a goroutine? cc.wmu.Lock() // Return connection-level flow control. - if unread > 0 { - cc.fr.WriteWindowUpdate(0, uint32(unread)) + if connAdd > 0 { + cc.fr.WriteWindowUpdate(0, uint32(connAdd)) } cc.bw.Flush() cc.wmu.Unlock() } - cs.bufPipe.BreakWithError(http2errClosedResponseBody) - cs.abortStream(http2errClosedResponseBody) - select { case <-cs.donec: case <-cs.ctx.Done(): @@ -9649,13 +9725,18 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { // But at least return their flow control: if f.Length > 0 { cc.mu.Lock() - cc.inflow.add(int32(f.Length)) + ok := cc.inflow.take(f.Length) + connAdd := cc.inflow.add(int(f.Length)) cc.mu.Unlock() - - cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(f.Length)) - cc.bw.Flush() - cc.wmu.Unlock() + if !ok { + return http2ConnectionError(http2ErrCodeFlowControl) + } + if connAdd > 0 { + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(connAdd)) + cc.bw.Flush() + cc.wmu.Unlock() + } } return nil } @@ -9686,9 +9767,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { } // Check connection-level flow control. cc.mu.Lock() - if cs.inflow.available() >= int32(f.Length) { - cs.inflow.take(int32(f.Length)) - } else { + if !http2takeInflows(&cc.inflow, &cs.inflow, f.Length) { cc.mu.Unlock() return http2ConnectionError(http2ErrCodeFlowControl) } @@ -9710,19 +9789,20 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { } } - if refund > 0 { - cc.inflow.add(int32(refund)) - if !didReset { - cs.inflow.add(int32(refund)) - } + sendConn := cc.inflow.add(refund) + var sendStream int32 + if !didReset { + sendStream = cs.inflow.add(refund) } cc.mu.Unlock() - if refund > 0 { + if sendConn > 0 || sendStream > 0 { cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(refund)) - if !didReset { - cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) + if sendConn > 0 { + cc.fr.WriteWindowUpdate(0, uint32(sendConn)) + } + if sendStream > 0 { + cc.fr.WriteWindowUpdate(cs.ID, uint32(sendStream)) } cc.bw.Flush() cc.wmu.Unlock() @@ -10723,7 +10803,8 @@ func (wr *http2FrameWriteRequest) replyToWriter(err error) { // writeQueue is used by implementations of WriteScheduler. type http2writeQueue struct { - s []http2FrameWriteRequest + s []http2FrameWriteRequest + prev, next *http2writeQueue } func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } @@ -11301,3 +11382,112 @@ func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { } return http2FrameWriteRequest{}, false } + +type http2roundRobinWriteScheduler struct { + // control contains control frames (SETTINGS, PING, etc.). + control http2writeQueue + + // streams maps stream ID to a queue. + streams map[uint32]*http2writeQueue + + // stream queues are stored in a circular linked list. + // head is the next stream to write, or nil if there are no streams open. + head *http2writeQueue + + // pool of empty queues for reuse. + queuePool http2writeQueuePool +} + +// newRoundRobinWriteScheduler constructs a new write scheduler. +// The round robin scheduler priorizes control frames +// like SETTINGS and PING over DATA frames. +// When there are no control frames to send, it performs a round-robin +// selection from the ready streams. +func http2newRoundRobinWriteScheduler() http2WriteScheduler { + ws := &http2roundRobinWriteScheduler{ + streams: make(map[uint32]*http2writeQueue), + } + return ws +} + +func (ws *http2roundRobinWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + if ws.streams[streamID] != nil { + panic(fmt.Errorf("stream %d already opened", streamID)) + } + q := ws.queuePool.get() + ws.streams[streamID] = q + if ws.head == nil { + ws.head = q + q.next = q + q.prev = q + } else { + // Queues are stored in a ring. + // Insert the new stream before ws.head, putting it at the end of the list. + q.prev = ws.head.prev + q.next = ws.head + q.prev.next = q + q.next.prev = q + } +} + +func (ws *http2roundRobinWriteScheduler) CloseStream(streamID uint32) { + q := ws.streams[streamID] + if q == nil { + return + } + if q.next == q { + // This was the only open stream. + ws.head = nil + } else { + q.prev.next = q.next + q.next.prev = q.prev + if ws.head == q { + ws.head = q.next + } + } + delete(ws.streams, streamID) + ws.queuePool.put(q) +} + +func (ws *http2roundRobinWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {} + +func (ws *http2roundRobinWriteScheduler) Push(wr http2FrameWriteRequest) { + if wr.isControl() { + ws.control.push(wr) + return + } + q := ws.streams[wr.StreamID()] + if q == nil { + // This is a closed stream. + // wr should not be a HEADERS or DATA frame. + // We push the request onto the control queue. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + ws.control.push(wr) + return + } + q.push(wr) +} + +func (ws *http2roundRobinWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + // Control and RST_STREAM frames first. + if !ws.control.empty() { + return ws.control.shift(), true + } + if ws.head == nil { + return http2FrameWriteRequest{}, false + } + q := ws.head + for { + if wr, ok := q.consume(math.MaxInt32); ok { + ws.head = q.next + return wr, true + } + q = q.next + if q == ws.head { + break + } + } + return http2FrameWriteRequest{}, false +} diff --git a/http.go b/http.go index 101799f5..9b81654f 100644 --- a/http.go +++ b/http.go @@ -86,14 +86,20 @@ func hexEscapeNonASCII(s string) string { return s } b := make([]byte, 0, newLen) + var pos int for i := 0; i < len(s); i++ { if s[i] >= utf8.RuneSelf { + if pos < i { + b = append(b, s[pos:i]...) + } b = append(b, '%') b = strconv.AppendInt(b, int64(s[i]), 16) - } else { - b = append(b, s[i]) + pos = i + 1 } } + if pos < len(s) { + b = append(b, s[pos:]...) + } return string(b) } diff --git a/http_test.go b/http_test.go index 9f4b50e6..b2c4bd0e 100644 --- a/http_test.go +++ b/http_test.go @@ -193,3 +193,13 @@ func TestNoUnicodeStrings(t *testing.T) { t.Fatal(err) } } + +const redirectURL = "/thisaredirect细雪withasciilettersのけぶabcdefghijk.html" + +func BenchmarkHexEscapeNonASCII(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + hexEscapeNonASCII(redirectURL) + } +} diff --git a/httputil/reverseproxy.go b/httputil/reverseproxy.go index 3e239f60..f5b31397 100644 --- a/httputil/reverseproxy.go +++ b/httputil/reverseproxy.go @@ -257,7 +257,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // Rewrite: func(r *ProxyRequest) { // r.SetURL(target) // r.Out.Host = r.In.Host // if desired -// } +// }, // } func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { director := func(req *http.Request) { @@ -524,9 +524,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Force chunking if we saw a response trailer. // This prevents net/http from calculating the length for short // bodies and adding a Content-Length. - if fl, ok := rw.(http.Flusher); ok { - fl.Flush() - } + http.NewResponseController(rw).Flush() } if len(res.Trailer) == announcedTrailers { @@ -601,21 +599,22 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { return p.FlushInterval } -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { +func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var w io.Writer = dst + if flushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: flushInterval, - } - defer mlw.stop() + mlw := &maxLatencyWriter{ + dst: dst, + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, + } + defer mlw.stop() - // set up initial timer so headers get flushed even if body writes are delayed - mlw.flushPending = true - mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) - dst = mlw - } + w = mlw } var buf []byte @@ -623,7 +622,7 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval buf = p.BufferPool.Get() defer p.BufferPool.Put(buf) } - _, err := p.copyBuffer(dst, src, buf) + _, err := p.copyBuffer(w, src, buf) return err } @@ -668,13 +667,9 @@ func (p *ReverseProxy) logf(format string, args ...any) { } } -type writeFlusher interface { - io.Writer - http.Flusher -} - type maxLatencyWriter struct { - dst writeFlusher + dst io.Writer + flush func() error latency time.Duration // non-zero; negative means to flush immediately mu sync.Mutex // protects t, flushPending, and dst.Flush @@ -687,7 +682,7 @@ func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { defer m.mu.Unlock() n, err = m.dst.Write(p) if m.latency < 0 { - m.dst.Flush() + m.flush() return } if m.flushPending { @@ -708,7 +703,7 @@ func (m *maxLatencyWriter) delayedFlush() { if !m.flushPending { // if stop was called but AfterFunc already started this goroutine return } - m.dst.Flush() + m.flush() m.flushPending = false } @@ -739,17 +734,19 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } - hj, ok := rw.(http.Hijacker) - if !ok { - p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) - return - } backConn, ok := res.Body.(io.ReadWriteCloser) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) return } + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConnCloseCh := make(chan bool) go func() { // Ensure that the cancellation of a request closes the backend. @@ -760,12 +757,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R } backConn.Close() }() - defer close(backConnCloseCh) - conn, brw, err := hj.Hijack() - if err != nil { - p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + if hijackErr != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) return } defer conn.Close() diff --git a/httputil/reverseproxy_test.go b/httputil/reverseproxy_test.go index 29e0e8c6..b3951bf1 100644 --- a/httputil/reverseproxy_test.go +++ b/httputil/reverseproxy_test.go @@ -479,6 +479,62 @@ func TestReverseProxyFlushInterval(t *testing.T) { } } +type mockFlusher struct { + http.ResponseWriter + flushed bool +} + +func (m *mockFlusher) Flush() { + m.flushed = true +} + +type wrappedRW struct { + http.ResponseWriter +} + +func (w *wrappedRW) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func TestReverseProxyResponseControllerFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + mf := &mockFlusher{} + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = -1 // flush immediately + proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mf.ResponseWriter = w + w = &wrappedRW{mf} + proxyHandler.ServeHTTP(w, r) + }) + + frontend := httptest.NewServer(proxyWithMiddleware) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + if !mf.flushed { + t.Errorf("response writer was not flushed") + } +} + func TestReverseProxyFlushIntervalHeaders(t *testing.T) { const expected = "hi" stopCh := make(chan struct{}) diff --git a/internal/chunked.go b/internal/chunked.go index 8b6e94b5..aad8e5aa 100644 --- a/internal/chunked.go +++ b/internal/chunked.go @@ -76,9 +76,7 @@ func (cr *chunkedReader) beginChunk() { // Currently, we say that we're willing to accept 16 bytes of overhead per chunk, // plus twice the amount of real data in the chunk. cr.excess -= 16 + (2 * int64(cr.n)) - if cr.excess < 0 { - cr.excess = 0 - } + cr.excess = max(cr.excess, 0) if cr.excess > 16*1024 { cr.err = errors.New("chunked encoding contains too much non-data") } diff --git a/internal/testenv/testenv.go b/internal/testenv/testenv.go index 32df8d8a..1b66ef5f 100644 --- a/internal/testenv/testenv.go +++ b/internal/testenv/testenv.go @@ -2,7 +2,11 @@ // the amount of deleted line with respect to upstream. package testenv -import "testing" +import ( + "context" + "os/exec" + "testing" +) // MustHaveExec always skips the current test. func MustHaveExec(t testing.TB) { @@ -28,3 +32,8 @@ func GoToolPath(t testing.TB) string { func Builder() string { return "" } + +func CommandContext(t testing.TB, ctx context.Context, name string, args ...string) *exec.Cmd { + t.Skip("testenv.CommandContext is not enabled in this fork") + return &exec.Cmd{} +} diff --git a/main_test.go b/main_test.go index fe4e05a9..ff5a0d6d 100644 --- a/main_test.go +++ b/main_test.go @@ -21,6 +21,7 @@ import ( var quietLog = log.New(io.Discard, "", 0) func TestMain(m *testing.M) { + *http.MaxWriteWaitBeforeConnReuse = 60 * time.Minute v := m.Run() if v == 0 && goroutineLeaked() { os.Exit(1) @@ -108,11 +109,30 @@ func runningBenchmarks() bool { return false } +var leakReported bool + func afterTest(t testing.TB) { http.DefaultTransport.(*http.Transport).CloseIdleConnections() if testing.Short() { return } + if leakReported { + // To avoid confusion, only report the first leak of each test run. + // After the first leak has been reported, we can't tell whether the leaked + // goroutines are a new leak from a subsequent test or just the same + // goroutines from the first leak still hanging around, and we may add a lot + // of latency waiting for them to exit at the end of each test. + return + } + + // We shouldn't be running the leak check for parallel tests, because we might + // report the goroutines from a test that is still running as a leak from a + // completely separate test that has just finished. So we use non-atomic loads + // and stores for the leakReported variable, and store every time we start a + // leak check so that the race detector will flag concurrent leak checks as a + // race even if we don't detect any leaks. + leakReported = true + var bad string badSubstring := map[string]string{ ").readLoop(": "a Transport", @@ -132,6 +152,7 @@ func afterTest(t testing.TB) { } } if bad == "" { + leakReported = false return } // Bad stuff found, but goroutines might just still be @@ -141,29 +162,15 @@ func afterTest(t testing.TB) { t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) } -// waitCondition reports whether fn eventually returned true, -// checking immediately and then every checkEvery amount, -// until waitFor has elapsed, at which point it returns false. -func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { - deadline := time.Now().Add(waitFor) - for time.Now().Before(deadline) { - if fn() { - return true - } - time.Sleep(checkEvery) - } - return false -} - -// waitErrCondition is like waitCondition but with errors instead of bools. -func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error { - deadline := time.Now().Add(waitFor) - var err error - for time.Now().Before(deadline) { - if err = fn(); err == nil { - return nil - } - time.Sleep(checkEvery) +// waitCondition waits for fn to return true, +// checking immediately and then at exponentially increasing intervals. +func waitCondition(t testing.TB, delay time.Duration, fn func(time.Duration) bool) { + t.Helper() + start := time.Now() + var since time.Duration + for !fn(since) { + time.Sleep(delay) + delay = 2*delay - (delay / 2) // 1.5x, rounded up + since = time.Since(start) } - return err } diff --git a/omithttp2.go b/omithttp2.go index 3316f55c..ca08ddfa 100644 --- a/omithttp2.go +++ b/omithttp2.go @@ -42,9 +42,17 @@ type http2noDialClientConnPool struct { type http2clientConnPool struct { mu *sync.Mutex - conns map[string][]struct{} + conns map[string][]*http2clientConn } +type http2clientConn struct{} + +type http2clientConnIdleState struct { + canTakeNewRequest bool +} + +func (cc *http2clientConn) idleState() http2clientConnIdleState { return http2clientConnIdleState{} } + func http2configureTransports(*Transport) (*http2Transport, error) { panic(noHTTP2) } func http2isNoCachedConnError(err error) bool { diff --git a/request.go b/request.go index 167e0163..bf56df33 100644 --- a/request.go +++ b/request.go @@ -48,6 +48,11 @@ type ProtocolError struct { func (pe *ProtocolError) Error() string { return pe.ErrorString } +// Is lets http.ErrNotSupported match errors.ErrUnsupported. +func (pe *ProtocolError) Is(err error) bool { + return pe == ErrNotSupported && err == errors.ErrUnsupported +} + var ( // ErrNotSupported indicates that a feature is not supported. // diff --git a/request_test.go b/request_test.go index 39c8c951..b53f0477 100644 --- a/request_test.go +++ b/request_test.go @@ -10,6 +10,7 @@ import ( "context" "crypto/rand" "encoding/base64" + "errors" "fmt" "io" "math" @@ -32,7 +33,7 @@ func TestQuery(t *testing.T) { } } -// Issue #25192: Test that ParseForm fails but still parses the form when an URL +// Issue #25192: Test that ParseForm fails but still parses the form when a URL // containing a semicolon is provided. func TestParseFormSemicolonSeparator(t *testing.T) { for _, method := range []string{"POST", "PATCH", "PUT", "GET"} { @@ -380,7 +381,7 @@ func TestMultipartRequest(t *testing.T) { } // Issue #25192: Test that ParseMultipartForm fails but still parses the -// multi-part form when an URL containing a semicolon is provided. +// multi-part form when a URL containing a semicolon is provided. func TestParseMultipartFormSemicolonSeparator(t *testing.T) { req := newTestMultipartRequest(t) req.URL = &url.URL{RawQuery: "q=foo;q=bar"} @@ -1389,3 +1390,9 @@ func runFileAndServerBenchmarks(b *testing.B, mode testMode, f *os.File, n int64 b.SetBytes(n) } } + +func TestErrNotSupported(t *testing.T) { + if !errors.Is(ErrNotSupported, errors.ErrUnsupported) { + t.Error("errors.Is(ErrNotSupported, errors.ErrUnsupported) failed") + } +} diff --git a/responsecontroller.go b/responsecontroller.go index 018bdc00..92276ffa 100644 --- a/responsecontroller.go +++ b/responsecontroller.go @@ -31,6 +31,7 @@ type ResponseController struct { // Hijack() (net.Conn, *bufio.ReadWriter, error) // SetReadDeadline(deadline time.Time) error // SetWriteDeadline(deadline time.Time) error +// EnableFullDuplex() error // // If the ResponseWriter does not support a method, ResponseController returns // an error matching ErrNotSupported. @@ -115,6 +116,30 @@ func (c *ResponseController) SetWriteDeadline(deadline time.Time) error { } } +// EnableFullDuplex indicates that the request handler will interleave reads from Request.Body +// with writes to the ResponseWriter. +// +// For HTTP/1 requests, the Go HTTP server by default consumes any unread portion of +// the request body before beginning to write the response, preventing handlers from +// concurrently reading from the request and writing the response. +// Calling EnableFullDuplex disables this behavior and permits handlers to continue to read +// from the request while concurrently writing the response. +// +// For HTTP/2 requests, the Go HTTP server always permits concurrent reads and responses. +func (c *ResponseController) EnableFullDuplex() error { + rw := c.rw + for { + switch t := rw.(type) { + case interface{ EnableFullDuplex() error }: + return t.EnableFullDuplex() + case rwUnwrapper: + rw = t.Unwrap() + default: + return errNotSupported() + } + } +} + // errNotSupported returns an error that Is ErrNotSupported, // but is not == to it. func errNotSupported() error { diff --git a/responsecontroller_test.go b/responsecontroller_test.go index 014737e7..1428fd70 100644 --- a/responsecontroller_test.go +++ b/responsecontroller_test.go @@ -157,7 +157,9 @@ func TestResponseControllerSetPastReadDeadline(t *testing.T) { } func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) { readc := make(chan struct{}) + donec := make(chan struct{}) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(donec) ctl := NewResponseController(w) b := make([]byte, 3) n, err := io.ReadFull(r.Body, b) @@ -193,10 +195,19 @@ func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) { wg.Add(1) go func() { defer wg.Done() + defer pw.Close() pw.Write([]byte("one")) - <-readc + select { + case <-readc: + case <-donec: + select { + case <-readc: + default: + t.Errorf("server handler unexpectedly exited without closing readc") + return + } + } pw.Write([]byte("two")) - pw.Close() }() defer wg.Wait() res, err := cst.c.Post(cst.ts.URL, "text/foo", pr) @@ -264,3 +275,51 @@ func testWrappedResponseController(t *testing.T, mode testMode) { io.Copy(io.Discard, res.Body) defer res.Body.Close() } + +func TestResponseControllerEnableFullDuplex(t *testing.T) { + run(t, testResponseControllerEnableFullDuplex) +} +func testResponseControllerEnableFullDuplex(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { + ctl := NewResponseController(w) + if err := ctl.EnableFullDuplex(); err != nil { + // TODO: Drop test for HTTP/2 when x/net is updated to support + // EnableFullDuplex. Since HTTP/2 supports full duplex by default, + // the rest of the test is fine; it's just the EnableFullDuplex call + // that fails. + if mode != http2Mode { + t.Errorf("ctl.EnableFullDuplex() = %v, want nil", err) + } + } + w.WriteHeader(200) + ctl.Flush() + for { + var buf [1]byte + n, err := req.Body.Read(buf[:]) + if n != 1 || err != nil { + break + } + w.Write(buf[:]) + ctl.Flush() + } + })) + pr, pw := io.Pipe() + res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + for i := byte(0); i < 10; i++ { + if _, err := pw.Write([]byte{i}); err != nil { + t.Fatalf("Write: %v", err) + } + var buf [1]byte + if n, err := res.Body.Read(buf[:]); n != 1 || err != nil { + t.Fatalf("Read: %v, %v", n, err) + } + if buf[0] != i { + t.Fatalf("read byte %v, want %v", buf[0], i) + } + } + pw.Close() +} diff --git a/roundtrip.go b/roundtrip.go index c4c5d3b6..49ea1a71 100644 --- a/roundtrip.go +++ b/roundtrip.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !js || !wasm +//go:build !js package http diff --git a/roundtrip_js.go b/roundtrip_js.go index 01c0600b..9f9f0cb6 100644 --- a/roundtrip_js.go +++ b/roundtrip_js.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "strconv" + "strings" "syscall/js" ) @@ -44,6 +45,16 @@ const jsFetchRedirect = "js.fetch:redirect" // the browser globals. var jsFetchMissing = js.Global().Get("fetch").IsUndefined() +// jsFetchDisabled controls whether the use of Fetch API is disabled. +// It's set to true when we detect we're running in Node.js, so that +// RoundTrip ends up talking over the same fake network the HTTP servers +// currently use in various tests and examples. See go.dev/issue/57613. +// +// TODO(go.dev/issue/60810): See if it's viable to test the Fetch API +// code path. +var jsFetchDisabled = js.Global().Get("process").Type() == js.TypeObject && + strings.HasPrefix(js.Global().Get("process").Get("argv0").String(), "node") + // RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *Request) (*Response, error) { // The Transport has a documented contract that states that if the DialContext or @@ -52,7 +63,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // though they are deprecated. Therefore, if any of these are set, we should obey // the contract and dial using the regular round-trip instead. Otherwise, we'll try // to fall back on the Fetch API, unless it's not available. - if t.Dial != nil || t.DialContext != nil || t.DialTLS != nil || t.DialTLSContext != nil || jsFetchMissing { + if t.Dial != nil || t.DialContext != nil || t.DialTLS != nil || t.DialTLSContext != nil || jsFetchMissing || jsFetchDisabled { return t.roundTrip(req) } @@ -99,6 +110,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API // and browser support. + // NOTE(haruyama480): Ensure HTTP/1 fallback exists. + // See https://go.dev/issue/61889 for discussion. body, err := io.ReadAll(req.Body) if err != nil { req.Body.Close() // RoundTrip must always close the body, including on errors. @@ -185,7 +198,23 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { failure = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() - errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String()) + + err := args[0] + // The error is a JS Error type + // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Error + // We can use the toString() method to get a string representation of the error. + errMsg := err.Call("toString").String() + // Errors can optionally contain a cause. + if cause := err.Get("cause"); !cause.IsUndefined() { + // The exact type of the cause is not defined, + // but if it's another error, we can call toString() on it too. + if !cause.Get("toString").IsUndefined() { + errMsg += ": " + cause.Call("toString").String() + } else if cause.Type() == js.TypeString { + errMsg += ": " + cause.String() + } + } + errCh <- fmt.Errorf("net/http: fetch() failed: %s", errMsg) return nil }) diff --git a/serve_test.go b/serve_test.go index 622def57..8b98db53 100644 --- a/serve_test.go +++ b/serve_test.go @@ -27,7 +27,6 @@ import ( "reflect" "regexp" "runtime" - "runtime/debug" "strconv" "strings" "sync" @@ -827,15 +826,7 @@ func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) { t.Fatal(err) } - // fail test if no response after 1 second - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req = req.WithContext(ctx) - r, err := c.Do(req) - if ctx.Err() == context.DeadlineExceeded { - t.Fatalf("http2 Get #%d response timed out", i) - } if err != nil { t.Fatalf("http2 Get #%d: %v", i, err) } @@ -988,25 +979,19 @@ func testOnlyWriteTimeout(t *testing.T, mode testMode) { c := ts.Client() - errc := make(chan error, 1) - go func() { + err := func() error { res, err := c.Get(ts.URL) if err != nil { - errc <- err - return + return err } _, err = io.Copy(io.Discard, res.Body) res.Body.Close() - errc <- err + return err }() - select { - case err := <-errc: - if err == nil { - t.Errorf("expected an error from Get request") - } - case <-time.After(10 * time.Second): - t.Fatal("timeout waiting for Get error") + if err == nil { + t.Errorf("expected an error copying body from Get request") } + if err := <-afterTimeoutErrc; err == nil { t.Error("expected write error after timeout") } @@ -1133,21 +1118,10 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { t.Fatal("ReadResponse error:", err) } - didReadAll := make(chan bool, 1) - go func() { - select { - case <-time.After(5 * time.Second): - t.Error("body not closed after 5s") - return - case <-didReadAll: - } - }() - _, err = io.ReadAll(r) if err != nil { t.Fatal("read error:", err) } - didReadAll <- true if !res.Close { t.Errorf("Response.Close = false; want true") @@ -1323,7 +1297,6 @@ func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { }).ts c := ts.Client() - c.Timeout = time.Second // Force separate connection for each: c.Transport.(*Transport).DisableKeepAlives = true @@ -1354,13 +1327,7 @@ func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { // Start another request and grab its connection response2c := make(chan string, 1) go fetch(2, response2c) - var conn2 net.Conn - - select { - case conn2 = <-conns: - case <-time.After(time.Second): - t.Fatal("Second Accept didn't happen") - } + conn2 := <-conns // Send a response on connection 2. conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{ @@ -1444,13 +1411,9 @@ func testTLSHandshakeTimeout(t *testing.T, mode testMode) { t.Errorf("Read = %d, %v; want an error and no bytes", n, err) } - select { - case v := <-errc: - if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { - t.Errorf("expected a TLS handshake timeout error; got %q", v) - } - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for logged error") + v := <-errc + if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { + t.Errorf("expected a TLS handshake timeout error; got %q", v) } } @@ -1536,8 +1499,6 @@ func TestServeTLS(t *testing.T) { case err := <-errc: t.Fatalf("ServeTLS: %v", err) case <-serving: - case <-time.After(5 * time.Second): - t.Fatal("timeout") } c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ @@ -1783,7 +1744,13 @@ func testServerExpect(t *testing.T, mode testMode) { // that doesn't send 100-continue expectations. writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue" + wg := sync.WaitGroup{} + wg.Add(1) + defer wg.Wait() + go func() { + defer wg.Done() + contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength) if test.chunked { contentLen = "Transfer-Encoding: chunked" @@ -2832,18 +2799,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func return } - var delay time.Duration - if deadline, ok := t.Deadline(); ok { - delay = time.Until(deadline) - } else { - delay = 5 * time.Second - } - select { - case <-done: - return - case <-time.After(delay): - t.Fatal("expected server handler to log an error") - } + <-done } type terrorWriter struct{ t *testing.T } @@ -2883,11 +2839,7 @@ func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) { t.Fatal(err) } res.Body.Close() - select { - case <-done: - case <-time.After(5 * time.Second): - t.Fatal("timeout") - } + <-done } func TestServerNoDate(t *testing.T) { @@ -3062,7 +3014,10 @@ func testRequestBodyLimit(t *testing.T, mode testMode) { // // But that's okay, since what we're really testing is that // the remote side hung up on us before we wrote too much. - _, _ = cst.c.Do(req) + resp, err := cst.c.Do(req) + if err == nil { + resp.Body.Close() + } if atomic.LoadInt64(nWritten) > limit*100 { t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", @@ -3250,8 +3205,6 @@ For: diec <- true case <-sawClose: break For - case <-time.After(5 * time.Second): - t.Fatal("timeout") } } ts.Close() @@ -3307,9 +3260,6 @@ func testCloseNotifierPipelined(t *testing.T, mode testMode) { if closes > 1 { return } - case <-time.After(5 * time.Second): - ts.CloseClientConnections() - t.Fatal("timeout") } } } @@ -3420,12 +3370,8 @@ func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { return } bodyOkay <- true - select { - case <-gone: - gotCloseNotify <- true - case <-time.After(5 * time.Second): - gotCloseNotify <- false - } + <-gone + gotCloseNotify <- true })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -3441,9 +3387,7 @@ func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { return } conn.Close() - if !<-gotCloseNotify { - t.Error("timeout waiting for CloseNotify") - } + <-gotCloseNotify } func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) } @@ -3519,13 +3463,8 @@ func testOptionsHandler(t *testing.T, mode testMode) { t.Fatal(err) } - select { - case got := <-rc: - if got.Method != "OPTIONS" || got.RequestURI != "*" { - t.Errorf("Expected OPTIONS * request, got %v", got) - } - case <-time.After(5 * time.Second): - t.Error("timeout") + if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" { + t.Errorf("Expected OPTIONS * request, got %v", got) } } @@ -3997,8 +3936,6 @@ func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { } <-unblockBackend })) - var quitTimer *time.Timer - defer func() { quitTimer.Stop() }() defer backend.close() backendRespc := make(chan *Response, 1) @@ -4031,20 +3968,6 @@ func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { rw.Write([]byte("OK")) })) defer proxy.close() - defer func() { - // Before we shut down our two httptest.Servers, start a timer. - // We choose 7 seconds because httptest.Server starts logging - // warnings to stderr at 5 seconds. If we don't disarm this bomb - // in 7 seconds (after the two httptest.Server.Close calls above), - // then we explode with stacks. - quitTimer = time.AfterFunc(7*time.Second, func() { - debug.SetTraceback("ALL") - stacks := make([]byte, 1<<20) - stacks = stacks[:runtime.Stack(stacks, true)] - fmt.Fprintf(os.Stderr, "%s", stacks) - log.Fatalf("Timeout.") - }) - }() defer close(unblockBackend) req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize)) @@ -4110,8 +4033,6 @@ func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) { } case err := <-errCh: t.Error(err) - case <-time.After(5 * time.Second): - t.Error("timeout") } } @@ -4188,22 +4109,7 @@ func testServerConnState(t *testing.T, mode testMode) { doRequests() - stateDelay := 5 * time.Second - if deadline, ok := t.Deadline(); ok { - // Allow an arbitrarily long delay. - // This test was observed to be flaky on the darwin-arm64-corellium builder, - // so we're increasing the deadline to see if it starts passing. - // See https://golang.org/issue/37322. - const arbitraryCleanupMargin = 1 * time.Second - stateDelay = time.Until(deadline) - arbitraryCleanupMargin - } - timer := time.NewTimer(stateDelay) - select { - case <-timer.C: - t.Errorf("Timed out after %v waiting for connection to change state.", stateDelay) - case <-complete: - timer.Stop() - } + <-complete sl := <-activeLog if !reflect.DeepEqual(sl.got, sl.want) { t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want) @@ -4492,8 +4398,6 @@ func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { } }() - timeout := time.NewTimer(numReq * 2 * time.Second) // 4x overkill - defer timeout.Stop() addrSeen := map[string]bool{} numOkay := 0 for { @@ -4513,8 +4417,6 @@ func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { if err == nil { numOkay++ } - case <-timeout.C: - t.Fatal("timeout waiting for requests to complete") } } } @@ -4948,15 +4850,11 @@ func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) { } host := cst.ts.Listener.Addr().String() - select { - case got := <-ch: - if addr, ok := got.(net.Addr); !ok { - t.Errorf("local addr value = %T; want net.Addr", got) - } else if fmt.Sprint(addr) != host { - t.Errorf("local addr = %v; want %v", addr, host) - } - case <-time.After(5 * time.Second): - t.Error("timed out") + got := <-ch + if addr, ok := got.(net.Addr); !ok { + t.Errorf("local addr value = %T; want net.Addr", got) + } else if fmt.Sprint(addr) != host { + t.Errorf("local addr = %v; want %v", addr, host) } } @@ -5163,8 +5061,9 @@ func BenchmarkClient(b *testing.B) { } // Start server process. - cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$") - cmd.Env = append(os.Environ(), "TEST_BENCH_SERVER=yes") + ctx, cancel := context.WithCancel(context.Background()) + cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$") + cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes") cmd.Stderr = os.Stderr stdout, err := cmd.StdoutPipe() if err != nil { @@ -5173,35 +5072,28 @@ func BenchmarkClient(b *testing.B) { if err := cmd.Start(); err != nil { b.Fatalf("subprocess failed to start: %v", err) } - defer cmd.Process.Kill() + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + close(done) + }() + defer func() { + cancel() + <-done + }() // Wait for the server in the child process to respond and tell us // its listening address, once it's started listening: - timer := time.AfterFunc(10*time.Second, func() { - cmd.Process.Kill() - }) - defer timer.Stop() bs := bufio.NewScanner(stdout) if !bs.Scan() { b.Fatalf("failed to read listening URL from child: %v", bs.Err()) } url := "http://" + strings.TrimSpace(bs.Text()) + "/" - timer.Stop() if _, err := getNoBody(url); err != nil { b.Fatalf("initial probe of child process failed: %v", err) } - done := make(chan error) - stop := make(chan struct{}) - defer close(stop) - go func() { - select { - case <-stop: - return - case done <- cmd.Wait(): - } - }() - // Do b.N requests to the server. b.StartTimer() for i := 0; i < b.N; i++ { @@ -5222,13 +5114,8 @@ func BenchmarkClient(b *testing.B) { // Instruct server process to stop. getNoBody(url + "?stop=yes") - select { - case err := <-done: - if err != nil { - b.Fatalf("subprocess failed: %v", err) - } - case <-time.After(5 * time.Second): - b.Fatalf("subprocess did not stop") + if err := <-done; err != nil { + b.Fatalf("subprocess failed: %v", err) } } @@ -5438,8 +5325,6 @@ func benchmarkCloseNotifier(b *testing.B, mode testMode) { <-rw.(CloseNotifier).CloseNotify() sawClose <- true })).ts - tot := time.NewTimer(5 * time.Second) - defer tot.Stop() b.StartTimer() for i := 0; i < b.N; i++ { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -5451,12 +5336,7 @@ func benchmarkCloseNotifier(b *testing.B, mode testMode) { b.Fatal(err) } conn.Close() - tot.Reset(5 * time.Second) - select { - case <-sawClose: - case <-tot.C: - b.Fatal("timeout") - } + <-sawClose } b.StopTimer() } @@ -5552,84 +5432,97 @@ func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) { get := func() string { return get(t, c, ts.URL) } a1, a2 := get(), get() - if a1 != a2 { - t.Fatal("expected first two requests on same connection") + if a1 == a2 { + t.Logf("made two requests from a single conn %q (as expected)", a1) + } else { + t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2) } - addr := strings.TrimPrefix(ts.URL, "http://") // The two requests should have used the same connection, // and there should not have been a second connection that // was created by racing dial against reuse. // (The first get was completed when the second get started.) - n := tr.IdleConnCountForTesting("http", addr) - if n != 1 { - t.Fatalf("idle count for %q after 2 gets = %d, want 1", addr, n) + if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 { + t.Errorf("found %d idle conns (%q); want 1", len(conns), conns) } // SetKeepAlivesEnabled should discard idle conns. ts.Config.SetKeepAlivesEnabled(false) - var idle1 int - if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { - idle1 = tr.IdleConnCountForTesting("http", addr) - return idle1 == 0 - }) { - t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1) - } + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { + if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 { + if d > 0 { + t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns) + } + return false + } + return true + }) - a3 := get() - if a3 == a2 { - t.Fatal("expected third request on new connection") - } + // If we make a third request it should use a new connection, but in general + // we have no way to verify that: the new connection could happen to reuse the + // exact same ports from the previous connection. } func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) } func testServerShutdown(t *testing.T, mode testMode) { - var doShutdown func() // set later - var doStateCount func() - var shutdownRes = make(chan error, 1) - var statesRes = make(chan map[ConnState]int, 1) - var gotOnShutdown = make(chan struct{}, 1) + var cst *clientServerTest + + var once sync.Once + statesRes := make(chan map[ConnState]int, 1) + shutdownRes := make(chan error, 1) + gotOnShutdown := make(chan struct{}) handler := HandlerFunc(func(w ResponseWriter, r *Request) { - doStateCount() - go doShutdown() - // Shutdown is graceful, so it should not interrupt - // this in-flight response. Add a tiny sleep here to - // increase the odds of a failure if shutdown has - // bugs. - time.Sleep(20 * time.Millisecond) + first := false + once.Do(func() { + statesRes <- cst.ts.Config.ExportAllConnsByState() + go func() { + shutdownRes <- cst.ts.Config.Shutdown(context.Background()) + }() + first = true + }) + + if first { + // Shutdown is graceful, so it should not interrupt this in-flight response + // but should reject new requests. (Since this request is still in flight, + // the server's port should not be reused for another server yet.) + <-gotOnShutdown + // TODO(#59038): The HTTP/2 server empirically does not always reject new + // requests. As a workaround, loop until we see a failure. + for !t.Failed() { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + break + } + out, _ := io.ReadAll(res.Body) + res.Body.Close() + if mode == http2Mode { + t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out) + t.Logf("Retrying to work around https://go.dev/issue/59038.") + continue + } + t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out) + } + } + io.WriteString(w, r.RemoteAddr) }) - cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) { - srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) + + cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) { + srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) }) }) - doShutdown = func() { - shutdownRes <- cst.ts.Config.Shutdown(context.Background()) - } - doStateCount = func() { - statesRes <- cst.ts.Config.ExportAllConnsByState() - } - get(t, cst.c, cst.ts.URL) // calls t.Fail on failure + out := get(t, cst.c, cst.ts.URL) // calls t.Fail on failure + t.Logf("%v: %q", cst.ts.URL, out) if err := <-shutdownRes; err != nil { t.Fatalf("Shutdown: %v", err) } - select { - case <-gotOnShutdown: - case <-time.After(5 * time.Second): - t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?") - } + <-gotOnShutdown // Will hang if RegisterOnShutdown is broken. if states := <-statesRes; states[StateActive] != 1 { t.Errorf("connection in wrong state, %v", states) } - - res, err := cst.c.Get(cst.ts.URL) - if err == nil { - res.Body.Close() - t.Fatal("second request should fail. server should be shut down") - } } func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) } @@ -5673,7 +5566,11 @@ func testServerShutdownStateNew(t *testing.T, mode testMode) { readRes <- err }() + // TODO(#59037): This timeout is hard-coded in closeIdleConnections. + // It is undocumented, and some users may find it surprising. + // Either document it, or switch to a less surprising behavior. const expectTimeout = 5 * time.Second + t0 := time.Now() select { case got := <-shutdownRes: @@ -5690,13 +5587,8 @@ func testServerShutdownStateNew(t *testing.T, mode testMode) { // Wait for c.Read to unblock; should be already done at this point, // or within a few milliseconds. - select { - case err := <-readRes: - if err == nil { - t.Error("expected error from Read") - } - case <-time.After(2 * time.Second): - t.Errorf("timeout waiting for Read to unblock") + if err := <-readRes; err == nil { + t.Error("expected error from Read") } } @@ -5721,9 +5613,15 @@ func testServerKeepAlivesEnabled(t *testing.T, mode testMode) { srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) for try := 0; try < 2; try++ { - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("request %v: test server has active conns", try) - } + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { + if !srv.ExportAllConnsIdle() { + if d > 0 { + t.Logf("test server still has active conns after %v", d) + } + return false + } + return true + }) conns := 0 var info httptrace.GotConnInfo ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ @@ -5966,11 +5864,7 @@ func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { if err := cn.(*net.TCPConn).CloseWrite(); err != nil { t.Fatal(err) } - select { - case <-done: - case <-time.After(2 * time.Second): - t.Error("timeout") - } + <-done } // Like TestServerHijackGetsBackgroundByte above but sending a @@ -6535,6 +6429,75 @@ func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) { } } +type tlogWriter struct{ t *testing.T } + +func (w tlogWriter) Write(p []byte) (int, error) { + w.t.Log(string(p)) + return len(p), nil +} + +func TestWriteHeaderSwitchingProtocols(t *testing.T) { + run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode}) +} +func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) { + const wantBody = "want" + const wantUpgrade = "someProto" + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", wantUpgrade) + w.WriteHeader(StatusSwitchingProtocols) + NewResponseController(w).Flush() + + // Writing headers or the body after sending a 101 header should fail. + w.WriteHeader(200) + if _, err := w.Write([]byte("x")); err == nil { + t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded") + } + + c, _, err := NewResponseController(w).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + defer c.Close() + if _, err := c.Write([]byte(wantBody)); err != nil { + t.Errorf("Write to hijacked body: %v", err) + } + }), func(ts *httptest.Server) { + // Don't spam log with warning about superfluous WriteHeader call. + ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0) + }).ts + + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("net.Dial: %v", err) + } + _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("conn.Write: %v", err) + } + defer conn.Close() + + r := bufio.NewReader(conn) + res, err := ReadResponse(r, &Request{Method: "GET"}) + if err != nil { + t.Fatal("ReadResponse error:", err) + } + if res.StatusCode != StatusSwitchingProtocols { + t.Errorf("Response StatusCode=%v, want 101", res.StatusCode) + } + if got := res.Header.Get("Upgrade"); got != wantUpgrade { + t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade) + } + body, err := io.ReadAll(r) + if err != nil { + t.Error(err) + } + if string(body) != wantBody { + t.Errorf("Response body = %q, want %q", string(body), wantBody) + } +} + func TestMuxRedirectRelative(t *testing.T) { setParallel(t) req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n"))) @@ -6557,10 +6520,10 @@ func TestQuerySemicolon(t *testing.T) { t.Cleanup(func() { afterTest(t) }) tests := []struct { - query string - xNoSemicolons string - xWithSemicolons string - warning bool + query string + xNoSemicolons string + xWithSemicolons string + expectParseFormErr bool }{ {"?a=1;x=bad&x=good", "good", "bad", true}, {"?a=1;b=bad&x=good", "good", "good", true}, @@ -6572,20 +6535,20 @@ func TestQuerySemicolon(t *testing.T) { for _, tt := range tests { t.Run(tt.query+"/allow=false", func(t *testing.T) { allowSemicolons := false - testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) + testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr) }) t.Run(tt.query+"/allow=true", func(t *testing.T) { - allowSemicolons, expectWarning := true, false - testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) + allowSemicolons, expectParseFormErr := true, false + testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr) }) } }) } -func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) { +func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) { writeBackX := func(w ResponseWriter, r *Request) { x := r.URL.Query().Get("x") - if expectWarning { + if expectParseFormErr { if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") { t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err) } @@ -6623,16 +6586,6 @@ func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, if got, want := string(slurp), wantX; got != want { t.Errorf("Body = %q; want = %q", got, want) } - - if expectWarning { - if !strings.Contains(logBuf.String(), "semicolon") { - t.Errorf("got %q from ErrorLog, expected a mention of semicolons", logBuf.String()) - } - } else { - if strings.Contains(logBuf.String(), "semicolon") { - t.Errorf("got %q from ErrorLog, expected no mention of semicolons", logBuf.String()) - } - } } func TestMaxBytesHandler(t *testing.T) { @@ -6666,9 +6619,30 @@ func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64 defer ts.Close() c := ts.Client() + + body := strings.Repeat("a", int(requestSize)) + var wg sync.WaitGroup + defer wg.Wait() + getBody := func() (io.ReadCloser, error) { + wg.Add(1) + body := &wgReadCloser{ + Reader: strings.NewReader(body), + wg: &wg, + } + return body, nil + } + reqBody, _ := getBody() + req, err := NewRequest("POST", ts.URL, reqBody) + if err != nil { + reqBody.Close() + t.Fatal(err) + } + req.ContentLength = int64(len(body)) + req.GetBody = getBody + req.Header.Set("Content-Type", "text/plain") + var buf strings.Builder - body := strings.NewReader(strings.Repeat("a", int(requestSize))) - res, err := c.Post(ts.URL, "text/plain", body) + res, err := c.Do(req) if err != nil { t.Errorf("unexpected connection error: %v", err) } else { @@ -6849,3 +6823,43 @@ func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) { } } } + +// TestDisableContentLength verifies that the Content-Length is set by default +// or disabled when the header is set to nil. +func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) } +func testDisableContentLength(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535") + } + + noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header()["Content-Length"] = nil // disable the default Content-Length response + fmt.Fprintf(w, "OK") + })) + + res, err := noCL.c.Get(noCL.ts.URL) + if err != nil { + t.Fatal(err) + } + if got, haveCL := res.Header["Content-Length"]; haveCL { + t.Errorf("Unexpected Content-Length: %q", got) + } + if err := res.Body.Close(); err != nil { + t.Fatal(err) + } + + withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "OK") + })) + + res, err = withCL.c.Get(withCL.ts.URL) + if err != nil { + t.Fatal(err) + } + if got := res.Header.Get("Content-Length"); got != "2" { + t.Errorf("Content-Length: %q; want 2", got) + } + if err := res.Body.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/server.go b/server.go index a0d6b57e..b6fdfe51 100644 --- a/server.go +++ b/server.go @@ -459,6 +459,10 @@ type response struct { // Content-Length. closeAfterReply bool + // When fullDuplex is false (the default), we consume any remaining + // request body before starting to write a response. + fullDuplex bool + // requestBodyLimitHit is set by requestTooLarge when // maxBytesReader hits its max size. It is checked in // WriteHeader, to make sure we don't consume the @@ -496,6 +500,11 @@ func (c *response) SetWriteDeadline(deadline time.Time) error { return c.conn.rwc.SetWriteDeadline(deadline) } +func (c *response) EnableFullDuplex() error { + c.fullDuplex = true + return nil +} + // TrailerPrefix is a magic prefix for ResponseWriter.Header map keys // that, if present, signals that the map entry is actually for // the response trailers, and not the response headers. The prefix @@ -1144,8 +1153,11 @@ func (w *response) WriteHeader(code int) { } checkWriteHeaderCode(code) - // Handle informational headers - if code >= 100 && code <= 199 { + // Handle informational headers. + // + // We shouldn't send any further headers after 101 Switching Protocols, + // so it takes the non-informational path. + if code >= 100 && code <= 199 && code != StatusSwitchingProtocols { // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() if code == 100 && w.canWriteContinue.Load() { w.writeContinueMu.Lock() @@ -1307,7 +1319,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // send a Content-Length header. // Further, we don't send an automatic Content-Length if they // set a Transfer-Encoding, because they're generally incompatible. - if w.handlerDone.Load() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) { + if w.handlerDone.Load() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && !header.has("Content-Length") && (!isHEAD || len(p) > 0) { w.contentLength = int64(len(p)) setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10) } @@ -1353,14 +1365,14 @@ func (cw *chunkWriter) writeHeader(p []byte) { w.closeAfterReply = true } - // Per RFC 2616, we should consume the request body before - // replying, if the handler hasn't already done so. But we - // don't want to do an unbounded amount of reading here for - // DoS reasons, so we only try up to a threshold. - // TODO(bradfitz): where does RFC 2616 say that? See Issue 15527 - // about HTTP/1.x Handlers concurrently reading and writing, like - // HTTP/2 handlers can do. Maybe this code should be relaxed? - if w.req.ContentLength != 0 && !w.closeAfterReply { + // We do this by default because there are a number of clients that + // send a full request before starting to read the response, and they + // can deadlock if we start writing the response with unconsumed body + // remaining. See Issue 15527 for some history. + // + // If full duplex mode has been enabled with ResponseController.EnableFullDuplex, + // then leave the request body alone. + if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex { var discard, tooBig bool switch bdy := w.req.Body.(type) { @@ -1748,7 +1760,7 @@ type closeWriter interface { var _ closeWriter = (*net.TCPConn)(nil) -// closeWrite flushes any outstanding data and sends a FIN packet (if +// closeWriteAndWait flushes any outstanding data and sends a FIN packet (if // client is connected via TCP), signaling that we're done. We then // pause for a bit, hoping the client processes it before any // subsequent RST. @@ -1843,7 +1855,9 @@ func isCommonNetReadError(err error) bool { // Serve a new connection. func (c *conn) serve(ctx context.Context) { - c.remoteAddr = c.rwc.RemoteAddr().String() + if ra := c.rwc.RemoteAddr(); ra != nil { + c.remoteAddr = ra.String() + } ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) var inFlightResponse *response defer func() { @@ -2277,7 +2291,7 @@ func RedirectHandler(url string, code int) Handler { // Longer patterns take precedence over shorter ones, so that // if there are handlers registered for both "/images/" // and "/images/thumbnails/", the latter handler will be -// called for paths beginning "/images/thumbnails/" and the +// called for paths beginning with "/images/thumbnails/" and the // former will receive requests for any other paths in the // "/images/" subtree. // @@ -2920,23 +2934,9 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { handler = globalOptionsHandler{} } - if req.URL != nil && strings.Contains(req.URL.RawQuery, ";") { - var allowQuerySemicolonsInUse atomic.Bool - req = req.WithContext(context.WithValue(req.Context(), silenceSemWarnContextKey, func() { - allowQuerySemicolonsInUse.Store(true) - })) - defer func() { - if !allowQuerySemicolonsInUse.Load() { - sh.srv.logf("http: URL query contains semicolon, which is no longer a supported separator; parts of the query may be stripped when parsed; see golang.org/issue/25192") - } - }() - } - handler.ServeHTTP(rw, req) } -var silenceSemWarnContextKey = &contextKey{"silence-semicolons"} - // AllowQuerySemicolons returns a handler that serves requests by converting any // unescaped semicolons in the URL query to ampersands, and invoking the handler h. // @@ -2948,9 +2948,6 @@ var silenceSemWarnContextKey = &contextKey{"silence-semicolons"} // AllowQuerySemicolons should be invoked before Request.ParseForm is called. func AllowQuerySemicolons(h Handler) Handler { return HandlerFunc(func(w ResponseWriter, r *Request) { - if silenceSemicolonsWarning, ok := r.Context().Value(silenceSemWarnContextKey).(func()); ok { - silenceSemicolonsWarning() - } if strings.Contains(r.URL.RawQuery, ";") { r2 := new(Request) *r2 = *r @@ -2989,7 +2986,7 @@ func (srv *Server) ListenAndServe() error { var testHookServerServe func(*Server, net.Listener) // used if non-nil -// shouldDoServeHTTP2 reports whether Server.Serve should configure +// shouldConfigureHTTP2ForServe reports whether Server.Serve should configure // automatic HTTP/2. (which sets up the srv.TLSNextProto map) func (srv *Server) shouldConfigureHTTP2ForServe() bool { if srv.TLSConfig == nil { diff --git a/socks_bundle.go b/socks_bundle.go index e4466695..776b03d9 100644 --- a/socks_bundle.go +++ b/socks_bundle.go @@ -445,7 +445,7 @@ func (up *socksUsernamePassword) Authenticate(ctx context.Context, rw io.ReadWri case socksAuthMethodNotRequired: return nil case socksAuthMethodUsernamePassword: - if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 { + if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 { return errors.New("invalid username/password") } b := []byte{socksauthUsernamePasswordVersion} diff --git a/tools/compare.bash b/tools/compare.bash index 17d6796c..cf48fac0 100755 --- a/tools/compare.bash +++ b/tools/compare.bash @@ -9,4 +9,24 @@ test -d $upstreamrepo || git clone git@github.com:golang/go.git $upstreamrepo git pull git checkout $TAG ) -diff -ur $upstreamrepo/src/net/http . + +# Classification of ./internal packages +# +# 1. packages that map directly and are captured by the following diff command +# +# - src/net/http/internal/ascii => ./internal/ascii +# +# - src/net/http/internal/testcert => ./internal/testcert + +diff -ur $upstreamrepo/src/net/http . || true + +# 2. packages that we need to diff for explicitly +# +# - src/internal/safefilepath => ./internal/safefilepath + +diff -ur $upstreamrepo/src/internal/safefilepath ./internal/safefilepath || true + +# 3. replacement packages +# +# - ./internal/fakerace fakes out src/internal/race +# - ./internal/testenv fakes out src/internal/testenv diff --git a/transfer.go b/transfer.go index 036839ef..31bd3038 100644 --- a/transfer.go +++ b/transfer.go @@ -416,7 +416,7 @@ func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err return } -// unwrapBodyReader unwraps the body's inner reader if it's a +// unwrapBody unwraps the body's inner reader if it's a // nopCloser. This is to ensure that body writes sourced from local // files (*os.File types) are properly optimized. // diff --git a/transport.go b/transport.go index 559f85cf..47253fd1 100644 --- a/transport.go +++ b/transport.go @@ -84,13 +84,13 @@ const DefaultMaxIdleConnsPerHost = 2 // ClientTrace.Got1xxResponse. // // Transport only retries a request upon encountering a network error -// if the request is idempotent and either has no body or has its -// Request.GetBody defined. HTTP requests are considered idempotent if -// they have HTTP methods GET, HEAD, OPTIONS, or TRACE; or if their -// Header map contains an "Idempotency-Key" or "X-Idempotency-Key" -// entry. If the idempotency key value is a zero-length slice, the -// request is treated as idempotent but the header is not sent on the -// wire. +// if the connection has been already been used successfully and if the +// request is idempotent and either has no body or has its Request.GetBody +// defined. HTTP requests are considered idempotent if they have HTTP methods +// GET, HEAD, OPTIONS, or TRACE; or if their Header map contains an +// "Idempotency-Key" or "X-Idempotency-Key" entry. If the idempotency key +// value is a zero-length slice, the request is treated as idempotent but the +// header is not sent on the wire. type Transport struct { idleMu sync.Mutex closeIdle bool // user has requested to close all idle conns @@ -172,7 +172,7 @@ type Transport struct { // If non-nil, HTTP/2 support may not be enabled by default. TLSClientConfig *tls.Config - // TLSHandshakeTimeout specifies the maximum amount of time waiting to + // TLSHandshakeTimeout specifies the maximum amount of time to // wait for a TLS handshake. Zero means no timeout. TLSHandshakeTimeout time.Duration @@ -623,6 +623,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { if e, ok := err.(transportReadFromServerError); ok { err = e.err } + if b, ok := req.Body.(*readTrackingBody); ok && !b.didClose { + // Issue 49621: Close the request body if pconn.roundTrip + // didn't do so already. This can happen if the pconn + // write loop exits without reading the write request. + req.closeBody() + } return nil, err } testHookRoundTripRetried() @@ -1175,7 +1181,11 @@ var zeroDialer net.Dialer func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { if t.DialContext != nil { - return t.DialContext(ctx, network, addr) + c, err := t.DialContext(ctx, network, addr) + if c == nil && err == nil { + err = errors.New("net/http: Transport.DialContext hook returned (nil, nil)") + } + return c, err } if t.Dial != nil { c, err := t.Dial(network, addr) @@ -2447,7 +2457,10 @@ func (pc *persistConn) writeLoop() { // maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip // will wait to see the Request's Body.Write result after getting a // response from the server. See comments in (*persistConn).wroteRequest. -const maxWriteWaitBeforeConnReuse = 50 * time.Millisecond +// +// In tests, we set this to a large value to avoid flakiness from inconsistent +// recycling of connections. +var maxWriteWaitBeforeConnReuse = 50 * time.Millisecond // wroteRequest is a check before recycling a connection that the previous write // (from writeLoop above) happened and was successful. @@ -2745,17 +2758,21 @@ var portMap = map[string]string{ "socks5": "1080", } -// canonicalAddr returns url.Host but always with a ":port" suffix. -func canonicalAddr(url *url.URL) string { +func idnaASCIIFromURL(url *url.URL) string { addr := url.Hostname() if v, err := idnaASCII(addr); err == nil { addr = v } + return addr +} + +// canonicalAddr returns url.Host but always with a ":port" suffix. +func canonicalAddr(url *url.URL) string { port := url.Port() if port == "" { port = portMap[url.Scheme] } - return net.JoinHostPort(addr, port) + return net.JoinHostPort(idnaASCIIFromURL(url), port) } // bodyEOFSignal is used by the HTTP/1 transport when reading response diff --git a/transport_default_other.go b/transport_default_other.go index 8a2f1cc4..4f6c5c12 100644 --- a/transport_default_other.go +++ b/transport_default_other.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build !(js && wasm) -// +build !js !wasm +//go:build !wasm package http diff --git a/transport_default_js.go b/transport_default_wasm.go similarity index 89% rename from transport_default_js.go rename to transport_default_wasm.go index c07d35ef..3946812d 100644 --- a/transport_default_js.go +++ b/transport_default_wasm.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build js && wasm -// +build js,wasm +//go:build (js && wasm) || wasip1 package http diff --git a/transport_test.go b/transport_test.go index 94082b0b..5e0b7420 100644 --- a/transport_test.go +++ b/transport_test.go @@ -124,6 +124,8 @@ func (tcs *testConnSet) check(t *testing.T) { continue } if i != 0 { + // TODO(bcmills): What is the Sleep here doing, and why is this + // Unlock/Sleep/Lock cycle needed at all? tcs.mu.Unlock() time.Sleep(50 * time.Millisecond) tcs.mu.Lock() @@ -248,7 +250,7 @@ func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { // an underlying TCP connection after making an http.Request with Request.Close set. // // It tests the behavior by making an HTTP request to a server which -// describes the source source connection it got (remote port number + +// describes the source connection it got (remote port number + // address of its net.Conn). func TestTransportConnectionCloseOnRequest(t *testing.T) { run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) @@ -738,7 +740,7 @@ func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { c := ts.Client() tr := c.Transport.(*Transport) - doReq := func(name string) string { + doReq := func(name string) { // Do a POST instead of a GET to prevent the Transport's // idempotent request retry logic from kicking in... res, err := c.Post(ts.URL, "", nil) @@ -753,26 +755,27 @@ func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { if err != nil { t.Fatalf("%s: %v", name, err) } - return string(slurp) + t.Logf("%s: ok (%q)", name, slurp) } - first := doReq("first") + doReq("first") keys1 := tr.IdleConnKeysForTesting() ts.CloseClientConnections() var keys2 []string - if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { keys2 = tr.IdleConnKeysForTesting() - return len(keys2) == 0 - }) { - t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) - } + if len(keys2) != 0 { + if d > 0 { + t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2) + } + return false + } + return true + }) - second := doReq("second") - if first == second { - t.Errorf("expected a different connection between requests. got %q both times", first) - } + doReq("second") } // Test that the Transport notices when a server hangs up on its @@ -862,7 +865,8 @@ func testStressSurpriseServerCloses(t *testing.T, mode testMode) { numClients = 20 reqsPerClient = 25 ) - activityc := make(chan bool) + var wg sync.WaitGroup + wg.Add(numClients * reqsPerClient) for i := 0; i < numClients; i++ { go func() { for i := 0; i < reqsPerClient; i++ { @@ -876,22 +880,13 @@ func testStressSurpriseServerCloses(t *testing.T, mode testMode) { // where we won the race. res.Body.Close() } - if !<-activityc { // Receives false when close(activityc) is executed - return - } + wg.Done() } }() } // Make sure all the request come back, one way or another. - for i := 0; i < numClients*reqsPerClient; i++ { - select { - case activityc <- true: - case <-time.After(5 * time.Second): - close(activityc) - t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") - } - } + wg.Wait() } // TestTransportHeadResponses verifies that we deal with Content-Lengths @@ -1323,12 +1318,7 @@ func testSOCKS5Proxy(t *testing.T, mode testMode) { if r.Header.Get(sentinelHeader) != sentinelValue { t.Errorf("Failed to retrieve sentinel value") } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to socks5 proxy") - } + got := <-ch ts.Close() tsu, err := url.Parse(ts.URL) if err != nil { @@ -1419,12 +1409,7 @@ func TestTransportProxy(t *testing.T) { if _, err := c.Head(ts.URL); err != nil { t.Error(err) } - var got *Request - select { - case got = <-proxyCh: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to http proxy") - } + got := <-proxyCh c.Transport.(*Transport).CloseIdleConnections() ts.Close() proxy.Close() @@ -2328,67 +2313,81 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } - inHandler := make(chan bool, 1) - mux := NewServeMux() - mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { - inHandler <- true - }) - mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { - inHandler <- true - time.Sleep(2 * time.Second) - }) - ts := newClientServerTest(t, mode, mux).ts - - c := ts.Client() - c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond - tests := []struct { - path string - want int - wantErr string - }{ - {path: "/fast", want: 200}, - {path: "/slow", wantErr: "timeout awaiting response headers"}, - {path: "/fast", want: 200}, - } - for i, tt := range tests { - req, _ := NewRequest("GET", ts.URL+tt.path, nil) - req = req.WithT(t) - res, err := c.Do(req) - select { - case <-inHandler: - case <-time.After(5 * time.Second): - t.Errorf("never entered handler for test index %d, %s", i, tt.path) - continue - } - if err != nil { - uerr, ok := err.(*url.Error) - if !ok { - t.Errorf("error is not an url.Error; got: %#v", err) - continue - } - nerr, ok := uerr.Err.(net.Error) - if !ok { - t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + timeout := 2 * time.Millisecond + retry := true + for retry && !t.Failed() { + var srvWG sync.WaitGroup + inHandler := make(chan bool, 1) + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { + inHandler <- true + srvWG.Done() + }) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + inHandler <- true + <-r.Context().Done() + srvWG.Done() + }) + ts := newClientServerTest(t, mode, mux).ts + + c := ts.Client() + c.Transport.(*Transport).ResponseHeaderTimeout = timeout + + retry = false + srvWG.Add(3) + tests := []struct { + path string + wantTimeout bool + }{ + {path: "/fast"}, + {path: "/slow", wantTimeout: true}, + {path: "/fast"}, + } + for i, tt := range tests { + req, _ := NewRequest("GET", ts.URL+tt.path, nil) + req = req.WithT(t) + res, err := c.Do(req) + <-inHandler + if err != nil { + uerr, ok := err.(*url.Error) + if !ok { + t.Errorf("error is not a url.Error; got: %#v", err) + continue + } + nerr, ok := uerr.Err.(net.Error) + if !ok { + t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + continue + } + if !nerr.Timeout() { + t.Errorf("want timeout error; got: %q", nerr) + continue + } + if !tt.wantTimeout { + if !retry { + // The timeout may be set too short. Retry with a longer one. + t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout) + timeout *= 2 + retry = true + } + } + if !strings.Contains(err.Error(), "timeout awaiting response headers") { + t.Errorf("%d. unexpected error: %v", i, err) + } continue } - if !nerr.Timeout() { - t.Errorf("want timeout error; got: %q", nerr) + if tt.wantTimeout { + t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path) continue } - if strings.Contains(err.Error(), tt.wantErr) { - continue + if res.StatusCode != 200 { + t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode) } - t.Errorf("%d. unexpected error: %v", i, err) - continue - } - if tt.wantErr != "" { - t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) - continue - } - if res.StatusCode != tt.want { - t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) } + + srvWG.Wait() + ts.Close() } } @@ -2399,9 +2398,11 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } + + const msg = "Hello" unblockc := make(chan bool) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "Hello") + io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc })).ts @@ -2415,35 +2416,32 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { if err != nil { t.Fatal(err) } - go func() { - time.Sleep(1 * time.Second) - tr.CancelRequest(req) - }() - t0 := time.Now() - body, err := io.ReadAll(res.Body) - d := time.Since(t0) + body := make([]byte, len(msg)) + n, _ := io.ReadFull(res.Body, body) + if n != len(body) || !bytes.Equal(body, []byte(msg)) { + t.Errorf("Body = %q; want %q", body[:n], msg) + } + tr.CancelRequest(req) + tail, err := io.ReadAll(res.Body) + res.Body.Close() if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } else if len(tail) > 0 { + t.Errorf("Spurious bytes from Body.Read: %q", tail) } - if string(body) != "Hello" { - t.Errorf("Body = %q; want Hello", body) - } - if d < 500*time.Millisecond { - t.Errorf("expected ~1 second delay; got %v", d) - } + // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. - for tries := 5; tries > 0; tries-- { + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { n := tr.NumPendingRequestsForTesting() - if n == 0 { - break - } - time.Sleep(100 * time.Millisecond) - if tries == 1 { - t.Errorf("pending requests = %d; want 0", n) + if n > 0 { + if d > 0 { + t.Logf("pending requests = %d after %v (want 0)", n, d) + } } - } + return true + }) } func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { @@ -2465,18 +2463,20 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) defer close(donec) c.Do(req) }() - start := time.Now() - timeout := 10 * time.Second - for time.Since(start) < timeout { - time.Sleep(100 * time.Millisecond) + + unblockc <- true + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { tr.CancelRequest(req) select { case <-donec: - return + return true default: + if d > 0 { + t.Logf("Do of canceled request has not returned after %v", d) + } + return false } - } - t.Errorf("Do of canceled request has not returned after %v", timeout) + }) } func TestTransportCancelRequestInDo(t *testing.T) { @@ -2522,22 +2522,22 @@ func TestTransportCancelRequestInDial(t *testing.T) { gotres <- true }() - select { - case inDial <- true: - case <-time.After(5 * time.Second): - close(inDial) - t.Fatal("timeout; never saw blocking dial") - } + inDial <- true eventLog.Printf("canceling") tr.CancelRequest(req) tr.CancelRequest(req) // used to panic on second call - select { - case <-gotres: - case <-time.After(5 * time.Second): - panic("hang. events are: " + logbuf.String()) + if d, ok := t.Deadline(); ok { + // When the test's deadline is about to expire, log the pending events for + // better debugging. + timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup. + timer := time.AfterFunc(timeout, func() { + panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String())) + }) + defer timer.Stop() } + <-gotres got := logbuf.String() want := `dial: blocking @@ -2554,9 +2554,11 @@ func testCancelRequestWithChannel(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } - unblockc := make(chan bool) + + const msg = "Hello" + unblockc := make(chan struct{}) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "Hello") + io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc })).ts @@ -2566,42 +2568,39 @@ func testCancelRequestWithChannel(t *testing.T, mode testMode) { tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) - ch := make(chan struct{}) - req.Cancel = ch + cancel := make(chan struct{}) + req.Cancel = cancel res, err := c.Do(req) if err != nil { t.Fatal(err) } - go func() { - time.Sleep(1 * time.Second) - close(ch) - }() - t0 := time.Now() - body, err := io.ReadAll(res.Body) - d := time.Since(t0) + body := make([]byte, len(msg)) + n, _ := io.ReadFull(res.Body, body) + if n != len(body) || !bytes.Equal(body, []byte(msg)) { + t.Errorf("Body = %q; want %q", body[:n], msg) + } + close(cancel) + tail, err := io.ReadAll(res.Body) + res.Body.Close() if err != ExportErrRequestCanceled { t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } else if len(tail) > 0 { + t.Errorf("Spurious bytes from Body.Read: %q", tail) } - if string(body) != "Hello" { - t.Errorf("Body = %q; want Hello", body) - } - if d < 500*time.Millisecond { - t.Errorf("expected ~1 second delay; got %v", d) - } + // Verify no outstanding requests after readLoop/writeLoop // goroutines shut down. - for tries := 5; tries > 0; tries-- { + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { n := tr.NumPendingRequestsForTesting() - if n == 0 { - break - } - time.Sleep(100 * time.Millisecond) - if tries == 1 { - t.Errorf("pending requests = %d; want 0", n) + if n > 0 { + if d > 0 { + t.Logf("pending requests = %d after %v (want 0)", n, d) + } } - } + return true + }) } func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { @@ -2730,25 +2729,13 @@ func testTransportCloseResponseBody(t *testing.T, mode testMode) { if !bytes.Equal(buf, want) { t.Fatalf("read %q; want %q", buf, want) } - didClose := make(chan error, 1) - go func() { - didClose <- res.Body.Close() - }() - select { - case err := <-didClose: - if err != nil { - t.Errorf("Close = %v", err) - } - case <-time.After(10 * time.Second): - t.Fatal("too long waiting for close") + + if err := res.Body.Close(); err != nil { + t.Errorf("Close = %v", err) } - select { - case err := <-writeErr: - if err == nil { - t.Errorf("expected non-nil write error") - } - case <-time.After(10 * time.Second): - t.Fatal("too long waiting for write error") + + if err := <-writeErr; err == nil { + t.Errorf("expected non-nil write error") } } @@ -2826,13 +2813,21 @@ func testTransportSocketLateBinding(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, mux).ts dialGate := make(chan bool, 1) + dialing := make(chan bool) c := ts.Client() c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { - if <-dialGate { - return net.Dial(n, addr) + for { + select { + case ok := <-dialGate: + if !ok { + return nil, errors.New("manually closed") + } + return net.Dial(n, addr) + case dialing <- true: + } } - return nil, errors.New("manually closed") } + defer close(dialGate) dialGate <- true // only allow one dial fooRes, err := c.Get(ts.URL + "/foo") @@ -2843,13 +2838,34 @@ func testTransportSocketLateBinding(t *testing.T, mode testMode) { if fooAddr == "" { t.Fatal("No addr on /foo request") } - time.AfterFunc(200*time.Millisecond, func() { - // let the foo response finish so we can use its - // connection for /bar + + fooDone := make(chan struct{}) + go func() { + // We know that the foo Dial completed and reached the handler because we + // read its header. Wait for the bar request to block in Dial, then + // let the foo response finish so we can use its connection for /bar. + + if mode == http2Mode { + // In HTTP/2 mode, the second Dial won't happen because the protocol + // multiplexes the streams by default. Just sleep for an arbitrary time; + // the test should pass regardless of how far the bar request gets by this + // point. + select { + case <-dialing: + t.Errorf("unexpected second Dial in HTTP/2 mode") + case <-time.After(10 * time.Millisecond): + } + } else { + <-dialing + } fooGate <- true io.Copy(io.Discard, fooRes.Body) fooRes.Body.Close() - }) + close(fooDone) + }() + defer func() { + <-fooDone + }() barRes, err := c.Get(ts.URL + "/bar") if err != nil { @@ -2860,7 +2876,6 @@ func testTransportSocketLateBinding(t *testing.T, mode testMode) { t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) } barRes.Body.Close() - dialGate <- false } // Issue 2184 @@ -3270,42 +3285,33 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { c.Close() }() - getdonec := make(chan struct{}) - go func() { - defer close(getdonec) - tr := &Transport{ - Dial: func(_, _ string) (net.Conn, error) { - return net.Dial("tcp", ln.Addr().String()) - }, - TLSHandshakeTimeout: 250 * time.Millisecond, - } - cl := &Client{Transport: tr} - _, err := cl.Get("https://dummy.tld/") - if err == nil { - t.Error("expected error") - return - } - ue, ok := err.(*url.Error) - if !ok { - t.Errorf("expected url.Error; got %#v", err) - return - } - ne, ok := ue.Err.(net.Error) - if !ok { - t.Errorf("expected net.Error; got %#v", err) - return - } - if !ne.Timeout() { - t.Errorf("expected timeout error; got %v", err) - } - if !strings.Contains(err.Error(), "handshake timeout") { - t.Errorf("expected 'handshake timeout' in error; got %v", err) - } - }() - select { - case <-getdonec: - case <-time.After(5 * time.Second): - t.Error("test timeout; TLS handshake hung?") + tr := &Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Error("expected error") + return + } + ue, ok := err.(*url.Error) + if !ok { + t.Errorf("expected url.Error; got %#v", err) + return + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Errorf("expected net.Error; got %#v", err) + return + } + if !ne.Timeout() { + t.Errorf("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Errorf("expected 'handshake timeout' in error; got %v", err) } } @@ -3395,9 +3401,13 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) { // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { - run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}) + run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel) } func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { + defer func(d time.Duration) { + *MaxWriteWaitBeforeConnReuse = d + }(*MaxWriteWaitBeforeConnReuse) + *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond var sconn struct { sync.Mutex c net.Conn @@ -3438,24 +3448,15 @@ func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { if err := wantBody(res, err, "foo"); err != nil { t.Errorf("POST response: %v", err) } - donec := make(chan bool) - go func() { - defer close(donec) - res, err = c.Get(ts.URL) - if err := wantBody(res, err, "bar"); err != nil { - t.Errorf("GET response: %v", err) - return - } - getOkay = true // suppress test noise - }() - time.AfterFunc(5*time.Second, closeConn) - select { - case <-donec: - finalBit <- 'x' // unblock the writeloop of the first Post - close(finalBit) - case <-time.After(7 * time.Second): - t.Fatal("timeout waiting for GET request to finish") + + res, err = c.Get(ts.URL) + if err := wantBody(res, err, "bar"); err != nil { + t.Errorf("GET response: %v", err) + return } + getOkay = true // suppress test noise + finalBit <- 'x' // unblock the writeloop of the first Post + close(finalBit) } // Tests that we don't leak Transport persistConn.readLoop goroutines @@ -3633,13 +3634,13 @@ func testRetryRequestsOnError(t *testing.T, mode testMode) { req := tc.req() res, err := c.Do(req) if err != nil { - if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 { + if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 { mu.Lock() got := logbuf.String() mu.Unlock() t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got) } - t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse) + t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse) } res.Body.Close() if res.Request != req { @@ -3699,13 +3700,8 @@ func testTransportClosesBodyOnError(t *testing.T, mode testMode) { if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) } - select { - case err := <-readBody: - if err == nil { - t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") - } - case <-time.After(5 * time.Second): - t.Error("timeout waiting for server handler to complete") + if err := <-readBody; err == nil { + t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") } select { case <-didClose: @@ -3892,7 +3888,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) { } // Test for issue 34282 -// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn +// Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn func TestTransportTraceGotConnH2IdleConns(t *testing.T) { tr := &Transport{} wantIdle := func(when string, n int) bool { @@ -3931,35 +3927,45 @@ func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { t.Skip("skipping in short mode") } - trFunc := func(tr *Transport) { - tr.MaxConnsPerHost = 1 - tr.MaxIdleConnsPerHost = 1 - tr.IdleConnTimeout = 10 * time.Millisecond - } - cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) + timeout := 1 * time.Millisecond + retry := true + for retry { + trFunc := func(tr *Transport) { + tr.MaxConnsPerHost = 1 + tr.MaxIdleConnsPerHost = 1 + tr.IdleConnTimeout = timeout + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) - if _, err := cst.c.Get(cst.ts.URL); err != nil { - t.Fatalf("got error: %s", err) - } + retry = false + tooShort := func(err error) bool { + if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { + return false + } + if !retry { + t.Helper() + t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout) + timeout *= 2 + retry = true + cst.close() + } + return true + } - time.Sleep(100 * time.Millisecond) - got := make(chan error) - go func() { if _, err := cst.c.Get(cst.ts.URL); err != nil { - got <- err + if tooShort(err) { + continue + } + t.Fatalf("got error: %s", err) } - close(got) - }() - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() - select { - case err := <-got: - if err != nil { + time.Sleep(10 * timeout) + if _, err := cst.c.Get(cst.ts.URL); err != nil { + if tooShort(err) { + continue + } t.Fatalf("got error: %s", err) } - case <-timeout.C: - t.Fatal("request never completed") } } @@ -3969,9 +3975,13 @@ func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { // golang.org/issue/8923 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } func testTransportRangeAndGzip(t *testing.T, mode testMode) { - reqc := make(chan *Request, 1) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - reqc <- r + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + t.Error("Transport advertised gzip support in the Accept header") + } + if r.Header.Get("Range") == "" { + t.Error("no Range in request") + } })).ts c := ts.Client() @@ -3981,18 +3991,6 @@ func testTransportRangeAndGzip(t *testing.T, mode testMode) { if err != nil { t.Fatal(err) } - - select { - case r := <-reqc: - if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - t.Error("Transport advertised gzip support in the Accept header") - } - if r.Header.Get("Range") == "" { - t.Error("no Range in request") - } - case <-time.After(10 * time.Second): - t.Fatal("timeout") - } res.Body.Close() } @@ -4091,6 +4089,45 @@ func testTransportDialCancelRace(t *testing.T, mode testMode) { } } +// https://go.dev/issue/49621 +func TestConnClosedBeforeRequestIsWritten(t *testing.T) { + run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode}) +} +func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + func(tr *Transport) { + tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + // Connection immediately returns errors. + return &funcConn{ + read: func([]byte) (int, error) { + return 0, errors.New("error") + }, + write: func([]byte) (int, error) { + return 0, errors.New("error") + }, + }, nil + } + }, + ).ts + // Set a short delay in RoundTrip to give the persistConn time to notice + // the connection is broken. We want to exercise the path where writeLoop exits + // before it reads the request to send. If this delay is too short, we may instead + // exercise the path where writeLoop accepts the request and then fails to write it. + // That's fine, so long as we get the desired path often enough. + SetEnterRoundTripHook(func() { + time.Sleep(1 * time.Millisecond) + }) + defer SetEnterRoundTripHook(nil) + var closes int + _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err == nil { + t.Fatalf("expected request to fail, but it did not") + } + if closes != 1 { + t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes) + } +} + // logWritesConn is a net.Conn that logs each Write call to writes // and then proxies to w. // It proxies Read calls to a reader it receives from rch. @@ -4207,15 +4244,26 @@ func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { res.Body.Close() }() - select { - case <-gotReq: - pw.Close() - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for handler to get request") - } + <-gotReq + pw.Close() <-gotRes } +type wgReadCloser struct { + io.Reader + wg *sync.WaitGroup + closed bool +} + +func (c *wgReadCloser) Close() error { + if c.closed { + return net.ErrClosed + } + c.closed = true + c.wg.Done() + return nil +} + // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { run(t, testTransportPrefersResponseOverWriteError) @@ -4237,12 +4285,29 @@ func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { fail := 0 count := 100 + bigBody := strings.Repeat("a", contentLengthLimit*2) + var wg sync.WaitGroup + defer wg.Wait() + getBody := func() (io.ReadCloser, error) { + wg.Add(1) + body := &wgReadCloser{ + Reader: strings.NewReader(bigBody), + wg: &wg, + } + return body, nil + } + for i := 0; i < count; i++ { - req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) + reqBody, _ := getBody() + req, err := NewRequest("PUT", ts.URL, reqBody) if err != nil { + reqBody.Close() t.Fatal(err) } + req.ContentLength = int64(len(bigBody)) + req.GetBody = getBody + resp, err := c.Do(req) if err != nil { fail++ @@ -4642,7 +4707,7 @@ func TestTransportRejectsAlphaPort(t *testing.T) { } } -// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 +// Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1 // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) @@ -4686,7 +4751,7 @@ func testTLSHandshakeTrace(t *testing.T, mode testMode) { t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") } if !done { - t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") + t.Fatal("Expected TLSHandshakeDone to be called, but wasn't") } } @@ -4696,57 +4761,80 @@ func testTransportIdleConnTimeout(t *testing.T, mode testMode) { t.Skip("skipping in short mode") } - const timeout = 1 * time.Second - - cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - // No body for convenience. - })) - tr := cst.tr - tr.IdleConnTimeout = timeout - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + timeout := 1 * time.Millisecond +timeoutLoop: + for { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // No body for convenience. + })) + tr := cst.tr + tr.IdleConnTimeout = timeout + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} - idleConns := func() []string { - if mode == http2Mode { - return tr.IdleConnStrsForTesting_h2() - } else { - return tr.IdleConnStrsForTesting() + idleConns := func() []string { + if mode == http2Mode { + return tr.IdleConnStrsForTesting_h2() + } else { + return tr.IdleConnStrsForTesting() + } } - } - var conn string - doReq := func(n int) { - req, _ := NewRequest("GET", cst.ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - PutIdleConn: func(err error) { - if err != nil { - t.Errorf("failed to keep idle conn: %v", err) + var conn string + doReq := func(n int) (timeoutOk bool) { + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + PutIdleConn: func(err error) { + if err != nil { + t.Errorf("failed to keep idle conn: %v", err) + } + }, + })) + res, err := c.Do(req) + if err != nil { + if strings.Contains(err.Error(), "use of closed network connection") { + t.Logf("req %v: connection closed prematurely", n) + return false } - }, - })) - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - conns := idleConns() - if len(conns) != 1 { - t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) - } - if conn == "" { - conn = conns[0] + } + res.Body.Close() + conns := idleConns() + if len(conns) != 1 { + if len(conns) == 0 { + t.Logf("req %v: no idle conns", n) + return false + } + t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) + } + if conn == "" { + conn = conns[0] + } + if conn != conns[0] { + t.Logf("req %v: cached connection changed; expected the same one throughout the test", n) + return false + } + return true } - if conn != conns[0] { - t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n) + for i := 0; i < 3; i++ { + if !doReq(i) { + t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout) + timeout *= 2 + cst.close() + continue timeoutLoop + } + time.Sleep(timeout / 2) } - } - for i := 0; i < 3; i++ { - doReq(i) - time.Sleep(timeout / 2) - } - time.Sleep(timeout * 3 / 2) - if got := idleConns(); len(got) != 0 { - t.Errorf("idle conns = %q; want none", got) + + waitCondition(t, timeout/2, func(d time.Duration) bool { + if got := idleConns(); len(got) != 0 { + if d >= timeout*3/2 { + t.Logf("after %v, idle conns = %q", d, got) + } + return false + } + return true + }) + break } } @@ -4792,13 +4880,9 @@ func testIdleConnH2Crash(t *testing.T, mode testMode) { cancel() - failTimer := time.NewTimer(5 * time.Second) - defer failTimer.Stop() select { case <-sawDoErr: case <-testDone: - case <-failTimer.C: - t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") } return c, nil } @@ -4888,16 +4972,13 @@ func testTransportProxyConnectHeader(t *testing.T, mode testMode) { res.Body.Close() t.Errorf("unexpected success") } - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case r := <-reqc: - if got, want := r.Header.Get("User-Agent"), "foo"; got != want { - t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) - } - if got, want := r.Header.Get("Other"), "bar"; got != want { - t.Errorf("CONNECT request Other = %q; want %q", got, want) - } + + r := <-reqc + if got, want := r.Header.Get("User-Agent"), "foo"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) } } @@ -4940,16 +5021,13 @@ func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { res.Body.Close() t.Errorf("unexpected success") } - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case r := <-reqc: - if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { - t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) - } - if got, want := r.Header.Get("Other"), "bar2"; got != want { - t.Errorf("CONNECT request Other = %q; want %q", got, want) - } + + r := <-reqc + if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar2"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) } } @@ -5168,46 +5246,58 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) } func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { - inHandler := make(chan net.Conn, 1) - handlerReadReturned := make(chan bool, 1) - cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - conn, _, err := w.(Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - inHandler <- conn - n, err := conn.Read([]byte{0}) - if n != 0 || err != io.EOF { - t.Errorf("unexpected Read result: %v, %v", n, err) - } - handlerReadReturned <- true - })) + timeout := 1 * time.Millisecond + for { + inHandler := make(chan bool) + cancelHandler := make(chan struct{}) + handlerDone := make(chan bool) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + <-r.Context().Done() - const timeout = 50 * time.Millisecond - cst.c.Timeout = timeout + select { + case <-cancelHandler: + return + case inHandler <- true: + } + defer func() { handlerDone <- true }() - _, err := cst.c.Get(cst.ts.URL) - if err == nil { - t.Fatal("unexpected Get succeess") - } + // Read from the conn until EOF to verify that it was correctly closed. + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + n, err := conn.Read([]byte{0}) + if n != 0 || err != io.EOF { + t.Errorf("unexpected Read result: %v, %v", n, err) + } + conn.Close() + })) - select { - case c := <-inHandler: + cst.c.Timeout = timeout + + _, err := cst.c.Get(cst.ts.URL) + if err == nil { + close(cancelHandler) + t.Fatal("unexpected Get success") + } + + tooSlow := time.NewTimer(timeout * 10) select { - case <-handlerReadReturned: - // Success. - return - case <-time.After(5 * time.Second): - t.Error("Handler's conn.Read seems to be stuck in Read") - c.Close() // close it to unblock Handler + case <-tooSlow.C: + // If we didn't get into the Handler, that probably means the builder was + // just slow and the Get failed in that time but never made it to the + // server. That's fine; we'll try again with a longer timeout. + t.Logf("no handler seen in %v; retrying with longer timeout", timeout) + close(cancelHandler) + cst.close() + timeout *= 2 + continue + case <-inHandler: + tooSlow.Stop() + <-handlerDone } - case <-time.After(timeout * 10): - // If we didn't get into the Handler in 50ms, that probably means - // the builder was just slow and the Get failed in that time - // but never made it to the server. That's fine. We'll usually - // test the part above on faster machines. - t.Skip("skipping test on slow builder") + break } } @@ -5220,18 +5310,27 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) } func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { - inHandler := make(chan net.Conn, 1) - handlerResult := make(chan error, 1) + inHandler := make(chan bool) + cancelHandler := make(chan struct{}) + handlerDone := make(chan bool) cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "100") w.(Flusher).Flush() + + select { + case <-cancelHandler: + return + case inHandler <- true: + } + defer func() { handlerDone <- true }() + conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) return } conn.Write([]byte("foo")) - inHandler <- conn + n, err := conn.Read([]byte{0}) // The error should be io.EOF or "read tcp // 127.0.0.1:35827->127.0.0.1:40290: read: connection @@ -5239,53 +5338,38 @@ func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { // care that it returns at all. But if it returns with // data, that's weird. if n != 0 || err == nil { - handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) - return + t.Errorf("unexpected Read result: %v, %v", n, err) } - handlerResult <- nil + conn.Close() })) // Set Timeout to something very long but non-zero to exercise // the codepaths that check for it. But rather than wait for it to fire // (which would make the test slow), we send on the req.Cancel channel instead, // which happens to exercise the same code paths. - cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. + cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it. req, _ := NewRequest("GET", cst.ts.URL, nil) - cancel := make(chan struct{}) - req.Cancel = cancel + cancelReq := make(chan struct{}) + req.Cancel = cancelReq res, err := cst.c.Do(req) if err != nil { - select { - case <-inHandler: - t.Fatalf("Get error: %v", err) - default: - // Failed before entering handler. Ignore result. - t.Skip("skipping test on slow builder") - } + close(cancelHandler) + t.Fatalf("Get error: %v", err) } - close(cancel) + // Cancel the request while the handler is still blocked on sending to the + // inHandler channel. Then read it until it fails, to verify that the + // connection is broken before the handler itself closes it. + close(cancelReq) got, err := io.ReadAll(res.Body) if err == nil { - t.Fatalf("unexpected success; read %q, nil", got) + t.Errorf("unexpected success; read %q, nil", got) } - select { - case c := <-inHandler: - select { - case err := <-handlerResult: - if err != nil { - t.Errorf("handler: %v", err) - } - return - case <-time.After(5 * time.Second): - t.Error("Handler's conn.Read seems to be stuck in Read") - c.Close() // close it to unblock Handler - } - case <-time.After(5 * time.Second): - t.Fatal("timeout") - } + // Now unblock the handler and wait for it to complete. + <-inHandler + <-handlerDone } func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { @@ -5721,17 +5805,18 @@ func testTransportIgnores408(t *testing.T, mode testMode) { t.Fatalf("got %q; want ok", slurp) } - t0 := time.Now() - for i := 0; i < 50; i++ { - time.Sleep(time.Duration(i) * 5 * time.Millisecond) - if cst.tr.IdleConnKeyCountForTesting() == 0 { - if got := logout.String(); got != "" { - t.Fatalf("expected no log output; got: %s", got) + waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool { + if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 { + if d > 0 { + t.Logf("%v idle conns still present after %v", n, d) } - return + return false } + return true + }) + if got := logout.String(); got != "" { + t.Fatalf("expected no log output; got: %s", got) } - t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) } func TestInvalidHeaderResponse(t *testing.T) { @@ -6004,26 +6089,22 @@ func TestAltProtoCancellation(t *testing.T) { Transport: tr, Timeout: time.Millisecond, } - tr.RegisterProtocol("timeout", timeoutProto{}) - _, err := c.Get("timeout://bar.com/path") + tr.RegisterProtocol("cancel", cancelProto{}) + _, err := c.Get("cancel://bar.com/path") if err == nil { t.Error("request unexpectedly succeeded") - } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { - t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) + } else if !strings.Contains(err.Error(), errCancelProto.Error()) { + t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto) } } -var timeoutProtoErr = errors.New("canceled as expected") +var errCancelProto = errors.New("canceled as expected") -type timeoutProto struct{} +type cancelProto struct{} -func (timeoutProto) RoundTrip(req *Request) (*Response, error) { - select { - case <-req.Cancel: - return nil, timeoutProtoErr - case <-time.After(5 * time.Second): - return nil, errors.New("request was not canceled") - } +func (cancelProto) RoundTrip(req *Request) (*Response, error) { + <-req.Cancel + return nil, errCancelProto } type roundTripFunc func(r *Request) (*Response, error) diff --git a/triv.go b/triv.go index fa6a249d..3b88227d 100644 --- a/triv.go +++ b/triv.go @@ -40,7 +40,7 @@ type Counter struct { func (ctr *Counter) String() string { ctr.mu.Lock() defer ctr.mu.Unlock() - return fmt.Sprintf("%d", ctr.n) + return strconv.Itoa(ctr.n) } func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) {