diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt index 013d354eb..afe11d150 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt @@ -448,7 +448,7 @@ class MockWebServer : ExternalResource(), Closeable { // Await shutdown. for (queue in taskRunner.activeQueues()) { - if (!queue.awaitIdle(TimeUnit.SECONDS.toNanos(5))) { + if (!queue.idleLatch().await(5, TimeUnit.SECONDS)) { throw IOException("Gave up waiting for queue to shut down") } } diff --git a/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt b/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt index 0a5d2f52f..408a2baa3 100644 --- a/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt +++ b/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt @@ -87,7 +87,7 @@ class OkHttpClientTestRule : TestRule { private fun ensureAllTaskQueuesIdle() { for (queue in TaskRunner.INSTANCE.activeQueues()) { - assertThat(queue.awaitIdle(TimeUnit.MILLISECONDS.toNanos(1000L))) + assertThat(queue.idleLatch().await(1_000L, TimeUnit.MILLISECONDS)) .withFailMessage("Queue still active after 1000 ms") .isTrue() } diff --git a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt index 0e6c555da..4de83b0c8 100644 --- a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt +++ b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt @@ -18,7 +18,6 @@ package okhttp3.internal.concurrent import okhttp3.internal.assertThreadDoesntHoldLock import java.util.concurrent.CountDownLatch import java.util.concurrent.RejectedExecutionException -import java.util.concurrent.TimeUnit /** * A set of tasks that are executed in sequential order. @@ -101,25 +100,42 @@ class TaskQueue internal constructor( }, delayNanos) } - /** Returns true if this queue became idle before the timeout elapsed. */ - fun awaitIdle(delayNanos: Long): Boolean { - val latch = CountDownLatch(1) - - val task = object : Task("OkHttp awaitIdle", cancelable = false) { - override fun runOnce(): Long { - latch.countDown() - return -1L - } - } - - // Don't delegate to schedule because that has to honor shutdown rules. + /** Returns a latch that reaches 0 when the queue is next idle. */ + fun idleLatch(): CountDownLatch { synchronized(taskRunner) { - if (scheduleAndDecide(task, 0L)) { + // If the queue is already idle, that's easy. + if (activeTask == null && futureTasks.isEmpty()) { + return CountDownLatch(0) + } + + // If there's an existing AwaitIdleTask, use it. This is necessary when the executor is + // shutdown but still busy as we can't enqueue in that case. + val existingTask = activeTask + if (existingTask is AwaitIdleTask) { + return existingTask.latch + } + for (futureTask in futureTasks) { + if (futureTask is AwaitIdleTask) { + return futureTask.latch + } + } + + // Don't delegate to schedule() because that enforces shutdown rules. + val newTask = AwaitIdleTask() + if (scheduleAndDecide(newTask, 0L)) { taskRunner.kickCoordinator(this) } + return newTask.latch } + } - return latch.await(delayNanos, TimeUnit.NANOSECONDS) + private class AwaitIdleTask : Task("OkHttp awaitIdle", cancelable = false) { + val latch = CountDownLatch(1) + + override fun runOnce(): Long { + latch.countDown() + return -1L + } } /** Adds [task] to run in [delayNanos]. Returns true if the coordinator is impacted. */ diff --git a/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt b/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt index ca907277b..7dedfbad9 100644 --- a/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt +++ b/okhttp/src/main/java/okhttp3/internal/ws/RealWebSocket.kt @@ -259,14 +259,14 @@ class RealWebSocket( /** For testing: wait until the web socket's executor has terminated. */ @Throws(InterruptedException::class) fun awaitTermination(timeout: Long, timeUnit: TimeUnit) { - taskQueue.awaitIdle(timeUnit.toNanos(timeout)) + taskQueue.idleLatch().await(timeout, timeUnit) } /** For testing: force this web socket to release its threads. */ @Throws(InterruptedException::class) fun tearDown() { taskQueue.shutdown() - taskQueue.awaitIdle(TimeUnit.SECONDS.toNanos(10L)) + taskQueue.idleLatch().await(10, TimeUnit.SECONDS) } @Synchronized fun sentPingCount(): Int = sentPingCount diff --git a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerRealBackendTest.kt b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerRealBackendTest.kt index e5e9aae67..a04cf5c45 100644 --- a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerRealBackendTest.kt +++ b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerRealBackendTest.kt @@ -82,11 +82,22 @@ class TaskRunnerRealBackendTest { return@schedule -1L } - queue.awaitIdle(TimeUnit.MILLISECONDS.toNanos(500)) + queue.idleLatch().await(500, TimeUnit.MILLISECONDS) assertThat(log.take()).isEqualTo("failing task running") assertThat(log.take()).isEqualTo("uncaught exception: java.lang.RuntimeException: boom!") assertThat(log.take()).isEqualTo("normal task running") assertThat(log).isEmpty() } + + @Test fun idleLatchAfterShutdown() { + queue.schedule("task") { + Thread.sleep(250) + backend.shutdown() + return@schedule -1L + } + + assertThat(queue.idleLatch().await(500L, TimeUnit.MILLISECONDS)).isTrue() + assertThat(queue.idleLatch().count).isEqualTo(0) + } } diff --git a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt index 6013c34a1..898a0531d 100644 --- a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt +++ b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt @@ -623,6 +623,30 @@ class TaskRunnerTest { ) } + @Test fun idleLatch() { + redQueue.execute("task") { + log += "run@${taskFaker.nanoTime}" + } + + val idleLatch = redQueue.idleLatch() + assertThat(idleLatch.count).isEqualTo(1) + + taskFaker.advanceUntil(0.µs) + assertThat(log).containsExactly("run@0") + + assertThat(idleLatch.count).isEqualTo(0) + } + + @Test fun multipleCallsToIdleLatchReturnSameInstance() { + redQueue.execute("task") { + log += "run@${taskFaker.nanoTime}" + } + + val idleLatch1 = redQueue.idleLatch() + val idleLatch2 = redQueue.idleLatch() + assertThat(idleLatch2).isSameAs(idleLatch1) + } + private val Int.µs: Long get() = this * 1_000L }