diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Request.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Request.kt index 34773fbab..8219aeb5a 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Request.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Request.kt @@ -83,15 +83,6 @@ class Request internal constructor( ), ) - init { - val connectionHeader = headers["Connection"] - if ("upgrade".equals(connectionHeader, ignoreCase = true)) { - require(body == null || body.contentLength() == 0L) { - "expected a null or empty request body with 'Connection: upgrade'" - } - } - } - fun header(name: String): String? = headers[name] fun headers(name: String): List = headers.values(name) diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/Exchange.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/Exchange.kt index 4484ca66e..64a902a91 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/Exchange.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/Exchange.kt @@ -46,10 +46,6 @@ class Exchange( internal var isDuplex: Boolean = false private set - /** True if the request body should not be used, but the socket, instead. */ - internal var isSocket: Boolean = false - private set - /** True if there was an exception on the connection to the peer. */ internal var hasFailure: Boolean = false private set @@ -82,7 +78,11 @@ class Exchange( val contentLength = request.body!!.contentLength() eventListener.requestBodyStart(call) val rawRequestBody = codec.createRequestBody(request, contentLength) - return RequestBodySink(rawRequestBody, contentLength) + return RequestBodySink( + delegate = rawRequestBody, + contentLength = contentLength, + isSocket = false, + ) } @Throws(IOException::class) @@ -134,7 +134,12 @@ class Exchange( val contentType = response.header("Content-Type") val contentLength = codec.reportedContentLength(response) val rawSource = codec.openResponseBodySource(response) - val source = ResponseBodySource(rawSource, contentLength) + val source = + ResponseBodySource( + delegate = rawSource, + contentLength = contentLength, + isSocket = false, + ) return RealResponseBody(contentType, contentLength, source.buffer()) } catch (e: IOException) { eventListener.responseFailed(call, e) @@ -147,8 +152,7 @@ class Exchange( fun peekTrailers(): Headers? = codec.peekTrailers() fun upgradeToSocket(): Socket { - isSocket = true - call.timeoutEarlyExit() + call.upgradeToSocket() (codec.carrier as RealConnection).useAsSocket() return object : Socket { @@ -156,8 +160,18 @@ class Exchange( this@Exchange.cancel() } - override val sink = RequestBodySink(codec.socket.sink, -1L) - override val source = ResponseBodySource(codec.socket.source, -1L) + override val sink = + RequestBodySink( + delegate = codec.socket.sink, + contentLength = -1L, + isSocket = true, + ) + override val source = + ResponseBodySource( + delegate = codec.socket.source, + contentLength = -1L, + isSocket = true, + ) } } @@ -179,6 +193,8 @@ class Exchange( exchange = this, requestDone = true, responseDone = true, + socketSinkDone = true, + socketSourceDone = true, e = null, ) } @@ -191,6 +207,7 @@ class Exchange( /** If [e] is non-null, this will return a non-null value. */ fun bodyComplete( bytesRead: Long = -1L, + isSocket: Boolean, responseDone: Boolean = false, requestDone: Boolean = false, e: IOException?, @@ -214,8 +231,10 @@ class Exchange( } return call.messageDone( exchange = this, - requestDone = requestDone, - responseDone = responseDone, + requestDone = requestDone && !isSocket, + responseDone = responseDone && !isSocket, + socketSinkDone = requestDone && isSocket, + socketSourceDone = responseDone && isSocket, e = e, ) } @@ -233,6 +252,7 @@ class Exchange( delegate: Sink, /** The exact number of bytes to be written, or -1L if that is unknown. */ private val contentLength: Long, + private val isSocket: Boolean, ) : ForwardingSink(delegate) { private var completed = false private var bytesReceived = 0L @@ -292,6 +312,7 @@ class Exchange( completed = true return bodyComplete( bytesRead = bytesReceived, + isSocket = isSocket, requestDone = true, e = e, ) @@ -302,6 +323,7 @@ class Exchange( internal inner class ResponseBodySource( delegate: Source, private val contentLength: Long, + private val isSocket: Boolean, ) : ForwardingSource(delegate) { private var bytesReceived = 0L private var invokeStartEvent = true @@ -372,6 +394,7 @@ class Exchange( } return bodyComplete( bytesRead = bytesReceived, + isSocket = isSocket, responseDone = true, e = e, ) diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealCall.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealCall.kt index 2c03fceea..ad13414a1 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealCall.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealCall.kt @@ -106,11 +106,10 @@ class RealCall( // These properties are guarded by `this`. They are typically only accessed by the thread executing // the call, but they may be accessed by other threads for duplex requests. - /** True if this call still has a request body open. */ private var requestBodyOpen = false - - /** True if this call still has a response body open. */ private var responseBodyOpen = false + private var socketSinkOpen = false + private var socketSourceOpen = false /** True if there are more exchanges expected for this call. */ private var expectMoreExchanges = true @@ -260,7 +259,7 @@ class RealCall( "cannot make a new request because the previous response is still open: " + "please call response.close()" } - check(!requestBodyOpen) + check(!requestBodyOpen && !socketSourceOpen && !socketSinkOpen) } if (newRoutePlanner) { @@ -292,8 +291,7 @@ class RealCall( internal fun initExchange(chain: RealInterceptorChain): Exchange { withLock { check(expectMoreExchanges) { "released" } - check(!responseBodyOpen) - check(!requestBodyOpen) + check(!responseBodyOpen && !requestBodyOpen && !socketSourceOpen && !socketSinkOpen) } val exchangeFinder = this.exchangeFinder!! @@ -331,22 +329,34 @@ class RealCall( exchange: Exchange, requestDone: Boolean = false, responseDone: Boolean = false, + socketSourceDone: Boolean = false, + socketSinkDone: Boolean = false, e: IOException?, ): IOException? { if (exchange != this.exchange) return e // This exchange was detached violently! - var bothStreamsDone = false + var allStreamsDone = false var callDone = false withLock { - if (requestDone && requestBodyOpen || responseDone && responseBodyOpen) { + if ( + requestDone && requestBodyOpen || + responseDone && responseBodyOpen || + socketSinkDone && socketSinkOpen || + socketSourceDone && socketSourceOpen + ) { if (requestDone) requestBodyOpen = false if (responseDone) responseBodyOpen = false - bothStreamsDone = !requestBodyOpen && !responseBodyOpen - callDone = !requestBodyOpen && !responseBodyOpen && !expectMoreExchanges + if (socketSinkDone) socketSinkOpen = false + if (socketSourceDone) socketSourceOpen = false + allStreamsDone = !requestBodyOpen && + !responseBodyOpen && + !socketSinkOpen && + !socketSourceOpen + callDone = allStreamsDone && !expectMoreExchanges } } - if (bothStreamsDone) { + if (allStreamsDone) { this.exchange = null this.connection?.incrementSuccessCount() } @@ -363,7 +373,7 @@ class RealCall( withLock { if (expectMoreExchanges) { expectMoreExchanges = false - callDone = !requestBodyOpen && !responseBodyOpen + callDone = !requestBodyOpen && !responseBodyOpen && !socketSinkOpen && !socketSourceOpen } } @@ -376,7 +386,8 @@ class RealCall( /** * Complete this call. This should be called once these properties are all false: - * [requestBodyOpen], [responseBodyOpen], and [expectMoreExchanges]. + * [requestBodyOpen], [responseBodyOpen], [socketSinkOpen], [socketSourceOpen], and + * [expectMoreExchanges]. * * This will release the connection if it is still held. * @@ -462,6 +473,20 @@ class RealCall( timeout.exit() } + fun upgradeToSocket() { + timeoutEarlyExit() + + withLock { + check(exchange != null) + check(!socketSinkOpen && !socketSourceOpen) + check(!requestBodyOpen) + check(responseBodyOpen) + responseBodyOpen = false + socketSinkOpen = true + socketSourceOpen = true + } + } + /** * @param closeExchange true if the current exchange should be closed because it will not be used. * This is usually due to either an exception or a retry. diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http/CallServerInterceptor.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http/CallServerInterceptor.kt index 8ca73379d..4451f8ca2 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http/CallServerInterceptor.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http/CallServerInterceptor.kt @@ -44,41 +44,39 @@ object CallServerInterceptor : Interceptor { try { exchange.writeRequestHeaders(request) - if (!isUpgradeRequest) { - if (hasRequestBody) { - // If there's a "Expect: 100-continue" header on the request, wait for a "HTTP/1.1 100 - // Continue" response before transmitting the request body. If we don't get that, return - // what we did get (such as a 4xx response) without ever transmitting the request body. - if ("100-continue".equals(request.header("Expect"), ignoreCase = true)) { + if (hasRequestBody) { + // If there's a "Expect: 100-continue" header on the request, wait for a "HTTP/1.1 100 + // Continue" response before transmitting the request body. If we don't get that, return + // what we did get (such as a 4xx response) without ever transmitting the request body. + if ("100-continue".equals(request.header("Expect"), ignoreCase = true)) { + exchange.flushRequest() + responseBuilder = exchange.readResponseHeaders(expectContinue = true) + exchange.responseHeadersStart() + invokeStartEvent = false + } + if (responseBuilder == null) { + if (requestBody.isDuplex()) { + // Prepare a duplex body so that the application can send a request body later. exchange.flushRequest() - responseBuilder = exchange.readResponseHeaders(expectContinue = true) - exchange.responseHeadersStart() - invokeStartEvent = false - } - if (responseBuilder == null) { - if (requestBody.isDuplex()) { - // Prepare a duplex body so that the application can send a request body later. - exchange.flushRequest() - val bufferedRequestBody = exchange.createRequestBody(request, true).buffer() - requestBody.writeTo(bufferedRequestBody) - } else { - // Write the request body if the "Expect: 100-continue" expectation was met. - val bufferedRequestBody = exchange.createRequestBody(request, false).buffer() - requestBody.writeTo(bufferedRequestBody) - bufferedRequestBody.close() - } + val bufferedRequestBody = exchange.createRequestBody(request, true).buffer() + requestBody.writeTo(bufferedRequestBody) } else { - exchange.noRequestBody() - if (!exchange.connection.isMultiplexed) { - // If the "Expect: 100-continue" expectation wasn't met, prevent the HTTP/1 connection - // from being reused. Otherwise we're still obligated to transmit the request body to - // leave the connection in a consistent state. - exchange.noNewExchangesOnConnection() - } + // Write the request body if the "Expect: 100-continue" expectation was met. + val bufferedRequestBody = exchange.createRequestBody(request, false).buffer() + requestBody.writeTo(bufferedRequestBody) + bufferedRequestBody.close() } } else { exchange.noRequestBody() + if (!exchange.connection.isMultiplexed) { + // If the "Expect: 100-continue" expectation wasn't met, prevent the HTTP/1 connection + // from being reused. Otherwise we're still obligated to transmit the request body to + // leave the connection in a consistent state. + exchange.noNewExchangesOnConnection() + } } + } else { + exchange.noRequestBody() } if (requestBody == null || !requestBody.isDuplex()) { @@ -154,9 +152,6 @@ object CallServerInterceptor : Interceptor { // This is not an upgrade response. else -> { - if (isUpgradeRequest) { - exchange.noRequestBody() // Failed upgrade request has no outbound data. - } val responseBody = exchange.openResponseBody(response) response .newBuilder() diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt index 57af599ec..daba5a2a9 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt @@ -17,8 +17,8 @@ package okhttp3.internal.http import assertk.assertThat import assertk.assertions.containsExactly -import assertk.assertions.hasMessage import assertk.assertions.isEqualTo +import assertk.assertions.isInstanceOf import assertk.assertions.isNull import assertk.assertions.isTrue import kotlin.test.assertFailsWith @@ -44,6 +44,8 @@ import okhttp3.CallEvent.ResponseBodyEnd import okhttp3.CallEvent.ResponseBodyStart import okhttp3.CallEvent.ResponseHeadersEnd import okhttp3.CallEvent.ResponseHeadersStart +import okhttp3.CallEvent.SecureConnectEnd +import okhttp3.CallEvent.SecureConnectStart import okhttp3.Headers.Companion.headersOf import okhttp3.OkHttpClientTestRule import okhttp3.Protocol @@ -122,10 +124,20 @@ class HttpUpgradesTest { } @Test - fun upgradeWithRequestBody() { + fun upgradeWithEmptyRequestBody() { executeAndCheckUpgrade(upgradeRequest().newBuilder().post(RequestBody.EMPTY).build()) } + @Test + fun upgradeWithNonEmptyRequestBody() { + executeAndCheckUpgrade( + upgradeRequest() + .newBuilder() + .post("Hello".toRequestBody()) + .build(), + ) + } + @Test fun upgradeHttps() { enableTls(Protocol.HTTP_1_1) @@ -238,8 +250,8 @@ class HttpUpgradesTest { } @Test - fun upgradeEventsWithRequestBody() { - upgradeWithRequestBody() + fun upgradeEventsWithEmptyRequestBody() { + upgradeWithEmptyRequestBody() assertThat(listener.recordedEventTypes()).containsExactly( CallStart::class, @@ -252,6 +264,8 @@ class HttpUpgradesTest { ConnectionAcquired::class, RequestHeadersStart::class, RequestHeadersEnd::class, + RequestBodyStart::class, + RequestBodyEnd::class, ResponseHeadersStart::class, ResponseHeadersEnd::class, FollowUpDecision::class, @@ -265,17 +279,32 @@ class HttpUpgradesTest { } @Test - fun upgradeRequestMustHaveAnEmptyBody() { - val e = - assertFailsWith { - Request - .Builder() - .url(server.url("/")) - .header("Connection", "upgrade") - .post("Hello".toRequestBody()) - .build() - } - assertThat(e).hasMessage("expected a null or empty request body with 'Connection: upgrade'") + fun upgradeEventsWithNonEmptyRequestBody() { + upgradeWithNonEmptyRequestBody() + + assertThat(listener.recordedEventTypes()).containsExactly( + CallStart::class, + ProxySelectStart::class, + ProxySelectEnd::class, + DnsStart::class, + DnsEnd::class, + ConnectStart::class, + ConnectEnd::class, + ConnectionAcquired::class, + RequestHeadersStart::class, + RequestHeadersEnd::class, + RequestBodyStart::class, + RequestBodyEnd::class, + ResponseHeadersStart::class, + ResponseHeadersEnd::class, + FollowUpDecision::class, + RequestBodyStart::class, + ResponseBodyStart::class, + ResponseBodyEnd::class, + RequestBodyEnd::class, + ConnectionReleased::class, + CallEnd::class, + ) } private fun enableTls(vararg protocols: Protocol) {