1
0
mirror of https://github.com/square/okhttp.git synced 2025-08-08 23:42:08 +03:00

New BufferedSocket class (#8977)

* New BufferedSocket class

I'm starting to push Okio's socket throughout OkHttp,
and it turns out what I usually want is a BufferedSocket
instead.

* Flush not emit

* Track exception messages

* Use flush more
This commit is contained in:
Jesse Wilson
2025-07-31 00:18:40 -04:00
committed by GitHub
parent 3bafcf16fb
commit 4e7860212f
18 changed files with 156 additions and 181 deletions

View File

@@ -516,7 +516,7 @@ public class MockWebServer : Closeable {
val connection = val connection =
Http2Connection Http2Connection
.Builder(false, taskRunner) .Builder(false, taskRunner)
.socket(socket.javaNetSocket) .socket(socket, socket.javaNetSocket.remoteSocketAddress.toString())
.listener(http2SocketHandler) .listener(http2SocketHandler)
.build() .build()
connection.start() connection.start()
@@ -836,8 +836,6 @@ public class MockWebServer : Closeable {
webSocket.initReaderAndWriter( webSocket.initReaderAndWriter(
name = name, name = name,
socket = socket, socket = socket,
socketSource = socket.source,
socketSink = socket.sink,
client = false, client = false,
) )

View File

@@ -23,6 +23,7 @@ import java.util.concurrent.CountDownLatch
import javax.net.ssl.SSLSocket import javax.net.ssl.SSLSocket
import okhttp3.Handshake import okhttp3.Handshake
import okhttp3.Handshake.Companion.handshake import okhttp3.Handshake.Companion.handshake
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.platform.Platform import okhttp3.internal.platform.Platform
import okio.BufferedSink import okio.BufferedSink
import okio.BufferedSource import okio.BufferedSource
@@ -40,7 +41,7 @@ import okio.buffer
internal class MockWebServerSocket( internal class MockWebServerSocket(
val javaNetSocket: Socket, val javaNetSocket: Socket,
) : Closeable, ) : Closeable,
okio.Socket { BufferedSocket {
private val delegate = javaNetSocket.asOkioSocket() private val delegate = javaNetSocket.asOkioSocket()
private val closedLatch = CountDownLatch(2) private val closedLatch = CountDownLatch(2)

View File

@@ -17,6 +17,7 @@ package mockwebserver3.internal.http2
import java.io.File import java.io.File
import java.io.IOException import java.io.IOException
import java.net.InetSocketAddress
import java.net.ProtocolException import java.net.ProtocolException
import java.net.ServerSocket import java.net.ServerSocket
import java.net.Socket import java.net.Socket
@@ -28,6 +29,7 @@ import okhttp3.Protocol
import okhttp3.Protocol.Companion.get import okhttp3.Protocol.Companion.get
import okhttp3.internal.closeQuietly import okhttp3.internal.closeQuietly
import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.connection.asBufferedSocket
import okhttp3.internal.http2.Header import okhttp3.internal.http2.Header
import okhttp3.internal.http2.Http2Connection import okhttp3.internal.http2.Http2Connection
import okhttp3.internal.http2.Http2Stream import okhttp3.internal.http2.Http2Stream
@@ -57,7 +59,7 @@ class Http2Server(
val connection = val connection =
Http2Connection Http2Connection
.Builder(false, TaskRunner.INSTANCE) .Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket) .socket(sslSocket.asBufferedSocket(), sslSocket.peerName())
.listener(this) .listener(this)
.build() .build()
connection.start() connection.start()
@@ -192,6 +194,11 @@ class Http2Server(
else -> "text/plain" else -> "text/plain"
} }
private fun Socket.peerName(): String {
val address = remoteSocketAddress
return if (address is InetSocketAddress) address.hostName else address.toString()
}
companion object { companion object {
val logger: Logger = Logger.getLogger(Http2Server::class.java.name) val logger: Logger = Logger.getLogger(Http2Server::class.java.name)

View File

@@ -21,25 +21,13 @@ import java.util.concurrent.FutureTask
import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import mockwebserver3.SocketHandler import mockwebserver3.SocketHandler
import okio.BufferedSink import okhttp3.internal.connection.BufferedSocket
import okio.BufferedSource import okhttp3.internal.connection.asBufferedSocket
import okio.Socket import okio.Socket
import okio.buffer
import okio.utf8Size import okio.utf8Size
private typealias Action = (BufferedSocket) -> Unit private typealias Action = (BufferedSocket) -> Unit
private class BufferedSocket(
val socket: Socket,
) {
val source: BufferedSource = socket.source.buffer()
val sink: BufferedSink = socket.sink.buffer()
fun cancel() {
socket.cancel()
}
}
/** /**
* A scriptable request/response conversation. Create the script by calling methods like * A scriptable request/response conversation. Create the script by calling methods like
* [receiveRequest] in the sequence they are run. * [receiveRequest] in the sequence they are run.
@@ -104,7 +92,7 @@ class MockSocketHandler : SocketHandler {
} }
override fun handle(socket: Socket) { override fun handle(socket: Socket) {
val task = serviceSocketTask(BufferedSocket(socket)) val task = serviceSocketTask(socket.asBufferedSocket())
results.add(task) results.add(task)
task.run() task.run()
} }

View File

@@ -16,12 +16,12 @@
package okhttp3.tls.internal package okhttp3.tls.internal
import java.net.InetSocketAddress
import java.net.Socket import java.net.Socket
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import javax.net.ssl.SSLEngine import javax.net.ssl.SSLEngine
import javax.net.ssl.X509ExtendedTrustManager import javax.net.ssl.X509ExtendedTrustManager
import okhttp3.internal.peerName
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement
/** /**
@@ -78,4 +78,9 @@ internal class InsecureExtendedTrustManager(
authType: String, authType: String,
socket: Socket?, socket: Socket?,
) = throw CertificateException("Unsupported operation") ) = throw CertificateException("Unsupported operation")
private fun Socket.peerName(): String {
val address = remoteSocketAddress
return if (address is InetSocketAddress) address.hostName else address.toString()
}
} }

View File

@@ -19,7 +19,6 @@ package okhttp3.internal
import java.io.IOException import java.io.IOException
import java.io.InterruptedIOException import java.io.InterruptedIOException
import java.net.InetSocketAddress
import java.net.ServerSocket import java.net.ServerSocket
import java.net.Socket import java.net.Socket
import java.net.SocketTimeoutException import java.net.SocketTimeoutException
@@ -197,11 +196,6 @@ internal fun Source.discard(
false false
} }
internal fun Socket.peerName(): String {
val address = remoteSocketAddress
return if (address is InetSocketAddress) address.hostName else address.toString()
}
/** /**
* Returns true if new reads and writes should be attempted on this. * Returns true if new reads and writes should be attempted on this.
* *

View File

@@ -0,0 +1,41 @@
/*
* Copyright (C) 2025 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.connection
import java.net.Socket as JavaNetSocket
import okio.BufferedSink
import okio.BufferedSource
import okio.Socket as OkioSocket
import okio.asOkioSocket
import okio.buffer
interface BufferedSocket : OkioSocket {
override val source: BufferedSource
override val sink: BufferedSink
}
fun JavaNetSocket.asBufferedSocket(): BufferedSocket = asOkioSocket().asBufferedSocket()
fun OkioSocket.asBufferedSocket(): BufferedSocket =
object : BufferedSocket {
private val delegate = this@asBufferedSocket
override val source = delegate.source.buffer()
override val sink = delegate.sink.buffer()
override fun cancel() {
delegate.cancel()
}
}

View File

@@ -42,11 +42,6 @@ import okhttp3.internal.http1.Http1ExchangeCodec
import okhttp3.internal.platform.Platform import okhttp3.internal.platform.Platform
import okhttp3.internal.tls.OkHostnameVerifier import okhttp3.internal.tls.OkHostnameVerifier
import okhttp3.internal.toHostHeader import okhttp3.internal.toHostHeader
import okio.BufferedSink
import okio.BufferedSource
import okio.Socket as OkioSocket
import okio.asOkioSocket
import okio.buffer
/** /**
* A single attempt to connect to a remote server, including these steps: * A single attempt to connect to a remote server, including these steps:
@@ -94,9 +89,7 @@ class ConnectPlan(
internal var javaNetSocket: JavaNetSocket? = null internal var javaNetSocket: JavaNetSocket? = null
private var handshake: Handshake? = null private var handshake: Handshake? = null
private var protocol: Protocol? = null private var protocol: Protocol? = null
private lateinit var okioSocket: OkioSocket private lateinit var socket: BufferedSocket
private lateinit var source: BufferedSource
private lateinit var sink: BufferedSink
private var connection: RealConnection? = null private var connection: RealConnection? = null
/** True if this connection is ready for use, including TCP, tunnels, and TLS. */ /** True if this connection is ready for use, including TCP, tunnels, and TLS. */
@@ -185,7 +178,7 @@ class ConnectPlan(
// that happens, then we will have buffered bytes that are needed by the SSLSocket! // that happens, then we will have buffered bytes that are needed by the SSLSocket!
// This check is imperfect: it doesn't tell us whether a handshake will succeed, just // This check is imperfect: it doesn't tell us whether a handshake will succeed, just
// that it will almost certainly fail because the proxy has sent unexpected data. // that it will almost certainly fail because the proxy has sent unexpected data.
if (!source.buffer.exhausted() || !sink.buffer.exhausted()) { if (!socket.source.buffer.exhausted() || !socket.sink.buffer.exhausted()) {
throw IOException("TLS tunnel buffered too many bytes!") throw IOException("TLS tunnel buffered too many bytes!")
} }
@@ -225,12 +218,10 @@ class ConnectPlan(
connectionPool = connectionPool, connectionPool = connectionPool,
route = route, route = route,
rawSocket = rawSocket, rawSocket = rawSocket,
socket = javaNetSocket!!, javaNetSocket = javaNetSocket!!,
handshake = handshake, handshake = handshake,
protocol = protocol!!, protocol = protocol!!,
okioSocket = okioSocket, socket = socket,
source = source,
sink = sink,
pingIntervalMillis = pingIntervalMillis, pingIntervalMillis = pingIntervalMillis,
connectionListener = connectionPool.connectionListener, connectionListener = connectionPool.connectionListener,
) )
@@ -291,9 +282,7 @@ class ConnectPlan(
// https://github.com/square/okhttp/issues/3245 // https://github.com/square/okhttp/issues/3245
// https://android-review.googlesource.com/#/c/271775/ // https://android-review.googlesource.com/#/c/271775/
try { try {
okioSocket = rawSocket.asOkioSocket() this.socket = rawSocket.asBufferedSocket()
source = okioSocket.source.buffer()
sink = okioSocket.sink.buffer()
} catch (npe: NullPointerException) { } catch (npe: NullPointerException) {
if (npe.message == NPE_THROW_WITH_NULL) { if (npe.message == NPE_THROW_WITH_NULL) {
throw IOException(npe) throw IOException(npe)
@@ -408,9 +397,7 @@ class ConnectPlan(
null null
} }
javaNetSocket = sslSocket javaNetSocket = sslSocket
okioSocket = sslSocket.asOkioSocket() socket = sslSocket.asBufferedSocket()
source = okioSocket.source.buffer()
sink = okioSocket.sink.buffer()
protocol = if (maybeProtocol != null) Protocol.get(maybeProtocol) else Protocol.HTTP_1_1 protocol = if (maybeProtocol != null) Protocol.get(maybeProtocol) else Protocol.HTTP_1_1
success = true success = true
} finally { } finally {
@@ -437,12 +424,10 @@ class ConnectPlan(
// No client for CONNECT tunnels: // No client for CONNECT tunnels:
client = null, client = null,
carrier = this, carrier = this,
socket = okioSocket, socket = socket,
source = source,
sink = sink,
) )
source.timeout().timeout(readTimeoutMillis.toLong(), TimeUnit.MILLISECONDS) socket.source.timeout().timeout(readTimeoutMillis.toLong(), TimeUnit.MILLISECONDS)
sink.timeout().timeout(writeTimeoutMillis.toLong(), TimeUnit.MILLISECONDS) socket.sink.timeout().timeout(writeTimeoutMillis.toLong(), TimeUnit.MILLISECONDS)
tunnelCodec.writeRequest(nextRequest.headers, requestLine) tunnelCodec.writeRequest(nextRequest.headers, requestLine)
tunnelCodec.finishRequest() tunnelCodec.finishRequest()
val response = val response =

View File

@@ -52,12 +52,6 @@ import okhttp3.internal.http2.StreamResetException
import okhttp3.internal.isHealthy import okhttp3.internal.isHealthy
import okhttp3.internal.tls.OkHostnameVerifier import okhttp3.internal.tls.OkHostnameVerifier
import okio.Buffer import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource
import okio.Sink
import okio.Source
import okio.Timeout
import okio.buffer
/** /**
* A connection to a remote web server capable of carrying 1 or more concurrent streams. * A connection to a remote web server capable of carrying 1 or more concurrent streams.
@@ -75,12 +69,10 @@ class RealConnection internal constructor(
* The application layer socket. Either an [SSLSocket] layered over [rawSocket], or [rawSocket] * The application layer socket. Either an [SSLSocket] layered over [rawSocket], or [rawSocket]
* itself if this connection does not use SSL. * itself if this connection does not use SSL.
*/ */
private val socket: JavaNetSocket, private val javaNetSocket: JavaNetSocket,
private val handshake: Handshake?, private val handshake: Handshake?,
private val protocol: Protocol, private val protocol: Protocol,
private val okioSocket: okio.Socket, private val socket: BufferedSocket,
private val source: BufferedSource,
private val sink: BufferedSink,
private val pingIntervalMillis: Int, private val pingIntervalMillis: Int,
internal val connectionListener: ConnectionListener, internal val connectionListener: ConnectionListener,
) : Http2Connection.Listener(), ) : Http2Connection.Listener(),
@@ -167,12 +159,12 @@ class RealConnection internal constructor(
@Throws(IOException::class) @Throws(IOException::class)
private fun startHttp2() { private fun startHttp2() {
socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream. javaNetSocket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream.
val flowControlListener = connectionListener as? FlowControlListener ?: FlowControlListener.None val flowControlListener = connectionListener as? FlowControlListener ?: FlowControlListener.None
val http2Connection = val http2Connection =
Http2Connection Http2Connection
.Builder(client = true, taskRunner) .Builder(client = true, taskRunner)
.socket(socket, route.address.url.host, source, sink) .socket(socket, route.address.url.host)
.listener(this) .listener(this)
.pingIntervalMillis(pingIntervalMillis) .pingIntervalMillis(pingIntervalMillis)
.flowControlListener(flowControlListener) .flowControlListener(flowControlListener)
@@ -277,23 +269,21 @@ class RealConnection internal constructor(
client: OkHttpClient, client: OkHttpClient,
chain: RealInterceptorChain, chain: RealInterceptorChain,
): ExchangeCodec { ): ExchangeCodec {
val socket = this.socket val okHttpSocket = this.socket
val source = this.source
val sink = this.sink
val http2Connection = this.http2Connection val http2Connection = this.http2Connection
return if (http2Connection != null) { return if (http2Connection != null) {
Http2ExchangeCodec(client, this, chain, http2Connection) Http2ExchangeCodec(client, this, chain, http2Connection)
} else { } else {
socket.soTimeout = chain.readTimeoutMillis() javaNetSocket.soTimeout = chain.readTimeoutMillis()
source.timeout().timeout(chain.readTimeoutMillis.toLong(), MILLISECONDS) okHttpSocket.source.timeout().timeout(chain.readTimeoutMillis.toLong(), MILLISECONDS)
sink.timeout().timeout(chain.writeTimeoutMillis.toLong(), MILLISECONDS) okHttpSocket.sink.timeout().timeout(chain.writeTimeoutMillis.toLong(), MILLISECONDS)
Http1ExchangeCodec(client, this, okioSocket, source, sink) Http1ExchangeCodec(client, this, okHttpSocket)
} }
} }
internal fun useAsSocket() { internal fun useAsSocket() {
socket.soTimeout = 0 javaNetSocket.soTimeout = 0
noNewExchanges() noNewExchanges()
} }
@@ -304,7 +294,7 @@ class RealConnection internal constructor(
rawSocket.closeQuietly() rawSocket.closeQuietly()
} }
override fun socket(): JavaNetSocket = socket override fun socket(): JavaNetSocket = javaNetSocket
/** Returns true if this connection is ready to host new streams. */ /** Returns true if this connection is ready to host new streams. */
fun isHealthy(doExtensiveChecks: Boolean): Boolean { fun isHealthy(doExtensiveChecks: Boolean): Boolean {
@@ -313,9 +303,9 @@ class RealConnection internal constructor(
val nowNs = System.nanoTime() val nowNs = System.nanoTime()
if (rawSocket.isClosed || if (rawSocket.isClosed ||
socket.isClosed || javaNetSocket.isClosed ||
socket.isInputShutdown || javaNetSocket.isInputShutdown ||
socket.isOutputShutdown javaNetSocket.isOutputShutdown
) { ) {
return false return false
} }
@@ -327,7 +317,7 @@ class RealConnection internal constructor(
val idleDurationNs = withLock { nowNs - idleAtNs } val idleDurationNs = withLock { nowNs - idleAtNs }
if (idleDurationNs >= IDLE_CONNECTION_HEALTHY_NS && doExtensiveChecks) { if (idleDurationNs >= IDLE_CONNECTION_HEALTHY_NS && doExtensiveChecks) {
return socket.isHealthy(source) return javaNetSocket.isHealthy(socket.source)
} }
return true return true
@@ -452,33 +442,10 @@ class RealConnection internal constructor(
socket: JavaNetSocket, socket: JavaNetSocket,
idleAtNs: Long, idleAtNs: Long,
): RealConnection { ): RealConnection {
val okioSocket = val bufferedSocket =
object : okio.Socket { object : BufferedSocket {
override val sink: Sink = override val sink = Buffer()
object : Sink { override val source = Buffer()
override fun close() = Unit
override fun flush() = Unit
override fun timeout(): Timeout = Timeout.NONE
override fun write(
source: Buffer,
byteCount: Long,
): Unit = throw UnsupportedOperationException()
}
override val source: Source =
object : Source {
override fun close() = Unit
override fun read(
sink: Buffer,
byteCount: Long,
): Long = throw UnsupportedOperationException()
override fun timeout(): Timeout = Timeout.NONE
}
override fun cancel() { override fun cancel() {
} }
@@ -490,12 +457,10 @@ class RealConnection internal constructor(
connectionPool = connectionPool, connectionPool = connectionPool,
route = route, route = route,
rawSocket = JavaNetSocket(), rawSocket = JavaNetSocket(),
socket = socket, javaNetSocket = socket,
handshake = null, handshake = null,
protocol = Protocol.HTTP_2, protocol = Protocol.HTTP_2,
okioSocket = okioSocket, socket = bufferedSocket,
source = okioSocket.source.buffer(),
sink = okioSocket.sink.buffer(),
pingIntervalMillis = 0, pingIntervalMillis = 0,
connectionListener = ConnectionListener.NONE, connectionListener = ConnectionListener.NONE,
) )

View File

@@ -26,6 +26,7 @@ import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
import okhttp3.Response import okhttp3.Response
import okhttp3.internal.checkOffsetAndCount import okhttp3.internal.checkOffsetAndCount
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.discard import okhttp3.internal.discard
import okhttp3.internal.headersContentLength import okhttp3.internal.headersContentLength
import okhttp3.internal.http.ExchangeCodec import okhttp3.internal.http.ExchangeCodec
@@ -37,11 +38,8 @@ import okhttp3.internal.http.receiveHeaders
import okhttp3.internal.http1.Http1ExchangeCodec.Companion.TRAILERS_RESPONSE_BODY_TRUNCATED import okhttp3.internal.http1.Http1ExchangeCodec.Companion.TRAILERS_RESPONSE_BODY_TRUNCATED
import okhttp3.internal.skipAll import okhttp3.internal.skipAll
import okio.Buffer import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource
import okio.ForwardingTimeout import okio.ForwardingTimeout
import okio.Sink import okio.Sink
import okio.Socket
import okio.Source import okio.Source
import okio.Timeout import okio.Timeout
@@ -66,12 +64,10 @@ class Http1ExchangeCodec(
/** The client that configures this stream. May be null for HTTPS proxy tunnels. */ /** The client that configures this stream. May be null for HTTPS proxy tunnels. */
private val client: OkHttpClient?, private val client: OkHttpClient?,
override val carrier: ExchangeCodec.Carrier, override val carrier: ExchangeCodec.Carrier,
override val socket: Socket, override val socket: BufferedSocket,
private val source: BufferedSource,
private val sink: BufferedSink,
) : ExchangeCodec { ) : ExchangeCodec {
private var state = STATE_IDLE private var state = STATE_IDLE
private val headersReader = HeadersReader(source) private val headersReader = HeadersReader(socket.source)
private val Response.isChunked: Boolean private val Response.isChunked: Boolean
get() = "chunked".equals(header("Transfer-Encoding"), ignoreCase = true) get() = "chunked".equals(header("Transfer-Encoding"), ignoreCase = true)
@@ -161,11 +157,11 @@ class Http1ExchangeCodec(
} }
override fun flushRequest() { override fun flushRequest() {
sink.flush() socket.sink.flush()
} }
override fun finishRequest() { override fun finishRequest() {
sink.flush() socket.sink.flush()
} }
/** Returns bytes of a request header for sending on an HTTP transport. */ /** Returns bytes of a request header for sending on an HTTP transport. */
@@ -174,15 +170,15 @@ class Http1ExchangeCodec(
requestLine: String, requestLine: String,
) { ) {
check(state == STATE_IDLE) { "state: $state" } check(state == STATE_IDLE) { "state: $state" }
sink.writeUtf8(requestLine).writeUtf8("\r\n") socket.sink.writeUtf8(requestLine).writeUtf8("\r\n")
for (i in 0 until headers.size) { for (i in 0 until headers.size) {
sink socket.sink
.writeUtf8(headers.name(i)) .writeUtf8(headers.name(i))
.writeUtf8(": ") .writeUtf8(": ")
.writeUtf8(headers.value(i)) .writeUtf8(headers.value(i))
.writeUtf8("\r\n") .writeUtf8("\r\n")
} }
sink.writeUtf8("\r\n") socket.sink.writeUtf8("\r\n")
state = STATE_OPEN_REQUEST_BODY state = STATE_OPEN_REQUEST_BODY
} }
@@ -295,7 +291,7 @@ class Http1ExchangeCodec(
/** An HTTP request body. */ /** An HTTP request body. */
private inner class KnownLengthSink : Sink { private inner class KnownLengthSink : Sink {
private val timeout = ForwardingTimeout(sink.timeout()) private val timeout = ForwardingTimeout(socket.sink.timeout())
private var closed: Boolean = false private var closed: Boolean = false
override fun timeout(): Timeout = timeout override fun timeout(): Timeout = timeout
@@ -306,12 +302,12 @@ class Http1ExchangeCodec(
) { ) {
check(!closed) { "closed" } check(!closed) { "closed" }
checkOffsetAndCount(source.size, 0, byteCount) checkOffsetAndCount(source.size, 0, byteCount)
sink.write(source, byteCount) socket.sink.write(source, byteCount)
} }
override fun flush() { override fun flush() {
if (closed) return // Don't throw; this stream might have been closed on the caller's behalf. if (closed) return // Don't throw; this stream might have been closed on the caller's behalf.
sink.flush() socket.sink.flush()
} }
override fun close() { override fun close() {
@@ -327,7 +323,7 @@ class Http1ExchangeCodec(
* to buffer chunks; typically by using a buffered sink with this sink. * to buffer chunks; typically by using a buffered sink with this sink.
*/ */
private inner class ChunkedSink : Sink { private inner class ChunkedSink : Sink {
private val timeout = ForwardingTimeout(sink.timeout()) private val timeout = ForwardingTimeout(socket.sink.timeout())
private var closed: Boolean = false private var closed: Boolean = false
override fun timeout(): Timeout = timeout override fun timeout(): Timeout = timeout
@@ -339,23 +335,25 @@ class Http1ExchangeCodec(
check(!closed) { "closed" } check(!closed) { "closed" }
if (byteCount == 0L) return if (byteCount == 0L) return
sink.writeHexadecimalUnsignedLong(byteCount) with(socket.sink) {
sink.writeUtf8("\r\n") writeHexadecimalUnsignedLong(byteCount)
sink.write(source, byteCount) writeUtf8("\r\n")
sink.writeUtf8("\r\n") write(source, byteCount)
writeUtf8("\r\n")
}
} }
@Synchronized @Synchronized
override fun flush() { override fun flush() {
if (closed) return // Don't throw; this stream might have been closed on the caller's behalf. if (closed) return // Don't throw; this stream might have been closed on the caller's behalf.
sink.flush() socket.sink.flush()
} }
@Synchronized @Synchronized
override fun close() { override fun close() {
if (closed) return if (closed) return
closed = true closed = true
sink.writeUtf8("0\r\n\r\n") socket.sink.writeUtf8("0\r\n\r\n")
detachTimeout(timeout) detachTimeout(timeout)
state = STATE_READ_RESPONSE_HEADERS state = STATE_READ_RESPONSE_HEADERS
} }
@@ -364,7 +362,7 @@ class Http1ExchangeCodec(
private abstract inner class AbstractSource( private abstract inner class AbstractSource(
val url: HttpUrl, val url: HttpUrl,
) : Source { ) : Source {
protected val timeout = ForwardingTimeout(source.timeout()) protected val timeout = ForwardingTimeout(socket.source.timeout())
protected var closed: Boolean = false protected var closed: Boolean = false
override fun timeout(): Timeout = timeout override fun timeout(): Timeout = timeout
@@ -374,7 +372,7 @@ class Http1ExchangeCodec(
byteCount: Long, byteCount: Long,
): Long = ): Long =
try { try {
source.read(sink, byteCount) socket.source.read(sink, byteCount)
} catch (e: IOException) { } catch (e: IOException) {
carrier.noNewExchanges() carrier.noNewExchanges()
responseBodyComplete(TRAILERS_RESPONSE_BODY_TRUNCATED) responseBodyComplete(TRAILERS_RESPONSE_BODY_TRUNCATED)
@@ -481,11 +479,11 @@ class Http1ExchangeCodec(
private fun readChunkSize() { private fun readChunkSize() {
// Read the suffix of the previous chunk. // Read the suffix of the previous chunk.
if (bytesRemainingInChunk != NO_CHUNK_YET) { if (bytesRemainingInChunk != NO_CHUNK_YET) {
source.readUtf8LineStrict() socket.source.readUtf8LineStrict()
} }
try { try {
bytesRemainingInChunk = source.readHexadecimalUnsignedLong() bytesRemainingInChunk = socket.source.readHexadecimalUnsignedLong()
val extensions = source.readUtf8LineStrict().trim() val extensions = socket.source.readUtf8LineStrict().trim()
if (bytesRemainingInChunk < 0L || extensions.isNotEmpty() && !extensions.startsWith(";")) { if (bytesRemainingInChunk < 0L || extensions.isNotEmpty() && !extensions.startsWith(";")) {
throw ProtocolException( throw ProtocolException(
"expected chunk size and optional extensions" + "expected chunk size and optional extensions" +

View File

@@ -18,7 +18,6 @@ package okhttp3.internal.http2
import java.io.Closeable import java.io.Closeable
import java.io.IOException import java.io.IOException
import java.io.InterruptedIOException import java.io.InterruptedIOException
import java.net.Socket
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import okhttp3.Headers import okhttp3.Headers
import okhttp3.internal.EMPTY_BYTE_ARRAY import okhttp3.internal.EMPTY_BYTE_ARRAY
@@ -29,22 +28,18 @@ import okhttp3.internal.concurrent.assertLockNotHeld
import okhttp3.internal.concurrent.notifyAll import okhttp3.internal.concurrent.notifyAll
import okhttp3.internal.concurrent.wait import okhttp3.internal.concurrent.wait
import okhttp3.internal.concurrent.withLock import okhttp3.internal.concurrent.withLock
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM
import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE
import okhttp3.internal.http2.flowcontrol.WindowCounter import okhttp3.internal.http2.flowcontrol.WindowCounter
import okhttp3.internal.ignoreIoExceptions import okhttp3.internal.ignoreIoExceptions
import okhttp3.internal.okHttpName import okhttp3.internal.okHttpName
import okhttp3.internal.peerName
import okhttp3.internal.platform.Platform import okhttp3.internal.platform.Platform
import okhttp3.internal.platform.Platform.Companion.INFO import okhttp3.internal.platform.Platform.Companion.INFO
import okhttp3.internal.toHeaders import okhttp3.internal.toHeaders
import okio.Buffer import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource import okio.BufferedSource
import okio.ByteString import okio.ByteString
import okio.buffer
import okio.sink
import okio.source
/** /**
* A socket connection to a remote peer. A connection hosts streams which can send and receive * A socket connection to a remote peer. A connection hosts streams which can send and receive
@@ -140,11 +135,11 @@ class Http2Connection internal constructor(
var writeBytesMaximum: Long = peerSettings.initialWindowSize.toLong() var writeBytesMaximum: Long = peerSettings.initialWindowSize.toLong()
private set private set
internal val socket: Socket = builder.socket internal val socket: BufferedSocket = builder.socket
val writer = Http2Writer(builder.sink, client) val writer = Http2Writer(socket.sink, client)
// Visible for testing // Visible for testing
val readerRunnable = ReaderRunnable(Http2Reader(builder.source, client)) val readerRunnable = ReaderRunnable(Http2Reader(socket.source, client))
// Guarded by this. // Guarded by this.
private val currentPushRequests = mutableSetOf<Int>() private val currentPushRequests = mutableSetOf<Int>()
@@ -479,9 +474,9 @@ class Http2Connection internal constructor(
writer.close() writer.close()
} }
// Close the socket to break out the reader thread, which will clean up after itself. // Cancel the socket to break out the reader thread, which will clean up after itself.
ignoreIoExceptions { ignoreIoExceptions {
socket.close() socket.cancel()
} }
// Release the threads. // Release the threads.
@@ -574,22 +569,17 @@ class Http2Connection internal constructor(
internal var client: Boolean, internal var client: Boolean,
internal val taskRunner: TaskRunner, internal val taskRunner: TaskRunner,
) { ) {
internal lateinit var socket: Socket internal lateinit var socket: BufferedSocket
internal lateinit var connectionName: String internal lateinit var connectionName: String
internal lateinit var source: BufferedSource
internal lateinit var sink: BufferedSink
internal var listener = Listener.REFUSE_INCOMING_STREAMS internal var listener = Listener.REFUSE_INCOMING_STREAMS
internal var pushObserver = PushObserver.CANCEL internal var pushObserver = PushObserver.CANCEL
internal var pingIntervalMillis: Int = 0 internal var pingIntervalMillis: Int = 0
internal var flowControlListener: FlowControlListener = FlowControlListener.None internal var flowControlListener: FlowControlListener = FlowControlListener.None
@Throws(IOException::class) @Throws(IOException::class)
@JvmOverloads
fun socket( fun socket(
socket: Socket, socket: BufferedSocket,
peerName: String = socket.peerName(), peerName: String,
source: BufferedSource = socket.source().buffer(),
sink: BufferedSink = socket.sink().buffer(),
) = apply { ) = apply {
this.socket = socket this.socket = socket
this.connectionName = this.connectionName =
@@ -597,8 +587,6 @@ class Http2Connection internal constructor(
client -> "$okHttpName $peerName" client -> "$okHttpName $peerName"
else -> "MockWebServer $peerName" else -> "MockWebServer $peerName"
} }
this.source = source
this.sink = sink
} }
fun listener(listener: Listener) = fun listener(listener: Listener) =
@@ -800,7 +788,7 @@ class Http2Connection internal constructor(
} }
} }
if (streamsToNotify != null) { if (streamsToNotify != null) {
for (stream in streamsToNotify!!) { for (stream in streamsToNotify) {
stream.withLock { stream.withLock {
stream.addBytesToWriteWindow(delta) stream.addBytesToWriteWindow(delta)
} }

View File

@@ -36,7 +36,9 @@ import okhttp3.internal.concurrent.Lockable
import okhttp3.internal.concurrent.Task import okhttp3.internal.concurrent.Task
import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.concurrent.assertLockHeld import okhttp3.internal.concurrent.assertLockHeld
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.connection.RealCall import okhttp3.internal.connection.RealCall
import okhttp3.internal.connection.asBufferedSocket
import okhttp3.internal.okHttpName import okhttp3.internal.okHttpName
import okhttp3.internal.ws.WebSocketProtocol.CLOSE_CLIENT_GOING_AWAY import okhttp3.internal.ws.WebSocketProtocol.CLOSE_CLIENT_GOING_AWAY
import okhttp3.internal.ws.WebSocketProtocol.CLOSE_MESSAGE_MAX import okhttp3.internal.ws.WebSocketProtocol.CLOSE_MESSAGE_MAX
@@ -49,7 +51,6 @@ import okio.ByteString
import okio.ByteString.Companion.encodeUtf8 import okio.ByteString.Companion.encodeUtf8
import okio.ByteString.Companion.toByteString import okio.ByteString.Companion.toByteString
import okio.Socket import okio.Socket
import okio.buffer
class RealWebSocket( class RealWebSocket(
taskRunner: TaskRunner, taskRunner: TaskRunner,
@@ -197,9 +198,7 @@ class RealWebSocket(
val name = "$okHttpName WebSocket ${request.url.redact()}" val name = "$okHttpName WebSocket ${request.url.redact()}"
initReaderAndWriter( initReaderAndWriter(
name = name, name = name,
socket = socket, socket = socket.asBufferedSocket(),
socketSource = socket.source.buffer(),
socketSink = socket.sink.buffer(),
client = true, client = true,
) )
loopReader(response) loopReader(response)
@@ -269,9 +268,7 @@ class RealWebSocket(
*/ */
fun initReaderAndWriter( fun initReaderAndWriter(
name: String, name: String,
socket: Socket, socket: BufferedSocket,
socketSource: BufferedSource,
socketSink: BufferedSink,
client: Boolean, client: Boolean,
) { ) {
val extensions = this.extensions!! val extensions = this.extensions!!
@@ -281,7 +278,7 @@ class RealWebSocket(
this.writer = this.writer =
WebSocketWriter( WebSocketWriter(
isClient = client, isClient = client,
sink = socketSink, sink = socket.sink,
random = random, random = random,
perMessageDeflate = extensions.perMessageDeflate, perMessageDeflate = extensions.perMessageDeflate,
noContextTakeover = extensions.noContextTakeover(client), noContextTakeover = extensions.noContextTakeover(client),
@@ -303,7 +300,7 @@ class RealWebSocket(
reader = reader =
WebSocketReader( WebSocketReader(
isClient = client, isClient = client,
source = socketSource, source = socket.source,
frameCallback = this, frameCallback = this,
perMessageDeflate = extensions.perMessageDeflate, perMessageDeflate = extensions.perMessageDeflate,
noContextTakeover = extensions.noContextTakeover(!client), noContextTakeover = extensions.noContextTakeover(!client),

View File

@@ -204,7 +204,7 @@ class WebSocketWriter(
} }
sinkBuffer.write(messageBuffer, dataSize) sinkBuffer.write(messageBuffer, dataSize)
sink.emit() sink.flush()
} }
override fun close() { override fun close() {

View File

@@ -2823,7 +2823,9 @@ open class CallTest {
call.enqueue(callback) call.enqueue(callback)
call.cancel() call.cancel()
latch.countDown() latch.countDown()
callback.await(server.url("/a")).assertFailure("Canceled", "Socket closed", "Socket is closed") callback
.await(server.url("/a"))
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
} }
@Test @Test
@@ -2831,7 +2833,9 @@ open class CallTest {
val call = client.newCall(Request(server.url("/"))) val call = client.newCall(Request(server.url("/")))
call.enqueue(callback) call.enqueue(callback)
client.dispatcher.cancelAll() client.dispatcher.cancelAll()
callback.await(server.url("/")).assertFailure("Canceled", "Socket closed", "Socket is closed") callback
.await(server.url("/"))
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
} }
@Test @Test
@@ -2942,7 +2946,9 @@ open class CallTest {
assertThat(server.takeRequest().url.encodedPath).isEqualTo("/a") assertThat(server.takeRequest().url.encodedPath).isEqualTo("/a")
callback.await(requestA.url).assertBody("A") callback.await(requestA.url).assertBody("A")
// At this point we know the callback is ready, and that it will receive a cancel failure. // At this point we know the callback is ready, and that it will receive a cancel failure.
callback.await(requestB.url).assertFailure("Canceled", "Socket closed") callback
.await(requestB.url)
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
} }
@Test @Test
@@ -4715,7 +4721,7 @@ open class CallTest {
.build() .build()
executeSynchronously("/") executeSynchronously("/")
.assertFailure(IOException::class.java) .assertFailure(IOException::class.java)
.assertFailure("Socket closed", "Socket is closed") .assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
} }
@Test @Test
@@ -4805,7 +4811,8 @@ open class CallTest {
.build() .build()
}, },
).build() ).build()
executeSynchronously("/").assertFailure("Canceled") executeSynchronously("/")
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
assertThat(closed.get()).isTrue() assertThat(closed.get()).isTrue()
} }

View File

@@ -316,7 +316,7 @@ class ConnectionPoolTest {
val connection = val connection =
Http2Connection Http2Connection
.Builder(true, TaskRunner.INSTANCE) .Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket()) .socket(peer.openSocket().asBufferedSocket(), "peer")
.pushObserver(Http2ConnectionTest.IGNORE) .pushObserver(Http2ConnectionTest.IGNORE)
.listener(realConnection) .listener(realConnection)
.build() .build()

View File

@@ -82,12 +82,12 @@ class HttpUpgradesTest {
socket.sink.buffer().use { sink -> socket.sink.buffer().use { sink ->
socket.source.buffer().use { source -> socket.source.buffer().use { source ->
sink.writeUtf8("client says hello\n") sink.writeUtf8("client says hello\n")
sink.emit() sink.flush()
assertThat(source.readUtf8Line()).isEqualTo("server says hello") assertThat(source.readUtf8Line()).isEqualTo("server says hello")
sink.writeUtf8("client says goodbye\n") sink.writeUtf8("client says goodbye\n")
sink.emit() sink.flush()
assertThat(source.readUtf8Line()).isEqualTo("server says goodbye") assertThat(source.readUtf8Line()).isEqualTo("server says goodbye")

View File

@@ -43,6 +43,7 @@ import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.concurrent.notifyAll import okhttp3.internal.concurrent.notifyAll
import okhttp3.internal.concurrent.wait import okhttp3.internal.concurrent.wait
import okhttp3.internal.concurrent.withLock import okhttp3.internal.concurrent.withLock
import okhttp3.internal.connection.asBufferedSocket
import okio.AsyncTimeout import okio.AsyncTimeout
import okio.Buffer import okio.Buffer
import okio.BufferedSource import okio.BufferedSource
@@ -497,7 +498,7 @@ class Http2ConnectionTest {
val connection = val connection =
Http2Connection Http2Connection
.Builder(true, TaskRunner.INSTANCE) .Builder(true, TaskRunner.INSTANCE)
.socket(socket) .socket(socket.asBufferedSocket(), "peer")
.pushObserver(IGNORE) .pushObserver(IGNORE)
.build() .build()
connection.start(sendConnectionPreface = false) connection.start(sendConnectionPreface = false)
@@ -1890,7 +1891,7 @@ class Http2ConnectionTest {
val connection = val connection =
Http2Connection Http2Connection
.Builder(true, TaskRunner.INSTANCE) .Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket()) .socket(peer.openSocket().asBufferedSocket(), "peer")
.build() .build()
connection.start(sendConnectionPreface = false) connection.start(sendConnectionPreface = false)
val stream = connection.newStream(headerEntries("b", "banana"), false) val stream = connection.newStream(headerEntries("b", "banana"), false)
@@ -1912,11 +1913,10 @@ class Http2ConnectionTest {
peer.acceptFrame() // SYN_STREAM. peer.acceptFrame() // SYN_STREAM.
peer.play() peer.play()
val taskRunner = taskFaker.taskRunner val taskRunner = taskFaker.taskRunner
val socket = peer.openSocket()
val connection = val connection =
Http2Connection Http2Connection
.Builder(true, taskRunner) .Builder(true, taskRunner)
.socket(socket) .socket(peer.openSocket().asBufferedSocket(), "peer")
.pushObserver(IGNORE) .pushObserver(IGNORE)
.build() .build()
connection.start(sendConnectionPreface = false) connection.start(sendConnectionPreface = false)
@@ -1972,7 +1972,7 @@ class Http2ConnectionTest {
val connection = val connection =
Http2Connection Http2Connection
.Builder(true, TaskRunner.INSTANCE) .Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket()) .socket(peer.openSocket().asBufferedSocket(), "peer")
.pushObserver(pushObserver) .pushObserver(pushObserver)
.listener(listener) .listener(listener)
.build() .build()

View File

@@ -35,6 +35,7 @@ import okhttp3.Request
import okhttp3.Response import okhttp3.Response
import okhttp3.TestUtil.repeat import okhttp3.TestUtil.repeat
import okhttp3.internal.concurrent.TaskFaker import okhttp3.internal.concurrent.TaskFaker
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.ws.WebSocketExtensions.Companion.parse import okhttp3.internal.ws.WebSocketExtensions.Companion.parse
import okio.BufferedSink import okio.BufferedSink
import okio.BufferedSource import okio.BufferedSource
@@ -466,7 +467,7 @@ class RealWebSocketTest {
private val taskFaker: TaskFaker, private val taskFaker: TaskFaker,
private val delegate: Socket, private val delegate: Socket,
private val client: Boolean, private val client: Boolean,
) : Socket { ) : BufferedSocket {
private val name = if (client) "client" else "server" private val name = if (client) "client" else "server"
val listener = WebSocketRecorder(name) val listener = WebSocketRecorder(name)
var webSocket: RealWebSocket? = null var webSocket: RealWebSocket? = null
@@ -530,7 +531,7 @@ class RealWebSocketTest {
} }
} }
} }
webSocket!!.initReaderAndWriter(name, this, source, sink, client) webSocket!!.initReaderAndWriter(name, this, client)
} }
/** /**