1
0
mirror of https://github.com/square/okhttp.git synced 2025-07-31 05:04:26 +03:00

Update dependency com.diffplug.spotless:spotless-plugin-gradle to v7 (#8702)

* Update dependency com.diffplug.spotless:spotless-plugin-gradle to v7

* Reformat

---------

Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com>
Co-authored-by: Jake Wharton <jw@squareup.com>
This commit is contained in:
renovate[bot]
2025-03-19 15:25:20 -04:00
committed by GitHub
parent c4d472cab7
commit a51cfbf841
304 changed files with 6747 additions and 4401 deletions

View File

@ -58,10 +58,12 @@ class AndroidAsyncDnsTest {
private val localhost: HandshakeCertificates by lazy { private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust. // Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.addSubjectAlternativeName("localhost") .addSubjectAlternativeName("localhost")
.build() .build()
return@lazy HandshakeCertificates.Builder() return@lazy HandshakeCertificates
.Builder()
.addPlatformTrustedCertificates() .addPlatformTrustedCertificates()
.heldCertificate(heldCertificate) .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate) .addTrustedCertificate(heldCertificate.certificate)
@ -73,7 +75,8 @@ class AndroidAsyncDnsTest {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6)) .dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager) .sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build() .build()
@ -187,7 +190,8 @@ class AndroidAsyncDnsTest {
connectivityManager.activeNetwork ?: throw AssumptionViolatedException("No active network") connectivityManager.activeNetwork ?: throw AssumptionViolatedException("No active network")
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6)) .dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.socketFactory(network.socketFactory) .socketFactory(network.socketFactory)
.build() .build()

View File

@ -110,7 +110,8 @@ class OkHttpTest {
private var client: OkHttpClient = clientTestRule.newClient() private var client: OkHttpClient = clientTestRule.newClient()
private val moshi = private val moshi =
Moshi.Builder() Moshi
.Builder()
.add(KotlinJsonAdapterFactory()) .add(KotlinJsonAdapterFactory())
.build() .build()
@ -144,17 +145,18 @@ class OkHttpTest {
val request = Request.Builder().url("https://api.twitter.com/robots.txt").build() val request = Request.Builder().url("https://api.twitter.com/robots.txt").build()
val clientCertificates = val clientCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addPlatformTrustedCertificates() .addPlatformTrustedCertificates()
.apply { .apply {
if (Build.VERSION.SDK_INT >= 24) { if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName) addInsecureHost(server.hostName)
} }
} }.build()
.build()
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager) .sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build() .build()
@ -194,14 +196,16 @@ class OkHttpTest {
var socketClass: String? = null var socketClass: String? = null
val clientCertificates = val clientCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addPlatformTrustedCertificates() .addPlatformTrustedCertificates()
.addInsecureHost(server.hostName) .addInsecureHost(server.hostName)
.build() .build()
// Need fresh client to reset sslSocketFactoryOrNull // Need fresh client to reset sslSocketFactoryOrNull
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.eventListenerFactory( .eventListenerFactory(
clientTestRule.wrap( clientTestRule.wrap(
object : EventListener() { object : EventListener() {
@ -213,8 +217,7 @@ class OkHttpTest {
} }
}, },
), ),
) ).sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build() .build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -257,7 +260,8 @@ class OkHttpTest {
} }
val clientCertificates = val clientCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addPlatformTrustedCertificates() .addPlatformTrustedCertificates()
.addInsecureHost(server.hostName) .addInsecureHost(server.hostName)
.build() .build()
@ -268,7 +272,8 @@ class OkHttpTest {
// Need fresh client to reset sslSocketFactoryOrNull // Need fresh client to reset sslSocketFactoryOrNull
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.eventListenerFactory( .eventListenerFactory(
clientTestRule.wrap( clientTestRule.wrap(
object : EventListener() { object : EventListener() {
@ -280,8 +285,7 @@ class OkHttpTest {
} }
}, },
), ),
) ).sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build() .build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -322,16 +326,18 @@ class OkHttpTest {
var socketClass: String? = null var socketClass: String? = null
val clientCertificates = val clientCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.addPlatformTrustedCertificates().apply { .Builder()
.addPlatformTrustedCertificates()
.apply {
if (Build.VERSION.SDK_INT >= 24) { if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName) addInsecureHost(server.hostName)
} }
} }.build()
.build()
client = client =
client.newBuilder() client
.newBuilder()
.eventListenerFactory( .eventListenerFactory(
clientTestRule.wrap( clientTestRule.wrap(
object : EventListener() { object : EventListener() {
@ -343,8 +349,7 @@ class OkHttpTest {
} }
}, },
), ),
) ).sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build() .build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -458,7 +463,8 @@ class OkHttpTest {
enableTls() enableTls()
val certificatePinner = val certificatePinner =
CertificatePinner.Builder() CertificatePinner
.Builder()
.add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=") .add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")
.build() .build()
client = client.newBuilder().certificatePinner(certificatePinner).build() client = client.newBuilder().certificatePinner(certificatePinner).build()
@ -479,12 +485,12 @@ class OkHttpTest {
enableTls() enableTls()
val certificatePinner = val certificatePinner =
CertificatePinner.Builder() CertificatePinner
.Builder()
.add( .add(
server.hostName, server.hostName,
CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0]), CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0]),
) ).build()
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build() client = client.newBuilder().certificatePinner(certificatePinner).build()
server.enqueue(MockResponse(body = "abc")) server.enqueue(MockResponse(body = "abc"))
@ -517,10 +523,23 @@ class OkHttpTest {
assertEquals( assertEquals(
listOf( listOf(
"CallStart", "ProxySelectStart", "ProxySelectEnd", "DnsStart", "DnsEnd", "CallStart",
"ConnectStart", "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ProxySelectStart",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", "ProxySelectEnd",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "DnsStart",
"DnsEnd",
"ConnectStart",
"SecureConnectStart",
"SecureConnectEnd",
"ConnectEnd",
"ConnectionAcquired",
"RequestHeadersStart",
"RequestHeadersEnd",
"ResponseHeadersStart",
"ResponseHeadersEnd",
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd", "CallEnd",
), ),
eventListener.recordedEventTypes(), eventListener.recordedEventTypes(),
@ -535,8 +554,14 @@ class OkHttpTest {
assertEquals( assertEquals(
listOf( listOf(
"CallStart", "CallStart",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", "ConnectionAcquired",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "RequestHeadersStart",
"RequestHeadersEnd",
"ResponseHeadersStart",
"ResponseHeadersEnd",
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd", "CallEnd",
), ),
eventListener.recordedEventTypes(), eventListener.recordedEventTypes(),
@ -550,20 +575,26 @@ class OkHttpTest {
enableTls() enableTls()
client = client =
client.newBuilder().eventListenerFactory( client
clientTestRule.wrap( .newBuilder()
object : EventListener() { .eventListenerFactory(
override fun connectionAcquired( clientTestRule.wrap(
call: Call, object : EventListener() {
connection: Connection, override fun connectionAcquired(
) { call: Call,
val sslSocket = connection.socket() as SSLSocket connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
sessionIds.add(sslSocket.session.id.toByteString().hex()) sessionIds.add(
} sslSocket.session.id
}, .toByteString()
), .hex(),
).build() )
}
},
),
).build()
server.enqueue(MockResponse(body = "abc1")) server.enqueue(MockResponse(body = "abc1"))
server.enqueue(MockResponse(body = "abc2")) server.enqueue(MockResponse(body = "abc2"))
@ -590,13 +621,18 @@ class OkHttpTest {
assumeNetwork() assumeNetwork()
client = client =
client.newBuilder() client
.newBuilder()
.eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory())) .eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory()))
.build() .build()
val dohDns = buildCloudflareIp(client) val dohDns = buildCloudflareIp(client)
val dohEnabledClient = val dohEnabledClient =
client.newBuilder().eventListener(EventListener.NONE).dns(dohDns).build() client
.newBuilder()
.eventListener(EventListener.NONE)
.dns(dohDns)
.build()
dohEnabledClient.get("https://www.twitter.com/robots.txt") dohEnabledClient.get("https://www.twitter.com/robots.txt")
dohEnabledClient.get("https://www.facebook.com/robots.txt") dohEnabledClient.get("https://www.facebook.com/robots.txt")
@ -630,7 +666,8 @@ class OkHttpTest {
val hostnameVerifier = HostnameVerifier { _, _ -> true } val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager) .sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier) .hostnameVerifier(hostnameVerifier)
.build() .build()
@ -649,17 +686,15 @@ class OkHttpTest {
val delegatingSocketFactory = val delegatingSocketFactory =
object : DelegatingSSLSocketFactory(sslSocketFactory) { object : DelegatingSSLSocketFactory(sslSocketFactory) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket { override fun configureSocket(sslSocket: SSLSocket): SSLSocket =
return object : DelegatingSSLSocket(sslSocket) { object : DelegatingSSLSocket(sslSocket) {
override fun getApplicationProtocol(): String { override fun getApplicationProtocol(): String = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
} }
}
} }
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(delegatingSocketFactory, trustManager) .sslSocketFactory(delegatingSocketFactory, trustManager)
.build() .build()
@ -714,7 +749,8 @@ class OkHttpTest {
val hostnameVerifier = HostnameVerifier { _, _ -> true } val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager) .sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier) .hostnameVerifier(hostnameVerifier)
.build() .build()
@ -748,7 +784,13 @@ class OkHttpTest {
} }
is CertificateException -> { is CertificateException -> {
assertTrue(ioe.cause?.cause is IllegalArgumentException) assertTrue(ioe.cause?.cause is IllegalArgumentException)
assertEquals(true, ioe.cause?.cause?.message?.startsWith("Invalid input to toASCII")) assertEquals(
true,
ioe.cause
?.cause
?.message
?.startsWith("Invalid input to toASCII"),
)
} }
else -> throw ioe else -> throw ioe
} }
@ -767,9 +809,12 @@ class OkHttpTest {
var socketClass: String? = null var socketClass: String? = null
val trustManager = val trustManager =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply { TrustManagerFactory
init(null as KeyStore?) .getInstance(TrustManagerFactory.getDefaultAlgorithm())
}.trustManagers.first() as X509TrustManager .apply {
init(null as KeyStore?)
}.trustManagers
.first() as X509TrustManager
val sslContext = val sslContext =
Platform.get().newSSLContext().apply { Platform.get().newSSLContext().apply {
@ -778,7 +823,8 @@ class OkHttpTest {
} }
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(sslContext.socketFactory, trustManager) .sslSocketFactory(sslContext.socketFactory, trustManager)
.eventListenerFactory( .eventListenerFactory(
clientTestRule.wrap( clientTestRule.wrap(
@ -791,8 +837,7 @@ class OkHttpTest {
} }
}, },
), ),
) ).build()
.build()
val request = Request.Builder().url("https://facebook.com/robots.txt").build() val request = Request.Builder().url("https://facebook.com/robots.txt").build()
@ -820,7 +865,8 @@ class OkHttpTest {
val calls = mutableMapOf<String, AtomicInteger>() val calls = mutableMapOf<String, AtomicInteger>()
override fun publish(record: LogRecord) { override fun publish(record: LogRecord) {
calls.getOrPut(record.loggerName) { AtomicInteger(0) } calls
.getOrPut(record.loggerName) { AtomicInteger(0) }
.incrementAndGet() .incrementAndGet()
} }
@ -833,26 +879,33 @@ class OkHttpTest {
level = Level.FINEST level = Level.FINEST
} }
Logger.getLogger("") Logger
.getLogger("")
.addHandler(testHandler) .addHandler(testHandler)
Logger.getLogger("okhttp3") Logger
.getLogger("okhttp3")
.addHandler(testHandler) .addHandler(testHandler)
Logger.getLogger(Http2::class.java.name) Logger
.getLogger(Http2::class.java.name)
.addHandler(testHandler) .addHandler(testHandler)
Logger.getLogger(TaskRunner::class.java.name) Logger
.getLogger(TaskRunner::class.java.name)
.addHandler(testHandler) .addHandler(testHandler)
Logger.getLogger(OkHttpClient::class.java.name) Logger
.getLogger(OkHttpClient::class.java.name)
.addHandler(testHandler) .addHandler(testHandler)
server.enqueue(MockResponse(body = "abc")) server.enqueue(MockResponse(body = "abc"))
val request = val request =
Request.Builder() Request
.Builder()
.url(server.url("/")) .url(server.url("/"))
.build() .build()
val response = val response =
client.newCall(request) client
.newCall(request)
.execute() .execute()
response.use { response.use {
@ -878,16 +931,19 @@ class OkHttpTest {
try { try {
client = client =
client.newBuilder() client
.newBuilder()
.cache(cache) .cache(cache)
.build() .build()
val request = val request =
Request.Builder() Request
.Builder()
.url(server.url("/")) .url(server.url("/"))
.build() .build()
client.newCall(request) client
.newCall(request)
.execute() .execute()
.use { .use {
assertEquals(200, it.code) assertEquals(200, it.code)
@ -898,7 +954,8 @@ class OkHttpTest {
assertTrue(it.cacheControl.isPublic) assertTrue(it.cacheControl.isPublic)
} }
client.newCall(request) client
.newCall(request)
.execute() .execute()
.use { .use {
assertEquals(200, it.code) assertEquals(200, it.code)
@ -923,20 +980,21 @@ class OkHttpTest {
} }
} }
fun buildCloudflareIp(bootstrapClient: OkHttpClient): DnsOverHttps { fun buildCloudflareIp(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder().client(bootstrapClient) DnsOverHttps
.Builder()
.client(bootstrapClient)
.url("https://1.1.1.1/dns-query".toHttpUrl()) .url("https://1.1.1.1/dns-query".toHttpUrl())
.build() .build()
}
private fun enableTls() { private fun enableTls() {
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory( .sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager, handshakeCertificates.trustManager,
) ).build()
.build()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
} }

View File

@ -24,7 +24,8 @@ class StrictModeTest {
@After @After
fun cleanup() { fun cleanup() {
StrictMode.setThreadPolicy( StrictMode.setThreadPolicy(
ThreadPolicy.Builder() ThreadPolicy
.Builder()
.permitAll() .permitAll()
.build(), .build(),
) )
@ -60,12 +61,12 @@ class StrictModeTest {
private fun applyStrictMode() { private fun applyStrictMode() {
StrictMode.setThreadPolicy( StrictMode.setThreadPolicy(
ThreadPolicy.Builder() ThreadPolicy
.Builder()
.detectCustomSlowCalls() .detectCustomSlowCalls()
.penaltyListener({ it.run() }) { .penaltyListener({ it.run() }) {
violations.add(it) violations.add(it)
} }.build(),
.build(),
) )
} }
} }

View File

@ -56,16 +56,17 @@ class AlpnOverrideTest {
@Test @Test
fun getWithCustomSocketFactory() { fun getWithCustomSocketFactory() {
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!) .sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.connectionSpecs( .connectionSpecs(
listOf( listOf(
ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) ConnectionSpec
.Builder(ConnectionSpec.MODERN_TLS)
.supportsTlsExtensions(false) .supportsTlsExtensions(false)
.build(), .build(),
), ),
) ).eventListener(
.eventListener(
object : EventListener() { object : EventListener() {
override fun connectionAcquired( override fun connectionAcquired(
call: Call, call: Call,
@ -76,11 +77,11 @@ class AlpnOverrideTest {
println("Negotiated " + sslSocket.applicationProtocol) println("Negotiated " + sslSocket.applicationProtocol)
} }
}, },
) ).build()
.build()
val request = val request =
Request.Builder() Request
.Builder()
.url("https://www.google.com") .url("https://www.google.com")
.build() .build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->

View File

@ -77,7 +77,8 @@ class LetsEncryptClientTest {
""".trimIndent().decodeCertificatePem() """.trimIndent().decodeCertificatePem()
val handshakeCertificates = val handshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
// TODO reenable in official answers // TODO reenable in official answers
// .addPlatformTrustedCertificates() // .addPlatformTrustedCertificates()
.addTrustedCertificate(cert) .addTrustedCertificate(cert)
@ -93,7 +94,8 @@ class LetsEncryptClientTest {
val client = clientBuilder.build() val client = clientBuilder.build()
val request = val request =
Request.Builder() Request
.Builder()
.url("https://valid-isrgrootx1.letsencrypt.org/robots.txt") .url("https://valid-isrgrootx1.letsencrypt.org/robots.txt")
.build() .build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->

View File

@ -40,7 +40,8 @@ import org.junit.jupiter.api.Test
@Tag("Remote") @Tag("Remote")
class SniOverrideTest { class SniOverrideTest {
var client = var client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.build() .build()
@Test @Test
@ -64,7 +65,8 @@ class SniOverrideTest {
} }
client = client =
client.newBuilder() client
.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!) .sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.hostnameVerifier { hostname, session -> .hostnameVerifier { hostname, session ->
val s = "hostname: $hostname peerHost:${session.peerHost}" val s = "hostname: $hostname peerHost:${session.peerHost}"
@ -80,11 +82,11 @@ class SniOverrideTest {
} catch (e: Exception) { } catch (e: Exception) {
false false
} }
} }.build()
.build()
val request = val request =
Request.Builder() Request
.Builder()
.url("https://sni.cloudflaressl.com/cdn-cgi/trace") .url("https://sni.cloudflaressl.com/cdn-cgi/trace")
.header("Host", "cloudflare-dns.com") .header("Host", "cloudflare-dns.com")
.build() .build()
@ -99,14 +101,15 @@ class SniOverrideTest {
@Test @Test
fun getWithDns() { fun getWithDns() {
client = client =
client.newBuilder() client
.newBuilder()
.dns { .dns {
Dns.SYSTEM.lookup("sni.cloudflaressl.com") Dns.SYSTEM.lookup("sni.cloudflaressl.com")
} }.build()
.build()
val request = val request =
Request.Builder() Request
.Builder()
.url("https://cloudflare-dns.com/cdn-cgi/trace") .url("https://cloudflare-dns.com/cdn-cgi/trace")
.build() .build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->

View File

@ -40,7 +40,9 @@ import org.robolectric.ParameterizedRobolectricTestRunner
import org.robolectric.ParameterizedRobolectricTestRunner.Parameters import org.robolectric.ParameterizedRobolectricTestRunner.Parameters
@RunWith(ParameterizedRobolectricTestRunner::class) @RunWith(ParameterizedRobolectricTestRunner::class)
class AndroidSocketAdapterTest(val adapter: SocketAdapter) { class AndroidSocketAdapterTest(
val adapter: SocketAdapter,
) {
val context: SSLContext by lazy { val context: SSLContext by lazy {
val provider: Provider = Conscrypt.newProviderBuilder().provideTrustManager(true).build() val provider: Provider = Conscrypt.newProviderBuilder().provideTrustManager(true).build()
@ -95,12 +97,11 @@ class AndroidSocketAdapterTest(val adapter: SocketAdapter) {
companion object { companion object {
@JvmStatic @JvmStatic
@Parameters(name = "{0}") @Parameters(name = "{0}")
fun data(): Collection<SocketAdapter> { fun data(): Collection<SocketAdapter> =
return listOfNotNull( listOfNotNull(
DeferredSocketAdapter(ConscryptSocketAdapter.factory), DeferredSocketAdapter(ConscryptSocketAdapter.factory),
DeferredSocketAdapter(AndroidSocketAdapter.factory("org.conscrypt")), DeferredSocketAdapter(AndroidSocketAdapter.factory("org.conscrypt")),
StandardAndroidSocketAdapter.buildIfSupported("org.conscrypt"), StandardAndroidSocketAdapter.buildIfSupported("org.conscrypt"),
) )
}
} }
} }

View File

@ -50,7 +50,8 @@ class RobolectricOkHttpClientTest {
fun setUp() { fun setUp() {
context = ApplicationProvider.getApplicationContext<Application>() context = ApplicationProvider.getApplicationContext<Application>()
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.cache(Cache(FakeFileSystem(), "/cache".toPath(), 10_000_000)) .cache(Cache(FakeFileSystem(), "/cache".toPath(), 10_000_000))
.build() .build()
} }
@ -62,7 +63,8 @@ class RobolectricOkHttpClientTest {
val request = Request("https://www.google.com/robots.txt".toHttpUrl()) val request = Request("https://www.google.com/robots.txt".toHttpUrl())
val networkRequest = val networkRequest =
request.newBuilder() request
.newBuilder()
.build() .build()
val call = client.newCall(networkRequest) val call = client.newCall(networkRequest)

View File

@ -15,8 +15,8 @@
*/ */
// https://www.eclipse.org/jetty/documentation/current/alpn-chapter.html#alpn-versions // https://www.eclipse.org/jetty/documentation/current/alpn-chapter.html#alpn-versions
private fun alpnBootVersionForPatchVersion(patchVersion: Int): String? { private fun alpnBootVersionForPatchVersion(patchVersion: Int): String? =
return when (patchVersion) { when (patchVersion) {
in 0..24 -> "8.1.0.v20141016" in 0..24 -> "8.1.0.v20141016"
in 25..30 -> "8.1.2.v20141202" in 25..30 -> "8.1.2.v20141202"
in 31..50 -> "8.1.3.v20150130" in 31..50 -> "8.1.3.v20150130"
@ -32,7 +32,6 @@ private fun alpnBootVersionForPatchVersion(patchVersion: Int): String? {
in 191..242 -> "8.1.13.v20181017" in 191..242 -> "8.1.13.v20181017"
else -> null else -> null
} }
}
/** /**
* Returns the alpn-boot version specific to this OpenJDK 8 JVM, or null if this is not a Java 8 VM. * Returns the alpn-boot version specific to this OpenJDK 8 JVM, or null if this is not a Java 8 VM.

View File

@ -39,8 +39,12 @@ private fun Project.applyOsgi(
val osgi = project.sourceSets.create("osgi") val osgi = project.sourceSets.create("osgi")
val osgiApi = project.configurations.getByName(osgiApiConfigurationName) val osgiApi = project.configurations.getByName(osgiApiConfigurationName)
val kotlinOsgi = val kotlinOsgi =
extensions.getByType(VersionCatalogsExtension::class.java).named("libs") extensions
.findLibrary("kotlin.stdlib.osgi").get().get() .getByType(VersionCatalogsExtension::class.java)
.named("libs")
.findLibrary("kotlin.stdlib.osgi")
.get()
.get()
project.dependencies { project.dependencies {
osgiApi(kotlinOsgi) osgiApi(kotlinOsgi)

View File

@ -66,7 +66,8 @@ class BasicLoomTest {
assertThat(System.getProperty("jdk.tracePinnedThreads")).isNotEmpty() assertThat(System.getProperty("jdk.tracePinnedThreads")).isNotEmpty()
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.trustMockServer() .trustMockServer()
.dispatcher(Dispatcher(newVirtualThreadPerTaskExecutor())) .dispatcher(Dispatcher(newVirtualThreadPerTaskExecutor()))
.build() .build()
@ -83,19 +84,18 @@ class BasicLoomTest {
assertThat(capturedOut.toString()).isEmpty() assertThat(capturedOut.toString()).isEmpty()
} }
private fun newVirtualThreadPerTaskExecutor(): ExecutorService { private fun newVirtualThreadPerTaskExecutor(): ExecutorService =
return Executors::class.java.getMethod("newVirtualThreadPerTaskExecutor").invoke(null) as ExecutorService Executors::class.java.getMethod("newVirtualThreadPerTaskExecutor").invoke(null) as ExecutorService
}
@Test @Test
fun testHttpsRequest() { fun testHttpsRequest() {
MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient -> MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient ->
mockServerClient mockServerClient
.`when`( .`when`(
request().withPath("/person") request()
.withPath("/person")
.withQueryStringParameter("name", "peter"), .withQueryStringParameter("name", "peter"),
) ).respond(response().withBody("Peter the person!"))
.respond(response().withBody("Peter the person!"))
val results = val results =
(1..20).map { (1..20).map {

View File

@ -40,7 +40,8 @@ class BasicMockServerTest {
val mockServer: MockServerContainer = MockServerContainer(MOCKSERVER_IMAGE) val mockServer: MockServerContainer = MockServerContainer(MOCKSERVER_IMAGE)
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.trustMockServer() .trustMockServer()
.build() .build()
@ -49,10 +50,10 @@ class BasicMockServerTest {
MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient -> MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient ->
mockServerClient mockServerClient
.`when`( .`when`(
request().withPath("/person") request()
.withPath("/person")
.withQueryStringParameter("name", "peter"), .withQueryStringParameter("name", "peter"),
) ).respond(response().withBody("Peter the person!"))
.respond(response().withBody("Peter the person!"))
val response = client.newCall(Request((mockServer.endpoint + "/person?name=peter").toHttpUrl())).execute() val response = client.newCall(Request((mockServer.endpoint + "/person?name=peter").toHttpUrl())).execute()
@ -65,10 +66,10 @@ class BasicMockServerTest {
MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient -> MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient ->
mockServerClient mockServerClient
.`when`( .`when`(
request().withPath("/person") request()
.withPath("/person")
.withQueryStringParameter("name", "peter"), .withQueryStringParameter("name", "peter"),
) ).respond(response().withBody("Peter the person!"))
.respond(response().withBody("Peter the person!"))
val response = client.newCall(Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl())).execute() val response = client.newCall(Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl())).execute()

View File

@ -55,9 +55,10 @@ class BasicProxyTest {
val client = OkHttpClient() val client = OkHttpClient()
val response = val response =
client.newCall( client
Request((mockServer.endpoint + "/person?name=peter").toHttpUrl()), .newCall(
).execute() Request((mockServer.endpoint + "/person?name=peter").toHttpUrl()),
).execute()
assertThat(response.body.string()).contains("Peter the person") assertThat(response.body.string()).contains("Peter the person")
assertThat(response.protocol).isEqualTo(Protocol.HTTP_1_1) assertThat(response.protocol).isEqualTo(Protocol.HTTP_1_1)
@ -70,14 +71,16 @@ class BasicProxyTest {
it.withProxyConfiguration(ProxyConfiguration.proxyConfiguration(ProxyConfiguration.Type.HTTP, it.remoteAddress())) it.withProxyConfiguration(ProxyConfiguration.proxyConfiguration(ProxyConfiguration.Type.HTTP, it.remoteAddress()))
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.proxy(Proxy(Proxy.Type.HTTP, it.remoteAddress())) .proxy(Proxy(Proxy.Type.HTTP, it.remoteAddress()))
.build() .build()
val response = val response =
client.newCall( client
Request((mockServer.endpoint + "/person?name=peter").toHttpUrl()), .newCall(
).execute() Request((mockServer.endpoint + "/person?name=peter").toHttpUrl()),
).execute()
assertThat(response.body.string()).contains("Peter the person") assertThat(response.body.string()).contains("Peter the person")
} }
@ -87,14 +90,16 @@ class BasicProxyTest {
fun testOkHttpSecureDirect() { fun testOkHttpSecureDirect() {
testRequest { testRequest {
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.trustMockServer() .trustMockServer()
.build() .build()
val response = val response =
client.newCall( client
Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl()), .newCall(
).execute() Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl()),
).execute()
assertThat(response.body.string()).contains("Peter the person") assertThat(response.body.string()).contains("Peter the person")
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2) assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)
@ -105,16 +110,18 @@ class BasicProxyTest {
fun testOkHttpSecureProxiedHttp1() { fun testOkHttpSecureProxiedHttp1() {
testRequest { testRequest {
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.trustMockServer() .trustMockServer()
.proxy(Proxy(Proxy.Type.HTTP, it.remoteAddress())) .proxy(Proxy(Proxy.Type.HTTP, it.remoteAddress()))
.protocols(listOf(Protocol.HTTP_1_1)) .protocols(listOf(Protocol.HTTP_1_1))
.build() .build()
val response = val response =
client.newCall( client
Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl()), .newCall(
).execute() Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl()),
).execute()
assertThat(response.body.string()).contains("Peter the person") assertThat(response.body.string()).contains("Peter the person")
assertThat(response.protocol).isEqualTo(Protocol.HTTP_1_1) assertThat(response.protocol).isEqualTo(Protocol.HTTP_1_1)
@ -128,7 +135,12 @@ class BasicProxyTest {
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
assertThat(connection.inputStream.source().buffer().readUtf8()).contains("Peter the person") assertThat(
connection.inputStream
.source()
.buffer()
.readUtf8(),
).contains("Peter the person")
} }
} }
@ -145,7 +157,12 @@ class BasicProxyTest {
val connection = url.openConnection(proxy) as HttpURLConnection val connection = url.openConnection(proxy) as HttpURLConnection
assertThat(connection.inputStream.source().buffer().readUtf8()).contains("Peter the person") assertThat(
connection.inputStream
.source()
.buffer()
.readUtf8(),
).contains("Peter the person")
} }
} }
@ -159,7 +176,12 @@ class BasicProxyTest {
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
assertThat(connection.inputStream.source().buffer().readUtf8()).contains("Peter the person") assertThat(
connection.inputStream
.source()
.buffer()
.readUtf8(),
).contains("Peter the person")
} }
} }
@ -179,21 +201,26 @@ class BasicProxyTest {
val connection = url.openConnection(proxy) as HttpURLConnection val connection = url.openConnection(proxy) as HttpURLConnection
assertThat(connection.inputStream.source().buffer().readUtf8()).contains("Peter the person") assertThat(
connection.inputStream
.source()
.buffer()
.readUtf8(),
).contains("Peter the person")
} }
} }
private fun testRequest(function: (MockServerClient) -> Unit) { private fun testRequest(function: (MockServerClient) -> Unit) {
MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient -> MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient ->
val request = val request =
request().withPath("/person") request()
.withPath("/person")
.withQueryStringParameter("name", "peter") .withQueryStringParameter("name", "peter")
mockServerClient mockServerClient
.`when`( .`when`(
request, request,
) ).respond(response().withBody("Peter the person!"))
.respond(response().withBody("Peter the person!"))
function(mockServerClient) function(mockServerClient)
} }

View File

@ -56,20 +56,22 @@ class SocksProxyTest {
MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient -> MockServerClient(mockServer.host, mockServer.serverPort).use { mockServerClient ->
mockServerClient mockServerClient
.`when`( .`when`(
request().withPath("/person") request()
.withPath("/person")
.withQueryStringParameter("name", "peter"), .withQueryStringParameter("name", "peter"),
) ).respond(response().withBody("Peter the person!"))
.respond(response().withBody("Peter the person!"))
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.proxy(Proxy(SOCKS, InetSocketAddress(socks5Proxy.host, socks5Proxy.firstMappedPort))) .proxy(Proxy(SOCKS, InetSocketAddress(socks5Proxy.host, socks5Proxy.firstMappedPort)))
.build() .build()
val response = val response =
client.newCall( client
Request("http://mockserver:1080/person?name=peter".toHttpUrl()), .newCall(
).execute() Request("http://mockserver:1080/person?name=peter".toHttpUrl()),
).execute()
assertThat(response.body.string()).contains("Peter the person") assertThat(response.body.string()).contains("Peter the person")
} }

View File

@ -57,7 +57,7 @@ gradlePlugin-kotlinSerialization = { module = "org.jetbrains.kotlin:kotlin-seria
gradlePlugin-mavenPublish = "com.vanniktech:gradle-maven-publish-plugin:0.31.0" gradlePlugin-mavenPublish = "com.vanniktech:gradle-maven-publish-plugin:0.31.0"
gradlePlugin-mavenSympathy = "io.github.usefulness.maven-sympathy:io.github.usefulness.maven-sympathy.gradle.plugin:0.3.0" gradlePlugin-mavenSympathy = "io.github.usefulness.maven-sympathy:io.github.usefulness.maven-sympathy.gradle.plugin:0.3.0"
gradlePlugin-shadow = "gradle.plugin.com.github.johnrengelman:shadow:8.0.0" gradlePlugin-shadow = "gradle.plugin.com.github.johnrengelman:shadow:8.0.0"
gradlePlugin-spotless = "com.diffplug.spotless:spotless-plugin-gradle:6.25.0" gradlePlugin-spotless = "com.diffplug.spotless:spotless-plugin-gradle:7.0.2"
guava-jre = "com.google.guava:guava:33.4.0-jre" guava-jre = "com.google.guava:guava:33.4.0-jre"
hamcrestLibrary = "org.hamcrest:hamcrest-library:3.0" hamcrestLibrary = "org.hamcrest:hamcrest-library:3.0"
httpClient5 = "org.apache.httpcomponents.client5:httpclient5:5.4.2" httpClient5 = "org.apache.httpcomponents.client5:httpclient5:5.4.2"

View File

@ -36,13 +36,9 @@ internal fun Dispatcher.wrap(): mockwebserver3.Dispatcher {
val delegate = this val delegate = this
return object : mockwebserver3.Dispatcher() { return object : mockwebserver3.Dispatcher() {
override fun dispatch(request: mockwebserver3.RecordedRequest): mockwebserver3.MockResponse { override fun dispatch(request: mockwebserver3.RecordedRequest): mockwebserver3.MockResponse = delegate.dispatch(request.unwrap()).wrap()
return delegate.dispatch(request.unwrap()).wrap()
}
override fun peek(): mockwebserver3.MockResponse { override fun peek(): mockwebserver3.MockResponse = delegate.peek().wrap()
return delegate.peek().wrap()
}
override fun shutdown() { override fun shutdown() {
delegate.shutdown() delegate.shutdown()
@ -86,17 +82,16 @@ internal fun MockResponse.wrap(): mockwebserver3.MockResponse {
return result.build() return result.build()
} }
private fun PushPromise.wrap(): mockwebserver3.PushPromise { private fun PushPromise.wrap(): mockwebserver3.PushPromise =
return mockwebserver3.PushPromise( mockwebserver3.PushPromise(
method = method, method = method,
path = path, path = path,
headers = headers, headers = headers,
response = response.wrap(), response = response.wrap(),
) )
}
internal fun mockwebserver3.RecordedRequest.unwrap(): RecordedRequest { internal fun mockwebserver3.RecordedRequest.unwrap(): RecordedRequest =
return RecordedRequest( RecordedRequest(
requestLine = requestLine, requestLine = requestLine,
headers = headers, headers = headers,
chunkSizes = chunkSizes, chunkSizes = chunkSizes,
@ -109,10 +104,9 @@ internal fun mockwebserver3.RecordedRequest.unwrap(): RecordedRequest {
handshake = handshake, handshake = handshake,
requestUrl = requestUrl, requestUrl = requestUrl,
) )
}
private fun MockResponse.wrapSocketPolicy(): mockwebserver3.SocketPolicy { private fun MockResponse.wrapSocketPolicy(): mockwebserver3.SocketPolicy =
return when (val socketPolicy = socketPolicy) { when (val socketPolicy = socketPolicy) {
SocketPolicy.SHUTDOWN_SERVER_AFTER_RESPONSE -> ShutdownServerAfterResponse SocketPolicy.SHUTDOWN_SERVER_AFTER_RESPONSE -> ShutdownServerAfterResponse
SocketPolicy.KEEP_OPEN -> KeepOpen SocketPolicy.KEEP_OPEN -> KeepOpen
SocketPolicy.DISCONNECT_AT_END -> DisconnectAtEnd SocketPolicy.DISCONNECT_AT_END -> DisconnectAtEnd
@ -129,4 +123,3 @@ private fun MockResponse.wrapSocketPolicy(): mockwebserver3.SocketPolicy {
SocketPolicy.RESET_STREAM_AT_START -> ResetStreamAtStart(http2ErrorCode) SocketPolicy.RESET_STREAM_AT_START -> ResetStreamAtStart(http2ErrorCode)
else -> error("Unexpected SocketPolicy: $socketPolicy") else -> error("Unexpected SocketPolicy: $socketPolicy")
} }
}

View File

@ -19,9 +19,7 @@ abstract class Dispatcher {
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
abstract fun dispatch(request: RecordedRequest): MockResponse abstract fun dispatch(request: RecordedRequest): MockResponse
open fun peek(): MockResponse { open fun peek(): MockResponse = MockResponse().apply { this.socketPolicy = SocketPolicy.KEEP_OPEN }
return MockResponse().apply { this.socketPolicy = SocketPolicy.KEEP_OPEN }
}
open fun shutdown() {} open fun shutdown() {}
} }

View File

@ -29,7 +29,9 @@ import okhttp3.HttpUrl
import okhttp3.Protocol import okhttp3.Protocol
import org.junit.rules.ExternalResource import org.junit.rules.ExternalResource
class MockWebServer : ExternalResource(), Closeable { class MockWebServer :
ExternalResource(),
Closeable {
@ExperimentalOkHttpApi @ExperimentalOkHttpApi
val delegate = mockwebserver3.MockWebServer() val delegate = mockwebserver3.MockWebServer()
@ -176,17 +178,13 @@ class MockWebServer : ExternalResource(), Closeable {
} }
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
fun takeRequest(): RecordedRequest { fun takeRequest(): RecordedRequest = delegate.takeRequest().unwrap()
return delegate.takeRequest().unwrap()
}
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
fun takeRequest( fun takeRequest(
timeout: Long, timeout: Long,
unit: TimeUnit, unit: TimeUnit,
): RecordedRequest? { ): RecordedRequest? = delegate.takeRequest(timeout, unit)?.unwrap()
return delegate.takeRequest(timeout, unit)?.unwrap()
}
@JvmName("-deprecated_requestCount") @JvmName("-deprecated_requestCount")
@Deprecated( @Deprecated(

View File

@ -21,13 +21,9 @@ class QueueDispatcher : Dispatcher() {
internal val delegate = QueueDispatcher() internal val delegate = QueueDispatcher()
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
override fun dispatch(request: RecordedRequest): MockResponse { override fun dispatch(request: RecordedRequest): MockResponse = throw UnsupportedOperationException("unexpected call")
throw UnsupportedOperationException("unexpected call")
}
override fun peek(): MockResponse { override fun peek(): MockResponse = throw UnsupportedOperationException("unexpected call")
throw UnsupportedOperationException("unexpected call")
}
fun enqueueResponse(response: MockResponse) { fun enqueueResponse(response: MockResponse) {
delegate.enqueueResponse(response.wrap()) delegate.enqueueResponse(response.wrap())

View File

@ -266,11 +266,19 @@ class MockWebServerTest {
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard. server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse()) server.enqueue(MockResponse())
try { try {
server.url("/a").toUrl().openConnection().getInputStream() server
.url("/a")
.toUrl()
.openConnection()
.getInputStream()
fail<Any>() fail<Any>()
} catch (expected: IOException) { } catch (expected: IOException) {
} }
server.url("/b").toUrl().openConnection().getInputStream() // Should succeed. server
.url("/b")
.toUrl()
.openConnection()
.getInputStream() // Should succeed.
} }
/** /**
@ -452,7 +460,11 @@ class MockWebServerTest {
object : Statement() { object : Statement() {
override fun evaluate() { override fun evaluate() {
called.set(true) called.set(true)
server.url("/").toUrl().openConnection().connect() server
.url("/")
.toUrl()
.openConnection()
.connect()
} }
}, },
Description.EMPTY, Description.EMPTY,
@ -460,7 +472,11 @@ class MockWebServerTest {
statement.evaluate() statement.evaluate()
assertThat(called.get()).isTrue() assertThat(called.get()).isTrue()
try { try {
server.url("/").toUrl().openConnection().connect() server
.url("/")
.toUrl()
.openConnection()
.connect()
fail<Any>() fail<Any>()
} catch (expected: ConnectException) { } catch (expected: ConnectException) {
} }
@ -593,20 +609,24 @@ class MockWebServerTest {
platform.assumeNotBouncyCastle() platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt() platform.assumeNotConscrypt()
val clientCa = val clientCa =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.build() .build()
val serverCa = val serverCa =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.build() .build()
val serverCertificate = val serverCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(serverCa) .signedBy(serverCa)
.addSubjectAlternativeName(server.hostName) .addSubjectAlternativeName(server.hostName)
.build() .build()
val serverHandshakeCertificates = val serverHandshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(clientCa.certificate) .addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate) .heldCertificate(serverCertificate)
.build() .build()
@ -614,11 +634,13 @@ class MockWebServerTest {
server.enqueue(MockResponse().setBody("abc")) server.enqueue(MockResponse().setBody("abc"))
server.requestClientAuth() server.requestClientAuth()
val clientCertificate = val clientCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(clientCa) .signedBy(clientCa)
.build() .build()
val clientHandshakeCertificates = val clientHandshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(serverCa.certificate) .addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate) .heldCertificate(clientCertificate)
.build() .build()

View File

@ -33,7 +33,11 @@ class MockWebServerRuleTest {
object : Statement() { object : Statement() {
override fun evaluate() { override fun evaluate() {
called.set(true) called.set(true)
rule.server.url("/").toUrl().openConnection().connect() rule.server
.url("/")
.toUrl()
.openConnection()
.connect()
} }
}, },
Description.EMPTY, Description.EMPTY,
@ -41,7 +45,11 @@ class MockWebServerRuleTest {
statement.evaluate() statement.evaluate()
assertThat(called.get()).isTrue() assertThat(called.get()).isTrue()
try { try {
rule.server.url("/").toUrl().openConnection().connect() rule.server
.url("/")
.toUrl()
.openConnection()
.connect()
fail() fail()
} catch (expected: ConnectException) { } catch (expected: ConnectException) {
} }

View File

@ -43,7 +43,11 @@ import org.junit.jupiter.api.extension.ParameterResolver
*/ */
@ExperimentalOkHttpApi @ExperimentalOkHttpApi
class MockWebServerExtension : class MockWebServerExtension :
BeforeEachCallback, AfterEachCallback, ParameterResolver, BeforeAllCallback, AfterAllCallback { BeforeEachCallback,
AfterEachCallback,
ParameterResolver,
BeforeAllCallback,
AfterAllCallback {
private val ExtensionContext.resource: ServersForTest private val ExtensionContext.resource: ServersForTest
get() = get() =
getStore(namespace).getOrComputeIfAbsent(this.uniqueId) { getStore(namespace).getOrComputeIfAbsent(this.uniqueId) {
@ -54,13 +58,12 @@ class MockWebServerExtension :
private val servers = mutableMapOf<String, MockWebServer>() private val servers = mutableMapOf<String, MockWebServer>()
private var started = false private var started = false
fun server(name: String): MockWebServer { fun server(name: String): MockWebServer =
return servers.getOrPut(name) { servers.getOrPut(name) {
MockWebServer().also { MockWebServer().also {
if (started) it.start() if (started) it.start()
} }
} }
}
fun startAll() { fun startAll() {
started = true started = true
@ -87,9 +90,7 @@ class MockWebServerExtension :
override fun supportsParameter( override fun supportsParameter(
parameterContext: ParameterContext, parameterContext: ParameterContext,
extensionContext: ExtensionContext, extensionContext: ExtensionContext,
): Boolean { ): Boolean = parameterContext.parameter.type === MockWebServer::class.java
return parameterContext.parameter.type === MockWebServer::class.java
}
@Suppress("NewApi") @Suppress("NewApi")
override fun resolveParameter( override fun resolveParameter(

View File

@ -34,9 +34,7 @@ abstract class Dispatcher {
* can return other values to test HTTP edge cases, such as unhappy socket policies or throttled * can return other values to test HTTP edge cases, such as unhappy socket policies or throttled
* request bodies. * request bodies.
*/ */
open fun peek(): MockResponse { open fun peek(): MockResponse = MockResponse(socketPolicy = KeepOpen)
return MockResponse(socketPolicy = KeepOpen)
}
/** /**
* Release any resources held by this dispatcher. Any requests that are currently being dispatched * Release any resources held by this dispatcher. Any requests that are currently being dispatched

View File

@ -196,7 +196,8 @@ class MockResponse {
this.streamHandlerVar = null this.streamHandlerVar = null
this.webSocketListenerVar = null this.webSocketListenerVar = null
this.headers = this.headers =
Headers.Builder() Headers
.Builder()
.add("Content-Length", "0") .add("Content-Length", "0")
this.trailers = Headers.Builder() this.trailers = Headers.Builder()
this.throttleBytesPerPeriod = Long.MAX_VALUE this.throttleBytesPerPeriod = Long.MAX_VALUE

View File

@ -219,14 +219,14 @@ class MockWebServer : Closeable {
* *
* @param path the request path, such as "/". * @param path the request path, such as "/".
*/ */
fun url(path: String): HttpUrl { fun url(path: String): HttpUrl =
return HttpUrl.Builder() HttpUrl
.Builder()
.scheme(if (sslSocketFactory != null) "https" else "http") .scheme(if (sslSocketFactory != null) "https" else "http")
.host(hostName) .host(hostName)
.port(port) .port(port)
.build() .build()
.resolve(path)!! .resolve(path)!!
}
/** /**
* Serve requests with HTTPS rather than otherwise. * Serve requests with HTTPS rather than otherwise.
@ -426,7 +426,9 @@ class MockWebServer : Closeable {
} }
} }
internal inner class SocketHandler(private val raw: Socket) { internal inner class SocketHandler(
private val raw: Socket,
) {
private var sequenceNumber = 0 private var sequenceNumber = 0
@Throws(Exception::class) @Throws(Exception::class)
@ -496,7 +498,8 @@ class MockWebServer : Closeable {
if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) { if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) {
val http2SocketHandler = Http2SocketHandler(socket, protocol) val http2SocketHandler = Http2SocketHandler(socket, protocol)
val connection = val connection =
Http2Connection.Builder(false, taskRunner) Http2Connection
.Builder(false, taskRunner)
.socket(socket) .socket(socket)
.listener(http2SocketHandler) .listener(http2SocketHandler)
.build() .build()
@ -707,12 +710,13 @@ class MockWebServer : Closeable {
var hasBody = false var hasBody = false
val policy = dispatcher.peek() val policy = dispatcher.peek()
val requestBodySink = val requestBodySink =
requestBody.withThrottlingAndSocketPolicy( requestBody
policy = policy, .withThrottlingAndSocketPolicy(
disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody, policy = policy,
expectedByteCount = contentLength, disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody,
socket = socket, expectedByteCount = contentLength,
).buffer() socket = socket,
).buffer()
requestBodySink.use { requestBodySink.use {
when { when {
policy.socketPolicy is DoNotReadRequestBody -> { policy.socketPolicy is DoNotReadRequestBody -> {
@ -772,7 +776,8 @@ class MockWebServer : Closeable {
) { ) {
val key = request.headers["Sec-WebSocket-Key"] val key = request.headers["Sec-WebSocket-Key"]
val webSocketResponse = val webSocketResponse =
response.newBuilder() response
.newBuilder()
.setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!)) .setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!))
.build() .build()
writeHttpResponse(socket, sink, webSocketResponse) writeHttpResponse(socket, sink, webSocketResponse)
@ -781,12 +786,14 @@ class MockWebServer : Closeable {
val scheme = if (request.handshake != null) "https" else "http" val scheme = if (request.handshake != null) "https" else "http"
val authority = request.headers["Host"] // Has host and port. val authority = request.headers["Host"] // Has host and port.
val fancyRequest = val fancyRequest =
Request.Builder() Request
.Builder()
.url("$scheme://$authority/") .url("$scheme://$authority/")
.headers(request.headers) .headers(request.headers)
.build() .build()
val fancyResponse = val fancyResponse =
Response.Builder() Response
.Builder()
.code(webSocketResponse.code) .code(webSocketResponse.code)
.message(webSocketResponse.message) .message(webSocketResponse.message)
.headers(webSocketResponse.headers) .headers(webSocketResponse.headers)
@ -842,12 +849,13 @@ class MockWebServer : Closeable {
val body = response.body ?: return val body = response.body ?: return
sleepNanos(response.bodyDelayNanos) sleepNanos(response.bodyDelayNanos)
val responseBodySink = val responseBodySink =
sink.withThrottlingAndSocketPolicy( sink
policy = response, .withThrottlingAndSocketPolicy(
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody, policy = response,
expectedByteCount = body.contentLength, disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
socket = socket, expectedByteCount = body.contentLength,
).buffer() socket = socket,
).buffer()
body.writeTo(responseBodySink) body.writeTo(responseBodySink)
responseBodySink.emit() responseBodySink.emit()
@ -1044,12 +1052,13 @@ class MockWebServer : Closeable {
try { try {
val contentLengthString = headers["content-length"] val contentLengthString = headers["content-length"]
val requestBodySink = val requestBodySink =
body.withThrottlingAndSocketPolicy( body
policy = peek, .withThrottlingAndSocketPolicy(
disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody, policy = peek,
expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE, disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody,
socket = socket, expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE,
).buffer() socket = socket,
).buffer()
requestBodySink.use { requestBodySink.use {
it.writeAll(stream.getSource()) it.writeAll(stream.getSource())
} }
@ -1116,12 +1125,14 @@ class MockWebServer : Closeable {
if (body != null) { if (body != null) {
sleepNanos(bodyDelayNanos) sleepNanos(bodyDelayNanos)
val responseBodySink = val responseBodySink =
stream.getSink().withThrottlingAndSocketPolicy( stream
policy = response, .getSink()
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody, .withThrottlingAndSocketPolicy(
expectedByteCount = body.contentLength, policy = response,
socket = socket, disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
).buffer() expectedByteCount = body.contentLength,
socket = socket,
).buffer()
responseBodySink.use { responseBodySink.use {
body.writeTo(responseBodySink) body.writeTo(responseBodySink)
} }

View File

@ -53,9 +53,7 @@ open class QueueDispatcher : Dispatcher() {
return result return result
} }
override fun peek(): MockResponse { override fun peek(): MockResponse = responseQueue.peek() ?: failFastResponse ?: super.peek()
return responseQueue.peek() ?: failFastResponse ?: super.peek()
}
open fun enqueueResponse(response: MockResponse) { open fun enqueueResponse(response: MockResponse) {
responseQueue.add(response) responseQueue.add(response)

View File

@ -90,8 +90,8 @@ class CustomDispatcherTest {
private fun buildRequestThread( private fun buildRequestThread(
path: String, path: String,
responseCode: AtomicInteger, responseCode: AtomicInteger,
): Thread { ): Thread =
return Thread { Thread {
val url = mockWebServer.url(path).toUrl() val url = mockWebServer.url(path).toUrl()
val conn: HttpURLConnection val conn: HttpURLConnection
try { try {
@ -100,5 +100,4 @@ class CustomDispatcherTest {
} catch (ignored: IOException) { } catch (ignored: IOException) {
} }
} }
}
} }

View File

@ -61,17 +61,22 @@ class MockResponseSniTest {
} }
val client = val client =
clientTestRule.newClientBuilder() clientTestRule
.newClientBuilder()
.sslSocketFactory( .sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager, handshakeCertificates.trustManager,
) ).dns(dns)
.dns(dns)
.build() .build()
server.enqueue(MockResponse()) server.enqueue(MockResponse())
val url = server.url("/").newBuilder().host("localhost.localdomain").build() val url =
server
.url("/")
.newBuilder()
.host("localhost.localdomain")
.build()
val call = client.newCall(Request(url = url)) val call = client.newCall(Request(url = url))
val response = call.execute() val response = call.execute()
assertThat(response.isSuccessful).isTrue() assertThat(response.isSuccessful).isTrue()
@ -90,12 +95,14 @@ class MockResponseSniTest {
@Test @Test
fun domainFronting() { fun domainFronting() {
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.commonName("server name") .commonName("server name")
.addSubjectAlternativeName("url-host.com") .addSubjectAlternativeName("url-host.com")
.build() .build()
val handshakeCertificates = val handshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.heldCertificate(heldCertificate) .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate) .addTrustedCertificate(heldCertificate.certificate)
.build() .build()
@ -107,12 +114,12 @@ class MockResponseSniTest {
} }
val client = val client =
clientTestRule.newClientBuilder() clientTestRule
.newClientBuilder()
.sslSocketFactory( .sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager, handshakeCertificates.trustManager,
) ).dns(dns)
.dns(dns)
.build() .build()
server.enqueue(MockResponse()) server.enqueue(MockResponse())
@ -168,24 +175,26 @@ class MockResponseSniTest {
*/ */
private fun requestToHostnameViaProxy(hostnameOrIpAddress: String): RecordedRequest { private fun requestToHostnameViaProxy(hostnameOrIpAddress: String): RecordedRequest {
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.commonName("server name") .commonName("server name")
.addSubjectAlternativeName(hostnameOrIpAddress) .addSubjectAlternativeName(hostnameOrIpAddress)
.build() .build()
val handshakeCertificates = val handshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.heldCertificate(heldCertificate) .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate) .addTrustedCertificate(heldCertificate.certificate)
.build() .build()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
val client = val client =
clientTestRule.newClientBuilder() clientTestRule
.newClientBuilder()
.sslSocketFactory( .sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager, handshakeCertificates.trustManager,
) ).proxy(server.toProxyAddress())
.proxy(server.toProxyAddress())
.build() .build()
server.enqueue(MockResponse(inTunnel = true)) server.enqueue(MockResponse(inTunnel = true))
@ -195,7 +204,9 @@ class MockResponseSniTest {
client.newCall( client.newCall(
Request( Request(
url = url =
server.url("/").newBuilder() server
.url("/")
.newBuilder()
.host(hostnameOrIpAddress) .host(hostnameOrIpAddress)
.build(), .build(),
), ),

View File

@ -130,7 +130,8 @@ class MockWebServerTest {
@Test @Test
fun mockResponseAddHeader() { fun mockResponseAddHeader() {
val builder = val builder =
MockResponse.Builder() MockResponse
.Builder()
.clearHeaders() .clearHeaders()
.addHeader("Cookie: s=square") .addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android") .addHeader("Cookie", "a=android")
@ -140,7 +141,8 @@ class MockWebServerTest {
@Test @Test
fun mockResponseSetHeader() { fun mockResponseSetHeader() {
val builder = val builder =
MockResponse.Builder() MockResponse
.Builder()
.clearHeaders() .clearHeaders()
.addHeader("Cookie: s=square") .addHeader("Cookie: s=square")
.addHeader("Cookie: a=android") .addHeader("Cookie: a=android")
@ -152,7 +154,8 @@ class MockWebServerTest {
@Test @Test
fun mockResponseSetHeaders() { fun mockResponseSetHeaders() {
val builder = val builder =
MockResponse.Builder() MockResponse
.Builder()
.clearHeaders() .clearHeaders()
.addHeader("Cookie: s=square") .addHeader("Cookie: s=square")
.addHeader("Cookies: delicious") .addHeader("Cookies: delicious")
@ -180,14 +183,16 @@ class MockWebServerTest {
@Test @Test
fun redirect() { fun redirect() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.code(HttpURLConnection.HTTP_MOVED_TEMP) .code(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path")) .addHeader("Location: " + server.url("/new-path"))
.body("This page has moved!") .body("This page has moved!")
.build(), .build(),
) )
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("This is the new location!") .body("This is the new location!")
.build(), .build(),
) )
@ -212,7 +217,8 @@ class MockWebServerTest {
} catch (ignored: InterruptedException) { } catch (ignored: InterruptedException) {
} }
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("enqueued in the background") .body("enqueued in the background")
.build(), .build(),
) )
@ -225,7 +231,8 @@ class MockWebServerTest {
@Test @Test
fun nonHexadecimalChunkSize() { fun nonHexadecimalChunkSize() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n") .body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders() .clearHeaders()
.addHeader("Transfer-encoding: chunked") .addHeader("Transfer-encoding: chunked")
@ -243,14 +250,16 @@ class MockWebServerTest {
@Test @Test
fun responseTimeout() { fun responseTimeout() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("ABC") .body("ABC")
.clearHeaders() .clearHeaders()
.addHeader("Content-Length: 4") .addHeader("Content-Length: 4")
.build(), .build(),
) )
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("DEF") .body("DEF")
.build(), .build(),
) )
@ -280,19 +289,28 @@ class MockWebServerTest {
@Test @Test
fun disconnectAtStart() { fun disconnectAtStart() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.socketPolicy(DisconnectAtStart) .socketPolicy(DisconnectAtStart)
.build(), .build(),
) )
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard. server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse()) server.enqueue(MockResponse())
try { try {
server.url("/a").toUrl().openConnection().getInputStream() server
.url("/a")
.toUrl()
.openConnection()
.getInputStream()
fail<Unit>() fail<Unit>()
} catch (expected: IOException) { } catch (expected: IOException) {
// Expected. // Expected.
} }
server.url("/b").toUrl().openConnection().getInputStream() // Should succeed. server
.url("/b")
.toUrl()
.openConnection()
.getInputStream() // Should succeed.
} }
@Test @Test
@ -300,7 +318,12 @@ class MockWebServerTest {
server.enqueue(MockResponse(body = "A")) server.enqueue(MockResponse(body = "A"))
(server.dispatcher as QueueDispatcher).clear() (server.dispatcher as QueueDispatcher).clear()
server.enqueue(MockResponse(body = "B")) server.enqueue(MockResponse(body = "B"))
val inputStream = server.url("/a").toUrl().openConnection().getInputStream() val inputStream =
server
.url("/a")
.toUrl()
.openConnection()
.getInputStream()
assertThat(inputStream!!.read()).isEqualTo('B'.code) assertThat(inputStream!!.read()).isEqualTo('B'.code)
} }
@ -312,7 +335,8 @@ class MockWebServerTest {
fun throttleRequest() { fun throttleRequest() {
assumeNotWindows() assumeNotWindows()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.throttleBody(3, 500, TimeUnit.MILLISECONDS) .throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build(), .build(),
) )
@ -335,7 +359,8 @@ class MockWebServerTest {
fun throttleResponse() { fun throttleResponse() {
assumeNotWindows() assumeNotWindows()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("ABCDEF") .body("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS) .throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build(), .build(),
@ -360,7 +385,8 @@ class MockWebServerTest {
fun delayResponse() { fun delayResponse() {
assumeNotWindows() assumeNotWindows()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("ABCDEF") .body("ABCDEF")
.bodyDelay(1, TimeUnit.SECONDS) .bodyDelay(1, TimeUnit.SECONDS)
.build(), .build(),
@ -378,7 +404,8 @@ class MockWebServerTest {
@Test @Test
fun disconnectRequestHalfway() { fun disconnectRequestHalfway() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.socketPolicy(DisconnectDuringRequestBody) .socketPolicy(DisconnectDuringRequestBody)
.build(), .build(),
) )
@ -413,7 +440,8 @@ class MockWebServerTest {
@Test @Test
fun disconnectResponseHalfway() { fun disconnectResponseHalfway() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("ab") .body("ab")
.socketPolicy(DisconnectDuringResponseBody) .socketPolicy(DisconnectDuringResponseBody)
.build(), .build(),
@ -497,7 +525,8 @@ class MockWebServerTest {
@Test @Test
fun requestUrlReconstructed() { fun requestUrlReconstructed() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("hello world") .body("hello world")
.build(), .build(),
) )
@ -522,7 +551,8 @@ class MockWebServerTest {
@Test @Test
fun shutdownServerAfterRequest() { fun shutdownServerAfterRequest() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.socketPolicy(ShutdownServerAfterResponse) .socketPolicy(ShutdownServerAfterResponse)
.build(), .build(),
) )
@ -540,7 +570,8 @@ class MockWebServerTest {
@Test @Test
fun http100Continue() { fun http100Continue() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("response") .body("response")
.build(), .build(),
) )
@ -559,7 +590,8 @@ class MockWebServerTest {
@Test @Test
fun multiple1xxResponses() { fun multiple1xxResponses() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.add100Continue() .add100Continue()
.add100Continue() .add100Continue()
.body("response") .body("response")
@ -615,7 +647,8 @@ class MockWebServerTest {
val handshakeCertificates = platform.localhostHandshakeCertificates() val handshakeCertificates = platform.localhostHandshakeCertificates()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("abc") .body("abc")
.build(), .build(),
) )
@ -643,36 +676,43 @@ class MockWebServerTest {
platform.assumeNotConscrypt() platform.assumeNotConscrypt()
val clientCa = val clientCa =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.build() .build()
val serverCa = val serverCa =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.build() .build()
val serverCertificate = val serverCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(serverCa) .signedBy(serverCa)
.addSubjectAlternativeName(server.hostName) .addSubjectAlternativeName(server.hostName)
.build() .build()
val serverHandshakeCertificates = val serverHandshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(clientCa.certificate) .addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate) .heldCertificate(serverCertificate)
.build() .build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory()) server.useHttps(serverHandshakeCertificates.sslSocketFactory())
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("abc") .body("abc")
.build(), .build(),
) )
server.requestClientAuth() server.requestClientAuth()
val clientCertificate = val clientCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(clientCa) .signedBy(clientCa)
.build() .build()
val clientHandshakeCertificates = val clientHandshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(serverCa.certificate) .addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate) .heldCertificate(clientCertificate)
.build() .build()
@ -697,12 +737,14 @@ class MockWebServerTest {
@Test @Test
fun proxiedRequestGetsCorrectRequestUrl() { fun proxiedRequestGetsCorrectRequestUrl() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Result") .body("Result")
.build(), .build(),
) )
val proxiedClient = val proxiedClient =
OkHttpClient.Builder() OkHttpClient
.Builder()
.proxy(server.toProxyAddress()) .proxy(server.toProxyAddress())
.readTimeout(Duration.ofMillis(100)) .readTimeout(Duration.ofMillis(100))
.build() .build()

View File

@ -55,7 +55,8 @@ class Http2Server(
throw ProtocolException("Protocol $protocol unsupported") throw ProtocolException("Protocol $protocol unsupported")
} }
val connection = val connection =
Http2Connection.Builder(false, TaskRunner.INSTANCE) Http2Connection
.Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket) .socket(sslSocket)
.listener(this) .listener(this)
.build() .build()
@ -179,8 +180,8 @@ class Http2Server(
} }
} }
private fun contentType(file: File): String { private fun contentType(file: File): String =
return when { when {
file.name.endsWith(".css") -> "text/css" file.name.endsWith(".css") -> "text/css"
file.name.endsWith(".gif") -> "image/gif" file.name.endsWith(".gif") -> "image/gif"
file.name.endsWith(".html") -> "text/html" file.name.endsWith(".html") -> "text/html"
@ -190,7 +191,6 @@ class Http2Server(
file.name.endsWith(".png") -> "image/png" file.name.endsWith(".png") -> "image/png"
else -> "text/plain" else -> "text/plain"
} }
}
companion object { companion object {
val logger: Logger = Logger.getLogger(Http2Server::class.java.name) val logger: Logger = Logger.getLogger(Http2Server::class.java.name)

View File

@ -64,16 +64,20 @@ class Main : CliktCommand(name = NAME) {
"--connect-timeout", "--connect-timeout",
).help( ).help(
"Maximum time allowed for connection (seconds)", "Maximum time allowed for connection (seconds)",
).int().default(DEFAULT_TIMEOUT) ).int()
.default(DEFAULT_TIMEOUT)
val readTimeout: Int by option("--read-timeout").help("Maximum time allowed for reading data (seconds)").int() val readTimeout: Int by option("--read-timeout")
.help("Maximum time allowed for reading data (seconds)")
.int()
.default(DEFAULT_TIMEOUT) .default(DEFAULT_TIMEOUT)
val callTimeout: Int by option( val callTimeout: Int by option(
"--call-timeout", "--call-timeout",
).help( ).help(
"Maximum time allowed for the entire call (seconds)", "Maximum time allowed for the entire call (seconds)",
).int().default(DEFAULT_TIMEOUT) ).int()
.default(DEFAULT_TIMEOUT)
val followRedirects: Boolean by option("-L", "--location").help("Follow redirects").flag() val followRedirects: Boolean by option("-L", "--location").help("Follow redirects").flag()
@ -163,9 +167,12 @@ class Main : CliktCommand(name = NAME) {
} }
private fun createInsecureSslSocketFactory(trustManager: TrustManager): SSLSocketFactory = private fun createInsecureSslSocketFactory(trustManager: TrustManager): SSLSocketFactory =
Platform.get().newSSLContext().apply { Platform
init(null, arrayOf(trustManager), null) .get()
}.socketFactory .newSSLContext()
.apply {
init(null, arrayOf(trustManager), null)
}.socketFactory
private fun createInsecureHostnameVerifier(): HostnameVerifier = HostnameVerifier { _, _ -> true } private fun createInsecureHostnameVerifier(): HostnameVerifier = HostnameVerifier { _, _ -> true }
} }

View File

@ -68,9 +68,7 @@ private fun Main.mediaType(): MediaType? {
return mimeType.toMediaTypeOrNull() return mimeType.toMediaTypeOrNull()
} }
private fun isSpecialHeader(s: String): Boolean { private fun isSpecialHeader(s: String): Boolean = s.equals("Content-Type", ignoreCase = true)
return s.equals("Content-Type", ignoreCase = true)
}
fun Main.commonRun() { fun Main.commonRun() {
client = createClient() client = createClient()

View File

@ -19,7 +19,5 @@ import java.util.logging.LogRecord
import java.util.logging.SimpleFormatter import java.util.logging.SimpleFormatter
object MessageFormatter : SimpleFormatter() { object MessageFormatter : SimpleFormatter() {
override fun format(record: LogRecord): String { override fun format(record: LogRecord): String = String.format("%s%n", record.message)
return String.format("%s%n", record.message)
}
} }

View File

@ -123,20 +123,18 @@ class MainTest {
} }
companion object { companion object {
fun fromArgs(vararg args: String): Main { fun fromArgs(vararg args: String): Main =
return Main().apply { Main().apply {
parse(args.toList()) parse(args.toList())
} }
}
private fun bodyAsString(body: RequestBody?): String { private fun bodyAsString(body: RequestBody?): String =
return try { try {
val buffer = Buffer() val buffer = Buffer()
body!!.writeTo(buffer) body!!.writeTo(buffer)
buffer.readString(body.contentType()!!.charset()!!) buffer.readString(body.contentType()!!.charset()!!)
} catch (e: IOException) { } catch (e: IOException) {
throw RuntimeException(e) throw RuntimeException(e)
} }
}
} }
} }

View File

@ -26,10 +26,12 @@ import okhttp3.brotli.internal.uncompress
* responses. n.b. this replaces the transparent gzip compression in BridgeInterceptor. * responses. n.b. this replaces the transparent gzip compression in BridgeInterceptor.
*/ */
object BrotliInterceptor : Interceptor { object BrotliInterceptor : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response =
return if (chain.request().header("Accept-Encoding") == null) { if (chain.request().header("Accept-Encoding") == null) {
val request = val request =
chain.request().newBuilder() chain
.request()
.newBuilder()
.header("Accept-Encoding", "br,gzip") .header("Accept-Encoding", "br,gzip")
.build() .build()
@ -39,5 +41,4 @@ object BrotliInterceptor : Interceptor {
} else { } else {
chain.proceed(chain.request()) chain.proceed(chain.request())
} }
}
} }

View File

@ -39,7 +39,8 @@ fun uncompress(response: Response): Response {
else -> return response else -> return response
} }
return response.newBuilder() return response
.newBuilder()
.removeHeader("Content-Encoding") .removeHeader("Content-Encoding")
.removeHeader("Content-Length") .removeHeader("Content-Length")
.body(decompressedSource.asResponseBody(body.contentType(), -1)) .body(decompressedSource.asResponseBody(body.contentType(), -1))

View File

@ -121,8 +121,9 @@ class BrotliInterceptorTest {
url: String, url: String,
bodyHex: ByteString, bodyHex: ByteString,
fn: Response.Builder.() -> Unit = {}, fn: Response.Builder.() -> Unit = {},
): Response { ): Response =
return Response.Builder() Response
.Builder()
.body(bodyHex.toResponseBody("text/plain".toMediaType())) .body(bodyHex.toResponseBody("text/plain".toMediaType()))
.code(200) .code(200)
.message("OK") .message("OK")
@ -130,5 +131,4 @@ class BrotliInterceptorTest {
.protocol(Protocol.HTTP_2) .protocol(Protocol.HTTP_2)
.apply(fn) .apply(fn)
.build() .build()
}
} }

View File

@ -20,7 +20,8 @@ import okhttp3.Request
fun main() { fun main() {
val client = val client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.addInterceptor(BrotliInterceptor) .addInterceptor(BrotliInterceptor)
.build() .build()

View File

@ -88,7 +88,8 @@ class ExecuteAsyncTest {
fun timeoutCall() { fun timeoutCall() {
runTest { runTest {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.bodyDelay(5, TimeUnit.SECONDS) .bodyDelay(5, TimeUnit.SECONDS)
.body("abc") .body("abc")
.build(), .build(),
@ -117,7 +118,8 @@ class ExecuteAsyncTest {
fun cancelledCall() { fun cancelledCall() {
runTest { runTest {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.bodyDelay(5, TimeUnit.SECONDS) .bodyDelay(5, TimeUnit.SECONDS)
.body("abc") .body("abc")
.build(), .build(),
@ -187,7 +189,8 @@ class ExecuteAsyncTest {
/** A call that keeps track of whether its response body is closed. */ /** A call that keeps track of whether its response body is closed. */
private class ClosableCall : FailingCall() { private class ClosableCall : FailingCall() {
private val response = private val response =
Response.Builder() Response
.Builder()
.request(Request("https://example.com/".toHttpUrl())) .request(Request("https://example.com/".toHttpUrl()))
.protocol(Protocol.HTTP_1_1) .protocol(Protocol.HTTP_1_1)
.message("OK") .message("OK")
@ -205,8 +208,7 @@ class ExecuteAsyncTest {
} }
}.buffer() }.buffer()
}, },
) ).build()
.build()
var responseClosed = false var responseClosed = false
var canceled = false var canceled = false

View File

@ -202,23 +202,27 @@ class DnsOverHttps internal constructor(
hostname: String, hostname: String,
type: Int, type: Int,
): Request = ): Request =
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply { Request
val query = DnsRecordCodec.encodeQuery(hostname, type) .Builder()
.header("Accept", DNS_MESSAGE.toString())
.apply {
val query = DnsRecordCodec.encodeQuery(hostname, type)
if (post) { if (post) {
url(url) url(url)
.cacheUrlOverride( .cacheUrlOverride(
url.newBuilder() url
.addQueryParameter("hostname", hostname).build(), .newBuilder()
) .addQueryParameter("hostname", hostname)
.post(query.toRequestBody(DNS_MESSAGE)) .build(),
} else { ).post(query.toRequestBody(DNS_MESSAGE))
val encoded = query.base64Url().replace("=", "") } else {
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build() val encoded = query.base64Url().replace("=", "")
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build()
url(requestUrl) url(requestUrl)
} }
}.build() }.build()
class Builder { class Builder {
internal var client: OkHttpClient? = null internal var client: OkHttpClient? = null
@ -299,8 +303,6 @@ class DnsOverHttps internal constructor(
} }
} }
internal fun isPrivateHost(host: String): Boolean { internal fun isPrivateHost(host: String): Boolean = PublicSuffixDatabase.get().getEffectiveTldPlusOne(host) == null
return PublicSuffixDatabase.get().getEffectiveTldPlusOne(host) == null
}
} }
} }

View File

@ -37,28 +37,29 @@ internal object DnsRecordCodec {
host: String, host: String,
type: Int, type: Int,
): ByteString = ): ByteString =
Buffer().apply { Buffer()
writeShort(0) // query id .apply {
writeShort(256) // flags with recursion writeShort(0) // query id
writeShort(1) // question count writeShort(256) // flags with recursion
writeShort(0) // answerCount writeShort(1) // question count
writeShort(0) // authorityResourceCount writeShort(0) // answerCount
writeShort(0) // additional writeShort(0) // authorityResourceCount
writeShort(0) // additional
val nameBuf = Buffer() val nameBuf = Buffer()
val labels = host.split('.').dropLastWhile { it.isEmpty() } val labels = host.split('.').dropLastWhile { it.isEmpty() }
for (label in labels) { for (label in labels) {
val utf8ByteCount = label.utf8Size() val utf8ByteCount = label.utf8Size()
require(utf8ByteCount == label.length.toLong()) { "non-ascii hostname: $host" } require(utf8ByteCount == label.length.toLong()) { "non-ascii hostname: $host" }
nameBuf.writeByte(utf8ByteCount.toInt()) nameBuf.writeByte(utf8ByteCount.toInt())
nameBuf.writeUtf8(label) nameBuf.writeUtf8(label)
} }
nameBuf.writeByte(0) // end nameBuf.writeByte(0) // end
nameBuf.copyTo(this, 0, nameBuf.size) nameBuf.copyTo(this, 0, nameBuf.size)
writeShort(type) writeShort(type)
writeShort(1) // CLASS_IN writeShort(1) // CLASS_IN
}.readByteString() }.readByteString()
@Throws(Exception::class) @Throws(Exception::class)
fun decodeAnswers( fun decodeAnswers(

View File

@ -59,7 +59,8 @@ class DnsOverHttpsTest {
private val cacheFs = FakeFileSystem() private val cacheFs = FakeFileSystem()
private val eventListener = RecordingEventListener() private val eventListener = RecordingEventListener()
private val bootstrapClient = private val bootstrapClient =
OkHttpClient.Builder() OkHttpClient
.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)) .protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.eventListener(eventListener) .eventListener(eventListener)
.build() .build()
@ -186,8 +187,7 @@ class DnsOverHttpsTest {
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112", "0003b00049df00112",
) ).newBuilder()
.newBuilder()
.setHeader("cache-control", "private, max-age=298") .setHeader("cache-control", "private, max-age=298")
.build(), .build(),
) )
@ -229,8 +229,7 @@ class DnsOverHttpsTest {
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112", "0003b00049df00112",
) ).newBuilder()
.newBuilder()
.setHeader("cache-control", "private, max-age=298") .setHeader("cache-control", "private, max-age=298")
.build(), .build(),
) )
@ -271,8 +270,7 @@ class DnsOverHttpsTest {
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112", "0003b00049df00112",
) ).newBuilder()
.newBuilder()
.setHeader("cache-control", "max-age=1") .setHeader("cache-control", "max-age=1")
.build(), .build(),
) )
@ -292,8 +290,7 @@ class DnsOverHttpsTest {
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112", "0003b00049df00112",
) ).newBuilder()
.newBuilder()
.setHeader("cache-control", "max-age=1") .setHeader("cache-control", "max-age=1")
.build(), .build(),
) )
@ -307,19 +304,18 @@ class DnsOverHttpsTest {
assertThat(cacheEvents()).containsExactly("CacheMiss") assertThat(cacheEvents()).containsExactly("CacheMiss")
} }
private fun cacheEvents(): List<String> { private fun cacheEvents(): List<String> =
return eventListener.recordedEventTypes().filter { it.contains("Cache") }.also { eventListener.recordedEventTypes().filter { it.contains("Cache") }.also {
eventListener.clearAllEvents() eventListener.clearAllEvents()
} }
}
private fun dnsResponse(s: String): MockResponse { private fun dnsResponse(s: String): MockResponse =
return MockResponse.Builder() MockResponse
.Builder()
.body(Buffer().write(s.decodeHex())) .body(Buffer().write(s.decodeHex()))
.addHeader("content-type", "application/dns-message") .addHeader("content-type", "application/dns-message")
.addHeader("content-length", s.length / 2) .addHeader("content-length", s.length / 2)
.build() .build()
}
private fun buildLocalhost( private fun buildLocalhost(
bootstrapClient: OkHttpClient, bootstrapClient: OkHttpClient,
@ -327,7 +323,9 @@ class DnsOverHttpsTest {
post: Boolean = false, post: Boolean = false,
): DnsOverHttps { ): DnsOverHttps {
val url = server.url("/lookup?ct") val url = server.url("/lookup?ct")
return DnsOverHttps.Builder().client(bootstrapClient) return DnsOverHttps
.Builder()
.client(bootstrapClient)
.includeIPv6(includeIPv6) .includeIPv6(includeIPv6)
.resolvePrivateAddresses(true) .resolvePrivateAddresses(true)
.url(url) .url(url)

View File

@ -39,9 +39,7 @@ class DnsRecordCodecTest {
private fun encodeQuery( private fun encodeQuery(
host: String, host: String,
type: Int, type: Int,
): String { ): String = DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "")
return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "")
}
@Test @Test
fun testGoogleDotComEncodingWithIPv6() { fun testGoogleDotComEncodingWithIPv6() {

View File

@ -26,73 +26,73 @@ import okhttp3.OkHttpClient
* https://github.com/curl/curl/wiki/DNS-over-HTTPS * https://github.com/curl/curl/wiki/DNS-over-HTTPS
*/ */
object DohProviders { object DohProviders {
private fun buildGoogle(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildGoogle(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://dns.google/dns-query".toHttpUrl()) .url("https://dns.google/dns-query".toHttpUrl())
.bootstrapDnsHosts(getByIp("8.8.4.4"), getByIp("8.8.8.8")) .bootstrapDnsHosts(getByIp("8.8.4.4"), getByIp("8.8.8.8"))
.build() .build()
}
private fun buildGooglePost(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildGooglePost(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://dns.google/dns-query".toHttpUrl()) .url("https://dns.google/dns-query".toHttpUrl())
.bootstrapDnsHosts(getByIp("8.8.4.4"), getByIp("8.8.8.8")) .bootstrapDnsHosts(getByIp("8.8.4.4"), getByIp("8.8.8.8"))
.post(true) .post(true)
.build() .build()
}
private fun buildCloudflareIp(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildCloudflareIp(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://1.1.1.1/dns-query".toHttpUrl()) .url("https://1.1.1.1/dns-query".toHttpUrl())
.includeIPv6(false) .includeIPv6(false)
.build() .build()
}
private fun buildCloudflare(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildCloudflare(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://1.1.1.1/dns-query".toHttpUrl()) .url("https://1.1.1.1/dns-query".toHttpUrl())
.bootstrapDnsHosts(getByIp("1.1.1.1"), getByIp("1.0.0.1")) .bootstrapDnsHosts(getByIp("1.1.1.1"), getByIp("1.0.0.1"))
.includeIPv6(false) .includeIPv6(false)
.build() .build()
}
private fun buildCloudflarePost(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildCloudflarePost(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://cloudflare-dns.com/dns-query".toHttpUrl()) .url("https://cloudflare-dns.com/dns-query".toHttpUrl())
.bootstrapDnsHosts(getByIp("1.1.1.1"), getByIp("1.0.0.1")) .bootstrapDnsHosts(getByIp("1.1.1.1"), getByIp("1.0.0.1"))
.includeIPv6(false) .includeIPv6(false)
.post(true) .post(true)
.build() .build()
}
fun buildCleanBrowsing(bootstrapClient: OkHttpClient): DnsOverHttps { fun buildCleanBrowsing(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://doh.cleanbrowsing.org/doh/family-filter/".toHttpUrl()) .url("https://doh.cleanbrowsing.org/doh/family-filter/".toHttpUrl())
.includeIPv6(false) .includeIPv6(false)
.build() .build()
}
private fun buildChantra(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildChantra(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://dns.dnsoverhttps.net/dns-query".toHttpUrl()) .url("https://dns.dnsoverhttps.net/dns-query".toHttpUrl())
.includeIPv6(false) .includeIPv6(false)
.build() .build()
}
private fun buildCryptoSx(bootstrapClient: OkHttpClient): DnsOverHttps { private fun buildCryptoSx(bootstrapClient: OkHttpClient): DnsOverHttps =
return DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url("https://doh.crypto.sx/dns-query".toHttpUrl()) .url("https://doh.crypto.sx/dns-query".toHttpUrl())
.includeIPv6(false) .includeIPv6(false)
.build() .build()
}
@JvmStatic @JvmStatic
fun providers( fun providers(
@ -100,8 +100,8 @@ object DohProviders {
http2Only: Boolean, http2Only: Boolean,
workingOnly: Boolean, workingOnly: Boolean,
getOnly: Boolean, getOnly: Boolean,
): List<DnsOverHttps> { ): List<DnsOverHttps> =
return buildList { buildList {
add(buildGoogle(client)) add(buildGoogle(client))
if (!getOnly) { if (!getOnly) {
add(buildGooglePost(client)) add(buildGooglePost(client))
@ -117,14 +117,12 @@ object DohProviders {
} }
add(buildChantra(client)) add(buildChantra(client))
} }
}
private fun getByIp(host: String): InetAddress { private fun getByIp(host: String): InetAddress =
return try { try {
InetAddress.getByName(host) InetAddress.getByName(host)
} catch (e: UnknownHostException) { } catch (e: UnknownHostException) {
// unlikely // unlikely
throw RuntimeException(e) throw RuntimeException(e)
} }
}
} }

View File

@ -74,7 +74,8 @@ fun main() {
val url = "https://dns.cloudflare.com/.not-so-well-known/run-dmc-query".toHttpUrl() val url = "https://dns.cloudflare.com/.not-so-well-known/run-dmc-query".toHttpUrl()
val badProviders = val badProviders =
listOf( listOf(
DnsOverHttps.Builder() DnsOverHttps
.Builder()
.client(bootstrapClient) .client(bootstrapClient)
.url(url) .url(url)
.post(true) .post(true)
@ -84,7 +85,8 @@ fun main() {
println("cached first run\n****************\n") println("cached first run\n****************\n")
names = listOf("google.com", "graph.facebook.com") names = listOf("google.com", "graph.facebook.com")
bootstrapClient = bootstrapClient =
bootstrapClient.newBuilder() bootstrapClient
.newBuilder()
.cache(dnsCache) .cache(dnsCache)
.build() .build()
dnsProviders = dnsProviders =

View File

@ -36,23 +36,21 @@ import okio.source
object HpackJsonUtil { object HpackJsonUtil {
@Suppress("unused") @Suppress("unused")
private val MOSHI = private val MOSHI =
Moshi.Builder() Moshi
.Builder()
.add( .add(
object : Any() { object : Any() {
@ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex() @ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex()
@FromJson fun byteStringFromJson(json: String) = json.decodeHex() @FromJson fun byteStringFromJson(json: String) = json.decodeHex()
}, },
) ).add(KotlinJsonAdapterFactory())
.add(KotlinJsonAdapterFactory())
.build() .build()
private val STORY_JSON_ADAPTER = MOSHI.adapter(Story::class.java) private val STORY_JSON_ADAPTER = MOSHI.adapter(Story::class.java)
private val fileSystem = FileSystem.SYSTEM private val fileSystem = FileSystem.SYSTEM
private fun readStory(source: BufferedSource): Story { private fun readStory(source: BufferedSource): Story = STORY_JSON_ADAPTER.fromJson(source)!!
return STORY_JSON_ADAPTER.fromJson(source)!!
}
private fun readStory(file: Path): Story { private fun readStory(file: Path): Story {
fileSystem.read(file) { fileSystem.read(file) {

View File

@ -55,9 +55,11 @@ fun generateMappingTableFile(data: IdnaMappingTableData): FileSpec {
val packageName = "okhttp3.internal.idn" val packageName = "okhttp3.internal.idn"
val idnaMappingTable = ClassName(packageName, "IdnaMappingTable") val idnaMappingTable = ClassName(packageName, "IdnaMappingTable")
return FileSpec.builder(packageName, "IdnaMappingTableInstance") return FileSpec
.builder(packageName, "IdnaMappingTableInstance")
.addProperty( .addProperty(
PropertySpec.builder("IDNA_MAPPING_TABLE", idnaMappingTable) PropertySpec
.builder("IDNA_MAPPING_TABLE", idnaMappingTable)
.addModifiers(KModifier.INTERNAL) .addModifiers(KModifier.INTERNAL)
.initializer( .initializer(
""" """
@ -71,18 +73,16 @@ fun generateMappingTableFile(data: IdnaMappingTableData): FileSpec {
data.sections.escapeDataString(), data.sections.escapeDataString(),
data.ranges.escapeDataString(), data.ranges.escapeDataString(),
data.mappings.escapeDataString(), data.mappings.escapeDataString(),
) ).build(),
.build(), ).build()
)
.build()
} }
/** /**
* KotlinPoet doesn't really know what to do with a string containing NUL, BEL, DEL, etc. We also * KotlinPoet doesn't really know what to do with a string containing NUL, BEL, DEL, etc. We also
* don't want to perform `trimMargin()` at runtime. * don't want to perform `trimMargin()` at runtime.
*/ */
fun String.escapeDataString(): String { fun String.escapeDataString(): String =
return buildString { buildString {
for (codePoint in this@escapeDataString.codePoints()) { for (codePoint in this@escapeDataString.codePoints()) {
when (codePoint) { when (codePoint) {
in 0..0x20, in 0..0x20,
@ -97,4 +97,3 @@ fun String.escapeDataString(): String {
} }
} }
} }
}

View File

@ -110,7 +110,11 @@ fun buildIdnaMappingTableData(table: SimpleIdnaMappingTable): IdnaMappingTableDa
internal fun inlineDeltaOrNull(mapping: Mapping): MappedRange.InlineDelta? { internal fun inlineDeltaOrNull(mapping: Mapping): MappedRange.InlineDelta? {
if (mapping.hasSingleSourceCodePoint) { if (mapping.hasSingleSourceCodePoint) {
val sourceCodePoint = mapping.sourceCodePoint0 val sourceCodePoint = mapping.sourceCodePoint0
val mappedCodePoints = mapping.mappedTo.utf8().codePoints().toList() val mappedCodePoints =
mapping.mappedTo
.utf8()
.codePoints()
.toList()
if (mappedCodePoints.size == 1) { if (mappedCodePoints.size == 1) {
val codePointDelta = mappedCodePoints.single() - sourceCodePoint val codePointDelta = mappedCodePoints.single() - sourceCodePoint
if (MappedRange.InlineDelta.MAX_VALUE >= abs(codePointDelta)) { if (MappedRange.InlineDelta.MAX_VALUE >= abs(codePointDelta)) {
@ -262,8 +266,8 @@ internal fun mergeAdjacentRanges(mappings: List<Mapping>): List<Mapping> {
return result return result
} }
internal fun canonicalizeType(type: Int): Int { internal fun canonicalizeType(type: Int): Int =
return when (type) { when (type) {
TYPE_IGNORED -> TYPE_IGNORED TYPE_IGNORED -> TYPE_IGNORED
TYPE_MAPPED, TYPE_MAPPED,
@ -279,7 +283,6 @@ internal fun canonicalizeType(type: Int): Int {
else -> error("unexpected type: $type") else -> error("unexpected type: $type")
} }
}
internal infix fun Byte.and(mask: Int): Int = toInt() and mask internal infix fun Byte.and(mask: Int): Int = toInt() and mask

View File

@ -225,10 +225,11 @@ class MappingTablesTest {
sourceCodePoint1 = sourceCodePoint1, sourceCodePoint1 = sourceCodePoint1,
type = TYPE_MAPPED, type = TYPE_MAPPED,
mappedTo = mappedTo =
Buffer().also { Buffer()
for (cp in mappedToCodePoints) { .also {
it.writeUtf8CodePoint(cp) for (cp in mappedToCodePoints) {
} it.writeUtf8CodePoint(cp)
}.readByteString(), }
}.readByteString(),
) )
} }

View File

@ -31,7 +31,9 @@ import okhttp3.internal.platform.Platform.Companion.WARN
import okhttp3.internal.trimSubstring import okhttp3.internal.trimSubstring
/** A cookie jar that delegates to a [java.net.CookieHandler]. */ /** A cookie jar that delegates to a [java.net.CookieHandler]. */
class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar { class JavaNetCookieJar(
private val cookieHandler: CookieHandler,
) : CookieJar {
override fun saveFromResponse( override fun saveFromResponse(
url: HttpUrl, url: HttpUrl,
cookies: List<Cookie>, cookies: List<Cookie>,
@ -112,7 +114,8 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
} }
result.add( result.add(
Cookie.Builder() Cookie
.Builder()
.name(name) .name(name)
.value(value) .value(value)
.domain(url.host) .domain(url.host)

View File

@ -326,14 +326,17 @@ class HttpLoggingInterceptor
if (queryParamsNameToRedact.isEmpty() || url.querySize == 0) { if (queryParamsNameToRedact.isEmpty() || url.querySize == 0) {
return url.toString() return url.toString()
} }
return url.newBuilder().query(null).apply { return url
for (i in 0 until url.querySize) { .newBuilder()
val parameterName = url.queryParameterName(i) .query(null)
val newValue = if (parameterName in queryParamsNameToRedact) "██" else url.queryParameterValue(i) .apply {
for (i in 0 until url.querySize) {
val parameterName = url.queryParameterName(i)
val newValue = if (parameterName in queryParamsNameToRedact) "██" else url.queryParameterValue(i)
addEncodedQueryParameter(parameterName, newValue) addEncodedQueryParameter(parameterName, newValue)
} }
}.toString() }.toString()
} }
private fun logHeader( private fun logHeader(

View File

@ -73,7 +73,8 @@ class HttpLoggingInterceptorTest {
fun setUp(server: MockWebServer) { fun setUp(server: MockWebServer) {
this.server = server this.server = server
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.addNetworkInterceptor( .addNetworkInterceptor(
Interceptor { chain -> Interceptor { chain ->
when { when {
@ -81,14 +82,12 @@ class HttpLoggingInterceptorTest {
else -> chain.proceed(chain.request()) else -> chain.proceed(chain.request())
} }
}, },
) ).addNetworkInterceptor(networkInterceptor)
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor) .addInterceptor(applicationInterceptor)
.sslSocketFactory( .sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager, handshakeCertificates.trustManager,
) ).hostnameVerifier(hostnameVerifier)
.hostnameVerifier(hostnameVerifier)
.build() .build()
host = "${server.hostName}:${server.port}" host = "${server.hostName}:${server.port}"
url = server.url("/") url = server.url("/")
@ -153,7 +152,8 @@ class HttpLoggingInterceptorTest {
fun basicResponseBody() { fun basicResponseBody() {
setLevel(Level.BASIC) setLevel(Level.BASIC)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build(), .build(),
@ -174,7 +174,8 @@ class HttpLoggingInterceptorTest {
fun basicChunkedResponseBody() { fun basicChunkedResponseBody() {
setLevel(Level.BASIC) setLevel(Level.BASIC)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.chunkedBody("Hello!", 2) .chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build(), .build(),
@ -320,7 +321,8 @@ class HttpLoggingInterceptorTest {
extraNetworkInterceptor = extraNetworkInterceptor =
Interceptor { chain: Interceptor.Chain -> Interceptor { chain: Interceptor.Chain ->
chain.proceed( chain.proceed(
chain.request() chain
.request()
.newBuilder() .newBuilder()
.header("Content-Length", "2") .header("Content-Length", "2")
.header("Content-Type", "text/plain-ish") .header("Content-Type", "text/plain-ish")
@ -328,11 +330,12 @@ class HttpLoggingInterceptorTest {
) )
} }
server.enqueue(MockResponse()) server.enqueue(MockResponse())
client.newCall( client
request() .newCall(
.post("Hi?".toRequestBody(PLAIN)) request()
.build(), .post("Hi?".toRequestBody(PLAIN))
).execute() .build(),
).execute()
applicationLogs applicationLogs
.assertLogEqual("--> POST $url") .assertLogEqual("--> POST $url")
.assertLogEqual("Content-Type: text/plain; charset=utf-8") .assertLogEqual("Content-Type: text/plain; charset=utf-8")
@ -361,7 +364,8 @@ class HttpLoggingInterceptorTest {
fun headersResponseBody() { fun headersResponseBody() {
setLevel(Level.HEADERS) setLevel(Level.HEADERS)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build(), .build(),
@ -430,7 +434,8 @@ class HttpLoggingInterceptorTest {
private fun bodyGetNoBody(code: Int) { private fun bodyGetNoBody(code: Int) {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.status("HTTP/1.1 $code No Content") .status("HTTP/1.1 $code No Content")
.build(), .build(),
) )
@ -495,7 +500,8 @@ class HttpLoggingInterceptorTest {
fun bodyResponseBody() { fun bodyResponseBody() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build(), .build(),
@ -532,7 +538,8 @@ class HttpLoggingInterceptorTest {
fun bodyResponseBodyChunked() { fun bodyResponseBodyChunked() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.chunkedBody("Hello!", 2) .chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build(), .build(),
@ -569,18 +576,20 @@ class HttpLoggingInterceptorTest {
fun bodyRequestGzipEncoded() { fun bodyRequestGzipEncoded() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.body(Buffer().writeUtf8("Uncompressed")) .body(Buffer().writeUtf8("Uncompressed"))
.build(), .build(),
) )
val response = val response =
client.newCall( client
request() .newCall(
.addHeader("Content-Encoding", "gzip") request()
.post("Uncompressed".toRequestBody().gzip()) .addHeader("Content-Encoding", "gzip")
.build(), .post("Uncompressed".toRequestBody().gzip())
).execute() .build(),
).execute()
val responseBody = response.body val responseBody = response.body
assertThat(responseBody.string(), "Expected response body to be valid") assertThat(responseBody.string(), "Expected response body to be valid")
.isEqualTo("Uncompressed") .isEqualTo("Uncompressed")
@ -608,7 +617,8 @@ class HttpLoggingInterceptorTest {
fun bodyResponseGzipEncoded() { fun bodyResponseGzipEncoded() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.setHeader("Content-Encoding", "gzip") .setHeader("Content-Encoding", "gzip")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.body(Buffer().write("H4sIAAAAAAAAAPNIzcnJ11HwQKIAdyO+9hMAAAA=".decodeBase64()!!)) .body(Buffer().write("H4sIAAAAAAAAAPNIzcnJ11HwQKIAdyO+9hMAAAA=".decodeBase64()!!))
@ -649,7 +659,8 @@ class HttpLoggingInterceptorTest {
fun bodyResponseUnknownEncoded() { fun bodyResponseUnknownEncoded() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() // It's invalid to return this if not requested, but the server might anyway MockResponse
.Builder() // It's invalid to return this if not requested, but the server might anyway
.setHeader("Content-Encoding", "br") .setHeader("Content-Encoding", "br")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.body(Buffer().write("iwmASGVsbG8sIEhlbGxvLCBIZWxsbwoD".decodeBase64()!!)) .body(Buffer().write("iwmASGVsbG8sIEhlbGxvLCBIZWxsbwoD".decodeBase64()!!))
@ -685,7 +696,8 @@ class HttpLoggingInterceptorTest {
fun bodyResponseIsStreaming() { fun bodyResponseIsStreaming() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.setHeader("Content-Type", "text/event-stream") .setHeader("Content-Type", "text/event-stream")
.chunkedBody( .chunkedBody(
""" """
@ -701,8 +713,7 @@ class HttpLoggingInterceptorTest {
| |
""".trimMargin(), """.trimMargin(),
8, 8,
) ).build(),
.build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -732,7 +743,8 @@ class HttpLoggingInterceptorTest {
fun bodyGetMalformedCharset() { fun bodyGetMalformedCharset() {
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.setHeader("Content-Type", "text/html; charset=0") .setHeader("Content-Type", "text/html; charset=0")
.body("Body with unknown charset") .body("Body with unknown charset")
.build(), .build(),
@ -778,7 +790,8 @@ class HttpLoggingInterceptorTest {
buffer.writeUtf8CodePoint(0x1a) buffer.writeUtf8CodePoint(0x1a)
buffer.writeUtf8CodePoint(0x0a) buffer.writeUtf8CodePoint(0x0a)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body(buffer) .body(buffer)
.setHeader("Content-Type", "image/png; charset=utf-8") .setHeader("Content-Type", "image/png; charset=utf-8")
.build(), .build(),
@ -813,7 +826,8 @@ class HttpLoggingInterceptorTest {
fun connectFail() { fun connectFail() {
setLevel(Level.BASIC) setLevel(Level.BASIC)
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.dns { hostname: String? -> throw UnknownHostException("reason") } .dns { hostname: String? -> throw UnknownHostException("reason") }
.addInterceptor(applicationInterceptor) .addInterceptor(applicationInterceptor)
.build() .build()
@ -859,13 +873,16 @@ class HttpLoggingInterceptorTest {
) )
applicationInterceptor.redactHeader("sEnSiTiVe") applicationInterceptor.redactHeader("sEnSiTiVe")
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.addNetworkInterceptor(networkInterceptor) .addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor) .addInterceptor(applicationInterceptor)
.build() .build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.addHeader("SeNsItIvE", "Value").addHeader("Not-Sensitive", "Value") .Builder()
.addHeader("SeNsItIvE", "Value")
.addHeader("Not-Sensitive", "Value")
.build(), .build(),
) )
val response = val response =
@ -875,8 +892,7 @@ class HttpLoggingInterceptorTest {
.addHeader("SeNsItIvE", "Value") .addHeader("SeNsItIvE", "Value")
.addHeader("Not-Sensitive", "Value") .addHeader("Not-Sensitive", "Value")
.build(), .build(),
) ).execute()
.execute()
response.body.close() response.body.close()
applicationLogs applicationLogs
.assertLogEqual("--> GET $url") .assertLogEqual("--> GET $url")
@ -923,12 +939,14 @@ class HttpLoggingInterceptorTest {
applicationInterceptor.redactQueryParams("user", "PassworD") applicationInterceptor.redactQueryParams("user", "PassworD")
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.addNetworkInterceptor(networkInterceptor) .addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor) .addInterceptor(applicationInterceptor)
.build() .build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.build(), .build(),
) )
val response = val response =
@ -936,8 +954,7 @@ class HttpLoggingInterceptorTest {
.newCall( .newCall(
request() request()
.build(), .build(),
) ).execute()
.execute()
response.body.close() response.body.close()
val redactedUrl = networkInterceptor.redactUrl(url) val redactedUrl = networkInterceptor.redactUrl(url)
val redactedUrlPattern = redactedUrl.replace("?", """\?""") val redactedUrlPattern = redactedUrl.replace("?", """\?""")
@ -976,12 +993,14 @@ class HttpLoggingInterceptorTest {
applicationInterceptor.redactQueryParams("user", "PassworD") applicationInterceptor.redactQueryParams("user", "PassworD")
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.addNetworkInterceptor(networkInterceptor) .addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor) .addInterceptor(applicationInterceptor)
.build() .build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.build(), .build(),
) )
val response = val response =
@ -989,8 +1008,7 @@ class HttpLoggingInterceptorTest {
.newCall( .newCall(
request() request()
.build(), .build(),
) ).execute()
.execute()
response.body.close() response.body.close()
val redactedUrl = networkInterceptor.redactUrl(url) val redactedUrl = networkInterceptor.redactUrl(url)
val redactedUrlPattern = redactedUrl.replace("?", """\?""") val redactedUrlPattern = redactedUrl.replace("?", """\?""")
@ -1011,24 +1029,21 @@ class HttpLoggingInterceptorTest {
url = server.url("/") url = server.url("/")
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Hello response!") .body("Hello response!")
.build(), .build(),
) )
val asyncRequestBody: RequestBody = val asyncRequestBody: RequestBody =
object : RequestBody() { object : RequestBody() {
override fun contentType(): MediaType? { override fun contentType(): MediaType? = null
return null
}
override fun writeTo(sink: BufferedSink) { override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hello request!") sink.writeUtf8("Hello request!")
sink.close() sink.close()
} }
override fun isDuplex(): Boolean { override fun isDuplex(): Boolean = true
return true
}
} }
val request = val request =
request() request()
@ -1053,7 +1068,8 @@ class HttpLoggingInterceptorTest {
url = server.url("/") url = server.url("/")
setLevel(Level.BODY) setLevel(Level.BODY)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Hello response!") .body("Hello response!")
.build(), .build(),
) )
@ -1089,9 +1105,7 @@ class HttpLoggingInterceptorTest {
.assertNoMoreLogs() .assertNoMoreLogs()
} }
private fun request(): Request.Builder { private fun request(): Request.Builder = Request.Builder().url(url)
return Request.Builder().url(url)
}
internal class LogRecorder( internal class LogRecorder(
val prefix: Regex = Regex(""), val prefix: Regex = Regex(""),

View File

@ -62,13 +62,13 @@ class LoggingEventListenerTest {
fun setUp(server: MockWebServer) { fun setUp(server: MockWebServer) {
this.server = server this.server = server
client = client =
clientTestRule.newClientBuilder() clientTestRule
.newClientBuilder()
.eventListenerFactory(loggingEventListenerFactory) .eventListenerFactory(loggingEventListenerFactory)
.sslSocketFactory( .sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager, handshakeCertificates.trustManager,
) ).retryOnConnectionFailure(false)
.retryOnConnectionFailure(false)
.build() .build()
url = server.url("/") url = server.url("/")
} }
@ -77,7 +77,8 @@ class LoggingEventListenerTest {
fun get() { fun get() {
assumeNotWindows() assumeNotWindows()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build(), .build(),
@ -97,8 +98,7 @@ class LoggingEventListenerTest {
Regex( Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""", """connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
), ),
) ).assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd""")) .assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart""")) .assertLogMatch(Regex("""responseHeadersStart"""))
.assertLogMatch(Regex("""responseHeadersEnd: Response\{protocol=http/1\.1, code=200, message=OK, url=$url\}""")) .assertLogMatch(Regex("""responseHeadersEnd: Response\{protocol=http/1\.1, code=200, message=OK, url=$url\}"""))
@ -126,8 +126,7 @@ class LoggingEventListenerTest {
Regex( Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""", """connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
), ),
) ).assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd""")) .assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""requestBodyStart""")) .assertLogMatch(Regex("""requestBodyStart"""))
.assertLogMatch(Regex("""requestBodyEnd: byteCount=6""")) .assertLogMatch(Regex("""requestBodyEnd: byteCount=6"""))
@ -162,12 +161,10 @@ class LoggingEventListenerTest {
Regex( Regex(
"""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}""", """secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}""",
), ),
) ).assertLogMatch(Regex("""connectEnd: h2"""))
.assertLogMatch(Regex("""connectEnd: h2"""))
.assertLogMatch( .assertLogMatch(
Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""), Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""),
) ).assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd""")) .assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart""")) .assertLogMatch(Regex("""responseHeadersStart"""))
.assertLogMatch(Regex("""responseHeadersEnd: Response\{protocol=h2, code=200, message=, url=$url\}""")) .assertLogMatch(Regex("""responseHeadersEnd: Response\{protocol=h2, code=200, message=, url=$url\}"""))
@ -181,7 +178,8 @@ class LoggingEventListenerTest {
@Test @Test
fun dnsFail() { fun dnsFail() {
client = client =
OkHttpClient.Builder() OkHttpClient
.Builder()
.dns { _ -> throw UnknownHostException("reason") } .dns { _ -> throw UnknownHostException("reason") }
.eventListenerFactory(loggingEventListenerFactory) .eventListenerFactory(loggingEventListenerFactory)
.build() .build()
@ -205,7 +203,8 @@ class LoggingEventListenerTest {
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1) server.protocols = Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1)
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.socketPolicy(FailHandshake) .socketPolicy(FailHandshake)
.build(), .build(),
) )
@ -227,13 +226,11 @@ class LoggingEventListenerTest {
Regex( Regex(
"""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""", """connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
), ),
) ).assertLogMatch(
.assertLogMatch(
Regex( Regex(
"""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""", """callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
), ),
) ).assertNoMoreLogs()
.assertNoMoreLogs()
} }
@Test @Test
@ -241,7 +238,12 @@ class LoggingEventListenerTest {
val request = Request.Builder().url(url).build() val request = Request.Builder().url(url).build()
val call = client.newCall(request) val call = client.newCall(request)
val response = val response =
Response.Builder().request(request).code(200).message("").protocol(Protocol.HTTP_2) Response
.Builder()
.request(request)
.code(200)
.message("")
.protocol(Protocol.HTTP_2)
.build() .build()
val listener = loggingEventListenerFactory.create(call) val listener = loggingEventListenerFactory.create(call)
listener.cacheConditionalHit(call, response) listener.cacheConditionalHit(call, response)
@ -256,9 +258,7 @@ class LoggingEventListenerTest {
.assertNoMoreLogs() .assertNoMoreLogs()
} }
private fun request(): Request.Builder { private fun request(): Request.Builder = Request.Builder().url(url)
return Request.Builder().url(url)
}
companion object { companion object {
private val PLAIN = "text/plain".toMediaType() private val PLAIN = "text/plain".toMediaType()

View File

@ -116,7 +116,8 @@ class OsgiTest {
private fun RepositoryPlugin.deployClassPath() { private fun RepositoryPlugin.deployClassPath() {
val classpath = System.getProperty("java.class.path") val classpath = System.getProperty("java.class.path")
val entries = val entries =
classpath.split(File.pathSeparator.toRegex()) classpath
.split(File.pathSeparator.toRegex())
.dropLastWhile { it.isEmpty() } .dropLastWhile { it.isEmpty() }
.toTypedArray() .toTypedArray()
for (classPathEntry in entries) { for (classPathEntry in entries) {

View File

@ -29,8 +29,8 @@ object EventSources {
fun createFactory(client: OkHttpClient) = createFactory(client as Call.Factory) fun createFactory(client: OkHttpClient) = createFactory(client as Call.Factory)
@JvmStatic @JvmStatic
fun createFactory(callFactory: Call.Factory): EventSource.Factory { fun createFactory(callFactory: Call.Factory): EventSource.Factory =
return EventSource.Factory { request, listener -> EventSource.Factory { request, listener ->
val actualRequest = val actualRequest =
if (request.header("Accept") == null) { if (request.header("Accept") == null) {
request.newBuilder().addHeader("Accept", "text/event-stream").build() request.newBuilder().addHeader("Accept", "text/event-stream").build()
@ -42,7 +42,6 @@ object EventSources {
connect(callFactory) connect(callFactory)
} }
} }
}
@JvmStatic @JvmStatic
fun processResponse( fun processResponse(

View File

@ -28,7 +28,9 @@ import okhttp3.sse.EventSourceListener
internal class RealEventSource( internal class RealEventSource(
private val request: Request, private val request: Request,
private val listener: EventSourceListener, private val listener: EventSourceListener,
) : EventSource, ServerSentEventReader.Callback, Callback { ) : EventSource,
ServerSentEventReader.Callback,
Callback {
private var call: Call? = null private var call: Call? = null
@Volatile private var canceled = false @Volatile private var canceled = false

View File

@ -48,7 +48,8 @@ class EventSourceHttpTest {
private val eventListener = RecordingEventListener() private val eventListener = RecordingEventListener()
private val listener = EventSourceRecorder() private val listener = EventSourceRecorder()
private var client = private var client =
clientTestRule.newClientBuilder() clientTestRule
.newClientBuilder()
.eventListenerFactory(clientTestRule.wrap(eventListener)) .eventListenerFactory(clientTestRule.wrap(eventListener))
.build() .build()
@ -65,7 +66,8 @@ class EventSourceHttpTest {
@Test @Test
fun event() { fun event() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -85,7 +87,8 @@ class EventSourceHttpTest {
@RetryingTest(5) @RetryingTest(5)
fun cancelInEventShortCircuits() { fun cancelInEventShortCircuits() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -104,7 +107,8 @@ class EventSourceHttpTest {
@Test @Test
fun badContentType() { fun badContentType() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -121,15 +125,15 @@ class EventSourceHttpTest {
@Test @Test
fun badResponseCode() { fun badResponseCode() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
| |
| |
""".trimMargin(), """.trimMargin(),
) ).setHeader("content-type", "text/event-stream")
.setHeader("content-type", "text/event-stream")
.code(401) .code(401)
.build(), .build(),
) )
@ -140,11 +144,13 @@ class EventSourceHttpTest {
@Test @Test
fun fullCallTimeoutDoesNotApplyOnceConnected() { fun fullCallTimeoutDoesNotApplyOnceConnected() {
client = client =
client.newBuilder() client
.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS) .callTimeout(250, TimeUnit.MILLISECONDS)
.build() .build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.bodyDelay(500, TimeUnit.MILLISECONDS) .bodyDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream") .setHeader("content-type", "text/event-stream")
.body("data: hey\n\n") .body("data: hey\n\n")
@ -160,11 +166,13 @@ class EventSourceHttpTest {
@Test @Test
fun fullCallTimeoutAppliesToSetup() { fun fullCallTimeoutAppliesToSetup() {
client = client =
client.newBuilder() client
.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS) .callTimeout(250, TimeUnit.MILLISECONDS)
.build() .build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.headersDelay(500, TimeUnit.MILLISECONDS) .headersDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream") .setHeader("content-type", "text/event-stream")
.body("data: hey\n\n") .body("data: hey\n\n")
@ -177,15 +185,15 @@ class EventSourceHttpTest {
@Test @Test
fun retainsAccept() { fun retainsAccept() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
| |
| |
""".trimMargin(), """.trimMargin(),
) ).setHeader("content-type", "text/event-stream")
.setHeader("content-type", "text/event-stream")
.build(), .build(),
) )
newEventSource("text/plain") newEventSource("text/plain")
@ -198,7 +206,8 @@ class EventSourceHttpTest {
@Test @Test
fun setsMissingAccept() { fun setsMissingAccept() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -219,7 +228,8 @@ class EventSourceHttpTest {
@Test @Test
fun eventListenerEvents() { fun eventListenerEvents() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -256,7 +266,8 @@ class EventSourceHttpTest {
private fun newEventSource(accept: String? = null): EventSource { private fun newEventSource(accept: String? = null): EventSource {
val builder = val builder =
Request.Builder() Request
.Builder()
.url(server.url("/")) .url(server.url("/"))
if (accept != null) { if (accept != null) {
builder.header("Accept", accept) builder.header("Accept", accept)

View File

@ -78,10 +78,9 @@ class EventSourceRecorder : EventSourceListener() {
} }
} }
private fun nextEvent(): Any { private fun nextEvent(): Any =
return events.poll(10, TimeUnit.SECONDS) events.poll(10, TimeUnit.SECONDS)
?: throw AssertionError("Timed out waiting for event.") ?: throw AssertionError("Timed out waiting for event.")
}
fun assertExhausted() { fun assertExhausted() {
assertThat(events).isEmpty() assertThat(events).isEmpty()

View File

@ -55,7 +55,8 @@ class EventSourcesHttpTest {
@Test @Test
fun processResponse() { fun processResponse() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -66,7 +67,8 @@ class EventSourcesHttpTest {
.build(), .build(),
) )
val request = val request =
Request.Builder() Request
.Builder()
.url(server.url("/")) .url(server.url("/"))
.build() .build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -79,7 +81,8 @@ class EventSourcesHttpTest {
@Test @Test
fun cancelShortCircuits() { fun cancelShortCircuits() {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse
.Builder()
.body( .body(
""" """
|data: hey |data: hey
@ -91,7 +94,8 @@ class EventSourcesHttpTest {
) )
listener.enqueueCancel() // Will cancel in onOpen(). listener.enqueueCancel() // Will cancel in onOpen().
val request = val request =
Request.Builder() Request
.Builder()
.url(server.url("/")) .url(server.url("/"))
.build() .build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()

View File

@ -23,30 +23,22 @@ import javax.net.ssl.SSLSessionContext
import javax.security.cert.X509Certificate import javax.security.cert.X509Certificate
/** An [SSLSession] that delegates all calls. */ /** An [SSLSession] that delegates all calls. */
abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSession { abstract class DelegatingSSLSession(
override fun getId(): ByteArray { protected val delegate: SSLSession?,
return delegate!!.id ) : SSLSession {
} override fun getId(): ByteArray = delegate!!.id
override fun getSessionContext(): SSLSessionContext { override fun getSessionContext(): SSLSessionContext = delegate!!.sessionContext
return delegate!!.sessionContext
}
override fun getCreationTime(): Long { override fun getCreationTime(): Long = delegate!!.creationTime
return delegate!!.creationTime
}
override fun getLastAccessedTime(): Long { override fun getLastAccessedTime(): Long = delegate!!.lastAccessedTime
return delegate!!.lastAccessedTime
}
override fun invalidate() { override fun invalidate() {
delegate!!.invalidate() delegate!!.invalidate()
} }
override fun isValid(): Boolean { override fun isValid(): Boolean = delegate!!.isValid
return delegate!!.isValid
}
override fun putValue( override fun putValue(
s: String, s: String,
@ -55,62 +47,36 @@ abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSe
delegate!!.putValue(s, o) delegate!!.putValue(s, o)
} }
override fun getValue(s: String): Any { override fun getValue(s: String): Any = delegate!!.getValue(s)
return delegate!!.getValue(s)
}
override fun removeValue(s: String) { override fun removeValue(s: String) {
delegate!!.removeValue(s) delegate!!.removeValue(s)
} }
override fun getValueNames(): Array<String> { override fun getValueNames(): Array<String> = delegate!!.valueNames
return delegate!!.valueNames
}
@Throws(SSLPeerUnverifiedException::class) @Throws(SSLPeerUnverifiedException::class)
override fun getPeerCertificates(): Array<Certificate>? { override fun getPeerCertificates(): Array<Certificate>? = delegate!!.peerCertificates
return delegate!!.peerCertificates
}
override fun getLocalCertificates(): Array<Certificate>? { override fun getLocalCertificates(): Array<Certificate>? = delegate!!.localCertificates
return delegate!!.localCertificates
}
@Throws(SSLPeerUnverifiedException::class) @Throws(SSLPeerUnverifiedException::class)
override fun getPeerCertificateChain(): Array<X509Certificate> { override fun getPeerCertificateChain(): Array<X509Certificate> = delegate!!.peerCertificateChain
return delegate!!.peerCertificateChain
}
@Throws(SSLPeerUnverifiedException::class) @Throws(SSLPeerUnverifiedException::class)
override fun getPeerPrincipal(): Principal { override fun getPeerPrincipal(): Principal = delegate!!.peerPrincipal
return delegate!!.peerPrincipal
}
override fun getLocalPrincipal(): Principal { override fun getLocalPrincipal(): Principal = delegate!!.localPrincipal
return delegate!!.localPrincipal
}
override fun getCipherSuite(): String { override fun getCipherSuite(): String = delegate!!.cipherSuite
return delegate!!.cipherSuite
}
override fun getProtocol(): String { override fun getProtocol(): String = delegate!!.protocol
return delegate!!.protocol
}
override fun getPeerHost(): String { override fun getPeerHost(): String = delegate!!.peerHost
return delegate!!.peerHost
}
override fun getPeerPort(): Int { override fun getPeerPort(): Int = delegate!!.peerPort
return delegate!!.peerPort
}
override fun getPacketBufferSize(): Int { override fun getPacketBufferSize(): Int = delegate!!.packetBufferSize
return delegate!!.packetBufferSize
}
override fun getApplicationBufferSize(): Int { override fun getApplicationBufferSize(): Int = delegate!!.applicationBufferSize
return delegate!!.applicationBufferSize
}
} }

View File

@ -29,7 +29,9 @@ import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSocket import javax.net.ssl.SSLSocket
/** An [SSLSocket] that delegates all calls. */ /** An [SSLSocket] that delegates all calls. */
abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSocket() { abstract class DelegatingSSLSocket(
protected val delegate: SSLSocket?,
) : SSLSocket() {
@Throws(IOException::class) @Throws(IOException::class)
override fun shutdownInput() { override fun shutdownInput() {
delegate!!.shutdownInput() delegate!!.shutdownInput()
@ -40,33 +42,23 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
delegate!!.shutdownOutput() delegate!!.shutdownOutput()
} }
override fun getSupportedCipherSuites(): Array<String> { override fun getSupportedCipherSuites(): Array<String> = delegate!!.supportedCipherSuites
return delegate!!.supportedCipherSuites
}
override fun getEnabledCipherSuites(): Array<String> { override fun getEnabledCipherSuites(): Array<String> = delegate!!.enabledCipherSuites
return delegate!!.enabledCipherSuites
}
override fun setEnabledCipherSuites(suites: Array<String>) { override fun setEnabledCipherSuites(suites: Array<String>) {
delegate!!.enabledCipherSuites = suites delegate!!.enabledCipherSuites = suites
} }
override fun getSupportedProtocols(): Array<String> { override fun getSupportedProtocols(): Array<String> = delegate!!.supportedProtocols
return delegate!!.supportedProtocols
}
override fun getEnabledProtocols(): Array<String> { override fun getEnabledProtocols(): Array<String> = delegate!!.enabledProtocols
return delegate!!.enabledProtocols
}
override fun setEnabledProtocols(protocols: Array<String>) { override fun setEnabledProtocols(protocols: Array<String>) {
delegate!!.enabledProtocols = protocols delegate!!.enabledProtocols = protocols
} }
override fun getSession(): SSLSession { override fun getSession(): SSLSession = delegate!!.session
return delegate!!.session
}
override fun addHandshakeCompletedListener(listener: HandshakeCompletedListener) { override fun addHandshakeCompletedListener(listener: HandshakeCompletedListener) {
delegate!!.addHandshakeCompletedListener(listener) delegate!!.addHandshakeCompletedListener(listener)
@ -85,9 +77,7 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
delegate!!.useClientMode = mode delegate!!.useClientMode = mode
} }
override fun getUseClientMode(): Boolean { override fun getUseClientMode(): Boolean = delegate!!.useClientMode
return delegate!!.useClientMode
}
override fun setNeedClientAuth(need: Boolean) { override fun setNeedClientAuth(need: Boolean) {
delegate!!.needClientAuth = need delegate!!.needClientAuth = need
@ -97,25 +87,17 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
delegate!!.wantClientAuth = want delegate!!.wantClientAuth = want
} }
override fun getNeedClientAuth(): Boolean { override fun getNeedClientAuth(): Boolean = delegate!!.needClientAuth
return delegate!!.needClientAuth
}
override fun getWantClientAuth(): Boolean { override fun getWantClientAuth(): Boolean = delegate!!.wantClientAuth
return delegate!!.wantClientAuth
}
override fun setEnableSessionCreation(flag: Boolean) { override fun setEnableSessionCreation(flag: Boolean) {
delegate!!.enableSessionCreation = flag delegate!!.enableSessionCreation = flag
} }
override fun getEnableSessionCreation(): Boolean { override fun getEnableSessionCreation(): Boolean = delegate!!.enableSessionCreation
return delegate!!.enableSessionCreation
}
override fun getSSLParameters(): SSLParameters { override fun getSSLParameters(): SSLParameters = delegate!!.sslParameters
return delegate!!.sslParameters
}
override fun setSSLParameters(p: SSLParameters) { override fun setSSLParameters(p: SSLParameters) {
delegate!!.sslParameters = p delegate!!.sslParameters = p
@ -126,61 +108,37 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
delegate!!.close() delegate!!.close()
} }
override fun getInetAddress(): InetAddress { override fun getInetAddress(): InetAddress = delegate!!.inetAddress
return delegate!!.inetAddress
}
@Throws(IOException::class) @Throws(IOException::class)
override fun getInputStream(): InputStream { override fun getInputStream(): InputStream = delegate!!.inputStream
return delegate!!.inputStream
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getKeepAlive(): Boolean { override fun getKeepAlive(): Boolean = delegate!!.keepAlive
return delegate!!.keepAlive
}
override fun getLocalAddress(): InetAddress { override fun getLocalAddress(): InetAddress = delegate!!.localAddress
return delegate!!.localAddress
}
override fun getLocalPort(): Int { override fun getLocalPort(): Int = delegate!!.localPort
return delegate!!.localPort
}
@Throws(IOException::class) @Throws(IOException::class)
override fun getOutputStream(): OutputStream { override fun getOutputStream(): OutputStream = delegate!!.outputStream
return delegate!!.outputStream
}
override fun getPort(): Int { override fun getPort(): Int = delegate!!.port
return delegate!!.port
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getSoLinger(): Int { override fun getSoLinger(): Int = delegate!!.soLinger
return delegate!!.soLinger
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getReceiveBufferSize(): Int { override fun getReceiveBufferSize(): Int = delegate!!.receiveBufferSize
return delegate!!.receiveBufferSize
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getSendBufferSize(): Int { override fun getSendBufferSize(): Int = delegate!!.sendBufferSize
return delegate!!.sendBufferSize
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getSoTimeout(): Int { override fun getSoTimeout(): Int = delegate!!.soTimeout
return delegate!!.soTimeout
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getTcpNoDelay(): Boolean { override fun getTcpNoDelay(): Boolean = delegate!!.tcpNoDelay
return delegate!!.tcpNoDelay
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun setKeepAlive(keepAlive: Boolean) { override fun setKeepAlive(keepAlive: Boolean) {
@ -215,29 +173,17 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
delegate!!.tcpNoDelay = on delegate!!.tcpNoDelay = on
} }
override fun toString(): String { override fun toString(): String = delegate!!.toString()
return delegate!!.toString()
}
override fun getLocalSocketAddress(): SocketAddress { override fun getLocalSocketAddress(): SocketAddress = delegate!!.localSocketAddress
return delegate!!.localSocketAddress
}
override fun getRemoteSocketAddress(): SocketAddress { override fun getRemoteSocketAddress(): SocketAddress = delegate!!.remoteSocketAddress
return delegate!!.remoteSocketAddress
}
override fun isBound(): Boolean { override fun isBound(): Boolean = delegate!!.isBound
return delegate!!.isBound
}
override fun isConnected(): Boolean { override fun isConnected(): Boolean = delegate!!.isConnected
return delegate!!.isConnected
}
override fun isClosed(): Boolean { override fun isClosed(): Boolean = delegate!!.isClosed
return delegate!!.isClosed
}
@Throws(IOException::class) @Throws(IOException::class)
override fun bind(localAddr: SocketAddress) { override fun bind(localAddr: SocketAddress) {
@ -257,13 +203,9 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
delegate!!.connect(remoteAddr, timeout) delegate!!.connect(remoteAddr, timeout)
} }
override fun isInputShutdown(): Boolean { override fun isInputShutdown(): Boolean = delegate!!.isInputShutdown
return delegate!!.isInputShutdown
}
override fun isOutputShutdown(): Boolean { override fun isOutputShutdown(): Boolean = delegate!!.isOutputShutdown
return delegate!!.isOutputShutdown
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun setReuseAddress(reuse: Boolean) { override fun setReuseAddress(reuse: Boolean) {
@ -271,9 +213,7 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
} }
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getReuseAddress(): Boolean { override fun getReuseAddress(): Boolean = delegate!!.reuseAddress
return delegate!!.reuseAddress
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun setOOBInline(oobinline: Boolean) { override fun setOOBInline(oobinline: Boolean) {
@ -281,9 +221,7 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
} }
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getOOBInline(): Boolean { override fun getOOBInline(): Boolean = delegate!!.oobInline
return delegate!!.oobInline
}
@Throws(SocketException::class) @Throws(SocketException::class)
override fun setTrafficClass(value: Int) { override fun setTrafficClass(value: Int) {
@ -291,36 +229,25 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
} }
@Throws(SocketException::class) @Throws(SocketException::class)
override fun getTrafficClass(): Int { override fun getTrafficClass(): Int = delegate!!.trafficClass
return delegate!!.trafficClass
}
@Throws(IOException::class) @Throws(IOException::class)
override fun sendUrgentData(value: Int) { override fun sendUrgentData(value: Int) {
delegate!!.sendUrgentData(value) delegate!!.sendUrgentData(value)
} }
override fun getChannel(): SocketChannel { override fun getChannel(): SocketChannel = delegate!!.channel
return delegate!!.channel
}
override fun getHandshakeSession(): SSLSession { override fun getHandshakeSession(): SSLSession = delegate!!.handshakeSession
return delegate!!.handshakeSession
}
override fun getApplicationProtocol(): String { override fun getApplicationProtocol(): String = delegate!!.applicationProtocol
return delegate!!.applicationProtocol
}
override fun getHandshakeApplicationProtocol(): String { override fun getHandshakeApplicationProtocol(): String = delegate!!.handshakeApplicationProtocol
return delegate!!.handshakeApplicationProtocol
}
override fun setHandshakeApplicationProtocolSelector(selector: BiFunction<SSLSocket, MutableList<String>, String>?) { override fun setHandshakeApplicationProtocolSelector(selector: BiFunction<SSLSocket, MutableList<String>, String>?) {
delegate!!.handshakeApplicationProtocolSelector = selector delegate!!.handshakeApplicationProtocolSelector = selector
} }
override fun getHandshakeApplicationProtocolSelector(): BiFunction<SSLSocket, MutableList<String>, String> { override fun getHandshakeApplicationProtocolSelector(): BiFunction<SSLSocket, MutableList<String>, String> =
return delegate!!.handshakeApplicationProtocolSelector delegate!!.handshakeApplicationProtocolSelector
}
} }

View File

@ -25,7 +25,9 @@ import javax.net.ssl.SSLSocketFactory
* A [SSLSocketFactory] that delegates calls. Sockets can be configured after creation by * A [SSLSocketFactory] that delegates calls. Sockets can be configured after creation by
* overriding [.configureSocket]. * overriding [.configureSocket].
*/ */
open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) : SSLSocketFactory() { open class DelegatingSSLSocketFactory(
private val delegate: SSLSocketFactory,
) : SSLSocketFactory() {
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket(): SSLSocket { override fun createSocket(): SSLSocket {
val sslSocket = delegate.createSocket() as SSLSocket val sslSocket = delegate.createSocket() as SSLSocket
@ -72,13 +74,9 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
return configureSocket(sslSocket) return configureSocket(sslSocket)
} }
override fun getDefaultCipherSuites(): Array<String> { override fun getDefaultCipherSuites(): Array<String> = delegate.defaultCipherSuites
return delegate.defaultCipherSuites
}
override fun getSupportedCipherSuites(): Array<String> { override fun getSupportedCipherSuites(): Array<String> = delegate.supportedCipherSuites
return delegate.supportedCipherSuites
}
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket( override fun createSocket(

View File

@ -24,7 +24,9 @@ import javax.net.ServerSocketFactory
* A [ServerSocketFactory] that delegates calls. Sockets can be configured after creation by * A [ServerSocketFactory] that delegates calls. Sockets can be configured after creation by
* overriding [.configureServerSocket]. * overriding [.configureServerSocket].
*/ */
open class DelegatingServerSocketFactory(private val delegate: ServerSocketFactory) : ServerSocketFactory() { open class DelegatingServerSocketFactory(
private val delegate: ServerSocketFactory,
) : ServerSocketFactory() {
@Throws(IOException::class) @Throws(IOException::class)
override fun createServerSocket(): ServerSocket { override fun createServerSocket(): ServerSocket {
val serverSocket = delegate.createServerSocket() val serverSocket = delegate.createServerSocket()

View File

@ -24,7 +24,9 @@ import javax.net.SocketFactory
* A [SocketFactory] that delegates calls. Sockets can be configured after creation by * A [SocketFactory] that delegates calls. Sockets can be configured after creation by
* overriding [.configureSocket]. * overriding [.configureSocket].
*/ */
open class DelegatingSocketFactory(private val delegate: SocketFactory) : SocketFactory() { open class DelegatingSocketFactory(
private val delegate: SocketFactory,
) : SocketFactory() {
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket(): Socket { override fun createSocket(): Socket {
val socket = delegate.createSocket() val socket = delegate.createSocket()

View File

@ -45,9 +45,7 @@ class FakeDns : Dns {
fun lookup( fun lookup(
hostname: String, hostname: String,
index: Int, index: Int,
): InetAddress { ): InetAddress = hostAddresses[hostname]!![index]
return hostAddresses[hostname]!![index]
}
@Throws(UnknownHostException::class) @Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> { override fun lookup(hostname: String): List<InetAddress> {

View File

@ -23,101 +23,62 @@ import javax.net.ssl.SSLSession
import javax.net.ssl.SSLSessionContext import javax.net.ssl.SSLSessionContext
import javax.security.cert.X509Certificate import javax.security.cert.X509Certificate
class FakeSSLSession(vararg val certificates: Certificate) : SSLSession { class FakeSSLSession(
override fun getApplicationBufferSize(): Int { vararg val certificates: Certificate,
throw UnsupportedOperationException() ) : SSLSession {
} override fun getApplicationBufferSize(): Int = throw UnsupportedOperationException()
override fun getCipherSuite(): String { override fun getCipherSuite(): String = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getCreationTime(): Long { override fun getCreationTime(): Long = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getId(): ByteArray { override fun getId(): ByteArray = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getLastAccessedTime(): Long { override fun getLastAccessedTime(): Long = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getLocalCertificates(): Array<Certificate> { override fun getLocalCertificates(): Array<Certificate> = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getLocalPrincipal(): Principal { override fun getLocalPrincipal(): Principal = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getPacketBufferSize(): Int { override fun getPacketBufferSize(): Int = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
@Throws(SSLPeerUnverifiedException::class) @Throws(SSLPeerUnverifiedException::class)
override fun getPeerCertificates(): Array<Certificate> { override fun getPeerCertificates(): Array<Certificate> =
return if (certificates.isEmpty()) { if (certificates.isEmpty()) {
throw SSLPeerUnverifiedException("peer not authenticated") throw SSLPeerUnverifiedException("peer not authenticated")
} else { } else {
certificates as Array<Certificate> certificates as Array<Certificate>
} }
}
@Throws( @Throws(
SSLPeerUnverifiedException::class, SSLPeerUnverifiedException::class,
) )
override fun getPeerCertificateChain(): Array<X509Certificate> { override fun getPeerCertificateChain(): Array<X509Certificate> = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getPeerHost(): String { override fun getPeerHost(): String = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getPeerPort(): Int { override fun getPeerPort(): Int = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
@Throws(SSLPeerUnverifiedException::class) @Throws(SSLPeerUnverifiedException::class)
override fun getPeerPrincipal(): Principal { override fun getPeerPrincipal(): Principal = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getProtocol(): String { override fun getProtocol(): String = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getSessionContext(): SSLSessionContext { override fun getSessionContext(): SSLSessionContext = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun putValue( override fun putValue(
s: String, s: String,
obj: Any, obj: Any,
) { ): Unit = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun removeValue(s: String) { override fun removeValue(s: String): Unit = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getValue(s: String): Any { override fun getValue(s: String): Any = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun getValueNames(): Array<String> { override fun getValueNames(): Array<String> = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun invalidate() { override fun invalidate(): Unit = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
override fun isValid(): Boolean { override fun isValid(): Boolean = throw UnsupportedOperationException()
throw UnsupportedOperationException()
}
} }

View File

@ -18,34 +18,26 @@ package okhttp3
import java.io.IOException import java.io.IOException
import okio.BufferedSink import okio.BufferedSink
open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() { open class ForwardingRequestBody(
delegate: RequestBody?,
) : RequestBody() {
private val delegate: RequestBody private val delegate: RequestBody
fun delegate(): RequestBody { fun delegate(): RequestBody = delegate
return delegate
}
override fun contentType(): MediaType? { override fun contentType(): MediaType? = delegate.contentType()
return delegate.contentType()
}
@Throws(IOException::class) @Throws(IOException::class)
override fun contentLength(): Long { override fun contentLength(): Long = delegate.contentLength()
return delegate.contentLength()
}
@Throws(IOException::class) @Throws(IOException::class)
override fun writeTo(sink: BufferedSink) { override fun writeTo(sink: BufferedSink) {
delegate.writeTo(sink) delegate.writeTo(sink)
} }
override fun isDuplex(): Boolean { override fun isDuplex(): Boolean = delegate.isDuplex()
return delegate.isDuplex()
}
override fun toString(): String { override fun toString(): String = javaClass.simpleName + "(" + delegate.toString() + ")"
return javaClass.simpleName + "(" + delegate.toString() + ")"
}
init { init {
requireNotNull(delegate) { "delegate == null" } requireNotNull(delegate) { "delegate == null" }

View File

@ -17,28 +17,20 @@ package okhttp3
import okio.BufferedSource import okio.BufferedSource
open class ForwardingResponseBody(delegate: ResponseBody?) : ResponseBody() { open class ForwardingResponseBody(
delegate: ResponseBody?,
) : ResponseBody() {
private val delegate: ResponseBody private val delegate: ResponseBody
fun delegate(): ResponseBody { fun delegate(): ResponseBody = delegate
return delegate
}
override fun contentType(): MediaType? { override fun contentType(): MediaType? = delegate.contentType()
return delegate.contentType()
}
override fun contentLength(): Long { override fun contentLength(): Long = delegate.contentLength()
return delegate.contentLength()
}
override fun source(): BufferedSource { override fun source(): BufferedSource = delegate.source()
return delegate.source()
}
override fun toString(): String { override fun toString(): String = javaClass.simpleName + "(" + delegate.toString() + ")"
return javaClass.simpleName + "(" + delegate.toString() + ")"
}
init { init {
requireNotNull(delegate) { "delegate == null" } requireNotNull(delegate) { "delegate == null" }

View File

@ -20,7 +20,10 @@ import java.util.logging.Handler
import java.util.logging.LogRecord import java.util.logging.LogRecord
object JsseDebugLogging { object JsseDebugLogging {
data class JsseDebugMessage(val message: String, val param: String?) { data class JsseDebugMessage(
val message: String,
val param: String?,
) {
enum class Type { enum class Type {
Handshake, Handshake,
Plaintext, Plaintext,
@ -47,13 +50,12 @@ object JsseDebugLogging {
else -> Type.Unknown else -> Type.Unknown
} }
override fun toString(): String { override fun toString(): String =
return if (param != null) { if (param != null) {
message + "\n" + param message + "\n" + param
} else { } else {
message message
} }
}
} }
private fun quietDebug(message: JsseDebugMessage) { private fun quietDebug(message: JsseDebugMessage) {

View File

@ -47,7 +47,9 @@ import org.junit.jupiter.api.extension.ExtensionContext
* Use [newClient] as a factory for a OkHttpClient instances. These instances are specifically * Use [newClient] as a factory for a OkHttpClient instances. These instances are specifically
* configured for testing. * configured for testing.
*/ */
class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback { class OkHttpClientTestRule :
BeforeEachCallback,
AfterEachCallback {
private val clientEventsList = mutableListOf<String>() private val clientEventsList = mutableListOf<String>()
private var testClient: OkHttpClient? = null private var testClient: OkHttpClient? = null
private var uncaughtException: Throwable? = null private var uncaughtException: Throwable? = null
@ -160,35 +162,33 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
val backend = TaskRunner.RealBackend(loomThreadFactory()) val backend = TaskRunner.RealBackend(loomThreadFactory())
val taskRunner = TaskRunner(backend) val taskRunner = TaskRunner(backend)
OkHttpClient.Builder() OkHttpClient
.Builder()
.connectionPool( .connectionPool(
buildConnectionPool( buildConnectionPool(
connectionListener = connectionListener, connectionListener = connectionListener,
taskRunner = taskRunner, taskRunner = taskRunner,
), ),
) ).dispatcher(Dispatcher(backend.executor))
.dispatcher(Dispatcher(backend.executor))
.taskRunnerInternal(taskRunner) .taskRunnerInternal(taskRunner)
} else { } else {
OkHttpClient.Builder() OkHttpClient
.Builder()
.connectionPool(ConnectionPool(connectionListener = connectionListener)) .connectionPool(ConnectionPool(connectionListener = connectionListener))
} }
private fun loomThreadFactory(): ThreadFactory { private fun loomThreadFactory(): ThreadFactory {
val ofVirtual = Thread::class.java.getMethod("ofVirtual").invoke(null) val ofVirtual = Thread::class.java.getMethod("ofVirtual").invoke(null)
return Class.forName("java.lang.Thread\$Builder") return Class
.forName("java.lang.Thread\$Builder")
.getMethod("factory") .getMethod("factory")
.invoke(ofVirtual) as ThreadFactory .invoke(ofVirtual) as ThreadFactory
} }
private fun isLoom(): Boolean { private fun isLoom(): Boolean = getPlatformSystemProperty() == LOOM_PROPERTY
return getPlatformSystemProperty() == LOOM_PROPERTY
}
fun newClientBuilder(): OkHttpClient.Builder { fun newClientBuilder(): OkHttpClient.Builder = newClient().newBuilder()
return newClient().newBuilder()
}
@Synchronized private fun addEvent(event: String) { @Synchronized private fun addEvent(event: String) {
if (recordEvents) { if (recordEvents) {
@ -302,10 +302,9 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
} }
@SuppressLint("NewApi") @SuppressLint("NewApi")
private fun ExtensionContext.isFlaky(): Boolean { private fun ExtensionContext.isFlaky(): Boolean =
return (testMethod.orElseGet { null }?.isAnnotationPresent(Flaky::class.java) == true) || (testMethod.orElseGet { null }?.isAnnotationPresent(Flaky::class.java) == true) ||
(testClass.orElseGet { null }?.isAnnotationPresent(Flaky::class.java) == true) (testClass.orElseGet { null }?.isAnnotationPresent(Flaky::class.java) == true)
}
@Synchronized private fun logEvents() { @Synchronized private fun logEvents() {
// Will be ineffective if test overrides the listener // Will be ineffective if test overrides the listener
@ -318,9 +317,7 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
} }
} }
fun recordedConnectionEventTypes(): List<String> { fun recordedConnectionEventTypes(): List<String> = connectionListener.recordedEventTypes()
return connectionListener.recordedEventTypes()
}
companion object { companion object {
/** /**

View File

@ -87,10 +87,10 @@ open class RecordingConnectionListener(
if (elapsedMs != -1L) { if (elapsedMs != -1L) {
assertThat( assertThat(
TimeUnit.NANOSECONDS.toMillis(actualElapsedNs) TimeUnit.NANOSECONDS
.toMillis(actualElapsedNs)
.toDouble(), .toDouble(),
) ).isCloseTo(elapsedMs.toDouble(), 100.0)
.isCloseTo(elapsedMs.toDouble(), 100.0)
} }
return result return result

View File

@ -28,9 +28,7 @@ class RecordingCookieJar : CookieJar {
requestCookies.add(cookies.toList()) requestCookies.add(cookies.toList())
} }
fun takeResponseCookies(): List<Cookie> { fun takeResponseCookies(): List<Cookie> = responseCookies.removeFirst()
return responseCookies.removeFirst()
}
fun assertResponseCookies(vararg cookies: String?) { fun assertResponseCookies(vararg cookies: String?) {
assertThat(takeResponseCookies().map(Cookie::toString)).containsExactly(*cookies) assertThat(takeResponseCookies().map(Cookie::toString)).containsExactly(*cookies)
@ -43,7 +41,5 @@ class RecordingCookieJar : CookieJar {
responseCookies.add(cookies) responseCookies.add(cookies)
} }
override fun loadForRequest(url: HttpUrl): List<Cookie> { override fun loadForRequest(url: HttpUrl): List<Cookie> = if (requestCookies.isEmpty()) emptyList() else requestCookies.removeFirst()
return if (requestCookies.isEmpty()) emptyList() else requestCookies.removeFirst()
}
} }

View File

@ -97,9 +97,7 @@ open class RecordingEventListener(
inline fun <reified T : CallEvent> removeUpToEvent(): T = removeUpToEvent(T::class.java) inline fun <reified T : CallEvent> removeUpToEvent(): T = removeUpToEvent(T::class.java)
inline fun <reified T : CallEvent> findEvent(): T { inline fun <reified T : CallEvent> findEvent(): T = eventSequence.first { it is T } as T
return eventSequence.first { it is T } as T
}
/** /**
* Remove and return the next event from the recorded sequence. * Remove and return the next event from the recorded sequence.
@ -123,10 +121,10 @@ open class RecordingEventListener(
if (elapsedMs != -1L) { if (elapsedMs != -1L) {
assertThat( assertThat(
TimeUnit.NANOSECONDS.toMillis(actualElapsedNs) TimeUnit.NANOSECONDS
.toMillis(actualElapsedNs)
.toDouble(), .toDouble(),
) ).isCloseTo(elapsedMs.toDouble(), 100.0)
.isCloseTo(elapsedMs.toDouble(), 100.0)
} }
return result return result

View File

@ -37,8 +37,8 @@ class SpecificHostSocketFactory(
hostMapping[requested] = real hostMapping[requested] = real
} }
override fun createSocket(): Socket { override fun createSocket(): Socket =
return object : Socket() { object : Socket() {
override fun connect( override fun connect(
endpoint: SocketAddress?, endpoint: SocketAddress?,
timeout: Int, timeout: Int,
@ -49,5 +49,4 @@ class SpecificHostSocketFactory(
super.connect(inetSocketAddress, timeout) super.connect(inetSocketAddress, timeout)
} }
} }
}
} }

View File

@ -36,9 +36,7 @@ object TestUtil {
@JvmStatic val isGraalVmImage = System.getProperty("org.graalvm.nativeimage.imagecode") != null @JvmStatic val isGraalVmImage = System.getProperty("org.graalvm.nativeimage.imagecode") != null
@JvmStatic @JvmStatic
fun headerEntries(vararg elements: String?): List<Header> { fun headerEntries(vararg elements: String?): List<Header> = List(elements.size / 2) { Header(elements[it * 2]!!, elements[it * 2 + 1]!!) }
return List(elements.size / 2) { Header(elements[it * 2]!!, elements[it * 2 + 1]!!) }
}
@JvmStatic @JvmStatic
fun repeat( fun repeat(
@ -125,15 +123,12 @@ object TestUtil {
} }
@JvmStatic @JvmStatic
fun threadFactory(name: String): ThreadFactory { fun threadFactory(name: String): ThreadFactory =
return object : ThreadFactory { object : ThreadFactory {
private var nextId = 1 private var nextId = 1
override fun newThread(runnable: Runnable): Thread { override fun newThread(runnable: Runnable): Thread = Thread(runnable, "$name-${nextId++}")
return Thread(runnable, "$name-${nextId++}")
}
} }
}
} }
fun getEnv(name: String) = System.getenv(name) fun getEnv(name: String) = System.getenv(name)

View File

@ -102,8 +102,8 @@ class TestValueFactory : Closeable {
taskRunner: TaskRunner = this.taskRunner, taskRunner: TaskRunner = this.taskRunner,
maxIdleConnections: Int = Int.MAX_VALUE, maxIdleConnections: Int = Int.MAX_VALUE,
routePlanner: RoutePlanner? = null, routePlanner: RoutePlanner? = null,
): RealConnectionPool { ): RealConnectionPool =
return RealConnectionPool( RealConnectionPool(
taskRunner = taskRunner, taskRunner = taskRunner,
maxIdleConnections = maxIdleConnections, maxIdleConnections = maxIdleConnections,
keepAliveDuration = 100L, keepAliveDuration = 100L,
@ -129,7 +129,6 @@ class TestValueFactory : Closeable {
) )
}, },
) )
}
/** Returns an address that's without an SSL socket factory or hostname verifier. */ /** Returns an address that's without an SSL socket factory or hostname verifier. */
fun newAddress( fun newAddress(
@ -137,8 +136,8 @@ class TestValueFactory : Closeable {
uriPort: Int = this.uriPort, uriPort: Int = this.uriPort,
proxy: Proxy? = null, proxy: Proxy? = null,
proxySelector: ProxySelector = this.proxySelector, proxySelector: ProxySelector = this.proxySelector,
): Address { ): Address =
return Address( Address(
uriHost = uriHost, uriHost = uriHost,
uriPort = uriPort, uriPort = uriPort,
dns = dns, dns = dns,
@ -152,7 +151,6 @@ class TestValueFactory : Closeable {
connectionSpecs = connectionSpecs, connectionSpecs = connectionSpecs,
proxySelector = proxySelector, proxySelector = proxySelector,
) )
}
fun newHttpsAddress( fun newHttpsAddress(
uriHost: String = this.uriHost, uriHost: String = this.uriHost,
@ -161,8 +159,8 @@ class TestValueFactory : Closeable {
proxySelector: ProxySelector = this.proxySelector, proxySelector: ProxySelector = this.proxySelector,
sslSocketFactory: SSLSocketFactory? = this.sslSocketFactory, sslSocketFactory: SSLSocketFactory? = this.sslSocketFactory,
hostnameVerifier: HostnameVerifier? = this.hostnameVerifier, hostnameVerifier: HostnameVerifier? = this.hostnameVerifier,
): Address { ): Address =
return Address( Address(
uriHost = uriHost, uriHost = uriHost,
uriPort = uriPort, uriPort = uriPort,
dns = dns, dns = dns,
@ -176,22 +174,20 @@ class TestValueFactory : Closeable {
connectionSpecs = connectionSpecs, connectionSpecs = connectionSpecs,
proxySelector = proxySelector, proxySelector = proxySelector,
) )
}
fun newRoute( fun newRoute(
address: Address = newAddress(), address: Address = newAddress(),
proxy: Proxy = this.proxy, proxy: Proxy = this.proxy,
socketAddress: InetSocketAddress = InetSocketAddress.createUnresolved(uriHost, uriPort), socketAddress: InetSocketAddress = InetSocketAddress.createUnresolved(uriHost, uriPort),
): Route { ): Route =
return Route( Route(
address = address, address = address,
proxy = proxy, proxy = proxy,
socketAddress = socketAddress, socketAddress = socketAddress,
) )
}
fun newChain(call: RealCall): RealInterceptorChain { fun newChain(call: RealCall): RealInterceptorChain =
return RealInterceptorChain( RealInterceptorChain(
call = call, call = call,
interceptors = listOf(), interceptors = listOf(),
index = 0, index = 0,
@ -201,7 +197,6 @@ class TestValueFactory : Closeable {
readTimeoutMillis = 10_000, readTimeoutMillis = 10_000,
writeTimeoutMillis = 10_000, writeTimeoutMillis = 10_000,
) )
}
fun newRoutePlanner( fun newRoutePlanner(
client: OkHttpClient, client: OkHttpClient,

View File

@ -26,9 +26,7 @@ import okio.buffer
/** Rewrites the request body sent to the server to be all uppercase. */ /** Rewrites the request body sent to the server to be all uppercase. */
class UppercaseRequestInterceptor : Interceptor { class UppercaseRequestInterceptor : Interceptor {
@Throws(IOException::class) @Throws(IOException::class)
override fun intercept(chain: Chain): Response { override fun intercept(chain: Chain): Response = chain.proceed(uppercaseRequest(chain.request()))
return chain.proceed(uppercaseRequest(chain.request()))
}
/** Returns a request that transforms `request` to be all uppercase. */ /** Returns a request that transforms `request` to be all uppercase. */
private fun uppercaseRequest(request: Request): Request { private fun uppercaseRequest(request: Request): Request {
@ -39,13 +37,14 @@ class UppercaseRequestInterceptor : Interceptor {
delegate().writeTo(uppercaseSink(sink).buffer()) delegate().writeTo(uppercaseSink(sink).buffer())
} }
} }
return request.newBuilder() return request
.newBuilder()
.method(request.method, uppercaseBody) .method(request.method, uppercaseBody)
.build() .build()
} }
private fun uppercaseSink(sink: Sink): Sink { private fun uppercaseSink(sink: Sink): Sink =
return object : ForwardingSink(sink) { object : ForwardingSink(sink) {
@Throws(IOException::class) @Throws(IOException::class)
override fun write( override fun write(
source: Buffer, source: Buffer,
@ -59,5 +58,4 @@ class UppercaseRequestInterceptor : Interceptor {
) )
} }
} }
}
} }

View File

@ -25,18 +25,15 @@ import okio.buffer
/** Rewrites the response body returned from the server to be all uppercase. */ /** Rewrites the response body returned from the server to be all uppercase. */
class UppercaseResponseInterceptor : Interceptor { class UppercaseResponseInterceptor : Interceptor {
@Throws(IOException::class) @Throws(IOException::class)
override fun intercept(chain: Chain): Response { override fun intercept(chain: Chain): Response = uppercaseResponse(chain.proceed(chain.request()))
return uppercaseResponse(chain.proceed(chain.request()))
}
private fun uppercaseResponse(response: Response): Response { private fun uppercaseResponse(response: Response): Response {
val uppercaseBody: ResponseBody = val uppercaseBody: ResponseBody =
object : ForwardingResponseBody(response.body) { object : ForwardingResponseBody(response.body) {
override fun source(): BufferedSource { override fun source(): BufferedSource = uppercaseSource(delegate().source()).buffer()
return uppercaseSource(delegate().source()).buffer()
}
} }
return response.newBuilder() return response
.newBuilder()
.body(uppercaseBody) .body(uppercaseBody)
.build() .build()
} }

View File

@ -46,7 +46,8 @@ class RecordingOkAuthenticator(
407 -> "Proxy-Authorization" 407 -> "Proxy-Authorization"
else -> "Authorization" else -> "Authorization"
} }
return response.request.newBuilder() return response.request
.newBuilder()
.addHeader(header, credential) .addHeader(header, credential)
.build() .build()
} }

View File

@ -349,7 +349,8 @@ class TaskFaker : Closeable {
*/ */
private inner class TaskFakerBlockingQueue<T>( private inner class TaskFakerBlockingQueue<T>(
val delegate: BlockingQueue<T>, val delegate: BlockingQueue<T>,
) : AbstractQueue<T>(), BlockingQueue<T> { ) : AbstractQueue<T>(),
BlockingQueue<T> {
override val size: Int = delegate.size override val size: Int = delegate.size
private var editCount = 0 private var editCount = 0

View File

@ -36,9 +36,7 @@ class AsyncRequestBody : RequestBody() {
override fun isDuplex(): Boolean = true override fun isDuplex(): Boolean = true
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
fun takeSink(): BufferedSink { fun takeSink(): BufferedSink = requestBodySinks.poll(5, SECONDS) ?: throw AssertionError("no sink to take")
return requestBodySinks.poll(5, SECONDS) ?: throw AssertionError("no sink to take")
}
fun assertNoMoreSinks() { fun assertNoMoreSinks() {
assertTrue(requestBodySinks.isEmpty()) assertTrue(requestBodySinks.isEmpty())

View File

@ -21,7 +21,9 @@ import okhttp3.internal.http2.flowcontrol.WindowCounter
/** /**
* ConnectionListener that outputs CSV for flow control of client receiving streams. * ConnectionListener that outputs CSV for flow control of client receiving streams.
*/ */
class Http2FlowControlConnectionListener : ConnectionListener(), FlowControlListener { class Http2FlowControlConnectionListener :
ConnectionListener(),
FlowControlListener {
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
override fun receivingStreamWindowChanged( override fun receivingStreamWindowChanged(

View File

@ -21,7 +21,9 @@ import okio.Path
import okio.Sink import okio.Sink
import okio.Source import okio.Source
class LoggingFilesystem(fileSystem: FileSystem) : ForwardingFileSystem(fileSystem) { class LoggingFilesystem(
fileSystem: FileSystem,
) : ForwardingFileSystem(fileSystem) {
fun log(line: String) { fun log(line: String) {
println(line) println(line)
} }

View File

@ -62,7 +62,9 @@ open class PlatformRule
constructor( constructor(
val requiredPlatformName: String? = null, val requiredPlatformName: String? = null,
val platform: Platform? = null, val platform: Platform? = null,
) : BeforeEachCallback, AfterEachCallback, InvocationInterceptor { ) : BeforeEachCallback,
AfterEachCallback,
InvocationInterceptor {
private val versionChecks = mutableListOf<Pair<Matcher<out Any>, Matcher<out Any>>>() private val versionChecks = mutableListOf<Pair<Matcher<out Any>, Matcher<out Any>>>()
override fun beforeEach(context: ExtensionContext) { override fun beforeEach(context: ExtensionContext) {
@ -157,34 +159,26 @@ open class PlatformRule
description.appendText(platform) description.appendText(platform)
} }
override fun matches(item: Any?): Boolean { override fun matches(item: Any?): Boolean = getPlatformSystemProperty() == platform
return getPlatformSystemProperty() == platform
}
} }
fun fromMajor(version: Int): Matcher<PlatformVersion> { fun fromMajor(version: Int): Matcher<PlatformVersion> =
return object : TypeSafeMatcher<PlatformVersion>() { object : TypeSafeMatcher<PlatformVersion>() {
override fun describeTo(description: Description) { override fun describeTo(description: Description) {
description.appendText("JDK with version from $version") description.appendText("JDK with version from $version")
} }
override fun matchesSafely(item: PlatformVersion): Boolean { override fun matchesSafely(item: PlatformVersion): Boolean = item.majorVersion >= version
return item.majorVersion >= version
}
} }
}
fun onMajor(version: Int): Matcher<PlatformVersion> { fun onMajor(version: Int): Matcher<PlatformVersion> =
return object : TypeSafeMatcher<PlatformVersion>() { object : TypeSafeMatcher<PlatformVersion>() {
override fun describeTo(description: Description) { override fun describeTo(description: Description) {
description.appendText("JDK with version $version") description.appendText("JDK with version $version")
} }
override fun matchesSafely(item: PlatformVersion): Boolean { override fun matchesSafely(item: PlatformVersion): Boolean = item.majorVersion == version
return item.majorVersion == version
}
} }
}
fun rethrowIfNotExpected(e: Throwable) { fun rethrowIfNotExpected(e: Throwable) {
versionChecks.forEach { (versionMatcher, failureMatcher) -> versionChecks.forEach { (versionMatcher, failureMatcher) ->
@ -336,20 +330,18 @@ open class PlatformRule
assumeTrue(PlatformVersion.majorVersion == majorVersion) assumeTrue(PlatformVersion.majorVersion == majorVersion)
} }
fun androidSdkVersion(): Int? { fun androidSdkVersion(): Int? =
return if (Platform.isAndroid) { if (Platform.isAndroid) {
Build.VERSION.SDK_INT Build.VERSION.SDK_INT
} else { } else {
null null
} }
}
fun localhostHandshakeCertificates(): HandshakeCertificates { fun localhostHandshakeCertificates(): HandshakeCertificates =
return when { when {
isBouncyCastle() -> localhostHandshakeCertificatesWithRsa2048 isBouncyCastle() -> localhostHandshakeCertificatesWithRsa2048
else -> localhost() else -> localhost()
} }
}
val isAndroid: Boolean val isAndroid: Boolean
get() = Platform.Companion.isAndroid get() = Platform.Companion.isAndroid
@ -373,12 +365,14 @@ open class PlatformRule
*/ */
private val localhostHandshakeCertificatesWithRsa2048: HandshakeCertificates by lazy { private val localhostHandshakeCertificatesWithRsa2048: HandshakeCertificates by lazy {
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.commonName("localhost") .commonName("localhost")
.addSubjectAlternativeName("localhost") .addSubjectAlternativeName("localhost")
.rsa2048() .rsa2048()
.build() .build()
return@lazy HandshakeCertificates.Builder() return@lazy HandshakeCertificates
.Builder()
.heldCertificate(heldCertificate) .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate) .addTrustedCertificate(heldCertificate.certificate)
.build() .build()
@ -398,7 +392,8 @@ open class PlatformRule
} }
val provider = val provider =
Conscrypt.newProviderBuilder() Conscrypt
.newProviderBuilder()
.provideTrustManager(true) .provideTrustManager(true)
.build() .build()
Security.insertProviderAt(provider, 1) Security.insertProviderAt(provider, 1)
@ -499,7 +494,10 @@ open class PlatformRule
} }
val isCorrettoInstalled: Boolean = val isCorrettoInstalled: Boolean =
isCorrettoSupported && Security.getProviders() isCorrettoSupported &&
.first().name == AmazonCorrettoCryptoProvider.PROVIDER_NAME Security
.getProviders()
.first()
.name == AmazonCorrettoCryptoProvider.PROVIDER_NAME
} }
} }

View File

@ -23,7 +23,5 @@ object PlatformVersion {
} }
} }
fun getJvmSpecVersion(): String { fun getJvmSpecVersion(): String = System.getProperty("java.specification.version", "unknown")
return System.getProperty("java.specification.version", "unknown")
}
} }

View File

@ -38,9 +38,7 @@ class OkHttpClientTestRuleTest {
val thread = val thread =
object : Thread() { object : Thread() {
override fun run() { override fun run(): Unit = throw RuntimeException("boom!")
throw RuntimeException("boom!")
}
} }
thread.start() thread.start()
thread.join() thread.join()

View File

@ -66,13 +66,12 @@ fun String.decodeCertificatePem(): X509Certificate {
* *
* [rfc_7468]: https://tools.ietf.org/html/rfc7468 * [rfc_7468]: https://tools.ietf.org/html/rfc7468
*/ */
fun X509Certificate.certificatePem(): String { fun X509Certificate.certificatePem(): String =
return buildString { buildString {
append("-----BEGIN CERTIFICATE-----\n") append("-----BEGIN CERTIFICATE-----\n")
encodeBase64Lines(encoded.toByteString()) encodeBase64Lines(encoded.toByteString())
append("-----END CERTIFICATE-----\n") append("-----END CERTIFICATE-----\n")
} }
}
internal fun StringBuilder.encodeBase64Lines(data: ByteString) { internal fun StringBuilder.encodeBase64Lines(data: ByteString) {
val base64 = data.base64() val base64 = data.base64()

View File

@ -93,11 +93,10 @@ class HandshakeCertificates private constructor(
fun sslSocketFactory(): SSLSocketFactory = sslContext().socketFactory fun sslSocketFactory(): SSLSocketFactory = sslContext().socketFactory
fun sslContext(): SSLContext { fun sslContext(): SSLContext =
return Platform.get().newSSLContext().apply { Platform.get().newSSLContext().apply {
init(arrayOf<KeyManager>(keyManager), arrayOf<TrustManager>(trustManager), SecureRandom()) init(arrayOf<KeyManager>(keyManager), arrayOf<TrustManager>(trustManager), SecureRandom())
} }
}
class Builder { class Builder {
private var heldCertificate: HeldCertificate? = null private var heldCertificate: HeldCertificate? = null

View File

@ -159,13 +159,12 @@ class HeldCertificate(
* [rfc_5208]: https://tools.ietf.org/html/rfc5208 * [rfc_5208]: https://tools.ietf.org/html/rfc5208
* [rfc_7468]: https://tools.ietf.org/html/rfc7468 * [rfc_7468]: https://tools.ietf.org/html/rfc7468
*/ */
fun privateKeyPkcs8Pem(): String { fun privateKeyPkcs8Pem(): String =
return buildString { buildString {
append("-----BEGIN PRIVATE KEY-----\n") append("-----BEGIN PRIVATE KEY-----\n")
encodeBase64Lines(keyPair.private.encoded.toByteString()) encodeBase64Lines(keyPair.private.encoded.toByteString())
append("-----END PRIVATE KEY-----\n") append("-----END PRIVATE KEY-----\n")
} }
}
/** /**
* Returns the RSA private key encoded in [PKCS #1][rfc_8017] [PEM format][rfc_7468]. * Returns the RSA private key encoded in [PKCS #1][rfc_8017] [PEM format][rfc_7468].
@ -358,7 +357,9 @@ class HeldCertificate(
issuerKeyPair = signedBy!!.keyPair issuerKeyPair = signedBy!!.keyPair
issuer = issuer =
CertificateAdapters.rdnSequence.fromDer( CertificateAdapters.rdnSequence.fromDer(
signedBy!!.certificate.subjectX500Principal.encoded.toByteString(), signedBy!!
.certificate.subjectX500Principal.encoded
.toByteString(),
) )
} else { } else {
issuerKeyPair = subjectKeyPair issuerKeyPair = subjectKeyPair
@ -477,8 +478,8 @@ class HeldCertificate(
return result return result
} }
private fun signatureAlgorithm(signedByKeyPair: KeyPair): AlgorithmIdentifier { private fun signatureAlgorithm(signedByKeyPair: KeyPair): AlgorithmIdentifier =
return when (signedByKeyPair.private) { when (signedByKeyPair.private) {
is RSAPrivateKey -> is RSAPrivateKey ->
AlgorithmIdentifier( AlgorithmIdentifier(
algorithm = SHA256_WITH_RSA_ENCRYPTION, algorithm = SHA256_WITH_RSA_ENCRYPTION,
@ -490,14 +491,12 @@ class HeldCertificate(
parameters = ByteString.EMPTY, parameters = ByteString.EMPTY,
) )
} }
}
private fun generateKeyPair(): KeyPair { private fun generateKeyPair(): KeyPair =
return KeyPairGenerator.getInstance(keyAlgorithm).run { KeyPairGenerator.getInstance(keyAlgorithm).run {
initialize(keySize, SecureRandom()) initialize(keySize, SecureRandom())
generateKeyPair() generateKeyPair()
} }
}
companion object { companion object {
private const val DEFAULT_DURATION_MILLIS = 1000L * 60 * 60 * 24 // 24 hours. private const val DEFAULT_DURATION_MILLIS = 1000L * 60 * 60 * 24 // 24 hours.

View File

@ -35,12 +35,14 @@ object TlsUtil {
private val localhost: HandshakeCertificates by lazy { private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust. // Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.commonName("localhost") .commonName("localhost")
.addSubjectAlternativeName("localhost") .addSubjectAlternativeName("localhost")
.addSubjectAlternativeName("localhost.localdomain") .addSubjectAlternativeName("localhost.localdomain")
.build() .build()
return@lazy HandshakeCertificates.Builder() return@lazy HandshakeCertificates
.Builder()
.heldCertificate(heldCertificate) .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate) .addTrustedCertificate(heldCertificate.certificate)
.build() .build()
@ -110,10 +112,9 @@ object TlsUtil {
return result[0] as X509KeyManager return result[0] as X509KeyManager
} }
private fun newEmptyKeyStore(keyStoreType: String?): KeyStore { private fun newEmptyKeyStore(keyStoreType: String?): KeyStore =
return KeyStore.getInstance(keyStoreType ?: KeyStore.getDefaultType()).apply { KeyStore.getInstance(keyStoreType ?: KeyStore.getDefaultType()).apply {
val inputStream: InputStream? = null // By convention, 'null' creates an empty key store. val inputStream: InputStream? = null // By convention, 'null' creates an empty key store.
load(inputStream, password) load(inputStream, password)
} }
}
} }

View File

@ -42,9 +42,7 @@ internal object CertificateAdapters {
*/ */
internal val time: DerAdapter<Long> = internal val time: DerAdapter<Long> =
object : DerAdapter<Long> { object : DerAdapter<Long> {
override fun matches(header: DerHeader): Boolean { override fun matches(header: DerHeader): Boolean = Adapters.UTC_TIME.matches(header) || Adapters.GENERALIZED_TIME.matches(header)
return Adapters.UTC_TIME.matches(header) || Adapters.GENERALIZED_TIME.matches(header)
}
override fun fromDer(reader: DerReader): Long { override fun fromDer(reader: DerReader): Long {
val peekHeader = val peekHeader =
@ -215,17 +213,18 @@ internal object CertificateAdapters {
* that follows. * that follows.
*/ */
private val extensionValue: BasicDerAdapter<Any?> = private val extensionValue: BasicDerAdapter<Any?> =
Adapters.usingTypeHint { typeHint -> Adapters
when (typeHint) { .usingTypeHint { typeHint ->
ObjectIdentifiers.SUBJECT_ALTERNATIVE_NAME -> subjectAlternativeName when (typeHint) {
ObjectIdentifiers.BASIC_CONSTRAINTS -> basicConstraints ObjectIdentifiers.SUBJECT_ALTERNATIVE_NAME -> subjectAlternativeName
else -> null ObjectIdentifiers.BASIC_CONSTRAINTS -> basicConstraints
} else -> null
}.withExplicitBox( }
tagClass = Adapters.OCTET_STRING.tagClass, }.withExplicitBox(
tag = Adapters.OCTET_STRING.tag, tagClass = Adapters.OCTET_STRING.tagClass,
forceConstructed = false, tag = Adapters.OCTET_STRING.tag,
) forceConstructed = false,
)
/** /**
* ``` * ```

View File

@ -135,11 +135,10 @@ internal interface DerAdapter<T> {
} }
/** Returns an adapter that returns a set of values of this type. */ /** Returns an adapter that returns a set of values of this type. */
fun asSetOf(): BasicDerAdapter<List<T>> { fun asSetOf(): BasicDerAdapter<List<T>> =
return asSequenceOf( asSequenceOf(
name = "SET OF", name = "SET OF",
tagClass = DerHeader.TAG_CLASS_UNIVERSAL, tagClass = DerHeader.TAG_CLASS_UNIVERSAL,
tag = 17L, tag = 17L,
) )
}
} }

View File

@ -37,7 +37,9 @@ import okio.buffer
* [x690]: https://www.itu.int/rec/T-REC-X.690 * [x690]: https://www.itu.int/rec/T-REC-X.690
* [asn1_and_der]: https://letsencrypt.org/docs/a-warm-welcome-to-asn1-and-der/ * [asn1_and_der]: https://letsencrypt.org/docs/a-warm-welcome-to-asn1-and-der/
*/ */
internal class DerReader(source: Source) { internal class DerReader(
source: Source,
) {
private val countingSource: CountingSource = CountingSource(source) private val countingSource: CountingSource = CountingSource(source)
private val source: BufferedSource = countingSource.buffer() private val source: BufferedSource = countingSource.buffer()
@ -312,9 +314,7 @@ internal class DerReader(source: Source) {
} }
/** Read a value as bytes without interpretation of its contents. */ /** Read a value as bytes without interpretation of its contents. */
fun readUnknown(): ByteString { fun readUnknown(): ByteString = source.readByteString(bytesLeft)
return source.readByteString(bytesLeft)
}
override fun toString(): String = path.joinToString(separator = " / ") override fun toString(): String = path.joinToString(separator = " / ")
@ -334,7 +334,9 @@ internal class DerReader(source: Source) {
} }
/** A source that keeps track of how many bytes it's consumed. */ /** A source that keeps track of how many bytes it's consumed. */
private class CountingSource(source: Source) : ForwardingSource(source) { private class CountingSource(
source: Source,
) : ForwardingSource(source) {
var bytesRead = 0L var bytesRead = 0L
override fun read( override fun read(

View File

@ -20,7 +20,9 @@ import okio.Buffer
import okio.BufferedSink import okio.BufferedSink
import okio.ByteString import okio.ByteString
internal class DerWriter(sink: BufferedSink) { internal class DerWriter(
sink: BufferedSink,
) {
/** A stack of buffers that will be concatenated once we know the length of each. */ /** A stack of buffers that will be concatenated once we know the length of each. */
private val stack = mutableListOf(sink) private val stack = mutableListOf(sink)

View File

@ -61,38 +61,46 @@ class HandshakeCertificatesTest {
platform.assumeNotBouncyCastle() platform.assumeNotBouncyCastle()
val clientRoot = val clientRoot =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(1) .certificateAuthority(1)
.build() .build()
val clientIntermediate = val clientIntermediate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.signedBy(clientRoot) .signedBy(clientRoot)
.build() .build()
val clientCertificate = val clientCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(clientIntermediate) .signedBy(clientIntermediate)
.build() .build()
val serverRoot = val serverRoot =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(1) .certificateAuthority(1)
.build() .build()
val serverIntermediate = val serverIntermediate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.signedBy(serverRoot) .signedBy(serverRoot)
.build() .build()
val serverCertificate = val serverCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(serverIntermediate) .signedBy(serverIntermediate)
.build() .build()
val server = val server =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(clientRoot.certificate) .addTrustedCertificate(clientRoot.certificate)
.heldCertificate(serverCertificate, serverIntermediate.certificate) .heldCertificate(serverCertificate, serverIntermediate.certificate)
.build() .build()
val client = val client =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(serverRoot.certificate) .addTrustedCertificate(serverRoot.certificate)
.heldCertificate(clientCertificate, clientIntermediate.certificate) .heldCertificate(clientCertificate, clientIntermediate.certificate)
.build() .build()
@ -113,20 +121,24 @@ class HandshakeCertificatesTest {
@Test fun keyManager() { @Test fun keyManager() {
val root = val root =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(1) .certificateAuthority(1)
.build() .build()
val intermediate = val intermediate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.signedBy(root) .signedBy(root)
.build() .build()
val certificate = val certificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(intermediate) .signedBy(intermediate)
.build() .build()
val handshakeCertificates = val handshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addTrustedCertificate(root.certificate) // BouncyCastle requires at least one .addTrustedCertificate(root.certificate) // BouncyCastle requires at least one
.heldCertificate(certificate, intermediate.certificate) .heldCertificate(certificate, intermediate.certificate)
.build() .build()
@ -140,7 +152,8 @@ class HandshakeCertificatesTest {
@Test fun platformTrustedCertificates() { @Test fun platformTrustedCertificates() {
val handshakeCertificates = val handshakeCertificates =
HandshakeCertificates.Builder() HandshakeCertificates
.Builder()
.addPlatformTrustedCertificates() .addPlatformTrustedCertificates()
.build() .build()
val acceptedIssuers = handshakeCertificates.trustManager.acceptedIssuers val acceptedIssuers = handshakeCertificates.trustManager.acceptedIssuers

View File

@ -60,7 +60,8 @@ class HeldCertificateTest {
fun customInterval() { fun customInterval() {
// 5 seconds starting on 1970-01-01. // 5 seconds starting on 1970-01-01.
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.validityInterval(5000L, 10000L) .validityInterval(5000L, 10000L)
.build() .build()
val certificate = heldCertificate.certificate val certificate = heldCertificate.certificate
@ -72,7 +73,8 @@ class HeldCertificateTest {
fun customDuration() { fun customDuration() {
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.duration(5, TimeUnit.SECONDS) .duration(5, TimeUnit.SECONDS)
.build() .build()
val certificate = heldCertificate.certificate val certificate = heldCertificate.certificate
@ -87,7 +89,8 @@ class HeldCertificateTest {
@Test @Test
fun subjectAlternativeNames() { fun subjectAlternativeNames() {
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.addSubjectAlternativeName("1.1.1.1") .addSubjectAlternativeName("1.1.1.1")
.addSubjectAlternativeName("cash.app") .addSubjectAlternativeName("cash.app")
.build() .build()
@ -101,7 +104,8 @@ class HeldCertificateTest {
@Test @Test
fun commonName() { fun commonName() {
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.commonName("cash.app") .commonName("cash.app")
.build() .build()
val certificate = heldCertificate.certificate val certificate = heldCertificate.certificate
@ -111,7 +115,8 @@ class HeldCertificateTest {
@Test @Test
fun organizationalUnit() { fun organizationalUnit() {
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.commonName("cash.app") .commonName("cash.app")
.organizationalUnit("cash") .organizationalUnit("cash")
.build() .build()
@ -130,8 +135,7 @@ class HeldCertificateTest {
"MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCApFHhtrLan28q+oMolZuaTfWBA0V5aM" + "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCApFHhtrLan28q+oMolZuaTfWBA0V5aM" +
"Ivq32BsloQu6LlvX1wJ4YEoUCjDlPOtpht7XLbUmBnbIzN89XK4UJVM6Sqp3K88Km8z7gMrdrfTom/274wL25fICR+" + "Ivq32BsloQu6LlvX1wJ4YEoUCjDlPOtpht7XLbUmBnbIzN89XK4UJVM6Sqp3K88Km8z7gMrdrfTom/274wL25fICR+" +
"yDEQ5fUVYBmJAKXZF1aoI0mIoEx0xFsQhIJ637v2MxJDupd61wIDAQAB" "yDEQ5fUVYBmJAKXZF1aoI0mIoEx0xFsQhIJ637v2MxJDupd61wIDAQAB"
) ).decodeBase64()!!
.decodeBase64()!!
val publicKey = val publicKey =
keyFactory.generatePublic( keyFactory.generatePublic(
X509EncodedKeySpec(publicKeyBytes.toByteArray()), X509EncodedKeySpec(publicKeyBytes.toByteArray()),
@ -154,7 +158,8 @@ class HeldCertificateTest {
PKCS8EncodedKeySpec(privateKeyBytes.toByteArray()), PKCS8EncodedKeySpec(privateKeyBytes.toByteArray()),
) )
val heldCertificate = val heldCertificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.keyPair(publicKey, privateKey) .keyPair(publicKey, privateKey)
.commonName("cash.app") .commonName("cash.app")
.validityInterval(0L, 1000L) .validityInterval(0L, 1000L)
@ -222,12 +227,14 @@ class HeldCertificateTest {
@Test @Test
fun ecdsaSignedByRsa() { fun ecdsaSignedByRsa() {
val root = val root =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.rsa2048() .rsa2048()
.build() .build()
val leaf = val leaf =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.ecdsa256() .ecdsa256()
.signedBy(root) .signedBy(root)
@ -239,12 +246,14 @@ class HeldCertificateTest {
@Test @Test
fun rsaSignedByEcdsa() { fun rsaSignedByEcdsa() {
val root = val root =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.ecdsa256() .ecdsa256()
.build() .build()
val leaf = val leaf =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.rsa2048() .rsa2048()
.signedBy(root) .signedBy(root)

View File

@ -644,8 +644,7 @@ internal class DerCertificatesTest {
"92bc50e78097f2e6a9768997e22f0d70000017173d326b3000004030046304402207" + "92bc50e78097f2e6a9768997e22f0d70000017173d326b3000004030046304402207" +
"e112c029b93e0db15d1de40eddb9aa7a55eeb4b48ce8c94ddf8ed71331e931e02207" + "e112c029b93e0db15d1de40eddb9aa7a55eeb4b48ce8c94ddf8ed71331e931e02207" +
"8281e3c39c8e643b901c2bc6c470aa0ed3ad01bf17f0f207dc0f8a5ab541a70" "8281e3c39c8e643b901c2bc6c470aa0ed3ad01bf17f0f207dc0f8a5ab541a70"
) ).decodeHex(),
.decodeHex(),
), ),
Extension( Extension(
id = keyUsage, id = keyUsage,
@ -724,7 +723,8 @@ internal class DerCertificatesTest {
@Test @Test
fun `certificate attributes`() { fun `certificate attributes`() {
val certificate = val certificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(3) .certificateAuthority(3)
.commonName("Jurassic Park") .commonName("Jurassic Park")
.organizationalUnit("Gene Research") .organizationalUnit("Gene Research")
@ -767,7 +767,8 @@ internal class DerCertificatesTest {
@Test @Test
fun `missing subject alternative names`() { fun `missing subject alternative names`() {
val certificate = val certificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(3) .certificateAuthority(3)
.commonName("Jurassic Park") .commonName("Jurassic Park")
.organizationalUnit("Gene Research") .organizationalUnit("Gene Research")
@ -817,7 +818,8 @@ internal class DerCertificatesTest {
val privateKey = keyFactory.generatePrivate(PKCS8EncodedKeySpec(privateKeyBytes.toByteArray())) val privateKey = keyFactory.generatePrivate(PKCS8EncodedKeySpec(privateKeyBytes.toByteArray()))
val certificate = val certificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.keyPair(publicKey, privateKey) .keyPair(publicKey, privateKey)
.build() .build()
@ -930,8 +932,7 @@ internal class DerCertificatesTest {
"5f5pk8F3wcXzAeVw06z3k1IB41Tu6MX+CyPU+TeudRlz+wV8b0zDvK+EnRKCCbptVFj1Bkt8lQ4JfcnhAkAk2Y3G" + "5f5pk8F3wcXzAeVw06z3k1IB41Tu6MX+CyPU+TeudRlz+wV8b0zDvK+EnRKCCbptVFj1Bkt8lQ4JfcnhAkAk2Y3G" +
"z+HySrkcT7Cg12M/NkdUQnZe3jr88pt/+IGNwomc6Wt/mJ4fcWONTkGMcfOZff1NQeNXDAZ6941XCsIVAkASOg02" + "z+HySrkcT7Cg12M/NkdUQnZe3jr88pt/+IGNwomc6Wt/mJ4fcWONTkGMcfOZff1NQeNXDAZ6941XCsIVAkASOg02" +
"PlVHLidU7mIE65swMM5/RNhS4aFjez/MwxFNOHaxc9VgCwYPXCLOtdf7AVovdyG0XWgbUXH+NyxKwboE" "PlVHLidU7mIE65swMM5/RNhS4aFjez/MwxFNOHaxc9VgCwYPXCLOtdf7AVovdyG0XWgbUXH+NyxKwboE"
) ).decodeBase64()!!
.decodeBase64()!!
val decoded = CertificateAdapters.privateKeyInfo.fromDer(privateKeyInfoByteString) val decoded = CertificateAdapters.privateKeyInfo.fromDer(privateKeyInfoByteString)
@ -946,12 +947,14 @@ internal class DerCertificatesTest {
@Test @Test
fun `RSA issuer and signature`() { fun `RSA issuer and signature`() {
val root = val root =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.rsa2048() .rsa2048()
.build() .build()
val certificate = val certificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(root) .signedBy(root)
.rsa2048() .rsa2048()
.build() .build()
@ -982,12 +985,14 @@ internal class DerCertificatesTest {
@Test @Test
fun `EC issuer and signature`() { fun `EC issuer and signature`() {
val root = val root =
HeldCertificate.Builder() HeldCertificate
.Builder()
.certificateAuthority(0) .certificateAuthority(0)
.ecdsa256() .ecdsa256()
.build() .build()
val certificate = val certificate =
HeldCertificate.Builder() HeldCertificate
.Builder()
.signedBy(root) .signedBy(root)
.ecdsa256() .ecdsa256()
.build() .build()
@ -1042,8 +1047,7 @@ internal class DerCertificatesTest {
"LXLJCzNLJEJcgV9TjbVu33eQR23yMuXD+cZsqLMF+L5IIM47W8dlwKJvMy0xs7Jb1S3NOIhcoVu+XPzRsgKv8Yi2" + "LXLJCzNLJEJcgV9TjbVu33eQR23yMuXD+cZsqLMF+L5IIM47W8dlwKJvMy0xs7Jb1S3NOIhcoVu+XPzRsgKv8Yi2" +
"B6l278RfzegiCx4vYJv0pBjFzizEiFH9bWTYIOlIJJSM57hoICgjCTS8BoEgndwWIyc/nEmlYaUwmCo9QynY+UmW" + "B6l278RfzegiCx4vYJv0pBjFzizEiFH9bWTYIOlIJJSM57hoICgjCTS8BoEgndwWIyc/nEmlYaUwmCo9QynY+UmW" +
"1WPWmVITEJPMdMK6AZqvvaWmuHJ6/vURaz+Hoc5D3z0yJDDCkv52bXV04ZoF6cbcWry7JvNA+djvay/4BRR4SZQ==" "1WPWmVITEJPMdMK6AZqvvaWmuHJ6/vURaz+Hoc5D3z0yJDDCkv52bXV04ZoF6cbcWry7JvNA+djvay/4BRR4SZQ=="
) ).decodeBase64()!!
.decodeBase64()!!
val decoded = CertificateAdapters.certificate.fromDer(certificateByteString) val decoded = CertificateAdapters.certificate.fromDer(certificateByteString)
assertThat(decoded.subjectAlternativeNames).isEqualTo( assertThat(decoded.subjectAlternativeNames).isEqualTo(
@ -1079,17 +1083,15 @@ internal class DerCertificatesTest {
} }
/** Returns a byte string that differs from this one by one bit. */ /** Returns a byte string that differs from this one by one bit. */
private fun ByteString.offByOneBit(): ByteString { private fun ByteString.offByOneBit(): ByteString =
return Buffer() Buffer()
.write(this, 0, size - 1) .write(this, 0, size - 1)
.writeByte(this[size - 1].toInt() xor 1) .writeByte(this[size - 1].toInt() xor 1)
.readByteString() .readByteString()
}
private fun date(s: String): Date { private fun date(s: String): Date =
return SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ").run { SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ").run {
timeZone = TimeZone.getTimeZone("GMT") timeZone = TimeZone.getTimeZone("GMT")
parse(s) parse(s)
} }
}
} }

View File

@ -974,10 +974,9 @@ internal class DerTest {
} }
} }
private fun date(s: String): Date { private fun date(s: String): Date =
return SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ").run { SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ").run {
timeZone = TimeZone.getTimeZone("GMT") timeZone = TimeZone.getTimeZone("GMT")
parse(s) parse(s)
} }
}
} }

View File

@ -38,7 +38,9 @@ import okhttp3.internal.tls.TrustRootIndex
/** Android 10+ (API 29+). */ /** Android 10+ (API 29+). */
@SuppressSignatureCheck @SuppressSignatureCheck
class Android10Platform : Platform(), ContextAwarePlatform { class Android10Platform :
Platform(),
ContextAwarePlatform {
override var applicationContext: Context? = null override var applicationContext: Context? = null
private val socketAdapters = private val socketAdapters =
@ -51,7 +53,8 @@ class Android10Platform : Platform(), ContextAwarePlatform {
).filter { it.isSupported() } ).filter { it.isSupported() }
override fun trustManager(sslSocketFactory: SSLSocketFactory): X509TrustManager? = override fun trustManager(sslSocketFactory: SSLSocketFactory): X509TrustManager? =
socketAdapters.find { it.matchesSocketFactory(sslSocketFactory) } socketAdapters
.find { it.matchesSocketFactory(sslSocketFactory) }
?.trustManager(sslSocketFactory) ?.trustManager(sslSocketFactory)
override fun newSSLContext(): SSLContext { override fun newSSLContext(): SSLContext {
@ -72,7 +75,8 @@ class Android10Platform : Platform(), ContextAwarePlatform {
protocols: List<Protocol>, protocols: List<Protocol>,
) { ) {
// No TLS extensions if the socket class is custom. // No TLS extensions if the socket class is custom.
socketAdapters.find { it.matchesSocket(sslSocket) } socketAdapters
.find { it.matchesSocket(sslSocket) }
?.configureTlsExtensions(sslSocket, hostname, protocols) ?.configureTlsExtensions(sslSocket, hostname, protocols)
} }
@ -80,13 +84,12 @@ class Android10Platform : Platform(), ContextAwarePlatform {
// No TLS extensions if the socket class is custom. // No TLS extensions if the socket class is custom.
socketAdapters.find { it.matchesSocket(sslSocket) }?.getSelectedProtocol(sslSocket) socketAdapters.find { it.matchesSocket(sslSocket) }?.getSelectedProtocol(sslSocket)
override fun getStackTraceForCloseable(closer: String): Any? { override fun getStackTraceForCloseable(closer: String): Any? =
return if (Build.VERSION.SDK_INT >= 30) { if (Build.VERSION.SDK_INT >= 30) {
CloseGuard().apply { open(closer) } CloseGuard().apply { open(closer) }
} else { } else {
super.getStackTraceForCloseable(closer) super.getStackTraceForCloseable(closer)
} }
}
override fun logCloseableLeak( override fun logCloseableLeak(
message: String, message: String,

View File

@ -45,7 +45,9 @@ import okhttp3.internal.tls.TrustRootIndex
/** Android 5 to 9 (API 21 to 28). */ /** Android 5 to 9 (API 21 to 28). */
@SuppressSignatureCheck @SuppressSignatureCheck
class AndroidPlatform : Platform(), ContextAwarePlatform { class AndroidPlatform :
Platform(),
ContextAwarePlatform {
override var applicationContext: Context? = null override var applicationContext: Context? = null
private val socketAdapters = private val socketAdapters =
@ -83,7 +85,8 @@ class AndroidPlatform : Platform(), ContextAwarePlatform {
} }
override fun trustManager(sslSocketFactory: SSLSocketFactory): X509TrustManager? = override fun trustManager(sslSocketFactory: SSLSocketFactory): X509TrustManager? =
socketAdapters.find { it.matchesSocketFactory(sslSocketFactory) } socketAdapters
.find { it.matchesSocketFactory(sslSocketFactory) }
?.trustManager(sslSocketFactory) ?.trustManager(sslSocketFactory)
override fun configureTlsExtensions( override fun configureTlsExtensions(
@ -92,7 +95,8 @@ class AndroidPlatform : Platform(), ContextAwarePlatform {
protocols: List<@JvmSuppressWildcards Protocol>, protocols: List<@JvmSuppressWildcards Protocol>,
) { ) {
// No TLS extensions if the socket class is custom. // No TLS extensions if the socket class is custom.
socketAdapters.find { it.matchesSocket(sslSocket) } socketAdapters
.find { it.matchesSocket(sslSocket) }
?.configureTlsExtensions(sslSocket, hostname, protocols) ?.configureTlsExtensions(sslSocket, hostname, protocols)
} }
@ -156,8 +160,8 @@ class AndroidPlatform : Platform(), ContextAwarePlatform {
private val trustManager: X509TrustManager, private val trustManager: X509TrustManager,
private val findByIssuerAndSignatureMethod: Method, private val findByIssuerAndSignatureMethod: Method,
) : TrustRootIndex { ) : TrustRootIndex {
override fun findByIssuerAndSignature(cert: X509Certificate): X509Certificate? { override fun findByIssuerAndSignature(cert: X509Certificate): X509Certificate? =
return try { try {
val trustAnchor = val trustAnchor =
findByIssuerAndSignatureMethod.invoke( findByIssuerAndSignatureMethod.invoke(
trustManager, trustManager,
@ -169,7 +173,6 @@ class AndroidPlatform : Platform(), ContextAwarePlatform {
} catch (_: InvocationTargetException) { } catch (_: InvocationTargetException) {
null null
} }
}
} }
companion object { companion object {

Some files were not shown because too many files have changed in this diff Show More