1
0
mirror of https://github.com/square/okhttp.git synced 2025-11-24 18:41:06 +03:00

Make all TaskRunner tasks cancelable

This probably should have been the case all along. Unfortunately,
ExecutorService Runnables are not cancelable by default, and that's
where we started.

After implementing all of TaskRunner it looks like where we're
cancelable and where we aren't is totally arbitrary. Making everything
cancelable simplifies the implementation and model.

The last remaining non-cancelable tasks:
 * awaitIdle() which we use in our tests only.
 * MockWebServer, where canceling would leak sockets
This commit is contained in:
Jesse Wilson
2019-10-06 17:38:37 -04:00
parent 4096e375c0
commit 0a691dcadb
8 changed files with 58 additions and 71 deletions

View File

@@ -389,7 +389,7 @@ class MockWebServer : ExternalResource(), Closeable {
portField = serverSocket!!.localPort portField = serverSocket!!.localPort
taskRunner.newQueue().execute("MockWebServer $portField") { taskRunner.newQueue().execute("MockWebServer $portField", cancelable = false) {
try { try {
logger.info("${this@MockWebServer} starting to accept connections") logger.info("${this@MockWebServer} starting to accept connections")
acceptConnections() acceptConnections()
@@ -464,7 +464,7 @@ class MockWebServer : ExternalResource(), Closeable {
} }
private fun serveConnection(raw: Socket) { private fun serveConnection(raw: Socket) {
taskRunner.newQueue().execute("MockWebServer ${raw.remoteSocketAddress}") { taskRunner.newQueue().execute("MockWebServer ${raw.remoteSocketAddress}", cancelable = false) {
try { try {
SocketHandler(raw).handle() SocketHandler(raw).handle()
} catch (e: IOException) { } catch (e: IOException) {

View File

@@ -168,7 +168,7 @@ class DiskLruCache internal constructor(
private var nextSequenceNumber: Long = 0 private var nextSequenceNumber: Long = 0
private val cleanupQueue = taskRunner.newQueue() private val cleanupQueue = taskRunner.newQueue()
private val cleanupTask = object : Task("OkHttp Cache", cancelable = false) { private val cleanupTask = object : Task("OkHttp Cache") {
override fun runOnce(): Long { override fun runOnce(): Long {
synchronized(this@DiskLruCache) { synchronized(this@DiskLruCache) {
if (!initialized || closed) { if (!initialized || closed) {

View File

@@ -48,7 +48,7 @@ package okhttp3.internal.concurrent
*/ */
abstract class Task( abstract class Task(
val name: String, val name: String,
val cancelable: Boolean val cancelable: Boolean = true
) { ) {
// Guarded by the TaskRunner. // Guarded by the TaskRunner.
internal var queue: TaskQueue? = null internal var queue: TaskQueue? = null

View File

@@ -54,11 +54,14 @@ class TaskQueue internal constructor(
* is running when that time is reached, that task is allowed to complete before this task is * is running when that time is reached, that task is allowed to complete before this task is
* started. Similarly the task will be delayed if the host lacks compute resources. * started. Similarly the task will be delayed if the host lacks compute resources.
* *
* @throws RejectedExecutionException if the queue is shut down. * @throws RejectedExecutionException if the queue is shut down and the task is not cancelable.
*/ */
fun schedule(task: Task, delayNanos: Long = 0L) { fun schedule(task: Task, delayNanos: Long = 0L) {
synchronized(taskRunner) { synchronized(taskRunner) {
if (shutdown) throw RejectedExecutionException() if (shutdown) {
if (task.cancelable) return
throw RejectedExecutionException()
}
if (scheduleAndDecide(task, delayNanos)) { if (scheduleAndDecide(task, delayNanos)) {
taskRunner.kickCoordinator(this) taskRunner.kickCoordinator(this)
@@ -70,25 +73,13 @@ class TaskQueue internal constructor(
inline fun schedule( inline fun schedule(
name: String, name: String,
delayNanos: Long = 0L, delayNanos: Long = 0L,
cancelable: Boolean = true,
crossinline block: () -> Long crossinline block: () -> Long
) { ) {
schedule(object : Task(name, cancelable) { schedule(object : Task(name) {
override fun runOnce() = block() override fun runOnce() = block()
}, delayNanos) }, delayNanos)
} }
/** Like [schedule], but this silently discard the task if the queue is shut down. */
fun trySchedule(task: Task, delayNanos: Long = 0L) {
synchronized(taskRunner) {
if (shutdown) return
if (scheduleAndDecide(task, delayNanos)) {
taskRunner.kickCoordinator(this)
}
}
}
/** Executes [block] once on a task runner thread. */ /** Executes [block] once on a task runner thread. */
inline fun execute( inline fun execute(
name: String, name: String,
@@ -104,21 +95,6 @@ class TaskQueue internal constructor(
}, delayNanos) }, delayNanos)
} }
/** Like [execute], but this silently discard the task if the queue is shut down. */
inline fun tryExecute(
name: String,
delayNanos: Long = 0L,
cancelable: Boolean = true,
crossinline block: () -> Unit
) {
trySchedule(object : Task(name, cancelable) {
override fun runOnce(): Long {
block()
return -1L
}
}, delayNanos)
}
/** Returns true if this queue became idle before the timeout elapsed. */ /** Returns true if this queue became idle before the timeout elapsed. */
fun awaitIdle(delayNanos: Long): Boolean { fun awaitIdle(delayNanos: Long): Boolean {
val latch = CountDownLatch(1) val latch = CountDownLatch(1)

View File

@@ -40,7 +40,7 @@ class RealConnectionPool(
private val keepAliveDurationNs: Long = timeUnit.toNanos(keepAliveDuration) private val keepAliveDurationNs: Long = timeUnit.toNanos(keepAliveDuration)
private val cleanupQueue: TaskQueue = taskRunner.newQueue() private val cleanupQueue: TaskQueue = taskRunner.newQueue()
private val cleanupTask = object : Task("OkHttp ConnectionPool", cancelable = true) { private val cleanupTask = object : Task("OkHttp ConnectionPool") {
override fun runOnce() = cleanup(System.nanoTime()) override fun runOnce() = cleanup(System.nanoTime())
} }

View File

@@ -315,7 +315,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
streamId: Int, streamId: Int,
errorCode: ErrorCode errorCode: ErrorCode
) { ) {
writerQueue.tryExecute("$connectionName[$streamId] writeSynReset") { writerQueue.execute("$connectionName[$streamId] writeSynReset") {
try { try {
writeSynReset(streamId, errorCode) writeSynReset(streamId, errorCode)
} catch (e: IOException) { } catch (e: IOException) {
@@ -336,7 +336,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
streamId: Int, streamId: Int,
unacknowledgedBytesRead: Long unacknowledgedBytesRead: Long
) { ) {
writerQueue.tryExecute("$connectionName[$streamId] windowUpdate") { writerQueue.execute("$connectionName[$streamId] windowUpdate") {
try { try {
writer.windowUpdate(streamId, unacknowledgedBytesRead) writer.windowUpdate(streamId, unacknowledgedBytesRead)
} catch (e: IOException) { } catch (e: IOException) {
@@ -625,8 +625,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
streams[streamId] = newStream streams[streamId] = newStream
// Use a different task queue for each stream because they should be handled in parallel. // Use a different task queue for each stream because they should be handled in parallel.
val taskName = "$connectionName[$streamId] onStream" taskRunner.newQueue().execute("$connectionName[$streamId] onStream") {
taskRunner.newQueue().execute(taskName, cancelable = false) {
try { try {
listener.onStream(newStream) listener.onStream(newStream)
} catch (e: IOException) { } catch (e: IOException) {
@@ -654,7 +653,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
override fun settings(clearPrevious: Boolean, settings: Settings) { override fun settings(clearPrevious: Boolean, settings: Settings) {
writerQueue.tryExecute("$connectionName applyAndAckSettings") { writerQueue.execute("$connectionName applyAndAckSettings") {
applyAndAckSettings(clearPrevious, settings) applyAndAckSettings(clearPrevious, settings)
} }
} }
@@ -697,7 +696,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
peerSettings = newPeerSettings peerSettings = newPeerSettings
settingsListenerQueue.tryExecute("$connectionName onSettings", cancelable = false) { settingsListenerQueue.execute("$connectionName onSettings") {
listener.onSettings(this@Http2Connection, newPeerSettings) listener.onSettings(this@Http2Connection, newPeerSettings)
} }
} }
@@ -732,7 +731,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
} else { } else {
// Send a reply to a client ping if this is a server and vice versa. // Send a reply to a client ping if this is a server and vice versa.
writerQueue.tryExecute("$connectionName ping") { writerQueue.execute("$connectionName ping") {
writePing(true, payload1, payload2) writePing(true, payload1, payload2)
} }
} }
@@ -819,7 +818,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
currentPushRequests.add(streamId) currentPushRequests.add(streamId)
} }
pushQueue.tryExecute("$connectionName[$streamId] onRequest") { pushQueue.execute("$connectionName[$streamId] onRequest") {
val cancel = pushObserver.onRequest(streamId, requestHeaders) val cancel = pushObserver.onRequest(streamId, requestHeaders)
ignoreIoExceptions { ignoreIoExceptions {
if (cancel) { if (cancel) {
@@ -837,7 +836,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
requestHeaders: List<Header>, requestHeaders: List<Header>,
inFinished: Boolean inFinished: Boolean
) { ) {
pushQueue.tryExecute("$connectionName[$streamId] onHeaders") { pushQueue.execute("$connectionName[$streamId] onHeaders") {
val cancel = pushObserver.onHeaders(streamId, requestHeaders, inFinished) val cancel = pushObserver.onHeaders(streamId, requestHeaders, inFinished)
ignoreIoExceptions { ignoreIoExceptions {
if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL)
@@ -864,7 +863,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
val buffer = Buffer() val buffer = Buffer()
source.require(byteCount.toLong()) // Eagerly read the frame before firing client thread. source.require(byteCount.toLong()) // Eagerly read the frame before firing client thread.
source.read(buffer, byteCount.toLong()) source.read(buffer, byteCount.toLong())
pushQueue.tryExecute("$connectionName[$streamId] onData") { pushQueue.execute("$connectionName[$streamId] onData") {
ignoreIoExceptions { ignoreIoExceptions {
val cancel = pushObserver.onData(streamId, buffer, byteCount, inFinished) val cancel = pushObserver.onData(streamId, buffer, byteCount, inFinished)
if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL)
@@ -878,7 +877,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable {
} }
internal fun pushResetLater(streamId: Int, errorCode: ErrorCode) { internal fun pushResetLater(streamId: Int, errorCode: ErrorCode) {
pushQueue.tryExecute("$connectionName[$streamId] onReset", cancelable = false) { pushQueue.execute("$connectionName[$streamId] onReset") {
pushObserver.onReset(streamId, errorCode) pushObserver.onReset(streamId, errorCode)
synchronized(this@Http2Connection) { synchronized(this@Http2Connection) {
currentPushRequests.remove(streamId) currentPushRequests.remove(streamId)

View File

@@ -393,7 +393,7 @@ class RealWebSocket(
private fun runWriter() { private fun runWriter() {
assert(Thread.holdsLock(this)) assert(Thread.holdsLock(this))
taskQueue.trySchedule(writerTask!!) taskQueue.schedule(writerTask!!)
} }
/** /**
@@ -535,7 +535,7 @@ class RealWebSocket(
val sink: BufferedSink val sink: BufferedSink
) : Closeable ) : Closeable
private inner class WriterTask : Task("$name writer", cancelable = true) { private inner class WriterTask : Task("$name writer") {
override fun runOnce(): Long { override fun runOnce(): Long {
try { try {
if (writeOneFrame()) return 0L if (writeOneFrame()) return 0L

View File

@@ -72,7 +72,7 @@ class TaskRunnerTest {
/** Repeat with a delay of 200 but schedule with a delay of 50. The schedule wins. */ /** Repeat with a delay of 200 but schedule with a delay of 50. The schedule wins. */
@Test fun executeScheduledEarlierReplacesRepeatedLater() { @Test fun executeScheduledEarlierReplacesRepeatedLater() {
val task = object : Task("task", cancelable = true) { val task = object : Task("task") {
val schedules = mutableListOf(50L) val schedules = mutableListOf(50L)
val delays = mutableListOf(200L, -1L) val delays = mutableListOf(200L, -1L)
override fun runOnce(): Long { override fun runOnce(): Long {
@@ -99,7 +99,7 @@ class TaskRunnerTest {
/** Schedule with a delay of 200 but repeat with a delay of 50. The repeat wins. */ /** Schedule with a delay of 200 but repeat with a delay of 50. The repeat wins. */
@Test fun executeRepeatedEarlierReplacesScheduledLater() { @Test fun executeRepeatedEarlierReplacesScheduledLater() {
val task = object : Task("task", cancelable = true) { val task = object : Task("task") {
val schedules = mutableListOf(200L) val schedules = mutableListOf(200L)
val delays = mutableListOf(50L, -1L) val delays = mutableListOf(50L, -1L)
override fun runOnce(): Long { override fun runOnce(): Long {
@@ -141,9 +141,12 @@ class TaskRunnerTest {
} }
@Test fun cancelReturnsFalseDoesNotCancel() { @Test fun cancelReturnsFalseDoesNotCancel() {
redQueue.execute("task", 100L, cancelable = false) { redQueue.schedule(object : Task("task", cancelable = false) {
log += "run@${taskFaker.nanoTime}" override fun runOnce(): Long {
} log += "run@${taskFaker.nanoTime}"
return -1L
}
}, 100L)
taskFaker.advanceUntil(0L) taskFaker.advanceUntil(0L)
assertThat(log).isEmpty() assertThat(log).isEmpty()
@@ -191,12 +194,14 @@ class TaskRunnerTest {
} }
@Test fun cancelWhileExecutingDoesNotStopUncancelableTask() { @Test fun cancelWhileExecutingDoesNotStopUncancelableTask() {
val delays = mutableListOf(50L, -1L) redQueue.schedule(object : Task("task", cancelable = false) {
redQueue.schedule("task", 100L, cancelable = false) { val delays = mutableListOf(50L, -1L)
log += "run@${taskFaker.nanoTime}" override fun runOnce(): Long {
redQueue.cancelAll() log += "run@${taskFaker.nanoTime}"
return@schedule delays.removeAt(0) redQueue.cancelAll()
} return delays.removeAt(0)
}
}, 100L)
taskFaker.advanceUntil(0L) taskFaker.advanceUntil(0L)
assertThat(log).isEmpty() assertThat(log).isEmpty()
@@ -227,9 +232,12 @@ class TaskRunnerTest {
} }
@Test fun interruptingCoordinatorAttemptsToCancelsAndFails() { @Test fun interruptingCoordinatorAttemptsToCancelsAndFails() {
redQueue.execute("task", 100L, cancelable = false) { redQueue.schedule(object : Task("task", cancelable = false) {
log += "run@${taskFaker.nanoTime}" override fun runOnce(): Long {
} log += "run@${taskFaker.nanoTime}"
return -1L
}
}, 100L)
taskFaker.advanceUntil(0L) taskFaker.advanceUntil(0L)
assertThat(log).isEmpty() assertThat(log).isEmpty()
@@ -317,7 +325,7 @@ class TaskRunnerTest {
* cumbersome to implement properly because the active task might be a cancel. * cumbersome to implement properly because the active task might be a cancel.
*/ */
@Test fun scheduledTasksDoesNotIncludeRunningTask() { @Test fun scheduledTasksDoesNotIncludeRunningTask() {
val task = object : Task("task one", cancelable = true) { val task = object : Task("task one") {
val schedules = mutableListOf(200L) val schedules = mutableListOf(200L)
override fun runOnce(): Long { override fun runOnce(): Long {
if (schedules.isNotEmpty()) { if (schedules.isNotEmpty()) {
@@ -408,10 +416,12 @@ class TaskRunnerTest {
} }
@Test fun shutdownFailsToCancelsScheduledTasks() { @Test fun shutdownFailsToCancelsScheduledTasks() {
redQueue.schedule("task", 100L, cancelable = false) { redQueue.schedule(object : Task("task", false) {
log += "run@${taskFaker.nanoTime}" override fun runOnce(): Long {
return@schedule 50L log += "run@${taskFaker.nanoTime}"
} return 50L
}
}, 100L)
taskFaker.advanceUntil(0L) taskFaker.advanceUntil(0L)
assertThat(log).isEmpty() assertThat(log).isEmpty()
@@ -430,7 +440,7 @@ class TaskRunnerTest {
@Test fun scheduleDiscardsTaskWhenShutdown() { @Test fun scheduleDiscardsTaskWhenShutdown() {
redQueue.shutdown() redQueue.shutdown()
redQueue.tryExecute("task", 100L) { redQueue.execute("task", 100L) {
// Do nothing. // Do nothing.
} }
@@ -441,9 +451,11 @@ class TaskRunnerTest {
redQueue.shutdown() redQueue.shutdown()
try { try {
redQueue.execute("task", 100L) { redQueue.schedule(object : Task("task", cancelable = false) {
// Do nothing. override fun runOnce(): Long {
} return -1L
}
}, 100L)
fail() fail()
} catch (_: RejectedExecutionException) { } catch (_: RejectedExecutionException) {
} }