1
0
mirror of https://github.com/square/okhttp.git synced 2025-07-31 05:04:26 +03:00

Reformat with Spotless (#8180)

* Enable spotless

* Run spotlessApply

* Fixup trimMargin

* Re-run spotlessApply
This commit is contained in:
Jesse Wilson
2024-01-07 20:13:22 -05:00
committed by GitHub
parent 0e312d7804
commit a228fd64cc
442 changed files with 24992 additions and 18542 deletions

View File

@ -2,9 +2,10 @@ root = true
[*]
indent_size = 2
ij_continuation_indent_size = 2
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.{kt, kts}]
kotlin_imports_layout = ascii
ij_kotlin_imports_layout = *

View File

@ -24,7 +24,6 @@ import org.junit.Test
* Run with "./gradlew :android-test-app:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/
class PublicSuffixDatabaseTest {
@Test
fun testTopLevelDomain() {
assertThat("https://www.google.com/robots.txt".toHttpUrl().topPrivateDomain()).isEqualTo("google.com")

View File

@ -19,20 +19,13 @@ import android.os.Bundle
import androidx.activity.ComponentActivity
import okhttp3.Call
import okhttp3.Callback
import okhttp3.CookieJar
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.internal.publicsuffix.PublicSuffixDatabase
import okio.FileSystem
import okio.IOException
import okio.Path.Companion.toPath
import java.net.CookieHandler
import java.net.URI
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
@ -41,15 +34,23 @@ class MainActivity : ComponentActivity() {
val url = "https://github.com/square/okhttp".toHttpUrl()
println(url.topPrivateDomain())
client.newCall(Request(url)).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
client.newCall(Request(url)).enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
println("failed: $e")
}
override fun onResponse(call: Call, response: Response) {
override fun onResponse(
call: Call,
response: Response,
) {
println("response: ${response.code}")
response.close()
}
})
},
)
}
}

View File

@ -17,5 +17,4 @@ package okhttp.android.testapp
import android.app.Application
class TestApplication: Application() {
}
class TestApplication : Application()

View File

@ -21,6 +21,25 @@ import com.google.android.gms.common.GooglePlayServicesNotAvailableException
import com.google.android.gms.security.ProviderInstaller
import com.squareup.moshi.Moshi
import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory
import java.io.IOException
import java.net.InetAddress
import java.net.UnknownHostException
import java.security.KeyStore
import java.security.SecureRandom
import java.security.Security
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Handler
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import mockwebserver3.MockResponse
import mockwebserver3.MockWebServer
import mockwebserver3.junit5.internal.MockWebServerExtension
@ -31,6 +50,7 @@ import okhttp3.Connection
import okhttp3.DelegatingSSLSocket
import okhttp3.DelegatingSSLSocketFactory
import okhttp3.EventListener
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okhttp3.OkHttpClientTestRule
@ -59,33 +79,13 @@ import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Assertions.fail
import org.junit.jupiter.api.Assumptions.assumeTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Tag
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.extension.RegisterExtension
import org.opentest4j.TestAbortedException
import java.io.IOException
import java.net.InetAddress
import java.net.UnknownHostException
import java.security.KeyStore
import java.security.SecureRandom
import java.security.Security
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Handler
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import okhttp3.Headers
import org.junit.jupiter.api.BeforeEach
/**
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
@ -93,20 +93,23 @@ import org.junit.jupiter.api.BeforeEach
@ExtendWith(MockWebServerExtension::class)
@Tag("Slow")
class OkHttpTest {
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension
public val platform = PlatformRule()
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension public val platform = PlatformRule()
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension public val clientTestRule = OkHttpClientTestRule().apply {
@RegisterExtension
public val clientTestRule =
OkHttpClientTestRule().apply {
logger = Logger.getLogger(OkHttpTest::class.java.name)
}
private var client: OkHttpClient = clientTestRule.newClient()
private val moshi = Moshi.Builder()
private val moshi =
Moshi.Builder()
.add(KotlinJsonAdapterFactory())
.build()
@ -136,7 +139,8 @@ class OkHttpTest {
val request = Request.Builder().url("https://api.twitter.com/robots.txt").build()
val clientCertificates = HandshakeCertificates.Builder()
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.apply {
if (Build.VERSION.SDK_INT >= 24) {
@ -145,7 +149,8 @@ class OkHttpTest {
}
.build()
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
@ -184,19 +189,26 @@ class OkHttpTest {
var socketClass: String? = null
val clientCertificates = HandshakeCertificates.Builder()
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.addInsecureHost(server.hostName)
.build()
// Need fresh client to reset sslSocketFactoryOrNull
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
})
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
@ -240,7 +252,8 @@ class OkHttpTest {
throw TestAbortedException("Google Play Services not available", gpsnae)
}
val clientCertificates = HandshakeCertificates.Builder()
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.addInsecureHost(server.hostName)
.build()
@ -250,13 +263,19 @@ class OkHttpTest {
var socketClass: String? = null
// Need fresh client to reset sslSocketFactoryOrNull
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
})
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
@ -298,7 +317,8 @@ class OkHttpTest {
var socketClass: String? = null
val clientCertificates = HandshakeCertificates.Builder()
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates().apply {
if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName)
@ -306,13 +326,19 @@ class OkHttpTest {
}
.build()
client = client.newBuilder()
client =
client.newBuilder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
})
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
@ -372,7 +398,7 @@ class OkHttpTest {
val tls_version: String,
val able_to_detect_n_minus_one_splitting: Boolean,
val insecure_cipher_suites: Map<String, List<String>>,
val given_cipher_suites: List<String>?
val given_cipher_suites: List<String>?,
)
@Test
@ -384,7 +410,8 @@ class OkHttpTest {
val response = client.newCall(request).execute()
val results = response.use {
val results =
response.use {
moshi.adapter(HowsMySslResults::class.java).fromJson(response.body.string())!!
}
@ -417,7 +444,7 @@ class OkHttpTest {
assertTrue(tlsVersion == TlsVersion.TLS_1_2 || tlsVersion == TlsVersion.TLS_1_3)
assertEquals(
"CN=localhost",
(response.handshake!!.peerCertificates.first() as X509Certificate).subjectDN.name
(response.handshake!!.peerCertificates.first() as X509Certificate).subjectDN.name,
)
}
}
@ -426,7 +453,8 @@ class OkHttpTest {
fun testCertificatePinningFailure() {
enableTls()
val certificatePinner = CertificatePinner.Builder()
val certificatePinner =
CertificatePinner.Builder()
.add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build()
@ -446,10 +474,11 @@ class OkHttpTest {
fun testCertificatePinningSuccess() {
enableTls()
val certificatePinner = CertificatePinner.Builder()
val certificatePinner =
CertificatePinner.Builder()
.add(
server.hostName,
CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0])
CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0]),
)
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build()
@ -488,9 +517,9 @@ class OkHttpTest {
"ConnectStart", "SecureConnectStart", "SecureConnectEnd", "ConnectEnd",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased",
"CallEnd"
"CallEnd",
),
eventListener.recordedEventTypes()
eventListener.recordedEventTypes(),
)
eventListener.clearAllEvents()
@ -504,9 +533,9 @@ class OkHttpTest {
"CallStart",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased",
"CallEnd"
"CallEnd",
),
eventListener.recordedEventTypes()
eventListener.recordedEventTypes(),
)
}
@ -516,14 +545,20 @@ class OkHttpTest {
enableTls()
client = client.newBuilder().eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
client =
client.newBuilder().eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
sessionIds.add(sslSocket.session.id.toByteString().hex())
}
})
},
),
).build()
server.enqueue(MockResponse(body = "abc1"))
@ -550,7 +585,8 @@ class OkHttpTest {
fun testDnsOverHttps() {
assumeNetwork()
client = client.newBuilder()
client =
client.newBuilder()
.eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory()))
.build()
@ -566,22 +602,31 @@ class OkHttpTest {
fun testCustomTrustManager() {
assumeNetwork()
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
val trustManager =
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
val sslContext = Platform.get().newSSLContext().apply {
val sslContext =
Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
val sslSocketFactory = sslContext.socketFactory
val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier)
.build()
@ -598,7 +643,8 @@ class OkHttpTest {
val sslSocketFactory = client.sslSocketFactory
val trustManager = client.x509TrustManager!!
val delegatingSocketFactory = object : DelegatingSSLSocketFactory(sslSocketFactory) {
val delegatingSocketFactory =
object : DelegatingSSLSocketFactory(sslSocketFactory) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
return object : DelegatingSSLSocket(sslSocket) {
override fun getApplicationProtocol(): String {
@ -608,7 +654,8 @@ class OkHttpTest {
}
}
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(delegatingSocketFactory, trustManager)
.build()
@ -626,16 +673,27 @@ class OkHttpTest {
var withHostCalled = false
var withoutHostCalled = false
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
val trustManager =
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {
withoutHostCalled = true
}
@Suppress("unused", "UNUSED_PARAMETER")
// called by Android via reflection in X509TrustManagerExtensions
fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String, hostname: String): List<X509Certificate> {
@Suppress("unused", "UNUSED_PARAMETER")
fun checkServerTrusted(
chain: Array<out X509Certificate>,
authType: String,
hostname: String,
): List<X509Certificate> {
withHostCalled = true
return chain.toList()
}
@ -643,14 +701,16 @@ class OkHttpTest {
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
val sslContext = Platform.get().newSSLContext().apply {
val sslContext =
Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
val sslSocketFactory = sslContext.socketFactory
val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier)
.build()
@ -702,23 +762,31 @@ class OkHttpTest {
var socketClass: String? = null
val trustManager = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
val trustManager =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(null as KeyStore?)
}.trustManagers.first() as X509TrustManager
val sslContext = Platform.get().newSSLContext().apply {
val sslContext =
Platform.get().newSSLContext().apply {
// TODO remove most of this code after https://github.com/bcgit/bc-java/issues/686
init(null, arrayOf(trustManager), SecureRandom())
}
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(sslContext.socketFactory, trustManager)
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
})
},
),
)
.build()
@ -743,7 +811,8 @@ class OkHttpTest {
fun testLoggingLevels() {
enableTls()
val testHandler = object : Handler() {
val testHandler =
object : Handler() {
val calls = mutableMapOf<String, AtomicInteger>()
override fun publish(record: LogRecord) {
@ -773,11 +842,13 @@ class OkHttpTest {
server.enqueue(MockResponse(body = "abc"))
val request = Request.Builder()
val request =
Request.Builder()
.url(server.url("/"))
.build()
val response = client.newCall(request)
val response =
client.newCall(request)
.execute()
response.use {
@ -802,11 +873,13 @@ class OkHttpTest {
val cache = Cache(ctxt.cacheDir.resolve("testCache"), cacheSize)
try {
client = client.newBuilder()
client =
client.newBuilder()
.cache(cache)
.build()
val request = Request.Builder()
val request =
Request.Builder()
.url(server.url("/"))
.build()
@ -853,9 +926,10 @@ class OkHttpTest {
}
private fun enableTls() {
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager
handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager,
)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory())

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp.android.test.alpn;
package okhttp.android.test.alpn
import android.os.Build
import android.util.Log
@ -36,9 +36,8 @@ import org.junit.jupiter.api.Test
*/
@Tag("Remote")
class AlpnOverrideTest {
class CustomSSLSocketFactory(
delegate: SSLSocketFactory
delegate: SSLSocketFactory,
) : DelegatingSSLSocketFactory(delegate) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
@ -56,23 +55,32 @@ class AlpnOverrideTest {
@Test
fun getWithCustomSocketFactory() {
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.connectionSpecs(listOf(
.connectionSpecs(
listOf(
ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
.supportsTlsExtensions(false)
.build()
))
.eventListener(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
.build(),
),
)
.eventListener(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
println("Requested " + sslSocket.sslParameters.applicationProtocols.joinToString())
println("Negotiated " + sslSocket.applicationProtocol)
}
})
},
)
.build()
val request = Request.Builder()
val request =
Request.Builder()
.url("https://www.google.com")
.build()
client.newCall(request).execute().use { response ->

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp.android.test.letsencrypt;
package okhttp.android.test.letsencrypt
import android.os.Build
import assertk.assertThat
@ -41,7 +41,8 @@ class LetsEncryptClientTest {
val clientBuilder = OkHttpClient.Builder()
if (androidMorEarlier) {
val cert: X509Certificate = """
val cert: X509Certificate =
"""
-----BEGIN CERTIFICATE-----
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
@ -75,20 +76,24 @@ class LetsEncryptClientTest {
-----END CERTIFICATE-----
""".trimIndent().decodeCertificatePem()
val handshakeCertificates = HandshakeCertificates.Builder()
val handshakeCertificates =
HandshakeCertificates.Builder()
// TODO reenable in official answers
// .addPlatformTrustedCertificates()
.addTrustedCertificate(cert)
.build()
clientBuilder
.sslSocketFactory(handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
}
val client = clientBuilder.build()
val request = Request.Builder()
val request =
Request.Builder()
.url("https://valid-isrgrootx1.letsencrypt.org/robots.txt")
.build()
client.newCall(request).execute().use { response ->

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp.android.test.sni;
package okhttp.android.test.sni
import android.os.Build
import android.util.Log
@ -39,7 +39,8 @@ import org.junit.jupiter.api.Test
*/
@Tag("Remote")
class SniOverrideTest {
var client = OkHttpClient.Builder()
var client =
OkHttpClient.Builder()
.build()
@Test
@ -47,7 +48,7 @@ class SniOverrideTest {
assumeTrue(Build.VERSION.SDK_INT >= 24)
class CustomSSLSocketFactory(
delegate: SSLSocketFactory
delegate: SSLSocketFactory,
) : DelegatingSSLSocketFactory(delegate) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
@ -62,7 +63,8 @@ class SniOverrideTest {
}
}
client = client.newBuilder()
client =
client.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.hostnameVerifier { hostname, session ->
val s = "hostname: $hostname peerHost:${session.peerHost}"
@ -81,7 +83,8 @@ class SniOverrideTest {
}
.build()
val request = Request.Builder()
val request =
Request.Builder()
.url("https://sni.cloudflaressl.com/cdn-cgi/trace")
.header("Host", "cloudflare-dns.com")
.build()
@ -89,20 +92,21 @@ class SniOverrideTest {
assertThat(response.code).isEqualTo(200)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)
assertThat(response.body.string()).contains("h=cloudflare-dns.com")
}
}
@Test
fun getWithDns() {
client = client.newBuilder()
client =
client.newBuilder()
.dns {
Dns.SYSTEM.lookup("sni.cloudflaressl.com")
}
.build()
val request = Request.Builder()
val request =
Request.Builder()
.url("https://cloudflare-dns.com/cdn-cgi/trace")
.build()
client.newCall(request).execute().use { response ->

View File

@ -1,14 +1,14 @@
@file:Suppress("UnstableApiUsage")
import com.diffplug.gradle.spotless.SpotlessExtension
import com.vanniktech.maven.publish.MavenPublishBaseExtension
import com.vanniktech.maven.publish.SonatypeHost
import java.net.URL
import java.net.URI
import kotlinx.validation.ApiValidationExtension
import org.gradle.api.tasks.testing.logging.TestExceptionFormat
import org.jetbrains.dokka.gradle.DokkaTaskPartial
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import ru.vyarus.gradle.plugin.animalsniffer.AnimalSnifferExtension
import java.net.URI
buildscript {
dependencies {
@ -35,6 +35,14 @@ buildscript {
}
apply(plugin = "org.jetbrains.dokka")
apply(plugin = "com.diffplug.spotless")
configure<SpotlessExtension> {
kotlin {
target("**/*.kt")
ktlint()
}
}
allprojects {
group = "com.squareup.okhttp3"

View File

@ -34,11 +34,12 @@ fun Project.applyOsgi(vararg bndProperties: String) {
private fun Project.applyOsgi(
jarTaskName: String,
osgiApiConfigurationName: String,
bndProperties: Array<out String>
bndProperties: Array<out String>,
) {
val osgi = project.sourceSets.create("osgi")
val osgiApi = project.configurations.getByName(osgiApiConfigurationName)
val kotlinOsgi = extensions.getByType(VersionCatalogsExtension::class.java).named("libs")
val kotlinOsgi =
extensions.getByType(VersionCatalogsExtension::class.java).named("libs")
.findLibrary("kotlin.stdlib.osgi").get().get()
project.dependencies {
@ -46,8 +47,9 @@ private fun Project.applyOsgi(
}
val jarTask = tasks.getByName<Jar>(jarTaskName)
val bundleExtension = jarTask.extensions.findByType() ?: jarTask.extensions.create(
BundleTaskExtension.NAME, BundleTaskExtension::class.java, jarTask
val bundleExtension =
jarTask.extensions.findByType() ?: jarTask.extensions.create(
BundleTaskExtension.NAME, BundleTaskExtension::class.java, jarTask,
)
bundleExtension.run {
setClasspath(osgi.compileClasspath + sourceSets["main"].compileClasspath)

View File

@ -36,9 +36,7 @@ internal fun Dispatcher.wrap(): mockwebserver3.Dispatcher {
val delegate = this
return object : mockwebserver3.Dispatcher() {
override fun dispatch(
request: mockwebserver3.RecordedRequest
): mockwebserver3.MockResponse {
override fun dispatch(request: mockwebserver3.RecordedRequest): mockwebserver3.MockResponse {
return delegate.dispatch(request.unwrap()).wrap()
}
@ -70,7 +68,8 @@ internal fun MockResponse.wrap(): mockwebserver3.MockResponse {
result.status = status
result.headers(headers)
result.trailers(trailers)
result.socketPolicy = when (socketPolicy) {
result.socketPolicy =
when (socketPolicy) {
SocketPolicy.EXPECT_CONTINUE, SocketPolicy.CONTINUE_ALWAYS -> {
result.add100Continue()
KeepOpen
@ -92,7 +91,7 @@ private fun PushPromise.wrap(): mockwebserver3.PushPromise {
method = method,
path = path,
headers = headers,
response = response.wrap()
response = response.wrap(),
)
}
@ -108,7 +107,7 @@ internal fun mockwebserver3.RecordedRequest.unwrap(): RecordedRequest {
method = method,
path = path,
handshake = handshake,
requestUrl = requestUrl
requestUrl = requestUrl,
)
}

View File

@ -14,14 +14,15 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.mockwebserver
import java.util.concurrent.TimeUnit
import okhttp3.Headers
import okhttp3.WebSocketListener
import okhttp3.internal.addHeaderLenient
import okhttp3.internal.http2.Settings
import okio.Buffer
import java.util.concurrent.TimeUnit
class MockResponse : Cloneable {
@set:JvmName("status")
@ -88,15 +89,18 @@ class MockResponse : Cloneable {
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "status"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getStatus(): String = status
fun setStatus(status: String) = apply {
fun setStatus(status: String) =
apply {
this.status = status
}
fun setResponseCode(code: Int): MockResponse {
val reason = when (code) {
val reason =
when (code) {
in 100..199 -> "Informational"
in 200..299 -> "OK"
in 300..399 -> "Redirection"
@ -107,41 +111,57 @@ class MockResponse : Cloneable {
return apply { status = "HTTP/1.1 $code $reason" }
}
fun clearHeaders() = apply {
fun clearHeaders() =
apply {
headersBuilder = Headers.Builder()
}
fun addHeader(header: String) = apply {
fun addHeader(header: String) =
apply {
headersBuilder.add(header)
}
fun addHeader(name: String, value: Any) = apply {
fun addHeader(
name: String,
value: Any,
) = apply {
headersBuilder.add(name, value.toString())
}
fun addHeaderLenient(name: String, value: Any) = apply {
fun addHeaderLenient(
name: String,
value: Any,
) = apply {
addHeaderLenient(headersBuilder, name, value.toString())
}
fun setHeader(name: String, value: Any) = apply {
fun setHeader(
name: String,
value: Any,
) = apply {
removeHeader(name)
addHeader(name, value)
}
fun removeHeader(name: String) = apply {
fun removeHeader(name: String) =
apply {
headersBuilder.removeAll(name)
}
fun getBody(): Buffer? = body?.clone()
fun setBody(body: Buffer) = apply {
fun setBody(body: Buffer) =
apply {
setHeader("Content-Length", body.size)
this.body = body.clone() // Defensive copy.
}
fun setBody(body: String): MockResponse = setBody(Buffer().writeUtf8(body))
fun setChunkedBody(body: Buffer, maxChunkSize: Int) = apply {
fun setChunkedBody(
body: Buffer,
maxChunkSize: Int,
) = apply {
removeHeader("Content-Length")
headersBuilder.add(CHUNKED_BODY_HEADER)
@ -157,14 +177,17 @@ class MockResponse : Cloneable {
this.body = bytesOut
}
fun setChunkedBody(body: String, maxChunkSize: Int): MockResponse =
setChunkedBody(Buffer().writeUtf8(body), maxChunkSize)
fun setChunkedBody(
body: String,
maxChunkSize: Int,
): MockResponse = setChunkedBody(Buffer().writeUtf8(body), maxChunkSize)
@JvmName("-deprecated_getHeaders")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getHeaders(): Headers = headers
fun setHeaders(headers: Headers) = apply { this.headers = headers }
@ -173,7 +196,8 @@ class MockResponse : Cloneable {
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "trailers"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getTrailers(): Headers = trailers
fun setTrailers(trailers: Headers) = apply { this.trailers = trailers }
@ -182,10 +206,12 @@ class MockResponse : Cloneable {
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "socketPolicy"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getSocketPolicy(): SocketPolicy = socketPolicy
fun setSocketPolicy(socketPolicy: SocketPolicy) = apply {
fun setSocketPolicy(socketPolicy: SocketPolicy) =
apply {
this.socketPolicy = socketPolicy
}
@ -193,47 +219,59 @@ class MockResponse : Cloneable {
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "http2ErrorCode"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getHttp2ErrorCode(): Int = http2ErrorCode
fun setHttp2ErrorCode(http2ErrorCode: Int) = apply {
fun setHttp2ErrorCode(http2ErrorCode: Int) =
apply {
this.http2ErrorCode = http2ErrorCode
}
fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply {
fun throttleBody(
bytesPerPeriod: Long,
period: Long,
unit: TimeUnit,
) = apply {
throttleBytesPerPeriod = bytesPerPeriod
throttlePeriodAmount = period
throttlePeriodUnit = unit
}
fun getThrottlePeriod(unit: TimeUnit): Long =
unit.convert(throttlePeriodAmount, throttlePeriodUnit)
fun getThrottlePeriod(unit: TimeUnit): Long = unit.convert(throttlePeriodAmount, throttlePeriodUnit)
fun setBodyDelay(delay: Long, unit: TimeUnit) = apply {
fun setBodyDelay(
delay: Long,
unit: TimeUnit,
) = apply {
bodyDelayAmount = delay
bodyDelayUnit = unit
}
fun getBodyDelay(unit: TimeUnit): Long =
unit.convert(bodyDelayAmount, bodyDelayUnit)
fun getBodyDelay(unit: TimeUnit): Long = unit.convert(bodyDelayAmount, bodyDelayUnit)
fun setHeadersDelay(delay: Long, unit: TimeUnit) = apply {
fun setHeadersDelay(
delay: Long,
unit: TimeUnit,
) = apply {
headersDelayAmount = delay
headersDelayUnit = unit
}
fun getHeadersDelay(unit: TimeUnit): Long =
unit.convert(headersDelayAmount, headersDelayUnit)
fun getHeadersDelay(unit: TimeUnit): Long = unit.convert(headersDelayAmount, headersDelayUnit)
fun withPush(promise: PushPromise) = apply {
fun withPush(promise: PushPromise) =
apply {
promises.add(promise)
}
fun withSettings(settings: Settings) = apply {
fun withSettings(settings: Settings) =
apply {
this.settings = settings
}
fun withWebSocketUpgrade(listener: WebSocketListener) = apply {
fun withWebSocketUpgrade(listener: WebSocketListener) =
apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")

View File

@ -15,9 +15,6 @@
*/
package okhttp3.mockwebserver
import okhttp3.HttpUrl
import okhttp3.Protocol
import org.junit.rules.ExternalResource
import java.io.Closeable
import java.io.IOException
import java.net.InetAddress
@ -27,6 +24,9 @@ import java.util.logging.Level
import java.util.logging.Logger
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.HttpUrl
import okhttp3.Protocol
import org.junit.rules.ExternalResource
class MockWebServer : ExternalResource(), Closeable {
val delegate = mockwebserver3.MockWebServer()
@ -57,7 +57,8 @@ class MockWebServer : ExternalResource(), Closeable {
var protocolNegotiationEnabled: Boolean by delegate::protocolNegotiationEnabled
@get:JvmName("protocols") var protocols: List<Protocol>
@get:JvmName("protocols")
var protocols: List<Protocol>
get() = delegate.protocols
set(value) {
delegate.protocols = value
@ -82,7 +83,8 @@ class MockWebServer : ExternalResource(), Closeable {
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "port"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getPort(): Int = port
fun toProxyAddress(): Proxy {
@ -93,10 +95,12 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_serverSocketFactory")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.serverSocketFactory = serverSocketFactory }"
replaceWith =
ReplaceWith(
expression = "run { this.serverSocketFactory = serverSocketFactory }",
),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) {
delegate.serverSocketFactory = serverSocketFactory
}
@ -109,10 +113,12 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_bodyLimit")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.bodyLimit = bodyLimit }"
replaceWith =
ReplaceWith(
expression = "run { this.bodyLimit = bodyLimit }",
),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setBodyLimit(bodyLimit: Long) {
delegate.bodyLimit = bodyLimit
}
@ -120,10 +126,12 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_protocolNegotiationEnabled")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }"
replaceWith =
ReplaceWith(
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }",
),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) {
delegate.protocolNegotiationEnabled = protocolNegotiationEnabled
}
@ -132,7 +140,8 @@ class MockWebServer : ExternalResource(), Closeable {
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setProtocols(protocols: List<Protocol>) {
delegate.protocols = protocols
}
@ -141,10 +150,14 @@ class MockWebServer : ExternalResource(), Closeable {
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "protocols"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun protocols(): List<Protocol> = delegate.protocols
fun useHttps(sslSocketFactory: SSLSocketFactory, tunnelProxy: Boolean) {
fun useHttps(
sslSocketFactory: SSLSocketFactory,
tunnelProxy: Boolean,
) {
delegate.useHttps(sslSocketFactory)
}
@ -166,7 +179,10 @@ class MockWebServer : ExternalResource(), Closeable {
}
@Throws(InterruptedException::class)
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? {
fun takeRequest(
timeout: Long,
unit: TimeUnit,
): RecordedRequest? {
return delegate.takeRequest(timeout, unit)?.unwrap()
}
@ -174,7 +190,8 @@ class MockWebServer : ExternalResource(), Closeable {
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "requestCount"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun getRequestCount(): Int = delegate.requestCount
fun enqueue(response: MockResponse) {
@ -182,13 +199,17 @@ class MockWebServer : ExternalResource(), Closeable {
}
@Throws(IOException::class)
@JvmOverloads fun start(port: Int = 0) {
@JvmOverloads
fun start(port: Int = 0) {
started = true
delegate.start(port)
}
@Throws(IOException::class)
fun start(inetAddress: InetAddress, port: Int) {
fun start(
inetAddress: InetAddress,
port: Int,
) {
started = true
delegate.start(inetAddress, port)
}

View File

@ -21,34 +21,37 @@ class PushPromise(
@get:JvmName("method") val method: String,
@get:JvmName("path") val path: String,
@get:JvmName("headers") val headers: Headers,
@get:JvmName("response") val response: MockResponse
@get:JvmName("response") val response: MockResponse,
) {
@JvmName("-deprecated_method")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "method"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun method(): String = method
@JvmName("-deprecated_path")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "path"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun path(): String = path
@JvmName("-deprecated_headers")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun headers(): Headers = headers
@JvmName("-deprecated_response")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "response"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun response(): MockResponse = response
}

View File

@ -15,6 +15,10 @@
*/
package okhttp3.mockwebserver
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
import okhttp3.Handshake
import okhttp3.Handshake.Companion.handshake
import okhttp3.Headers
@ -22,10 +26,6 @@ import okhttp3.HttpUrl
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
import okhttp3.TlsVersion
import okio.Buffer
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
class RecordedRequest {
val requestLine: String
@ -44,7 +44,8 @@ class RecordedRequest {
@Deprecated(
message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
val utf8Body: String
get() = body.readUtf8()
@ -62,7 +63,7 @@ class RecordedRequest {
method: String?,
path: String?,
handshake: Handshake?,
requestUrl: HttpUrl?
requestUrl: HttpUrl?,
) {
this.requestLine = requestLine
this.headers = headers
@ -86,7 +87,7 @@ class RecordedRequest {
body: Buffer,
sequenceNumber: Int,
socket: Socket,
failure: IOException? = null
failure: IOException? = null,
) {
this.requestLine = requestLine
this.headers = headers
@ -141,7 +142,8 @@ class RecordedRequest {
@Deprecated(
message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.WARNING)
level = DeprecationLevel.WARNING,
)
fun getUtf8Body(): String = body.readUtf8()
fun getHeader(name: String): String? = headers.values(name).firstOrNull()

View File

@ -32,5 +32,5 @@ enum class SocketPolicy {
NO_RESPONSE,
RESET_STREAM_AT_START,
EXPECT_CONTINUE,
CONTINUE_ALWAYS
CONTINUE_ALWAYS,
}

View File

@ -15,6 +15,12 @@
*/
package okhttp3.mockwebserver
import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.Handshake
import okhttp3.Headers
import okhttp3.Headers.Companion.headersOf
@ -26,12 +32,6 @@ import okhttp3.internal.http2.Settings
import okio.Buffer
import org.junit.Ignore
import org.junit.Test
import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
/**
* Access every type, function, and property from Kotlin to defend against unexpected regressions in
@ -45,14 +45,17 @@ import javax.net.ssl.SSLSocketFactory
"VARIABLE_WITH_REDUNDANT_INITIALIZER",
"RedundantLambdaArrow",
"RedundantExplicitType",
"IMPLICIT_NOTHING_AS_TYPE_PARAMETER"
"IMPLICIT_NOTHING_AS_TYPE_PARAMETER",
)
class KotlinSourceModernTest {
@Test @Ignore
fun dispatcherFromMockWebServer() {
val dispatcher = object : Dispatcher() {
val dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse = TODO()
override fun peek(): MockResponse = TODO()
override fun shutdown() = TODO()
}
}
@ -96,8 +99,11 @@ class KotlinSourceModernTest {
mockResponse = mockResponse.withSettings(Settings())
var settings: Settings = mockResponse.settings
settings = mockResponse.settings
mockResponse = mockResponse.withWebSocketUpgrade(object : WebSocketListener() {
})
mockResponse =
mockResponse.withWebSocketUpgrade(
object : WebSocketListener() {
},
)
var webSocketListener: WebSocketListener? = mockResponse.webSocketListener
webSocketListener = mockResponse.webSocketListener
}
@ -146,8 +152,10 @@ class KotlinSourceModernTest {
@Test @Ignore
fun queueDispatcher() {
val queueDispatcher: QueueDispatcher = QueueDispatcher()
var mockResponse: MockResponse = queueDispatcher.dispatch(
RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket()))
var mockResponse: MockResponse =
queueDispatcher.dispatch(
RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket()),
)
mockResponse = queueDispatcher.peek()
queueDispatcher.enqueueResponse(MockResponse())
queueDispatcher.shutdown()
@ -157,8 +165,16 @@ class KotlinSourceModernTest {
@Test @Ignore
fun recordedRequest() {
var recordedRequest: RecordedRequest = RecordedRequest(
"", headersOf(), listOf(), 0L, Buffer(), 0, Socket())
var recordedRequest: RecordedRequest =
RecordedRequest(
"",
headersOf(),
listOf(),
0L,
Buffer(),
0,
Socket(),
)
recordedRequest = RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket())
var requestUrl: HttpUrl? = recordedRequest.requestUrl
var requestLine: String = recordedRequest.requestLine

View File

@ -86,14 +86,15 @@ class MockWebServerTest {
@Test
fun setResponseMockReason() {
val reasons = arrayOf(
val reasons =
arrayOf(
"Mock Response",
"Informational",
"OK",
"Redirection",
"Client Error",
"Server Error",
"Mock Response"
"Mock Response",
)
for (i in 0..599) {
val response = MockResponse().setResponseCode(i)
@ -119,7 +120,8 @@ class MockWebServerTest {
@Test
fun mockResponseAddHeader() {
val response = MockResponse()
val response =
MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
@ -129,7 +131,8 @@ class MockWebServerTest {
@Test
fun mockResponseSetHeader() {
val response = MockResponse()
val response =
MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie: a=android")
@ -141,7 +144,8 @@ class MockWebServerTest {
@Test
fun mockResponseSetHeaders() {
val response = MockResponse()
val response =
MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
@ -173,7 +177,7 @@ class MockWebServerTest {
MockResponse()
.setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path"))
.setBody("This page has moved!")
.setBody("This page has moved!"),
)
server.enqueue(MockResponse().setBody("This is the new location!"))
val connection = server.url("/").toUrl().openConnection()
@ -211,7 +215,7 @@ class MockWebServerTest {
MockResponse()
.setBody("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders()
.addHeader("Transfer-encoding: chunked")
.addHeader("Transfer-encoding: chunked"),
)
val connection = server.url("/").toUrl().openConnection()
val inputStream = connection.getInputStream()
@ -228,7 +232,7 @@ class MockWebServerTest {
MockResponse()
.setBody("ABC")
.clearHeaders()
.addHeader("Content-Length: 4")
.addHeader("Content-Length: 4"),
)
server.enqueue(MockResponse().setBody("DEF"))
val urlConnection = server.url("/").toUrl().openConnection()
@ -257,7 +261,7 @@ class MockWebServerTest {
fun disconnectAtStart() {
server.enqueue(
MockResponse()
.setSocketPolicy(SocketPolicy.DISCONNECT_AT_START)
.setSocketPolicy(SocketPolicy.DISCONNECT_AT_START),
)
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse())
@ -278,7 +282,7 @@ class MockWebServerTest {
assumeNotWindows()
server.enqueue(
MockResponse()
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.throttleBody(3, 500, TimeUnit.MILLISECONDS),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
@ -301,7 +305,7 @@ class MockWebServerTest {
server.enqueue(
MockResponse()
.setBody("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.throttleBody(3, 500, TimeUnit.MILLISECONDS),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
@ -325,7 +329,7 @@ class MockWebServerTest {
server.enqueue(
MockResponse()
.setBody("ABCDEF")
.setBodyDelay(1, TimeUnit.SECONDS)
.setBodyDelay(1, TimeUnit.SECONDS),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
@ -373,7 +377,7 @@ class MockWebServerTest {
server.enqueue(
MockResponse()
.setBody("ab")
.setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY)
.setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY),
)
val connection = server.url("/").toUrl().openConnection()
assertThat(connection.getContentLength()).isEqualTo(2)
@ -443,12 +447,16 @@ class MockWebServerTest {
@Test
fun statementStartsAndStops() {
val called = AtomicBoolean()
val statement = server.apply(object : Statement() {
val statement =
server.apply(
object : Statement() {
override fun evaluate() {
called.set(true)
server.url("/").toUrl().openConnection().connect()
}
}, Description.EMPTY)
},
Description.EMPTY,
)
statement.evaluate()
assertThat(called.get()).isTrue()
try {
@ -484,7 +492,7 @@ class MockWebServerTest {
assertThat(reader.readLine()).isEqualTo("hello world")
val request = server.takeRequest()
assertThat(request.requestLine).isEqualTo(
"GET /a/deep/path?key=foo%20bar HTTP/1.1"
"GET /a/deep/path?key=foo%20bar HTTP/1.1",
)
val requestUrl = request.requestUrl
assertThat(requestUrl!!.scheme).isEqualTo("http")
@ -530,8 +538,8 @@ class MockWebServerTest {
fail<Any>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, http/1.1]"
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, http/1.1]",
)
}
}
@ -544,8 +552,8 @@ class MockWebServerTest {
fail<Any>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, h2_prior_knowledge]"
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, h2_prior_knowledge]",
)
}
}
@ -584,27 +592,33 @@ class MockWebServerTest {
fun httpsWithClientAuth() {
platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt()
val clientCa = HeldCertificate.Builder()
val clientCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCa = HeldCertificate.Builder()
val serverCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCertificate = HeldCertificate.Builder()
val serverCertificate =
HeldCertificate.Builder()
.signedBy(serverCa)
.addSubjectAlternativeName(server.hostName)
.build()
val serverHandshakeCertificates = HandshakeCertificates.Builder()
val serverHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory(), false)
server.enqueue(MockResponse().setBody("abc"))
server.requestClientAuth()
val clientCertificate = HeldCertificate.Builder()
val clientCertificate =
HeldCertificate.Builder()
.signedBy(clientCa)
.build()
val clientHandshakeCertificates = HandshakeCertificates.Builder()
val clientHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()

View File

@ -15,11 +15,11 @@
*/
package mockwebserver3.junit4
import mockwebserver3.MockWebServer
import org.junit.rules.ExternalResource
import java.io.IOException
import java.util.logging.Level
import java.util.logging.Logger
import mockwebserver3.MockWebServer
import org.junit.rules.ExternalResource
/**
* Runs MockWebServer for the duration of a single test method.

View File

@ -28,12 +28,16 @@ class MockWebServerRuleTest {
@Test fun statementStartsAndStops() {
val rule = MockWebServerRule()
val called = AtomicBoolean()
val statement: Statement = rule.apply(object : Statement() {
val statement: Statement =
rule.apply(
object : Statement() {
override fun evaluate() {
called.set(true)
rule.server.url("/").toUrl().openConnection().connect()
}
}, Description.EMPTY)
},
Description.EMPTY,
)
statement.evaluate()
assertThat(called.get()).isTrue()
try {

View File

@ -38,10 +38,11 @@ import org.junit.jupiter.api.extension.ParameterResolver
* - The test lifecycle default (passed into test method, plus @BeforeEach, @AfterEach)
* - named instances with @MockWebServerInstance.
*/
class MockWebServerExtension
: BeforeEachCallback, AfterEachCallback, ParameterResolver {
class MockWebServerExtension :
BeforeEachCallback, AfterEachCallback, ParameterResolver {
private val ExtensionContext.resource: ServersForTest
get() = getStore(namespace).getOrComputeIfAbsent(this.uniqueId) {
get() =
getStore(namespace).getOrComputeIfAbsent(this.uniqueId) {
ServersForTest()
} as ServersForTest
@ -80,20 +81,21 @@ class MockWebServerExtension
@IgnoreJRERequirement
override fun supportsParameter(
parameterContext: ParameterContext,
extensionContext: ExtensionContext
extensionContext: ExtensionContext,
): Boolean {
// Not supported on constructors, or static contexts
return parameterContext.parameter.type === MockWebServer::class.java
&& extensionContext.testMethod.isPresent
return parameterContext.parameter.type === MockWebServer::class.java &&
extensionContext.testMethod.isPresent
}
@Suppress("NewApi")
override fun resolveParameter(
parameterContext: ParameterContext,
extensionContext: ExtensionContext
extensionContext: ExtensionContext,
): Any {
val nameAnnotation = parameterContext.findAnnotation(MockWebServerInstance::class.java)
val name = if (nameAnnotation.isPresent) {
val name =
if (nameAnnotation.isPresent) {
nameAnnotation.get().name
} else {
defaultName

View File

@ -16,5 +16,5 @@
package mockwebserver3.junit5.internal
annotation class MockWebServerInstance(
val name: String
val name: String,
)

View File

@ -35,7 +35,7 @@ class ExtensionMultipleInstancesTest {
fun setup(
defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer
@MockWebServerInstance("B") instanceB: MockWebServer,
) {
defaultInstancePort = defaultInstance.port
instanceAPort = instanceA.port
@ -51,7 +51,7 @@ class ExtensionMultipleInstancesTest {
fun tearDown(
defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer
@MockWebServerInstance("B") instanceB: MockWebServer,
) {
assertThat(defaultInstance.port).isEqualTo(defaultInstancePort)
assertThat(instanceA.port).isEqualTo(instanceAPort)
@ -62,7 +62,7 @@ class ExtensionMultipleInstancesTest {
fun testClient(
defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer
@MockWebServerInstance("B") instanceB: MockWebServer,
) {
assertThat(defaultInstance.port).isEqualTo(defaultInstancePort)
assertThat(instanceA.port).isEqualTo(instanceAPort)

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3
import java.util.concurrent.TimeUnit
@ -28,7 +29,6 @@ import okio.Buffer
/** A scripted response to be replayed by the mock web server. */
class MockResponse {
/** Returns the HTTP response line, such as "HTTP/1.1 200 OK". */
val status: String
@ -76,14 +76,15 @@ class MockResponse {
body: String = "",
inTunnel: Boolean = false,
socketPolicy: SocketPolicy = KeepOpen,
) : this(Builder()
) : this(
Builder()
.apply {
this.code = code
this.headers.addAll(headers)
if (inTunnel) inTunnel()
this.body(body)
this.socketPolicy = socketPolicy
}
},
)
private constructor(builder: Builder) {
@ -101,7 +102,8 @@ class MockResponse {
this.bodyDelayNanos = builder.bodyDelayNanos
this.headersDelayNanos = builder.headersDelayNanos
this.pushPromises = builder.pushPromises.toList()
this.settings = Settings().apply {
this.settings =
Settings().apply {
merge(builder.settings)
}
}
@ -125,7 +127,8 @@ class MockResponse {
return statusParts[1].toInt()
}
set(value) {
val reason = when (value) {
val reason =
when (value) {
in 100..199 -> "Informational"
in 200..299 -> "OK"
in 300..399 -> "Redirection"
@ -189,7 +192,8 @@ class MockResponse {
this.bodyVar = null
this.streamHandlerVar = null
this.webSocketListenerVar = null
this.headers = Headers.Builder()
this.headers =
Headers.Builder()
.add("Content-Length", "0")
this.trailers = Headers.Builder()
this.throttleBytesPerPeriod = Long.MAX_VALUE
@ -216,17 +220,20 @@ class MockResponse {
this.bodyDelayNanos = mockResponse.bodyDelayNanos
this.headersDelayNanos = mockResponse.headersDelayNanos
this.pushPromises = mockResponse.pushPromises.toMutableList()
this.settings = Settings().apply {
this.settings =
Settings().apply {
merge(mockResponse.settings)
}
}
fun code(code: Int) = apply {
fun code(code: Int) =
apply {
this.code = code
}
/** Sets the status and returns this. */
fun status(status: String) = apply {
fun status(status: String) =
apply {
this.status = status
}
@ -234,7 +241,8 @@ class MockResponse {
* Removes all HTTP headers including any "Content-Length" and "Transfer-encoding" headers that
* were added by default.
*/
fun clearHeaders() = apply {
fun clearHeaders() =
apply {
headers = Headers.Builder()
}
@ -242,7 +250,8 @@ class MockResponse {
* Adds [header] as an HTTP header. For well-formed HTTP [header] should contain a name followed
* by a colon and a value.
*/
fun addHeader(header: String) = apply {
fun addHeader(header: String) =
apply {
headers.add(header)
}
@ -250,7 +259,10 @@ class MockResponse {
* Adds a new header with the name and value. This may be used to add multiple headers with the
* same name.
*/
fun addHeader(name: String, value: Any) = apply {
fun addHeader(
name: String,
value: Any,
) = apply {
headers.add(name, value.toString())
}
@ -259,24 +271,32 @@ class MockResponse {
* same name. Unlike [addHeader] this does not validate the name and
* value.
*/
fun addHeaderLenient(name: String, value: Any) = apply {
fun addHeaderLenient(
name: String,
value: Any,
) = apply {
addHeaderLenient(headers, name, value.toString())
}
/** Removes all headers named [name], then adds a new header with the name and value. */
fun setHeader(name: String, value: Any) = apply {
fun setHeader(
name: String,
value: Any,
) = apply {
removeHeader(name)
addHeader(name, value)
}
/** Removes all headers named [name]. */
fun removeHeader(name: String) = apply {
fun removeHeader(name: String) =
apply {
headers.removeAll(name)
}
fun body(body: Buffer) = body(body.toMockResponseBody())
fun body(body: MockResponseBody) = apply {
fun body(body: MockResponseBody) =
apply {
setHeader("Content-Length", body.contentLength)
this.body = body
}
@ -284,14 +304,18 @@ class MockResponse {
/** Sets the response body to the UTF-8 encoded bytes of [body]. */
fun body(body: String): Builder = body(Buffer().writeUtf8(body))
fun streamHandler(streamHandler: StreamHandler) = apply {
fun streamHandler(streamHandler: StreamHandler) =
apply {
this.streamHandler = streamHandler
}
/**
* Sets the response body to [body], chunked every [maxChunkSize] bytes.
*/
fun chunkedBody(body: Buffer, maxChunkSize: Int) = apply {
fun chunkedBody(
body: Buffer,
maxChunkSize: Int,
) = apply {
removeHeader("Content-Length")
headers.add(CHUNKED_BODY_HEADER)
@ -311,21 +335,26 @@ class MockResponse {
* Sets the response body to the UTF-8 encoded bytes of [body],
* chunked every [maxChunkSize] bytes.
*/
fun chunkedBody(body: String, maxChunkSize: Int): Builder =
chunkedBody(Buffer().writeUtf8(body), maxChunkSize)
fun chunkedBody(
body: String,
maxChunkSize: Int,
): Builder = chunkedBody(Buffer().writeUtf8(body), maxChunkSize)
/** Sets the headers and returns this. */
fun headers(headers: Headers) = apply {
fun headers(headers: Headers) =
apply {
this.headers = headers.newBuilder()
}
/** Sets the trailers and returns this. */
fun trailers(trailers: Headers) = apply {
fun trailers(trailers: Headers) =
apply {
this.trailers = trailers.newBuilder()
}
/** Sets the socket policy and returns this. */
fun socketPolicy(socketPolicy: SocketPolicy) = apply {
fun socketPolicy(socketPolicy: SocketPolicy) =
apply {
this.socketPolicy = socketPolicy
}
@ -333,7 +362,11 @@ class MockResponse {
* Throttles the request reader and response writer to sleep for the given period after each
* series of [bytesPerPeriod] bytes are transferred. Use this to simulate network behavior.
*/
fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply {
fun throttleBody(
bytesPerPeriod: Long,
period: Long,
unit: TimeUnit,
) = apply {
throttleBytesPerPeriod = bytesPerPeriod
throttlePeriodNanos = unit.toNanos(period)
}
@ -342,11 +375,17 @@ class MockResponse {
* Set the delayed time of the response body to [delay]. This applies to the response body
* only; response headers are not affected.
*/
fun bodyDelay(delay: Long, unit: TimeUnit) = apply {
fun bodyDelay(
delay: Long,
unit: TimeUnit,
) = apply {
bodyDelayNanos = unit.toNanos(delay)
}
fun headersDelay(delay: Long, unit: TimeUnit) = apply {
fun headersDelay(
delay: Long,
unit: TimeUnit,
) = apply {
headersDelayNanos = unit.toNanos(delay)
}
@ -354,7 +393,8 @@ class MockResponse {
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this attaches a
* pushed stream to this response.
*/
fun addPush(promise: PushPromise) = apply {
fun addPush(promise: PushPromise) =
apply {
this.pushPromises += promise
}
@ -362,7 +402,8 @@ class MockResponse {
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this pushes
* [settings] before writing the response.
*/
fun settings(settings: Settings) = apply {
fun settings(settings: Settings) =
apply {
this.settings.clear()
this.settings.merge(settings)
}
@ -371,7 +412,8 @@ class MockResponse {
* Attempts to perform a web socket upgrade on the connection.
* This will overwrite any previously set status, body, or streamHandler.
*/
fun webSocketUpgrade(listener: WebSocketListener) = apply {
fun webSocketUpgrade(listener: WebSocketListener) =
apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")
@ -385,7 +427,8 @@ class MockResponse {
* When a new connection is received, all in-tunnel responses are served before the connection is
* upgraded to HTTPS or HTTP/2.
*/
fun inTunnel() = apply {
fun inTunnel() =
apply {
removeHeader("Content-Length")
inTunnel = true
}
@ -395,11 +438,13 @@ class MockResponse {
* [headers delay][headersDelay] applies after this response is transmitted. Set a
* headers delay on that response to delay its transmission.
*/
fun addInformationalResponse(response: MockResponse) = apply {
fun addInformationalResponse(response: MockResponse) =
apply {
informationalResponses += response
}
fun add100Continue() = apply {
fun add100Continue() =
apply {
addInformationalResponse(MockResponse(code = 100))
}

View File

@ -15,6 +15,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3
import java.io.Closeable
@ -97,8 +98,9 @@ import okio.source
* in sequence.
*/
class MockWebServer : Closeable {
private val taskRunnerBackend = TaskRunner.RealBackend(
threadFactory("MockWebServer TaskRunner", daemon = false)
private val taskRunnerBackend =
TaskRunner.RealBackend(
threadFactory("MockWebServer TaskRunner", daemon = false),
)
private val taskRunner = TaskRunner(taskRunnerBackend)
private val requestQueue = LinkedBlockingQueue<RecordedRequest>()
@ -126,6 +128,7 @@ class MockWebServer : Closeable {
}
return field
}
@Synchronized set(value) {
check(!started) { "serverSocketFactory must not be set after start()" }
field = value
@ -280,8 +283,10 @@ class MockWebServer : Closeable {
* @return the head of the request queue
*/
@Throws(InterruptedException::class)
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? =
requestQueue.poll(timeout, unit)
fun takeRequest(
timeout: Long,
unit: TimeUnit,
): RecordedRequest? = requestQueue.poll(timeout, unit)
/**
* Scripts [response] to be returned to a request made in sequence. The first request is
@ -291,8 +296,7 @@ class MockWebServer : Closeable {
* @throws ClassCastException if the default dispatcher has been
* replaced with [setDispatcher][dispatcher].
*/
fun enqueue(response: MockResponse) =
(dispatcher as QueueDispatcher).enqueueResponse(response)
fun enqueue(response: MockResponse) = (dispatcher as QueueDispatcher).enqueueResponse(response)
/**
* Starts the server on the loopback interface for the given port.
@ -301,7 +305,8 @@ class MockWebServer : Closeable {
* use port 0 to avoid flakiness when a specific port is unavailable.
*/
@Throws(IOException::class)
@JvmOverloads fun start(port: Int = 0) = start(InetAddress.getByName("localhost"), port)
@JvmOverloads
fun start(port: Int = 0) = start(InetAddress.getByName("localhost"), port)
/**
* Starts the server on the given address and port.
@ -311,14 +316,18 @@ class MockWebServer : Closeable {
* use port 0 to avoid flakiness when a specific port is unavailable.
*/
@Throws(IOException::class)
fun start(inetAddress: InetAddress, port: Int) = start(InetSocketAddress(inetAddress, port))
fun start(
inetAddress: InetAddress,
port: Int,
) = start(InetSocketAddress(inetAddress, port))
/**
* Starts the server and binds to the given socket address.
*
* @param inetSocketAddress the socket address to bind the server on
*/
@Synchronized @Throws(IOException::class)
@Synchronized
@Throws(IOException::class)
private fun start(inetSocketAddress: InetSocketAddress) {
check(!shutdown) { "shutdown() already called" }
if (started) return
@ -432,8 +441,9 @@ class MockWebServer : Closeable {
processHandshakeFailure(raw)
return
}
socket = sslSocketFactory!!.createSocket(
raw, raw.inetAddress.hostAddress, raw.port, true
socket =
sslSocketFactory!!.createSocket(
raw, raw.inetAddress.hostAddress, raw.port, true,
)
val sslSocket = socket as SSLSocket
sslSocket.useClientMode = false
@ -452,7 +462,8 @@ class MockWebServer : Closeable {
if (protocolNegotiationEnabled) {
val protocolString = Platform.get().getSelectedProtocol(sslSocket)
protocol = when {
protocol =
when {
protocolString != null -> Protocol.get(protocolString)
else -> Protocol.HTTP_1_1
}
@ -463,7 +474,8 @@ class MockWebServer : Closeable {
openClientSockets.remove(raw)
}
else -> {
protocol = when {
protocol =
when {
Protocol.H2_PRIOR_KNOWLEDGE in protocols -> Protocol.H2_PRIOR_KNOWLEDGE
else -> Protocol.HTTP_1_1
}
@ -478,7 +490,8 @@ class MockWebServer : Closeable {
if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) {
val http2SocketHandler = Http2SocketHandler(socket, protocol)
val connection = Http2Connection.Builder(false, taskRunner)
val connection =
Http2Connection.Builder(false, taskRunner)
.socket(socket)
.listener(http2SocketHandler)
.build()
@ -498,7 +511,7 @@ class MockWebServer : Closeable {
if (sequenceNumber == 0) {
logger.warning(
"${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request"
"${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request",
)
}
@ -538,7 +551,7 @@ class MockWebServer : Closeable {
private fun processOneRequest(
socket: Socket,
source: BufferedSource,
sink: BufferedSink
sink: BufferedSink,
): Boolean {
if (source.exhausted()) {
return false // No more requests on this socket.
@ -581,7 +594,7 @@ class MockWebServer : Closeable {
if (logger.isLoggable(Level.FINE)) {
logger.fine(
"${this@MockWebServer} received request: $request and responded: $response"
"${this@MockWebServer} received request: $request and responded: $response",
)
}
@ -607,8 +620,12 @@ class MockWebServer : Closeable {
val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf<TrustManager>(UNTRUSTED_TRUST_MANAGER), SecureRandom())
val sslSocketFactory = context.socketFactory
val socket = sslSocketFactory.createSocket(
raw, raw.inetAddress.hostAddress, raw.port, true
val socket =
sslSocketFactory.createSocket(
raw,
raw.inetAddress.hostAddress,
raw.port,
true,
) as SSLSocket
try {
socket.startHandshake() // we're testing a handshake failure
@ -619,9 +636,19 @@ class MockWebServer : Closeable {
}
@Throws(InterruptedException::class)
private fun dispatchBookkeepingRequest(sequenceNumber: Int, socket: Socket) {
val request = RecordedRequest(
"", headersOf(), emptyList(), 0L, Buffer(), sequenceNumber, socket
private fun dispatchBookkeepingRequest(
sequenceNumber: Int,
socket: Socket,
) {
val request =
RecordedRequest(
"",
headersOf(),
emptyList(),
0L,
Buffer(),
sequenceNumber,
socket,
)
atomicRequestCount.incrementAndGet()
requestQueue.add(request)
@ -634,7 +661,7 @@ class MockWebServer : Closeable {
socket: Socket,
source: BufferedSource,
sink: BufferedSink,
sequenceNumber: Int
sequenceNumber: Int,
): RecordedRequest {
var request = ""
val headers = Headers.Builder()
@ -674,7 +701,8 @@ class MockWebServer : Closeable {
var hasBody = false
val policy = dispatcher.peek()
val requestBodySink = requestBody.withThrottlingAndSocketPolicy(
val requestBodySink =
requestBody.withThrottlingAndSocketPolicy(
policy = policy,
disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody,
expectedByteCount = contentLength,
@ -725,7 +753,7 @@ class MockWebServer : Closeable {
body = requestBody.buffer,
sequenceNumber = sequenceNumber,
socket = socket,
failure = failure
failure = failure,
)
}
@ -735,10 +763,11 @@ class MockWebServer : Closeable {
source: BufferedSource,
sink: BufferedSink,
request: RecordedRequest,
response: MockResponse
response: MockResponse,
) {
val key = request.headers["Sec-WebSocket-Key"]
val webSocketResponse = response.newBuilder()
val webSocketResponse =
response.newBuilder()
.setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!))
.build()
writeHttpResponse(socket, sink, webSocketResponse)
@ -746,11 +775,13 @@ class MockWebServer : Closeable {
// Adapt the request and response into our Request and Response domain model.
val scheme = if (request.handshake != null) "https" else "http"
val authority = request.headers["Host"] // Has host and port.
val fancyRequest = Request.Builder()
val fancyRequest =
Request.Builder()
.url("$scheme://$authority/")
.headers(request.headers)
.build()
val fancyResponse = Response.Builder()
val fancyResponse =
Response.Builder()
.code(webSocketResponse.code)
.message(webSocketResponse.message)
.headers(webSocketResponse.headers)
@ -759,14 +790,16 @@ class MockWebServer : Closeable {
.build()
val connectionClose = CountDownLatch(1)
val streams = object : RealWebSocket.Streams(false, source, sink) {
val streams =
object : RealWebSocket.Streams(false, source, sink) {
override fun close() = connectionClose.countDown()
override fun cancel() {
socket.closeQuietly()
}
}
val webSocket = RealWebSocket(
val webSocket =
RealWebSocket(
taskRunner = taskRunner,
originalRequest = fancyRequest,
listener = webSocketResponse.webSocketListener!!,
@ -789,7 +822,11 @@ class MockWebServer : Closeable {
}
@Throws(IOException::class)
private fun writeHttpResponse(socket: Socket, sink: BufferedSink, response: MockResponse) {
private fun writeHttpResponse(
socket: Socket,
sink: BufferedSink,
response: MockResponse,
) {
sleepNanos(response.headersDelayNanos)
sink.writeUtf8(response.status)
sink.writeUtf8("\r\n")
@ -798,7 +835,8 @@ class MockWebServer : Closeable {
val body = response.body ?: return
sleepNanos(response.bodyDelayNanos)
val responseBodySink = sink.withThrottlingAndSocketPolicy(
val responseBodySink =
sink.withThrottlingAndSocketPolicy(
policy = response,
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
expectedByteCount = body.contentLength,
@ -813,7 +851,10 @@ class MockWebServer : Closeable {
}
@Throws(IOException::class)
private fun writeHeaders(sink: BufferedSink, headers: Headers) {
private fun writeHeaders(
sink: BufferedSink,
headers: Headers,
) {
for ((name, value) in headers) {
sink.writeUtf8(name)
sink.writeUtf8(": ")
@ -834,7 +875,8 @@ class MockWebServer : Closeable {
var result: Sink = this
if (policy.throttlePeriodNanos > 0L) {
result = ThrottledSink(
result =
ThrottledSink(
delegate = result,
bytesPerPeriod = policy.throttleBytesPerPeriod,
periodDelayNanos = policy.throttlePeriodNanos,
@ -842,11 +884,13 @@ class MockWebServer : Closeable {
}
if (disconnectHalfway) {
val halfwayByteCount = when {
val halfwayByteCount =
when {
expectedByteCount != -1L -> expectedByteCount / 2
else -> 0L
}
result = TriggerSink(
result =
TriggerSink(
delegate = result,
triggerByteCount = halfwayByteCount,
) {
@ -871,13 +915,16 @@ class MockWebServer : Closeable {
/** A buffer wrapper that drops data after [bodyLimit] bytes. */
private class TruncatingBuffer(
private var remainingByteCount: Long
private var remainingByteCount: Long,
) : Sink {
val buffer = Buffer()
var receivedByteCount = 0L
@Throws(IOException::class)
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
val toRead = minOf(remainingByteCount, byteCount)
if (toRead > 0L) {
source.read(buffer, toRead)
@ -904,7 +951,7 @@ class MockWebServer : Closeable {
/** Processes HTTP requests layered over HTTP/2. */
private inner class Http2SocketHandler constructor(
private val socket: Socket,
private val protocol: Protocol
private val protocol: Protocol,
) : Http2Connection.Listener() {
private val sequenceNumber = AtomicInteger()
@ -935,7 +982,7 @@ class MockWebServer : Closeable {
if (logger.isLoggable(Level.FINE)) {
logger.fine(
"${this@MockWebServer} received request: $request " +
"and responded: $response protocol is $protocol"
"and responded: $response protocol is $protocol",
)
}
@ -990,7 +1037,8 @@ class MockWebServer : Closeable {
if (readBody && peek.streamHandler == null && peek.socketPolicy !is DoNotReadRequestBody) {
try {
val contentLengthString = headers["content-length"]
val requestBodySink = body.withThrottlingAndSocketPolicy(
val requestBodySink =
body.withThrottlingAndSocketPolicy(
policy = peek,
disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody,
expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE,
@ -1012,7 +1060,7 @@ class MockWebServer : Closeable {
body = body,
sequenceNumber = sequenceNumber.getAndIncrement(),
socket = socket,
failure = exception
failure = exception,
)
}
@ -1029,7 +1077,7 @@ class MockWebServer : Closeable {
private fun writeResponse(
stream: Http2Stream,
request: RecordedRequest,
response: MockResponse
response: MockResponse,
) {
val settings = response.settings
stream.connection.setSettings(settings)
@ -1042,9 +1090,11 @@ class MockWebServer : Closeable {
val trailers = response.trailers
val body = response.body
val streamHandler = response.streamHandler
val outFinished = (body == null &&
val outFinished = (
body == null &&
response.pushPromises.isEmpty() &&
streamHandler == null)
streamHandler == null
)
val flushHeaders = body == null || bodyDelayNanos != 0L
require(!outFinished || trailers.size == 0) {
"unsupported: no body and non-empty trailers $trailers"
@ -1059,11 +1109,12 @@ class MockWebServer : Closeable {
pushPromises(stream, request, response.pushPromises)
if (body != null) {
sleepNanos(bodyDelayNanos)
val responseBodySink = stream.getSink().withThrottlingAndSocketPolicy(
val responseBodySink =
stream.getSink().withThrottlingAndSocketPolicy(
policy = response,
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
expectedByteCount = body.contentLength,
socket = socket
socket = socket,
).buffer()
responseBodySink.use {
body.writeTo(responseBodySink)
@ -1079,7 +1130,7 @@ class MockWebServer : Closeable {
private fun pushPromises(
stream: Http2Stream,
request: RecordedRequest,
promises: List<PushPromise>
promises: List<PushPromise>,
) {
for (pushPromise in promises) {
val pushedHeaders = mutableListOf<Header>()
@ -1100,8 +1151,8 @@ class MockWebServer : Closeable {
bodySize = 0,
body = Buffer(),
sequenceNumber = sequenceNumber.getAndIncrement(),
socket = socket
)
socket = socket,
),
)
val hasBody = pushPromise.response.body != null
val pushedStream = stream.connection.pushStream(stream.id, pushedHeaders, hasBody)
@ -1115,16 +1166,17 @@ class MockWebServer : Closeable {
private const val CLIENT_AUTH_REQUESTED = 1
private const val CLIENT_AUTH_REQUIRED = 2
private val UNTRUSTED_TRUST_MANAGER = object : X509TrustManager {
private val UNTRUSTED_TRUST_MANAGER =
object : X509TrustManager {
@Throws(CertificateException::class)
override fun checkClientTrusted(
chain: Array<X509Certificate>,
authType: String
authType: String,
) = throw CertificateException()
override fun checkServerTrusted(
chain: Array<X509Certificate>,
authType: String
authType: String,
) = throw AssertionError()
override fun getAcceptedIssuers(): Array<X509Certificate> = throw AssertionError()

View File

@ -68,7 +68,8 @@ open class QueueDispatcher : Dispatcher() {
}
open fun setFailFast(failFast: Boolean) {
val failFastResponse = if (failFast) {
val failFastResponse =
if (failFast) {
MockResponse(code = HttpURLConnection.HTTP_NOT_FOUND)
} else {
null

View File

@ -31,34 +31,28 @@ import okio.Buffer
/** An HTTP request that came into the mock web server. */
class RecordedRequest(
val requestLine: String,
/** All headers. */
val headers: Headers,
/**
* The sizes of the chunks of this request's body, or an empty list if the request's body
* was empty or unchunked.
*/
val chunkSizes: List<Int>,
/** The total size of the body of this POST request (before truncation).*/
val bodySize: Long,
/** The body of this POST request. This may be truncated. */
val body: Buffer,
/**
* The index of this request on its HTTP connection. Since a single HTTP connection may serve
* multiple requests, each request is assigned its own sequence number.
*/
val sequenceNumber: Int,
socket: Socket,
/**
* The failure MockWebServer recorded when attempting to decode this request. If, for example,
* the inbound request was truncated, this exception will be non-null.
*/
val failure: IOException? = null
val failure: IOException? = null,
) {
val method: String?
val path: String?
@ -102,7 +96,8 @@ class RecordedRequest(
val scheme = if (socket is SSLSocket) "https" else "http"
val localPort = socket.localPort
val hostAndPort = headers[":authority"]
val hostAndPort =
headers[":authority"]
?: headers["Host"]
?: when (val inetAddress = socket.localAddress) {
is Inet6Address -> "[${inetAddress.hostAddress}]:$localPort"

View File

@ -30,7 +30,6 @@ package mockwebserver3
* server has closed the socket before follow up requests are made.
*/
sealed interface SocketPolicy {
/**
* Shutdown [MockWebServer] after writing response.
*/

View File

@ -29,7 +29,11 @@ internal class ThrottledSink(
private val periodDelayNanos: Long,
) : Sink by delegate {
private var bytesWrittenSinceLastDelay = 0L
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
var bytesLeft = byteCount
while (bytesLeft > 0) {

View File

@ -30,7 +30,10 @@ internal class TriggerSink(
) : Sink by delegate {
private var bytesWritten = 0L
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
if (byteCount == 0L) return // Avoid double-triggering.
if (bytesWritten == triggerByteCount) {

View File

@ -34,24 +34,28 @@ class MockStreamHandler : StreamHandler {
private val actions = LinkedBlockingQueue<Action>()
private val results = LinkedBlockingQueue<FutureTask<Void>>()
fun receiveRequest(expected: String) = apply {
fun receiveRequest(expected: String) =
apply {
actions += { stream ->
val actual = stream.requestBody.readUtf8(expected.utf8Size())
if (actual != expected) throw AssertionError("$actual != $expected")
}
}
fun exhaustRequest() = apply {
fun exhaustRequest() =
apply {
actions += { stream ->
if (!stream.requestBody.exhausted()) throw AssertionError("expected exhausted")
}
}
fun cancelStream() = apply {
fun cancelStream() =
apply {
actions += { stream -> stream.cancel() }
}
fun requestIOException() = apply {
fun requestIOException() =
apply {
actions += { stream ->
try {
stream.requestBody.exhausted()
@ -64,7 +68,7 @@ class MockStreamHandler : StreamHandler {
@JvmOverloads
fun sendResponse(
s: String,
responseSent: CountDownLatch = CountDownLatch(0)
responseSent: CountDownLatch = CountDownLatch(0),
) = apply {
actions += { stream ->
stream.responseBody.writeUtf8(s)
@ -73,11 +77,15 @@ class MockStreamHandler : StreamHandler {
}
}
fun exhaustResponse() = apply {
fun exhaustResponse() =
apply {
actions += { stream -> stream.responseBody.close() }
}
fun sleep(duration: Long, unit: TimeUnit) = apply {
fun sleep(
duration: Long,
unit: TimeUnit,
) = apply {
actions += { Thread.sleep(unit.toMillis(duration)) }
}
@ -88,9 +96,7 @@ class MockStreamHandler : StreamHandler {
}
/** Returns a task that processes both request and response from [stream]. */
private fun serviceStreamTask(
stream: Stream,
): FutureTask<Void> {
private fun serviceStreamTask(stream: Stream): FutureTask<Void> {
return FutureTask<Void> {
stream.requestBody.use {
stream.responseBody.use {
@ -106,7 +112,8 @@ class MockStreamHandler : StreamHandler {
/** Returns once all stream actions complete successfully. */
fun awaitSuccess() {
val futureTask = results.poll(5, TimeUnit.SECONDS)
val futureTask =
results.poll(5, TimeUnit.SECONDS)
?: throw AssertionError("no onRequest call received")
futureTask.get(5, TimeUnit.SECONDS)
}

View File

@ -37,7 +37,8 @@ class CustomDispatcherTest {
@Test
fun simpleDispatch() {
val requestsMade = mutableListOf<RecordedRequest>()
val dispatcher: Dispatcher = object : Dispatcher() {
val dispatcher: Dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
requestsMade.add(request)
return MockResponse()
@ -59,7 +60,8 @@ class CustomDispatcherTest {
val secondRequest = "/bar"
val firstRequest = "/foo"
val latch = CountDownLatch(1)
val dispatcher: Dispatcher = object : Dispatcher() {
val dispatcher: Dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
if (request.path == firstRequest) {
latch.await()
@ -85,7 +87,10 @@ class CustomDispatcherTest {
assertThat(secondResponseCode.get()).isEqualTo(200)
}
private fun buildRequestThread(path: String, responseCode: AtomicInteger): Thread {
private fun buildRequestThread(
path: String,
responseCode: AtomicInteger,
): Thread {
return Thread {
val url = mockWebServer.url(path).toUrl()
val conn: HttpURLConnection

View File

@ -55,14 +55,16 @@ class MockResponseSniTest {
val handshakeCertificates = localhost()
server.useHttps(handshakeCertificates.sslSocketFactory())
val dns = Dns {
val dns =
Dns {
Dns.SYSTEM.lookup(server.hostName)
}
val client = clientTestRule.newClientBuilder()
val client =
clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
handshakeCertificates.trustManager,
)
.dns(dns)
.build()
@ -84,35 +86,40 @@ class MockResponseSniTest {
*/
@Test
fun domainFronting() {
val heldCertificate = HeldCertificate.Builder()
val heldCertificate =
HeldCertificate.Builder()
.commonName("server name")
.addSubjectAlternativeName("url-host.com")
.build()
val handshakeCertificates = HandshakeCertificates.Builder()
val handshakeCertificates =
HandshakeCertificates.Builder()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory())
val dns = Dns {
val dns =
Dns {
Dns.SYSTEM.lookup(server.hostName)
}
val client = clientTestRule.newClientBuilder()
val client =
clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
handshakeCertificates.trustManager,
)
.dns(dns)
.build()
server.enqueue(MockResponse())
val call = client.newCall(
val call =
client.newCall(
Request(
url = "https://url-host.com:${server.port}/".toHttpUrl(),
headers = headersOf("Host", "header-host"),
)
),
)
val response = call.execute()
assertThat(response.isSuccessful).isTrue()
@ -150,20 +157,23 @@ class MockResponseSniTest {
* tell MockWebServer to act as a proxy.
*/
private fun requestToHostnameViaProxy(hostnameOrIpAddress: String): RecordedRequest {
val heldCertificate = HeldCertificate.Builder()
val heldCertificate =
HeldCertificate.Builder()
.commonName("server name")
.addSubjectAlternativeName(hostnameOrIpAddress)
.build()
val handshakeCertificates = HandshakeCertificates.Builder()
val handshakeCertificates =
HandshakeCertificates.Builder()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory())
val client = clientTestRule.newClientBuilder()
val client =
clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
handshakeCertificates.trustManager,
)
.proxy(server.toProxyAddress())
.build()
@ -171,12 +181,14 @@ class MockResponseSniTest {
server.enqueue(MockResponse(inTunnel = true))
server.enqueue(MockResponse())
val call = client.newCall(
val call =
client.newCall(
Request(
url = server.url("/").newBuilder()
url =
server.url("/").newBuilder()
.host(hostnameOrIpAddress)
.build()
)
.build(),
),
)
val response = call.execute()
assertThat(response.isSuccessful).isTrue()

View File

@ -92,14 +92,15 @@ class MockWebServerTest {
@Test
fun setResponseMockReason() {
val reasons = arrayOf<String?>(
val reasons =
arrayOf<String?>(
"Mock Response",
"Informational",
"OK",
"Redirection",
"Client Error",
"Server Error",
"Mock Response"
"Mock Response",
)
for (i in 0..599) {
val builder = MockResponse.Builder().code(i)
@ -128,7 +129,8 @@ class MockWebServerTest {
@Test
fun mockResponseAddHeader() {
val builder = MockResponse.Builder()
val builder =
MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
@ -137,7 +139,8 @@ class MockWebServerTest {
@Test
fun mockResponseSetHeader() {
val builder = MockResponse.Builder()
val builder =
MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie: a=android")
@ -148,7 +151,8 @@ class MockWebServerTest {
@Test
fun mockResponseSetHeaders() {
val builder = MockResponse.Builder()
val builder =
MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
@ -175,14 +179,18 @@ class MockWebServerTest {
@Test
fun redirect() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.code(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path"))
.body("This page has moved!")
.build())
server.enqueue(MockResponse.Builder()
.build(),
)
server.enqueue(
MockResponse.Builder()
.body("This is the new location!")
.build())
.build(),
)
val connection = server.url("/").toUrl().openConnection()
val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8))
assertThat(reader.readLine()).isEqualTo("This is the new location!")
@ -203,9 +211,11 @@ class MockWebServerTest {
Thread.sleep(1000)
} catch (ignored: InterruptedException) {
}
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("enqueued in the background")
.build())
.build(),
)
}.start()
val connection = server.url("/").toUrl().openConnection()
val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8))
@ -214,11 +224,13 @@ class MockWebServerTest {
@Test
fun nonHexadecimalChunkSize() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders()
.addHeader("Transfer-encoding: chunked")
.build())
.build(),
)
val connection = server.url("/").toUrl().openConnection()
try {
connection.getInputStream().read()
@ -230,14 +242,18 @@ class MockWebServerTest {
@Test
fun responseTimeout() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("ABC")
.clearHeaders()
.addHeader("Content-Length: 4")
.build())
server.enqueue(MockResponse.Builder()
.build(),
)
server.enqueue(
MockResponse.Builder()
.body("DEF")
.build())
.build(),
)
val urlConnection = server.url("/").toUrl().openConnection()
urlConnection!!.readTimeout = 1000
val inputStream = urlConnection.getInputStream()
@ -263,9 +279,11 @@ class MockWebServerTest {
@Disabled("Not actually failing where expected")
@Test
fun disconnectAtStart() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.socketPolicy(DisconnectAtStart)
.build())
.build(),
)
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse())
try {
@ -293,9 +311,11 @@ class MockWebServerTest {
@Test
fun throttleRequest() {
assumeNotWindows()
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build())
.build(),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
connection.doOutput = true
@ -314,10 +334,12 @@ class MockWebServerTest {
@Test
fun throttleResponse() {
assumeNotWindows()
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build())
.build(),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
val inputStream = connection!!.getInputStream()
@ -337,10 +359,12 @@ class MockWebServerTest {
@Test
fun delayResponse() {
assumeNotWindows()
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("ABCDEF")
.bodyDelay(1, TimeUnit.SECONDS)
.build())
.build(),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
val inputStream = connection!!.getInputStream()
@ -353,9 +377,11 @@ class MockWebServerTest {
@Test
fun disconnectRequestHalfway() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.socketPolicy(DisconnectDuringRequestBody)
.build())
.build(),
)
// Limit the size of the request body that the server holds in memory to an arbitrary
// 3.5 MBytes so this test can pass on devices with little memory.
server.bodyLimit = 7 * 512 * 1024
@ -386,10 +412,12 @@ class MockWebServerTest {
@Test
fun disconnectResponseHalfway() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("ab")
.socketPolicy(DisconnectDuringResponseBody)
.build())
.build(),
)
val connection = server.url("/").toUrl().openConnection()
assertThat(connection!!.contentLength).isEqualTo(2)
val inputStream = connection.getInputStream()
@ -468,9 +496,11 @@ class MockWebServerTest {
@Test
fun requestUrlReconstructed() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("hello world")
.build())
.build(),
)
val url = server.url("/a/deep/path?key=foo%20bar").toUrl()
val connection = url.openConnection() as HttpURLConnection
val inputStream = connection.inputStream
@ -479,7 +509,8 @@ class MockWebServerTest {
assertThat(reader.readLine()).isEqualTo("hello world")
val request = server.takeRequest()
assertThat(request.requestLine).isEqualTo(
"GET /a/deep/path?key=foo%20bar HTTP/1.1")
"GET /a/deep/path?key=foo%20bar HTTP/1.1",
)
val requestUrl = request.requestUrl
assertThat(requestUrl!!.scheme).isEqualTo("http")
assertThat(requestUrl.host).isEqualTo(server.hostName)
@ -490,9 +521,11 @@ class MockWebServerTest {
@Test
fun shutdownServerAfterRequest() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.socketPolicy(ShutdownServerAfterResponse)
.build())
.build(),
)
val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection
assertThat(connection.responseCode).isEqualTo(HttpURLConnection.HTTP_OK)
@ -506,9 +539,11 @@ class MockWebServerTest {
@Test
fun http100Continue() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("response")
.build())
.build(),
)
val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection
connection.doOutput = true
@ -523,11 +558,13 @@ class MockWebServerTest {
@Test
fun multiple1xxResponses() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.add100Continue()
.add100Continue()
.body("response")
.build())
.build(),
)
val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection
connection.doOutput = true
@ -546,8 +583,9 @@ class MockWebServerTest {
fail<Unit>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, http/1.1]")
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, http/1.1]",
)
}
}
@ -559,8 +597,9 @@ class MockWebServerTest {
fail<Unit>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, h2_prior_knowledge]")
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, h2_prior_knowledge]",
)
}
}
@ -575,9 +614,11 @@ class MockWebServerTest {
fun https() {
val handshakeCertificates = platform.localhostHandshakeCertificates()
server.useHttps(handshakeCertificates.sslSocketFactory())
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("abc")
.build())
.build(),
)
val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.sslSocketFactory = handshakeCertificates.sslSocketFactory()
@ -601,29 +642,37 @@ class MockWebServerTest {
platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt()
val clientCa = HeldCertificate.Builder()
val clientCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCa = HeldCertificate.Builder()
val serverCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCertificate = HeldCertificate.Builder()
val serverCertificate =
HeldCertificate.Builder()
.signedBy(serverCa)
.addSubjectAlternativeName(server.hostName)
.build()
val serverHandshakeCertificates = HandshakeCertificates.Builder()
val serverHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory())
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("abc")
.build())
.build(),
)
server.requestClientAuth()
val clientCertificate = HeldCertificate.Builder()
val clientCertificate =
HeldCertificate.Builder()
.signedBy(clientCa)
.build()
val clientHandshakeCertificates = HandshakeCertificates.Builder()
val clientHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
@ -647,10 +696,13 @@ class MockWebServerTest {
@Test
fun proxiedRequestGetsCorrectRequestUrl() {
server.enqueue(MockResponse.Builder()
server.enqueue(
MockResponse.Builder()
.body("Result")
.build())
val proxiedClient = OkHttpClient.Builder()
.build(),
)
val proxiedClient =
OkHttpClient.Builder()
.proxy(server.toProxyAddress())
.readTimeout(Duration.ofMillis(100))
.build()

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3
import assertk.assertThat
@ -32,30 +33,34 @@ class RecordedRequestTest {
private val headers: Headers = EMPTY_HEADERS
@Test fun testIPv4() {
val socket = FakeSocket(
val socket =
FakeSocket(
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
localPort = 80
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/")
}
@Test fun testIpv6() {
val socket = FakeSocket(
localAddress = InetAddress.getByAddress(
val socket =
FakeSocket(
localAddress =
InetAddress.getByAddress(
"::1",
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
),
localPort = 80
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://[::1]/")
}
@Test fun testUsesLocal() {
val socket = FakeSocket(
val socket =
FakeSocket(
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
localPort = 80
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/")
@ -63,12 +68,14 @@ class RecordedRequestTest {
@Test fun testHostname() {
val headers = headersOf("Host", "host-from-header.com")
val socket = FakeSocket(
localAddress = InetAddress.getByAddress(
val socket =
FakeSocket(
localAddress =
InetAddress.getByAddress(
"host-from-address.com",
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
),
localPort = 80
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://host-from-header.com/")
@ -78,11 +85,14 @@ class RecordedRequestTest {
private val localAddress: InetAddress,
private val localPort: Int,
private val remoteAddress: InetAddress = localAddress,
private val remotePort: Int = 1234
private val remotePort: Int = 1234,
) : Socket() {
override fun getInetAddress() = remoteAddress
override fun getLocalAddress() = localAddress
override fun getLocalPort() = localPort
override fun getPort() = remotePort
}
}

View File

@ -39,7 +39,7 @@ import okio.source
/** A basic HTTP/2 server that serves the contents of a local directory. */
class Http2Server(
private val baseDirectory: File,
private val sslSocketFactory: SSLSocketFactory
private val sslSocketFactory: SSLSocketFactory,
) : Http2Connection.Listener() {
private fun run() {
val serverSocket = ServerSocket(8888)
@ -54,7 +54,8 @@ class Http2Server(
if (protocol != Protocol.HTTP_2) {
throw ProtocolException("Protocol $protocol unsupported")
}
val connection = Http2Connection.Builder(false, TaskRunner.INSTANCE)
val connection =
Http2Connection.Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket)
.listener(this)
.build()
@ -70,10 +71,12 @@ class Http2Server(
}
private fun doSsl(socket: Socket): SSLSocket {
val sslSocket = sslSocketFactory.createSocket(
socket, socket.inetAddress.hostAddress,
val sslSocket =
sslSocketFactory.createSocket(
socket,
socket.inetAddress.hostAddress,
socket.port,
true
true,
) as SSLSocket
sslSocket.useClientMode = false
Platform.get().configureTlsExtensions(sslSocket, null, listOf(Protocol.HTTP_2))
@ -111,32 +114,40 @@ class Http2Server(
}
}
private fun send404(stream: Http2Stream, path: String) {
val responseHeaders = listOf(
private fun send404(
stream: Http2Stream,
path: String,
) {
val responseHeaders =
listOf(
Header(":status", "404"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/plain")
Header("content-type", "text/plain"),
)
stream.writeHeaders(
responseHeaders = responseHeaders,
outFinished = false,
flushHeaders = false
flushHeaders = false,
)
val out = stream.getSink().buffer()
out.writeUtf8("Not found: $path")
out.close()
}
private fun serveDirectory(stream: Http2Stream, files: Array<File>) {
val responseHeaders = listOf(
private fun serveDirectory(
stream: Http2Stream,
files: Array<File>,
) {
val responseHeaders =
listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/html; charset=UTF-8")
Header("content-type", "text/html; charset=UTF-8"),
)
stream.writeHeaders(
responseHeaders = responseHeaders,
outFinished = false,
flushHeaders = false
flushHeaders = false,
)
val out = stream.getSink().buffer()
for (file in files) {
@ -146,16 +157,20 @@ class Http2Server(
out.close()
}
private fun serveFile(stream: Http2Stream, file: File) {
val responseHeaders = listOf(
private fun serveFile(
stream: Http2Stream,
file: File,
) {
val responseHeaders =
listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", contentType(file))
Header("content-type", contentType(file)),
)
stream.writeHeaders(
responseHeaders = responseHeaders,
outFinished = false,
flushHeaders = false
flushHeaders = false,
)
file.source().use { source ->
stream.getSink().buffer().use { sink ->
@ -186,9 +201,10 @@ class Http2Server(
println("Usage: Http2Server <base directory>")
return
}
val server = Http2Server(
val server =
Http2Server(
File(args[0]),
localhost().sslContext().socketFactory
localhost().sslContext().socketFactory,
)
server.run()
}

View File

@ -15,19 +15,22 @@
*/
package okhttp3
import java.io.OutputStream
import java.io.PrintStream
import org.junit.platform.engine.TestExecutionResult
import org.junit.platform.launcher.TestExecutionListener
import org.junit.platform.launcher.TestIdentifier
import org.junit.platform.launcher.TestPlan
import java.io.OutputStream
import java.io.PrintStream
object DotListener : TestExecutionListener {
private var originalSystemErr: PrintStream? = null
private var originalSystemOut: PrintStream? = null
private var testCount = 0
override fun executionSkipped(testIdentifier: TestIdentifier, reason: String) {
override fun executionSkipped(
testIdentifier: TestIdentifier,
reason: String,
) {
printStatus("-")
}
@ -40,7 +43,7 @@ object DotListener: TestExecutionListener {
override fun executionFinished(
testIdentifier: TestIdentifier,
testExecutionResult: TestExecutionResult
testExecutionResult: TestExecutionResult,
) {
if (!testIdentifier.isContainer) {
when (testExecutionResult.status!!) {

View File

@ -15,12 +15,13 @@
*/
package okhttp3
import java.io.File
import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor
import org.junit.platform.engine.discovery.DiscoverySelectors
import java.io.File
// TODO move to junit5 tags
val avoidedTests = setOf(
val avoidedTests =
setOf(
"okhttp3.BouncyCastleTest",
"okhttp3.ConscryptTest",
"okhttp3.CorrettoTest",
@ -44,7 +45,8 @@ val avoidedTests = setOf(
fun main() {
val knownTestFile = File("native-image-tests/src/main/resources/testlist.txt")
val testSelector = DiscoverySelectors.selectPackage("okhttp3")
val testClasses = findTests(listOf(testSelector))
val testClasses =
findTests(listOf(testSelector))
.filter { it.isContainer }
.mapNotNull { (it as? ClassBasedTestDescriptor)?.testClass?.name }
.filterNot { it in avoidedTests }

View File

@ -15,6 +15,9 @@
*/
package okhttp3
import java.io.File
import java.io.PrintWriter
import kotlin.system.exitProcess
import org.junit.jupiter.engine.JupiterTestEngine
import org.junit.platform.console.options.Theme
import org.junit.platform.engine.DiscoverySelector
@ -30,9 +33,6 @@ import org.junit.platform.launcher.core.LauncherConfig
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder
import org.junit.platform.launcher.core.LauncherFactory
import org.junit.platform.launcher.listeners.SummaryGeneratingListener
import java.io.File
import java.io.PrintWriter
import kotlin.system.exitProcess
/**
* Graal main method to run tests with minimal reflection and automatic settings.
@ -49,7 +49,8 @@ fun main(vararg args: String) {
val jupiterTestEngine = buildTestEngine()
val config = LauncherConfig.builder()
val config =
LauncherConfig.builder()
.enableTestExecutionListenerAutoRegistration(false)
.enableTestEngineAutoRegistration(false)
.enablePostDiscoveryFilterAutoRegistration(false)
@ -89,7 +90,8 @@ fun testSelectors(inputFile: File? = null): List<DiscoverySelector> {
val lines =
inputFile?.readLines() ?: sampleTestClass.getResource("/testlist.txt").readText().lines()
val flatClassnameList = lines
val flatClassnameList =
lines
.filter { it.isNotBlank() }
return flatClassnameList
@ -107,7 +109,8 @@ fun testSelectors(inputFile: File? = null): List<DiscoverySelector> {
* Builds a Junit Test Plan request for a fixed set of classes, or potentially a recursive package.
*/
fun buildRequest(selectors: List<DiscoverySelector>): LauncherDiscoveryRequest {
val request: LauncherDiscoveryRequest = LauncherDiscoveryRequestBuilder.request()
val request: LauncherDiscoveryRequest =
LauncherDiscoveryRequestBuilder.request()
// TODO replace junit.jupiter.extensions.autodetection.enabled with API approach.
// .enableImplicitConfigurationParameters(false)
.selectors(selectors)
@ -136,11 +139,13 @@ fun findTests(selectors: List<DiscoverySelector>): List<TestDescriptor> {
* https://github.com/junit-team/junit5/issues/2469
*/
fun treeListener(): TestExecutionListener {
val colorPalette = Class.forName("org.junit.platform.console.tasks.ColorPalette").getField("DEFAULT").apply {
val colorPalette =
Class.forName("org.junit.platform.console.tasks.ColorPalette").getField("DEFAULT").apply {
isAccessible = true
}.get(null)
return Class.forName(
"org.junit.platform.console.tasks.TreePrintingListener").declaredConstructors.first()
"org.junit.platform.console.tasks.TreePrintingListener",
).declaredConstructors.first()
.apply {
isAccessible = true
}

View File

@ -26,7 +26,8 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
class SampleTest {
@JvmField @RegisterExtension val clientRule = OkHttpClientTestRule()
@JvmField @RegisterExtension
val clientRule = OkHttpClientTestRule()
@Test
fun passingTest() {

View File

@ -16,11 +16,11 @@
package okhttp3
import com.oracle.svm.core.annotate.AutomaticFeature
import java.io.File
import java.lang.IllegalStateException
import org.graalvm.nativeimage.hosted.Feature
import org.graalvm.nativeimage.hosted.RuntimeClassInitialization
import org.graalvm.nativeimage.hosted.RuntimeReflection
import java.io.File
import java.lang.IllegalStateException
@AutomaticFeature
class TestRegistration : Feature {
@ -40,7 +40,10 @@ class TestRegistration : Feature {
registerParamProvider(access, "okhttp3.WebPlatformUrlTest\$TestDataParamProvider")
}
private fun registerParamProvider(access: Feature.BeforeAnalysisAccess, provider: String) {
private fun registerParamProvider(
access: Feature.BeforeAnalysisAccess,
provider: String,
) {
val providerClass = access.findClassByName(provider)
if (providerClass != null) {
registerTest(access, providerClass)
@ -54,7 +57,10 @@ class TestRegistration : Feature {
registerStandardClass(access, "org.junit.platform.console.tasks.TreePrintingListener")
}
private fun registerStandardClass(access: Feature.BeforeAnalysisAccess, name: String) {
private fun registerStandardClass(
access: Feature.BeforeAnalysisAccess,
name: String,
) {
val clazz: Class<*> = access.findClassByName(name) ?: throw IllegalStateException("Missing class $name")
RuntimeReflection.register(clazz)
clazz.declaredConstructors.forEach {
@ -81,7 +87,10 @@ class TestRegistration : Feature {
}
}
private fun registerTest(access: Feature.BeforeAnalysisAccess, java: Class<*>) {
private fun registerTest(
access: Feature.BeforeAnalysisAccess,
java: Class<*>,
) {
access.registerAsUsed(java)
RuntimeReflection.register(java)
java.constructors.forEach {

View File

@ -48,11 +48,17 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
val userAgent: String by option("-A", "--user-agent", help = "User-Agent to send to server").default(NAME + "/" + versionString())
val connectTimeout: Int by option("--connect-timeout", help="Maximum time allowed for connection (seconds)").int().default(DEFAULT_TIMEOUT)
val connectTimeout: Int by option(
"--connect-timeout",
help = "Maximum time allowed for connection (seconds)",
).int().default(DEFAULT_TIMEOUT)
val readTimeout: Int by option("--read-timeout", help = "Maximum time allowed for reading data (seconds)").int().default(DEFAULT_TIMEOUT)
val callTimeout: Int by option("--call-timeout", help="Maximum time allowed for the entire call (seconds)").int().default(DEFAULT_TIMEOUT)
val callTimeout: Int by option(
"--call-timeout",
help = "Maximum time allowed for the entire call (seconds)",
).int().default(DEFAULT_TIMEOUT)
val followRedirects: Boolean by option("-L", "--location", help = "Follow redirects").flag()
@ -123,10 +129,17 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
return prop.getProperty("version", "dev")
}
private fun createInsecureTrustManager(): X509TrustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {}
private fun createInsecureTrustManager(): X509TrustManager =
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<X509Certificate>,
authType: String,
) {}
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {}
override fun checkServerTrusted(
chain: Array<X509Certificate>,
authType: String,
) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
@ -136,7 +149,6 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
init(null, arrayOf(trustManager), null)
}.socketFactory
private fun createInsecureHostnameVerifier(): HostnameVerifier =
HostnameVerifier { _, _ -> true }
private fun createInsecureHostnameVerifier(): HostnameVerifier = HostnameVerifier { _, _ -> true }
}
}

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("ktlint:standard:filename")
package okhttp3.curl.internal
import java.io.IOException
@ -53,7 +54,8 @@ internal fun Main.commonCreateRequest(): Request {
}
private fun Main.mediaType(): MediaType? {
val mimeType = headers?.let {
val mimeType =
headers?.let {
for (header in it) {
val parts = header.split(':', limit = 2)
if ("Content-Type".equals(parts[0], ignoreCase = true)) {

View File

@ -1,23 +1,28 @@
package okhttp3.curl.logging
import okhttp3.internal.http2.Http2
import java.util.logging.ConsoleHandler
import java.util.logging.Level
import java.util.logging.LogManager
import java.util.logging.LogRecord
import java.util.logging.Logger
import okhttp3.internal.http2.Http2
class LoggingUtil {
companion object {
private val activeLoggers = mutableListOf<Logger>()
fun configureLogging(debug: Boolean, showHttp2Frames: Boolean, sslDebug: Boolean) {
fun configureLogging(
debug: Boolean,
showHttp2Frames: Boolean,
sslDebug: Boolean,
) {
if (debug || showHttp2Frames || sslDebug) {
if (sslDebug) {
System.setProperty("javax.net.debug", "")
}
LogManager.getLogManager().reset()
val handler = object : ConsoleHandler() {
val handler =
object : ConsoleHandler() {
override fun publish(record: LogRecord) {
super.publish(record)

View File

@ -19,7 +19,8 @@ import java.util.logging.LogRecord
* Why so much construction?
*/
class OneLineLogFormat : Formatter() {
private val d = DateTimeFormatterBuilder()
private val d =
DateTimeFormatterBuilder()
.appendValue(HOUR_OF_DAY, 2)
.appendLiteral(':')
.appendValue(MINUTE_OF_HOUR, 2)

View File

@ -15,14 +15,14 @@
*/
package okhttp3.curl
import java.io.IOException
import okhttp3.RequestBody
import okio.Buffer
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import assertk.assertions.startsWith
import java.io.IOException
import kotlin.test.Test
import okhttp3.RequestBody
import okio.Buffer
class MainTest {
@Test
@ -34,7 +34,8 @@ class MainTest {
}
@Test
@Throws(IOException::class) fun put() {
@Throws(IOException::class)
fun put() {
val request = fromArgs("-X", "PUT", "-d", "foo", "http://example.com").createRequest()
assertThat(request.method).isEqualTo("PUT")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
@ -48,7 +49,7 @@ class MainTest {
assertThat(request.method).isEqualTo("POST")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
assertThat(body!!.contentType().toString()).isEqualTo(
"application/x-www-form-urlencoded; charset=utf-8"
"application/x-www-form-urlencoded; charset=utf-8",
)
assertThat(bodyAsString(body)).isEqualTo("foo")
}
@ -60,16 +61,20 @@ class MainTest {
assertThat(request.method).isEqualTo("PUT")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
assertThat(body!!.contentType().toString()).isEqualTo(
"application/x-www-form-urlencoded; charset=utf-8"
"application/x-www-form-urlencoded; charset=utf-8",
)
assertThat(bodyAsString(body)).isEqualTo("foo")
}
@Test
fun contentTypeHeader() {
val request = fromArgs(
"-d", "foo", "-H", "Content-Type: application/json",
"http://example.com"
val request =
fromArgs(
"-d",
"foo",
"-H",
"Content-Type: application/json",
"http://example.com",
).createRequest()
val body = request.body
assertThat(request.method).isEqualTo("POST")
@ -105,12 +110,14 @@ class MainTest {
@Test
fun headerSplitWithDate() {
val request = fromArgs(
"-H", "If-Modified-Since: Mon, 18 Aug 2014 15:16:06 GMT",
"http://example.com"
val request =
fromArgs(
"-H",
"If-Modified-Since: Mon, 18 Aug 2014 15:16:06 GMT",
"http://example.com",
).createRequest()
assertThat(request.header("If-Modified-Since")).isEqualTo(
"Mon, 18 Aug 2014 15:16:06 GMT"
"Mon, 18 Aug 2014 15:16:06 GMT",
)
}

View File

@ -50,12 +50,14 @@ import org.junit.Test
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/
class AndroidAsyncDnsTest {
@JvmField @Rule val serverRule = MockWebServerRule()
@JvmField @Rule
val serverRule = MockWebServerRule()
private lateinit var client: OkHttpClient
private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate = HeldCertificate.Builder()
val heldCertificate =
HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build()
return@lazy HandshakeCertificates.Builder()
@ -69,7 +71,8 @@ class AndroidAsyncDnsTest {
fun init() {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()
@ -127,17 +130,26 @@ class AndroidAsyncDnsTest {
val latch = CountDownLatch(1)
// assumes an IPv4 address
AndroidAsyncDns.IPv4.query(hostname, object : AsyncDns.Callback {
override fun onResponse(hostname: String, addresses: List<InetAddress>) {
AndroidAsyncDns.IPv4.query(
hostname,
object : AsyncDns.Callback {
override fun onResponse(
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
latch.countDown()
}
override fun onFailure(hostname: String, e: IOException) {
override fun onFailure(
hostname: String,
e: IOException,
) {
exception = e
latch.countDown()
}
})
},
)
latch.await()
@ -173,7 +185,8 @@ class AndroidAsyncDnsTest {
val network =
connectivityManager.activeNetwork ?: throw AssumptionViolatedException("No active network")
val client = OkHttpClient.Builder()
val client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.socketFactory(network.socketFactory)
.build()

View File

@ -43,20 +43,34 @@ class AndroidAsyncDns(
private val resolver = DnsResolver.getInstance()
private val executor = Executors.newSingleThreadExecutor()
override fun query(hostname: String, callback: AsyncDns.Callback) {
override fun query(
hostname: String,
callback: AsyncDns.Callback,
) {
resolver.query(
network, hostname, dnsClass.type, DnsResolver.FLAG_EMPTY, executor, null,
network,
hostname,
dnsClass.type,
DnsResolver.FLAG_EMPTY,
executor,
null,
object : DnsResolver.Callback<List<InetAddress>> {
override fun onAnswer(addresses: List<InetAddress>, rCode: Int) {
override fun onAnswer(
addresses: List<InetAddress>,
rCode: Int,
) {
callback.onResponse(hostname, addresses)
}
override fun onError(e: DnsResolver.DnsException) {
callback.onFailure(hostname, UnknownHostException(e.message).apply {
callback.onFailure(
hostname,
UnknownHostException(e.message).apply {
initCause(e)
})
}
},
)
}
},
)
}

View File

@ -43,14 +43,14 @@ import org.robolectric.annotation.Config
sdk = [30],
)
class RobolectricOkHttpClientTest {
private lateinit var context: Context
private lateinit var client: OkHttpClient
@Before
fun setUp() {
context = ApplicationProvider.getApplicationContext<Application>()
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.cache(Cache("/cache".toPath(), 10_000_000, FakeFileSystem()))
.build()
}
@ -61,7 +61,8 @@ class RobolectricOkHttpClientTest {
val request = Request("https://www.google.com/robots.txt".toHttpUrl())
val networkRequest = request.newBuilder()
val networkRequest =
request.newBuilder()
.build()
val call = client.newCall(networkRequest)

View File

@ -28,7 +28,8 @@ import okhttp3.brotli.internal.uncompress
object BrotliInterceptor : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response {
return if (chain.request().header("Accept-Encoding") == null) {
val request = chain.request().newBuilder()
val request =
chain.request().newBuilder()
.header("Accept-Encoding", "br,gzip")
.build()

View File

@ -30,7 +30,8 @@ fun uncompress(response: Response): Response {
val body = response.body
val encoding = response.header("Content-Encoding") ?: return response
val decompressedSource = when {
val decompressedSource =
when {
encoding.equals("br", ignoreCase = true) ->
BrotliInputStream(body.source().inputStream()).source().buffer()
encoding.equals("gzip", ignoreCase = true) ->

View File

@ -43,7 +43,8 @@ class BrotliInterceptorTest {
"abe82ba64ed250a497162006824684db917963ecebe041b352a3e62d629cc97b95cac24265b175171e" +
"5cb384cd0912aeb5b5dd9555f2dd1a9b20688201"
val response = response("https://httpbin.org/brotli", s.decodeHex()) {
val response =
response("https://httpbin.org/brotli", s.decodeHex()) {
header("Content-Encoding", "br")
}
@ -63,7 +64,8 @@ class BrotliInterceptorTest {
"8abcbd54b7b6b97640c965bbfec238d9f4109ceb6edb01d66ba54d6247296441531e445970f627215b" +
"b22f1017320dd5000000"
val response = response("https://httpbin.org/gzip", s.decodeHex()) {
val response =
response("https://httpbin.org/gzip", s.decodeHex()) {
header("Content-Encoding", "gzip")
}
@ -86,7 +88,8 @@ class BrotliInterceptorTest {
@Test
fun testFailsUncompress() {
val response = response("https://httpbin.org/brotli", "bb919aaa06e8".decodeHex()) {
val response =
response("https://httpbin.org/brotli", "bb919aaa06e8".decodeHex()) {
header("Content-Encoding", "br")
}
@ -101,7 +104,8 @@ class BrotliInterceptorTest {
@Test
fun testSkipUncompressNoContentResponse() {
val response = response("https://httpbin.org/brotli", EMPTY) {
val response =
response("https://httpbin.org/brotli", EMPTY) {
header("Content-Encoding", "br")
code(204)
message("NO CONTENT")
@ -116,7 +120,7 @@ class BrotliInterceptorTest {
private fun response(
url: String,
bodyHex: ByteString,
fn: Response.Builder.() -> Unit = {}
fn: Response.Builder.() -> Unit = {},
): Response {
return Response.Builder()
.body(bodyHex.toResponseBody("text/plain".toMediaType()))

View File

@ -19,7 +19,8 @@ import okhttp3.OkHttpClient
import okhttp3.Request
fun main() {
val client = OkHttpClient.Builder()
val client =
OkHttpClient.Builder()
.addInterceptor(BrotliInterceptor)
.build()
@ -27,7 +28,10 @@ fun main() {
sendRequest("https://httpbin.org/gzip", client)
}
private fun sendRequest(url: String, client: OkHttpClient) {
private fun sendRequest(
url: String,
client: OkHttpClient,
) {
val req = Request.Builder().url(url).build()
client.newCall(req).execute().use {

View File

@ -17,23 +17,32 @@
package okhttp3
import kotlin.coroutines.resumeWithException
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.suspendCancellableCoroutine
import okio.IOException
import kotlin.coroutines.resumeWithException
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun Call.executeAsync(): Response = suspendCancellableCoroutine { continuation ->
suspend fun Call.executeAsync(): Response =
suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
this.cancel()
}
this.enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
this.enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
continuation.resumeWithException(e)
}
override fun onResponse(call: Call, response: Response) {
override fun onResponse(
call: Call,
response: Response,
) {
continuation.resume(value = response, onCancellation = { call.cancel() })
}
})
},
)
}

View File

@ -80,7 +80,7 @@ class SuspendCallTest {
MockResponse.Builder()
.bodyDelay(5, TimeUnit.SECONDS)
.body("abc")
.build()
.build(),
)
val call = client.newCall(request)
@ -109,7 +109,7 @@ class SuspendCallTest {
MockResponse.Builder()
.bodyDelay(5, TimeUnit.SECONDS)
.body("abc")
.build()
.build(),
)
val call = client.newCall(request)
@ -137,7 +137,7 @@ class SuspendCallTest {
MockResponse(
body = "abc",
socketPolicy = DisconnectAfterRequest,
)
),
)
val call = client.newCall(request)

View File

@ -26,13 +26,13 @@ import okhttp3.Dns
*/
internal class BootstrapDns(
private val dnsHostname: String,
private val dnsServers: List<InetAddress>
private val dnsServers: List<InetAddress>,
) : Dns {
@Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> {
if (this.dnsHostname != hostname) {
throw UnknownHostException(
"BootstrapDns called for $hostname instead of $dnsHostname"
"BootstrapDns called for $hostname instead of $dnsHostname",
)
}

View File

@ -51,7 +51,7 @@ class DnsOverHttps internal constructor(
@get:JvmName("includeIPv6") val includeIPv6: Boolean,
@get:JvmName("post") val post: Boolean,
@get:JvmName("resolvePrivateAddresses") val resolvePrivateAddresses: Boolean,
@get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean
@get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean,
) : Dns {
@Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> {
@ -94,37 +94,46 @@ class DnsOverHttps internal constructor(
networkRequests: MutableList<Call>,
results: MutableList<InetAddress>,
failures: MutableList<Exception>,
type: Int
type: Int,
) {
val request = buildRequest(hostname, type)
val response = getCacheOnlyResponse(request)
response?.let { processResponse(it, hostname, results, failures) } ?: networkRequests.add(
client.newCall(request))
client.newCall(request),
)
}
private fun executeRequests(
hostname: String,
networkRequests: List<Call>,
responses: MutableList<InetAddress>,
failures: MutableList<Exception>
failures: MutableList<Exception>,
) {
val latch = CountDownLatch(networkRequests.size)
for (call in networkRequests) {
call.enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
call.enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
synchronized(failures) {
failures.add(e)
}
latch.countDown()
}
override fun onResponse(call: Call, response: Response) {
override fun onResponse(
call: Call,
response: Response,
) {
processResponse(response, hostname, responses, failures)
latch.countDown()
}
})
},
)
}
try {
@ -138,7 +147,7 @@ class DnsOverHttps internal constructor(
response: Response,
hostname: String,
results: MutableList<InetAddress>,
failures: MutableList<Exception>
failures: MutableList<Exception>,
) {
try {
val addresses = readResponse(hostname, response)
@ -153,7 +162,10 @@ class DnsOverHttps internal constructor(
}
@Throws(UnknownHostException::class)
private fun throwBestFailure(hostname: String, failures: List<Exception>): List<InetAddress> {
private fun throwBestFailure(
hostname: String,
failures: List<Exception>,
): List<InetAddress> {
if (failures.isEmpty()) {
throw UnknownHostException(hostname)
}
@ -179,7 +191,8 @@ class DnsOverHttps internal constructor(
try {
// Use the cache without hitting the network first
// 504 code indicates that the Cache is stale
val preferCache = CacheControl.Builder()
val preferCache =
CacheControl.Builder()
.onlyIfCached()
.build()
val cacheRequest = request.newBuilder().cacheControl(preferCache).build()
@ -199,7 +212,10 @@ class DnsOverHttps internal constructor(
}
@Throws(Exception::class)
private fun readResponse(hostname: String, response: Response): List<InetAddress> {
private fun readResponse(
hostname: String,
response: Response,
): List<InetAddress> {
if (response.cacheResponse == null && response.protocol !== Protocol.HTTP_2) {
Platform.get().log("Incorrect protocol: ${response.protocol}", Platform.WARN)
}
@ -213,7 +229,7 @@ class DnsOverHttps internal constructor(
if (body.contentLength() > MAX_RESPONSE_SIZE) {
throw IOException(
"response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes"
"response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes",
)
}
@ -223,7 +239,10 @@ class DnsOverHttps internal constructor(
}
}
private fun buildRequest(hostname: String, type: Int): Request =
private fun buildRequest(
hostname: String,
type: Int,
): Request =
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply {
val query = DnsRecordCodec.encodeQuery(hostname, type)
@ -255,42 +274,49 @@ class DnsOverHttps internal constructor(
includeIPv6,
post,
resolvePrivateAddresses,
resolvePublicAddresses
resolvePublicAddresses,
)
}
fun client(client: OkHttpClient) = apply {
fun client(client: OkHttpClient) =
apply {
this.client = client
}
fun url(url: HttpUrl) = apply {
fun url(url: HttpUrl) =
apply {
this.url = url
}
fun includeIPv6(includeIPv6: Boolean) = apply {
fun includeIPv6(includeIPv6: Boolean) =
apply {
this.includeIPv6 = includeIPv6
}
fun post(post: Boolean) = apply {
fun post(post: Boolean) =
apply {
this.post = post
}
fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) = apply {
fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) =
apply {
this.resolvePrivateAddresses = resolvePrivateAddresses
}
fun resolvePublicAddresses(resolvePublicAddresses: Boolean) = apply {
fun resolvePublicAddresses(resolvePublicAddresses: Boolean) =
apply {
this.resolvePublicAddresses = resolvePublicAddresses
}
fun bootstrapDnsHosts(bootstrapDnsHosts: List<InetAddress>?) = apply {
fun bootstrapDnsHosts(bootstrapDnsHosts: List<InetAddress>?) =
apply {
this.bootstrapDnsHosts = bootstrapDnsHosts
}
fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder =
bootstrapDnsHosts(bootstrapDnsHosts.toList())
fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder = bootstrapDnsHosts(bootstrapDnsHosts.toList())
fun systemDns(systemDns: Dns) = apply {
fun systemDns(systemDns: Dns) =
apply {
this.systemDns = systemDns
}
}

View File

@ -33,7 +33,11 @@ internal object DnsRecordCodec {
private const val TYPE_PTR = 0x000c
private val ASCII = Charsets.US_ASCII
fun encodeQuery(host: String, type: Int): ByteString = Buffer().apply {
fun encodeQuery(
host: String,
type: Int,
): ByteString =
Buffer().apply {
writeShort(0) // query id
writeShort(256) // flags with recursion
writeShort(1) // question count
@ -57,7 +61,10 @@ internal object DnsRecordCodec {
}.readByteString()
@Throws(Exception::class)
fun decodeAnswers(hostname: String, byteString: ByteString): List<InetAddress> {
fun decodeAnswers(
hostname: String,
byteString: ByteString,
): List<InetAddress> {
val result = mutableListOf<InetAddress>()
val buf = Buffer()
@ -91,7 +98,8 @@ internal object DnsRecordCodec {
val type = buf.readShort().toInt() and 0xffff
buf.readShort() // class
@Suppress("UNUSED_VARIABLE") val ttl = buf.readInt().toLong() and 0xffffffffL // ttl
@Suppress("UNUSED_VARIABLE")
val ttl = buf.readInt().toLong() and 0xffffffffL // ttl
val length = buf.readShort().toInt() and 0xffff
if (type == TYPE_A || type == TYPE_AAAA) {

View File

@ -55,7 +55,8 @@ class DnsOverHttpsTest {
private lateinit var server: MockWebServer
private lateinit var dns: Dns
private val cacheFs = FakeFileSystem()
private val bootstrapClient = OkHttpClient.Builder()
private val bootstrapClient =
OkHttpClient.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.build()
@ -72,8 +73,8 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
)
"0003b00049df00112",
),
)
val result = dns.lookup("google.com")
assertThat(result).isEqualTo(listOf(address("157.240.1.18")))
@ -89,15 +90,15 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c0005000" +
"100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001000" +
"10000003b00049df00112"
)
"10000003b00049df00112",
),
)
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c0005000" +
"100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c000" +
"10000003b00102a032880f0290011faceb00c00000002"
)
"10000003b00102a032880f0290011faceb00c00000002",
),
)
dns = buildLocalhost(bootstrapClient, true)
val result = dns.lookup("google.com")
@ -111,7 +112,7 @@ class DnsOverHttpsTest {
assertThat(listOf(request1.path, request2.path))
.containsExactlyInAnyOrder(
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ"
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ",
)
}
@ -121,8 +122,8 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b00060001000" +
"007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e6574c01b5adb1" +
"2c100000e10000003840012750000000e10"
)
"2c100000e10000003840012750000000e10",
),
)
try {
dns.lookup("google.com")
@ -179,11 +180,11 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "private, max-age=298")
.build()
.build(),
)
var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
@ -204,29 +205,29 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "max-age=1")
.build()
.build(),
)
var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
var recordedRequest = server.takeRequest(0, TimeUnit.SECONDS)
assertThat(recordedRequest!!.method).isEqualTo("GET")
assertThat(recordedRequest.path).isEqualTo(
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ"
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
)
Thread.sleep(2000)
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "max-age=1")
.build()
.build(),
)
result = cachedDns.lookup("google.com")
assertThat(result).isEqualTo(listOf(address("157.240.1.18")))
@ -244,7 +245,10 @@ class DnsOverHttpsTest {
.build()
}
private fun buildLocalhost(bootstrapClient: OkHttpClient, includeIPv6: Boolean): DnsOverHttps {
private fun buildLocalhost(
bootstrapClient: OkHttpClient,
includeIPv6: Boolean,
): DnsOverHttps {
val url = server.url("/lookup?ct")
return DnsOverHttps.Builder().client(bootstrapClient)
.includeIPv6(includeIPv6)

View File

@ -14,12 +14,12 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.dnsoverhttps
import assertk.assertThat
import assertk.assertions.containsExactly
import assertk.assertions.isEqualTo
import assertk.fail
import java.net.InetAddress
import java.net.UnknownHostException
import kotlin.test.assertFailsWith
@ -36,7 +36,10 @@ class DnsRecordCodecTest {
assertThat(encoded).isEqualTo("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ")
}
private fun encodeQuery(host: String, type: Int): String {
private fun encodeQuery(
host: String,
type: Int,
): String {
return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "")
}
@ -48,32 +51,44 @@ class DnsRecordCodecTest {
@Test
fun testGoogleDotComDecodingFromCloudflare() {
val encoded = decodeAnswers(
val encoded =
decodeAnswers(
hostname = "test.com",
byteString = ("00008180000100010000000006676f6f676c6503636f6d0000010001c00c0001000100000043" +
"0004d83ad54e").decodeHex()
byteString =
(
"00008180000100010000000006676f6f676c6503636f6d0000010001c00c0001000100000043" +
"0004d83ad54e"
).decodeHex(),
)
assertThat(encoded).containsExactly(InetAddress.getByName("216.58.213.78"))
}
@Test
fun testGoogleDotComDecodingFromGoogle() {
val decoded = decodeAnswers(
val decoded =
decodeAnswers(
hostname = "test.com",
byteString = ("0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c" +
byteString =
(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c" +
"0005000100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001" +
"00010000003b00049df00112").decodeHex()
"00010000003b00049df00112"
).decodeHex(),
)
assertThat(decoded).containsExactly(InetAddress.getByName("157.240.1.18"))
}
@Test
fun testGoogleDotComDecodingFromGoogleIPv6() {
val decoded = decodeAnswers(
val decoded =
decodeAnswers(
hostname = "test.com",
byteString = ("0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c" +
byteString =
(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c" +
"0005000100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c" +
"00010000003b00102a032880f0290011faceb00c00000002").decodeHex()
"00010000003b00102a032880f0290011faceb00c00000002"
).decodeHex(),
)
assertThat(decoded)
.containsExactly(InetAddress.getByName("2a03:2880:f029:11:face:b00c:0:2"))
@ -84,9 +99,12 @@ class DnsRecordCodecTest {
assertFailsWith<UnknownHostException> {
decodeAnswers(
hostname = "sdflkhfsdlkjdf.ee",
byteString = ("0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b" +
byteString =
(
"0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b" +
"00060001000007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e65" +
"74c01b5adb12c100000e10000003840012750000000e10").decodeHex()
"74c01b5adb12c100000e10000003840012750000000e10"
).decodeHex(),
)
}.also { expected ->
assertThat(expected.message).isEqualTo("sdflkhfsdlkjdf.ee: NXDOMAIN")

View File

@ -24,7 +24,10 @@ import okhttp3.OkHttpClient
import okhttp3.dnsoverhttps.DohProviders.providers
import org.conscrypt.OpenSSLProvider
private fun runBatch(dnsProviders: List<DnsOverHttps>, names: List<String>) {
private fun runBatch(
dnsProviders: List<DnsOverHttps>,
names: List<String>,
) {
var time = System.currentTimeMillis()
for (dns in dnsProviders) {
println("Testing ${dns.url}")
@ -54,33 +57,38 @@ fun main() {
var names = listOf("google.com", "graph.facebook.com", "sdflkhfsdlkjdf.ee")
try {
println("uncached\n********\n")
var dnsProviders = providers(
var dnsProviders =
providers(
client = bootstrapClient,
http2Only = false,
workingOnly = false,
getOnly = false,
)
runBatch(dnsProviders, names)
val dnsCache = Cache(
val dnsCache =
Cache(
directory = File("./target/TestDohMain.cache.${System.currentTimeMillis()}"),
maxSize = 10L * 1024 * 1024
maxSize = 10L * 1024 * 1024,
)
println("Bad targets\n***********\n")
val url = "https://dns.cloudflare.com/.not-so-well-known/run-dmc-query".toHttpUrl()
val badProviders = listOf(
val badProviders =
listOf(
DnsOverHttps.Builder()
.client(bootstrapClient)
.url(url)
.post(true)
.build()
.build(),
)
runBatch(badProviders, names)
println("cached first run\n****************\n")
names = listOf("google.com", "graph.facebook.com")
bootstrapClient = bootstrapClient.newBuilder()
bootstrapClient =
bootstrapClient.newBuilder()
.cache(dnsCache)
.build()
dnsProviders = providers(
dnsProviders =
providers(
client = bootstrapClient,
http2Only = true,
workingOnly = true,
@ -88,7 +96,8 @@ fun main() {
)
runBatch(dnsProviders, names)
println("cached second run\n*****************\n")
dnsProviders = providers(
dnsProviders =
providers(
client = bootstrapClient,
http2Only = true,
workingOnly = true,

View File

@ -28,7 +28,7 @@ class HpackDecodeInteropTest : HpackDecodeTestBase() {
fun testGoodDecoderInterop(story: Story) {
assumeFalse(
story === Story.MISSING,
"Test stories missing, checkout git submodule"
"Test stories missing, checkout git submodule",
)
testDecoder(story)
}

View File

@ -36,7 +36,7 @@ open class HpackDecodeTestBase {
assertSetEquals(
"seqno=$testCase.seqno",
testCase.headersList,
hpackReader.getAndResetHeaderList()
hpackReader.getAndResetHeaderList(),
)
}
}

View File

@ -43,7 +43,7 @@ class HpackRoundTripTest : HpackDecodeTestBase() {
fun testRoundTrip(story: Story) {
assumeFalse(
story === Story.MISSING,
"Test stories missing, checkout git submodule"
"Test stories missing, checkout git submodule",
)
val newCases = mutableListOf<Case>()

View File

@ -35,11 +35,15 @@ import okio.source
*/
object HpackJsonUtil {
@Suppress("unused")
private val MOSHI = Moshi.Builder()
.add(object : Any() {
private val MOSHI =
Moshi.Builder()
.add(
object : Any() {
@ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex()
@FromJson fun byteStringFromJson(json: String) = json.decodeHex()
})
},
)
.add(KotlinJsonAdapterFactory())
.build()
private val STORY_JSON_ADAPTER = MOSHI.adapter(Story::class.java)
@ -58,7 +62,8 @@ object HpackJsonUtil {
/** Iterate through the hpack-test-case resources, only picking stories for the current draft. */
fun storiesForCurrentDraft(): Array<String> {
val resource = HpackJsonUtil::class.java.getResource("/hpack-test-case")
val resource =
HpackJsonUtil::class.java.getResource("/hpack-test-case")
?: return arrayOf()
val testCaseDirectory = File(resource.toURI()).toOkioPath()
@ -83,16 +88,19 @@ object HpackJsonUtil {
val result = mutableListOf<Story>()
var i = 0
while (true) { // break after last test.
val storyResourceName = String.format(
val storyResourceName =
String.format(
"/hpack-test-case/%s/story_%02d.json",
testFolderName,
i,
)
val storyInputStream = HpackJsonUtil::class.java.getResourceAsStream(storyResourceName)
val storyInputStream =
HpackJsonUtil::class.java.getResourceAsStream(storyResourceName)
?: break
try {
storyInputStream.use {
val story = readStory(storyInputStream.source().buffer())
val story =
readStory(storyInputStream.source().buffer())
.copy(fileName = storyResourceName)
result.add(story)
i++

View File

@ -24,7 +24,6 @@ data class Story(
val cases: List<Case>,
val fileName: String? = null,
) {
// Used as the test name.
override fun toString() = fileName ?: "?"

View File

@ -33,7 +33,8 @@ fun main(vararg args: String) {
fun loadIdnaMappingTableData(): IdnaMappingTableData {
val path = "/okhttp3/internal/idna/IdnaMappingTable.txt".toPath()
val table = FileSystem.RESOURCES.read(path) {
val table =
FileSystem.RESOURCES.read(path) {
readPlainTextIdnaMappingTable()
}
return buildIdnaMappingTableData(table)
@ -71,7 +72,7 @@ fun generateMappingTableFile(data: IdnaMappingTableData): FileSpec {
data.ranges.escapeDataString(),
data.mappings.escapeDataString(),
)
.build()
.build(),
)
.build()
}
@ -89,7 +90,8 @@ fun String.escapeDataString(): String {
'$'.code,
'\\'.code,
'·'.code,
127 -> append(String.format("\\u%04x", codePoint))
127,
-> append(String.format("\\u%04x", codePoint))
else -> appendCodePoint(codePoint)
}

View File

@ -23,10 +23,11 @@ internal sealed interface MappedRange {
data class Constant(
override val rangeStart: Int,
val type: Int
val type: Int,
) : MappedRange {
val b1: Int
get() = when (type) {
get() =
when (type) {
TYPE_IGNORED -> 119
TYPE_VALID -> 120
TYPE_DISALLOWED -> 121
@ -36,9 +37,8 @@ internal sealed interface MappedRange {
data class Inline1(
override val rangeStart: Int,
private val mappedTo: ByteString
private val mappedTo: ByteString,
) : MappedRange {
val b1: Int
get() {
val b3bit8 = mappedTo[0] and 0x80 != 0
@ -51,9 +51,8 @@ internal sealed interface MappedRange {
data class Inline2(
override val rangeStart: Int,
private val mappedTo: ByteString
private val mappedTo: ByteString,
) : MappedRange {
val b1: Int
get() {
val b2bit8 = mappedTo[0] and 0x80 != 0
@ -75,13 +74,13 @@ internal sealed interface MappedRange {
data class InlineDelta(
override val rangeStart: Int,
val codepointDelta: Int
val codepointDelta: Int,
) : MappedRange {
private val absoluteDelta = abs(codepointDelta)
val b1: Int
get() = when {
get() =
when {
codepointDelta < 0 -> 0x40 or (absoluteDelta shr 14)
codepointDelta > 0 -> 0x50 or (absoluteDelta shr 14)
else -> error("Unexpected codepointDelta of 0")
@ -100,6 +99,6 @@ internal sealed interface MappedRange {
data class External(
override val rangeStart: Int,
val mappedTo: ByteString
val mappedTo: ByteString,
) : MappedRange
}

View File

@ -135,8 +135,10 @@ internal fun sections(mappings: List<Mapping>): Map<Int, List<MappedRange>> {
val sectionList = result.getOrPut(section) { mutableListOf() }
sectionList += when (mapping.type) {
TYPE_MAPPED -> run {
sectionList +=
when (mapping.type) {
TYPE_MAPPED ->
run {
val deltaMapping = inlineDeltaOrNull(mapping)
if (deltaMapping != null) {
return@run deltaMapping
@ -202,13 +204,15 @@ internal fun withoutSectionSpans(mappings: List<Mapping>): List<Mapping> {
while (true) {
if (current.spansSections) {
result += Mapping(
result +=
Mapping(
current.sourceCodePoint0,
current.section + 0x7f,
current.type,
current.mappedTo,
)
current = Mapping(
current =
Mapping(
current.section + 0x80,
current.sourceCodePoint1,
current.type,
@ -246,7 +250,8 @@ internal fun mergeAdjacentRanges(mappings: List<Mapping>): List<Mapping> {
index++
}
result += Mapping(
result +=
Mapping(
sourceCodePoint0 = mapping.sourceCodePoint0,
sourceCodePoint1 = unionWith.sourceCodePoint1,
type = type,
@ -262,11 +267,13 @@ internal fun canonicalizeType(type: Int): Int {
TYPE_IGNORED -> TYPE_IGNORED
TYPE_MAPPED,
TYPE_DISALLOWED_STD3_MAPPED -> TYPE_MAPPED
TYPE_DISALLOWED_STD3_MAPPED,
-> TYPE_MAPPED
TYPE_DEVIATION,
TYPE_DISALLOWED_STD3_VALID,
TYPE_VALID -> TYPE_VALID
TYPE_VALID,
-> TYPE_VALID
TYPE_DISALLOWED -> TYPE_DISALLOWED
@ -279,4 +286,3 @@ internal infix fun Byte.and(mask: Int): Int = toInt() and mask
internal infix fun Short.and(mask: Int): Int = toInt() and mask
internal infix fun Int.and(mask: Long): Long = toLong() and mask

View File

@ -42,8 +42,12 @@ class SimpleIdnaMappingTable internal constructor(
/**
* Returns true if the [codePoint] was applied successfully. Returns false if it was disallowed.
*/
fun map(codePoint: Int, sink: BufferedSink): Boolean {
val index = mappings.binarySearch {
fun map(
codePoint: Int,
sink: BufferedSink,
): Boolean {
val index =
mappings.binarySearch {
when {
it.sourceCodePoint1 < codePoint -> -1
it.sourceCodePoint0 > codePoint -> 1
@ -77,8 +81,8 @@ class SimpleIdnaMappingTable internal constructor(
}
}
private val optionsDelimiter = Options.of(
private val optionsDelimiter =
Options.of(
// 0.
".".encodeUtf8(),
// 1.
@ -91,7 +95,8 @@ private val optionsDelimiter = Options.of(
"\n".encodeUtf8(),
)
private val optionsDot = Options.of(
private val optionsDot =
Options.of(
// 0.
".".encodeUtf8(),
)
@ -102,7 +107,8 @@ private const val DELIMITER_SEMICOLON = 2
private const val DELIMITER_HASH = 3
private const val DELIMITER_NEWLINE = 4
private val optionsType = Options.of(
private val optionsType =
Options.of(
// 0.
"deviation ".encodeUtf8(),
// 1.
@ -182,7 +188,8 @@ fun BufferedSource.readPlainTextIdnaMappingTable(): SimpleIdnaMappingTable {
// "002F" or "0000..002C"
val sourceCodePoint0 = readHexadecimalUnsignedLong()
val sourceCodePoint1 = when (select(optionsDot)) {
val sourceCodePoint1 =
when (select(optionsDot)) {
DELIMITER_DOT -> {
if (readByte() != '.'.code.toByte()) throw IOException("expected '..'")
readHexadecimalUnsignedLong()
@ -228,7 +235,8 @@ fun BufferedSource.readPlainTextIdnaMappingTable(): SimpleIdnaMappingTable {
skipRestOfLine()
result += Mapping(
result +=
Mapping(
sourceCodePoint0.toInt(),
sourceCodePoint1.toInt(),
type,

View File

@ -34,8 +34,8 @@ class MappingTablesTest {
Mapping(0x0234, 0x0236, TYPE_VALID, ByteString.EMPTY),
Mapping(0x0237, 0x0239, TYPE_VALID, ByteString.EMPTY),
Mapping(0x023a, 0x023a, TYPE_MAPPED, "b".encodeUtf8()),
)
)
),
),
).containsExactly(
Mapping(0x0232, 0x0232, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0233, 0x0239, TYPE_VALID, ByteString.EMPTY),
@ -49,8 +49,8 @@ class MappingTablesTest {
listOf(
Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()),
)
)
),
),
).containsExactly(
Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()),
@ -62,8 +62,8 @@ class MappingTablesTest {
mergeAdjacentRanges(
listOf(
Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY),
)
)
),
),
).containsExactly(
Mapping(0x0000, 0x002c, TYPE_VALID, ByteString.EMPTY),
)
@ -74,9 +74,9 @@ class MappingTablesTest {
mergeAdjacentRanges(
listOf(
Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY),
Mapping(0x002d, 0x002e, TYPE_VALID, ByteString.EMPTY)
)
)
Mapping(0x002d, 0x002e, TYPE_VALID, ByteString.EMPTY),
),
),
).containsExactly(
Mapping(0x0000, 0x002e, TYPE_VALID, ByteString.EMPTY),
)
@ -87,8 +87,8 @@ class MappingTablesTest {
withoutSectionSpans(
listOf(
Mapping(0x40000, 0x40180, TYPE_DISALLOWED, ByteString.EMPTY),
)
)
),
),
).containsExactly(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
@ -103,8 +103,8 @@ class MappingTablesTest {
listOf(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
)
)
),
),
).containsExactly(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
@ -119,8 +119,8 @@ class MappingTablesTest {
InlineDelta(2, 5),
InlineDelta(3, 5),
MappedRange.External(4, "a".encodeUtf8()),
)
)
),
),
).containsExactly(
InlineDelta(1, 5),
MappedRange.External(4, "a".encodeUtf8()),
@ -134,8 +134,8 @@ class MappingTablesTest {
InlineDelta(1, 5),
InlineDelta(2, 5),
InlineDelta(3, 1),
)
)
),
),
).containsExactly(
InlineDelta(1, 5),
InlineDelta(3, 1),
@ -148,9 +148,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 1,
sourceCodePoint1 = 1,
mappedToCodePoints = listOf(2)
)
)
mappedToCodePoints = listOf(2),
),
),
).isEqualTo(InlineDelta(1, 1))
assertThat(
@ -158,9 +158,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 2,
sourceCodePoint1 = 2,
mappedToCodePoints = listOf(1)
)
)
mappedToCodePoints = listOf(1),
),
),
).isEqualTo(InlineDelta(2, -1))
}
@ -170,9 +170,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 2,
sourceCodePoint1 = 3,
mappedToCodePoints = listOf(2)
)
)
mappedToCodePoints = listOf(2),
),
),
).isEqualTo(null)
}
@ -182,9 +182,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 1,
sourceCodePoint1 = 1,
mappedToCodePoints = listOf(2, 3)
)
)
mappedToCodePoints = listOf(2, 3),
),
),
).isEqualTo(null)
}
@ -194,14 +194,14 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 0,
sourceCodePoint1 = 0,
mappedToCodePoints = listOf((1 shl 18) - 1)
)
)
mappedToCodePoints = listOf((1 shl 18) - 1),
),
),
).isEqualTo(
InlineDelta(
rangeStart = 0,
codepointDelta = InlineDelta.MAX_VALUE
)
codepointDelta = InlineDelta.MAX_VALUE,
),
)
assertThat(
@ -209,24 +209,26 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 0,
sourceCodePoint1 = 0,
mappedToCodePoints = listOf(1 shl 18)
)
)
mappedToCodePoints = listOf(1 shl 18),
),
),
).isEqualTo(null)
}
private fun mappingOf(
sourceCodePoint0: Int,
sourceCodePoint1: Int,
mappedToCodePoints: List<Int>
): Mapping = Mapping(
mappedToCodePoints: List<Int>,
): Mapping =
Mapping(
sourceCodePoint0 = sourceCodePoint0,
sourceCodePoint1 = sourceCodePoint1,
type = TYPE_MAPPED,
mappedTo = Buffer().also {
mappedTo =
Buffer().also {
for (cp in mappedToCodePoints) {
it.writeUtf8CodePoint(cp)
}
}.readByteString()
}.readByteString(),
)
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.java.net.cookiejar
package okhttp3.java.net.cookiejar
import java.io.IOException
import java.net.CookieHandler
@ -32,8 +32,10 @@ import okhttp3.internal.trimSubstring
/** A cookie jar that delegates to a [java.net.CookieHandler]. */
class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
override fun saveFromResponse(url: HttpUrl, cookies: List<Cookie>) {
override fun saveFromResponse(
url: HttpUrl,
cookies: List<Cookie>,
) {
val cookieStrings = mutableListOf<String>()
for (cookie in cookies) {
cookieStrings.add(cookieToString(cookie, true))
@ -47,7 +49,8 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
}
override fun loadForRequest(url: HttpUrl): List<Cookie> {
val cookieHeaders = try {
val cookieHeaders =
try {
// The RI passes all headers. We don't have 'em, so we don't pass 'em!
cookieHandler.get(url.toUri(), emptyMap<String, List<String>>())
} catch (e: IOException) {
@ -58,7 +61,8 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
var cookies: MutableList<Cookie>? = null
for ((key, value) in cookieHeaders) {
if (("Cookie".equals(key, ignoreCase = true) || "Cookie2".equals(key, ignoreCase = true)) &&
value.isNotEmpty()) {
value.isNotEmpty()
) {
for (header in value) {
if (cookies == null) cookies = mutableListOf()
cookies.addAll(decodeHeaderAsJavaNetCookies(url, header))
@ -77,7 +81,10 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
* Convert a request header to OkHttp's cookies via [HttpCookie]. That extra step handles
* multiple cookies in a single request header, which [Cookie.parse] doesn't support.
*/
private fun decodeHeaderAsJavaNetCookies(url: HttpUrl, header: String): List<Cookie> {
private fun decodeHeaderAsJavaNetCookies(
url: HttpUrl,
header: String,
): List<Cookie> {
val result = mutableListOf<Cookie>()
var pos = 0
val limit = header.length
@ -92,7 +99,8 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
}
// We have either name=value or just a name.
var value = if (equalsSign < pairEnd) {
var value =
if (equalsSign < pairEnd) {
header.trimSubstring(equalsSign + 1, pairEnd)
} else {
""
@ -103,11 +111,13 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
value = value.substring(1, value.length - 1)
}
result.add(Cookie.Builder()
result.add(
Cookie.Builder()
.name(name)
.value(value)
.domain(url.host)
.build())
.build(),
)
pos = pairEnd + 1
}
return result

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.logging
import java.io.IOException
@ -38,14 +39,16 @@ import okio.GzipSource
* The format of the logs created by this class should not be considered stable and may
* change slightly between releases. If you need a stable logging format, use your own interceptor.
*/
class HttpLoggingInterceptor @JvmOverloads constructor(
private val logger: Logger = Logger.DEFAULT
class HttpLoggingInterceptor
@JvmOverloads
constructor(
private val logger: Logger = Logger.DEFAULT,
) : Interceptor {
@Volatile private var headersToRedact = emptySet<String>()
@set:JvmName("level")
@Volatile var level = Level.NONE
@Volatile
var level = Level.NONE
enum class Level {
/** No logs. */
@ -103,7 +106,7 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
* <-- END HTTP
* ```
*/
BODY
BODY,
}
fun interface Logger {
@ -113,6 +116,7 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
/** A [Logger] defaults output appropriate for the current platform. */
@JvmField
val DEFAULT: Logger = DefaultLogger()
private class DefaultLogger : Logger {
override fun log(message: String) {
Platform.get().log(message)
@ -135,7 +139,8 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
* un-deprecated because Java callers can't chain when assigning Kotlin vals. (The getter remains
* deprecated).
*/
fun setLevel(level: Level) = apply {
fun setLevel(level: Level) =
apply {
this.level = level
}
@ -143,7 +148,7 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "level"),
level = DeprecationLevel.ERROR
level = DeprecationLevel.ERROR,
)
fun getLevel(): Level = level
@ -217,7 +222,7 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
logger.log("")
if (!buffer.isProbablyUtf8()) {
logger.log(
"--> END ${request.method} (binary ${requestBody.contentLength()}-byte body omitted)"
"--> END ${request.method} (binary ${requestBody.contentLength()}-byte body omitted)",
)
} else if (gzippedLength != null) {
logger.log("--> END ${request.method} (${buffer.size}-byte, $gzippedLength-gzipped-byte body)")
@ -249,7 +254,7 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
append(" ${response.request.url} (${tookMs}ms")
if (!logHeaders) append(", $bodySize body")
append(")")
}
},
)
if (logHeaders) {
@ -299,7 +304,7 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
append("<-- END HTTP (${totalMs}ms, ${buffer.size}-byte")
if (gzippedLength != null) append(", $gzippedLength-gzipped-byte")
append(" body)")
}
},
)
}
}
@ -307,7 +312,10 @@ class HttpLoggingInterceptor @JvmOverloads constructor(
return response
}
private fun logHeader(headers: Headers, i: Int) {
private fun logHeader(
headers: Headers,
i: Int,
) {
val value = if (headers.name(i) in headersToRedact) "██" else headers.value(i)
logger.log(headers.name(i) + ": " + value)
}

View File

@ -38,7 +38,7 @@ import okhttp3.Response
* slightly between releases. If you need a stable logging format, use your own event listener.
*/
class LoggingEventListener private constructor(
private val logger: HttpLoggingInterceptor.Logger
private val logger: HttpLoggingInterceptor.Logger,
) : EventListener() {
private var startNs: Long = 0
@ -48,23 +48,41 @@ class LoggingEventListener private constructor(
logWithTime("callStart: ${call.request()}")
}
override fun proxySelectStart(call: Call, url: HttpUrl) {
override fun proxySelectStart(
call: Call,
url: HttpUrl,
) {
logWithTime("proxySelectStart: $url")
}
override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List<Proxy>) {
override fun proxySelectEnd(
call: Call,
url: HttpUrl,
proxies: List<Proxy>,
) {
logWithTime("proxySelectEnd: $proxies")
}
override fun dnsStart(call: Call, domainName: String) {
override fun dnsStart(
call: Call,
domainName: String,
) {
logWithTime("dnsStart: $domainName")
}
override fun dnsEnd(call: Call, domainName: String, inetAddressList: List<InetAddress>) {
override fun dnsEnd(
call: Call,
domainName: String,
inetAddressList: List<InetAddress>,
) {
logWithTime("dnsEnd: $inetAddressList")
}
override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) {
override fun connectStart(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
) {
logWithTime("connectStart: $inetSocketAddress $proxy")
}
@ -72,7 +90,10 @@ class LoggingEventListener private constructor(
logWithTime("secureConnectStart")
}
override fun secureConnectEnd(call: Call, handshake: Handshake?) {
override fun secureConnectEnd(
call: Call,
handshake: Handshake?,
) {
logWithTime("secureConnectEnd: $handshake")
}
@ -80,7 +101,7 @@ class LoggingEventListener private constructor(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?
protocol: Protocol?,
) {
logWithTime("connectEnd: $protocol")
}
@ -90,16 +111,22 @@ class LoggingEventListener private constructor(
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?,
ioe: IOException
ioe: IOException,
) {
logWithTime("connectFailed: $protocol $ioe")
}
override fun connectionAcquired(call: Call, connection: Connection) {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
logWithTime("connectionAcquired: $connection")
}
override fun connectionReleased(call: Call, connection: Connection) {
override fun connectionReleased(
call: Call,
connection: Connection,
) {
logWithTime("connectionReleased")
}
@ -107,7 +134,10 @@ class LoggingEventListener private constructor(
logWithTime("requestHeadersStart")
}
override fun requestHeadersEnd(call: Call, request: Request) {
override fun requestHeadersEnd(
call: Call,
request: Request,
) {
logWithTime("requestHeadersEnd")
}
@ -115,11 +145,17 @@ class LoggingEventListener private constructor(
logWithTime("requestBodyStart")
}
override fun requestBodyEnd(call: Call, byteCount: Long) {
override fun requestBodyEnd(
call: Call,
byteCount: Long,
) {
logWithTime("requestBodyEnd: byteCount=$byteCount")
}
override fun requestFailed(call: Call, ioe: IOException) {
override fun requestFailed(
call: Call,
ioe: IOException,
) {
logWithTime("requestFailed: $ioe")
}
@ -127,7 +163,10 @@ class LoggingEventListener private constructor(
logWithTime("responseHeadersStart")
}
override fun responseHeadersEnd(call: Call, response: Response) {
override fun responseHeadersEnd(
call: Call,
response: Response,
) {
logWithTime("responseHeadersEnd: $response")
}
@ -135,11 +174,17 @@ class LoggingEventListener private constructor(
logWithTime("responseBodyStart")
}
override fun responseBodyEnd(call: Call, byteCount: Long) {
override fun responseBodyEnd(
call: Call,
byteCount: Long,
) {
logWithTime("responseBodyEnd: byteCount=$byteCount")
}
override fun responseFailed(call: Call, ioe: IOException) {
override fun responseFailed(
call: Call,
ioe: IOException,
) {
logWithTime("responseFailed: $ioe")
}
@ -147,7 +192,10 @@ class LoggingEventListener private constructor(
logWithTime("callEnd")
}
override fun callFailed(call: Call, ioe: IOException) {
override fun callFailed(
call: Call,
ioe: IOException,
) {
logWithTime("callFailed: $ioe")
}
@ -155,11 +203,17 @@ class LoggingEventListener private constructor(
logWithTime("canceled")
}
override fun satisfactionFailure(call: Call, response: Response) {
override fun satisfactionFailure(
call: Call,
response: Response,
) {
logWithTime("satisfactionFailure: $response")
}
override fun cacheHit(call: Call, response: Response) {
override fun cacheHit(
call: Call,
response: Response,
) {
logWithTime("cacheHit: $response")
}
@ -167,7 +221,10 @@ class LoggingEventListener private constructor(
logWithTime("cacheMiss")
}
override fun cacheConditionalHit(call: Call, cachedResponse: Response) {
override fun cacheConditionalHit(
call: Call,
cachedResponse: Response,
) {
logWithTime("cacheConditionalHit: $cachedResponse")
}
@ -176,8 +233,10 @@ class LoggingEventListener private constructor(
logger.log("[$timeMs ms] $message")
}
open class Factory @JvmOverloads constructor(
private val logger: HttpLoggingInterceptor.Logger = HttpLoggingInterceptor.Logger.DEFAULT
open class Factory
@JvmOverloads
constructor(
private val logger: HttpLoggingInterceptor.Logger = HttpLoggingInterceptor.Logger.DEFAULT,
) : EventListener.Factory {
override fun create(call: Call): EventListener = LoggingEventListener(logger)
}

View File

@ -72,13 +72,16 @@ class HttpLoggingInterceptorTest {
@BeforeEach
fun setUp(server: MockWebServer) {
this.server = server
client = OkHttpClient.Builder()
.addNetworkInterceptor(Interceptor { chain ->
client =
OkHttpClient.Builder()
.addNetworkInterceptor(
Interceptor { chain ->
when {
extraNetworkInterceptor != null -> extraNetworkInterceptor!!.intercept(chain)
else -> chain.proceed(chain.request())
}
})
},
)
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor)
.sslSocketFactory(
@ -153,7 +156,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -174,7 +177,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -278,7 +281,8 @@ class HttpLoggingInterceptorTest {
fun headersPostNoLength() {
setLevel(Level.HEADERS)
server.enqueue(MockResponse())
val body: RequestBody = object : RequestBody() {
val body: RequestBody =
object : RequestBody() {
override fun contentType() = PLAIN
override fun writeTo(sink: BufferedSink) {
@ -313,20 +317,21 @@ class HttpLoggingInterceptorTest {
@Test
fun headersPostWithHeaderOverrides() {
setLevel(Level.HEADERS)
extraNetworkInterceptor = Interceptor { chain: Interceptor.Chain ->
extraNetworkInterceptor =
Interceptor { chain: Interceptor.Chain ->
chain.proceed(
chain.request()
.newBuilder()
.header("Content-Length", "2")
.header("Content-Type", "text/plain-ish")
.build()
.build(),
)
}
server.enqueue(MockResponse())
client.newCall(
request()
.post("Hi?".toRequestBody(PLAIN))
.build()
.build(),
).execute()
applicationLogs
.assertLogEqual("--> POST $url")
@ -359,7 +364,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -427,7 +432,7 @@ class HttpLoggingInterceptorTest {
server.enqueue(
MockResponse.Builder()
.status("HTTP/1.1 $code No Content")
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -493,7 +498,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -530,7 +535,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -567,13 +572,14 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.setHeader("Content-Type", PLAIN)
.body(Buffer().writeUtf8("Uncompressed"))
.build()
.build(),
)
val response = client.newCall(
val response =
client.newCall(
request()
.addHeader("Content-Encoding", "gzip")
.post("Uncompressed".toRequestBody().gzip())
.build()
.build(),
).execute()
val responseBody = response.body
assertThat(responseBody.string(), "Expected response body to be valid")
@ -606,7 +612,7 @@ class HttpLoggingInterceptorTest {
.setHeader("Content-Encoding", "gzip")
.setHeader("Content-Type", PLAIN)
.body(Buffer().write("H4sIAAAAAAAAAPNIzcnJ11HwQKIAdyO+9hMAAAA=".decodeBase64()!!))
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
val responseBody = response.body
@ -647,7 +653,7 @@ class HttpLoggingInterceptorTest {
.setHeader("Content-Encoding", "br")
.setHeader("Content-Type", PLAIN)
.body(Buffer().write("iwmASGVsbG8sIEhlbGxvLCBIZWxsbwoD".decodeBase64()!!))
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -693,9 +699,10 @@ class HttpLoggingInterceptorTest {
|data: 113411
|
|
""".trimMargin(), 8
""".trimMargin(),
8,
)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -728,7 +735,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.setHeader("Content-Type", "text/html; charset=0")
.body("Body with unknown charset")
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -774,7 +781,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body(buffer)
.setHeader("Content-Type", "image/png; charset=utf-8")
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -805,7 +812,8 @@ class HttpLoggingInterceptorTest {
@Test
fun connectFail() {
setLevel(Level.BASIC)
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.dns { hostname: String? -> throw UnknownHostException("reason") }
.addInterceptor(applicationInterceptor)
.build()
@ -840,29 +848,33 @@ class HttpLoggingInterceptorTest {
@Test
fun headersAreRedacted() {
val networkInterceptor = HttpLoggingInterceptor(networkLogs).setLevel(
Level.HEADERS
val networkInterceptor =
HttpLoggingInterceptor(networkLogs).setLevel(
Level.HEADERS,
)
networkInterceptor.redactHeader("sEnSiTiVe")
val applicationInterceptor = HttpLoggingInterceptor(applicationLogs).setLevel(
Level.HEADERS
val applicationInterceptor =
HttpLoggingInterceptor(applicationLogs).setLevel(
Level.HEADERS,
)
applicationInterceptor.redactHeader("sEnSiTiVe")
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor)
.build()
server.enqueue(
MockResponse.Builder()
.addHeader("SeNsItIvE", "Value").addHeader("Not-Sensitive", "Value")
.build()
.build(),
)
val response = client
val response =
client
.newCall(
request()
.addHeader("SeNsItIvE", "Value")
.addHeader("Not-Sensitive", "Value")
.build()
.build(),
)
.execute()
response.body.close()
@ -903,9 +915,10 @@ class HttpLoggingInterceptorTest {
server.enqueue(
MockResponse.Builder()
.body("Hello response!")
.build()
.build(),
)
val asyncRequestBody: RequestBody = object : RequestBody() {
val asyncRequestBody: RequestBody =
object : RequestBody() {
override fun contentType(): MediaType? {
return null
}
@ -919,7 +932,8 @@ class HttpLoggingInterceptorTest {
return true
}
}
val request = request()
val request =
request()
.post(asyncRequestBody)
.build()
val response = client.newCall(request).execute()
@ -943,9 +957,10 @@ class HttpLoggingInterceptorTest {
server.enqueue(
MockResponse.Builder()
.body("Hello response!")
.build()
.build(),
)
val asyncRequestBody: RequestBody = object : RequestBody() {
val asyncRequestBody: RequestBody =
object : RequestBody() {
var counter = 0
override fun contentType() = null
@ -959,7 +974,8 @@ class HttpLoggingInterceptorTest {
override fun isOneShot() = true
}
val request = request()
val request =
request()
.post(asyncRequestBody)
.build()
val response = client.newCall(request).execute()
@ -985,14 +1001,16 @@ class HttpLoggingInterceptorTest {
private val logs = mutableListOf<String>()
private var index = 0
fun assertLogEqual(expected: String) = apply {
fun assertLogEqual(expected: String) =
apply {
assertThat(index, "No more messages found")
.isLessThan(logs.size)
assertThat(logs[index++]).isEqualTo(expected)
return this
}
fun assertLogMatch(regex: Regex) = apply {
fun assertLogMatch(regex: Regex) =
apply {
assertThat(index, "No more messages found")
.isLessThan(logs.size)
assertThat(logs[index++])

View File

@ -50,8 +50,9 @@ class LoggingEventListenerTest {
val clientTestRule = OkHttpClientTestRule()
private lateinit var server: MockWebServer
private val handshakeCertificates = platform.localhostHandshakeCertificates()
private val logRecorder = HttpLoggingInterceptorTest.LogRecorder(
prefix = Regex("""\[\d+ ms] """)
private val logRecorder =
HttpLoggingInterceptorTest.LogRecorder(
prefix = Regex("""\[\d+ ms] """),
)
private val loggingEventListenerFactory = LoggingEventListener.Factory(logRecorder)
private lateinit var client: OkHttpClient
@ -60,11 +61,12 @@ class LoggingEventListenerTest {
@BeforeEach
fun setUp(server: MockWebServer) {
this.server = server
client = clientTestRule.newClientBuilder()
client =
clientTestRule.newClientBuilder()
.eventListenerFactory(loggingEventListenerFactory)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
handshakeCertificates.trustManager,
)
.retryOnConnectionFailure(false)
.build()
@ -78,7 +80,7 @@ class LoggingEventListenerTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
assertThat(response.body).isNotNull()
@ -91,7 +93,11 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""connectEnd: http/1.1"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}"""))
.assertLogMatch(
Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
),
)
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart"""))
@ -116,7 +122,11 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""connectEnd: http/1.1"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}"""))
.assertLogMatch(
Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
),
)
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""requestBodyStart"""))
@ -148,9 +158,15 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""secureConnectStart"""))
.assertLogMatch(Regex("""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}"""))
.assertLogMatch(
Regex(
"""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}""",
),
)
.assertLogMatch(Regex("""connectEnd: h2"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""))
.assertLogMatch(
Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""),
)
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart"""))
@ -164,7 +180,8 @@ class LoggingEventListenerTest {
@Test
fun dnsFail() {
client = OkHttpClient.Builder()
client =
OkHttpClient.Builder()
.dns { _ -> throw UnknownHostException("reason") }
.eventListenerFactory(loggingEventListenerFactory)
.build()
@ -190,7 +207,7 @@ class LoggingEventListenerTest {
server.enqueue(
MockResponse.Builder()
.socketPolicy(FailHandshake)
.build()
.build(),
)
url = server.url("/")
try {
@ -206,8 +223,16 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""secureConnectStart"""))
.assertLogMatch(Regex("""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*"""))
.assertLogMatch(Regex("""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*"""))
.assertLogMatch(
Regex(
"""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
),
)
.assertLogMatch(
Regex(
"""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
),
)
.assertNoMoreLogs()
}

View File

@ -33,6 +33,9 @@ interface EventSource {
* asynchronous process to connect the socket. Once that succeeds or fails, `listener` will be
* notified. The caller must cancel the returned event source when it is no longer in use.
*/
fun newEventSource(request: Request, listener: EventSourceListener): EventSource
fun newEventSource(
request: Request,
listener: EventSourceListener,
): EventSource
}
}

View File

@ -22,13 +22,21 @@ abstract class EventSourceListener {
* Invoked when an event source has been accepted by the remote peer and may begin transmitting
* events.
*/
open fun onOpen(eventSource: EventSource, response: Response) {
open fun onOpen(
eventSource: EventSource,
response: Response,
) {
}
/**
* TODO description.
*/
open fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
open fun onEvent(
eventSource: EventSource,
id: String?,
type: String?,
data: String,
) {
}
/**
@ -43,6 +51,10 @@ abstract class EventSourceListener {
* Invoked when an event source has been closed due to an error reading from or writing to the
* network. Incoming events may have been lost. No further calls to this listener will be made.
*/
open fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
open fun onFailure(
eventSource: EventSource,
t: Throwable?,
response: Response?,
) {
}
}

View File

@ -45,7 +45,10 @@ object EventSources {
}
@JvmStatic
fun processResponse(response: Response, listener: EventSourceListener) {
fun processResponse(
response: Response,
listener: EventSourceListener,
) {
val eventSource = RealEventSource(response.request, listener)
eventSource.processResponse(response)
}

View File

@ -27,18 +27,23 @@ import okhttp3.sse.EventSourceListener
internal class RealEventSource(
private val request: Request,
private val listener: EventSourceListener
private val listener: EventSourceListener,
) : EventSource, ServerSentEventReader.Callback, Callback {
private var call: Call? = null
@Volatile private var canceled = false
fun connect(callFactory: Call.Factory) {
call = callFactory.newCall(request).apply {
call =
callFactory.newCall(request).apply {
enqueue(this@RealEventSource)
}
}
override fun onResponse(call: Call, response: Response) {
override fun onResponse(
call: Call,
response: Response,
) {
processResponse(response)
}
@ -52,8 +57,11 @@ internal class RealEventSource(
val body = response.body
if (!body.isEventStream()) {
listener.onFailure(this,
IllegalStateException("Invalid content-type: ${body.contentType()}"), response)
listener.onFailure(
this,
IllegalStateException("Invalid content-type: ${body.contentType()}"),
response,
)
return
}
@ -71,7 +79,8 @@ internal class RealEventSource(
}
}
} catch (e: Exception) {
val exception = when {
val exception =
when {
canceled -> IOException("canceled", e)
else -> e
}
@ -91,7 +100,10 @@ internal class RealEventSource(
return contentType.type == "text" && contentType.subtype == "event-stream"
}
override fun onFailure(call: Call, e: IOException) {
override fun onFailure(
call: Call,
e: IOException,
) {
listener.onFailure(this, e, null)
}
@ -102,7 +114,11 @@ internal class RealEventSource(
call?.cancel()
}
override fun onEvent(id: String?, type: String?, data: String) {
override fun onEvent(
id: String?,
type: String?,
data: String,
) {
listener.onEvent(this, id, type, data)
}

View File

@ -24,12 +24,17 @@ import okio.Options
class ServerSentEventReader(
private val source: BufferedSource,
private val callback: Callback
private val callback: Callback,
) {
private var lastId: String? = null
interface Callback {
fun onEvent(id: String?, type: String?, data: String)
fun onEvent(
id: String?,
type: String?,
data: String,
)
fun onRetryChange(timeMs: Long)
}
@ -101,7 +106,11 @@ class ServerSentEventReader(
}
@Throws(IOException::class)
private fun completeEvent(id: String?, type: String?, data: Buffer) {
private fun completeEvent(
id: String?,
type: String?,
data: Buffer,
) {
if (data.size != 0L) {
lastId = id
data.skip(1L) // Leading newline.
@ -110,34 +119,48 @@ class ServerSentEventReader(
}
companion object {
val options = Options.of(
/* 0 */ "\r\n".encodeUtf8(),
/* 1 */ "\r".encodeUtf8(),
/* 2 */ "\n".encodeUtf8(),
/* 3 */ "data: ".encodeUtf8(),
/* 4 */ "data:".encodeUtf8(),
/* 5 */ "data\r\n".encodeUtf8(),
/* 6 */ "data\r".encodeUtf8(),
/* 7 */ "data\n".encodeUtf8(),
/* 8 */ "id: ".encodeUtf8(),
/* 9 */ "id:".encodeUtf8(),
/* 10 */ "id\r\n".encodeUtf8(),
/* 11 */ "id\r".encodeUtf8(),
/* 12 */ "id\n".encodeUtf8(),
/* 13 */ "event: ".encodeUtf8(),
/* 14 */ "event:".encodeUtf8(),
/* 15 */ "event\r\n".encodeUtf8(),
/* 16 */ "event\r".encodeUtf8(),
/* 17 */ "event\n".encodeUtf8(),
/* 18 */ "retry: ".encodeUtf8(),
/* 19 */ "retry:".encodeUtf8()
val options =
Options.of(
// 0
"\r\n".encodeUtf8(),
// 1
"\r".encodeUtf8(),
// 2
"\n".encodeUtf8(),
// 3
"data: ".encodeUtf8(),
// 4
"data:".encodeUtf8(),
// 5
"data\r\n".encodeUtf8(),
// 6
"data\r".encodeUtf8(),
// 7
"data\n".encodeUtf8(),
// 8
"id: ".encodeUtf8(),
// 9
"id:".encodeUtf8(),
// 10
"id\r\n".encodeUtf8(),
// 11
"id\r".encodeUtf8(),
// 12
"id\n".encodeUtf8(),
// 13
"event: ".encodeUtf8(),
// 14
"event:".encodeUtf8(),
// 15
"event\r\n".encodeUtf8(),
// 16
"event\r".encodeUtf8(),
// 17
"event\n".encodeUtf8(),
// 18
"retry: ".encodeUtf8(),
// 19
"retry:".encodeUtf8(),
)
private val CRLF = "\r\n".encodeUtf8()

View File

@ -47,7 +47,8 @@ class EventSourceHttpTest {
val clientTestRule = OkHttpClientTestRule()
private val eventListener = RecordingEventListener()
private val listener = EventSourceRecorder()
private var client = clientTestRule.newClientBuilder()
private var client =
clientTestRule.newClientBuilder()
.eventListenerFactory(clientTestRule.wrap(eventListener))
.build()
@ -70,9 +71,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -90,9 +91,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
listener.enqueueCancel() // Will cancel in onOpen().
newEventSource()
@ -109,9 +110,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/plain")
.build()
.build(),
)
newEventSource()
listener.assertFailure("Invalid content-type: text/plain")
@ -126,11 +127,11 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
)
.setHeader("content-type", "text/event-stream")
.code(401)
.build()
.build(),
)
newEventSource()
listener.assertFailure(null)
@ -138,7 +139,8 @@ class EventSourceHttpTest {
@Test
fun fullCallTimeoutDoesNotApplyOnceConnected() {
client = client.newBuilder()
client =
client.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS)
.build()
server.enqueue(
@ -146,7 +148,7 @@ class EventSourceHttpTest {
.bodyDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream")
.body("data: hey\n\n")
.build()
.build(),
)
val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -157,7 +159,8 @@ class EventSourceHttpTest {
@Test
fun fullCallTimeoutAppliesToSetup() {
client = client.newBuilder()
client =
client.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS)
.build()
server.enqueue(
@ -165,7 +168,7 @@ class EventSourceHttpTest {
.headersDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream")
.body("data: hey\n\n")
.build()
.build(),
)
newEventSource()
listener.assertFailure("timeout")
@ -180,10 +183,10 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
)
.setHeader("content-type", "text/event-stream")
.build()
.build(),
)
newEventSource("text/plain")
listener.assertOpen()
@ -201,9 +204,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
newEventSource()
listener.assertOpen()
@ -222,9 +225,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -247,12 +250,13 @@ class EventSourceHttpTest {
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd"
"CallEnd",
)
}
private fun newEventSource(accept: String? = null): EventSource {
val builder = Request.Builder()
val builder =
Request.Builder()
.url(server.url("/"))
if (accept != null) {
builder.header("Accept", accept)

View File

@ -35,7 +35,10 @@ class EventSourceRecorder : EventSourceListener() {
cancel = true
}
override fun onOpen(eventSource: EventSource, response: Response) {
override fun onOpen(
eventSource: EventSource,
response: Response,
) {
get().log("[ES] onOpen", Platform.INFO, null)
events.add(Open(eventSource, response))
drainCancelQueue(eventSource)
@ -52,9 +55,7 @@ class EventSourceRecorder : EventSourceListener() {
drainCancelQueue(eventSource)
}
override fun onClosed(
eventSource: EventSource,
) {
override fun onClosed(eventSource: EventSource) {
get().log("[ES] onClosed", Platform.INFO, null)
events.add(Closed)
drainCancelQueue(eventSource)
@ -70,9 +71,7 @@ class EventSourceRecorder : EventSourceListener() {
drainCancelQueue(eventSource)
}
private fun drainCancelQueue(
eventSource: EventSource,
) {
private fun drainCancelQueue(eventSource: EventSource) {
if (cancel) {
cancel = false
eventSource.cancel()

View File

@ -61,11 +61,12 @@ class EventSourcesHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
val request = Request.Builder()
val request =
Request.Builder()
.url(server.url("/"))
.build()
val response = client.newCall(request).execute()
@ -84,12 +85,13 @@ class EventSourcesHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
listener.enqueueCancel() // Will cancel in onOpen().
val request = Request.Builder()
val request =
Request.Builder()
.url(server.url("/"))
.build()
val response = client.newCall(request).execute()

View File

@ -42,7 +42,7 @@ class ServerSentEventIteratorTest {
|data: 10
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
}
@ -56,7 +56,7 @@ class ServerSentEventIteratorTest {
|data: 10
|
|
""".trimMargin().replace("\n", "\r")
""".trimMargin().replace("\n", "\r"),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
}
@ -70,7 +70,7 @@ class ServerSentEventIteratorTest {
|data: 10
|
|
""".trimMargin().replace("\n", "\r\n")
""".trimMargin().replace("\n", "\r\n"),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
}
@ -89,7 +89,7 @@ class ServerSentEventIteratorTest {
|data: 113411
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, "add", "73857293"))
assertThat(callbacks.remove()).isEqualTo(Event(null, "remove", "2153"))
@ -106,7 +106,7 @@ class ServerSentEventIteratorTest {
|id: 1
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
}
@ -124,7 +124,7 @@ class ServerSentEventIteratorTest {
|data: third event
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "second event"))
@ -142,7 +142,7 @@ class ServerSentEventIteratorTest {
|
|data:
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, ""))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "\n"))
@ -157,7 +157,7 @@ class ServerSentEventIteratorTest {
|data: test
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test"))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test"))
@ -170,7 +170,7 @@ class ServerSentEventIteratorTest {
|data: test
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, " test"))
}
@ -188,7 +188,7 @@ class ServerSentEventIteratorTest {
|data: third event
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event"))
@ -207,7 +207,7 @@ class ServerSentEventIteratorTest {
|data: second event
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event"))
@ -223,7 +223,7 @@ class ServerSentEventIteratorTest {
|id: 1
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(22L)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
@ -237,7 +237,7 @@ class ServerSentEventIteratorTest {
|
|retry: hey
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(22L)
}
@ -253,7 +253,7 @@ class ServerSentEventIteratorTest {
|retrying
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "a"))
}
@ -270,7 +270,7 @@ class ServerSentEventIteratorTest {
|data
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "c\n"))
}
@ -281,14 +281,19 @@ class ServerSentEventIteratorTest {
"""
|retry
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks).isEmpty()
}
private fun consumeEvents(source: String) {
val callback: ServerSentEventReader.Callback = object : ServerSentEventReader.Callback {
override fun onEvent(id: String?, type: String?, data: String) {
val callback: ServerSentEventReader.Callback =
object : ServerSentEventReader.Callback {
override fun onEvent(
id: String?,
type: String?,
data: String,
) {
callbacks.add(Event(id, type, data))
}

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3
import java.io.IOException
@ -37,40 +38,38 @@ sealed class CallEvent {
data class ProxySelectStart(
override val timestampNs: Long,
override val call: Call,
val url: HttpUrl
val url: HttpUrl,
) : CallEvent()
data class ProxySelectEnd(
override val timestampNs: Long,
override val call: Call,
val url: HttpUrl,
val proxies: List<Proxy>?
val proxies: List<Proxy>?,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ProxySelectStart && call == event.call && url == event.url
override fun closes(event: CallEvent): Boolean = event is ProxySelectStart && call == event.call && url == event.url
}
data class DnsStart(
override val timestampNs: Long,
override val call: Call,
val domainName: String
val domainName: String,
) : CallEvent()
data class DnsEnd(
override val timestampNs: Long,
override val call: Call,
val domainName: String,
val inetAddressList: List<InetAddress>
val inetAddressList: List<InetAddress>,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is DnsStart && call == event.call && domainName == event.domainName
override fun closes(event: CallEvent): Boolean = event is DnsStart && call == event.call && domainName == event.domainName
}
data class ConnectStart(
override val timestampNs: Long,
override val call: Call,
val inetSocketAddress: InetSocketAddress,
val proxy: Proxy?
val proxy: Proxy?,
) : CallEvent()
data class ConnectEnd(
@ -78,7 +77,7 @@ sealed class CallEvent {
override val call: Call,
val inetSocketAddress: InetSocketAddress,
val proxy: Proxy?,
val protocol: Protocol?
val protocol: Protocol?,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy
@ -90,7 +89,7 @@ sealed class CallEvent {
val inetSocketAddress: InetSocketAddress,
val proxy: Proxy,
val protocol: Protocol?,
val ioe: IOException
val ioe: IOException,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy
@ -98,149 +97,139 @@ sealed class CallEvent {
data class SecureConnectStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class SecureConnectEnd(
override val timestampNs: Long,
override val call: Call,
val handshake: Handshake?
val handshake: Handshake?,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is SecureConnectStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is SecureConnectStart && call == event.call
}
data class ConnectionAcquired(
override val timestampNs: Long,
override val call: Call,
val connection: Connection
val connection: Connection,
) : CallEvent()
data class ConnectionReleased(
override val timestampNs: Long,
override val call: Call,
val connection: Connection
val connection: Connection,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ConnectionAcquired && call == event.call && connection == event.connection
override fun closes(event: CallEvent): Boolean = event is ConnectionAcquired && call == event.call && connection == event.connection
}
data class CallStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CallEnd(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is CallStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is CallStart && call == event.call
}
data class CallFailed(
override val timestampNs: Long,
override val call: Call,
val ioe: IOException
val ioe: IOException,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is CallStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is CallStart && call == event.call
}
data class Canceled(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class RequestHeadersStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class RequestHeadersEnd(
override val timestampNs: Long,
override val call: Call,
val headerLength: Long
val headerLength: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is RequestHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is RequestHeadersStart && call == event.call
}
data class RequestBodyStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class RequestBodyEnd(
override val timestampNs: Long,
override val call: Call,
val bytesWritten: Long
val bytesWritten: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is RequestBodyStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is RequestBodyStart && call == event.call
}
data class RequestFailed(
override val timestampNs: Long,
override val call: Call,
val ioe: IOException
val ioe: IOException,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is RequestHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is RequestHeadersStart && call == event.call
}
data class ResponseHeadersStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class ResponseHeadersEnd(
override val timestampNs: Long,
override val call: Call,
val headerLength: Long
val headerLength: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ResponseHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is ResponseHeadersStart && call == event.call
}
data class ResponseBodyStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class ResponseBodyEnd(
override val timestampNs: Long,
override val call: Call,
val bytesRead: Long
val bytesRead: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ResponseBodyStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is ResponseBodyStart && call == event.call
}
data class ResponseFailed(
override val timestampNs: Long,
override val call: Call,
val ioe: IOException
val ioe: IOException,
) : CallEvent()
data class SatisfactionFailure(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CacheHit(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CacheMiss(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CacheConditionalHit(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
}

View File

@ -23,7 +23,7 @@ import java.util.concurrent.TimeUnit
class ClientRuleEventListener(
val delegate: EventListener = NONE,
var logger: (String) -> Unit
var logger: (String) -> Unit,
) : EventListener(),
EventListener.Factory {
private var startNs: Long? = null
@ -40,7 +40,7 @@ class ClientRuleEventListener(
override fun proxySelectStart(
call: Call,
url: HttpUrl
url: HttpUrl,
) {
logWithTime("proxySelectStart: $url")
@ -50,7 +50,7 @@ class ClientRuleEventListener(
override fun proxySelectEnd(
call: Call,
url: HttpUrl,
proxies: List<Proxy>
proxies: List<Proxy>,
) {
logWithTime("proxySelectEnd: $proxies")
@ -59,7 +59,7 @@ class ClientRuleEventListener(
override fun dnsStart(
call: Call,
domainName: String
domainName: String,
) {
logWithTime("dnsStart: $domainName")
@ -69,7 +69,7 @@ class ClientRuleEventListener(
override fun dnsEnd(
call: Call,
domainName: String,
inetAddressList: List<InetAddress>
inetAddressList: List<InetAddress>,
) {
logWithTime("dnsEnd: $inetAddressList")
@ -79,7 +79,7 @@ class ClientRuleEventListener(
override fun connectStart(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy
proxy: Proxy,
) {
logWithTime("connectStart: $inetSocketAddress $proxy")
@ -94,7 +94,7 @@ class ClientRuleEventListener(
override fun secureConnectEnd(
call: Call,
handshake: Handshake?
handshake: Handshake?,
) {
logWithTime("secureConnectEnd: $handshake")
@ -105,7 +105,7 @@ class ClientRuleEventListener(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?
protocol: Protocol?,
) {
logWithTime("connectEnd: $protocol")
@ -117,7 +117,7 @@ class ClientRuleEventListener(
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?,
ioe: IOException
ioe: IOException,
) {
logWithTime("connectFailed: $protocol $ioe")
@ -126,7 +126,7 @@ class ClientRuleEventListener(
override fun connectionAcquired(
call: Call,
connection: Connection
connection: Connection,
) {
logWithTime("connectionAcquired: $connection")
@ -135,7 +135,7 @@ class ClientRuleEventListener(
override fun connectionReleased(
call: Call,
connection: Connection
connection: Connection,
) {
logWithTime("connectionReleased")
@ -150,7 +150,7 @@ class ClientRuleEventListener(
override fun requestHeadersEnd(
call: Call,
request: Request
request: Request,
) {
logWithTime("requestHeadersEnd")
@ -165,7 +165,7 @@ class ClientRuleEventListener(
override fun requestBodyEnd(
call: Call,
byteCount: Long
byteCount: Long,
) {
logWithTime("requestBodyEnd: byteCount=$byteCount")
@ -174,7 +174,7 @@ class ClientRuleEventListener(
override fun requestFailed(
call: Call,
ioe: IOException
ioe: IOException,
) {
logWithTime("requestFailed: $ioe")
@ -189,7 +189,7 @@ class ClientRuleEventListener(
override fun responseHeadersEnd(
call: Call,
response: Response
response: Response,
) {
logWithTime("responseHeadersEnd: $response")
@ -204,7 +204,7 @@ class ClientRuleEventListener(
override fun responseBodyEnd(
call: Call,
byteCount: Long
byteCount: Long,
) {
logWithTime("responseBodyEnd: byteCount=$byteCount")
@ -213,7 +213,7 @@ class ClientRuleEventListener(
override fun responseFailed(
call: Call,
ioe: IOException
ioe: IOException,
) {
logWithTime("responseFailed: $ioe")
@ -228,7 +228,7 @@ class ClientRuleEventListener(
override fun callFailed(
call: Call,
ioe: IOException
ioe: IOException,
) {
logWithTime("callFailed: $ioe")
@ -241,7 +241,10 @@ class ClientRuleEventListener(
delegate.canceled(call)
}
override fun satisfactionFailure(call: Call, response: Response) {
override fun satisfactionFailure(
call: Call,
response: Response,
) {
logWithTime("satisfactionFailure")
delegate.satisfactionFailure(call, response)
@ -253,13 +256,19 @@ class ClientRuleEventListener(
delegate.cacheMiss(call)
}
override fun cacheHit(call: Call, response: Response) {
override fun cacheHit(
call: Call,
response: Response,
) {
logWithTime("cacheHit")
delegate.cacheHit(call, response)
}
override fun cacheConditionalHit(call: Call, cachedResponse: Response) {
override fun cacheConditionalHit(
call: Call,
cachedResponse: Response,
) {
logWithTime("cacheConditionalHit")
delegate.cacheConditionalHit(call, cachedResponse)
@ -267,7 +276,8 @@ class ClientRuleEventListener(
private fun logWithTime(message: String) {
val startNs = startNs
val timeMs = if (startNs == null) {
val timeMs =
if (startNs == null) {
// Event occurred before start, for an example an early cancel.
0L
} else {

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3
import java.io.IOException
@ -35,17 +36,16 @@ sealed class ConnectionEvent {
data class ConnectStart(
override val timestampNs: Long,
val route: Route,
val call: Call
val call: Call,
) : ConnectionEvent()
data class ConnectFailed(
override val timestampNs: Long,
val route: Route,
val call: Call,
val exception: IOException
val exception: IOException,
) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean =
event is ConnectStart && call == event.call && route == event.route
override fun closes(event: ConnectionEvent): Boolean = event is ConnectStart && call == event.call && route == event.route
}
data class ConnectEnd(
@ -54,8 +54,7 @@ sealed class ConnectionEvent {
val route: Route,
val call: Call,
) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean =
event is ConnectStart && call == event.call && route == event.route
override fun closes(event: ConnectionEvent): Boolean = event is ConnectStart && call == event.call && route == event.route
}
data class ConnectionClosed(
@ -66,15 +65,14 @@ sealed class ConnectionEvent {
data class ConnectionAcquired(
override val timestampNs: Long,
override val connection: Connection,
val call: Call
val call: Call,
) : ConnectionEvent()
data class ConnectionReleased(
override val timestampNs: Long,
override val connection: Connection,
val call: Call
val call: Call,
) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean =
event is ConnectionAcquired && connection == event.connection && call == event.call
}

View File

@ -24,7 +24,6 @@ import javax.security.cert.X509Certificate
/** An [SSLSession] that delegates all calls. */
abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSession {
override fun getId(): ByteArray {
return delegate!!.id
}
@ -49,7 +48,10 @@ abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSe
return delegate!!.isValid
}
override fun putValue(s: String, o: Any) {
override fun putValue(
s: String,
o: Any,
) {
delegate!!.putValue(s, o)
}

View File

@ -198,7 +198,10 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
}
@Throws(SocketException::class)
override fun setSoLinger(on: Boolean, timeout: Int) {
override fun setSoLinger(
on: Boolean,
timeout: Int,
) {
delegate!!.setSoLinger(on, timeout)
}
@ -247,7 +250,10 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
}
@Throws(IOException::class)
override fun connect(remoteAddr: SocketAddress, timeout: Int) {
override fun connect(
remoteAddr: SocketAddress,
timeout: Int,
) {
delegate!!.connect(remoteAddr, timeout)
}

View File

@ -33,28 +33,40 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
}
@Throws(IOException::class)
override fun createSocket(host: String, port: Int): SSLSocket {
override fun createSocket(
host: String,
port: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port) as SSLSocket
return configureSocket(sslSocket)
}
@Throws(IOException::class)
override fun createSocket(
host: String, port: Int, localAddress: InetAddress, localPort: Int
host: String,
port: Int,
localAddress: InetAddress,
localPort: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket
return configureSocket(sslSocket)
}
@Throws(IOException::class)
override fun createSocket(host: InetAddress, port: Int): SSLSocket {
override fun createSocket(
host: InetAddress,
port: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port) as SSLSocket
return configureSocket(sslSocket)
}
@Throws(IOException::class)
override fun createSocket(
host: InetAddress, port: Int, localAddress: InetAddress, localPort: Int
host: InetAddress,
port: Int,
localAddress: InetAddress,
localPort: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket
return configureSocket(sslSocket)
@ -70,7 +82,10 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
@Throws(IOException::class)
override fun createSocket(
socket: Socket, host: String, port: Int, autoClose: Boolean
socket: Socket,
host: String,
port: Int,
autoClose: Boolean,
): SSLSocket {
val sslSocket = delegate.createSocket(socket, host, port, autoClose) as SSLSocket
return configureSocket(sslSocket)

View File

@ -38,13 +38,20 @@ open class DelegatingServerSocketFactory(private val delegate: ServerSocketFacto
}
@Throws(IOException::class)
override fun createServerSocket(port: Int, backlog: Int): ServerSocket {
override fun createServerSocket(
port: Int,
backlog: Int,
): ServerSocket {
val serverSocket = delegate.createServerSocket(port, backlog)
return configureServerSocket(serverSocket)
}
@Throws(IOException::class)
override fun createServerSocket(port: Int, backlog: Int, ifAddress: InetAddress): ServerSocket {
override fun createServerSocket(
port: Int,
backlog: Int,
ifAddress: InetAddress,
): ServerSocket {
val serverSocket = delegate.createServerSocket(port, backlog, ifAddress)
return configureServerSocket(serverSocket)
}

View File

@ -32,30 +32,40 @@ open class DelegatingSocketFactory(private val delegate: SocketFactory) : Socket
}
@Throws(IOException::class)
override fun createSocket(host: String, port: Int): Socket {
override fun createSocket(
host: String,
port: Int,
): Socket {
val socket = delegate.createSocket(host, port)
return configureSocket(socket)
}
@Throws(IOException::class)
override fun createSocket(
host: String, port: Int, localAddress: InetAddress,
localPort: Int
host: String,
port: Int,
localAddress: InetAddress,
localPort: Int,
): Socket {
val socket = delegate.createSocket(host, port, localAddress, localPort)
return configureSocket(socket)
}
@Throws(IOException::class)
override fun createSocket(host: InetAddress, port: Int): Socket {
override fun createSocket(
host: InetAddress,
port: Int,
): Socket {
val socket = delegate.createSocket(host, port)
return configureSocket(socket)
}
@Throws(IOException::class)
override fun createSocket(
host: InetAddress, port: Int, localAddress: InetAddress,
localPort: Int
host: InetAddress,
port: Int,
localAddress: InetAddress,
localPort: Int,
): Socket {
val socket = delegate.createSocket(host, port, localAddress, localPort)
return configureSocket(socket)

View File

@ -29,7 +29,7 @@ class FakeDns : Dns {
/** Sets the results for `hostname`. */
operator fun set(
hostname: String,
addresses: List<InetAddress>
addresses: List<InetAddress>,
): FakeDns {
hostAddresses[hostname] = addresses
return this
@ -41,9 +41,10 @@ class FakeDns : Dns {
return this
}
@Throws(UnknownHostException::class) fun lookup(
@Throws(UnknownHostException::class)
fun lookup(
hostname: String,
index: Int
index: Int,
): InetAddress {
return hostAddresses[hostname]!![index]
}
@ -66,7 +67,7 @@ class FakeDns : Dns {
return (from until nextAddress)
.map {
return@map InetAddress.getByAddress(
Buffer().writeInt(it.toInt()).readByteArray()
Buffer().writeInt(it.toInt()).readByteArray(),
)
}
}
@ -78,7 +79,7 @@ class FakeDns : Dns {
return (from until nextAddress)
.map {
return@map InetAddress.getByAddress(
Buffer().writeLong(0L).writeLong(it).readByteArray()
Buffer().writeLong(0L).writeLong(it).readByteArray(),
)
}
}

View File

@ -37,7 +37,7 @@ class FakeProxySelector : ProxySelector() {
override fun connectFailed(
uri: URI,
sa: SocketAddress,
ioe: IOException
ioe: IOException,
) {
}
}

View File

@ -67,7 +67,7 @@ class FakeSSLSession(vararg val certificates: Certificate) : SSLSession {
}
@Throws(
SSLPeerUnverifiedException::class
SSLPeerUnverifiedException::class,
)
override fun getPeerCertificateChain(): Array<X509Certificate> {
throw UnsupportedOperationException()
@ -96,7 +96,7 @@ class FakeSSLSession(vararg val certificates: Certificate) : SSLSession {
override fun putValue(
s: String,
obj: Any
obj: Any,
) {
throw UnsupportedOperationException()
}

View File

@ -20,6 +20,7 @@ import okio.BufferedSink
open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() {
private val delegate: RequestBody
fun delegate(): RequestBody {
return delegate
}
@ -28,7 +29,8 @@ open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() {
return delegate.contentType()
}
@Throws(IOException::class) override fun contentLength(): Long {
@Throws(IOException::class)
override fun contentLength(): Long {
return delegate.contentLength()
}

View File

@ -19,6 +19,7 @@ import okio.BufferedSource
open class ForwardingResponseBody(delegate: ResponseBody?) : ResponseBody() {
private val delegate: ResponseBody
fun delegate(): ResponseBody {
return delegate
}

View File

@ -22,11 +22,16 @@ import java.util.logging.LogRecord
object JsseDebugLogging {
data class JsseDebugMessage(val message: String, val param: String?) {
enum class Type {
Handshake, Plaintext, Encrypted, Setup, Unknown
Handshake,
Plaintext,
Encrypted,
Setup,
Unknown,
}
val type: Type
get() = when {
get() =
when {
message == "adding as trusted certificates" -> Type.Setup
message == "Raw read" || message == "Raw write" -> Type.Encrypted
message == "Plaintext before ENCRYPTION" || message == "Plaintext after DECRYPTION" -> Type.Plaintext
@ -66,7 +71,9 @@ object JsseDebugLogging {
fun enableJsseDebugLogging(debugHandler: (JsseDebugMessage) -> Unit = this::quietDebug): Closeable {
System.setProperty("javax.net.debug", "")
return OkHttpDebugLogging.enable("javax.net.ssl", object : Handler() {
return OkHttpDebugLogging.enable(
"javax.net.ssl",
object : Handler() {
override fun publish(record: LogRecord) {
val param = record.parameters?.firstOrNull() as? String
debugHandler(JsseDebugMessage(record.message, param))
@ -77,6 +84,7 @@ object JsseDebugLogging {
override fun close() {
}
})
},
)
}
}

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3
import android.annotation.SuppressLint
@ -62,7 +63,8 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
var recordFrames = false
var recordSslDebug = false
private val sslExcludeFilter = Regex(
private val sslExcludeFilter =
Regex(
buildString {
append("^(?:")
append(
@ -75,15 +77,17 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
"adding as trusted certificates",
"Ignore disabled cipher suite",
"Ignore unsupported cipher suite",
).joinToString(separator = "|")
).joinToString(separator = "|"),
)
append(").*")
}
},
)
private val testLogHandler = object : Handler() {
private val testLogHandler =
object : Handler() {
override fun publish(record: LogRecord) {
val recorded = when (record.loggerName) {
val recorded =
when (record.loggerName) {
TaskRunner::class.java.name -> recordTaskRunner
Http2::class.java.name -> recordFrames
"javax.net.ssl" -> recordSslDebug && !sslExcludeFilter.matches(record.message)
@ -122,8 +126,7 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
Logger.getLogger("javax.net.ssl").fn()
}
fun wrap(eventListener: EventListener) =
EventListener.Factory { ClientRuleEventListener(eventListener, ::addEvent) }
fun wrap(eventListener: EventListener) = EventListener.Factory { ClientRuleEventListener(eventListener, ::addEvent) }
fun wrap(eventListenerFactory: EventListener.Factory) =
EventListener.Factory { call -> ClientRuleEventListener(eventListenerFactory.create(call), ::addEvent) }
@ -140,7 +143,8 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
fun newClient(): OkHttpClient {
var client = testClient
if (client == null) {
client = initialClientBuilder()
client =
initialClientBuilder()
.dns(SINGLE_INET_ADDRESS_DNS) // Prevent unexpected fallback addresses.
.eventListenerFactory { ClientRuleEventListener(logger = ::addEvent) }
.build()
@ -151,7 +155,8 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
return client
}
private fun initialClientBuilder(): OkHttpClient.Builder = if (isLoom()) {
private fun initialClientBuilder(): OkHttpClient.Builder =
if (isLoom()) {
val backend = TaskRunner.RealBackend(loomThreadFactory())
val taskRunner = TaskRunner(backend)
@ -160,7 +165,7 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
buildConnectionPool(
connectionListener = connectionListener,
taskRunner = taskRunner,
)
),
)
.dispatcher(Dispatcher(backend.executor))
.taskRunnerInternal(taskRunner)
@ -322,7 +327,8 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
* A network that resolves only one IP address per host. Use this when testing route selection
* fallbacks to prevent the host machine's various IP addresses from interfering.
*/
private val SINGLE_INET_ADDRESS_DNS = Dns { hostname ->
private val SINGLE_INET_ADDRESS_DNS =
Dns { hostname ->
val addresses = Dns.SYSTEM.lookup(hostname)
listOf(addresses[0])
}

Some files were not shown because too many files have changed in this diff Show More