diff --git a/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt b/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt index fccfce9e9..c301ddbc1 100644 --- a/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt +++ b/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt @@ -20,10 +20,10 @@ package okhttp3 import kotlin.coroutines.resumeWithException import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.suspendCancellableCoroutine +import okhttp3.internal.closeQuietly import okio.IOException -@OptIn(ExperimentalCoroutinesApi::class) -@ExperimentalOkHttpApi +@ExperimentalCoroutinesApi // resume with a resource cleanup. suspend fun Call.executeAsync(): Response = suspendCancellableCoroutine { continuation -> continuation.invokeOnCancellation { @@ -42,7 +42,9 @@ suspend fun Call.executeAsync(): Response = call: Call, response: Response, ) { - continuation.resume(value = response, onCancellation = { call.cancel() }) + continuation.resume(response) { + response.closeQuietly() + } } }, ) diff --git a/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt b/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt index 04a080d03..dcf6080f1 100644 --- a/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt +++ b/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt @@ -26,15 +26,23 @@ import java.io.IOException import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.job +import kotlinx.coroutines.supervisorScope import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import mockwebserver3.MockResponse import mockwebserver3.MockWebServer import mockwebserver3.SocketPolicy.DisconnectAfterRequest +import okhttp3.HttpUrl.Companion.toHttpUrl +import okio.Buffer +import okio.ForwardingSource +import okio.buffer import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.RegisterExtension @@ -148,4 +156,63 @@ class SuspendCallTest { } } } + + @Test + fun responseClosedIfCoroutineCanceled() { + runTest { + val call = ClosableCall() + + supervisorScope { + assertFailsWith { + coroutineScope { + call.afterCallbackOnResponse = { + coroutineContext.job.cancel() + } + call.executeAsync() + } + } + } + + assertThat(call.canceled).isTrue() + assertThat(call.responseClosed).isTrue() + } + } + + /** A call that keeps track of whether its response body is closed. */ + private class ClosableCall : FailingCall() { + private val response = + Response.Builder() + .request(Request("https://example.com/".toHttpUrl())) + .protocol(Protocol.HTTP_1_1) + .message("OK") + .code(200) + .body( + object : ResponseBody() { + override fun contentType() = null + + override fun contentLength() = -1L + + override fun source() = + object : ForwardingSource(Buffer()) { + override fun close() { + responseClosed = true + } + }.buffer() + }, + ) + .build() + + var responseClosed = false + var canceled = false + var afterCallbackOnResponse: () -> Unit = {} + + override fun cancel() { + canceled = true + } + + override fun enqueue(responseCallback: Callback) { + responseCallback.onResponse(this, response) + afterCallbackOnResponse() + } + } }