From aef791ac366b66ac7490a79dd2fdff8393fd6f8f Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Fri, 5 Apr 2024 07:25:23 -0400 Subject: [PATCH] Don't leak response bodies in executeAsync (#8330) * Don't leak response bodies in executeAsync Also make callers opt in to an unstable coroutines API. If the resource cleanup coroutines API changes, we'll have to change this API. Remove the OkHttp experimental API. This is a good enough API as far as OkHttp is concerned. * Spotless --- .../main/kotlin/okhttp3/JvmCallExtensions.kt | 8 ++- .../test/kotlin/okhttp3/SuspendCallTest.kt | 67 +++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt b/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt index fccfce9e9..c301ddbc1 100644 --- a/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt +++ b/okhttp-coroutines/src/main/kotlin/okhttp3/JvmCallExtensions.kt @@ -20,10 +20,10 @@ package okhttp3 import kotlin.coroutines.resumeWithException import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.suspendCancellableCoroutine +import okhttp3.internal.closeQuietly import okio.IOException -@OptIn(ExperimentalCoroutinesApi::class) -@ExperimentalOkHttpApi +@ExperimentalCoroutinesApi // resume with a resource cleanup. suspend fun Call.executeAsync(): Response = suspendCancellableCoroutine { continuation -> continuation.invokeOnCancellation { @@ -42,7 +42,9 @@ suspend fun Call.executeAsync(): Response = call: Call, response: Response, ) { - continuation.resume(value = response, onCancellation = { call.cancel() }) + continuation.resume(response) { + response.closeQuietly() + } } }, ) diff --git a/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt b/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt index 04a080d03..dcf6080f1 100644 --- a/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt +++ b/okhttp-coroutines/src/test/kotlin/okhttp3/SuspendCallTest.kt @@ -26,15 +26,23 @@ import java.io.IOException import java.util.concurrent.TimeUnit import kotlin.test.assertFailsWith import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.job +import kotlinx.coroutines.supervisorScope import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout import mockwebserver3.MockResponse import mockwebserver3.MockWebServer import mockwebserver3.SocketPolicy.DisconnectAfterRequest +import okhttp3.HttpUrl.Companion.toHttpUrl +import okio.Buffer +import okio.ForwardingSource +import okio.buffer import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.RegisterExtension @@ -148,4 +156,63 @@ class SuspendCallTest { } } } + + @Test + fun responseClosedIfCoroutineCanceled() { + runTest { + val call = ClosableCall() + + supervisorScope { + assertFailsWith { + coroutineScope { + call.afterCallbackOnResponse = { + coroutineContext.job.cancel() + } + call.executeAsync() + } + } + } + + assertThat(call.canceled).isTrue() + assertThat(call.responseClosed).isTrue() + } + } + + /** A call that keeps track of whether its response body is closed. */ + private class ClosableCall : FailingCall() { + private val response = + Response.Builder() + .request(Request("https://example.com/".toHttpUrl())) + .protocol(Protocol.HTTP_1_1) + .message("OK") + .code(200) + .body( + object : ResponseBody() { + override fun contentType() = null + + override fun contentLength() = -1L + + override fun source() = + object : ForwardingSource(Buffer()) { + override fun close() { + responseClosed = true + } + }.buffer() + }, + ) + .build() + + var responseClosed = false + var canceled = false + var afterCallbackOnResponse: () -> Unit = {} + + override fun cancel() { + canceled = true + } + + override fun enqueue(responseCallback: Callback) { + responseCallback.onResponse(this, response) + afterCallbackOnResponse() + } + } }