1
0
mirror of https://github.com/square/okhttp.git synced 2025-08-08 23:42:08 +03:00

New BufferedSocket class (#8977)

* New BufferedSocket class

I'm starting to push Okio's socket throughout OkHttp,
and it turns out what I usually want is a BufferedSocket
instead.

* Flush not emit

* Track exception messages

* Use flush more
This commit is contained in:
Jesse Wilson
2025-07-31 00:18:40 -04:00
committed by GitHub
parent 3bafcf16fb
commit 4e7860212f
18 changed files with 156 additions and 181 deletions

View File

@@ -516,7 +516,7 @@ public class MockWebServer : Closeable {
val connection =
Http2Connection
.Builder(false, taskRunner)
.socket(socket.javaNetSocket)
.socket(socket, socket.javaNetSocket.remoteSocketAddress.toString())
.listener(http2SocketHandler)
.build()
connection.start()
@@ -836,8 +836,6 @@ public class MockWebServer : Closeable {
webSocket.initReaderAndWriter(
name = name,
socket = socket,
socketSource = socket.source,
socketSink = socket.sink,
client = false,
)

View File

@@ -23,6 +23,7 @@ import java.util.concurrent.CountDownLatch
import javax.net.ssl.SSLSocket
import okhttp3.Handshake
import okhttp3.Handshake.Companion.handshake
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.platform.Platform
import okio.BufferedSink
import okio.BufferedSource
@@ -40,7 +41,7 @@ import okio.buffer
internal class MockWebServerSocket(
val javaNetSocket: Socket,
) : Closeable,
okio.Socket {
BufferedSocket {
private val delegate = javaNetSocket.asOkioSocket()
private val closedLatch = CountDownLatch(2)

View File

@@ -17,6 +17,7 @@ package mockwebserver3.internal.http2
import java.io.File
import java.io.IOException
import java.net.InetSocketAddress
import java.net.ProtocolException
import java.net.ServerSocket
import java.net.Socket
@@ -28,6 +29,7 @@ import okhttp3.Protocol
import okhttp3.Protocol.Companion.get
import okhttp3.internal.closeQuietly
import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.connection.asBufferedSocket
import okhttp3.internal.http2.Header
import okhttp3.internal.http2.Http2Connection
import okhttp3.internal.http2.Http2Stream
@@ -57,7 +59,7 @@ class Http2Server(
val connection =
Http2Connection
.Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket)
.socket(sslSocket.asBufferedSocket(), sslSocket.peerName())
.listener(this)
.build()
connection.start()
@@ -192,6 +194,11 @@ class Http2Server(
else -> "text/plain"
}
private fun Socket.peerName(): String {
val address = remoteSocketAddress
return if (address is InetSocketAddress) address.hostName else address.toString()
}
companion object {
val logger: Logger = Logger.getLogger(Http2Server::class.java.name)

View File

@@ -21,25 +21,13 @@ import java.util.concurrent.FutureTask
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import mockwebserver3.SocketHandler
import okio.BufferedSink
import okio.BufferedSource
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.connection.asBufferedSocket
import okio.Socket
import okio.buffer
import okio.utf8Size
private typealias Action = (BufferedSocket) -> Unit
private class BufferedSocket(
val socket: Socket,
) {
val source: BufferedSource = socket.source.buffer()
val sink: BufferedSink = socket.sink.buffer()
fun cancel() {
socket.cancel()
}
}
/**
* A scriptable request/response conversation. Create the script by calling methods like
* [receiveRequest] in the sequence they are run.
@@ -104,7 +92,7 @@ class MockSocketHandler : SocketHandler {
}
override fun handle(socket: Socket) {
val task = serviceSocketTask(BufferedSocket(socket))
val task = serviceSocketTask(socket.asBufferedSocket())
results.add(task)
task.run()
}

View File

@@ -16,12 +16,12 @@
package okhttp3.tls.internal
import java.net.InetSocketAddress
import java.net.Socket
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.SSLEngine
import javax.net.ssl.X509ExtendedTrustManager
import okhttp3.internal.peerName
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement
/**
@@ -78,4 +78,9 @@ internal class InsecureExtendedTrustManager(
authType: String,
socket: Socket?,
) = throw CertificateException("Unsupported operation")
private fun Socket.peerName(): String {
val address = remoteSocketAddress
return if (address is InetSocketAddress) address.hostName else address.toString()
}
}

View File

@@ -19,7 +19,6 @@ package okhttp3.internal
import java.io.IOException
import java.io.InterruptedIOException
import java.net.InetSocketAddress
import java.net.ServerSocket
import java.net.Socket
import java.net.SocketTimeoutException
@@ -197,11 +196,6 @@ internal fun Source.discard(
false
}
internal fun Socket.peerName(): String {
val address = remoteSocketAddress
return if (address is InetSocketAddress) address.hostName else address.toString()
}
/**
* Returns true if new reads and writes should be attempted on this.
*

View File

@@ -0,0 +1,41 @@
/*
* Copyright (C) 2025 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.internal.connection
import java.net.Socket as JavaNetSocket
import okio.BufferedSink
import okio.BufferedSource
import okio.Socket as OkioSocket
import okio.asOkioSocket
import okio.buffer
interface BufferedSocket : OkioSocket {
override val source: BufferedSource
override val sink: BufferedSink
}
fun JavaNetSocket.asBufferedSocket(): BufferedSocket = asOkioSocket().asBufferedSocket()
fun OkioSocket.asBufferedSocket(): BufferedSocket =
object : BufferedSocket {
private val delegate = this@asBufferedSocket
override val source = delegate.source.buffer()
override val sink = delegate.sink.buffer()
override fun cancel() {
delegate.cancel()
}
}

View File

@@ -42,11 +42,6 @@ import okhttp3.internal.http1.Http1ExchangeCodec
import okhttp3.internal.platform.Platform
import okhttp3.internal.tls.OkHostnameVerifier
import okhttp3.internal.toHostHeader
import okio.BufferedSink
import okio.BufferedSource
import okio.Socket as OkioSocket
import okio.asOkioSocket
import okio.buffer
/**
* A single attempt to connect to a remote server, including these steps:
@@ -94,9 +89,7 @@ class ConnectPlan(
internal var javaNetSocket: JavaNetSocket? = null
private var handshake: Handshake? = null
private var protocol: Protocol? = null
private lateinit var okioSocket: OkioSocket
private lateinit var source: BufferedSource
private lateinit var sink: BufferedSink
private lateinit var socket: BufferedSocket
private var connection: RealConnection? = null
/** True if this connection is ready for use, including TCP, tunnels, and TLS. */
@@ -185,7 +178,7 @@ class ConnectPlan(
// that happens, then we will have buffered bytes that are needed by the SSLSocket!
// This check is imperfect: it doesn't tell us whether a handshake will succeed, just
// that it will almost certainly fail because the proxy has sent unexpected data.
if (!source.buffer.exhausted() || !sink.buffer.exhausted()) {
if (!socket.source.buffer.exhausted() || !socket.sink.buffer.exhausted()) {
throw IOException("TLS tunnel buffered too many bytes!")
}
@@ -225,12 +218,10 @@ class ConnectPlan(
connectionPool = connectionPool,
route = route,
rawSocket = rawSocket,
socket = javaNetSocket!!,
javaNetSocket = javaNetSocket!!,
handshake = handshake,
protocol = protocol!!,
okioSocket = okioSocket,
source = source,
sink = sink,
socket = socket,
pingIntervalMillis = pingIntervalMillis,
connectionListener = connectionPool.connectionListener,
)
@@ -291,9 +282,7 @@ class ConnectPlan(
// https://github.com/square/okhttp/issues/3245
// https://android-review.googlesource.com/#/c/271775/
try {
okioSocket = rawSocket.asOkioSocket()
source = okioSocket.source.buffer()
sink = okioSocket.sink.buffer()
this.socket = rawSocket.asBufferedSocket()
} catch (npe: NullPointerException) {
if (npe.message == NPE_THROW_WITH_NULL) {
throw IOException(npe)
@@ -408,9 +397,7 @@ class ConnectPlan(
null
}
javaNetSocket = sslSocket
okioSocket = sslSocket.asOkioSocket()
source = okioSocket.source.buffer()
sink = okioSocket.sink.buffer()
socket = sslSocket.asBufferedSocket()
protocol = if (maybeProtocol != null) Protocol.get(maybeProtocol) else Protocol.HTTP_1_1
success = true
} finally {
@@ -437,12 +424,10 @@ class ConnectPlan(
// No client for CONNECT tunnels:
client = null,
carrier = this,
socket = okioSocket,
source = source,
sink = sink,
socket = socket,
)
source.timeout().timeout(readTimeoutMillis.toLong(), TimeUnit.MILLISECONDS)
sink.timeout().timeout(writeTimeoutMillis.toLong(), TimeUnit.MILLISECONDS)
socket.source.timeout().timeout(readTimeoutMillis.toLong(), TimeUnit.MILLISECONDS)
socket.sink.timeout().timeout(writeTimeoutMillis.toLong(), TimeUnit.MILLISECONDS)
tunnelCodec.writeRequest(nextRequest.headers, requestLine)
tunnelCodec.finishRequest()
val response =

View File

@@ -52,12 +52,6 @@ import okhttp3.internal.http2.StreamResetException
import okhttp3.internal.isHealthy
import okhttp3.internal.tls.OkHostnameVerifier
import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource
import okio.Sink
import okio.Source
import okio.Timeout
import okio.buffer
/**
* A connection to a remote web server capable of carrying 1 or more concurrent streams.
@@ -75,12 +69,10 @@ class RealConnection internal constructor(
* The application layer socket. Either an [SSLSocket] layered over [rawSocket], or [rawSocket]
* itself if this connection does not use SSL.
*/
private val socket: JavaNetSocket,
private val javaNetSocket: JavaNetSocket,
private val handshake: Handshake?,
private val protocol: Protocol,
private val okioSocket: okio.Socket,
private val source: BufferedSource,
private val sink: BufferedSink,
private val socket: BufferedSocket,
private val pingIntervalMillis: Int,
internal val connectionListener: ConnectionListener,
) : Http2Connection.Listener(),
@@ -167,12 +159,12 @@ class RealConnection internal constructor(
@Throws(IOException::class)
private fun startHttp2() {
socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream.
javaNetSocket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream.
val flowControlListener = connectionListener as? FlowControlListener ?: FlowControlListener.None
val http2Connection =
Http2Connection
.Builder(client = true, taskRunner)
.socket(socket, route.address.url.host, source, sink)
.socket(socket, route.address.url.host)
.listener(this)
.pingIntervalMillis(pingIntervalMillis)
.flowControlListener(flowControlListener)
@@ -277,23 +269,21 @@ class RealConnection internal constructor(
client: OkHttpClient,
chain: RealInterceptorChain,
): ExchangeCodec {
val socket = this.socket
val source = this.source
val sink = this.sink
val okHttpSocket = this.socket
val http2Connection = this.http2Connection
return if (http2Connection != null) {
Http2ExchangeCodec(client, this, chain, http2Connection)
} else {
socket.soTimeout = chain.readTimeoutMillis()
source.timeout().timeout(chain.readTimeoutMillis.toLong(), MILLISECONDS)
sink.timeout().timeout(chain.writeTimeoutMillis.toLong(), MILLISECONDS)
Http1ExchangeCodec(client, this, okioSocket, source, sink)
javaNetSocket.soTimeout = chain.readTimeoutMillis()
okHttpSocket.source.timeout().timeout(chain.readTimeoutMillis.toLong(), MILLISECONDS)
okHttpSocket.sink.timeout().timeout(chain.writeTimeoutMillis.toLong(), MILLISECONDS)
Http1ExchangeCodec(client, this, okHttpSocket)
}
}
internal fun useAsSocket() {
socket.soTimeout = 0
javaNetSocket.soTimeout = 0
noNewExchanges()
}
@@ -304,7 +294,7 @@ class RealConnection internal constructor(
rawSocket.closeQuietly()
}
override fun socket(): JavaNetSocket = socket
override fun socket(): JavaNetSocket = javaNetSocket
/** Returns true if this connection is ready to host new streams. */
fun isHealthy(doExtensiveChecks: Boolean): Boolean {
@@ -313,9 +303,9 @@ class RealConnection internal constructor(
val nowNs = System.nanoTime()
if (rawSocket.isClosed ||
socket.isClosed ||
socket.isInputShutdown ||
socket.isOutputShutdown
javaNetSocket.isClosed ||
javaNetSocket.isInputShutdown ||
javaNetSocket.isOutputShutdown
) {
return false
}
@@ -327,7 +317,7 @@ class RealConnection internal constructor(
val idleDurationNs = withLock { nowNs - idleAtNs }
if (idleDurationNs >= IDLE_CONNECTION_HEALTHY_NS && doExtensiveChecks) {
return socket.isHealthy(source)
return javaNetSocket.isHealthy(socket.source)
}
return true
@@ -452,33 +442,10 @@ class RealConnection internal constructor(
socket: JavaNetSocket,
idleAtNs: Long,
): RealConnection {
val okioSocket =
object : okio.Socket {
override val sink: Sink =
object : Sink {
override fun close() = Unit
override fun flush() = Unit
override fun timeout(): Timeout = Timeout.NONE
override fun write(
source: Buffer,
byteCount: Long,
): Unit = throw UnsupportedOperationException()
}
override val source: Source =
object : Source {
override fun close() = Unit
override fun read(
sink: Buffer,
byteCount: Long,
): Long = throw UnsupportedOperationException()
override fun timeout(): Timeout = Timeout.NONE
}
val bufferedSocket =
object : BufferedSocket {
override val sink = Buffer()
override val source = Buffer()
override fun cancel() {
}
@@ -490,12 +457,10 @@ class RealConnection internal constructor(
connectionPool = connectionPool,
route = route,
rawSocket = JavaNetSocket(),
socket = socket,
javaNetSocket = socket,
handshake = null,
protocol = Protocol.HTTP_2,
okioSocket = okioSocket,
source = okioSocket.source.buffer(),
sink = okioSocket.sink.buffer(),
socket = bufferedSocket,
pingIntervalMillis = 0,
connectionListener = ConnectionListener.NONE,
)

View File

@@ -26,6 +26,7 @@ import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.internal.checkOffsetAndCount
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.discard
import okhttp3.internal.headersContentLength
import okhttp3.internal.http.ExchangeCodec
@@ -37,11 +38,8 @@ import okhttp3.internal.http.receiveHeaders
import okhttp3.internal.http1.Http1ExchangeCodec.Companion.TRAILERS_RESPONSE_BODY_TRUNCATED
import okhttp3.internal.skipAll
import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource
import okio.ForwardingTimeout
import okio.Sink
import okio.Socket
import okio.Source
import okio.Timeout
@@ -66,12 +64,10 @@ class Http1ExchangeCodec(
/** The client that configures this stream. May be null for HTTPS proxy tunnels. */
private val client: OkHttpClient?,
override val carrier: ExchangeCodec.Carrier,
override val socket: Socket,
private val source: BufferedSource,
private val sink: BufferedSink,
override val socket: BufferedSocket,
) : ExchangeCodec {
private var state = STATE_IDLE
private val headersReader = HeadersReader(source)
private val headersReader = HeadersReader(socket.source)
private val Response.isChunked: Boolean
get() = "chunked".equals(header("Transfer-Encoding"), ignoreCase = true)
@@ -161,11 +157,11 @@ class Http1ExchangeCodec(
}
override fun flushRequest() {
sink.flush()
socket.sink.flush()
}
override fun finishRequest() {
sink.flush()
socket.sink.flush()
}
/** Returns bytes of a request header for sending on an HTTP transport. */
@@ -174,15 +170,15 @@ class Http1ExchangeCodec(
requestLine: String,
) {
check(state == STATE_IDLE) { "state: $state" }
sink.writeUtf8(requestLine).writeUtf8("\r\n")
socket.sink.writeUtf8(requestLine).writeUtf8("\r\n")
for (i in 0 until headers.size) {
sink
socket.sink
.writeUtf8(headers.name(i))
.writeUtf8(": ")
.writeUtf8(headers.value(i))
.writeUtf8("\r\n")
}
sink.writeUtf8("\r\n")
socket.sink.writeUtf8("\r\n")
state = STATE_OPEN_REQUEST_BODY
}
@@ -295,7 +291,7 @@ class Http1ExchangeCodec(
/** An HTTP request body. */
private inner class KnownLengthSink : Sink {
private val timeout = ForwardingTimeout(sink.timeout())
private val timeout = ForwardingTimeout(socket.sink.timeout())
private var closed: Boolean = false
override fun timeout(): Timeout = timeout
@@ -306,12 +302,12 @@ class Http1ExchangeCodec(
) {
check(!closed) { "closed" }
checkOffsetAndCount(source.size, 0, byteCount)
sink.write(source, byteCount)
socket.sink.write(source, byteCount)
}
override fun flush() {
if (closed) return // Don't throw; this stream might have been closed on the caller's behalf.
sink.flush()
socket.sink.flush()
}
override fun close() {
@@ -327,7 +323,7 @@ class Http1ExchangeCodec(
* to buffer chunks; typically by using a buffered sink with this sink.
*/
private inner class ChunkedSink : Sink {
private val timeout = ForwardingTimeout(sink.timeout())
private val timeout = ForwardingTimeout(socket.sink.timeout())
private var closed: Boolean = false
override fun timeout(): Timeout = timeout
@@ -339,23 +335,25 @@ class Http1ExchangeCodec(
check(!closed) { "closed" }
if (byteCount == 0L) return
sink.writeHexadecimalUnsignedLong(byteCount)
sink.writeUtf8("\r\n")
sink.write(source, byteCount)
sink.writeUtf8("\r\n")
with(socket.sink) {
writeHexadecimalUnsignedLong(byteCount)
writeUtf8("\r\n")
write(source, byteCount)
writeUtf8("\r\n")
}
}
@Synchronized
override fun flush() {
if (closed) return // Don't throw; this stream might have been closed on the caller's behalf.
sink.flush()
socket.sink.flush()
}
@Synchronized
override fun close() {
if (closed) return
closed = true
sink.writeUtf8("0\r\n\r\n")
socket.sink.writeUtf8("0\r\n\r\n")
detachTimeout(timeout)
state = STATE_READ_RESPONSE_HEADERS
}
@@ -364,7 +362,7 @@ class Http1ExchangeCodec(
private abstract inner class AbstractSource(
val url: HttpUrl,
) : Source {
protected val timeout = ForwardingTimeout(source.timeout())
protected val timeout = ForwardingTimeout(socket.source.timeout())
protected var closed: Boolean = false
override fun timeout(): Timeout = timeout
@@ -374,7 +372,7 @@ class Http1ExchangeCodec(
byteCount: Long,
): Long =
try {
source.read(sink, byteCount)
socket.source.read(sink, byteCount)
} catch (e: IOException) {
carrier.noNewExchanges()
responseBodyComplete(TRAILERS_RESPONSE_BODY_TRUNCATED)
@@ -481,11 +479,11 @@ class Http1ExchangeCodec(
private fun readChunkSize() {
// Read the suffix of the previous chunk.
if (bytesRemainingInChunk != NO_CHUNK_YET) {
source.readUtf8LineStrict()
socket.source.readUtf8LineStrict()
}
try {
bytesRemainingInChunk = source.readHexadecimalUnsignedLong()
val extensions = source.readUtf8LineStrict().trim()
bytesRemainingInChunk = socket.source.readHexadecimalUnsignedLong()
val extensions = socket.source.readUtf8LineStrict().trim()
if (bytesRemainingInChunk < 0L || extensions.isNotEmpty() && !extensions.startsWith(";")) {
throw ProtocolException(
"expected chunk size and optional extensions" +

View File

@@ -18,7 +18,6 @@ package okhttp3.internal.http2
import java.io.Closeable
import java.io.IOException
import java.io.InterruptedIOException
import java.net.Socket
import java.util.concurrent.TimeUnit
import okhttp3.Headers
import okhttp3.internal.EMPTY_BYTE_ARRAY
@@ -29,22 +28,18 @@ import okhttp3.internal.concurrent.assertLockNotHeld
import okhttp3.internal.concurrent.notifyAll
import okhttp3.internal.concurrent.wait
import okhttp3.internal.concurrent.withLock
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM
import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE
import okhttp3.internal.http2.flowcontrol.WindowCounter
import okhttp3.internal.ignoreIoExceptions
import okhttp3.internal.okHttpName
import okhttp3.internal.peerName
import okhttp3.internal.platform.Platform
import okhttp3.internal.platform.Platform.Companion.INFO
import okhttp3.internal.toHeaders
import okio.Buffer
import okio.BufferedSink
import okio.BufferedSource
import okio.ByteString
import okio.buffer
import okio.sink
import okio.source
/**
* A socket connection to a remote peer. A connection hosts streams which can send and receive
@@ -140,11 +135,11 @@ class Http2Connection internal constructor(
var writeBytesMaximum: Long = peerSettings.initialWindowSize.toLong()
private set
internal val socket: Socket = builder.socket
val writer = Http2Writer(builder.sink, client)
internal val socket: BufferedSocket = builder.socket
val writer = Http2Writer(socket.sink, client)
// Visible for testing
val readerRunnable = ReaderRunnable(Http2Reader(builder.source, client))
val readerRunnable = ReaderRunnable(Http2Reader(socket.source, client))
// Guarded by this.
private val currentPushRequests = mutableSetOf<Int>()
@@ -479,9 +474,9 @@ class Http2Connection internal constructor(
writer.close()
}
// Close the socket to break out the reader thread, which will clean up after itself.
// Cancel the socket to break out the reader thread, which will clean up after itself.
ignoreIoExceptions {
socket.close()
socket.cancel()
}
// Release the threads.
@@ -574,22 +569,17 @@ class Http2Connection internal constructor(
internal var client: Boolean,
internal val taskRunner: TaskRunner,
) {
internal lateinit var socket: Socket
internal lateinit var socket: BufferedSocket
internal lateinit var connectionName: String
internal lateinit var source: BufferedSource
internal lateinit var sink: BufferedSink
internal var listener = Listener.REFUSE_INCOMING_STREAMS
internal var pushObserver = PushObserver.CANCEL
internal var pingIntervalMillis: Int = 0
internal var flowControlListener: FlowControlListener = FlowControlListener.None
@Throws(IOException::class)
@JvmOverloads
fun socket(
socket: Socket,
peerName: String = socket.peerName(),
source: BufferedSource = socket.source().buffer(),
sink: BufferedSink = socket.sink().buffer(),
socket: BufferedSocket,
peerName: String,
) = apply {
this.socket = socket
this.connectionName =
@@ -597,8 +587,6 @@ class Http2Connection internal constructor(
client -> "$okHttpName $peerName"
else -> "MockWebServer $peerName"
}
this.source = source
this.sink = sink
}
fun listener(listener: Listener) =
@@ -800,7 +788,7 @@ class Http2Connection internal constructor(
}
}
if (streamsToNotify != null) {
for (stream in streamsToNotify!!) {
for (stream in streamsToNotify) {
stream.withLock {
stream.addBytesToWriteWindow(delta)
}

View File

@@ -36,7 +36,9 @@ import okhttp3.internal.concurrent.Lockable
import okhttp3.internal.concurrent.Task
import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.concurrent.assertLockHeld
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.connection.RealCall
import okhttp3.internal.connection.asBufferedSocket
import okhttp3.internal.okHttpName
import okhttp3.internal.ws.WebSocketProtocol.CLOSE_CLIENT_GOING_AWAY
import okhttp3.internal.ws.WebSocketProtocol.CLOSE_MESSAGE_MAX
@@ -49,7 +51,6 @@ import okio.ByteString
import okio.ByteString.Companion.encodeUtf8
import okio.ByteString.Companion.toByteString
import okio.Socket
import okio.buffer
class RealWebSocket(
taskRunner: TaskRunner,
@@ -197,9 +198,7 @@ class RealWebSocket(
val name = "$okHttpName WebSocket ${request.url.redact()}"
initReaderAndWriter(
name = name,
socket = socket,
socketSource = socket.source.buffer(),
socketSink = socket.sink.buffer(),
socket = socket.asBufferedSocket(),
client = true,
)
loopReader(response)
@@ -269,9 +268,7 @@ class RealWebSocket(
*/
fun initReaderAndWriter(
name: String,
socket: Socket,
socketSource: BufferedSource,
socketSink: BufferedSink,
socket: BufferedSocket,
client: Boolean,
) {
val extensions = this.extensions!!
@@ -281,7 +278,7 @@ class RealWebSocket(
this.writer =
WebSocketWriter(
isClient = client,
sink = socketSink,
sink = socket.sink,
random = random,
perMessageDeflate = extensions.perMessageDeflate,
noContextTakeover = extensions.noContextTakeover(client),
@@ -303,7 +300,7 @@ class RealWebSocket(
reader =
WebSocketReader(
isClient = client,
source = socketSource,
source = socket.source,
frameCallback = this,
perMessageDeflate = extensions.perMessageDeflate,
noContextTakeover = extensions.noContextTakeover(!client),

View File

@@ -204,7 +204,7 @@ class WebSocketWriter(
}
sinkBuffer.write(messageBuffer, dataSize)
sink.emit()
sink.flush()
}
override fun close() {

View File

@@ -2823,7 +2823,9 @@ open class CallTest {
call.enqueue(callback)
call.cancel()
latch.countDown()
callback.await(server.url("/a")).assertFailure("Canceled", "Socket closed", "Socket is closed")
callback
.await(server.url("/a"))
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
}
@Test
@@ -2831,7 +2833,9 @@ open class CallTest {
val call = client.newCall(Request(server.url("/")))
call.enqueue(callback)
client.dispatcher.cancelAll()
callback.await(server.url("/")).assertFailure("Canceled", "Socket closed", "Socket is closed")
callback
.await(server.url("/"))
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
}
@Test
@@ -2942,7 +2946,9 @@ open class CallTest {
assertThat(server.takeRequest().url.encodedPath).isEqualTo("/a")
callback.await(requestA.url).assertBody("A")
// At this point we know the callback is ready, and that it will receive a cancel failure.
callback.await(requestB.url).assertFailure("Canceled", "Socket closed")
callback
.await(requestB.url)
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
}
@Test
@@ -4715,7 +4721,7 @@ open class CallTest {
.build()
executeSynchronously("/")
.assertFailure(IOException::class.java)
.assertFailure("Socket closed", "Socket is closed")
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
}
@Test
@@ -4805,7 +4811,8 @@ open class CallTest {
.build()
},
).build()
executeSynchronously("/").assertFailure("Canceled")
executeSynchronously("/")
.assertFailure("canceled", "Canceled", "Socket closed", "Socket is closed")
assertThat(closed.get()).isTrue()
}

View File

@@ -316,7 +316,7 @@ class ConnectionPoolTest {
val connection =
Http2Connection
.Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket())
.socket(peer.openSocket().asBufferedSocket(), "peer")
.pushObserver(Http2ConnectionTest.IGNORE)
.listener(realConnection)
.build()

View File

@@ -82,12 +82,12 @@ class HttpUpgradesTest {
socket.sink.buffer().use { sink ->
socket.source.buffer().use { source ->
sink.writeUtf8("client says hello\n")
sink.emit()
sink.flush()
assertThat(source.readUtf8Line()).isEqualTo("server says hello")
sink.writeUtf8("client says goodbye\n")
sink.emit()
sink.flush()
assertThat(source.readUtf8Line()).isEqualTo("server says goodbye")

View File

@@ -43,6 +43,7 @@ import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.concurrent.notifyAll
import okhttp3.internal.concurrent.wait
import okhttp3.internal.concurrent.withLock
import okhttp3.internal.connection.asBufferedSocket
import okio.AsyncTimeout
import okio.Buffer
import okio.BufferedSource
@@ -497,7 +498,7 @@ class Http2ConnectionTest {
val connection =
Http2Connection
.Builder(true, TaskRunner.INSTANCE)
.socket(socket)
.socket(socket.asBufferedSocket(), "peer")
.pushObserver(IGNORE)
.build()
connection.start(sendConnectionPreface = false)
@@ -1890,7 +1891,7 @@ class Http2ConnectionTest {
val connection =
Http2Connection
.Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket())
.socket(peer.openSocket().asBufferedSocket(), "peer")
.build()
connection.start(sendConnectionPreface = false)
val stream = connection.newStream(headerEntries("b", "banana"), false)
@@ -1912,11 +1913,10 @@ class Http2ConnectionTest {
peer.acceptFrame() // SYN_STREAM.
peer.play()
val taskRunner = taskFaker.taskRunner
val socket = peer.openSocket()
val connection =
Http2Connection
.Builder(true, taskRunner)
.socket(socket)
.socket(peer.openSocket().asBufferedSocket(), "peer")
.pushObserver(IGNORE)
.build()
connection.start(sendConnectionPreface = false)
@@ -1972,7 +1972,7 @@ class Http2ConnectionTest {
val connection =
Http2Connection
.Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket())
.socket(peer.openSocket().asBufferedSocket(), "peer")
.pushObserver(pushObserver)
.listener(listener)
.build()

View File

@@ -35,6 +35,7 @@ import okhttp3.Request
import okhttp3.Response
import okhttp3.TestUtil.repeat
import okhttp3.internal.concurrent.TaskFaker
import okhttp3.internal.connection.BufferedSocket
import okhttp3.internal.ws.WebSocketExtensions.Companion.parse
import okio.BufferedSink
import okio.BufferedSource
@@ -466,7 +467,7 @@ class RealWebSocketTest {
private val taskFaker: TaskFaker,
private val delegate: Socket,
private val client: Boolean,
) : Socket {
) : BufferedSocket {
private val name = if (client) "client" else "server"
val listener = WebSocketRecorder(name)
var webSocket: RealWebSocket? = null
@@ -530,7 +531,7 @@ class RealWebSocketTest {
}
}
}
webSocket!!.initReaderAndWriter(name, this, source, sink, client)
webSocket!!.initReaderAndWriter(name, this, client)
}
/**