diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 19d747e755..ce7866b702 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,11 +40,11 @@ jobs: shell: bash run: tar -czf maven-repo.tgz -C ~ .m2/repository - name: Persist Maven Repo - uses: actions/upload-artifact@v1 + uses: actions/upload-artifact@v4 with: name: maven-repo path: maven-repo.tgz - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 if: failure() with: name: surefire-reports-build @@ -110,7 +110,7 @@ jobs: hostname || true - uses: actions/checkout@v2 - name: Download Maven Repo - uses: actions/download-artifact@v1 + uses: actions/download-artifact@v4 with: name: maven-repo path: . @@ -130,7 +130,7 @@ jobs: run: mvn -v - name: Run Tests run: mvn -U -B -fae test -Pproxy '-DfailIfNoTests=false' -pl ${{ matrix.module }} - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 if: failure() with: name: surefire-reports-${{ matrix.jdk }}-${{ matrix.module }}-${{ matrix.os }} @@ -163,7 +163,7 @@ jobs: hostname || true - uses: actions/checkout@v2 - name: Download Maven Repo - uses: actions/download-artifact@v1 + uses: actions/download-artifact@v4 with: name: maven-repo path: . @@ -182,7 +182,7 @@ jobs: run: mvn -v - name: Run Tests run: mvn -U -B -fae test ${{ matrix.proxy }} '-DfailIfNoTests=false' -pl ${{ matrix.module }} -Dtest.ipv6=true - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 if: failure() with: name: surefire-reports-${{ matrix.jdk }}-ipv6-${{ matrix.module }}${{ matrix.proxy }}-${{ matrix.os }} diff --git a/core/src/main/java/io/undertow/UndertowMessages.java b/core/src/main/java/io/undertow/UndertowMessages.java index 4edcc69af0..e13f7a6a5a 100644 --- a/core/src/main/java/io/undertow/UndertowMessages.java +++ b/core/src/main/java/io/undertow/UndertowMessages.java @@ -643,4 +643,10 @@ public interface UndertowMessages { @Message(id = 207, value = "Invalid SNI hostname '%s'") IllegalArgumentException invalidSniHostname(String hostNameValue, @Cause Throwable t); + + // 208 placeholder + + @Message(id = 209, value = "Protocol string was too large for the buffer. Either provide a smaller message or a bigger buffer. Protocol: %s") + IllegalStateException protocolTooLargeForBuffer(String protocolString); + } diff --git a/core/src/main/java/io/undertow/attribute/StoredResponse.java b/core/src/main/java/io/undertow/attribute/StoredResponse.java index 9623c6d0e2..b610a0b438 100644 --- a/core/src/main/java/io/undertow/attribute/StoredResponse.java +++ b/core/src/main/java/io/undertow/attribute/StoredResponse.java @@ -66,6 +66,11 @@ private String extractCharset(HeaderMap headers) { if(contentType.startsWith("text/")) { return StandardCharsets.ISO_8859_1.displayName(); } + // json has no charset param: https://www.iana.org/assignments/media-types/application/json + // the default is UTF-8: https://www.rfc-editor.org/rfc/rfc7158#section-8.1 & https://www.rfc-editor.org/rfc/rfc8259#section-8.1 + if(contentType.equals("application/json")) { + return StandardCharsets.UTF_8.name(); + } return null; } return null; diff --git a/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java b/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java index e95a418860..fac9921b8a 100644 --- a/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java +++ b/core/src/main/java/io/undertow/protocols/http2/Http2Channel.java @@ -1221,6 +1221,7 @@ private StreamHolder handleRstStream(int streamId, boolean receivedRst) { resetStreamTracker.store(streamId, holder); if(streamId % 2 == (isClient() ? 1 : 0)) { sendConcurrentStreamsAtomicUpdater.getAndDecrement(this); + holder.sent = true; } else { receiveConcurrentStreamsAtomicUpdater.getAndDecrement(this); } @@ -1235,13 +1236,14 @@ private StreamHolder handleRstStream(int streamId, boolean receivedRst) { //Server side originated, no input from client other than RST //this can happen on page refresh when push happens, but client //still has valid cache entry + //NOTE: this is specific case when its set. holder.resetByPeer = receivedRst; } else { handleRstWindow(); } } } else if(receivedRst){ - final StreamHolder resetStream = resetStreamTracker.find(streamId); + final StreamHolder resetStream = resetStreamTracker.find(streamId, true); if(resetStream != null && resetStream.resetByPeer) { //This means other side reset stream at some point. //depending on peer or network latency our frames might be late and @@ -1374,6 +1376,10 @@ private static final class StreamHolder { * This flag is set only in case of short lived server push that was reset by remote end. */ boolean resetByPeer = false; + /** + * flag indicate whether our side originated. This is done for caching purposes, handlng differs. + */ + boolean sent = false; Http2StreamSourceChannel sourceChannel; Http2StreamSinkChannel sinkChannel; @@ -1384,6 +1390,13 @@ private static final class StreamHolder { StreamHolder(Http2StreamSinkChannel sinkChannel) { this.sinkChannel = sinkChannel; } + + @Override + public String toString() { + return "StreamHolder [sourceClosed=" + sourceClosed + ", sinkClosed=" + sinkClosed + ", resetByPeer=" + resetByPeer + + ", sent=" + sent + ", sourceChannel=" + sourceChannel + ", sinkChannel=" + sinkChannel + "]"; + } + } // cache that keeps track of streams until they can be evicted @see Http2Channel#RST_STREAM_EVICATION_TIME @@ -1399,12 +1412,27 @@ private void store(int streamId, StreamHolder streamHolder) { streamHolders.put(streamId, streamHolder); entries.add(new StreamCacheEntry(streamId)); } - private StreamHolder find(int streamId) { + + /** + * Method will return only sent + * @param streamId + * @return + */ + private StreamHolder find(final int streamId) { + return find(streamId, false); + } + + private StreamHolder find(final int streamId, final boolean all) { for (Iterator iterator = entries.iterator(); iterator.hasNext();) { StreamCacheEntry entry = iterator.next(); if (entry.shouldEvict()) { iterator.remove(); StreamHolder holder = streamHolders.remove(entry.streamId); + if(!holder.sent || holder.resetByPeer) { + //if its not our end of chain, its just cached, so we only cache for purpose of + // handling eager RST + continue; + } AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel; if(receiver != null) { IoUtils.safeClose(receiver); @@ -1418,7 +1446,12 @@ private StreamHolder find(int streamId) { } } else break; } - return streamHolders.get(streamId); + final StreamHolder holder = streamHolders.get(streamId); + if(holder != null && (!all && !holder.sent)) { + return null; + } else { + return holder; + } } private Map getStreamHolders() { diff --git a/core/src/main/java/io/undertow/server/HttpServerExchange.java b/core/src/main/java/io/undertow/server/HttpServerExchange.java index 6210a8c429..a8b646a791 100644 --- a/core/src/main/java/io/undertow/server/HttpServerExchange.java +++ b/core/src/main/java/io/undertow/server/HttpServerExchange.java @@ -78,6 +78,7 @@ import java.util.TreeMap; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import static org.xnio.Bits.allAreSet; import static org.xnio.Bits.anyAreClear; @@ -165,7 +166,8 @@ public final class HttpServerExchange extends AbstractAttachable { // mutable state - private int state = 200; + private volatile int state = 200; + private static final AtomicIntegerFieldUpdater stateUpdater = AtomicIntegerFieldUpdater.newUpdater(HttpServerExchange.class, "state"); private HttpString requestMethod = HttpString.EMPTY; private String requestScheme; @@ -480,9 +482,9 @@ public HttpServerExchange setRequestURI(final String requestURI) { public HttpServerExchange setRequestURI(final String requestURI, boolean containsHost) { this.requestURI = requestURI; if (containsHost) { - this.state |= FLAG_URI_CONTAINS_HOST; + setFlags(FLAG_URI_CONTAINS_HOST); } else { - this.state &= ~FLAG_URI_CONTAINS_HOST; + clearFlags(FLAG_URI_CONTAINS_HOST); } return this; } @@ -763,9 +765,9 @@ void updateBytesSent(long bytes) { public HttpServerExchange setPersistent(final boolean persistent) { if (persistent) { - this.state = this.state | FLAG_PERSISTENT; + setFlags(FLAG_PERSISTENT); } else { - this.state = this.state & ~FLAG_PERSISTENT; + clearFlags(FLAG_PERSISTENT); } return this; } @@ -775,7 +777,7 @@ public boolean isDispatched() { } public HttpServerExchange unDispatch() { - state &= ~FLAG_DISPATCHED; + clearFlags(FLAG_DISPATCHED); dispatchTask = null; return this; } @@ -789,7 +791,7 @@ public HttpServerExchange unDispatch() { */ @Deprecated public HttpServerExchange dispatch() { - state |= FLAG_DISPATCHED; + setFlags(FLAG_DISPATCHED); return this; } @@ -825,7 +827,7 @@ public HttpServerExchange dispatch(final Executor executor, final Runnable runna if (executor != null) { this.dispatchExecutor = executor; } - state |= FLAG_DISPATCHED; + setFlags(FLAG_DISPATCHED); if(anyAreSet(state, FLAG_SHOULD_RESUME_READS | FLAG_SHOULD_RESUME_WRITES)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -892,9 +894,9 @@ boolean isInCall() { HttpServerExchange setInCall(boolean value) { if (value) { - state |= FLAG_IN_CALL; + setFlags(FLAG_IN_CALL); } else { - state &= ~FLAG_IN_CALL; + clearFlags(FLAG_IN_CALL); } return this; } @@ -1268,7 +1270,7 @@ public boolean isResponseStarted() { public StreamSourceChannel getRequestChannel() { if (requestChannel != null) { if(anyAreSet(state, FLAG_REQUEST_RESET)) { - state &= ~FLAG_REQUEST_RESET; + clearFlags(FLAG_REQUEST_RESET); return requestChannel; } return null; @@ -1287,7 +1289,7 @@ public StreamSourceChannel getRequestChannel() { } void resetRequestChannel() { - state |= FLAG_REQUEST_RESET; + setFlags(FLAG_REQUEST_RESET); } public boolean isRequestChannelAvailable() { @@ -1328,8 +1330,7 @@ public boolean isResponseComplete() { * the socket or implement a transfer coding. */ void terminateRequest() { - int oldVal = state; - if (allAreSet(oldVal, FLAG_REQUEST_TERMINATED)) { + if (allAreSet(state, FLAG_REQUEST_TERMINATED)) { // idempotent return; } @@ -1337,8 +1338,8 @@ void terminateRequest() { requestChannel.suspendReads(); requestChannel.requestDone(); } - this.state = oldVal | FLAG_REQUEST_TERMINATED; - if (anyAreSet(oldVal, FLAG_RESPONSE_TERMINATED)) { + setFlags(FLAG_REQUEST_TERMINATED); + if (anyAreSet(state, FLAG_RESPONSE_TERMINATED)) { invokeExchangeCompleteListeners(); } } @@ -1476,8 +1477,7 @@ public HttpServerExchange setStatusCode(final int statusCode) { if (statusCode < 0 || statusCode > 999) { throw new IllegalArgumentException("Invalid response code"); } - int oldVal = state; - if (allAreSet(oldVal, FLAG_RESPONSE_SENT)) { + if (allAreSet(state, FLAG_RESPONSE_SENT)) { throw UndertowMessages.MESSAGES.responseAlreadyStarted(); } if(statusCode >= 500) { @@ -1485,7 +1485,8 @@ public HttpServerExchange setStatusCode(final int statusCode) { UndertowLogger.ERROR_RESPONSE.debugf(new RuntimeException(), "Setting error code %s for exchange %s", statusCode, this); } } - this.state = oldVal & ~MASK_RESPONSE_CODE | statusCode & MASK_RESPONSE_CODE; + clearFlags(MASK_RESPONSE_CODE); + setFlags(statusCode & MASK_RESPONSE_CODE); return this; } @@ -1625,8 +1626,7 @@ public OutputStream getOutputStream() { * the socket or implement a transfer coding. */ HttpServerExchange terminateResponse() { - int oldVal = state; - if (allAreSet(oldVal, FLAG_RESPONSE_TERMINATED)) { + if (allAreSet(state, FLAG_RESPONSE_TERMINATED)) { // idempotent return this; } @@ -1634,8 +1634,8 @@ HttpServerExchange terminateResponse() { responseChannel.suspendWrites(); responseChannel.responseDone(); } - this.state = oldVal | FLAG_RESPONSE_TERMINATED; - if (anyAreSet(oldVal, FLAG_REQUEST_TERMINATED)) { + setFlags(FLAG_RESPONSE_TERMINATED); + if (anyAreSet(state, FLAG_REQUEST_TERMINATED)) { invokeExchangeCompleteListeners(); } return this; @@ -1870,11 +1870,10 @@ public void handleException(final Channel channel, final IOException exception) * @throws IllegalStateException if the response headers were already sent */ HttpServerExchange startResponse() throws IllegalStateException { - int oldVal = state; - if (allAreSet(oldVal, FLAG_RESPONSE_SENT)) { + if (allAreSet(state, FLAG_RESPONSE_SENT)) { throw UndertowMessages.MESSAGES.responseAlreadyStarted(); } - this.state = oldVal | FLAG_RESPONSE_SENT; + setFlags(FLAG_RESPONSE_SENT); log.tracef("Starting to write response for %s", this); return this; @@ -2059,7 +2058,7 @@ protected boolean isFinished() { @Override public void resumeWrites() { if (isInCall()) { - state |= FLAG_SHOULD_RESUME_WRITES; + setFlags(FLAG_SHOULD_RESUME_WRITES); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2070,7 +2069,7 @@ public void resumeWrites() { @Override public void suspendWrites() { - state &= ~FLAG_SHOULD_RESUME_WRITES; + clearFlags(FLAG_SHOULD_RESUME_WRITES); super.suspendWrites(); } @@ -2081,7 +2080,7 @@ public void wakeupWrites() { } if (isInCall()) { wakeup = true; - state |= FLAG_SHOULD_RESUME_WRITES; + setFlags(FLAG_SHOULD_RESUME_WRITES); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2102,10 +2101,10 @@ public void runResume() { } else { if (wakeup) { wakeup = false; - state &= ~FLAG_SHOULD_RESUME_WRITES; + clearFlags(FLAG_SHOULD_RESUME_WRITES); delegate.wakeupWrites(); } else { - state &= ~FLAG_SHOULD_RESUME_WRITES; + clearFlags(FLAG_SHOULD_RESUME_WRITES); delegate.resumeWrites(); } } @@ -2232,7 +2231,7 @@ protected boolean isFinished() { public void resumeReads() { readsResumed = true; if (isInCall()) { - state |= FLAG_SHOULD_RESUME_READS; + setFlags(FLAG_SHOULD_RESUME_READS); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2245,7 +2244,7 @@ public void resumeReads() { public void wakeupReads() { if (isInCall()) { wakeup = true; - state |= FLAG_SHOULD_RESUME_READS; + setFlags(FLAG_SHOULD_RESUME_READS); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2302,7 +2301,7 @@ public void awaitReadable() throws IOException { @Override public void suspendReads() { readsResumed = false; - state &= ~(FLAG_SHOULD_RESUME_READS); + clearFlags(FLAG_SHOULD_RESUME_READS); super.suspendReads(); } @@ -2472,10 +2471,10 @@ public void runResume() { } else { if (wakeup) { wakeup = false; - state &= ~FLAG_SHOULD_RESUME_READS; + clearFlags(FLAG_SHOULD_RESUME_READS); delegate.wakeupReads(); } else { - state &= ~FLAG_SHOULD_RESUME_READS; + clearFlags(FLAG_SHOULD_RESUME_READS); delegate.resumeReads(); } } @@ -2540,4 +2539,18 @@ public T create() { public String toString() { return "HttpServerExchange{ " + getRequestMethod().toString() + " " + getRequestURI() + '}'; } + + private void setFlags(int flags) { + int old; + do { + old = state; + } while (!stateUpdater.compareAndSet(this, old, old | flags)); + } + + private void clearFlags(int flags) { + int old; + do { + old = state; + } while (!stateUpdater.compareAndSet(this, old, old & ~flags)); + } } diff --git a/core/src/main/java/io/undertow/server/protocol/http/HttpResponseConduit.java b/core/src/main/java/io/undertow/server/protocol/http/HttpResponseConduit.java index 0019f26799..f8fdfe916c 100644 --- a/core/src/main/java/io/undertow/server/protocol/http/HttpResponseConduit.java +++ b/core/src/main/java/io/undertow/server/protocol/http/HttpResponseConduit.java @@ -179,8 +179,20 @@ private int processWrite(int state, final Object userData, int pos, int length) // we don't have a dangling flag that won't be cleared at the finally block this.state |= POOLED_BUFFER_IN_USE; assert buffer.remaining() >= 50; - // append response status and headers - Protocols.HTTP_1_1.appendTo(buffer); + // append protocol + HttpString protocol = exchange.getProtocol(); + String protocolString = protocol.toString(); + if (protocolString.isEmpty()) { + protocol = Protocols.HTTP_1_1; + } + if (protocol.length() > buffer.remaining()) { + pooledBuffer.close(); + pooledBuffer = null; + truncateWrites(); + throw UndertowMessages.MESSAGES.protocolTooLargeForBuffer(protocolString); + } + protocol.appendTo(buffer); + // append status code, reason phrase, and headers buffer.put((byte) ' '); int code = exchange.getStatusCode(); assert 999 >= code && code >= 100; diff --git a/core/src/main/java/io/undertow/server/protocol/proxy/ProxyProtocolReadListener.java b/core/src/main/java/io/undertow/server/protocol/proxy/ProxyProtocolReadListener.java index 2d3f4715ff..08cc48de52 100644 --- a/core/src/main/java/io/undertow/server/protocol/proxy/ProxyProtocolReadListener.java +++ b/core/src/main/java/io/undertow/server/protocol/proxy/ProxyProtocolReadListener.java @@ -49,6 +49,7 @@ class ProxyProtocolReadListener implements ChannelListener private final UndertowXnioSsl ssl; private final ByteBufferPool bufferPool; private final OptionMap sslOptionMap; + private final StringBuilder stringBuilder = new StringBuilder(); private int byteCount; private String protocol; @@ -222,7 +223,6 @@ private void parseProxyProtocolV2(PooledByteBuffer buffer, AtomicBoolean freeBuf } private void parseProxyProtocolV1(PooledByteBuffer buffer, AtomicBoolean freeBuffer) throws Exception { - final StringBuilder stringBuilder = new StringBuilder(); while (buffer.getBuffer().hasRemaining()) { char c = (char) buffer.getBuffer().get(); if (byteCount < NAME.length) { @@ -281,31 +281,46 @@ private void parseProxyProtocolV1(PooledByteBuffer buffer, AtomicBoolean freeBuf throw UndertowMessages.MESSAGES.invalidProxyHeader(); } } else if (sourceAddress == null) { - sourceAddress = parseAddress(stringBuilder.toString(), protocol); - stringBuilder.setLength(0); + try { + sourceAddress = parseAddress(stringBuilder.toString(), protocol); + } finally { + stringBuilder.setLength(0); + } } else if (destAddress == null) { - destAddress = parseAddress(stringBuilder.toString(), protocol); - stringBuilder.setLength(0); + try { + destAddress = parseAddress(stringBuilder.toString(), protocol); + } finally { + stringBuilder.setLength(0); + } } else { - sourcePort = Integer.parseInt(stringBuilder.toString()); - stringBuilder.setLength(0); + try { + sourcePort = Integer.parseInt(stringBuilder.toString()); + } finally { + stringBuilder.setLength(0); + } } break; case '\r': if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) { - destPort = Integer.parseInt(stringBuilder.toString()); - stringBuilder.setLength(0); + try { + destPort = Integer.parseInt(stringBuilder.toString()); + } finally { + stringBuilder.setLength(0); + } carriageReturnSeen = true; } else if (protocol == null) { if (UNKNOWN.equals(stringBuilder.toString())) { parsingUnknown = true; carriageReturnSeen = true; } + stringBuilder.setLength(0); } else { + stringBuilder.setLength(0); throw UndertowMessages.MESSAGES.invalidProxyHeader(); } break; case '\n': + stringBuilder.setLength(0); throw UndertowMessages.MESSAGES.invalidProxyHeader(); default: stringBuilder.append(c); diff --git a/core/src/test/java/io/undertow/protocols/http2/PushResourceRSTTestCase.java b/core/src/test/java/io/undertow/protocols/http2/PushResourceRSTTestCase.java index 415614bc45..faa2b94fd9 100644 --- a/core/src/test/java/io/undertow/protocols/http2/PushResourceRSTTestCase.java +++ b/core/src/test/java/io/undertow/protocols/http2/PushResourceRSTTestCase.java @@ -17,41 +17,6 @@ */ package io.undertow.protocols.http2; -import static io.undertow.server.protocol.http2.Http2OpenListener.HTTP2; -import static io.undertow.testutils.StopServerWithExternalWorkerUtils.stopWorker; -import static java.security.AccessController.doPrivileged; - -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.URI; -import java.security.PrivilegedAction; -import java.util.List; -import java.util.ServiceLoader; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; -import org.xnio.ChannelListener; -import org.xnio.ChannelListeners; -import org.xnio.FutureResult; -import org.xnio.IoFuture; -import org.xnio.IoUtils; -import org.xnio.OptionMap; -import org.xnio.Options; -import org.xnio.StreamConnection; -import org.xnio.Xnio; -import org.xnio.XnioWorker; -import org.xnio.channels.StreamSinkChannel; -import org.xnio.ssl.SslConnection; - import io.undertow.Undertow; import io.undertow.UndertowLogger; import io.undertow.UndertowOptions; @@ -70,15 +35,49 @@ import io.undertow.protocols.ssl.UndertowXnioSsl; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.PathHandler; -import io.undertow.testutils.AjpIgnore; import io.undertow.testutils.DefaultServer; -import io.undertow.testutils.ProxyIgnore; import io.undertow.testutils.category.UnitTest; import io.undertow.util.AttachmentKey; import io.undertow.util.Headers; import io.undertow.util.Methods; import io.undertow.util.StatusCodes; import io.undertow.util.StringReadChannelListener; +import org.jboss.logging.Logger; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.xnio.ChannelListener; +import org.xnio.ChannelListeners; +import org.xnio.FutureResult; +import org.xnio.IoFuture; +import org.xnio.IoUtils; +import org.xnio.OptionMap; +import org.xnio.Options; +import org.xnio.StreamConnection; +import org.xnio.Xnio; +import org.xnio.XnioWorker; +import org.xnio.channels.StreamSinkChannel; +import org.xnio.ssl.SslConnection; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.security.PrivilegedAction; +import java.util.List; +import java.util.ServiceLoader; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static io.undertow.server.protocol.http2.Http2OpenListener.HTTP2; +import static io.undertow.testutils.StopServerWithExternalWorkerUtils.stopWorker; +import static java.security.AccessController.doPrivileged; /** * Test RST frames handling on push. This test mimics rapid refresh on client side, which will result in push requests from @@ -87,9 +86,9 @@ */ @Category(UnitTest.class) @RunWith(DefaultServer.class) -@ProxyIgnore -@AjpIgnore +@Ignore public class PushResourceRSTTestCase { + private static final Logger log = Logger.getLogger(PushResourceRSTTestCase.class); private static final String PUSHER = "/pusher"; private static final String PUSHER_MSG; private static final String TRIGGER = "/trigger"; @@ -240,18 +239,30 @@ private ClientCallback createClientCallback(final List() { @Override - public void completed(ClientExchange result) { + public void completed(final ClientExchange result) { result.setPushHandler(new PushCallback() { @Override public boolean handlePush(ClientExchange originalRequest, ClientExchange pushedRequest) { pushRstCount.incrementAndGet(); + log.debugf("Handling push %d", pushRstCount.get()); latch.countDown(); + setUpResponseListenerAndShutdownWrites(result); return false; } }); + } + @Override + public void failed(IOException e) { + e.printStackTrace(); + exception = e; + latch.countDown(); + } + + private void setUpResponseListenerAndShutdownWrites(ClientExchange result) { result.setResponseListener(new ClientCallback() { @Override public void completed(final ClientExchange result) { + log.debugf("Got result %s", result); responses.add(result.getResponse()); new StringReadChannelListener(result.getConnection().getBufferPool()) { @@ -290,13 +301,6 @@ public void failed(IOException e) { latch.countDown(); } } - - @Override - public void failed(IOException e) { - e.printStackTrace(); - exception = e; - latch.countDown(); - } }; } } diff --git a/core/src/test/java/io/undertow/server/handlers/StatusLineTestCase.java b/core/src/test/java/io/undertow/server/handlers/StatusLineTestCase.java new file mode 100644 index 0000000000..ecfe4721c1 --- /dev/null +++ b/core/src/test/java/io/undertow/server/handlers/StatusLineTestCase.java @@ -0,0 +1,141 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2024 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.undertow.server.handlers; + +import io.undertow.io.Sender; +import io.undertow.server.HttpHandler; +import io.undertow.server.HttpServerExchange; +import io.undertow.server.ServerConnection; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.ProxyIgnore; +import io.undertow.testutils.TestHttpClient; +import io.undertow.util.Headers; +import io.undertow.util.HttpString; +import io.undertow.util.StatusCodes; +import org.apache.http.HttpResponse; +import org.apache.http.ProtocolVersion; +import org.apache.http.client.methods.HttpGet; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.IOException; + +/** + * Tests that if the protocol is set to a value, that value is returned on the + * status line. + * + * @author Jeff Okamoto + */ +@RunWith(DefaultServer.class) +@ProxyIgnore +public class StatusLineTestCase { + + /* + * For the purposes of the test, the protocol name has to be "HTTP" because the test + * framework runs through a parser, and it rejects other strings. + */ + private static final String DEFAULT_PROTOCOL_NAME = "HTTP"; + private static final String DEFAULT_PROTOCOL_MAJOR = "1"; + private static final String DEFAULT_PROTOCOL_MINOR = "1"; + private static final String PROTOCOL_NAME = "HTTP"; + private static final String PROTOCOL_MAJOR = "3"; + private static final String PROTOCOL_MINOR = "4"; + private static final String PROTOCOL_STRING = PROTOCOL_NAME + "/" + PROTOCOL_MAJOR + "." + PROTOCOL_MINOR; + private static final String REASON_PHRASE = "Reason-Phrase"; + private static final String MESSAGE = "My HTTP Request!"; + + private static volatile ServerConnection connection; + + @Test + public void verifyStatusLine() throws IOException { + DefaultServer.setRootHandler(new HttpHandler() { + + @Override + public void handleRequest(final HttpServerExchange exchange) throws Exception { + if (connection == null) { + connection = exchange.getConnection(); + } else if (!DefaultServer.isAjp() && !DefaultServer.isProxy() && connection != exchange.getConnection()) { + Sender sender = exchange.getResponseSender(); + sender.send("Connection not persistent"); + return; + } + exchange.setProtocol(new HttpString(PROTOCOL_STRING)); + exchange.setReasonPhrase(REASON_PHRASE); + exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, MESSAGE.length() + ""); + final Sender sender = exchange.getResponseSender(); + sender.send(MESSAGE); + } + }); + + connection = null; + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/path"); + TestHttpClient client = new TestHttpClient(); + try { + HttpResponse result = client.execute(get); + Assert.assertEquals(StatusCodes.OK, result.getStatusLine().getStatusCode()); + + ProtocolVersion protocolVersion = result.getStatusLine().getProtocolVersion(); + Assert.assertEquals(PROTOCOL_NAME, protocolVersion.getProtocol()); + Assert.assertEquals(Integer.parseInt(PROTOCOL_MAJOR), protocolVersion.getMajor()); + Assert.assertEquals(Integer.parseInt(PROTOCOL_MINOR), protocolVersion.getMinor()); + + Assert.assertEquals(REASON_PHRASE, result.getStatusLine().getReasonPhrase()); + } finally { + client.getConnectionManager().shutdown(); + } + } + + @Test + public void verifyDefaultStatusLine() throws IOException { + DefaultServer.setRootHandler(new HttpHandler() { + + @Override + public void handleRequest(final HttpServerExchange exchange) throws Exception { + if (connection == null) { + connection = exchange.getConnection(); + } else if (!DefaultServer.isAjp() && !DefaultServer.isProxy() && connection != exchange.getConnection()) { + Sender sender = exchange.getResponseSender(); + sender.send("Connection not persistent"); + return; + } + exchange.getResponseHeaders().put(Headers.CONTENT_LENGTH, MESSAGE.length() + ""); + final Sender sender = exchange.getResponseSender(); + sender.send(MESSAGE); + } + }); + + connection = null; + HttpGet get = new HttpGet(DefaultServer.getDefaultServerURL() + "/path"); + TestHttpClient client = new TestHttpClient(); + try { + HttpResponse result = client.execute(get); + Assert.assertEquals(StatusCodes.OK, result.getStatusLine().getStatusCode()); + + ProtocolVersion protocolVersion = result.getStatusLine().getProtocolVersion(); + Assert.assertEquals(DEFAULT_PROTOCOL_NAME, protocolVersion.getProtocol()); + Assert.assertEquals(Integer.parseInt(DEFAULT_PROTOCOL_MAJOR), protocolVersion.getMajor()); + Assert.assertEquals(Integer.parseInt(DEFAULT_PROTOCOL_MINOR), protocolVersion.getMinor()); + + } finally { + client.getConnectionManager().shutdown(); + } + } + +} diff --git a/servlet/src/main/java/io/undertow/servlet/spec/HttpServletRequestImpl.java b/servlet/src/main/java/io/undertow/servlet/spec/HttpServletRequestImpl.java index dcd3781f66..d878e1922c 100644 --- a/servlet/src/main/java/io/undertow/servlet/spec/HttpServletRequestImpl.java +++ b/servlet/src/main/java/io/undertow/servlet/spec/HttpServletRequestImpl.java @@ -533,7 +533,6 @@ public void logout() throws ServletException { @Override public Collection getParts() throws IOException, ServletException { - verifyMultipartServlet(); if (parts == null) { loadParts(); } @@ -550,11 +549,7 @@ private void verifyMultipartServlet() { @Override public Part getPart(final String name) throws IOException, ServletException { - verifyMultipartServlet(); - if (parts == null) { - loadParts(); - } - for (Part part : parts) { + for (Part part : getParts()) { if (part.getName().equals(name)) { return part; } @@ -580,6 +575,7 @@ private void loadParts() throws IOException, ServletException { final ServletRequestContext requestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY); if (parts == null) { + verifyMultipartServlet(); final List parts = new ArrayList<>(); String mimeType = exchange.getRequestHeaders().getFirst(Headers.CONTENT_TYPE); if (mimeType != null && mimeType.startsWith(MultiPartParserDefinition.MULTIPART_FORM_DATA)) { diff --git a/servlet/src/main/java/io/undertow/servlet/spec/ServletPrintWriter.java b/servlet/src/main/java/io/undertow/servlet/spec/ServletPrintWriter.java index e9d6a13899..eb5cf1a08d 100644 --- a/servlet/src/main/java/io/undertow/servlet/spec/ServletPrintWriter.java +++ b/servlet/src/main/java/io/undertow/servlet/spec/ServletPrintWriter.java @@ -172,8 +172,10 @@ public void write(final CharBuffer input) { remainingContentLength -= writtenLength; outputStream.updateWritten(writtenLength); if (result.isOverflow() || !buffer.hasRemaining()) { + final int remainingBytesBeforeFlush = buffer.remaining(); outputStream.flushInternal(); - if (buffer.remaining() == remaining) { + if (buffer.remaining() == remainingBytesBeforeFlush) { + // no progress has been made, set error to true error = true; return; }