From 0a691dcadb04b2202a0744fd1e609fd12a922341 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Sun, 6 Oct 2019 17:38:37 -0400 Subject: [PATCH] Make all TaskRunner tasks cancelable This probably should have been the case all along. Unfortunately, ExecutorService Runnables are not cancelable by default, and that's where we started. After implementing all of TaskRunner it looks like where we're cancelable and where we aren't is totally arbitrary. Making everything cancelable simplifies the implementation and model. The last remaining non-cancelable tasks: * awaitIdle() which we use in our tests only. * MockWebServer, where canceling would leak sockets --- .../okhttp3/mockwebserver/MockWebServer.kt | 4 +- .../okhttp3/internal/cache/DiskLruCache.kt | 2 +- .../java/okhttp3/internal/concurrent/Task.kt | 2 +- .../okhttp3/internal/concurrent/TaskQueue.kt | 36 ++---------- .../internal/connection/RealConnectionPool.kt | 2 +- .../okhttp3/internal/http2/Http2Connection.kt | 21 ++++--- .../java/okhttp3/internal/ws/RealWebSocket.kt | 4 +- .../internal/concurrent/TaskRunnerTest.kt | 58 +++++++++++-------- 8 files changed, 58 insertions(+), 71 deletions(-) diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt index e75b15e6d..f6782c11c 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt @@ -389,7 +389,7 @@ class MockWebServer : ExternalResource(), Closeable { portField = serverSocket!!.localPort - taskRunner.newQueue().execute("MockWebServer $portField") { + taskRunner.newQueue().execute("MockWebServer $portField", cancelable = false) { try { logger.info("${this@MockWebServer} starting to accept connections") acceptConnections() @@ -464,7 +464,7 @@ class MockWebServer : ExternalResource(), Closeable { } private fun serveConnection(raw: Socket) { - taskRunner.newQueue().execute("MockWebServer ${raw.remoteSocketAddress}") { + taskRunner.newQueue().execute("MockWebServer ${raw.remoteSocketAddress}", cancelable = false) { try { SocketHandler(raw).handle() } catch (e: IOException) { diff --git a/okhttp/src/main/java/okhttp3/internal/cache/DiskLruCache.kt b/okhttp/src/main/java/okhttp3/internal/cache/DiskLruCache.kt index 113d58e48..477b0cfa7 100644 --- a/okhttp/src/main/java/okhttp3/internal/cache/DiskLruCache.kt +++ b/okhttp/src/main/java/okhttp3/internal/cache/DiskLruCache.kt @@ -168,7 +168,7 @@ class DiskLruCache internal constructor( private var nextSequenceNumber: Long = 0 private val cleanupQueue = taskRunner.newQueue() - private val cleanupTask = object : Task("OkHttp Cache", cancelable = false) { + private val cleanupTask = object : Task("OkHttp Cache") { override fun runOnce(): Long { synchronized(this@DiskLruCache) { if (!initialized || closed) { diff --git a/okhttp/src/main/java/okhttp3/internal/concurrent/Task.kt b/okhttp/src/main/java/okhttp3/internal/concurrent/Task.kt index 9bae6e693..61cef201a 100644 --- a/okhttp/src/main/java/okhttp3/internal/concurrent/Task.kt +++ b/okhttp/src/main/java/okhttp3/internal/concurrent/Task.kt @@ -48,7 +48,7 @@ package okhttp3.internal.concurrent */ abstract class Task( val name: String, - val cancelable: Boolean + val cancelable: Boolean = true ) { // Guarded by the TaskRunner. internal var queue: TaskQueue? = null diff --git a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt index b533922f5..13ed3c802 100644 --- a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt +++ b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt @@ -54,11 +54,14 @@ class TaskQueue internal constructor( * is running when that time is reached, that task is allowed to complete before this task is * started. Similarly the task will be delayed if the host lacks compute resources. * - * @throws RejectedExecutionException if the queue is shut down. + * @throws RejectedExecutionException if the queue is shut down and the task is not cancelable. */ fun schedule(task: Task, delayNanos: Long = 0L) { synchronized(taskRunner) { - if (shutdown) throw RejectedExecutionException() + if (shutdown) { + if (task.cancelable) return + throw RejectedExecutionException() + } if (scheduleAndDecide(task, delayNanos)) { taskRunner.kickCoordinator(this) @@ -70,25 +73,13 @@ class TaskQueue internal constructor( inline fun schedule( name: String, delayNanos: Long = 0L, - cancelable: Boolean = true, crossinline block: () -> Long ) { - schedule(object : Task(name, cancelable) { + schedule(object : Task(name) { override fun runOnce() = block() }, delayNanos) } - /** Like [schedule], but this silently discard the task if the queue is shut down. */ - fun trySchedule(task: Task, delayNanos: Long = 0L) { - synchronized(taskRunner) { - if (shutdown) return - - if (scheduleAndDecide(task, delayNanos)) { - taskRunner.kickCoordinator(this) - } - } - } - /** Executes [block] once on a task runner thread. */ inline fun execute( name: String, @@ -104,21 +95,6 @@ class TaskQueue internal constructor( }, delayNanos) } - /** Like [execute], but this silently discard the task if the queue is shut down. */ - inline fun tryExecute( - name: String, - delayNanos: Long = 0L, - cancelable: Boolean = true, - crossinline block: () -> Unit - ) { - trySchedule(object : Task(name, cancelable) { - override fun runOnce(): Long { - block() - return -1L - } - }, delayNanos) - } - /** Returns true if this queue became idle before the timeout elapsed. */ fun awaitIdle(delayNanos: Long): Boolean { val latch = CountDownLatch(1) diff --git a/okhttp/src/main/java/okhttp3/internal/connection/RealConnectionPool.kt b/okhttp/src/main/java/okhttp3/internal/connection/RealConnectionPool.kt index e7477c829..34bda48ab 100644 --- a/okhttp/src/main/java/okhttp3/internal/connection/RealConnectionPool.kt +++ b/okhttp/src/main/java/okhttp3/internal/connection/RealConnectionPool.kt @@ -40,7 +40,7 @@ class RealConnectionPool( private val keepAliveDurationNs: Long = timeUnit.toNanos(keepAliveDuration) private val cleanupQueue: TaskQueue = taskRunner.newQueue() - private val cleanupTask = object : Task("OkHttp ConnectionPool", cancelable = true) { + private val cleanupTask = object : Task("OkHttp ConnectionPool") { override fun runOnce() = cleanup(System.nanoTime()) } diff --git a/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt b/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt index b4918d130..7b6a40272 100644 --- a/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt +++ b/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt @@ -315,7 +315,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { streamId: Int, errorCode: ErrorCode ) { - writerQueue.tryExecute("$connectionName[$streamId] writeSynReset") { + writerQueue.execute("$connectionName[$streamId] writeSynReset") { try { writeSynReset(streamId, errorCode) } catch (e: IOException) { @@ -336,7 +336,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { streamId: Int, unacknowledgedBytesRead: Long ) { - writerQueue.tryExecute("$connectionName[$streamId] windowUpdate") { + writerQueue.execute("$connectionName[$streamId] windowUpdate") { try { writer.windowUpdate(streamId, unacknowledgedBytesRead) } catch (e: IOException) { @@ -625,8 +625,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { streams[streamId] = newStream // Use a different task queue for each stream because they should be handled in parallel. - val taskName = "$connectionName[$streamId] onStream" - taskRunner.newQueue().execute(taskName, cancelable = false) { + taskRunner.newQueue().execute("$connectionName[$streamId] onStream") { try { listener.onStream(newStream) } catch (e: IOException) { @@ -654,7 +653,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } override fun settings(clearPrevious: Boolean, settings: Settings) { - writerQueue.tryExecute("$connectionName applyAndAckSettings") { + writerQueue.execute("$connectionName applyAndAckSettings") { applyAndAckSettings(clearPrevious, settings) } } @@ -697,7 +696,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { peerSettings = newPeerSettings - settingsListenerQueue.tryExecute("$connectionName onSettings", cancelable = false) { + settingsListenerQueue.execute("$connectionName onSettings") { listener.onSettings(this@Http2Connection, newPeerSettings) } } @@ -732,7 +731,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } } else { // Send a reply to a client ping if this is a server and vice versa. - writerQueue.tryExecute("$connectionName ping") { + writerQueue.execute("$connectionName ping") { writePing(true, payload1, payload2) } } @@ -819,7 +818,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } currentPushRequests.add(streamId) } - pushQueue.tryExecute("$connectionName[$streamId] onRequest") { + pushQueue.execute("$connectionName[$streamId] onRequest") { val cancel = pushObserver.onRequest(streamId, requestHeaders) ignoreIoExceptions { if (cancel) { @@ -837,7 +836,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { requestHeaders: List
, inFinished: Boolean ) { - pushQueue.tryExecute("$connectionName[$streamId] onHeaders") { + pushQueue.execute("$connectionName[$streamId] onHeaders") { val cancel = pushObserver.onHeaders(streamId, requestHeaders, inFinished) ignoreIoExceptions { if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) @@ -864,7 +863,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { val buffer = Buffer() source.require(byteCount.toLong()) // Eagerly read the frame before firing client thread. source.read(buffer, byteCount.toLong()) - pushQueue.tryExecute("$connectionName[$streamId] onData") { + pushQueue.execute("$connectionName[$streamId] onData") { ignoreIoExceptions { val cancel = pushObserver.onData(streamId, buffer, byteCount, inFinished) if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) @@ -878,7 +877,7 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } internal fun pushResetLater(streamId: Int, errorCode: ErrorCode) { - pushQueue.tryExecute("$connectionName[$streamId] onReset", cancelable = false) { + pushQueue.execute("$connectionName[$streamId] onReset") { pushObserver.onReset(streamId, errorCode) synchronized(this@Http2Connection) { currentPushRequests.remove(streamId) diff --git a/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt b/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt index d8627e194..b93e1c233 100644 --- a/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt +++ b/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt @@ -393,7 +393,7 @@ class RealWebSocket( private fun runWriter() { assert(Thread.holdsLock(this)) - taskQueue.trySchedule(writerTask!!) + taskQueue.schedule(writerTask!!) } /** @@ -535,7 +535,7 @@ class RealWebSocket( val sink: BufferedSink ) : Closeable - private inner class WriterTask : Task("$name writer", cancelable = true) { + private inner class WriterTask : Task("$name writer") { override fun runOnce(): Long { try { if (writeOneFrame()) return 0L diff --git a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt index 9fdbdf9d9..f6066b4ca 100644 --- a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt +++ b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt @@ -72,7 +72,7 @@ class TaskRunnerTest { /** Repeat with a delay of 200 but schedule with a delay of 50. The schedule wins. */ @Test fun executeScheduledEarlierReplacesRepeatedLater() { - val task = object : Task("task", cancelable = true) { + val task = object : Task("task") { val schedules = mutableListOf(50L) val delays = mutableListOf(200L, -1L) override fun runOnce(): Long { @@ -99,7 +99,7 @@ class TaskRunnerTest { /** Schedule with a delay of 200 but repeat with a delay of 50. The repeat wins. */ @Test fun executeRepeatedEarlierReplacesScheduledLater() { - val task = object : Task("task", cancelable = true) { + val task = object : Task("task") { val schedules = mutableListOf(200L) val delays = mutableListOf(50L, -1L) override fun runOnce(): Long { @@ -141,9 +141,12 @@ class TaskRunnerTest { } @Test fun cancelReturnsFalseDoesNotCancel() { - redQueue.execute("task", 100L, cancelable = false) { - log += "run@${taskFaker.nanoTime}" - } + redQueue.schedule(object : Task("task", cancelable = false) { + override fun runOnce(): Long { + log += "run@${taskFaker.nanoTime}" + return -1L + } + }, 100L) taskFaker.advanceUntil(0L) assertThat(log).isEmpty() @@ -191,12 +194,14 @@ class TaskRunnerTest { } @Test fun cancelWhileExecutingDoesNotStopUncancelableTask() { - val delays = mutableListOf(50L, -1L) - redQueue.schedule("task", 100L, cancelable = false) { - log += "run@${taskFaker.nanoTime}" - redQueue.cancelAll() - return@schedule delays.removeAt(0) - } + redQueue.schedule(object : Task("task", cancelable = false) { + val delays = mutableListOf(50L, -1L) + override fun runOnce(): Long { + log += "run@${taskFaker.nanoTime}" + redQueue.cancelAll() + return delays.removeAt(0) + } + }, 100L) taskFaker.advanceUntil(0L) assertThat(log).isEmpty() @@ -227,9 +232,12 @@ class TaskRunnerTest { } @Test fun interruptingCoordinatorAttemptsToCancelsAndFails() { - redQueue.execute("task", 100L, cancelable = false) { - log += "run@${taskFaker.nanoTime}" - } + redQueue.schedule(object : Task("task", cancelable = false) { + override fun runOnce(): Long { + log += "run@${taskFaker.nanoTime}" + return -1L + } + }, 100L) taskFaker.advanceUntil(0L) assertThat(log).isEmpty() @@ -317,7 +325,7 @@ class TaskRunnerTest { * cumbersome to implement properly because the active task might be a cancel. */ @Test fun scheduledTasksDoesNotIncludeRunningTask() { - val task = object : Task("task one", cancelable = true) { + val task = object : Task("task one") { val schedules = mutableListOf(200L) override fun runOnce(): Long { if (schedules.isNotEmpty()) { @@ -408,10 +416,12 @@ class TaskRunnerTest { } @Test fun shutdownFailsToCancelsScheduledTasks() { - redQueue.schedule("task", 100L, cancelable = false) { - log += "run@${taskFaker.nanoTime}" - return@schedule 50L - } + redQueue.schedule(object : Task("task", false) { + override fun runOnce(): Long { + log += "run@${taskFaker.nanoTime}" + return 50L + } + }, 100L) taskFaker.advanceUntil(0L) assertThat(log).isEmpty() @@ -430,7 +440,7 @@ class TaskRunnerTest { @Test fun scheduleDiscardsTaskWhenShutdown() { redQueue.shutdown() - redQueue.tryExecute("task", 100L) { + redQueue.execute("task", 100L) { // Do nothing. } @@ -441,9 +451,11 @@ class TaskRunnerTest { redQueue.shutdown() try { - redQueue.execute("task", 100L) { - // Do nothing. - } + redQueue.schedule(object : Task("task", cancelable = false) { + override fun runOnce(): Long { + return -1L + } + }, 100L) fail() } catch (_: RejectedExecutionException) { }