diff --git a/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt b/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt index 0a7a27920..12641e5d7 100644 --- a/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt +++ b/mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt @@ -516,7 +516,7 @@ public class MockWebServer : Closeable { val connection = Http2Connection .Builder(false, taskRunner) - .socket(socket.javaNetSocket) + .socket(socket, socket.javaNetSocket.remoteSocketAddress.toString()) .listener(http2SocketHandler) .build() connection.start() @@ -836,8 +836,6 @@ public class MockWebServer : Closeable { webSocket.initReaderAndWriter( name = name, socket = socket, - socketSource = socket.source, - socketSink = socket.sink, client = false, ) diff --git a/mockwebserver/src/main/kotlin/mockwebserver3/internal/MockWebServerSocket.kt b/mockwebserver/src/main/kotlin/mockwebserver3/internal/MockWebServerSocket.kt index 8bd1d8ab9..0bff298e8 100644 --- a/mockwebserver/src/main/kotlin/mockwebserver3/internal/MockWebServerSocket.kt +++ b/mockwebserver/src/main/kotlin/mockwebserver3/internal/MockWebServerSocket.kt @@ -23,6 +23,7 @@ import java.util.concurrent.CountDownLatch import javax.net.ssl.SSLSocket import okhttp3.Handshake import okhttp3.Handshake.Companion.handshake +import okhttp3.internal.connection.BufferedSocket import okhttp3.internal.platform.Platform import okio.BufferedSink import okio.BufferedSource @@ -40,7 +41,7 @@ import okio.buffer internal class MockWebServerSocket( val javaNetSocket: Socket, ) : Closeable, - okio.Socket { + BufferedSocket { private val delegate = javaNetSocket.asOkioSocket() private val closedLatch = CountDownLatch(2) diff --git a/mockwebserver/src/test/java/mockwebserver3/internal/http2/Http2Server.kt b/mockwebserver/src/test/java/mockwebserver3/internal/http2/Http2Server.kt index bb818a090..bc9a3a12b 100644 --- a/mockwebserver/src/test/java/mockwebserver3/internal/http2/Http2Server.kt +++ b/mockwebserver/src/test/java/mockwebserver3/internal/http2/Http2Server.kt @@ -17,6 +17,7 @@ package mockwebserver3.internal.http2 import java.io.File import java.io.IOException +import java.net.InetSocketAddress import java.net.ProtocolException import java.net.ServerSocket import java.net.Socket @@ -28,6 +29,7 @@ import okhttp3.Protocol import okhttp3.Protocol.Companion.get import okhttp3.internal.closeQuietly import okhttp3.internal.concurrent.TaskRunner +import okhttp3.internal.connection.asBufferedSocket import okhttp3.internal.http2.Header import okhttp3.internal.http2.Http2Connection import okhttp3.internal.http2.Http2Stream @@ -57,7 +59,7 @@ class Http2Server( val connection = Http2Connection .Builder(false, TaskRunner.INSTANCE) - .socket(sslSocket) + .socket(sslSocket.asBufferedSocket(), sslSocket.peerName()) .listener(this) .build() connection.start() @@ -192,6 +194,11 @@ class Http2Server( else -> "text/plain" } + private fun Socket.peerName(): String { + val address = remoteSocketAddress + return if (address is InetSocketAddress) address.hostName else address.toString() + } + companion object { val logger: Logger = Logger.getLogger(Http2Server::class.java.name) diff --git a/okhttp-testing-support/src/main/kotlin/okhttp3/internal/duplex/MockSocketHandler.kt b/okhttp-testing-support/src/main/kotlin/okhttp3/internal/duplex/MockSocketHandler.kt index 7dd531c4e..97ab09b7f 100644 --- a/okhttp-testing-support/src/main/kotlin/okhttp3/internal/duplex/MockSocketHandler.kt +++ b/okhttp-testing-support/src/main/kotlin/okhttp3/internal/duplex/MockSocketHandler.kt @@ -21,25 +21,13 @@ import java.util.concurrent.FutureTask import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit import mockwebserver3.SocketHandler -import okio.BufferedSink -import okio.BufferedSource +import okhttp3.internal.connection.BufferedSocket +import okhttp3.internal.connection.asBufferedSocket import okio.Socket -import okio.buffer import okio.utf8Size private typealias Action = (BufferedSocket) -> Unit -private class BufferedSocket( - val socket: Socket, -) { - val source: BufferedSource = socket.source.buffer() - val sink: BufferedSink = socket.sink.buffer() - - fun cancel() { - socket.cancel() - } -} - /** * A scriptable request/response conversation. Create the script by calling methods like * [receiveRequest] in the sequence they are run. @@ -104,7 +92,7 @@ class MockSocketHandler : SocketHandler { } override fun handle(socket: Socket) { - val task = serviceSocketTask(BufferedSocket(socket)) + val task = serviceSocketTask(socket.asBufferedSocket()) results.add(task) task.run() } diff --git a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/InsecureExtendedTrustManager.kt b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/InsecureExtendedTrustManager.kt index 36ad28977..343eaf9cd 100644 --- a/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/InsecureExtendedTrustManager.kt +++ b/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/InsecureExtendedTrustManager.kt @@ -16,12 +16,12 @@ package okhttp3.tls.internal +import java.net.InetSocketAddress import java.net.Socket import java.security.cert.CertificateException import java.security.cert.X509Certificate import javax.net.ssl.SSLEngine import javax.net.ssl.X509ExtendedTrustManager -import okhttp3.internal.peerName import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement /** @@ -78,4 +78,9 @@ internal class InsecureExtendedTrustManager( authType: String, socket: Socket?, ) = throw CertificateException("Unsupported operation") + + private fun Socket.peerName(): String { + val address = remoteSocketAddress + return if (address is InetSocketAddress) address.hostName else address.toString() + } } diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/-UtilJvm.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/-UtilJvm.kt index 45f2b4167..c79cd2120 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/-UtilJvm.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/-UtilJvm.kt @@ -19,7 +19,6 @@ package okhttp3.internal import java.io.IOException import java.io.InterruptedIOException -import java.net.InetSocketAddress import java.net.ServerSocket import java.net.Socket import java.net.SocketTimeoutException @@ -197,11 +196,6 @@ internal fun Source.discard( false } -internal fun Socket.peerName(): String { - val address = remoteSocketAddress - return if (address is InetSocketAddress) address.hostName else address.toString() -} - /** * Returns true if new reads and writes should be attempted on this. * diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/BufferedSocket.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/BufferedSocket.kt new file mode 100644 index 000000000..75fb4f4ab --- /dev/null +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/BufferedSocket.kt @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2025 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.internal.connection + +import java.net.Socket as JavaNetSocket +import okio.BufferedSink +import okio.BufferedSource +import okio.Socket as OkioSocket +import okio.asOkioSocket +import okio.buffer + +interface BufferedSocket : OkioSocket { + override val source: BufferedSource + override val sink: BufferedSink +} + +fun JavaNetSocket.asBufferedSocket(): BufferedSocket = asOkioSocket().asBufferedSocket() + +fun OkioSocket.asBufferedSocket(): BufferedSocket = + object : BufferedSocket { + private val delegate = this@asBufferedSocket + override val source = delegate.source.buffer() + override val sink = delegate.sink.buffer() + + override fun cancel() { + delegate.cancel() + } + } diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/ConnectPlan.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/ConnectPlan.kt index a049db130..2ad364175 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/ConnectPlan.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/ConnectPlan.kt @@ -42,11 +42,6 @@ import okhttp3.internal.http1.Http1ExchangeCodec import okhttp3.internal.platform.Platform import okhttp3.internal.tls.OkHostnameVerifier import okhttp3.internal.toHostHeader -import okio.BufferedSink -import okio.BufferedSource -import okio.Socket as OkioSocket -import okio.asOkioSocket -import okio.buffer /** * A single attempt to connect to a remote server, including these steps: @@ -94,9 +89,7 @@ class ConnectPlan( internal var javaNetSocket: JavaNetSocket? = null private var handshake: Handshake? = null private var protocol: Protocol? = null - private lateinit var okioSocket: OkioSocket - private lateinit var source: BufferedSource - private lateinit var sink: BufferedSink + private lateinit var socket: BufferedSocket private var connection: RealConnection? = null /** True if this connection is ready for use, including TCP, tunnels, and TLS. */ @@ -185,7 +178,7 @@ class ConnectPlan( // that happens, then we will have buffered bytes that are needed by the SSLSocket! // This check is imperfect: it doesn't tell us whether a handshake will succeed, just // that it will almost certainly fail because the proxy has sent unexpected data. - if (!source.buffer.exhausted() || !sink.buffer.exhausted()) { + if (!socket.source.buffer.exhausted() || !socket.sink.buffer.exhausted()) { throw IOException("TLS tunnel buffered too many bytes!") } @@ -225,12 +218,10 @@ class ConnectPlan( connectionPool = connectionPool, route = route, rawSocket = rawSocket, - socket = javaNetSocket!!, + javaNetSocket = javaNetSocket!!, handshake = handshake, protocol = protocol!!, - okioSocket = okioSocket, - source = source, - sink = sink, + socket = socket, pingIntervalMillis = pingIntervalMillis, connectionListener = connectionPool.connectionListener, ) @@ -291,9 +282,7 @@ class ConnectPlan( // https://github.com/square/okhttp/issues/3245 // https://android-review.googlesource.com/#/c/271775/ try { - okioSocket = rawSocket.asOkioSocket() - source = okioSocket.source.buffer() - sink = okioSocket.sink.buffer() + this.socket = rawSocket.asBufferedSocket() } catch (npe: NullPointerException) { if (npe.message == NPE_THROW_WITH_NULL) { throw IOException(npe) @@ -408,9 +397,7 @@ class ConnectPlan( null } javaNetSocket = sslSocket - okioSocket = sslSocket.asOkioSocket() - source = okioSocket.source.buffer() - sink = okioSocket.sink.buffer() + socket = sslSocket.asBufferedSocket() protocol = if (maybeProtocol != null) Protocol.get(maybeProtocol) else Protocol.HTTP_1_1 success = true } finally { @@ -437,12 +424,10 @@ class ConnectPlan( // No client for CONNECT tunnels: client = null, carrier = this, - socket = okioSocket, - source = source, - sink = sink, + socket = socket, ) - source.timeout().timeout(readTimeoutMillis.toLong(), TimeUnit.MILLISECONDS) - sink.timeout().timeout(writeTimeoutMillis.toLong(), TimeUnit.MILLISECONDS) + socket.source.timeout().timeout(readTimeoutMillis.toLong(), TimeUnit.MILLISECONDS) + socket.sink.timeout().timeout(writeTimeoutMillis.toLong(), TimeUnit.MILLISECONDS) tunnelCodec.writeRequest(nextRequest.headers, requestLine) tunnelCodec.finishRequest() val response = diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealConnection.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealConnection.kt index 520a617f0..1db62231e 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealConnection.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/connection/RealConnection.kt @@ -52,12 +52,6 @@ import okhttp3.internal.http2.StreamResetException import okhttp3.internal.isHealthy import okhttp3.internal.tls.OkHostnameVerifier import okio.Buffer -import okio.BufferedSink -import okio.BufferedSource -import okio.Sink -import okio.Source -import okio.Timeout -import okio.buffer /** * A connection to a remote web server capable of carrying 1 or more concurrent streams. @@ -75,12 +69,10 @@ class RealConnection internal constructor( * The application layer socket. Either an [SSLSocket] layered over [rawSocket], or [rawSocket] * itself if this connection does not use SSL. */ - private val socket: JavaNetSocket, + private val javaNetSocket: JavaNetSocket, private val handshake: Handshake?, private val protocol: Protocol, - private val okioSocket: okio.Socket, - private val source: BufferedSource, - private val sink: BufferedSink, + private val socket: BufferedSocket, private val pingIntervalMillis: Int, internal val connectionListener: ConnectionListener, ) : Http2Connection.Listener(), @@ -167,12 +159,12 @@ class RealConnection internal constructor( @Throws(IOException::class) private fun startHttp2() { - socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream. + javaNetSocket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream. val flowControlListener = connectionListener as? FlowControlListener ?: FlowControlListener.None val http2Connection = Http2Connection .Builder(client = true, taskRunner) - .socket(socket, route.address.url.host, source, sink) + .socket(socket, route.address.url.host) .listener(this) .pingIntervalMillis(pingIntervalMillis) .flowControlListener(flowControlListener) @@ -277,23 +269,21 @@ class RealConnection internal constructor( client: OkHttpClient, chain: RealInterceptorChain, ): ExchangeCodec { - val socket = this.socket - val source = this.source - val sink = this.sink + val okHttpSocket = this.socket val http2Connection = this.http2Connection return if (http2Connection != null) { Http2ExchangeCodec(client, this, chain, http2Connection) } else { - socket.soTimeout = chain.readTimeoutMillis() - source.timeout().timeout(chain.readTimeoutMillis.toLong(), MILLISECONDS) - sink.timeout().timeout(chain.writeTimeoutMillis.toLong(), MILLISECONDS) - Http1ExchangeCodec(client, this, okioSocket, source, sink) + javaNetSocket.soTimeout = chain.readTimeoutMillis() + okHttpSocket.source.timeout().timeout(chain.readTimeoutMillis.toLong(), MILLISECONDS) + okHttpSocket.sink.timeout().timeout(chain.writeTimeoutMillis.toLong(), MILLISECONDS) + Http1ExchangeCodec(client, this, okHttpSocket) } } internal fun useAsSocket() { - socket.soTimeout = 0 + javaNetSocket.soTimeout = 0 noNewExchanges() } @@ -304,7 +294,7 @@ class RealConnection internal constructor( rawSocket.closeQuietly() } - override fun socket(): JavaNetSocket = socket + override fun socket(): JavaNetSocket = javaNetSocket /** Returns true if this connection is ready to host new streams. */ fun isHealthy(doExtensiveChecks: Boolean): Boolean { @@ -313,9 +303,9 @@ class RealConnection internal constructor( val nowNs = System.nanoTime() if (rawSocket.isClosed || - socket.isClosed || - socket.isInputShutdown || - socket.isOutputShutdown + javaNetSocket.isClosed || + javaNetSocket.isInputShutdown || + javaNetSocket.isOutputShutdown ) { return false } @@ -327,7 +317,7 @@ class RealConnection internal constructor( val idleDurationNs = withLock { nowNs - idleAtNs } if (idleDurationNs >= IDLE_CONNECTION_HEALTHY_NS && doExtensiveChecks) { - return socket.isHealthy(source) + return javaNetSocket.isHealthy(socket.source) } return true @@ -452,33 +442,10 @@ class RealConnection internal constructor( socket: JavaNetSocket, idleAtNs: Long, ): RealConnection { - val okioSocket = - object : okio.Socket { - override val sink: Sink = - object : Sink { - override fun close() = Unit - - override fun flush() = Unit - - override fun timeout(): Timeout = Timeout.NONE - - override fun write( - source: Buffer, - byteCount: Long, - ): Unit = throw UnsupportedOperationException() - } - - override val source: Source = - object : Source { - override fun close() = Unit - - override fun read( - sink: Buffer, - byteCount: Long, - ): Long = throw UnsupportedOperationException() - - override fun timeout(): Timeout = Timeout.NONE - } + val bufferedSocket = + object : BufferedSocket { + override val sink = Buffer() + override val source = Buffer() override fun cancel() { } @@ -490,12 +457,10 @@ class RealConnection internal constructor( connectionPool = connectionPool, route = route, rawSocket = JavaNetSocket(), - socket = socket, + javaNetSocket = socket, handshake = null, protocol = Protocol.HTTP_2, - okioSocket = okioSocket, - source = okioSocket.source.buffer(), - sink = okioSocket.sink.buffer(), + socket = bufferedSocket, pingIntervalMillis = 0, connectionListener = ConnectionListener.NONE, ) diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http1/Http1ExchangeCodec.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http1/Http1ExchangeCodec.kt index 185b0f7f8..36e94d563 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http1/Http1ExchangeCodec.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http1/Http1ExchangeCodec.kt @@ -26,6 +26,7 @@ import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response import okhttp3.internal.checkOffsetAndCount +import okhttp3.internal.connection.BufferedSocket import okhttp3.internal.discard import okhttp3.internal.headersContentLength import okhttp3.internal.http.ExchangeCodec @@ -37,11 +38,8 @@ import okhttp3.internal.http.receiveHeaders import okhttp3.internal.http1.Http1ExchangeCodec.Companion.TRAILERS_RESPONSE_BODY_TRUNCATED import okhttp3.internal.skipAll import okio.Buffer -import okio.BufferedSink -import okio.BufferedSource import okio.ForwardingTimeout import okio.Sink -import okio.Socket import okio.Source import okio.Timeout @@ -66,12 +64,10 @@ class Http1ExchangeCodec( /** The client that configures this stream. May be null for HTTPS proxy tunnels. */ private val client: OkHttpClient?, override val carrier: ExchangeCodec.Carrier, - override val socket: Socket, - private val source: BufferedSource, - private val sink: BufferedSink, + override val socket: BufferedSocket, ) : ExchangeCodec { private var state = STATE_IDLE - private val headersReader = HeadersReader(source) + private val headersReader = HeadersReader(socket.source) private val Response.isChunked: Boolean get() = "chunked".equals(header("Transfer-Encoding"), ignoreCase = true) @@ -161,11 +157,11 @@ class Http1ExchangeCodec( } override fun flushRequest() { - sink.flush() + socket.sink.flush() } override fun finishRequest() { - sink.flush() + socket.sink.flush() } /** Returns bytes of a request header for sending on an HTTP transport. */ @@ -174,15 +170,15 @@ class Http1ExchangeCodec( requestLine: String, ) { check(state == STATE_IDLE) { "state: $state" } - sink.writeUtf8(requestLine).writeUtf8("\r\n") + socket.sink.writeUtf8(requestLine).writeUtf8("\r\n") for (i in 0 until headers.size) { - sink + socket.sink .writeUtf8(headers.name(i)) .writeUtf8(": ") .writeUtf8(headers.value(i)) .writeUtf8("\r\n") } - sink.writeUtf8("\r\n") + socket.sink.writeUtf8("\r\n") state = STATE_OPEN_REQUEST_BODY } @@ -295,7 +291,7 @@ class Http1ExchangeCodec( /** An HTTP request body. */ private inner class KnownLengthSink : Sink { - private val timeout = ForwardingTimeout(sink.timeout()) + private val timeout = ForwardingTimeout(socket.sink.timeout()) private var closed: Boolean = false override fun timeout(): Timeout = timeout @@ -306,12 +302,12 @@ class Http1ExchangeCodec( ) { check(!closed) { "closed" } checkOffsetAndCount(source.size, 0, byteCount) - sink.write(source, byteCount) + socket.sink.write(source, byteCount) } override fun flush() { if (closed) return // Don't throw; this stream might have been closed on the caller's behalf. - sink.flush() + socket.sink.flush() } override fun close() { @@ -327,7 +323,7 @@ class Http1ExchangeCodec( * to buffer chunks; typically by using a buffered sink with this sink. */ private inner class ChunkedSink : Sink { - private val timeout = ForwardingTimeout(sink.timeout()) + private val timeout = ForwardingTimeout(socket.sink.timeout()) private var closed: Boolean = false override fun timeout(): Timeout = timeout @@ -339,23 +335,25 @@ class Http1ExchangeCodec( check(!closed) { "closed" } if (byteCount == 0L) return - sink.writeHexadecimalUnsignedLong(byteCount) - sink.writeUtf8("\r\n") - sink.write(source, byteCount) - sink.writeUtf8("\r\n") + with(socket.sink) { + writeHexadecimalUnsignedLong(byteCount) + writeUtf8("\r\n") + write(source, byteCount) + writeUtf8("\r\n") + } } @Synchronized override fun flush() { if (closed) return // Don't throw; this stream might have been closed on the caller's behalf. - sink.flush() + socket.sink.flush() } @Synchronized override fun close() { if (closed) return closed = true - sink.writeUtf8("0\r\n\r\n") + socket.sink.writeUtf8("0\r\n\r\n") detachTimeout(timeout) state = STATE_READ_RESPONSE_HEADERS } @@ -364,7 +362,7 @@ class Http1ExchangeCodec( private abstract inner class AbstractSource( val url: HttpUrl, ) : Source { - protected val timeout = ForwardingTimeout(source.timeout()) + protected val timeout = ForwardingTimeout(socket.source.timeout()) protected var closed: Boolean = false override fun timeout(): Timeout = timeout @@ -374,7 +372,7 @@ class Http1ExchangeCodec( byteCount: Long, ): Long = try { - source.read(sink, byteCount) + socket.source.read(sink, byteCount) } catch (e: IOException) { carrier.noNewExchanges() responseBodyComplete(TRAILERS_RESPONSE_BODY_TRUNCATED) @@ -481,11 +479,11 @@ class Http1ExchangeCodec( private fun readChunkSize() { // Read the suffix of the previous chunk. if (bytesRemainingInChunk != NO_CHUNK_YET) { - source.readUtf8LineStrict() + socket.source.readUtf8LineStrict() } try { - bytesRemainingInChunk = source.readHexadecimalUnsignedLong() - val extensions = source.readUtf8LineStrict().trim() + bytesRemainingInChunk = socket.source.readHexadecimalUnsignedLong() + val extensions = socket.source.readUtf8LineStrict().trim() if (bytesRemainingInChunk < 0L || extensions.isNotEmpty() && !extensions.startsWith(";")) { throw ProtocolException( "expected chunk size and optional extensions" + diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http2/Http2Connection.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http2/Http2Connection.kt index c732bee1d..c293d5740 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http2/Http2Connection.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/http2/Http2Connection.kt @@ -18,7 +18,6 @@ package okhttp3.internal.http2 import java.io.Closeable import java.io.IOException import java.io.InterruptedIOException -import java.net.Socket import java.util.concurrent.TimeUnit import okhttp3.Headers import okhttp3.internal.EMPTY_BYTE_ARRAY @@ -29,22 +28,18 @@ import okhttp3.internal.concurrent.assertLockNotHeld import okhttp3.internal.concurrent.notifyAll import okhttp3.internal.concurrent.wait import okhttp3.internal.concurrent.withLock +import okhttp3.internal.connection.BufferedSocket import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE import okhttp3.internal.http2.flowcontrol.WindowCounter import okhttp3.internal.ignoreIoExceptions import okhttp3.internal.okHttpName -import okhttp3.internal.peerName import okhttp3.internal.platform.Platform import okhttp3.internal.platform.Platform.Companion.INFO import okhttp3.internal.toHeaders import okio.Buffer -import okio.BufferedSink import okio.BufferedSource import okio.ByteString -import okio.buffer -import okio.sink -import okio.source /** * A socket connection to a remote peer. A connection hosts streams which can send and receive @@ -140,11 +135,11 @@ class Http2Connection internal constructor( var writeBytesMaximum: Long = peerSettings.initialWindowSize.toLong() private set - internal val socket: Socket = builder.socket - val writer = Http2Writer(builder.sink, client) + internal val socket: BufferedSocket = builder.socket + val writer = Http2Writer(socket.sink, client) // Visible for testing - val readerRunnable = ReaderRunnable(Http2Reader(builder.source, client)) + val readerRunnable = ReaderRunnable(Http2Reader(socket.source, client)) // Guarded by this. private val currentPushRequests = mutableSetOf() @@ -479,9 +474,9 @@ class Http2Connection internal constructor( writer.close() } - // Close the socket to break out the reader thread, which will clean up after itself. + // Cancel the socket to break out the reader thread, which will clean up after itself. ignoreIoExceptions { - socket.close() + socket.cancel() } // Release the threads. @@ -574,22 +569,17 @@ class Http2Connection internal constructor( internal var client: Boolean, internal val taskRunner: TaskRunner, ) { - internal lateinit var socket: Socket + internal lateinit var socket: BufferedSocket internal lateinit var connectionName: String - internal lateinit var source: BufferedSource - internal lateinit var sink: BufferedSink internal var listener = Listener.REFUSE_INCOMING_STREAMS internal var pushObserver = PushObserver.CANCEL internal var pingIntervalMillis: Int = 0 internal var flowControlListener: FlowControlListener = FlowControlListener.None @Throws(IOException::class) - @JvmOverloads fun socket( - socket: Socket, - peerName: String = socket.peerName(), - source: BufferedSource = socket.source().buffer(), - sink: BufferedSink = socket.sink().buffer(), + socket: BufferedSocket, + peerName: String, ) = apply { this.socket = socket this.connectionName = @@ -597,8 +587,6 @@ class Http2Connection internal constructor( client -> "$okHttpName $peerName" else -> "MockWebServer $peerName" } - this.source = source - this.sink = sink } fun listener(listener: Listener) = @@ -800,7 +788,7 @@ class Http2Connection internal constructor( } } if (streamsToNotify != null) { - for (stream in streamsToNotify!!) { + for (stream in streamsToNotify) { stream.withLock { stream.addBytesToWriteWindow(delta) } diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/RealWebSocket.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/RealWebSocket.kt index 1674f8b85..e6e87a10c 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/RealWebSocket.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/RealWebSocket.kt @@ -36,7 +36,9 @@ import okhttp3.internal.concurrent.Lockable import okhttp3.internal.concurrent.Task import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.concurrent.assertLockHeld +import okhttp3.internal.connection.BufferedSocket import okhttp3.internal.connection.RealCall +import okhttp3.internal.connection.asBufferedSocket import okhttp3.internal.okHttpName import okhttp3.internal.ws.WebSocketProtocol.CLOSE_CLIENT_GOING_AWAY import okhttp3.internal.ws.WebSocketProtocol.CLOSE_MESSAGE_MAX @@ -49,7 +51,6 @@ import okio.ByteString import okio.ByteString.Companion.encodeUtf8 import okio.ByteString.Companion.toByteString import okio.Socket -import okio.buffer class RealWebSocket( taskRunner: TaskRunner, @@ -197,9 +198,7 @@ class RealWebSocket( val name = "$okHttpName WebSocket ${request.url.redact()}" initReaderAndWriter( name = name, - socket = socket, - socketSource = socket.source.buffer(), - socketSink = socket.sink.buffer(), + socket = socket.asBufferedSocket(), client = true, ) loopReader(response) @@ -269,9 +268,7 @@ class RealWebSocket( */ fun initReaderAndWriter( name: String, - socket: Socket, - socketSource: BufferedSource, - socketSink: BufferedSink, + socket: BufferedSocket, client: Boolean, ) { val extensions = this.extensions!! @@ -281,7 +278,7 @@ class RealWebSocket( this.writer = WebSocketWriter( isClient = client, - sink = socketSink, + sink = socket.sink, random = random, perMessageDeflate = extensions.perMessageDeflate, noContextTakeover = extensions.noContextTakeover(client), @@ -303,7 +300,7 @@ class RealWebSocket( reader = WebSocketReader( isClient = client, - source = socketSource, + source = socket.source, frameCallback = this, perMessageDeflate = extensions.perMessageDeflate, noContextTakeover = extensions.noContextTakeover(!client), diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/WebSocketWriter.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/WebSocketWriter.kt index f6c594045..3416bc1dd 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/WebSocketWriter.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/ws/WebSocketWriter.kt @@ -204,7 +204,7 @@ class WebSocketWriter( } sinkBuffer.write(messageBuffer, dataSize) - sink.emit() + sink.flush() } override fun close() { diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/CallTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/CallTest.kt index f9a5a5654..8927363bf 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/CallTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/CallTest.kt @@ -2823,7 +2823,9 @@ open class CallTest { call.enqueue(callback) call.cancel() latch.countDown() - callback.await(server.url("/a")).assertFailure("Canceled", "Socket closed", "Socket is closed") + callback + .await(server.url("/a")) + .assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed") } @Test @@ -2831,7 +2833,9 @@ open class CallTest { val call = client.newCall(Request(server.url("/"))) call.enqueue(callback) client.dispatcher.cancelAll() - callback.await(server.url("/")).assertFailure("Canceled", "Socket closed", "Socket is closed") + callback + .await(server.url("/")) + .assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed") } @Test @@ -2942,7 +2946,9 @@ open class CallTest { assertThat(server.takeRequest().url.encodedPath).isEqualTo("/a") callback.await(requestA.url).assertBody("A") // At this point we know the callback is ready, and that it will receive a cancel failure. - callback.await(requestB.url).assertFailure("Canceled", "Socket closed") + callback + .await(requestB.url) + .assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed") } @Test @@ -4715,7 +4721,7 @@ open class CallTest { .build() executeSynchronously("/") .assertFailure(IOException::class.java) - .assertFailure("Socket closed", "Socket is closed") + .assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed") } @Test @@ -4805,7 +4811,8 @@ open class CallTest { .build() }, ).build() - executeSynchronously("/").assertFailure("Canceled") + executeSynchronously("/") + .assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed") assertThat(closed.get()).isTrue() } diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/internal/connection/ConnectionPoolTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/internal/connection/ConnectionPoolTest.kt index e952dc7f6..09e8f55be 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/internal/connection/ConnectionPoolTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/internal/connection/ConnectionPoolTest.kt @@ -316,7 +316,7 @@ class ConnectionPoolTest { val connection = Http2Connection .Builder(true, TaskRunner.INSTANCE) - .socket(peer.openSocket()) + .socket(peer.openSocket().asBufferedSocket(), "peer") .pushObserver(Http2ConnectionTest.IGNORE) .listener(realConnection) .build() diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt index 337d9791f..a9666fc39 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/internal/http/HttpUpgradesTest.kt @@ -82,12 +82,12 @@ class HttpUpgradesTest { socket.sink.buffer().use { sink -> socket.source.buffer().use { source -> sink.writeUtf8("client says hello\n") - sink.emit() + sink.flush() assertThat(source.readUtf8Line()).isEqualTo("server says hello") sink.writeUtf8("client says goodbye\n") - sink.emit() + sink.flush() assertThat(source.readUtf8Line()).isEqualTo("server says goodbye") diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/internal/http2/Http2ConnectionTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/internal/http2/Http2ConnectionTest.kt index 5ed6152a9..dd05b1f69 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/internal/http2/Http2ConnectionTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/internal/http2/Http2ConnectionTest.kt @@ -43,6 +43,7 @@ import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.concurrent.notifyAll import okhttp3.internal.concurrent.wait import okhttp3.internal.concurrent.withLock +import okhttp3.internal.connection.asBufferedSocket import okio.AsyncTimeout import okio.Buffer import okio.BufferedSource @@ -497,7 +498,7 @@ class Http2ConnectionTest { val connection = Http2Connection .Builder(true, TaskRunner.INSTANCE) - .socket(socket) + .socket(socket.asBufferedSocket(), "peer") .pushObserver(IGNORE) .build() connection.start(sendConnectionPreface = false) @@ -1890,7 +1891,7 @@ class Http2ConnectionTest { val connection = Http2Connection .Builder(true, TaskRunner.INSTANCE) - .socket(peer.openSocket()) + .socket(peer.openSocket().asBufferedSocket(), "peer") .build() connection.start(sendConnectionPreface = false) val stream = connection.newStream(headerEntries("b", "banana"), false) @@ -1912,11 +1913,10 @@ class Http2ConnectionTest { peer.acceptFrame() // SYN_STREAM. peer.play() val taskRunner = taskFaker.taskRunner - val socket = peer.openSocket() val connection = Http2Connection .Builder(true, taskRunner) - .socket(socket) + .socket(peer.openSocket().asBufferedSocket(), "peer") .pushObserver(IGNORE) .build() connection.start(sendConnectionPreface = false) @@ -1972,7 +1972,7 @@ class Http2ConnectionTest { val connection = Http2Connection .Builder(true, TaskRunner.INSTANCE) - .socket(peer.openSocket()) + .socket(peer.openSocket().asBufferedSocket(), "peer") .pushObserver(pushObserver) .listener(listener) .build() diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/internal/ws/RealWebSocketTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/internal/ws/RealWebSocketTest.kt index 155c6b2cf..fd847acab 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/internal/ws/RealWebSocketTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/internal/ws/RealWebSocketTest.kt @@ -35,6 +35,7 @@ import okhttp3.Request import okhttp3.Response import okhttp3.TestUtil.repeat import okhttp3.internal.concurrent.TaskFaker +import okhttp3.internal.connection.BufferedSocket import okhttp3.internal.ws.WebSocketExtensions.Companion.parse import okio.BufferedSink import okio.BufferedSource @@ -466,7 +467,7 @@ class RealWebSocketTest { private val taskFaker: TaskFaker, private val delegate: Socket, private val client: Boolean, - ) : Socket { + ) : BufferedSocket { private val name = if (client) "client" else "server" val listener = WebSocketRecorder(name) var webSocket: RealWebSocket? = null @@ -530,7 +531,7 @@ class RealWebSocketTest { } } } - webSocket!!.initReaderAndWriter(name, this, source, sink, client) + webSocket!!.initReaderAndWriter(name, this, client) } /**