diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt index 0385c2fc7..a456b8fdf 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt @@ -25,6 +25,7 @@ import okhttp3.Request import okhttp3.Response import okhttp3.internal.addHeaderLenient import okhttp3.internal.closeQuietly +import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.duplex.MwsDuplexAccess import okhttp3.internal.execute import okhttp3.internal.http.HttpMethod @@ -99,6 +100,7 @@ import javax.net.ssl.X509TrustManager * in sequence. */ class MockWebServer : ExternalResource(), Closeable { + private val taskRunner = TaskRunner() private val requestQueue = LinkedBlockingQueue() private val openClientSockets = Collections.newSetFromMap(ConcurrentHashMap()) @@ -454,6 +456,12 @@ class MockWebServer : ExternalResource(), Closeable { } catch (e: InterruptedException) { throw AssertionError() } + + for (queue in taskRunner.activeQueues()) { + if (!queue.awaitIdle(TimeUnit.MILLISECONDS.toNanos(500L))) { + throw IOException("Gave up waiting for ${queue.owner} to shut down") + } + } } @Synchronized override fun after() { @@ -533,7 +541,7 @@ class MockWebServer : ExternalResource(), Closeable { if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) { val http2SocketHandler = Http2SocketHandler(socket, protocol) - val connection = Http2Connection.Builder(false) + val connection = Http2Connection.Builder(false, taskRunner) .socket(socket) .listener(http2SocketHandler) .build() diff --git a/mockwebserver/src/test/java/okhttp3/mockwebserver/internal/http2/Http2Server.java b/mockwebserver/src/test/java/okhttp3/mockwebserver/internal/http2/Http2Server.java index 1f3039daa..78a6b7994 100644 --- a/mockwebserver/src/test/java/okhttp3/mockwebserver/internal/http2/Http2Server.java +++ b/mockwebserver/src/test/java/okhttp3/mockwebserver/internal/http2/Http2Server.java @@ -29,6 +29,7 @@ import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import okhttp3.Headers; import okhttp3.Protocol; +import okhttp3.internal.concurrent.TaskRunner; import okhttp3.internal.http2.Header; import okhttp3.internal.http2.Http2Connection; import okhttp3.internal.http2.Http2Stream; @@ -69,7 +70,7 @@ public final class Http2Server extends Http2Connection.Listener { if (protocol != Protocol.HTTP_2) { throw new ProtocolException("Protocol " + protocol + " unsupported"); } - Http2Connection connection = new Http2Connection.Builder(false) + Http2Connection connection = new Http2Connection.Builder(false, TaskRunner.INSTANCE) .socket(sslSocket) .listener(this) .build(); diff --git a/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt b/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt index a39b3cc9a..1315914a3 100644 --- a/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt +++ b/okhttp-testing-support/src/main/java/okhttp3/OkHttpClientTestRule.kt @@ -15,8 +15,6 @@ */ package okhttp3 -import okhttp3.internal.concurrent.Task -import okhttp3.internal.concurrent.TaskQueue import okhttp3.internal.concurrent.TaskRunner import okhttp3.testing.Flaky import org.assertj.core.api.Assertions.assertThat @@ -25,7 +23,6 @@ import org.junit.runner.Description import org.junit.runners.model.Statement import java.net.InetAddress import java.util.concurrent.ConcurrentLinkedDeque -import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit /** Apply this rule to tests that need an OkHttpClient instance. */ @@ -66,7 +63,7 @@ class OkHttpClientTestRule : TestRule { private fun ensureAllTaskQueuesIdle() { for (queue in TaskRunner.INSTANCE.activeQueues()) { - assertThat(queue.awaitIdle(500L, TimeUnit.MILLISECONDS)) + assertThat(queue.awaitIdle(TimeUnit.MILLISECONDS.toNanos(500L))) .withFailMessage("Queue ${queue.owner} still active after 500ms") .isTrue() } @@ -133,19 +130,6 @@ class OkHttpClientTestRule : TestRule { } } - /** Returns true if this queue became idle before the timeout elapsed. */ - private fun TaskQueue.awaitIdle(timeout: Long, timeUnit: TimeUnit): Boolean { - val latch = CountDownLatch(1) - schedule(object : Task("awaitIdle") { - override fun runOnce(): Long { - latch.countDown() - return -1L - } - }) - - return latch.await(timeout, timeUnit) - } - companion object { /** * Quick and dirty pool of OkHttpClient instances. Each has its own independent dispatcher and diff --git a/okhttp/src/main/java/okhttp3/internal/Util.kt b/okhttp/src/main/java/okhttp3/internal/Util.kt index 1748d996a..861221b36 100644 --- a/okhttp/src/main/java/okhttp3/internal/Util.kt +++ b/okhttp/src/main/java/okhttp3/internal/Util.kt @@ -49,7 +49,6 @@ import java.util.LinkedHashMap import java.util.Locale import java.util.TimeZone import java.util.concurrent.Executor -import java.util.concurrent.RejectedExecutionException import java.util.concurrent.ThreadFactory import java.util.concurrent.TimeUnit import kotlin.text.Charsets.UTF_32BE @@ -393,14 +392,6 @@ inline fun Executor.execute(name: String, crossinline block: () -> Unit) { } } -/** Executes [block] unless this executor has been shutdown, in which case this does nothing. */ -inline fun Executor.tryExecute(name: String, crossinline block: () -> Unit) { - try { - execute(name, block) - } catch (_: RejectedExecutionException) { - } -} - fun Buffer.skipAll(b: Byte): Int { var count = 0 while (!exhausted() && this[0] == b) { @@ -510,34 +501,9 @@ fun Long.toHexString(): String = java.lang.Long.toHexString(this) fun Int.toHexString(): String = Integer.toHexString(this) -/** - * Lock and wait a duration in nanoseconds. Unlike [java.lang.Object.wait] this interprets 0 as - * "don't wait" instead of "wait forever". - */ -@Throws(InterruptedException::class) -fun Any.lockAndWaitNanos(nanos: Long) { - synchronized(this) { - objectWaitNanos(nanos) - } -} - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "NOTHING_TO_INLINE") inline fun Any.wait() = (this as Object).wait() -/** - * Wait a duration in nanoseconds. Unlike [java.lang.Object.wait] this interprets 0 as "don't wait" - * instead of "wait forever". - */ -@Throws(InterruptedException::class) -@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") -fun Any.objectWaitNanos(nanos: Long) { - val ms = nanos / 1_000_000L - val ns = nanos - (ms * 1_000_000L) - if (ms > 0L || nanos > 0) { - (this as Object).wait(ms, ns.toInt()) - } -} - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "NOTHING_TO_INLINE") inline fun Any.notify() = (this as Object).notify() diff --git a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt index b0ce4dda6..5696e73d4 100644 --- a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt +++ b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskQueue.kt @@ -16,6 +16,9 @@ package okhttp3.internal.concurrent import okhttp3.internal.addIfAbsent +import java.util.concurrent.CountDownLatch +import java.util.concurrent.RejectedExecutionException +import java.util.concurrent.TimeUnit /** * A set of tasks that are executed in sequential order. @@ -32,6 +35,8 @@ class TaskQueue internal constructor( */ val owner: Any ) { + private var shutdown = false + /** This queue's currently-executing task, or null if none is currently executing. */ private var activeTask: Task? = null @@ -61,19 +66,55 @@ class TaskQueue internal constructor( * The target execution time is implemented on a best-effort basis. If another task in this queue * is running when that time is reached, that task is allowed to complete before this task is * started. Similarly the task will be delayed if the host lacks compute resources. + * + * @throws RejectedExecutionException if the queue is shut down. */ fun schedule(task: Task, delayNanos: Long = 0L) { - task.initQueue(this) - synchronized(taskRunner) { + if (shutdown) throw RejectedExecutionException() + if (scheduleAndDecide(task, delayNanos)) { taskRunner.kickCoordinator(this) } } } + /** Like [schedule], but this silently discard the task if the queue is shut down. */ + fun trySchedule(task: Task, delayNanos: Long = 0L) { + synchronized(taskRunner) { + if (shutdown) return + + if (scheduleAndDecide(task, delayNanos)) { + taskRunner.kickCoordinator(this) + } + } + } + + /** Returns true if this queue became idle before the timeout elapsed. */ + fun awaitIdle(delayNanos: Long): Boolean { + val latch = CountDownLatch(1) + + val task = object : Task("awaitIdle") { + override fun runOnce(): Long { + latch.countDown() + return -1L + } + } + + // Don't delegate to schedule because that has to honor shutdown rules. + synchronized(taskRunner) { + if (scheduleAndDecide(task, 0L)) { + taskRunner.kickCoordinator(this) + } + } + + return latch.await(delayNanos, TimeUnit.NANOSECONDS) + } + /** Adds [task] to run in [delayNanos]. Returns true if the coordinator should run. */ private fun scheduleAndDecide(task: Task, delayNanos: Long): Boolean { + task.initQueue(this) + val now = taskRunner.backend.nanoTime() val executeNanoTime = now + delayNanos @@ -100,6 +141,8 @@ class TaskQueue internal constructor( * be removed from the execution schedule. */ fun cancelAll() { + check(!Thread.holdsLock(this)) + synchronized(taskRunner) { if (cancelAllAndDecide()) { taskRunner.kickCoordinator(this) @@ -107,6 +150,17 @@ class TaskQueue internal constructor( } } + fun shutdown() { + check(!Thread.holdsLock(this)) + + synchronized(taskRunner) { + shutdown = true + if (cancelAllAndDecide()) { + taskRunner.kickCoordinator(this) + } + } + } + /** Returns true if the coordinator should run. */ private fun cancelAllAndDecide(): Boolean { val runningTask = activeTask @@ -160,7 +214,7 @@ class TaskQueue internal constructor( synchronized(taskRunner) { check(activeTask === task) - if (delayNanos != -1L) { + if (delayNanos != -1L && !shutdown) { scheduleAndDecide(task, delayNanos) } else if (!futureTasks.contains(task)) { cancelTasks.remove(task) // We don't need to cancel it because it isn't scheduled. diff --git a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskRunner.kt b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskRunner.kt index fb5732287..a7a1eba55 100644 --- a/okhttp/src/main/java/okhttp3/internal/concurrent/TaskRunner.kt +++ b/okhttp/src/main/java/okhttp3/internal/concurrent/TaskRunner.kt @@ -17,7 +17,6 @@ package okhttp3.internal.concurrent import okhttp3.internal.addIfAbsent import okhttp3.internal.notify -import okhttp3.internal.objectWaitNanos import okhttp3.internal.threadFactory import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.SynchronousQueue @@ -158,8 +157,18 @@ class TaskRunner( taskRunner.notify() } + /** + * Wait a duration in nanoseconds. Unlike [java.lang.Object.wait] this interprets 0 as + * "don't wait" instead of "wait forever". + */ + @Throws(InterruptedException::class) + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") override fun coordinatorWait(taskRunner: TaskRunner, nanos: Long) { - taskRunner.objectWaitNanos(nanos) + val ms = nanos / 1_000_000L + val ns = nanos - (ms * 1_000_000L) + if (ms > 0L || nanos > 0) { + (taskRunner as Object).wait(ms, ns.toInt()) + } } fun shutdown() { diff --git a/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.kt b/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.kt index a68eaf25e..24b0231c3 100644 --- a/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.kt +++ b/okhttp/src/main/java/okhttp3/internal/connection/RealConnection.kt @@ -33,6 +33,7 @@ import okhttp3.Response import okhttp3.Route import okhttp3.internal.EMPTY_RESPONSE import okhttp3.internal.closeQuietly +import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.http.ExchangeCodec import okhttp3.internal.http1.Http1ExchangeCodec import okhttp3.internal.http2.ConnectionShutdownException @@ -321,7 +322,7 @@ class RealConnection( val source = this.source!! val sink = this.sink!! socket.soTimeout = 0 // HTTP/2 connection timeouts are set per-stream. - val http2Connection = Http2Connection.Builder(true) + val http2Connection = Http2Connection.Builder(client = true, taskRunner = TaskRunner.INSTANCE) .socket(socket, route.address.url.host, source, sink) .listener(this) .pingIntervalMillis(pingIntervalMillis) diff --git a/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt b/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt index 048605e85..54cd7fae8 100644 --- a/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt +++ b/okhttp/src/main/java/okhttp3/internal/http2/Http2Connection.kt @@ -18,9 +18,10 @@ package okhttp3.internal.http2 import okhttp3.internal.EMPTY_BYTE_ARRAY import okhttp3.internal.EMPTY_HEADERS import okhttp3.internal.closeQuietly +import okhttp3.internal.concurrent.Task +import okhttp3.internal.concurrent.TaskRunner import okhttp3.internal.connectionName import okhttp3.internal.execute -import okhttp3.internal.format import okhttp3.internal.http2.ErrorCode.REFUSED_STREAM import okhttp3.internal.http2.Settings.Companion.DEFAULT_INITIAL_WINDOW_SIZE import okhttp3.internal.ignoreIoExceptions @@ -28,9 +29,7 @@ import okhttp3.internal.notifyAll import okhttp3.internal.platform.Platform import okhttp3.internal.platform.Platform.Companion.INFO import okhttp3.internal.threadFactory -import okhttp3.internal.threadName import okhttp3.internal.toHeaders -import okhttp3.internal.tryExecute import okhttp3.internal.wait import okio.Buffer import okio.BufferedSink @@ -43,12 +42,9 @@ import java.io.Closeable import java.io.IOException import java.io.InterruptedIOException import java.net.Socket -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.ScheduledThreadPoolExecutor import java.util.concurrent.SynchronousQueue import java.util.concurrent.ThreadPoolExecutor import java.util.concurrent.TimeUnit -import java.util.concurrent.TimeUnit.MILLISECONDS /** * A socket connection to a remote peer. A connection hosts streams which can send and receive @@ -91,13 +87,10 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { internal set /** Asynchronously writes frames to the outgoing socket. */ - private val writerExecutor = ScheduledThreadPoolExecutor(1, - threadFactory(format("OkHttp %s Writer", connectionName), false)) + private val writerQueue = builder.taskRunner.newQueue("$connectionName Writer") /** Ensures push promise callbacks events are sent in order per stream. */ - // Like newSingleThreadExecutor, except lazy creates the thread. - private val pushExecutor = ThreadPoolExecutor(0, 1, 60L, TimeUnit.SECONDS, LinkedBlockingQueue(), - threadFactory(format("OkHttp %s Push Observer", connectionName), true)) + private val pushQueue = builder.taskRunner.newQueue("$connectionName Push") /** User code to run in response to push promise events. */ private val pushObserver: PushObserver = builder.pushObserver @@ -149,11 +142,15 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { init { if (builder.pingIntervalMillis != 0) { - writerExecutor.scheduleAtFixedRate({ - threadName("OkHttp $connectionName ping") { + val pingIntervalNanos = TimeUnit.MILLISECONDS.toNanos(builder.pingIntervalMillis.toLong()) + writerQueue.schedule(object : Task("OkHttp $connectionName ping") { + override fun runOnce(): Long { writePing(false, 0, 0) + return pingIntervalNanos } - }, builder.pingIntervalMillis.toLong(), builder.pingIntervalMillis.toLong(), MILLISECONDS) + + override fun tryCancel() = true + }, pingIntervalNanos) } } @@ -328,13 +325,18 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { streamId: Int, errorCode: ErrorCode ) { - writerExecutor.tryExecute("OkHttp $connectionName stream $streamId") { - try { - writeSynReset(streamId, errorCode) - } catch (e: IOException) { - failConnection(e) + writerQueue.trySchedule(object : Task("OkHttp $connectionName stream $streamId") { + override fun runOnce(): Long { + try { + writeSynReset(streamId, errorCode) + } catch (e: IOException) { + failConnection(e) + } + return -1L } - } + + override fun tryCancel() = true + }) } @Throws(IOException::class) @@ -349,13 +351,18 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { streamId: Int, unacknowledgedBytesRead: Long ) { - writerExecutor.tryExecute("OkHttp Window Update $connectionName stream $streamId") { - try { - writer.windowUpdate(streamId, unacknowledgedBytesRead) - } catch (e: IOException) { - failConnection(e) + writerQueue.trySchedule(object : Task("OkHttp Window Update $connectionName stream $streamId") { + override fun runOnce(): Long { + try { + writer.windowUpdate(streamId, unacknowledgedBytesRead) + } catch (e: IOException) { + failConnection(e) + } + return -1L } - } + + override fun tryCancel() = true + }) } fun writePing( @@ -467,8 +474,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } // Release the threads. - writerExecutor.shutdown() - pushExecutor.shutdown() + writerQueue.shutdown() + pushQueue.shutdown() } private fun failConnection(e: IOException?) { @@ -511,7 +518,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { class Builder( /** True if this peer initiated the connection; false if this peer accepted the connection. */ - internal var client: Boolean + internal var client: Boolean, + internal val taskRunner: TaskRunner ) { internal lateinit var socket: Socket internal lateinit var connectionName: String @@ -659,9 +667,14 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } override fun settings(clearPrevious: Boolean, settings: Settings) { - writerExecutor.tryExecute("OkHttp $connectionName ACK Settings") { - applyAndAckSettings(clearPrevious, settings) - } + writerQueue.trySchedule(object : Task("OkHttp $connectionName ACK Settings") { + override fun runOnce(): Long { + applyAndAckSettings(clearPrevious, settings) + return -1L + } + + override fun tryCancel() = true + }) } /** @@ -725,9 +738,14 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } } else { // Send a reply to a client ping if this is a server and vice versa. - writerExecutor.tryExecute("OkHttp $connectionName ping") { - writePing(true, payload1, payload2) - } + writerQueue.trySchedule(object : Task("OkHttp $connectionName ping") { + override fun runOnce(): Long { + writePing(true, payload1, payload2) + return -1L + } + + override fun tryCancel() = true + }) } } @@ -812,8 +830,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } currentPushRequests.add(streamId) } - if (!isShutdown) { - pushExecutor.tryExecute("OkHttp $connectionName Push Request[$streamId]") { + pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Request[$streamId]") { + override fun runOnce(): Long { val cancel = pushObserver.onRequest(streamId, requestHeaders) ignoreIoExceptions { if (cancel) { @@ -823,8 +841,11 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } } } + return -1L } - } + + override fun tryCancel() = true + }) } internal fun pushHeadersLater( @@ -832,8 +853,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { requestHeaders: List
, inFinished: Boolean ) { - if (!isShutdown) { - pushExecutor.tryExecute("OkHttp $connectionName Push Headers[$streamId]") { + pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Headers[$streamId]") { + override fun runOnce(): Long { val cancel = pushObserver.onHeaders(streamId, requestHeaders, inFinished) ignoreIoExceptions { if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) @@ -843,8 +864,11 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } } } + return -1L } - } + + override fun tryCancel() = true + }) } /** @@ -861,8 +885,8 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { val buffer = Buffer() source.require(byteCount.toLong()) // Eagerly read the frame before firing client thread. source.read(buffer, byteCount.toLong()) - if (!isShutdown) { - pushExecutor.execute("OkHttp $connectionName Push Data[$streamId]") { + pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Data[$streamId]") { + override fun runOnce(): Long { ignoreIoExceptions { val cancel = pushObserver.onData(streamId, buffer, byteCount, inFinished) if (cancel) writer.rstStream(streamId, ErrorCode.CANCEL) @@ -872,19 +896,23 @@ class Http2Connection internal constructor(builder: Builder) : Closeable { } } } + return -1L } - } + + override fun tryCancel() = true + }) } internal fun pushResetLater(streamId: Int, errorCode: ErrorCode) { - if (!isShutdown) { - pushExecutor.execute("OkHttp $connectionName Push Reset[$streamId]") { + pushQueue.trySchedule(object : Task("OkHttp $connectionName Push Reset[$streamId]") { + override fun runOnce(): Long { pushObserver.onReset(streamId, errorCode) synchronized(this@Http2Connection) { currentPushRequests.remove(streamId) } + return -1L } - } + }) } /** Listener of streams and settings initiated by the peer. */ diff --git a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskFaker.kt b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskFaker.kt index bd8cf47cd..29d9840b3 100644 --- a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskFaker.kt +++ b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskFaker.kt @@ -40,7 +40,7 @@ class TaskFaker { /** How many tasks can be executed immediately. */ val tasksSize: Int get() = tasks.size - /** Guarded by taskRunner. */ + /** Guarded by [taskRunner]. */ var nanoTime = 0L private set @@ -143,7 +143,9 @@ class TaskFaker { fun assertNoMoreTasks() { assertThat(coordinatorToRun).isNull() assertThat(tasks).isEmpty() - assertThat(coordinatorWaitingUntilTime).isEqualTo(Long.MAX_VALUE) + assertThat(coordinatorWaitingUntilTime) + .withFailMessage("tasks are scheduled to run at $coordinatorWaitingUntilTime") + .isEqualTo(Long.MAX_VALUE) } fun interruptCoordinatorThread() { diff --git a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt index 3b0e600de..a6eccbd12 100644 --- a/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt +++ b/okhttp/src/test/java/okhttp3/internal/concurrent/TaskRunnerTest.kt @@ -16,7 +16,9 @@ package okhttp3.internal.concurrent import org.assertj.core.api.Assertions.assertThat +import org.junit.Assert.fail import org.junit.Test +import java.util.concurrent.RejectedExecutionException class TaskRunnerTest { private val taskFaker = TaskFaker() @@ -470,4 +472,79 @@ class TaskRunnerTest { taskFaker.assertNoMoreTasks() } + + @Test fun shutdownSuccessfullyCancelsScheduledTasks() { + redQueue.schedule(object : Task("task") { + override fun runOnce(): Long { + log += "run@${taskFaker.nanoTime}" + return -1L + } + + override fun tryCancel(): Boolean { + log += "cancel@${taskFaker.nanoTime}" + return true + } + }, 100L) + + taskFaker.advanceUntil(0L) + assertThat(log).isEmpty() + + redQueue.shutdown() + + taskFaker.advanceUntil(99L) + assertThat(log).containsExactly("cancel@99") + + taskFaker.assertNoMoreTasks() + } + + @Test fun shutdownFailsToCancelsScheduledTasks() { + redQueue.schedule(object : Task("task") { + override fun runOnce(): Long { + log += "run@${taskFaker.nanoTime}" + return 50L + } + + override fun tryCancel(): Boolean { + log += "cancel@${taskFaker.nanoTime}" + return false + } + }, 100L) + + taskFaker.advanceUntil(0L) + assertThat(log).isEmpty() + + redQueue.shutdown() + + taskFaker.advanceUntil(99L) + assertThat(log).containsExactly("cancel@99") + + taskFaker.advanceUntil(100L) + assertThat(log).containsExactly("cancel@99", "run@100") + + taskFaker.assertNoMoreTasks() + } + + @Test fun scheduleDiscardsTaskWhenShutdown() { + redQueue.shutdown() + + redQueue.trySchedule(object : Task("task") { + override fun runOnce() = -1L + }, 100L) + + taskFaker.assertNoMoreTasks() + } + + @Test fun scheduleThrowsWhenShutdown() { + redQueue.shutdown() + + try { + redQueue.schedule(object : Task("task") { + override fun runOnce() = -1L + }, 100L) + fail() + } catch (_: RejectedExecutionException) { + } + + taskFaker.assertNoMoreTasks() + } } diff --git a/okhttp/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java b/okhttp/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java index e39871d1c..a5186d71c 100644 --- a/okhttp/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java +++ b/okhttp/src/test/java/okhttp3/internal/http2/Http2ConnectionTest.java @@ -27,6 +27,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import okhttp3.Headers; import okhttp3.internal.Util; +import okhttp3.internal.concurrent.TaskRunner; import okhttp3.internal.http2.MockHttp2Peer.InFrame; import okio.AsyncTimeout; import okio.Buffer; @@ -498,7 +499,7 @@ public final class Http2ConnectionTest { String longString = repeat('a', Http2.INITIAL_MAX_FRAME_SIZE + 1); Socket socket = peer.openSocket(); - Http2Connection connection = new Http2Connection.Builder(true) + Http2Connection connection = new Http2Connection.Builder(true, TaskRunner.INSTANCE) .socket(socket) .pushObserver(IGNORE) .build(); @@ -1786,7 +1787,7 @@ public final class Http2ConnectionTest { peer.acceptFrame(); // GOAWAY peer.play(); - Http2Connection connection = new Http2Connection.Builder(true) + Http2Connection connection = new Http2Connection.Builder(true, TaskRunner.INSTANCE) .socket(peer.openSocket()) .build(); connection.start(false); @@ -1857,7 +1858,7 @@ public final class Http2ConnectionTest { /** Builds a new connection to {@code peer} with settings acked. */ private Http2Connection connect(MockHttp2Peer peer, PushObserver pushObserver, Http2Connection.Listener listener) throws Exception { - Http2Connection connection = new Http2Connection.Builder(true) + Http2Connection connection = new Http2Connection.Builder(true, TaskRunner.INSTANCE) .socket(peer.openSocket()) .pushObserver(pushObserver) .listener(listener)