diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/Tags.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/Tags.kt index c9192fe23..2f88c9fa5 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/Tags.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/internal/Tags.kt @@ -15,6 +15,7 @@ */ package okhttp3.internal +import java.util.concurrent.atomic.AtomicReference import kotlin.reflect.KClass /** @@ -104,3 +105,28 @@ private class LinkedTags( .reversed() .joinToString(prefix = "{", postfix = "}") { "${it.key}=${it.value}" } } + +internal fun AtomicReference.computeIfAbsent( + type: KClass, + compute: () -> T, +): T { + var computed: T? = null + + while (true) { + val tags = get() + + // If the element is already present. Return it. + val existing = tags[type] + if (existing != null) return existing + + if (computed == null) { + computed = compute() + } + + // If we successfully add the computed element, we're done. + val newTags = tags.plus(type, computed) + if (compareAndSet(tags, newTags)) return computed + + // We lost the race. Possibly to other code that was putting a *different* key. Try again! + } +} diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/internal/TagsTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/internal/TagsTest.kt index 66562c64a..748d08c4a 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/internal/TagsTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/internal/TagsTest.kt @@ -18,6 +18,7 @@ package okhttp3.internal import assertk.assertThat import assertk.assertions.isEqualTo import assertk.assertions.isNull +import java.util.concurrent.atomic.AtomicReference import org.junit.jupiter.api.Test class TagsTest { @@ -157,4 +158,67 @@ class TagsTest { assertThat(tags[String::class]).isEqualTo("a") assertThat(tags.toString()).isEqualTo("{class kotlin.String=a}") } + + @Test + fun computeIfAbsentWhenEmpty() { + val tags = EmptyTags + val atomicTags = AtomicReference(tags) + assertThat(atomicTags.computeIfAbsent(String::class) { "a" }).isEqualTo("a") + assertThat(atomicTags.get()[String::class]).isEqualTo("a") + } + + @Test + fun computeIfAbsentWhenPresent() { + val tags = EmptyTags.plus(String::class, "a") + val atomicTags = AtomicReference(tags) + assertThat(atomicTags.computeIfAbsent(String::class) { "b" }).isEqualTo("a") + assertThat(atomicTags.get()[String::class]).isEqualTo("a") + } + + @Test + fun computeIfAbsentWhenDifferentKeyRaceLostDuringCompute() { + val tags = EmptyTags + val atomicTags = AtomicReference(tags) + val result = + atomicTags.computeIfAbsent(String::class) { + // 'Race' by making another computeIfAbsent call. In practice this would be another thread. + assertThat(atomicTags.computeIfAbsent(Integer::class) { 5 as Integer }).isEqualTo(5) + "a" + } + assertThat(result).isEqualTo("a") + assertThat(atomicTags.get()[String::class]).isEqualTo("a") + assertThat(atomicTags.get()[Integer::class]).isEqualTo(5) + } + + @Test + fun computeIfAbsentWhenSameKeyRaceLostDuringCompute() { + val tags = EmptyTags + val atomicTags = AtomicReference(tags) + val result = + atomicTags.computeIfAbsent(String::class) { + // 'Race' by making another computeIfAbsent call. In practice this would be another thread. + assertThat(atomicTags.computeIfAbsent(String::class) { "b" }).isEqualTo("b") + "a" + } + assertThat(result).isEqualTo("b") + assertThat(atomicTags.get()[String::class]).isEqualTo("b") + } + + @Test + fun computeIfAbsentOnlyComputesOnceAfterRaceLost() { + var computeCount = 0 + val tags = EmptyTags + val atomicTags = AtomicReference(tags) + val result = + atomicTags.computeIfAbsent(String::class) { + computeCount++ + // 'Race' by making another computeIfAbsent call. In practice this would be another thread. + assertThat(atomicTags.computeIfAbsent(Integer::class) { 5 as Integer }).isEqualTo(5) + "a" + } + assertThat(result).isEqualTo("a") + assertThat(computeCount).isEqualTo(1) + assertThat(atomicTags.get()[Integer::class]).isEqualTo(5) + assertThat(atomicTags.get()[String::class]).isEqualTo("a") + } }