1
0
mirror of https://github.com/square/okhttp.git synced 2026-01-14 07:22:20 +03:00

Merge pull request #5502 from square/jwilson.0926.http2

TaskRunner support for shutting down queues
This commit is contained in:
Jesse Wilson
2019-09-27 17:23:08 -04:00
committed by GitHub
11 changed files with 242 additions and 111 deletions

View File

@@ -25,6 +25,7 @@ import okhttp3.Request
import okhttp3.Response import okhttp3.Response
import okhttp3.internal.addHeaderLenient import okhttp3.internal.addHeaderLenient
import okhttp3.internal.closeQuietly import okhttp3.internal.closeQuietly
import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.duplex.MwsDuplexAccess import okhttp3.internal.duplex.MwsDuplexAccess
import okhttp3.internal.execute import okhttp3.internal.execute
import okhttp3.internal.http.HttpMethod import okhttp3.internal.http.HttpMethod
@@ -99,6 +100,7 @@ import javax.net.ssl.X509TrustManager
* in sequence. * in sequence.
*/ */
class MockWebServer : ExternalResource(), Closeable { class MockWebServer : ExternalResource(), Closeable {
private val taskRunner = TaskRunner()
private val requestQueue = LinkedBlockingQueue<RecordedRequest>() private val requestQueue = LinkedBlockingQueue<RecordedRequest>()
private val openClientSockets = private val openClientSockets =
Collections.newSetFromMap(ConcurrentHashMap<Socket, Boolean>()) Collections.newSetFromMap(ConcurrentHashMap<Socket, Boolean>())
@@ -454,6 +456,12 @@ class MockWebServer : ExternalResource(), Closeable {
} catch (e: InterruptedException) { } catch (e: InterruptedException) {
throw AssertionError() throw AssertionError()
} }
for (queue in taskRunner.activeQueues()) {
if (!queue.awaitIdle(TimeUnit.MILLISECONDS.toNanos(500L))) {
throw IOException("Gave up waiting for ${queue.owner} to shut down")
}
}
} }
@Synchronized override fun after() { @Synchronized override fun after() {
@@ -533,7 +541,7 @@ class MockWebServer : ExternalResource(), Closeable {
if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) { if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) {
val http2SocketHandler = Http2SocketHandler(socket, protocol) val http2SocketHandler = Http2SocketHandler(socket, protocol)
val connection = Http2Connection.Builder(false) val connection = Http2Connection.Builder(false, taskRunner)
.socket(socket) .socket(socket)
.listener(http2SocketHandler) .listener(http2SocketHandler)
.build() .build()

View File

@@ -29,6 +29,7 @@ import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import okhttp3.Headers; import okhttp3.Headers;
import okhttp3.Protocol; import okhttp3.Protocol;
import okhttp3.internal.concurrent.TaskRunner;
import okhttp3.internal.http2.Header; import okhttp3.internal.http2.Header;
import okhttp3.internal.http2.Http2Connection; import okhttp3.internal.http2.Http2Connection;
import okhttp3.internal.http2.Http2Stream; import okhttp3.internal.http2.Http2Stream;
@@ -69,7 +70,7 @@ public final class Http2Server extends Http2Connection.Listener {
if (protocol != Protocol.HTTP_2) { if (protocol != Protocol.HTTP_2) {
throw new ProtocolException("Protocol " + protocol + " unsupported"); throw new ProtocolException("Protocol " + protocol + " unsupported");
} }
Http2Connection connection = new Http2Connection.Builder(false) Http2Connection connection = new Http2Connection.Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket) .socket(sslSocket)
.listener(this) .listener(this)
.build(); .build();

View File

@@ -15,8 +15,6 @@
*/ */
package okhttp3 package okhttp3
import okhttp3.internal.concurrent.Task
import okhttp3.internal.concurrent.TaskQueue
import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.concurrent.TaskRunner
import okhttp3.testing.Flaky import okhttp3.testing.Flaky
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
@@ -25,7 +23,6 @@ import org.junit.runner.Description
import org.junit.runners.model.Statement import org.junit.runners.model.Statement
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.ConcurrentLinkedDeque import java.util.concurrent.ConcurrentLinkedDeque
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
/** Apply this rule to tests that need an OkHttpClient instance. */ /** Apply this rule to tests that need an OkHttpClient instance. */
@@ -66,7 +63,7 @@ class OkHttpClientTestRule : TestRule {
private fun ensureAllTaskQueuesIdle() { private fun ensureAllTaskQueuesIdle() {
for (queue in TaskRunner.INSTANCE.activeQueues()) { for (queue in TaskRunner.INSTANCE.activeQueues()) {
assertThat(queue.awaitIdle(500L, TimeUnit.MILLISECONDS)) assertThat(queue.awaitIdle(TimeUnit.MILLISECONDS.toNanos(500L)))
.withFailMessage("Queue ${queue.owner} still active after 500ms") .withFailMessage("Queue ${queue.owner} still active after 500ms")
.isTrue() .isTrue()
} }
@@ -133,19 +130,6 @@ class OkHttpClientTestRule : TestRule {
} }
} }
/** Returns true if this queue became idle before the timeout elapsed. */
private fun TaskQueue.awaitIdle(timeout: Long, timeUnit: TimeUnit): Boolean {
val latch = CountDownLatch(1)
schedule(object : Task("awaitIdle") {
override fun runOnce(): Long {
latch.countDown()
return -1L
}
})
return latch.await(timeout, timeUnit)
}
companion object { companion object {
/** /**
* Quick and dirty pool of OkHttpClient instances. Each has its own independent dispatcher and * Quick and dirty pool of OkHttpClient instances. Each has its own independent dispatcher and

View File

@@ -49,7 +49,6 @@ import java.util.LinkedHashMap
import java.util.Locale import java.util.Locale
import java.util.TimeZone import java.util.TimeZone
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.RejectedExecutionException
import java.util.concurrent.ThreadFactory import java.util.concurrent.ThreadFactory
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.text.Charsets.UTF_32BE import kotlin.text.Charsets.UTF_32BE
@@ -393,14 +392,6 @@ inline fun Executor.execute(name: String, crossinline block: () -> Unit) {
} }
} }
/** Executes [block] unless this executor has been shutdown, in which case this does nothing. */
inline fun Executor.tryExecute(name: String, crossinline block: () -> Unit) {
try {
execute(name, block)
} catch (_: RejectedExecutionException) {
}
}
fun Buffer.skipAll(b: Byte): Int { fun Buffer.skipAll(b: Byte): Int {
var count = 0 var count = 0
while (!exhausted() && this[0] == b) { while (!exhausted() && this[0] == b) {
@@ -510,34 +501,9 @@ fun Long.toHexString(): String = java.lang.Long.toHexString(this)
fun Int.toHexString(): String = Integer.toHexString(this) fun Int.toHexString(): String = Integer.toHexString(this)
/**
* Lock and wait a duration in nanoseconds. Unlike [java.lang.Object.wait] this interprets 0 as
* "don't wait" instead of "wait forever".
*/
@Throws(InterruptedException::class)
fun Any.lockAndWaitNanos(nanos: Long) {
synchronized(this) {
objectWaitNanos(nanos)
}
}
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "NOTHING_TO_INLINE") @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "NOTHING_TO_INLINE")
inline fun Any.wait() = (this as Object).wait() inline fun Any.wait() = (this as Object).wait()
/**
* Wait a duration in nanoseconds. Unlike [java.lang.Object.wait] this interprets 0 as "don't wait"
* instead of "wait forever".
*/
@Throws(InterruptedException::class)
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
fun Any.objectWaitNanos(nanos: Long) {
val ms = nanos / 1_000_000L
val ns = nanos - (ms * 1_000_000L)
if (ms > 0L || nanos > 0) {
(this as Object).wait(ms, ns.toInt())
}
}
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "NOTHING_TO_INLINE") @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "NOTHING_TO_INLINE")
inline fun Any.notify() = (this as Object).notify() inline fun Any.notify() = (this as Object).notify()

View File

@@ -16,6 +16,9 @@
package okhttp3.internal.concurrent package okhttp3.internal.concurrent
import okhttp3.internal.addIfAbsent import okhttp3.internal.addIfAbsent
import java.util.concurrent.CountDownLatch
import java.util.concurrent.RejectedExecutionException
import java.util.concurrent.TimeUnit
/** /**
* A set of tasks that are executed in sequential order. * A set of tasks that are executed in sequential order.
@@ -32,6 +35,8 @@ class TaskQueue internal constructor(
*/ */
val owner: Any val owner: Any
) { ) {
private var shutdown = false
/** This queue's currently-executing task, or null if none is currently executing. */ /** This queue's currently-executing task, or null if none is currently executing. */
private var activeTask: Task? = null private var activeTask: Task? = null
@@ -61,19 +66,55 @@ class TaskQueue internal constructor(
* The target execution time is implemented on a best-effort basis. If another task in this queue * The target execution time is implemented on a best-effort basis. If another task in this queue
* is running when that time is reached, that task is allowed to complete before this task is * is running when that time is reached, that task is allowed to complete before this task is
* started. Similarly the task will be delayed if the host lacks compute resources. * started. Similarly the task will be delayed if the host lacks compute resources.
*
* @throws RejectedExecutionException if the queue is shut down.
*/ */
fun schedule(task: Task, delayNanos: Long = 0L) { fun schedule(task: Task, delayNanos: Long = 0L) {
task.initQueue(this)
synchronized(taskRunner) { synchronized(taskRunner) {
if (shutdown) throw RejectedExecutionException()
if (scheduleAndDecide(task, delayNanos)) { if (scheduleAndDecide(task, delayNanos)) {
taskRunner.kickCoordinator(this) taskRunner.kickCoordinator(this)
} }
} }
} }
/** Like [schedule], but this silently discard the task if the queue is shut down. */
fun trySchedule(task: Task, delayNanos: Long = 0L) {
synchronized(taskRunner) {
if (shutdown) return
if (scheduleAndDecide(task, delayNanos)) {
taskRunner.kickCoordinator(this)
}
}
}
/** Returns true if this queue became idle before the timeout elapsed. */
fun awaitIdle(delayNanos: Long): Boolean {
val latch = CountDownLatch(1)
val task = object : Task("awaitIdle") {
override fun runOnce(): Long {
latch.countDown()
return -1L
}
}
// Don't delegate to schedule because that has to honor shutdown rules.
synchronized(taskRunner) {
if (scheduleAndDecide(task, 0L)) {
taskRunner.kickCoordinator(this)
}
}
return latch.await(delayNanos, TimeUnit.NANOSECONDS)
}
/** Adds [task] to run in [delayNanos]. Returns true if the coordinator should run. */ /** Adds [task] to run in [delayNanos]. Returns true if the coordinator should run. */
private fun scheduleAndDecide(task: Task, delayNanos: Long): Boolean { private fun scheduleAndDecide(task: Task, delayNanos: Long): Boolean {
task.initQueue(this)
val now = taskRunner.backend.nanoTime() val now = taskRunner.backend.nanoTime()
val executeNanoTime = now + delayNanos val executeNanoTime = now + delayNanos
@@ -100,6 +141,8 @@ class TaskQueue internal constructor(
* be removed from the execution schedule. * be removed from the execution schedule.
*/ */
fun cancelAll() { fun cancelAll() {
check(!Thread.holdsLock(this))
synchronized(taskRunner) { synchronized(taskRunner) {
if (cancelAllAndDecide()) { if (cancelAllAndDecide()) {
taskRunner.kickCoordinator(this) taskRunner.kickCoordinator(this)
@@ -107,6 +150,17 @@ class TaskQueue internal constructor(
} }
} }
fun shutdown() {
check(!Thread.holdsLock(this))
synchronized(taskRunner) {
shutdown = true
if (cancelAllAndDecide()) {
taskRunner.kickCoordinator(this)
}
}
}
/** Returns true if the coordinator should run. */ /** Returns true if the coordinator should run. */
private fun cancelAllAndDecide(): Boolean { private fun cancelAllAndDecide(): Boolean {
val runningTask = activeTask val runningTask = activeTask
@@ -160,7 +214,7 @@ class TaskQueue internal constructor(
synchronized(taskRunner) { synchronized(taskRunner) {
check(activeTask === task) check(activeTask === task)
if (delayNanos != -1L) { if (delayNanos != -1L && !shutdown) {
scheduleAndDecide(task, delayNanos) scheduleAndDecide(task, delayNanos)
} else if (!futureTasks.contains(task)) { } else if (!futureTasks.contains(task)) {
cancelTasks.remove(task) // We don't need to cancel it because it isn't scheduled. cancelTasks.remove(task) // We don't need to cancel it because it isn't scheduled.

View File

@@ -17,7 +17,6 @@ package okhttp3.internal.concurrent
import okhttp3.internal.addIfAbsent import okhttp3.internal.addIfAbsent
import okhttp3.internal.notify import okhttp3.internal.notify
import okhttp3.internal.objectWaitNanos
import okhttp3.internal.threadFactory import okhttp3.internal.threadFactory
import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.SynchronousQueue import java.util.concurrent.SynchronousQueue
@@ -158,8 +157,18 @@ class TaskRunner(
taskRunner.notify() taskRunner.notify()
} }
/**
* Wait a duration in nanoseconds. Unlike [java.lang.Object.wait] this interprets 0 as
* "don't wait" instead of "wait forever".
*/
@Throws(InterruptedException::class)
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
override fun coordinatorWait(taskRunner: TaskRunner, nanos: Long) { override fun coordinatorWait(taskRunner: TaskRunner, nanos: Long) {
taskRunner.objectWaitNanos(nanos) val ms = nanos / 1_000_000L
val ns = nanos - (ms * 1_000_000L)
if (ms > 0L || nanos > 0) {
(taskRunner as Object).wait(ms, ns.toInt())
}
} }
fun shutdown() { fun shutdown() {

View File

@@ -33,6 +33,7 @@ import okhttp3.Response
import okhttp3.Route import okhttp3.Route
import okhttp3.internal.EMPTY_RESPONSE import okhttp3.internal.EMPTY_RESPONSE
import okhttp3.internal.closeQuietly import okhttp3.internal.closeQuietly
import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.http.ExchangeCodec import okhttp3.internal.http.ExchangeCodec
import okhttp3.internal.http1.Http1ExchangeCodec import okhttp3.internal.http1.Http1ExchangeCodec
import okhttp3.internal.http2.ConnectionShutdownException import okhttp3.internal.http2.ConnectionShutdownException
@@ -321,7 +322,7 @@ class RealConnection(
val source = this.source!! val source = this.source!!
val sink = this.sink!! val sink = this.sink!!
socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream. socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream.
val http2Connection = Http2Connection.Builder(true) val http2Connection = Http2Connection.Builder(client = true, taskRunner = TaskRunner.INSTANCE)
.socket(socket, route.address.url.host, source, sink) .socket(socket, route.address.url.host, source, sink)
.listener(this) .listener(this)
.pingIntervalMillis(pingIntervalMillis) .pingIntervalMillis(pingIntervalMillis)

View File

@@ -18,9 +18,10 @@ package okhttp3.internal.http2
import okhttp3.internal.EMPTY_BYTE_ARRAY import okhttp3.internal.EMPTY_BYTE_ARRAY
import okhttp3.internal.EMPTY_HEADERS import okhttp3.internal.EMPTY_HEADERS
import okhttp3.internal.closeQuietly import okhttp3.internal.closeQuietly
import okhttp3.internal.concurrent.Task
import okhttp3.internal.concurrent.TaskRunner
import okhttp3.internal.connectionName import okhttp3.internal.connectionName
import okhttp3.internal.execute import okhttp3.internal.execute
import okhttp3.internal.format
import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM
import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE
import okhttp3.internal.ignoreIoExceptions import okhttp3.internal.ignoreIoExceptions
@@ -28,9 +29,7 @@ import okhttp3.internal.notifyAll
import okhttp3.internal.platform.Platform import okhttp3.internal.platform.Platform
import okhttp3.internal.platform.Platform.Companion.INFO import okhttp3.internal.platform.Platform.Companion.INFO
import okhttp3.internal.threadFactory import okhttp3.internal.threadFactory
import okhttp3.internal.threadName
import okhttp3.internal.toHeaders import okhttp3.internal.toHeaders
import okhttp3.internal.tryExecute
import okhttp3.internal.wait import okhttp3.internal.wait
import okio.Buffer import okio.Buffer
import okio.BufferedSink import okio.BufferedSink
@@ -43,12 +42,9 @@ import java.io.Closeable
import java.io.IOException import java.io.IOException
import java.io.InterruptedIOException import java.io.InterruptedIOException
import java.net.Socket import java.net.Socket
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ScheduledThreadPoolExecutor
import java.util.concurrent.SynchronousQueue import java.util.concurrent.SynchronousQueue
import java.util.concurrent.ThreadPoolExecutor import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MILLISECONDS
/** /**
* A socket connection to a remote peer. A connection hosts streams which can send and receive * A socket connection to a remote peer. A connection hosts streams which can send and receive
@@ -91,13 +87,10 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
internal set internal set
/** Asynchronously writes frames to the outgoing socket. */ /** Asynchronously writes frames to the outgoing socket. */
private val writerExecutor = ScheduledThreadPoolExecutor(1, private val writerQueue = builder.taskRunner.newQueue("$connectionName Writer")
threadFactory(format("OkHttp %s Writer", connectionName), false))
/** Ensures push promise callbacks events are sent in order per stream. */ /** Ensures push promise callbacks events are sent in order per stream. */
// Like newSingleThreadExecutor, except lazy creates the thread. private val pushQueue = builder.taskRunner.newQueue("$connectionName Push")
private val pushExecutor = ThreadPoolExecutor(0, 1, 60L, TimeUnit.SECONDS, LinkedBlockingQueue(),
threadFactory(format("OkHttp %s Push Observer", connectionName), true))
/** User code to run in response to push promise events. */ /** User code to run in response to push promise events. */
private val pushObserver: PushObserver = builder.pushObserver private val pushObserver: PushObserver = builder.pushObserver
@@ -149,11 +142,15 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
init { init {
if (builder.pingIntervalMillis != 0) { if (builder.pingIntervalMillis != 0) {
writerExecutor.scheduleAtFixedRate({ val pingIntervalNanos = TimeUnit.MILLISECONDS.toNanos(builder.pingIntervalMillis.toLong())
threadName("OkHttp $connectionName ping") { writerQueue.schedule(object : Task("OkHttp $connectionName ping") {
override fun runOnce(): Long {
writePing(false, 0, 0) writePing(false, 0, 0)
return pingIntervalNanos
} }
}, builder.pingIntervalMillis.toLong(), builder.pingIntervalMillis.toLong(), MILLISECONDS)
override fun tryCancel() = true
}, pingIntervalNanos)
} }
} }
@@ -328,13 +325,18 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
streamId: Int, streamId: Int,
errorCode: ErrorCode errorCode: ErrorCode
) { ) {
writerExecutor.tryExecute("OkHttp $connectionName stream $streamId") { writerQueue.trySchedule(object : Task("OkHttp $connectionName stream $streamId") {
try { override fun runOnce(): Long {
writeSynReset(streamId, errorCode) try {
} catch (e: IOException) { writeSynReset(streamId, errorCode)
failConnection(e) } catch (e: IOException) {
failConnection(e)
}
return -1L
} }
}
override fun tryCancel() = true
})
} }
@Throws(IOException::class) @Throws(IOException::class)
@@ -349,13 +351,18 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
streamId: Int, streamId: Int,
unacknowledgedBytesRead: Long unacknowledgedBytesRead: Long
) { ) {
writerExecutor.tryExecute("OkHttp Window Update $connectionName stream $streamId") { writerQueue.trySchedule(object : Task("OkHttp Window Update $connectionName stream $streamId") {
try { override fun runOnce(): Long {
writer.windowUpdate(streamId, unacknowledgedBytesRead) try {
} catch (e: IOException) { writer.windowUpdate(streamId, unacknowledgedBytesRead)
failConnection(e) } catch (e: IOException) {
failConnection(e)
}
return -1L
} }
}
override fun tryCancel() = true
})
} }
fun writePing( fun writePing(
@@ -467,8 +474,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
// Release the threads. // Release the threads.
writerExecutor.shutdown() writerQueue.shutdown()
pushExecutor.shutdown() pushQueue.shutdown()
} }
private fun failConnection(e: IOException?) { private fun failConnection(e: IOException?) {
@@ -511,7 +518,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
class Builder( class Builder(
/** True if this peer initiated the connection; false if this peer accepted the connection. */ /** True if this peer initiated the connection; false if this peer accepted the connection. */
internal var client: Boolean internal var client: Boolean,
internal val taskRunner: TaskRunner
) { ) {
internal lateinit var socket: Socket internal lateinit var socket: Socket
internal lateinit var connectionName: String internal lateinit var connectionName: String
@@ -659,9 +667,14 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
override fun settings(clearPrevious: Boolean, settings: Settings) { override fun settings(clearPrevious: Boolean, settings: Settings) {
writerExecutor.tryExecute("OkHttp $connectionName ACK Settings") { writerQueue.trySchedule(object : Task("OkHttp $connectionName ACK Settings") {
applyAndAckSettings(clearPrevious, settings) override fun runOnce(): Long {
} applyAndAckSettings(clearPrevious, settings)
return -1L
}
override fun tryCancel() = true
})
} }
/** /**
@@ -725,9 +738,14 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
} else { } else {
// Send a reply to a client ping if this is a server and vice versa. // Send a reply to a client ping if this is a server and vice versa.
writerExecutor.tryExecute("OkHttp $connectionName ping") { writerQueue.trySchedule(object : Task("OkHttp $connectionName ping") {
writePing(true, payload1, payload2) override fun runOnce(): Long {
} writePing(true, payload1, payload2)
return -1L
}
override fun tryCancel() = true
})
} }
} }
@@ -812,8 +830,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
currentPushRequests.add(streamId) currentPushRequests.add(streamId)
} }
if (!isShutdown) { pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Request[$streamId]") {
pushExecutor.tryExecute("OkHttp $connectionName Push Request[$streamId]") { override fun runOnce(): Long {
val cancel = pushObserver.onRequest(streamId, requestHeaders) val cancel = pushObserver.onRequest(streamId, requestHeaders)
ignoreIoExceptions { ignoreIoExceptions {
if (cancel) { if (cancel) {
@@ -823,8 +841,11 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
} }
} }
return -1L
} }
}
override fun tryCancel() = true
})
} }
internal fun pushHeadersLater( internal fun pushHeadersLater(
@@ -832,8 +853,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
requestHeaders: List<Header>, requestHeaders: List<Header>,
inFinished: Boolean inFinished: Boolean
) { ) {
if (!isShutdown) { pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Headers[$streamId]") {
pushExecutor.tryExecute("OkHttp $connectionName Push Headers[$streamId]") { override fun runOnce(): Long {
val cancel = pushObserver.onHeaders(streamId, requestHeaders, inFinished) val cancel = pushObserver.onHeaders(streamId, requestHeaders, inFinished)
ignoreIoExceptions { ignoreIoExceptions {
if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL)
@@ -843,8 +864,11 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
} }
} }
return -1L
} }
}
override fun tryCancel() = true
})
} }
/** /**
@@ -861,8 +885,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
val buffer = Buffer() val buffer = Buffer()
source.require(byteCount.toLong()) // Eagerly read the frame before firing client thread. source.require(byteCount.toLong()) // Eagerly read the frame before firing client thread.
source.read(buffer, byteCount.toLong()) source.read(buffer, byteCount.toLong())
if (!isShutdown) { pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Data[$streamId]") {
pushExecutor.execute("OkHttp $connectionName Push Data[$streamId]") { override fun runOnce(): Long {
ignoreIoExceptions { ignoreIoExceptions {
val cancel = pushObserver.onData(streamId, buffer, byteCount, inFinished) val cancel = pushObserver.onData(streamId, buffer, byteCount, inFinished)
if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL)
@@ -872,19 +896,23 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
} }
} }
return -1L
} }
}
override fun tryCancel() = true
})
} }
internal fun pushResetLater(streamId: Int, errorCode: ErrorCode) { internal fun pushResetLater(streamId: Int, errorCode: ErrorCode) {
if (!isShutdown) { pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Reset[$streamId]") {
pushExecutor.execute("OkHttp $connectionName Push Reset[$streamId]") { override fun runOnce(): Long {
pushObserver.onReset(streamId, errorCode) pushObserver.onReset(streamId, errorCode)
synchronized(this@Http2Connection) { synchronized(this@Http2Connection) {
currentPushRequests.remove(streamId) currentPushRequests.remove(streamId)
} }
return -1L
} }
} })
} }
/** Listener of streams and settings initiated by the peer. */ /** Listener of streams and settings initiated by the peer. */

View File

@@ -40,7 +40,7 @@ class TaskFaker {
/** How many tasks can be executed immediately. */ /** How many tasks can be executed immediately. */
val tasksSize: Int get() = tasks.size val tasksSize: Int get() = tasks.size
/** Guarded by taskRunner. */ /** Guarded by [taskRunner]. */
var nanoTime = 0L var nanoTime = 0L
private set private set
@@ -143,7 +143,9 @@ class TaskFaker {
fun assertNoMoreTasks() { fun assertNoMoreTasks() {
assertThat(coordinatorToRun).isNull() assertThat(coordinatorToRun).isNull()
assertThat(tasks).isEmpty() assertThat(tasks).isEmpty()
assertThat(coordinatorWaitingUntilTime).isEqualTo(Long.MAX_VALUE) assertThat(coordinatorWaitingUntilTime)
.withFailMessage("tasks are scheduled to run at $coordinatorWaitingUntilTime")
.isEqualTo(Long.MAX_VALUE)
} }
fun interruptCoordinatorThread() { fun interruptCoordinatorThread() {

View File

@@ -16,7 +16,9 @@
package okhttp3.internal.concurrent package okhttp3.internal.concurrent
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.fail
import org.junit.Test import org.junit.Test
import java.util.concurrent.RejectedExecutionException
class TaskRunnerTest { class TaskRunnerTest {
private val taskFaker = TaskFaker() private val taskFaker = TaskFaker()
@@ -470,4 +472,79 @@ class TaskRunnerTest {
taskFaker.assertNoMoreTasks() taskFaker.assertNoMoreTasks()
} }
@Test fun shutdownSuccessfullyCancelsScheduledTasks() {
redQueue.schedule(object : Task("task") {
override fun runOnce(): Long {
log += "run@${taskFaker.nanoTime}"
return -1L
}
override fun tryCancel(): Boolean {
log += "cancel@${taskFaker.nanoTime}"
return true
}
}, 100L)
taskFaker.advanceUntil(0L)
assertThat(log).isEmpty()
redQueue.shutdown()
taskFaker.advanceUntil(99L)
assertThat(log).containsExactly("cancel@99")
taskFaker.assertNoMoreTasks()
}
@Test fun shutdownFailsToCancelsScheduledTasks() {
redQueue.schedule(object : Task("task") {
override fun runOnce(): Long {
log += "run@${taskFaker.nanoTime}"
return 50L
}
override fun tryCancel(): Boolean {
log += "cancel@${taskFaker.nanoTime}"
return false
}
}, 100L)
taskFaker.advanceUntil(0L)
assertThat(log).isEmpty()
redQueue.shutdown()
taskFaker.advanceUntil(99L)
assertThat(log).containsExactly("cancel@99")
taskFaker.advanceUntil(100L)
assertThat(log).containsExactly("cancel@99", "run@100")
taskFaker.assertNoMoreTasks()
}
@Test fun scheduleDiscardsTaskWhenShutdown() {
redQueue.shutdown()
redQueue.trySchedule(object : Task("task") {
override fun runOnce() = -1L
}, 100L)
taskFaker.assertNoMoreTasks()
}
@Test fun scheduleThrowsWhenShutdown() {
redQueue.shutdown()
try {
redQueue.schedule(object : Task("task") {
override fun runOnce() = -1L
}, 100L)
fail()
} catch (_: RejectedExecutionException) {
}
taskFaker.assertNoMoreTasks()
}
} }

View File

@@ -27,6 +27,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import okhttp3.Headers; import okhttp3.Headers;
import okhttp3.internal.Util; import okhttp3.internal.Util;
import okhttp3.internal.concurrent.TaskRunner;
import okhttp3.internal.http2.MockHttp2Peer.InFrame; import okhttp3.internal.http2.MockHttp2Peer.InFrame;
import okio.AsyncTimeout; import okio.AsyncTimeout;
import okio.Buffer; import okio.Buffer;
@@ -498,7 +499,7 @@ public final class Http2ConnectionTest {
String longString = repeat('a', Http2.INITIAL_MAX_FRAME_SIZE + 1); String longString = repeat('a', Http2.INITIAL_MAX_FRAME_SIZE + 1);
Socket socket = peer.openSocket(); Socket socket = peer.openSocket();
Http2Connection connection = new Http2Connection.Builder(true) Http2Connection connection = new Http2Connection.Builder(true, TaskRunner.INSTANCE)
.socket(socket) .socket(socket)
.pushObserver(IGNORE) .pushObserver(IGNORE)
.build(); .build();
@@ -1786,7 +1787,7 @@ public final class Http2ConnectionTest {
peer.acceptFrame(); // GOAWAY peer.acceptFrame(); // GOAWAY
peer.play(); peer.play();
Http2Connection connection = new Http2Connection.Builder(true) Http2Connection connection = new Http2Connection.Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket()) .socket(peer.openSocket())
.build(); .build();
connection.start(false); connection.start(false);
@@ -1857,7 +1858,7 @@ public final class Http2ConnectionTest {
/** Builds a new connection to {@code peer} with settings acked. */ /** Builds a new connection to {@code peer} with settings acked. */
private Http2Connection connect(MockHttp2Peer peer, PushObserver pushObserver, private Http2Connection connect(MockHttp2Peer peer, PushObserver pushObserver,
Http2Connection.Listener listener) throws Exception { Http2Connection.Listener listener) throws Exception {
Http2Connection connection = new Http2Connection.Builder(true) Http2Connection connection = new Http2Connection.Builder(true, TaskRunner.INSTANCE)
.socket(peer.openSocket()) .socket(peer.openSocket())
.pushObserver(pushObserver) .pushObserver(pushObserver)
.listener(listener) .listener(listener)