1
0
mirror of https://github.com/square/okhttp.git synced 2025-07-31 05:04:26 +03:00

Reformat with Spotless (#8180)

* Enable spotless

* Run spotlessApply

* Fixup trimMargin

* Re-run spotlessApply
This commit is contained in:
Jesse Wilson
2024-01-07 20:13:22 -05:00
committed by GitHub
parent 0e312d7804
commit a228fd64cc
442 changed files with 24992 additions and 18542 deletions

View File

@ -2,9 +2,10 @@ root = true
[*] [*]
indent_size = 2 indent_size = 2
ij_continuation_indent_size = 2
charset = utf-8 charset = utf-8
trim_trailing_whitespace = true trim_trailing_whitespace = true
insert_final_newline = true insert_final_newline = true
[*.{kt, kts}] [*.{kt, kts}]
kotlin_imports_layout = ascii ij_kotlin_imports_layout = *

View File

@ -24,7 +24,6 @@ import org.junit.Test
* Run with "./gradlew :android-test-app:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set. * Run with "./gradlew :android-test-app:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/ */
class PublicSuffixDatabaseTest { class PublicSuffixDatabaseTest {
@Test @Test
fun testTopLevelDomain() { fun testTopLevelDomain() {
assertThat("https://www.google.com/robots.txt".toHttpUrl().topPrivateDomain()).isEqualTo("google.com") assertThat("https://www.google.com/robots.txt".toHttpUrl().topPrivateDomain()).isEqualTo("google.com")

View File

@ -19,20 +19,13 @@ import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import okhttp3.Call import okhttp3.Call
import okhttp3.Callback import okhttp3.Callback
import okhttp3.CookieJar
import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
import okhttp3.Response import okhttp3.Response
import okhttp3.internal.publicsuffix.PublicSuffixDatabase
import okio.FileSystem
import okio.IOException import okio.IOException
import okio.Path.Companion.toPath
import java.net.CookieHandler
import java.net.URI
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) { override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
@ -41,15 +34,23 @@ class MainActivity : ComponentActivity() {
val url = "https://github.com/square/okhttp".toHttpUrl() val url = "https://github.com/square/okhttp".toHttpUrl()
println(url.topPrivateDomain()) println(url.topPrivateDomain())
client.newCall(Request(url)).enqueue(object : Callback { client.newCall(Request(url)).enqueue(
override fun onFailure(call: Call, e: IOException) { object : Callback {
println("failed: $e") override fun onFailure(
} call: Call,
e: IOException,
) {
println("failed: $e")
}
override fun onResponse(call: Call, response: Response) { override fun onResponse(
println("response: ${response.code}") call: Call,
response.close() response: Response,
} ) {
}) println("response: ${response.code}")
response.close()
}
},
)
} }
} }

View File

@ -17,5 +17,4 @@ package okhttp.android.testapp
import android.app.Application import android.app.Application
class TestApplication: Application() { class TestApplication : Application()
}

View File

@ -21,6 +21,25 @@ import com.google.android.gms.common.GooglePlayServicesNotAvailableException
import com.google.android.gms.security.ProviderInstaller import com.google.android.gms.security.ProviderInstaller
import com.squareup.moshi.Moshi import com.squareup.moshi.Moshi
import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory
import java.io.IOException
import java.net.InetAddress
import java.net.UnknownHostException
import java.security.KeyStore
import java.security.SecureRandom
import java.security.Security
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Handler
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import mockwebserver3.MockResponse import mockwebserver3.MockResponse
import mockwebserver3.MockWebServer import mockwebserver3.MockWebServer
import mockwebserver3.junit5.internal.MockWebServerExtension import mockwebserver3.junit5.internal.MockWebServerExtension
@ -31,6 +50,7 @@ import okhttp3.Connection
import okhttp3.DelegatingSSLSocket import okhttp3.DelegatingSSLSocket
import okhttp3.DelegatingSSLSocketFactory import okhttp3.DelegatingSSLSocketFactory
import okhttp3.EventListener import okhttp3.EventListener
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
import okhttp3.OkHttpClientTestRule import okhttp3.OkHttpClientTestRule
@ -59,33 +79,13 @@ import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Assertions.fail import org.junit.jupiter.api.Assertions.fail
import org.junit.jupiter.api.Assumptions.assumeTrue import org.junit.jupiter.api.Assumptions.assumeTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Tag import org.junit.jupiter.api.Tag
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.extension.RegisterExtension import org.junit.jupiter.api.extension.RegisterExtension
import org.opentest4j.TestAbortedException import org.opentest4j.TestAbortedException
import java.io.IOException
import java.net.InetAddress
import java.net.UnknownHostException
import java.security.KeyStore
import java.security.SecureRandom
import java.security.Security
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Handler
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import okhttp3.Headers
import org.junit.jupiter.api.BeforeEach
/** /**
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set. * Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
@ -93,22 +93,25 @@ import org.junit.jupiter.api.BeforeEach
@ExtendWith(MockWebServerExtension::class) @ExtendWith(MockWebServerExtension::class)
@Tag("Slow") @Tag("Slow")
class OkHttpTest { class OkHttpTest {
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension
public val platform = PlatformRule()
@Suppress("RedundantVisibilityModifier") @Suppress("RedundantVisibilityModifier")
@JvmField @JvmField
@RegisterExtension public val platform = PlatformRule() @RegisterExtension
public val clientTestRule =
@Suppress("RedundantVisibilityModifier") OkHttpClientTestRule().apply {
@JvmField logger = Logger.getLogger(OkHttpTest::class.java.name)
@RegisterExtension public val clientTestRule = OkHttpClientTestRule().apply { }
logger = Logger.getLogger(OkHttpTest::class.java.name)
}
private var client: OkHttpClient = clientTestRule.newClient() private var client: OkHttpClient = clientTestRule.newClient()
private val moshi = Moshi.Builder() private val moshi =
.add(KotlinJsonAdapterFactory()) Moshi.Builder()
.build() .add(KotlinJsonAdapterFactory())
.build()
private val handshakeCertificates = localhost() private val handshakeCertificates = localhost()
@ -136,18 +139,20 @@ class OkHttpTest {
val request = Request.Builder().url("https://api.twitter.com/robots.txt").build() val request = Request.Builder().url("https://api.twitter.com/robots.txt").build()
val clientCertificates = HandshakeCertificates.Builder() val clientCertificates =
.addPlatformTrustedCertificates() HandshakeCertificates.Builder()
.apply { .addPlatformTrustedCertificates()
if (Build.VERSION.SDK_INT >= 24) { .apply {
addInsecureHost(server.hostName) if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName)
}
} }
} .build()
.build()
client = client.newBuilder() client =
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager) client.newBuilder()
.build() .sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -184,22 +189,29 @@ class OkHttpTest {
var socketClass: String? = null var socketClass: String? = null
val clientCertificates = HandshakeCertificates.Builder() val clientCertificates =
.addPlatformTrustedCertificates() HandshakeCertificates.Builder()
.addInsecureHost(server.hostName) .addPlatformTrustedCertificates()
.build() .addInsecureHost(server.hostName)
.build()
// Need fresh client to reset sslSocketFactoryOrNull // Need fresh client to reset sslSocketFactoryOrNull
client = OkHttpClient.Builder() client =
.eventListenerFactory( OkHttpClient.Builder()
clientTestRule.wrap(object : EventListener() { .eventListenerFactory(
override fun connectionAcquired(call: Call, connection: Connection) { clientTestRule.wrap(
socketClass = connection.socket().javaClass.name object : EventListener() {
} override fun connectionAcquired(
}) call: Call,
) connection: Connection,
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager) ) {
.build() socketClass = connection.socket().javaClass.name
}
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -240,26 +252,33 @@ class OkHttpTest {
throw TestAbortedException("Google Play Services not available", gpsnae) throw TestAbortedException("Google Play Services not available", gpsnae)
} }
val clientCertificates = HandshakeCertificates.Builder() val clientCertificates =
.addPlatformTrustedCertificates() HandshakeCertificates.Builder()
.addInsecureHost(server.hostName) .addPlatformTrustedCertificates()
.build() .addInsecureHost(server.hostName)
.build()
val request = Request.Builder().url("https://facebook.com/robots.txt").build() val request = Request.Builder().url("https://facebook.com/robots.txt").build()
var socketClass: String? = null var socketClass: String? = null
// Need fresh client to reset sslSocketFactoryOrNull // Need fresh client to reset sslSocketFactoryOrNull
client = OkHttpClient.Builder() client =
.eventListenerFactory( OkHttpClient.Builder()
clientTestRule.wrap(object : EventListener() { .eventListenerFactory(
override fun connectionAcquired(call: Call, connection: Connection) { clientTestRule.wrap(
socketClass = connection.socket().javaClass.name object : EventListener() {
} override fun connectionAcquired(
}) call: Call,
) connection: Connection,
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager) ) {
.build() socketClass = connection.socket().javaClass.name
}
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -298,24 +317,31 @@ class OkHttpTest {
var socketClass: String? = null var socketClass: String? = null
val clientCertificates = HandshakeCertificates.Builder() val clientCertificates =
.addPlatformTrustedCertificates().apply { HandshakeCertificates.Builder()
if (Build.VERSION.SDK_INT >= 24) { .addPlatformTrustedCertificates().apply {
addInsecureHost(server.hostName) if (Build.VERSION.SDK_INT >= 24) {
} addInsecureHost(server.hostName)
}
.build()
client = client.newBuilder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
socketClass = connection.socket().javaClass.name
} }
}) }
) .build()
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build() client =
client.newBuilder()
.eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -372,7 +398,7 @@ class OkHttpTest {
val tls_version: String, val tls_version: String,
val able_to_detect_n_minus_one_splitting: Boolean, val able_to_detect_n_minus_one_splitting: Boolean,
val insecure_cipher_suites: Map<String, List<String>>, val insecure_cipher_suites: Map<String, List<String>>,
val given_cipher_suites: List<String>? val given_cipher_suites: List<String>?,
) )
@Test @Test
@ -384,9 +410,10 @@ class OkHttpTest {
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
val results = response.use { val results =
moshi.adapter(HowsMySslResults::class.java).fromJson(response.body.string())!! response.use {
} moshi.adapter(HowsMySslResults::class.java).fromJson(response.body.string())!!
}
Platform.get().log("results $results", Platform.WARN) Platform.get().log("results $results", Platform.WARN)
@ -417,7 +444,7 @@ class OkHttpTest {
assertTrue(tlsVersion == TlsVersion.TLS_1_2 || tlsVersion == TlsVersion.TLS_1_3) assertTrue(tlsVersion == TlsVersion.TLS_1_2 || tlsVersion == TlsVersion.TLS_1_3)
assertEquals( assertEquals(
"CN=localhost", "CN=localhost",
(response.handshake!!.peerCertificates.first() as X509Certificate).subjectDN.name (response.handshake!!.peerCertificates.first() as X509Certificate).subjectDN.name,
) )
} }
} }
@ -426,9 +453,10 @@ class OkHttpTest {
fun testCertificatePinningFailure() { fun testCertificatePinningFailure() {
enableTls() enableTls()
val certificatePinner = CertificatePinner.Builder() val certificatePinner =
.add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=") CertificatePinner.Builder()
.build() .add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build() client = client.newBuilder().certificatePinner(certificatePinner).build()
server.enqueue(MockResponse(body = "abc")) server.enqueue(MockResponse(body = "abc"))
@ -446,12 +474,13 @@ class OkHttpTest {
fun testCertificatePinningSuccess() { fun testCertificatePinningSuccess() {
enableTls() enableTls()
val certificatePinner = CertificatePinner.Builder() val certificatePinner =
.add( CertificatePinner.Builder()
server.hostName, .add(
CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0]) server.hostName,
) CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0]),
.build() )
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build() client = client.newBuilder().certificatePinner(certificatePinner).build()
server.enqueue(MockResponse(body = "abc")) server.enqueue(MockResponse(body = "abc"))
@ -488,9 +517,9 @@ class OkHttpTest {
"ConnectStart", "SecureConnectStart", "SecureConnectEnd", "ConnectEnd", "ConnectStart", "SecureConnectStart", "SecureConnectEnd", "ConnectEnd",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", "ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased",
"CallEnd" "CallEnd",
), ),
eventListener.recordedEventTypes() eventListener.recordedEventTypes(),
) )
eventListener.clearAllEvents() eventListener.clearAllEvents()
@ -504,9 +533,9 @@ class OkHttpTest {
"CallStart", "CallStart",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart", "ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased", "ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased",
"CallEnd" "CallEnd",
), ),
eventListener.recordedEventTypes() eventListener.recordedEventTypes(),
) )
} }
@ -516,15 +545,21 @@ class OkHttpTest {
enableTls() enableTls()
client = client.newBuilder().eventListenerFactory( client =
clientTestRule.wrap(object : EventListener() { client.newBuilder().eventListenerFactory(
override fun connectionAcquired(call: Call, connection: Connection) { clientTestRule.wrap(
val sslSocket = connection.socket() as SSLSocket object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
sessionIds.add(sslSocket.session.id.toByteString().hex()) sessionIds.add(sslSocket.session.id.toByteString().hex())
} }
}) },
).build() ),
).build()
server.enqueue(MockResponse(body = "abc1")) server.enqueue(MockResponse(body = "abc1"))
server.enqueue(MockResponse(body = "abc2")) server.enqueue(MockResponse(body = "abc2"))
@ -550,9 +585,10 @@ class OkHttpTest {
fun testDnsOverHttps() { fun testDnsOverHttps() {
assumeNetwork() assumeNetwork()
client = client.newBuilder() client =
.eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory())) client.newBuilder()
.build() .eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory()))
.build()
val dohDns = buildCloudflareIp(client) val dohDns = buildCloudflareIp(client)
val dohEnabledClient = val dohEnabledClient =
@ -566,25 +602,34 @@ class OkHttpTest {
fun testCustomTrustManager() { fun testCustomTrustManager() {
assumeNetwork() assumeNetwork()
val trustManager = object : X509TrustManager { val trustManager =
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {} object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {} override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf() override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
} }
val sslContext = Platform.get().newSSLContext().apply { val sslContext =
init(null, arrayOf(trustManager), null) Platform.get().newSSLContext().apply {
} init(null, arrayOf(trustManager), null)
}
val sslSocketFactory = sslContext.socketFactory val sslSocketFactory = sslContext.socketFactory
val hostnameVerifier = HostnameVerifier { _, _ -> true } val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client.newBuilder() client =
.sslSocketFactory(sslSocketFactory, trustManager) client.newBuilder()
.hostnameVerifier(hostnameVerifier) .sslSocketFactory(sslSocketFactory, trustManager)
.build() .hostnameVerifier(hostnameVerifier)
.build()
client.get("https://www.facebook.com/robots.txt") client.get("https://www.facebook.com/robots.txt")
} }
@ -598,19 +643,21 @@ class OkHttpTest {
val sslSocketFactory = client.sslSocketFactory val sslSocketFactory = client.sslSocketFactory
val trustManager = client.x509TrustManager!! val trustManager = client.x509TrustManager!!
val delegatingSocketFactory = object : DelegatingSSLSocketFactory(sslSocketFactory) { val delegatingSocketFactory =
override fun configureSocket(sslSocket: SSLSocket): SSLSocket { object : DelegatingSSLSocketFactory(sslSocketFactory) {
return object : DelegatingSSLSocket(sslSocket) { override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
override fun getApplicationProtocol(): String { return object : DelegatingSSLSocket(sslSocket) {
throw UnsupportedOperationException() override fun getApplicationProtocol(): String {
throw UnsupportedOperationException()
}
} }
} }
} }
}
client = client.newBuilder() client =
.sslSocketFactory(delegatingSocketFactory, trustManager) client.newBuilder()
.build() .sslSocketFactory(delegatingSocketFactory, trustManager)
.build()
val request = Request.Builder().url(server.url("/").toString()).build() val request = Request.Builder().url(server.url("/").toString()).build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
@ -626,34 +673,47 @@ class OkHttpTest {
var withHostCalled = false var withHostCalled = false
var withoutHostCalled = false var withoutHostCalled = false
val trustManager = object : X509TrustManager { val trustManager =
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {} object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) { override fun checkServerTrusted(
withoutHostCalled = true chain: Array<out X509Certificate>?,
authType: String?,
) {
withoutHostCalled = true
}
// called by Android via reflection in X509TrustManagerExtensions
@Suppress("unused", "UNUSED_PARAMETER")
fun checkServerTrusted(
chain: Array<out X509Certificate>,
authType: String,
hostname: String,
): List<X509Certificate> {
withHostCalled = true
return chain.toList()
}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
} }
@Suppress("unused", "UNUSED_PARAMETER") val sslContext =
// called by Android via reflection in X509TrustManagerExtensions Platform.get().newSSLContext().apply {
fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String, hostname: String): List<X509Certificate> { init(null, arrayOf(trustManager), null)
withHostCalled = true
return chain.toList()
} }
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
val sslContext = Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
val sslSocketFactory = sslContext.socketFactory val sslSocketFactory = sslContext.socketFactory
val hostnameVerifier = HostnameVerifier { _, _ -> true } val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client.newBuilder() client =
.sslSocketFactory(sslSocketFactory, trustManager) client.newBuilder()
.hostnameVerifier(hostnameVerifier) .sslSocketFactory(sslSocketFactory, trustManager)
.build() .hostnameVerifier(hostnameVerifier)
.build()
client.get("https://www.facebook.com/robots.txt") client.get("https://www.facebook.com/robots.txt")
@ -702,25 +762,33 @@ class OkHttpTest {
var socketClass: String? = null var socketClass: String? = null
val trustManager = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply { val trustManager =
init(null as KeyStore?) TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
}.trustManagers.first() as X509TrustManager init(null as KeyStore?)
}.trustManagers.first() as X509TrustManager
val sslContext = Platform.get().newSSLContext().apply { val sslContext =
// TODO remove most of this code after https://github.com/bcgit/bc-java/issues/686 Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), SecureRandom()) // TODO remove most of this code after https://github.com/bcgit/bc-java/issues/686
} init(null, arrayOf(trustManager), SecureRandom())
}
client = client.newBuilder() client =
.sslSocketFactory(sslContext.socketFactory, trustManager) client.newBuilder()
.eventListenerFactory( .sslSocketFactory(sslContext.socketFactory, trustManager)
clientTestRule.wrap(object : EventListener() { .eventListenerFactory(
override fun connectionAcquired(call: Call, connection: Connection) { clientTestRule.wrap(
socketClass = connection.socket().javaClass.name object : EventListener() {
} override fun connectionAcquired(
}) call: Call,
) connection: Connection,
.build() ) {
socketClass = connection.socket().javaClass.name
}
},
),
)
.build()
val request = Request.Builder().url("https://facebook.com/robots.txt").build() val request = Request.Builder().url("https://facebook.com/robots.txt").build()
@ -743,22 +811,23 @@ class OkHttpTest {
fun testLoggingLevels() { fun testLoggingLevels() {
enableTls() enableTls()
val testHandler = object : Handler() { val testHandler =
val calls = mutableMapOf<String, AtomicInteger>() object : Handler() {
val calls = mutableMapOf<String, AtomicInteger>()
override fun publish(record: LogRecord) { override fun publish(record: LogRecord) {
calls.getOrPut(record.loggerName) { AtomicInteger(0) } calls.getOrPut(record.loggerName) { AtomicInteger(0) }
.incrementAndGet() .incrementAndGet()
} }
override fun flush() { override fun flush() {
} }
override fun close() { override fun close() {
}
}.apply {
level = Level.FINEST
} }
}.apply {
level = Level.FINEST
}
Logger.getLogger("") Logger.getLogger("")
.addHandler(testHandler) .addHandler(testHandler)
@ -773,12 +842,14 @@ class OkHttpTest {
server.enqueue(MockResponse(body = "abc")) server.enqueue(MockResponse(body = "abc"))
val request = Request.Builder() val request =
.url(server.url("/")) Request.Builder()
.build() .url(server.url("/"))
.build()
val response = client.newCall(request) val response =
.execute() client.newCall(request)
.execute()
response.use { response.use {
assertEquals(200, response.code) assertEquals(200, response.code)
@ -802,13 +873,15 @@ class OkHttpTest {
val cache = Cache(ctxt.cacheDir.resolve("testCache"), cacheSize) val cache = Cache(ctxt.cacheDir.resolve("testCache"), cacheSize)
try { try {
client = client.newBuilder() client =
.cache(cache) client.newBuilder()
.build() .cache(cache)
.build()
val request = Request.Builder() val request =
.url(server.url("/")) Request.Builder()
.build() .url(server.url("/"))
.build()
client.newCall(request) client.newCall(request)
.execute() .execute()
@ -853,11 +926,12 @@ class OkHttpTest {
} }
private fun enableTls() { private fun enableTls() {
client = client.newBuilder() client =
.sslSocketFactory( client.newBuilder()
handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager .sslSocketFactory(
) handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager,
.build() )
.build()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
} }

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package okhttp.android.test.alpn; package okhttp.android.test.alpn
import android.os.Build import android.os.Build
import android.util.Log import android.util.Log
@ -36,9 +36,8 @@ import org.junit.jupiter.api.Test
*/ */
@Tag("Remote") @Tag("Remote")
class AlpnOverrideTest { class AlpnOverrideTest {
class CustomSSLSocketFactory( class CustomSSLSocketFactory(
delegate: SSLSocketFactory delegate: SSLSocketFactory,
) : DelegatingSSLSocketFactory(delegate) { ) : DelegatingSSLSocketFactory(delegate) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket { override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
@ -56,25 +55,34 @@ class AlpnOverrideTest {
@Test @Test
fun getWithCustomSocketFactory() { fun getWithCustomSocketFactory() {
client = client.newBuilder() client =
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!) client.newBuilder()
.connectionSpecs(listOf( .sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS) .connectionSpecs(
.supportsTlsExtensions(false) listOf(
.build() ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
)) .supportsTlsExtensions(false)
.eventListener(object : EventListener() { .build(),
override fun connectionAcquired(call: Call, connection: Connection) { ),
val sslSocket = connection.socket() as SSLSocket )
println("Requested " + sslSocket.sslParameters.applicationProtocols.joinToString()) .eventListener(
println("Negotiated " + sslSocket.applicationProtocol) object : EventListener() {
} override fun connectionAcquired(
}) call: Call,
.build() connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
println("Requested " + sslSocket.sslParameters.applicationProtocols.joinToString())
println("Negotiated " + sslSocket.applicationProtocol)
}
},
)
.build()
val request = Request.Builder() val request =
.url("https://www.google.com") Request.Builder()
.build() .url("https://www.google.com")
.build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(200) assertThat(response.code).isEqualTo(200)
} }

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package okhttp.android.test.letsencrypt; package okhttp.android.test.letsencrypt
import android.os.Build import android.os.Build
import assertk.assertThat import assertk.assertThat
@ -41,56 +41,61 @@ class LetsEncryptClientTest {
val clientBuilder = OkHttpClient.Builder() val clientBuilder = OkHttpClient.Builder()
if (androidMorEarlier) { if (androidMorEarlier) {
val cert: X509Certificate = """ val cert: X509Certificate =
-----BEGIN CERTIFICATE----- """
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw -----BEGIN CERTIFICATE-----
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
-----END CERTIFICATE----- emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
""".trimIndent().decodeCertificatePem() -----END CERTIFICATE-----
""".trimIndent().decodeCertificatePem()
val handshakeCertificates = HandshakeCertificates.Builder() val handshakeCertificates =
// TODO reenable in official answers HandshakeCertificates.Builder()
// TODO reenable in official answers
// .addPlatformTrustedCertificates() // .addPlatformTrustedCertificates()
.addTrustedCertificate(cert) .addTrustedCertificate(cert)
.build() .build()
clientBuilder clientBuilder
.sslSocketFactory(handshakeCertificates.sslSocketFactory(), .sslSocketFactory(
handshakeCertificates.trustManager) handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
} }
val client = clientBuilder.build() val client = clientBuilder.build()
val request = Request.Builder() val request =
.url("https://valid-isrgrootx1.letsencrypt.org/robots.txt") Request.Builder()
.build() .url("https://valid-isrgrootx1.letsencrypt.org/robots.txt")
.build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(404) assertThat(response.code).isEqualTo(404)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2) assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package okhttp.android.test.sni; package okhttp.android.test.sni
import android.os.Build import android.os.Build
import android.util.Log import android.util.Log
@ -39,15 +39,16 @@ import org.junit.jupiter.api.Test
*/ */
@Tag("Remote") @Tag("Remote")
class SniOverrideTest { class SniOverrideTest {
var client = OkHttpClient.Builder() var client =
.build() OkHttpClient.Builder()
.build()
@Test @Test
fun getWithCustomSocketFactory() { fun getWithCustomSocketFactory() {
assumeTrue(Build.VERSION.SDK_INT >= 24) assumeTrue(Build.VERSION.SDK_INT >= 24)
class CustomSSLSocketFactory( class CustomSSLSocketFactory(
delegate: SSLSocketFactory delegate: SSLSocketFactory,
) : DelegatingSSLSocketFactory(delegate) { ) : DelegatingSSLSocketFactory(delegate) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket { override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
@ -62,49 +63,52 @@ class SniOverrideTest {
} }
} }
client = client.newBuilder() client =
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!) client.newBuilder()
.hostnameVerifier { hostname, session -> .sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
val s = "hostname: $hostname peerHost:${session.peerHost}" .hostnameVerifier { hostname, session ->
Log.d("SniOverrideTest", s) val s = "hostname: $hostname peerHost:${session.peerHost}"
try { Log.d("SniOverrideTest", s)
val cert = session.peerCertificates[0] as X509Certificate try {
for (name in cert.subjectAlternativeNames) { val cert = session.peerCertificates[0] as X509Certificate
if (name[0] as Int == 2) { for (name in cert.subjectAlternativeNames) {
Log.d("SniOverrideTest", "cert: " + name[1]) if (name[0] as Int == 2) {
Log.d("SniOverrideTest", "cert: " + name[1])
}
} }
true
} catch (e: Exception) {
false
} }
true
} catch (e: Exception) {
false
} }
} .build()
.build()
val request = Request.Builder() val request =
.url("https://sni.cloudflaressl.com/cdn-cgi/trace") Request.Builder()
.header("Host", "cloudflare-dns.com") .url("https://sni.cloudflaressl.com/cdn-cgi/trace")
.build() .header("Host", "cloudflare-dns.com")
.build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(200) assertThat(response.code).isEqualTo(200)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2) assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)
assertThat(response.body.string()).contains("h=cloudflare-dns.com") assertThat(response.body.string()).contains("h=cloudflare-dns.com")
} }
} }
@Test @Test
fun getWithDns() { fun getWithDns() {
client = client.newBuilder() client =
.dns { client.newBuilder()
Dns.SYSTEM.lookup("sni.cloudflaressl.com") .dns {
} Dns.SYSTEM.lookup("sni.cloudflaressl.com")
.build() }
.build()
val request = Request.Builder() val request =
.url("https://cloudflare-dns.com/cdn-cgi/trace") Request.Builder()
.build() .url("https://cloudflare-dns.com/cdn-cgi/trace")
.build()
client.newCall(request).execute().use { response -> client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(200) assertThat(response.code).isEqualTo(200)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2) assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)

View File

@ -1,14 +1,14 @@
@file:Suppress("UnstableApiUsage") @file:Suppress("UnstableApiUsage")
import com.diffplug.gradle.spotless.SpotlessExtension
import com.vanniktech.maven.publish.MavenPublishBaseExtension import com.vanniktech.maven.publish.MavenPublishBaseExtension
import com.vanniktech.maven.publish.SonatypeHost import com.vanniktech.maven.publish.SonatypeHost
import java.net.URL import java.net.URI
import kotlinx.validation.ApiValidationExtension import kotlinx.validation.ApiValidationExtension
import org.gradle.api.tasks.testing.logging.TestExceptionFormat import org.gradle.api.tasks.testing.logging.TestExceptionFormat
import org.jetbrains.dokka.gradle.DokkaTaskPartial import org.jetbrains.dokka.gradle.DokkaTaskPartial
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import ru.vyarus.gradle.plugin.animalsniffer.AnimalSnifferExtension import ru.vyarus.gradle.plugin.animalsniffer.AnimalSnifferExtension
import java.net.URI
buildscript { buildscript {
dependencies { dependencies {
@ -35,6 +35,14 @@ buildscript {
} }
apply(plugin = "org.jetbrains.dokka") apply(plugin = "org.jetbrains.dokka")
apply(plugin = "com.diffplug.spotless")
configure<SpotlessExtension> {
kotlin {
target("**/*.kt")
ktlint()
}
}
allprojects { allprojects {
group = "com.squareup.okhttp3" group = "com.squareup.okhttp3"

View File

@ -34,21 +34,23 @@ fun Project.applyOsgi(vararg bndProperties: String) {
private fun Project.applyOsgi( private fun Project.applyOsgi(
jarTaskName: String, jarTaskName: String,
osgiApiConfigurationName: String, osgiApiConfigurationName: String,
bndProperties: Array<out String> bndProperties: Array<out String>,
) { ) {
val osgi = project.sourceSets.create("osgi") val osgi = project.sourceSets.create("osgi")
val osgiApi = project.configurations.getByName(osgiApiConfigurationName) val osgiApi = project.configurations.getByName(osgiApiConfigurationName)
val kotlinOsgi = extensions.getByType(VersionCatalogsExtension::class.java).named("libs") val kotlinOsgi =
.findLibrary("kotlin.stdlib.osgi").get().get() extensions.getByType(VersionCatalogsExtension::class.java).named("libs")
.findLibrary("kotlin.stdlib.osgi").get().get()
project.dependencies { project.dependencies {
osgiApi(kotlinOsgi) osgiApi(kotlinOsgi)
} }
val jarTask = tasks.getByName<Jar>(jarTaskName) val jarTask = tasks.getByName<Jar>(jarTaskName)
val bundleExtension = jarTask.extensions.findByType() ?: jarTask.extensions.create( val bundleExtension =
BundleTaskExtension.NAME, BundleTaskExtension::class.java, jarTask jarTask.extensions.findByType() ?: jarTask.extensions.create(
) BundleTaskExtension.NAME, BundleTaskExtension::class.java, jarTask,
)
bundleExtension.run { bundleExtension.run {
setClasspath(osgi.compileClasspath + sourceSets["main"].compileClasspath) setClasspath(osgi.compileClasspath + sourceSets["main"].compileClasspath)
bnd(*bndProperties) bnd(*bndProperties)

View File

@ -36,9 +36,7 @@ internal fun Dispatcher.wrap(): mockwebserver3.Dispatcher {
val delegate = this val delegate = this
return object : mockwebserver3.Dispatcher() { return object : mockwebserver3.Dispatcher() {
override fun dispatch( override fun dispatch(request: mockwebserver3.RecordedRequest): mockwebserver3.MockResponse {
request: mockwebserver3.RecordedRequest
): mockwebserver3.MockResponse {
return delegate.dispatch(request.unwrap()).wrap() return delegate.dispatch(request.unwrap()).wrap()
} }
@ -70,17 +68,18 @@ internal fun MockResponse.wrap(): mockwebserver3.MockResponse {
result.status = status result.status = status
result.headers(headers) result.headers(headers)
result.trailers(trailers) result.trailers(trailers)
result.socketPolicy = when (socketPolicy) { result.socketPolicy =
SocketPolicy.EXPECT_CONTINUE, SocketPolicy.CONTINUE_ALWAYS -> { when (socketPolicy) {
result.add100Continue() SocketPolicy.EXPECT_CONTINUE, SocketPolicy.CONTINUE_ALWAYS -> {
KeepOpen result.add100Continue()
KeepOpen
}
SocketPolicy.UPGRADE_TO_SSL_AT_END -> {
result.inTunnel()
KeepOpen
}
else -> wrapSocketPolicy()
} }
SocketPolicy.UPGRADE_TO_SSL_AT_END -> {
result.inTunnel()
KeepOpen
}
else -> wrapSocketPolicy()
}
result.throttleBody(throttleBytesPerPeriod, getThrottlePeriod(MILLISECONDS), MILLISECONDS) result.throttleBody(throttleBytesPerPeriod, getThrottlePeriod(MILLISECONDS), MILLISECONDS)
result.bodyDelay(getBodyDelay(MILLISECONDS), MILLISECONDS) result.bodyDelay(getBodyDelay(MILLISECONDS), MILLISECONDS)
result.headersDelay(getHeadersDelay(MILLISECONDS), MILLISECONDS) result.headersDelay(getHeadersDelay(MILLISECONDS), MILLISECONDS)
@ -92,7 +91,7 @@ private fun PushPromise.wrap(): mockwebserver3.PushPromise {
method = method, method = method,
path = path, path = path,
headers = headers, headers = headers,
response = response.wrap() response = response.wrap(),
) )
} }
@ -108,7 +107,7 @@ internal fun mockwebserver3.RecordedRequest.unwrap(): RecordedRequest {
method = method, method = method,
path = path, path = path,
handshake = handshake, handshake = handshake,
requestUrl = requestUrl requestUrl = requestUrl,
) )
} }

View File

@ -14,14 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.mockwebserver package okhttp3.mockwebserver
import java.util.concurrent.TimeUnit
import okhttp3.Headers import okhttp3.Headers
import okhttp3.WebSocketListener import okhttp3.WebSocketListener
import okhttp3.internal.addHeaderLenient import okhttp3.internal.addHeaderLenient
import okhttp3.internal.http2.Settings import okhttp3.internal.http2.Settings
import okio.Buffer import okio.Buffer
import java.util.concurrent.TimeUnit
class MockResponse : Cloneable { class MockResponse : Cloneable {
@set:JvmName("status") @set:JvmName("status")
@ -86,62 +87,81 @@ class MockResponse : Cloneable {
@JvmName("-deprecated_getStatus") @JvmName("-deprecated_getStatus")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "status"), replaceWith = ReplaceWith(expression = "status"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getStatus(): String = status fun getStatus(): String = status
fun setStatus(status: String) = apply { fun setStatus(status: String) =
this.status = status apply {
} this.status = status
}
fun setResponseCode(code: Int): MockResponse { fun setResponseCode(code: Int): MockResponse {
val reason = when (code) { val reason =
in 100..199 -> "Informational" when (code) {
in 200..299 -> "OK" in 100..199 -> "Informational"
in 300..399 -> "Redirection" in 200..299 -> "OK"
in 400..499 -> "Client Error" in 300..399 -> "Redirection"
in 500..599 -> "Server Error" in 400..499 -> "Client Error"
else -> "Mock Response" in 500..599 -> "Server Error"
} else -> "Mock Response"
}
return apply { status = "HTTP/1.1 $code $reason" } return apply { status = "HTTP/1.1 $code $reason" }
} }
fun clearHeaders() = apply { fun clearHeaders() =
headersBuilder = Headers.Builder() apply {
} headersBuilder = Headers.Builder()
}
fun addHeader(header: String) = apply { fun addHeader(header: String) =
headersBuilder.add(header) apply {
} headersBuilder.add(header)
}
fun addHeader(name: String, value: Any) = apply { fun addHeader(
name: String,
value: Any,
) = apply {
headersBuilder.add(name, value.toString()) headersBuilder.add(name, value.toString())
} }
fun addHeaderLenient(name: String, value: Any) = apply { fun addHeaderLenient(
name: String,
value: Any,
) = apply {
addHeaderLenient(headersBuilder, name, value.toString()) addHeaderLenient(headersBuilder, name, value.toString())
} }
fun setHeader(name: String, value: Any) = apply { fun setHeader(
name: String,
value: Any,
) = apply {
removeHeader(name) removeHeader(name)
addHeader(name, value) addHeader(name, value)
} }
fun removeHeader(name: String) = apply { fun removeHeader(name: String) =
headersBuilder.removeAll(name) apply {
} headersBuilder.removeAll(name)
}
fun getBody(): Buffer? = body?.clone() fun getBody(): Buffer? = body?.clone()
fun setBody(body: Buffer) = apply { fun setBody(body: Buffer) =
setHeader("Content-Length", body.size) apply {
this.body = body.clone() // Defensive copy. setHeader("Content-Length", body.size)
} this.body = body.clone() // Defensive copy.
}
fun setBody(body: String): MockResponse = setBody(Buffer().writeUtf8(body)) fun setBody(body: String): MockResponse = setBody(Buffer().writeUtf8(body))
fun setChunkedBody(body: Buffer, maxChunkSize: Int) = apply { fun setChunkedBody(
body: Buffer,
maxChunkSize: Int,
) = apply {
removeHeader("Content-Length") removeHeader("Content-Length")
headersBuilder.add(CHUNKED_BODY_HEADER) headersBuilder.add(CHUNKED_BODY_HEADER)
@ -157,89 +177,107 @@ class MockResponse : Cloneable {
this.body = bytesOut this.body = bytesOut
} }
fun setChunkedBody(body: String, maxChunkSize: Int): MockResponse = fun setChunkedBody(
setChunkedBody(Buffer().writeUtf8(body), maxChunkSize) body: String,
maxChunkSize: Int,
): MockResponse = setChunkedBody(Buffer().writeUtf8(body), maxChunkSize)
@JvmName("-deprecated_getHeaders") @JvmName("-deprecated_getHeaders")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "headers"), replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getHeaders(): Headers = headers fun getHeaders(): Headers = headers
fun setHeaders(headers: Headers) = apply { this.headers = headers } fun setHeaders(headers: Headers) = apply { this.headers = headers }
@JvmName("-deprecated_getTrailers") @JvmName("-deprecated_getTrailers")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "trailers"), replaceWith = ReplaceWith(expression = "trailers"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getTrailers(): Headers = trailers fun getTrailers(): Headers = trailers
fun setTrailers(trailers: Headers) = apply { this.trailers = trailers } fun setTrailers(trailers: Headers) = apply { this.trailers = trailers }
@JvmName("-deprecated_getSocketPolicy") @JvmName("-deprecated_getSocketPolicy")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "socketPolicy"), replaceWith = ReplaceWith(expression = "socketPolicy"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getSocketPolicy(): SocketPolicy = socketPolicy fun getSocketPolicy(): SocketPolicy = socketPolicy
fun setSocketPolicy(socketPolicy: SocketPolicy) = apply { fun setSocketPolicy(socketPolicy: SocketPolicy) =
this.socketPolicy = socketPolicy apply {
} this.socketPolicy = socketPolicy
}
@JvmName("-deprecated_getHttp2ErrorCode") @JvmName("-deprecated_getHttp2ErrorCode")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "http2ErrorCode"), replaceWith = ReplaceWith(expression = "http2ErrorCode"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getHttp2ErrorCode(): Int = http2ErrorCode fun getHttp2ErrorCode(): Int = http2ErrorCode
fun setHttp2ErrorCode(http2ErrorCode: Int) = apply { fun setHttp2ErrorCode(http2ErrorCode: Int) =
this.http2ErrorCode = http2ErrorCode apply {
} this.http2ErrorCode = http2ErrorCode
}
fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply { fun throttleBody(
bytesPerPeriod: Long,
period: Long,
unit: TimeUnit,
) = apply {
throttleBytesPerPeriod = bytesPerPeriod throttleBytesPerPeriod = bytesPerPeriod
throttlePeriodAmount = period throttlePeriodAmount = period
throttlePeriodUnit = unit throttlePeriodUnit = unit
} }
fun getThrottlePeriod(unit: TimeUnit): Long = fun getThrottlePeriod(unit: TimeUnit): Long = unit.convert(throttlePeriodAmount, throttlePeriodUnit)
unit.convert(throttlePeriodAmount, throttlePeriodUnit)
fun setBodyDelay(delay: Long, unit: TimeUnit) = apply { fun setBodyDelay(
delay: Long,
unit: TimeUnit,
) = apply {
bodyDelayAmount = delay bodyDelayAmount = delay
bodyDelayUnit = unit bodyDelayUnit = unit
} }
fun getBodyDelay(unit: TimeUnit): Long = fun getBodyDelay(unit: TimeUnit): Long = unit.convert(bodyDelayAmount, bodyDelayUnit)
unit.convert(bodyDelayAmount, bodyDelayUnit)
fun setHeadersDelay(delay: Long, unit: TimeUnit) = apply { fun setHeadersDelay(
delay: Long,
unit: TimeUnit,
) = apply {
headersDelayAmount = delay headersDelayAmount = delay
headersDelayUnit = unit headersDelayUnit = unit
} }
fun getHeadersDelay(unit: TimeUnit): Long = fun getHeadersDelay(unit: TimeUnit): Long = unit.convert(headersDelayAmount, headersDelayUnit)
unit.convert(headersDelayAmount, headersDelayUnit)
fun withPush(promise: PushPromise) = apply { fun withPush(promise: PushPromise) =
promises.add(promise) apply {
} promises.add(promise)
}
fun withSettings(settings: Settings) = apply { fun withSettings(settings: Settings) =
this.settings = settings apply {
} this.settings = settings
}
fun withWebSocketUpgrade(listener: WebSocketListener) = apply { fun withWebSocketUpgrade(listener: WebSocketListener) =
status = "HTTP/1.1 101 Switching Protocols" apply {
setHeader("Connection", "Upgrade") status = "HTTP/1.1 101 Switching Protocols"
setHeader("Upgrade", "websocket") setHeader("Connection", "Upgrade")
body = null setHeader("Upgrade", "websocket")
webSocketListener = listener body = null
} webSocketListener = listener
}
override fun toString(): String = status override fun toString(): String = status

View File

@ -15,9 +15,6 @@
*/ */
package okhttp3.mockwebserver package okhttp3.mockwebserver
import okhttp3.HttpUrl
import okhttp3.Protocol
import org.junit.rules.ExternalResource
import java.io.Closeable import java.io.Closeable
import java.io.IOException import java.io.IOException
import java.net.InetAddress import java.net.InetAddress
@ -27,6 +24,9 @@ import java.util.logging.Level
import java.util.logging.Logger import java.util.logging.Logger
import javax.net.ServerSocketFactory import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory import javax.net.ssl.SSLSocketFactory
import okhttp3.HttpUrl
import okhttp3.Protocol
import org.junit.rules.ExternalResource
class MockWebServer : ExternalResource(), Closeable { class MockWebServer : ExternalResource(), Closeable {
val delegate = mockwebserver3.MockWebServer() val delegate = mockwebserver3.MockWebServer()
@ -57,7 +57,8 @@ class MockWebServer : ExternalResource(), Closeable {
var protocolNegotiationEnabled: Boolean by delegate::protocolNegotiationEnabled var protocolNegotiationEnabled: Boolean by delegate::protocolNegotiationEnabled
@get:JvmName("protocols") var protocols: List<Protocol> @get:JvmName("protocols")
var protocols: List<Protocol>
get() = delegate.protocols get() = delegate.protocols
set(value) { set(value) {
delegate.protocols = value delegate.protocols = value
@ -80,9 +81,10 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_port") @JvmName("-deprecated_port")
@Deprecated( @Deprecated(
message = "moved to val", message = "moved to val",
replaceWith = ReplaceWith(expression = "port"), replaceWith = ReplaceWith(expression = "port"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getPort(): Int = port fun getPort(): Int = port
fun toProxyAddress(): Proxy { fun toProxyAddress(): Proxy {
@ -92,11 +94,13 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_serverSocketFactory") @JvmName("-deprecated_serverSocketFactory")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith( replaceWith =
expression = "run { this.serverSocketFactory = serverSocketFactory }" ReplaceWith(
expression = "run { this.serverSocketFactory = serverSocketFactory }",
), ),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) { fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) {
delegate.serverSocketFactory = serverSocketFactory delegate.serverSocketFactory = serverSocketFactory
} }
@ -108,43 +112,52 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_bodyLimit") @JvmName("-deprecated_bodyLimit")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith( replaceWith =
expression = "run { this.bodyLimit = bodyLimit }" ReplaceWith(
expression = "run { this.bodyLimit = bodyLimit }",
), ),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun setBodyLimit(bodyLimit: Long) { fun setBodyLimit(bodyLimit: Long) {
delegate.bodyLimit = bodyLimit delegate.bodyLimit = bodyLimit
} }
@JvmName("-deprecated_protocolNegotiationEnabled") @JvmName("-deprecated_protocolNegotiationEnabled")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith( replaceWith =
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }" ReplaceWith(
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }",
), ),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) { fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) {
delegate.protocolNegotiationEnabled = protocolNegotiationEnabled delegate.protocolNegotiationEnabled = protocolNegotiationEnabled
} }
@JvmName("-deprecated_protocols") @JvmName("-deprecated_protocols")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"), replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun setProtocols(protocols: List<Protocol>) { fun setProtocols(protocols: List<Protocol>) {
delegate.protocols = protocols delegate.protocols = protocols
} }
@JvmName("-deprecated_protocols") @JvmName("-deprecated_protocols")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "protocols"), replaceWith = ReplaceWith(expression = "protocols"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun protocols(): List<Protocol> = delegate.protocols fun protocols(): List<Protocol> = delegate.protocols
fun useHttps(sslSocketFactory: SSLSocketFactory, tunnelProxy: Boolean) { fun useHttps(
sslSocketFactory: SSLSocketFactory,
tunnelProxy: Boolean,
) {
delegate.useHttps(sslSocketFactory) delegate.useHttps(sslSocketFactory)
} }
@ -166,15 +179,19 @@ class MockWebServer : ExternalResource(), Closeable {
} }
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? { fun takeRequest(
timeout: Long,
unit: TimeUnit,
): RecordedRequest? {
return delegate.takeRequest(timeout, unit)?.unwrap() return delegate.takeRequest(timeout, unit)?.unwrap()
} }
@JvmName("-deprecated_requestCount") @JvmName("-deprecated_requestCount")
@Deprecated( @Deprecated(
message = "moved to val", message = "moved to val",
replaceWith = ReplaceWith(expression = "requestCount"), replaceWith = ReplaceWith(expression = "requestCount"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun getRequestCount(): Int = delegate.requestCount fun getRequestCount(): Int = delegate.requestCount
fun enqueue(response: MockResponse) { fun enqueue(response: MockResponse) {
@ -182,13 +199,17 @@ class MockWebServer : ExternalResource(), Closeable {
} }
@Throws(IOException::class) @Throws(IOException::class)
@JvmOverloads fun start(port: Int = 0) { @JvmOverloads
fun start(port: Int = 0) {
started = true started = true
delegate.start(port) delegate.start(port)
} }
@Throws(IOException::class) @Throws(IOException::class)
fun start(inetAddress: InetAddress, port: Int) { fun start(
inetAddress: InetAddress,
port: Int,
) {
started = true started = true
delegate.start(inetAddress, port) delegate.start(inetAddress, port)
} }

View File

@ -21,34 +21,37 @@ class PushPromise(
@get:JvmName("method") val method: String, @get:JvmName("method") val method: String,
@get:JvmName("path") val path: String, @get:JvmName("path") val path: String,
@get:JvmName("headers") val headers: Headers, @get:JvmName("headers") val headers: Headers,
@get:JvmName("response") val response: MockResponse @get:JvmName("response") val response: MockResponse,
) { ) {
@JvmName("-deprecated_method") @JvmName("-deprecated_method")
@Deprecated( @Deprecated(
message = "moved to val", message = "moved to val",
replaceWith = ReplaceWith(expression = "method"), replaceWith = ReplaceWith(expression = "method"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun method(): String = method fun method(): String = method
@JvmName("-deprecated_path") @JvmName("-deprecated_path")
@Deprecated( @Deprecated(
message = "moved to val", message = "moved to val",
replaceWith = ReplaceWith(expression = "path"), replaceWith = ReplaceWith(expression = "path"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun path(): String = path fun path(): String = path
@JvmName("-deprecated_headers") @JvmName("-deprecated_headers")
@Deprecated( @Deprecated(
message = "moved to val", message = "moved to val",
replaceWith = ReplaceWith(expression = "headers"), replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun headers(): Headers = headers fun headers(): Headers = headers
@JvmName("-deprecated_response") @JvmName("-deprecated_response")
@Deprecated( @Deprecated(
message = "moved to val", message = "moved to val",
replaceWith = ReplaceWith(expression = "response"), replaceWith = ReplaceWith(expression = "response"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
fun response(): MockResponse = response fun response(): MockResponse = response
} }

View File

@ -15,6 +15,10 @@
*/ */
package okhttp3.mockwebserver package okhttp3.mockwebserver
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
import okhttp3.Handshake import okhttp3.Handshake
import okhttp3.Handshake.Companion.handshake import okhttp3.Handshake.Companion.handshake
import okhttp3.Headers import okhttp3.Headers
@ -22,10 +26,6 @@ import okhttp3.HttpUrl
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
import okhttp3.TlsVersion import okhttp3.TlsVersion
import okio.Buffer import okio.Buffer
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
class RecordedRequest { class RecordedRequest {
val requestLine: String val requestLine: String
@ -42,9 +42,10 @@ class RecordedRequest {
@get:JvmName("-deprecated_utf8Body") @get:JvmName("-deprecated_utf8Body")
@Deprecated( @Deprecated(
message = "Use body.readUtf8()", message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"), replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR,
)
val utf8Body: String val utf8Body: String
get() = body.readUtf8() get() = body.readUtf8()
@ -62,7 +63,7 @@ class RecordedRequest {
method: String?, method: String?,
path: String?, path: String?,
handshake: Handshake?, handshake: Handshake?,
requestUrl: HttpUrl? requestUrl: HttpUrl?,
) { ) {
this.requestLine = requestLine this.requestLine = requestLine
this.headers = headers this.headers = headers
@ -86,7 +87,7 @@ class RecordedRequest {
body: Buffer, body: Buffer,
sequenceNumber: Int, sequenceNumber: Int,
socket: Socket, socket: Socket,
failure: IOException? = null failure: IOException? = null,
) { ) {
this.requestLine = requestLine this.requestLine = requestLine
this.headers = headers this.headers = headers
@ -139,9 +140,10 @@ class RecordedRequest {
} }
@Deprecated( @Deprecated(
message = "Use body.readUtf8()", message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"), replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.WARNING) level = DeprecationLevel.WARNING,
)
fun getUtf8Body(): String = body.readUtf8() fun getUtf8Body(): String = body.readUtf8()
fun getHeader(name: String): String? = headers.values(name).firstOrNull() fun getHeader(name: String): String? = headers.values(name).firstOrNull()

View File

@ -32,5 +32,5 @@ enum class SocketPolicy {
NO_RESPONSE, NO_RESPONSE,
RESET_STREAM_AT_START, RESET_STREAM_AT_START,
EXPECT_CONTINUE, EXPECT_CONTINUE,
CONTINUE_ALWAYS CONTINUE_ALWAYS,
} }

View File

@ -15,6 +15,12 @@
*/ */
package okhttp3.mockwebserver package okhttp3.mockwebserver
import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.Handshake import okhttp3.Handshake
import okhttp3.Headers import okhttp3.Headers
import okhttp3.Headers.Companion.headersOf import okhttp3.Headers.Companion.headersOf
@ -26,35 +32,32 @@ import okhttp3.internal.http2.Settings
import okio.Buffer import okio.Buffer
import org.junit.Ignore import org.junit.Ignore
import org.junit.Test import org.junit.Test
import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
/** /**
* Access every type, function, and property from Kotlin to defend against unexpected regressions in * Access every type, function, and property from Kotlin to defend against unexpected regressions in
* modern 4.0.x kotlin source-compatibility. * modern 4.0.x kotlin source-compatibility.
*/ */
@Suppress( @Suppress(
"ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE", "ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE",
"UNUSED_ANONYMOUS_PARAMETER", "UNUSED_ANONYMOUS_PARAMETER",
"UNUSED_VALUE", "UNUSED_VALUE",
"UNUSED_VARIABLE", "UNUSED_VARIABLE",
"VARIABLE_WITH_REDUNDANT_INITIALIZER", "VARIABLE_WITH_REDUNDANT_INITIALIZER",
"RedundantLambdaArrow", "RedundantLambdaArrow",
"RedundantExplicitType", "RedundantExplicitType",
"IMPLICIT_NOTHING_AS_TYPE_PARAMETER" "IMPLICIT_NOTHING_AS_TYPE_PARAMETER",
) )
class KotlinSourceModernTest { class KotlinSourceModernTest {
@Test @Ignore @Test @Ignore
fun dispatcherFromMockWebServer() { fun dispatcherFromMockWebServer() {
val dispatcher = object : Dispatcher() { val dispatcher =
override fun dispatch(request: RecordedRequest): MockResponse = TODO() object : Dispatcher() {
override fun peek(): MockResponse = TODO() override fun dispatch(request: RecordedRequest): MockResponse = TODO()
override fun shutdown() = TODO()
} override fun peek(): MockResponse = TODO()
override fun shutdown() = TODO()
}
} }
@Test @Ignore @Test @Ignore
@ -96,8 +99,11 @@ class KotlinSourceModernTest {
mockResponse = mockResponse.withSettings(Settings()) mockResponse = mockResponse.withSettings(Settings())
var settings: Settings = mockResponse.settings var settings: Settings = mockResponse.settings
settings = mockResponse.settings settings = mockResponse.settings
mockResponse = mockResponse.withWebSocketUpgrade(object : WebSocketListener() { mockResponse =
}) mockResponse.withWebSocketUpgrade(
object : WebSocketListener() {
},
)
var webSocketListener: WebSocketListener? = mockResponse.webSocketListener var webSocketListener: WebSocketListener? = mockResponse.webSocketListener
webSocketListener = mockResponse.webSocketListener webSocketListener = mockResponse.webSocketListener
} }
@ -146,8 +152,10 @@ class KotlinSourceModernTest {
@Test @Ignore @Test @Ignore
fun queueDispatcher() { fun queueDispatcher() {
val queueDispatcher: QueueDispatcher = QueueDispatcher() val queueDispatcher: QueueDispatcher = QueueDispatcher()
var mockResponse: MockResponse = queueDispatcher.dispatch( var mockResponse: MockResponse =
RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket())) queueDispatcher.dispatch(
RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket()),
)
mockResponse = queueDispatcher.peek() mockResponse = queueDispatcher.peek()
queueDispatcher.enqueueResponse(MockResponse()) queueDispatcher.enqueueResponse(MockResponse())
queueDispatcher.shutdown() queueDispatcher.shutdown()
@ -157,8 +165,16 @@ class KotlinSourceModernTest {
@Test @Ignore @Test @Ignore
fun recordedRequest() { fun recordedRequest() {
var recordedRequest: RecordedRequest = RecordedRequest( var recordedRequest: RecordedRequest =
"", headersOf(), listOf(), 0L, Buffer(), 0, Socket()) RecordedRequest(
"",
headersOf(),
listOf(),
0L,
Buffer(),
0,
Socket(),
)
recordedRequest = RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket()) recordedRequest = RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket())
var requestUrl: HttpUrl? = recordedRequest.requestUrl var requestUrl: HttpUrl? = recordedRequest.requestUrl
var requestLine: String = recordedRequest.requestLine var requestLine: String = recordedRequest.requestLine

View File

@ -86,15 +86,16 @@ class MockWebServerTest {
@Test @Test
fun setResponseMockReason() { fun setResponseMockReason() {
val reasons = arrayOf( val reasons =
"Mock Response", arrayOf(
"Informational", "Mock Response",
"OK", "Informational",
"Redirection", "OK",
"Client Error", "Redirection",
"Server Error", "Client Error",
"Mock Response" "Server Error",
) "Mock Response",
)
for (i in 0..599) { for (i in 0..599) {
val response = MockResponse().setResponseCode(i) val response = MockResponse().setResponseCode(i)
val expectedReason = reasons[i / 100] val expectedReason = reasons[i / 100]
@ -119,21 +120,23 @@ class MockWebServerTest {
@Test @Test
fun mockResponseAddHeader() { fun mockResponseAddHeader() {
val response = MockResponse() val response =
.clearHeaders() MockResponse()
.addHeader("Cookie: s=square") .clearHeaders()
.addHeader("Cookie", "a=android") .addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
assertThat(headersToList(response)) assertThat(headersToList(response))
.containsExactly("Cookie: s=square", "Cookie: a=android") .containsExactly("Cookie: s=square", "Cookie: a=android")
} }
@Test @Test
fun mockResponseSetHeader() { fun mockResponseSetHeader() {
val response = MockResponse() val response =
.clearHeaders() MockResponse()
.addHeader("Cookie: s=square") .clearHeaders()
.addHeader("Cookie: a=android") .addHeader("Cookie: s=square")
.addHeader("Cookies: delicious") .addHeader("Cookie: a=android")
.addHeader("Cookies: delicious")
response.setHeader("cookie", "r=robot") response.setHeader("cookie", "r=robot")
assertThat(headersToList(response)) assertThat(headersToList(response))
.containsExactly("Cookies: delicious", "cookie: r=robot") .containsExactly("Cookies: delicious", "cookie: r=robot")
@ -141,10 +144,11 @@ class MockWebServerTest {
@Test @Test
fun mockResponseSetHeaders() { fun mockResponseSetHeaders() {
val response = MockResponse() val response =
.clearHeaders() MockResponse()
.addHeader("Cookie: s=square") .clearHeaders()
.addHeader("Cookies: delicious") .addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
response.setHeaders(Headers.Builder().add("Cookie", "a=android").build()) response.setHeaders(Headers.Builder().add("Cookie", "a=android").build())
assertThat(headersToList(response)).containsExactly("Cookie: a=android") assertThat(headersToList(response)).containsExactly("Cookie: a=android")
} }
@ -173,7 +177,7 @@ class MockWebServerTest {
MockResponse() MockResponse()
.setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path")) .addHeader("Location: " + server.url("/new-path"))
.setBody("This page has moved!") .setBody("This page has moved!"),
) )
server.enqueue(MockResponse().setBody("This is the new location!")) server.enqueue(MockResponse().setBody("This is the new location!"))
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
@ -211,7 +215,7 @@ class MockWebServerTest {
MockResponse() MockResponse()
.setBody("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n") .setBody("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders() .clearHeaders()
.addHeader("Transfer-encoding: chunked") .addHeader("Transfer-encoding: chunked"),
) )
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
val inputStream = connection.getInputStream() val inputStream = connection.getInputStream()
@ -228,7 +232,7 @@ class MockWebServerTest {
MockResponse() MockResponse()
.setBody("ABC") .setBody("ABC")
.clearHeaders() .clearHeaders()
.addHeader("Content-Length: 4") .addHeader("Content-Length: 4"),
) )
server.enqueue(MockResponse().setBody("DEF")) server.enqueue(MockResponse().setBody("DEF"))
val urlConnection = server.url("/").toUrl().openConnection() val urlConnection = server.url("/").toUrl().openConnection()
@ -257,7 +261,7 @@ class MockWebServerTest {
fun disconnectAtStart() { fun disconnectAtStart() {
server.enqueue( server.enqueue(
MockResponse() MockResponse()
.setSocketPolicy(SocketPolicy.DISCONNECT_AT_START) .setSocketPolicy(SocketPolicy.DISCONNECT_AT_START),
) )
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard. server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse()) server.enqueue(MockResponse())
@ -278,7 +282,7 @@ class MockWebServerTest {
assumeNotWindows() assumeNotWindows()
server.enqueue( server.enqueue(
MockResponse() MockResponse()
.throttleBody(3, 500, TimeUnit.MILLISECONDS) .throttleBody(3, 500, TimeUnit.MILLISECONDS),
) )
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
@ -301,7 +305,7 @@ class MockWebServerTest {
server.enqueue( server.enqueue(
MockResponse() MockResponse()
.setBody("ABCDEF") .setBody("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS) .throttleBody(3, 500, TimeUnit.MILLISECONDS),
) )
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
@ -325,7 +329,7 @@ class MockWebServerTest {
server.enqueue( server.enqueue(
MockResponse() MockResponse()
.setBody("ABCDEF") .setBody("ABCDEF")
.setBodyDelay(1, TimeUnit.SECONDS) .setBodyDelay(1, TimeUnit.SECONDS),
) )
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
@ -373,7 +377,7 @@ class MockWebServerTest {
server.enqueue( server.enqueue(
MockResponse() MockResponse()
.setBody("ab") .setBody("ab")
.setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY) .setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY),
) )
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
assertThat(connection.getContentLength()).isEqualTo(2) assertThat(connection.getContentLength()).isEqualTo(2)
@ -443,12 +447,16 @@ class MockWebServerTest {
@Test @Test
fun statementStartsAndStops() { fun statementStartsAndStops() {
val called = AtomicBoolean() val called = AtomicBoolean()
val statement = server.apply(object : Statement() { val statement =
override fun evaluate() { server.apply(
called.set(true) object : Statement() {
server.url("/").toUrl().openConnection().connect() override fun evaluate() {
} called.set(true)
}, Description.EMPTY) server.url("/").toUrl().openConnection().connect()
}
},
Description.EMPTY,
)
statement.evaluate() statement.evaluate()
assertThat(called.get()).isTrue() assertThat(called.get()).isTrue()
try { try {
@ -484,7 +492,7 @@ class MockWebServerTest {
assertThat(reader.readLine()).isEqualTo("hello world") assertThat(reader.readLine()).isEqualTo("hello world")
val request = server.takeRequest() val request = server.takeRequest()
assertThat(request.requestLine).isEqualTo( assertThat(request.requestLine).isEqualTo(
"GET /a/deep/path?key=foo%20bar HTTP/1.1" "GET /a/deep/path?key=foo%20bar HTTP/1.1",
) )
val requestUrl = request.requestUrl val requestUrl = request.requestUrl
assertThat(requestUrl!!.scheme).isEqualTo("http") assertThat(requestUrl!!.scheme).isEqualTo("http")
@ -530,8 +538,8 @@ class MockWebServerTest {
fail<Any>() fail<Any>()
} catch (expected: IllegalArgumentException) { } catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo( assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: " "protocols containing h2_prior_knowledge cannot use other protocols: " +
+ "[h2_prior_knowledge, http/1.1]" "[h2_prior_knowledge, http/1.1]",
) )
} }
} }
@ -544,8 +552,8 @@ class MockWebServerTest {
fail<Any>() fail<Any>()
} catch (expected: IllegalArgumentException) { } catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo( assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: " "protocols containing h2_prior_knowledge cannot use other protocols: " +
+ "[h2_prior_knowledge, h2_prior_knowledge]" "[h2_prior_knowledge, h2_prior_knowledge]",
) )
} }
} }
@ -584,30 +592,36 @@ class MockWebServerTest {
fun httpsWithClientAuth() { fun httpsWithClientAuth() {
platform.assumeNotBouncyCastle() platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt() platform.assumeNotConscrypt()
val clientCa = HeldCertificate.Builder() val clientCa =
.certificateAuthority(0) HeldCertificate.Builder()
.build() .certificateAuthority(0)
val serverCa = HeldCertificate.Builder() .build()
.certificateAuthority(0) val serverCa =
.build() HeldCertificate.Builder()
val serverCertificate = HeldCertificate.Builder() .certificateAuthority(0)
.signedBy(serverCa) .build()
.addSubjectAlternativeName(server.hostName) val serverCertificate =
.build() HeldCertificate.Builder()
val serverHandshakeCertificates = HandshakeCertificates.Builder() .signedBy(serverCa)
.addTrustedCertificate(clientCa.certificate) .addSubjectAlternativeName(server.hostName)
.heldCertificate(serverCertificate) .build()
.build() val serverHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory(), false) server.useHttps(serverHandshakeCertificates.sslSocketFactory(), false)
server.enqueue(MockResponse().setBody("abc")) server.enqueue(MockResponse().setBody("abc"))
server.requestClientAuth() server.requestClientAuth()
val clientCertificate = HeldCertificate.Builder() val clientCertificate =
.signedBy(clientCa) HeldCertificate.Builder()
.build() .signedBy(clientCa)
val clientHandshakeCertificates = HandshakeCertificates.Builder() .build()
.addTrustedCertificate(serverCa.certificate) val clientHandshakeCertificates =
.heldCertificate(clientCertificate) HandshakeCertificates.Builder()
.build() .addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
val url = server.url("/") val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.setSSLSocketFactory(clientHandshakeCertificates.sslSocketFactory()) connection.setSSLSocketFactory(clientHandshakeCertificates.sslSocketFactory())

View File

@ -15,11 +15,11 @@
*/ */
package mockwebserver3.junit4 package mockwebserver3.junit4
import mockwebserver3.MockWebServer
import org.junit.rules.ExternalResource
import java.io.IOException import java.io.IOException
import java.util.logging.Level import java.util.logging.Level
import java.util.logging.Logger import java.util.logging.Logger
import mockwebserver3.MockWebServer
import org.junit.rules.ExternalResource
/** /**
* Runs MockWebServer for the duration of a single test method. * Runs MockWebServer for the duration of a single test method.

View File

@ -28,12 +28,16 @@ class MockWebServerRuleTest {
@Test fun statementStartsAndStops() { @Test fun statementStartsAndStops() {
val rule = MockWebServerRule() val rule = MockWebServerRule()
val called = AtomicBoolean() val called = AtomicBoolean()
val statement: Statement = rule.apply(object : Statement() { val statement: Statement =
override fun evaluate() { rule.apply(
called.set(true) object : Statement() {
rule.server.url("/").toUrl().openConnection().connect() override fun evaluate() {
} called.set(true)
}, Description.EMPTY) rule.server.url("/").toUrl().openConnection().connect()
}
},
Description.EMPTY,
)
statement.evaluate() statement.evaluate()
assertThat(called.get()).isTrue() assertThat(called.get()).isTrue()
try { try {

View File

@ -38,12 +38,13 @@ import org.junit.jupiter.api.extension.ParameterResolver
* - The test lifecycle default (passed into test method, plus @BeforeEach, @AfterEach) * - The test lifecycle default (passed into test method, plus @BeforeEach, @AfterEach)
* - named instances with @MockWebServerInstance. * - named instances with @MockWebServerInstance.
*/ */
class MockWebServerExtension class MockWebServerExtension :
: BeforeEachCallback, AfterEachCallback, ParameterResolver { BeforeEachCallback, AfterEachCallback, ParameterResolver {
private val ExtensionContext.resource: ServersForTest private val ExtensionContext.resource: ServersForTest
get() = getStore(namespace).getOrComputeIfAbsent(this.uniqueId) { get() =
ServersForTest() getStore(namespace).getOrComputeIfAbsent(this.uniqueId) {
} as ServersForTest ServersForTest()
} as ServersForTest
private class ServersForTest { private class ServersForTest {
private val servers = mutableMapOf<String, MockWebServer>() private val servers = mutableMapOf<String, MockWebServer>()
@ -80,24 +81,25 @@ class MockWebServerExtension
@IgnoreJRERequirement @IgnoreJRERequirement
override fun supportsParameter( override fun supportsParameter(
parameterContext: ParameterContext, parameterContext: ParameterContext,
extensionContext: ExtensionContext extensionContext: ExtensionContext,
): Boolean { ): Boolean {
// Not supported on constructors, or static contexts // Not supported on constructors, or static contexts
return parameterContext.parameter.type === MockWebServer::class.java return parameterContext.parameter.type === MockWebServer::class.java &&
&& extensionContext.testMethod.isPresent extensionContext.testMethod.isPresent
} }
@Suppress("NewApi") @Suppress("NewApi")
override fun resolveParameter( override fun resolveParameter(
parameterContext: ParameterContext, parameterContext: ParameterContext,
extensionContext: ExtensionContext extensionContext: ExtensionContext,
): Any { ): Any {
val nameAnnotation = parameterContext.findAnnotation(MockWebServerInstance::class.java) val nameAnnotation = parameterContext.findAnnotation(MockWebServerInstance::class.java)
val name = if (nameAnnotation.isPresent) { val name =
nameAnnotation.get().name if (nameAnnotation.isPresent) {
} else { nameAnnotation.get().name
defaultName } else {
} defaultName
}
return extensionContext.resource.server(name) return extensionContext.resource.server(name)
} }

View File

@ -16,5 +16,5 @@
package mockwebserver3.junit5.internal package mockwebserver3.junit5.internal
annotation class MockWebServerInstance( annotation class MockWebServerInstance(
val name: String val name: String,
) )

View File

@ -35,7 +35,7 @@ class ExtensionMultipleInstancesTest {
fun setup( fun setup(
defaultInstance: MockWebServer, defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer, @MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer @MockWebServerInstance("B") instanceB: MockWebServer,
) { ) {
defaultInstancePort = defaultInstance.port defaultInstancePort = defaultInstance.port
instanceAPort = instanceA.port instanceAPort = instanceA.port
@ -51,7 +51,7 @@ class ExtensionMultipleInstancesTest {
fun tearDown( fun tearDown(
defaultInstance: MockWebServer, defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer, @MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer @MockWebServerInstance("B") instanceB: MockWebServer,
) { ) {
assertThat(defaultInstance.port).isEqualTo(defaultInstancePort) assertThat(defaultInstance.port).isEqualTo(defaultInstancePort)
assertThat(instanceA.port).isEqualTo(instanceAPort) assertThat(instanceA.port).isEqualTo(instanceAPort)
@ -62,7 +62,7 @@ class ExtensionMultipleInstancesTest {
fun testClient( fun testClient(
defaultInstance: MockWebServer, defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer, @MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer @MockWebServerInstance("B") instanceB: MockWebServer,
) { ) {
assertThat(defaultInstance.port).isEqualTo(defaultInstancePort) assertThat(defaultInstance.port).isEqualTo(defaultInstancePort)
assertThat(instanceA.port).isEqualTo(instanceAPort) assertThat(instanceA.port).isEqualTo(instanceAPort)

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3 package mockwebserver3
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -28,7 +29,6 @@ import okio.Buffer
/** A scripted response to be replayed by the mock web server. */ /** A scripted response to be replayed by the mock web server. */
class MockResponse { class MockResponse {
/** Returns the HTTP response line, such as "HTTP/1.1 200 OK". */ /** Returns the HTTP response line, such as "HTTP/1.1 200 OK". */
val status: String val status: String
@ -76,14 +76,15 @@ class MockResponse {
body: String = "", body: String = "",
inTunnel: Boolean = false, inTunnel: Boolean = false,
socketPolicy: SocketPolicy = KeepOpen, socketPolicy: SocketPolicy = KeepOpen,
) : this(Builder() ) : this(
.apply { Builder()
this.code = code .apply {
this.headers.addAll(headers) this.code = code
if (inTunnel) inTunnel() this.headers.addAll(headers)
this.body(body) if (inTunnel) inTunnel()
this.socketPolicy = socketPolicy this.body(body)
} this.socketPolicy = socketPolicy
},
) )
private constructor(builder: Builder) { private constructor(builder: Builder) {
@ -101,9 +102,10 @@ class MockResponse {
this.bodyDelayNanos = builder.bodyDelayNanos this.bodyDelayNanos = builder.bodyDelayNanos
this.headersDelayNanos = builder.headersDelayNanos this.headersDelayNanos = builder.headersDelayNanos
this.pushPromises = builder.pushPromises.toList() this.pushPromises = builder.pushPromises.toList()
this.settings = Settings().apply { this.settings =
merge(builder.settings) Settings().apply {
} merge(builder.settings)
}
} }
fun newBuilder(): Builder = Builder(this) fun newBuilder(): Builder = Builder(this)
@ -125,14 +127,15 @@ class MockResponse {
return statusParts[1].toInt() return statusParts[1].toInt()
} }
set(value) { set(value) {
val reason = when (value) { val reason =
in 100..199 -> "Informational" when (value) {
in 200..299 -> "OK" in 100..199 -> "Informational"
in 300..399 -> "Redirection" in 200..299 -> "OK"
in 400..499 -> "Client Error" in 300..399 -> "Redirection"
in 500..599 -> "Server Error" in 400..499 -> "Client Error"
else -> "Mock Response" in 500..599 -> "Server Error"
} else -> "Mock Response"
}
status = "HTTP/1.1 $value $reason" status = "HTTP/1.1 $value $reason"
} }
@ -189,8 +192,9 @@ class MockResponse {
this.bodyVar = null this.bodyVar = null
this.streamHandlerVar = null this.streamHandlerVar = null
this.webSocketListenerVar = null this.webSocketListenerVar = null
this.headers = Headers.Builder() this.headers =
.add("Content-Length", "0") Headers.Builder()
.add("Content-Length", "0")
this.trailers = Headers.Builder() this.trailers = Headers.Builder()
this.throttleBytesPerPeriod = Long.MAX_VALUE this.throttleBytesPerPeriod = Long.MAX_VALUE
this.throttlePeriodNanos = 0L this.throttlePeriodNanos = 0L
@ -216,41 +220,49 @@ class MockResponse {
this.bodyDelayNanos = mockResponse.bodyDelayNanos this.bodyDelayNanos = mockResponse.bodyDelayNanos
this.headersDelayNanos = mockResponse.headersDelayNanos this.headersDelayNanos = mockResponse.headersDelayNanos
this.pushPromises = mockResponse.pushPromises.toMutableList() this.pushPromises = mockResponse.pushPromises.toMutableList()
this.settings = Settings().apply { this.settings =
merge(mockResponse.settings) Settings().apply {
} merge(mockResponse.settings)
}
} }
fun code(code: Int) = apply { fun code(code: Int) =
this.code = code apply {
} this.code = code
}
/** Sets the status and returns this. */ /** Sets the status and returns this. */
fun status(status: String) = apply { fun status(status: String) =
this.status = status apply {
} this.status = status
}
/** /**
* Removes all HTTP headers including any "Content-Length" and "Transfer-encoding" headers that * Removes all HTTP headers including any "Content-Length" and "Transfer-encoding" headers that
* were added by default. * were added by default.
*/ */
fun clearHeaders() = apply { fun clearHeaders() =
headers = Headers.Builder() apply {
} headers = Headers.Builder()
}
/** /**
* Adds [header] as an HTTP header. For well-formed HTTP [header] should contain a name followed * Adds [header] as an HTTP header. For well-formed HTTP [header] should contain a name followed
* by a colon and a value. * by a colon and a value.
*/ */
fun addHeader(header: String) = apply { fun addHeader(header: String) =
headers.add(header) apply {
} headers.add(header)
}
/** /**
* Adds a new header with the name and value. This may be used to add multiple headers with the * Adds a new header with the name and value. This may be used to add multiple headers with the
* same name. * same name.
*/ */
fun addHeader(name: String, value: Any) = apply { fun addHeader(
name: String,
value: Any,
) = apply {
headers.add(name, value.toString()) headers.add(name, value.toString())
} }
@ -259,39 +271,51 @@ class MockResponse {
* same name. Unlike [addHeader] this does not validate the name and * same name. Unlike [addHeader] this does not validate the name and
* value. * value.
*/ */
fun addHeaderLenient(name: String, value: Any) = apply { fun addHeaderLenient(
name: String,
value: Any,
) = apply {
addHeaderLenient(headers, name, value.toString()) addHeaderLenient(headers, name, value.toString())
} }
/** Removes all headers named [name], then adds a new header with the name and value. */ /** Removes all headers named [name], then adds a new header with the name and value. */
fun setHeader(name: String, value: Any) = apply { fun setHeader(
name: String,
value: Any,
) = apply {
removeHeader(name) removeHeader(name)
addHeader(name, value) addHeader(name, value)
} }
/** Removes all headers named [name]. */ /** Removes all headers named [name]. */
fun removeHeader(name: String) = apply { fun removeHeader(name: String) =
headers.removeAll(name) apply {
} headers.removeAll(name)
}
fun body(body: Buffer) = body(body.toMockResponseBody()) fun body(body: Buffer) = body(body.toMockResponseBody())
fun body(body: MockResponseBody) = apply { fun body(body: MockResponseBody) =
setHeader("Content-Length", body.contentLength) apply {
this.body = body setHeader("Content-Length", body.contentLength)
} this.body = body
}
/** Sets the response body to the UTF-8 encoded bytes of [body]. */ /** Sets the response body to the UTF-8 encoded bytes of [body]. */
fun body(body: String): Builder = body(Buffer().writeUtf8(body)) fun body(body: String): Builder = body(Buffer().writeUtf8(body))
fun streamHandler(streamHandler: StreamHandler) = apply { fun streamHandler(streamHandler: StreamHandler) =
this.streamHandler = streamHandler apply {
} this.streamHandler = streamHandler
}
/** /**
* Sets the response body to [body], chunked every [maxChunkSize] bytes. * Sets the response body to [body], chunked every [maxChunkSize] bytes.
*/ */
fun chunkedBody(body: Buffer, maxChunkSize: Int) = apply { fun chunkedBody(
body: Buffer,
maxChunkSize: Int,
) = apply {
removeHeader("Content-Length") removeHeader("Content-Length")
headers.add(CHUNKED_BODY_HEADER) headers.add(CHUNKED_BODY_HEADER)
@ -311,29 +335,38 @@ class MockResponse {
* Sets the response body to the UTF-8 encoded bytes of [body], * Sets the response body to the UTF-8 encoded bytes of [body],
* chunked every [maxChunkSize] bytes. * chunked every [maxChunkSize] bytes.
*/ */
fun chunkedBody(body: String, maxChunkSize: Int): Builder = fun chunkedBody(
chunkedBody(Buffer().writeUtf8(body), maxChunkSize) body: String,
maxChunkSize: Int,
): Builder = chunkedBody(Buffer().writeUtf8(body), maxChunkSize)
/** Sets the headers and returns this. */ /** Sets the headers and returns this. */
fun headers(headers: Headers) = apply { fun headers(headers: Headers) =
this.headers = headers.newBuilder() apply {
} this.headers = headers.newBuilder()
}
/** Sets the trailers and returns this. */ /** Sets the trailers and returns this. */
fun trailers(trailers: Headers) = apply { fun trailers(trailers: Headers) =
this.trailers = trailers.newBuilder() apply {
} this.trailers = trailers.newBuilder()
}
/** Sets the socket policy and returns this. */ /** Sets the socket policy and returns this. */
fun socketPolicy(socketPolicy: SocketPolicy) = apply { fun socketPolicy(socketPolicy: SocketPolicy) =
this.socketPolicy = socketPolicy apply {
} this.socketPolicy = socketPolicy
}
/** /**
* Throttles the request reader and response writer to sleep for the given period after each * Throttles the request reader and response writer to sleep for the given period after each
* series of [bytesPerPeriod] bytes are transferred. Use this to simulate network behavior. * series of [bytesPerPeriod] bytes are transferred. Use this to simulate network behavior.
*/ */
fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply { fun throttleBody(
bytesPerPeriod: Long,
period: Long,
unit: TimeUnit,
) = apply {
throttleBytesPerPeriod = bytesPerPeriod throttleBytesPerPeriod = bytesPerPeriod
throttlePeriodNanos = unit.toNanos(period) throttlePeriodNanos = unit.toNanos(period)
} }
@ -342,11 +375,17 @@ class MockResponse {
* Set the delayed time of the response body to [delay]. This applies to the response body * Set the delayed time of the response body to [delay]. This applies to the response body
* only; response headers are not affected. * only; response headers are not affected.
*/ */
fun bodyDelay(delay: Long, unit: TimeUnit) = apply { fun bodyDelay(
delay: Long,
unit: TimeUnit,
) = apply {
bodyDelayNanos = unit.toNanos(delay) bodyDelayNanos = unit.toNanos(delay)
} }
fun headersDelay(delay: Long, unit: TimeUnit) = apply { fun headersDelay(
delay: Long,
unit: TimeUnit,
) = apply {
headersDelayNanos = unit.toNanos(delay) headersDelayNanos = unit.toNanos(delay)
} }
@ -354,29 +393,32 @@ class MockResponse {
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this attaches a * When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this attaches a
* pushed stream to this response. * pushed stream to this response.
*/ */
fun addPush(promise: PushPromise) = apply { fun addPush(promise: PushPromise) =
this.pushPromises += promise apply {
} this.pushPromises += promise
}
/** /**
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this pushes * When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this pushes
* [settings] before writing the response. * [settings] before writing the response.
*/ */
fun settings(settings: Settings) = apply { fun settings(settings: Settings) =
this.settings.clear() apply {
this.settings.merge(settings) this.settings.clear()
} this.settings.merge(settings)
}
/** /**
* Attempts to perform a web socket upgrade on the connection. * Attempts to perform a web socket upgrade on the connection.
* This will overwrite any previously set status, body, or streamHandler. * This will overwrite any previously set status, body, or streamHandler.
*/ */
fun webSocketUpgrade(listener: WebSocketListener) = apply { fun webSocketUpgrade(listener: WebSocketListener) =
status = "HTTP/1.1 101 Switching Protocols" apply {
setHeader("Connection", "Upgrade") status = "HTTP/1.1 101 Switching Protocols"
setHeader("Upgrade", "websocket") setHeader("Connection", "Upgrade")
webSocketListener = listener setHeader("Upgrade", "websocket")
} webSocketListener = listener
}
/** /**
* Configures this response to be served as a response to an HTTP CONNECT request, either for * Configures this response to be served as a response to an HTTP CONNECT request, either for
@ -385,23 +427,26 @@ class MockResponse {
* When a new connection is received, all in-tunnel responses are served before the connection is * When a new connection is received, all in-tunnel responses are served before the connection is
* upgraded to HTTPS or HTTP/2. * upgraded to HTTPS or HTTP/2.
*/ */
fun inTunnel() = apply { fun inTunnel() =
removeHeader("Content-Length") apply {
inTunnel = true removeHeader("Content-Length")
} inTunnel = true
}
/** /**
* Adds an HTTP 1xx response to precede this response. Note that this response's * Adds an HTTP 1xx response to precede this response. Note that this response's
* [headers delay][headersDelay] applies after this response is transmitted. Set a * [headers delay][headersDelay] applies after this response is transmitted. Set a
* headers delay on that response to delay its transmission. * headers delay on that response to delay its transmission.
*/ */
fun addInformationalResponse(response: MockResponse) = apply { fun addInformationalResponse(response: MockResponse) =
informationalResponses += response apply {
} informationalResponses += response
}
fun add100Continue() = apply { fun add100Continue() =
addInformationalResponse(MockResponse(code = 100)) apply {
} addInformationalResponse(MockResponse(code = 100))
}
public override fun clone(): Builder = build().newBuilder() public override fun clone(): Builder = build().newBuilder()

View File

@ -15,6 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3 package mockwebserver3
import java.io.Closeable import java.io.Closeable
@ -97,9 +98,10 @@ import okio.source
* in sequence. * in sequence.
*/ */
class MockWebServer : Closeable { class MockWebServer : Closeable {
private val taskRunnerBackend = TaskRunner.RealBackend( private val taskRunnerBackend =
threadFactory("MockWebServer TaskRunner", daemon = false) TaskRunner.RealBackend(
) threadFactory("MockWebServer TaskRunner", daemon = false),
)
private val taskRunner = TaskRunner(taskRunnerBackend) private val taskRunner = TaskRunner(taskRunnerBackend)
private val requestQueue = LinkedBlockingQueue<RecordedRequest>() private val requestQueue = LinkedBlockingQueue<RecordedRequest>()
private val openClientSockets = private val openClientSockets =
@ -126,6 +128,7 @@ class MockWebServer : Closeable {
} }
return field return field
} }
@Synchronized set(value) { @Synchronized set(value) {
check(!started) { "serverSocketFactory must not be set after start()" } check(!started) { "serverSocketFactory must not be set after start()" }
field = value field = value
@ -280,8 +283,10 @@ class MockWebServer : Closeable {
* @return the head of the request queue * @return the head of the request queue
*/ */
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? = fun takeRequest(
requestQueue.poll(timeout, unit) timeout: Long,
unit: TimeUnit,
): RecordedRequest? = requestQueue.poll(timeout, unit)
/** /**
* Scripts [response] to be returned to a request made in sequence. The first request is * Scripts [response] to be returned to a request made in sequence. The first request is
@ -291,8 +296,7 @@ class MockWebServer : Closeable {
* @throws ClassCastException if the default dispatcher has been * @throws ClassCastException if the default dispatcher has been
* replaced with [setDispatcher][dispatcher]. * replaced with [setDispatcher][dispatcher].
*/ */
fun enqueue(response: MockResponse) = fun enqueue(response: MockResponse) = (dispatcher as QueueDispatcher).enqueueResponse(response)
(dispatcher as QueueDispatcher).enqueueResponse(response)
/** /**
* Starts the server on the loopback interface for the given port. * Starts the server on the loopback interface for the given port.
@ -301,7 +305,8 @@ class MockWebServer : Closeable {
* use port 0 to avoid flakiness when a specific port is unavailable. * use port 0 to avoid flakiness when a specific port is unavailable.
*/ */
@Throws(IOException::class) @Throws(IOException::class)
@JvmOverloads fun start(port: Int = 0) = start(InetAddress.getByName("localhost"), port) @JvmOverloads
fun start(port: Int = 0) = start(InetAddress.getByName("localhost"), port)
/** /**
* Starts the server on the given address and port. * Starts the server on the given address and port.
@ -311,14 +316,18 @@ class MockWebServer : Closeable {
* use port 0 to avoid flakiness when a specific port is unavailable. * use port 0 to avoid flakiness when a specific port is unavailable.
*/ */
@Throws(IOException::class) @Throws(IOException::class)
fun start(inetAddress: InetAddress, port: Int) = start(InetSocketAddress(inetAddress, port)) fun start(
inetAddress: InetAddress,
port: Int,
) = start(InetSocketAddress(inetAddress, port))
/** /**
* Starts the server and binds to the given socket address. * Starts the server and binds to the given socket address.
* *
* @param inetSocketAddress the socket address to bind the server on * @param inetSocketAddress the socket address to bind the server on
*/ */
@Synchronized @Throws(IOException::class) @Synchronized
@Throws(IOException::class)
private fun start(inetSocketAddress: InetSocketAddress) { private fun start(inetSocketAddress: InetSocketAddress) {
check(!shutdown) { "shutdown() already called" } check(!shutdown) { "shutdown() already called" }
if (started) return if (started) return
@ -432,9 +441,10 @@ class MockWebServer : Closeable {
processHandshakeFailure(raw) processHandshakeFailure(raw)
return return
} }
socket = sslSocketFactory!!.createSocket( socket =
raw, raw.inetAddress.hostAddress, raw.port, true sslSocketFactory!!.createSocket(
) raw, raw.inetAddress.hostAddress, raw.port, true,
)
val sslSocket = socket as SSLSocket val sslSocket = socket as SSLSocket
sslSocket.useClientMode = false sslSocket.useClientMode = false
if (clientAuth == CLIENT_AUTH_REQUIRED) { if (clientAuth == CLIENT_AUTH_REQUIRED) {
@ -452,10 +462,11 @@ class MockWebServer : Closeable {
if (protocolNegotiationEnabled) { if (protocolNegotiationEnabled) {
val protocolString = Platform.get().getSelectedProtocol(sslSocket) val protocolString = Platform.get().getSelectedProtocol(sslSocket)
protocol = when { protocol =
protocolString != null -> Protocol.get(protocolString) when {
else -> Protocol.HTTP_1_1 protocolString != null -> Protocol.get(protocolString)
} else -> Protocol.HTTP_1_1
}
Platform.get().afterHandshake(sslSocket) Platform.get().afterHandshake(sslSocket)
} else { } else {
protocol = Protocol.HTTP_1_1 protocol = Protocol.HTTP_1_1
@ -463,10 +474,11 @@ class MockWebServer : Closeable {
openClientSockets.remove(raw) openClientSockets.remove(raw)
} }
else -> { else -> {
protocol = when { protocol =
Protocol.H2_PRIOR_KNOWLEDGE in protocols -> Protocol.H2_PRIOR_KNOWLEDGE when {
else -> Protocol.HTTP_1_1 Protocol.H2_PRIOR_KNOWLEDGE in protocols -> Protocol.H2_PRIOR_KNOWLEDGE
} else -> Protocol.HTTP_1_1
}
socket = raw socket = raw
} }
} }
@ -478,10 +490,11 @@ class MockWebServer : Closeable {
if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) { if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) {
val http2SocketHandler = Http2SocketHandler(socket, protocol) val http2SocketHandler = Http2SocketHandler(socket, protocol)
val connection = Http2Connection.Builder(false, taskRunner) val connection =
.socket(socket) Http2Connection.Builder(false, taskRunner)
.listener(http2SocketHandler) .socket(socket)
.build() .listener(http2SocketHandler)
.build()
connection.start() connection.start()
openConnections.add(connection) openConnections.add(connection)
openClientSockets.remove(socket) openClientSockets.remove(socket)
@ -498,7 +511,7 @@ class MockWebServer : Closeable {
if (sequenceNumber == 0) { if (sequenceNumber == 0) {
logger.warning( logger.warning(
"${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request" "${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request",
) )
} }
@ -538,7 +551,7 @@ class MockWebServer : Closeable {
private fun processOneRequest( private fun processOneRequest(
socket: Socket, socket: Socket,
source: BufferedSource, source: BufferedSource,
sink: BufferedSink sink: BufferedSink,
): Boolean { ): Boolean {
if (source.exhausted()) { if (source.exhausted()) {
return false // No more requests on this socket. return false // No more requests on this socket.
@ -581,7 +594,7 @@ class MockWebServer : Closeable {
if (logger.isLoggable(Level.FINE)) { if (logger.isLoggable(Level.FINE)) {
logger.fine( logger.fine(
"${this@MockWebServer} received request: $request and responded: $response" "${this@MockWebServer} received request: $request and responded: $response",
) )
} }
@ -607,9 +620,13 @@ class MockWebServer : Closeable {
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf<TrustManager>(UNTRUSTED_TRUST_MANAGER), SecureRandom()) context.init(null, arrayOf<TrustManager>(UNTRUSTED_TRUST_MANAGER), SecureRandom())
val sslSocketFactory = context.socketFactory val sslSocketFactory = context.socketFactory
val socket = sslSocketFactory.createSocket( val socket =
raw, raw.inetAddress.hostAddress, raw.port, true sslSocketFactory.createSocket(
) as SSLSocket raw,
raw.inetAddress.hostAddress,
raw.port,
true,
) as SSLSocket
try { try {
socket.startHandshake() // we're testing a handshake failure socket.startHandshake() // we're testing a handshake failure
throw AssertionError() throw AssertionError()
@ -619,10 +636,20 @@ class MockWebServer : Closeable {
} }
@Throws(InterruptedException::class) @Throws(InterruptedException::class)
private fun dispatchBookkeepingRequest(sequenceNumber: Int, socket: Socket) { private fun dispatchBookkeepingRequest(
val request = RecordedRequest( sequenceNumber: Int,
"", headersOf(), emptyList(), 0L, Buffer(), sequenceNumber, socket socket: Socket,
) ) {
val request =
RecordedRequest(
"",
headersOf(),
emptyList(),
0L,
Buffer(),
sequenceNumber,
socket,
)
atomicRequestCount.incrementAndGet() atomicRequestCount.incrementAndGet()
requestQueue.add(request) requestQueue.add(request)
dispatcher.dispatch(request) dispatcher.dispatch(request)
@ -634,7 +661,7 @@ class MockWebServer : Closeable {
socket: Socket, socket: Socket,
source: BufferedSource, source: BufferedSource,
sink: BufferedSink, sink: BufferedSink,
sequenceNumber: Int sequenceNumber: Int,
): RecordedRequest { ): RecordedRequest {
var request = "" var request = ""
val headers = Headers.Builder() val headers = Headers.Builder()
@ -674,12 +701,13 @@ class MockWebServer : Closeable {
var hasBody = false var hasBody = false
val policy = dispatcher.peek() val policy = dispatcher.peek()
val requestBodySink = requestBody.withThrottlingAndSocketPolicy( val requestBodySink =
policy = policy, requestBody.withThrottlingAndSocketPolicy(
disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody, policy = policy,
expectedByteCount = contentLength, disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody,
socket = socket, expectedByteCount = contentLength,
).buffer() socket = socket,
).buffer()
requestBodySink.use { requestBodySink.use {
when { when {
policy.socketPolicy is DoNotReadRequestBody -> { policy.socketPolicy is DoNotReadRequestBody -> {
@ -725,7 +753,7 @@ class MockWebServer : Closeable {
body = requestBody.buffer, body = requestBody.buffer,
sequenceNumber = sequenceNumber, sequenceNumber = sequenceNumber,
socket = socket, socket = socket,
failure = failure failure = failure,
) )
} }
@ -735,47 +763,52 @@ class MockWebServer : Closeable {
source: BufferedSource, source: BufferedSource,
sink: BufferedSink, sink: BufferedSink,
request: RecordedRequest, request: RecordedRequest,
response: MockResponse response: MockResponse,
) { ) {
val key = request.headers["Sec-WebSocket-Key"] val key = request.headers["Sec-WebSocket-Key"]
val webSocketResponse = response.newBuilder() val webSocketResponse =
.setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!)) response.newBuilder()
.build() .setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!))
.build()
writeHttpResponse(socket, sink, webSocketResponse) writeHttpResponse(socket, sink, webSocketResponse)
// Adapt the request and response into our Request and Response domain model. // Adapt the request and response into our Request and Response domain model.
val scheme = if (request.handshake != null) "https" else "http" val scheme = if (request.handshake != null) "https" else "http"
val authority = request.headers["Host"] // Has host and port. val authority = request.headers["Host"] // Has host and port.
val fancyRequest = Request.Builder() val fancyRequest =
.url("$scheme://$authority/") Request.Builder()
.headers(request.headers) .url("$scheme://$authority/")
.build() .headers(request.headers)
val fancyResponse = Response.Builder() .build()
.code(webSocketResponse.code) val fancyResponse =
.message(webSocketResponse.message) Response.Builder()
.headers(webSocketResponse.headers) .code(webSocketResponse.code)
.request(fancyRequest) .message(webSocketResponse.message)
.protocol(Protocol.HTTP_1_1) .headers(webSocketResponse.headers)
.build() .request(fancyRequest)
.protocol(Protocol.HTTP_1_1)
.build()
val connectionClose = CountDownLatch(1) val connectionClose = CountDownLatch(1)
val streams = object : RealWebSocket.Streams(false, source, sink) { val streams =
override fun close() = connectionClose.countDown() object : RealWebSocket.Streams(false, source, sink) {
override fun close() = connectionClose.countDown()
override fun cancel() { override fun cancel() {
socket.closeQuietly() socket.closeQuietly()
}
} }
} val webSocket =
val webSocket = RealWebSocket( RealWebSocket(
taskRunner = taskRunner, taskRunner = taskRunner,
originalRequest = fancyRequest, originalRequest = fancyRequest,
listener = webSocketResponse.webSocketListener!!, listener = webSocketResponse.webSocketListener!!,
random = SecureRandom(), random = SecureRandom(),
pingIntervalMillis = 0, pingIntervalMillis = 0,
extensions = WebSocketExtensions.parse(webSocketResponse.headers), extensions = WebSocketExtensions.parse(webSocketResponse.headers),
// Compress all messages if compression is enabled. // Compress all messages if compression is enabled.
minimumDeflateSize = 0L, minimumDeflateSize = 0L,
) )
val name = "MockWebServer WebSocket ${request.path!!}" val name = "MockWebServer WebSocket ${request.path!!}"
webSocket.initReaderAndWriter(name, streams) webSocket.initReaderAndWriter(name, streams)
try { try {
@ -789,7 +822,11 @@ class MockWebServer : Closeable {
} }
@Throws(IOException::class) @Throws(IOException::class)
private fun writeHttpResponse(socket: Socket, sink: BufferedSink, response: MockResponse) { private fun writeHttpResponse(
socket: Socket,
sink: BufferedSink,
response: MockResponse,
) {
sleepNanos(response.headersDelayNanos) sleepNanos(response.headersDelayNanos)
sink.writeUtf8(response.status) sink.writeUtf8(response.status)
sink.writeUtf8("\r\n") sink.writeUtf8("\r\n")
@ -798,12 +835,13 @@ class MockWebServer : Closeable {
val body = response.body ?: return val body = response.body ?: return
sleepNanos(response.bodyDelayNanos) sleepNanos(response.bodyDelayNanos)
val responseBodySink = sink.withThrottlingAndSocketPolicy( val responseBodySink =
policy = response, sink.withThrottlingAndSocketPolicy(
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody, policy = response,
expectedByteCount = body.contentLength, disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
socket = socket, expectedByteCount = body.contentLength,
).buffer() socket = socket,
).buffer()
body.writeTo(responseBodySink) body.writeTo(responseBodySink)
responseBodySink.emit() responseBodySink.emit()
@ -813,7 +851,10 @@ class MockWebServer : Closeable {
} }
@Throws(IOException::class) @Throws(IOException::class)
private fun writeHeaders(sink: BufferedSink, headers: Headers) { private fun writeHeaders(
sink: BufferedSink,
headers: Headers,
) {
for ((name, value) in headers) { for ((name, value) in headers) {
sink.writeUtf8(name) sink.writeUtf8(name)
sink.writeUtf8(": ") sink.writeUtf8(": ")
@ -834,25 +875,28 @@ class MockWebServer : Closeable {
var result: Sink = this var result: Sink = this
if (policy.throttlePeriodNanos > 0L) { if (policy.throttlePeriodNanos > 0L) {
result = ThrottledSink( result =
delegate = result, ThrottledSink(
bytesPerPeriod = policy.throttleBytesPerPeriod, delegate = result,
periodDelayNanos = policy.throttlePeriodNanos, bytesPerPeriod = policy.throttleBytesPerPeriod,
) periodDelayNanos = policy.throttlePeriodNanos,
)
} }
if (disconnectHalfway) { if (disconnectHalfway) {
val halfwayByteCount = when { val halfwayByteCount =
expectedByteCount != -1L -> expectedByteCount / 2 when {
else -> 0L expectedByteCount != -1L -> expectedByteCount / 2
} else -> 0L
result = TriggerSink( }
delegate = result, result =
triggerByteCount = halfwayByteCount, TriggerSink(
) { delegate = result,
result.flush() triggerByteCount = halfwayByteCount,
socket.close() ) {
} result.flush()
socket.close()
}
} }
return result return result
@ -871,13 +915,16 @@ class MockWebServer : Closeable {
/** A buffer wrapper that drops data after [bodyLimit] bytes. */ /** A buffer wrapper that drops data after [bodyLimit] bytes. */
private class TruncatingBuffer( private class TruncatingBuffer(
private var remainingByteCount: Long private var remainingByteCount: Long,
) : Sink { ) : Sink {
val buffer = Buffer() val buffer = Buffer()
var receivedByteCount = 0L var receivedByteCount = 0L
@Throws(IOException::class) @Throws(IOException::class)
override fun write(source: Buffer, byteCount: Long) { override fun write(
source: Buffer,
byteCount: Long,
) {
val toRead = minOf(remainingByteCount, byteCount) val toRead = minOf(remainingByteCount, byteCount)
if (toRead > 0L) { if (toRead > 0L) {
source.read(buffer, toRead) source.read(buffer, toRead)
@ -904,7 +951,7 @@ class MockWebServer : Closeable {
/** Processes HTTP requests layered over HTTP/2. */ /** Processes HTTP requests layered over HTTP/2. */
private inner class Http2SocketHandler constructor( private inner class Http2SocketHandler constructor(
private val socket: Socket, private val socket: Socket,
private val protocol: Protocol private val protocol: Protocol,
) : Http2Connection.Listener() { ) : Http2Connection.Listener() {
private val sequenceNumber = AtomicInteger() private val sequenceNumber = AtomicInteger()
@ -935,7 +982,7 @@ class MockWebServer : Closeable {
if (logger.isLoggable(Level.FINE)) { if (logger.isLoggable(Level.FINE)) {
logger.fine( logger.fine(
"${this@MockWebServer} received request: $request " + "${this@MockWebServer} received request: $request " +
"and responded: $response protocol is $protocol" "and responded: $response protocol is $protocol",
) )
} }
@ -990,12 +1037,13 @@ class MockWebServer : Closeable {
if (readBody && peek.streamHandler == null && peek.socketPolicy !is DoNotReadRequestBody) { if (readBody && peek.streamHandler == null && peek.socketPolicy !is DoNotReadRequestBody) {
try { try {
val contentLengthString = headers["content-length"] val contentLengthString = headers["content-length"]
val requestBodySink = body.withThrottlingAndSocketPolicy( val requestBodySink =
policy = peek, body.withThrottlingAndSocketPolicy(
disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody, policy = peek,
expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE, disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody,
socket = socket, expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE,
).buffer() socket = socket,
).buffer()
requestBodySink.use { requestBodySink.use {
it.writeAll(stream.getSource()) it.writeAll(stream.getSource())
} }
@ -1012,7 +1060,7 @@ class MockWebServer : Closeable {
body = body, body = body,
sequenceNumber = sequenceNumber.getAndIncrement(), sequenceNumber = sequenceNumber.getAndIncrement(),
socket = socket, socket = socket,
failure = exception failure = exception,
) )
} }
@ -1029,7 +1077,7 @@ class MockWebServer : Closeable {
private fun writeResponse( private fun writeResponse(
stream: Http2Stream, stream: Http2Stream,
request: RecordedRequest, request: RecordedRequest,
response: MockResponse response: MockResponse,
) { ) {
val settings = response.settings val settings = response.settings
stream.connection.setSettings(settings) stream.connection.setSettings(settings)
@ -1042,9 +1090,11 @@ class MockWebServer : Closeable {
val trailers = response.trailers val trailers = response.trailers
val body = response.body val body = response.body
val streamHandler = response.streamHandler val streamHandler = response.streamHandler
val outFinished = (body == null && val outFinished = (
response.pushPromises.isEmpty() && body == null &&
streamHandler == null) response.pushPromises.isEmpty() &&
streamHandler == null
)
val flushHeaders = body == null || bodyDelayNanos != 0L val flushHeaders = body == null || bodyDelayNanos != 0L
require(!outFinished || trailers.size == 0) { require(!outFinished || trailers.size == 0) {
"unsupported: no body and non-empty trailers $trailers" "unsupported: no body and non-empty trailers $trailers"
@ -1059,12 +1109,13 @@ class MockWebServer : Closeable {
pushPromises(stream, request, response.pushPromises) pushPromises(stream, request, response.pushPromises)
if (body != null) { if (body != null) {
sleepNanos(bodyDelayNanos) sleepNanos(bodyDelayNanos)
val responseBodySink = stream.getSink().withThrottlingAndSocketPolicy( val responseBodySink =
policy = response, stream.getSink().withThrottlingAndSocketPolicy(
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody, policy = response,
expectedByteCount = body.contentLength, disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
socket = socket expectedByteCount = body.contentLength,
).buffer() socket = socket,
).buffer()
responseBodySink.use { responseBodySink.use {
body.writeTo(responseBodySink) body.writeTo(responseBodySink)
} }
@ -1079,7 +1130,7 @@ class MockWebServer : Closeable {
private fun pushPromises( private fun pushPromises(
stream: Http2Stream, stream: Http2Stream,
request: RecordedRequest, request: RecordedRequest,
promises: List<PushPromise> promises: List<PushPromise>,
) { ) {
for (pushPromise in promises) { for (pushPromise in promises) {
val pushedHeaders = mutableListOf<Header>() val pushedHeaders = mutableListOf<Header>()
@ -1100,8 +1151,8 @@ class MockWebServer : Closeable {
bodySize = 0, bodySize = 0,
body = Buffer(), body = Buffer(),
sequenceNumber = sequenceNumber.getAndIncrement(), sequenceNumber = sequenceNumber.getAndIncrement(),
socket = socket socket = socket,
) ),
) )
val hasBody = pushPromise.response.body != null val hasBody = pushPromise.response.body != null
val pushedStream = stream.connection.pushStream(stream.id, pushedHeaders, hasBody) val pushedStream = stream.connection.pushStream(stream.id, pushedHeaders, hasBody)
@ -1115,20 +1166,21 @@ class MockWebServer : Closeable {
private const val CLIENT_AUTH_REQUESTED = 1 private const val CLIENT_AUTH_REQUESTED = 1
private const val CLIENT_AUTH_REQUIRED = 2 private const val CLIENT_AUTH_REQUIRED = 2
private val UNTRUSTED_TRUST_MANAGER = object : X509TrustManager { private val UNTRUSTED_TRUST_MANAGER =
@Throws(CertificateException::class) object : X509TrustManager {
override fun checkClientTrusted( @Throws(CertificateException::class)
chain: Array<X509Certificate>, override fun checkClientTrusted(
authType: String chain: Array<X509Certificate>,
) = throw CertificateException() authType: String,
) = throw CertificateException()
override fun checkServerTrusted( override fun checkServerTrusted(
chain: Array<X509Certificate>, chain: Array<X509Certificate>,
authType: String authType: String,
) = throw AssertionError() ) = throw AssertionError()
override fun getAcceptedIssuers(): Array<X509Certificate> = throw AssertionError() override fun getAcceptedIssuers(): Array<X509Certificate> = throw AssertionError()
} }
private val logger = Logger.getLogger(MockWebServer::class.java.name) private val logger = Logger.getLogger(MockWebServer::class.java.name)
} }

View File

@ -68,11 +68,12 @@ open class QueueDispatcher : Dispatcher() {
} }
open fun setFailFast(failFast: Boolean) { open fun setFailFast(failFast: Boolean) {
val failFastResponse = if (failFast) { val failFastResponse =
MockResponse(code = HttpURLConnection.HTTP_NOT_FOUND) if (failFast) {
} else { MockResponse(code = HttpURLConnection.HTTP_NOT_FOUND)
null } else {
} null
}
setFailFast(failFastResponse) setFailFast(failFastResponse)
} }

View File

@ -31,34 +31,28 @@ import okio.Buffer
/** An HTTP request that came into the mock web server. */ /** An HTTP request that came into the mock web server. */
class RecordedRequest( class RecordedRequest(
val requestLine: String, val requestLine: String,
/** All headers. */ /** All headers. */
val headers: Headers, val headers: Headers,
/** /**
* The sizes of the chunks of this request's body, or an empty list if the request's body * The sizes of the chunks of this request's body, or an empty list if the request's body
* was empty or unchunked. * was empty or unchunked.
*/ */
val chunkSizes: List<Int>, val chunkSizes: List<Int>,
/** The total size of the body of this POST request (before truncation).*/ /** The total size of the body of this POST request (before truncation).*/
val bodySize: Long, val bodySize: Long,
/** The body of this POST request. This may be truncated. */ /** The body of this POST request. This may be truncated. */
val body: Buffer, val body: Buffer,
/** /**
* The index of this request on its HTTP connection. Since a single HTTP connection may serve * The index of this request on its HTTP connection. Since a single HTTP connection may serve
* multiple requests, each request is assigned its own sequence number. * multiple requests, each request is assigned its own sequence number.
*/ */
val sequenceNumber: Int, val sequenceNumber: Int,
socket: Socket, socket: Socket,
/** /**
* The failure MockWebServer recorded when attempting to decode this request. If, for example, * The failure MockWebServer recorded when attempting to decode this request. If, for example,
* the inbound request was truncated, this exception will be non-null. * the inbound request was truncated, this exception will be non-null.
*/ */
val failure: IOException? = null val failure: IOException? = null,
) { ) {
val method: String? val method: String?
val path: String? val path: String?
@ -102,12 +96,13 @@ class RecordedRequest(
val scheme = if (socket is SSLSocket) "https" else "http" val scheme = if (socket is SSLSocket) "https" else "http"
val localPort = socket.localPort val localPort = socket.localPort
val hostAndPort = headers[":authority"] val hostAndPort =
?: headers["Host"] headers[":authority"]
?: when (val inetAddress = socket.localAddress) { ?: headers["Host"]
is Inet6Address -> "[${inetAddress.hostAddress}]:$localPort" ?: when (val inetAddress = socket.localAddress) {
else -> "${inetAddress.hostAddress}:$localPort" is Inet6Address -> "[${inetAddress.hostAddress}]:$localPort"
} else -> "${inetAddress.hostAddress}:$localPort"
}
// Allow null in failure case to allow for testing bad requests // Allow null in failure case to allow for testing bad requests
this.requestUrl = "$scheme://$hostAndPort$path".toHttpUrlOrNull() this.requestUrl = "$scheme://$hostAndPort$path".toHttpUrlOrNull()

View File

@ -30,7 +30,6 @@ package mockwebserver3
* server has closed the socket before follow up requests are made. * server has closed the socket before follow up requests are made.
*/ */
sealed interface SocketPolicy { sealed interface SocketPolicy {
/** /**
* Shutdown [MockWebServer] after writing response. * Shutdown [MockWebServer] after writing response.
*/ */

View File

@ -29,7 +29,11 @@ internal class ThrottledSink(
private val periodDelayNanos: Long, private val periodDelayNanos: Long,
) : Sink by delegate { ) : Sink by delegate {
private var bytesWrittenSinceLastDelay = 0L private var bytesWrittenSinceLastDelay = 0L
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
var bytesLeft = byteCount var bytesLeft = byteCount
while (bytesLeft > 0) { while (bytesLeft > 0) {

View File

@ -30,7 +30,10 @@ internal class TriggerSink(
) : Sink by delegate { ) : Sink by delegate {
private var bytesWritten = 0L private var bytesWritten = 0L
override fun write(source: Buffer, byteCount: Long) { override fun write(
source: Buffer,
byteCount: Long,
) {
if (byteCount == 0L) return // Avoid double-triggering. if (byteCount == 0L) return // Avoid double-triggering.
if (bytesWritten == triggerByteCount) { if (bytesWritten == triggerByteCount) {

View File

@ -34,37 +34,41 @@ class MockStreamHandler : StreamHandler {
private val actions = LinkedBlockingQueue<Action>() private val actions = LinkedBlockingQueue<Action>()
private val results = LinkedBlockingQueue<FutureTask<Void>>() private val results = LinkedBlockingQueue<FutureTask<Void>>()
fun receiveRequest(expected: String) = apply { fun receiveRequest(expected: String) =
actions += { stream -> apply {
val actual = stream.requestBody.readUtf8(expected.utf8Size()) actions += { stream ->
if (actual != expected) throw AssertionError("$actual != $expected") val actual = stream.requestBody.readUtf8(expected.utf8Size())
} if (actual != expected) throw AssertionError("$actual != $expected")
} }
}
fun exhaustRequest() = apply {
actions += { stream -> fun exhaustRequest() =
if (!stream.requestBody.exhausted()) throw AssertionError("expected exhausted") apply {
} actions += { stream ->
} if (!stream.requestBody.exhausted()) throw AssertionError("expected exhausted")
}
fun cancelStream() = apply { }
actions += { stream -> stream.cancel() }
} fun cancelStream() =
apply {
fun requestIOException() = apply { actions += { stream -> stream.cancel() }
actions += { stream -> }
try {
stream.requestBody.exhausted() fun requestIOException() =
throw AssertionError("expected IOException") apply {
} catch (expected: IOException) { actions += { stream ->
try {
stream.requestBody.exhausted()
throw AssertionError("expected IOException")
} catch (expected: IOException) {
}
} }
} }
}
@JvmOverloads @JvmOverloads
fun sendResponse( fun sendResponse(
s: String, s: String,
responseSent: CountDownLatch = CountDownLatch(0) responseSent: CountDownLatch = CountDownLatch(0),
) = apply { ) = apply {
actions += { stream -> actions += { stream ->
stream.responseBody.writeUtf8(s) stream.responseBody.writeUtf8(s)
@ -73,11 +77,15 @@ class MockStreamHandler : StreamHandler {
} }
} }
fun exhaustResponse() = apply { fun exhaustResponse() =
actions += { stream -> stream.responseBody.close() } apply {
} actions += { stream -> stream.responseBody.close() }
}
fun sleep(duration: Long, unit: TimeUnit) = apply { fun sleep(
duration: Long,
unit: TimeUnit,
) = apply {
actions += { Thread.sleep(unit.toMillis(duration)) } actions += { Thread.sleep(unit.toMillis(duration)) }
} }
@ -88,9 +96,7 @@ class MockStreamHandler : StreamHandler {
} }
/** Returns a task that processes both request and response from [stream]. */ /** Returns a task that processes both request and response from [stream]. */
private fun serviceStreamTask( private fun serviceStreamTask(stream: Stream): FutureTask<Void> {
stream: Stream,
): FutureTask<Void> {
return FutureTask<Void> { return FutureTask<Void> {
stream.requestBody.use { stream.requestBody.use {
stream.responseBody.use { stream.responseBody.use {
@ -106,8 +112,9 @@ class MockStreamHandler : StreamHandler {
/** Returns once all stream actions complete successfully. */ /** Returns once all stream actions complete successfully. */
fun awaitSuccess() { fun awaitSuccess() {
val futureTask = results.poll(5, TimeUnit.SECONDS) val futureTask =
?: throw AssertionError("no onRequest call received") results.poll(5, TimeUnit.SECONDS)
?: throw AssertionError("no onRequest call received")
futureTask.get(5, TimeUnit.SECONDS) futureTask.get(5, TimeUnit.SECONDS)
} }
} }

View File

@ -37,12 +37,13 @@ class CustomDispatcherTest {
@Test @Test
fun simpleDispatch() { fun simpleDispatch() {
val requestsMade = mutableListOf<RecordedRequest>() val requestsMade = mutableListOf<RecordedRequest>()
val dispatcher: Dispatcher = object : Dispatcher() { val dispatcher: Dispatcher =
override fun dispatch(request: RecordedRequest): MockResponse { object : Dispatcher() {
requestsMade.add(request) override fun dispatch(request: RecordedRequest): MockResponse {
return MockResponse() requestsMade.add(request)
return MockResponse()
}
} }
}
assertThat(requestsMade.size).isEqualTo(0) assertThat(requestsMade.size).isEqualTo(0)
mockWebServer.dispatcher = dispatcher mockWebServer.dispatcher = dispatcher
val url = mockWebServer.url("/").toUrl() val url = mockWebServer.url("/").toUrl()
@ -59,14 +60,15 @@ class CustomDispatcherTest {
val secondRequest = "/bar" val secondRequest = "/bar"
val firstRequest = "/foo" val firstRequest = "/foo"
val latch = CountDownLatch(1) val latch = CountDownLatch(1)
val dispatcher: Dispatcher = object : Dispatcher() { val dispatcher: Dispatcher =
override fun dispatch(request: RecordedRequest): MockResponse { object : Dispatcher() {
if (request.path == firstRequest) { override fun dispatch(request: RecordedRequest): MockResponse {
latch.await() if (request.path == firstRequest) {
latch.await()
}
return MockResponse()
} }
return MockResponse()
} }
}
mockWebServer.dispatcher = dispatcher mockWebServer.dispatcher = dispatcher
val startsFirst = buildRequestThread(firstRequest, firstResponseCode) val startsFirst = buildRequestThread(firstRequest, firstResponseCode)
startsFirst.start() startsFirst.start()
@ -85,7 +87,10 @@ class CustomDispatcherTest {
assertThat(secondResponseCode.get()).isEqualTo(200) assertThat(secondResponseCode.get()).isEqualTo(200)
} }
private fun buildRequestThread(path: String, responseCode: AtomicInteger): Thread { private fun buildRequestThread(
path: String,
responseCode: AtomicInteger,
): Thread {
return Thread { return Thread {
val url = mockWebServer.url(path).toUrl() val url = mockWebServer.url(path).toUrl()
val conn: HttpURLConnection val conn: HttpURLConnection

View File

@ -55,17 +55,19 @@ class MockResponseSniTest {
val handshakeCertificates = localhost() val handshakeCertificates = localhost()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
val dns = Dns { val dns =
Dns.SYSTEM.lookup(server.hostName) Dns {
} Dns.SYSTEM.lookup(server.hostName)
}
val client = clientTestRule.newClientBuilder() val client =
.sslSocketFactory( clientTestRule.newClientBuilder()
handshakeCertificates.sslSocketFactory(), .sslSocketFactory(
handshakeCertificates.trustManager handshakeCertificates.sslSocketFactory(),
) handshakeCertificates.trustManager,
.dns(dns) )
.build() .dns(dns)
.build()
server.enqueue(MockResponse()) server.enqueue(MockResponse())
@ -84,36 +86,41 @@ class MockResponseSniTest {
*/ */
@Test @Test
fun domainFronting() { fun domainFronting() {
val heldCertificate = HeldCertificate.Builder() val heldCertificate =
.commonName("server name") HeldCertificate.Builder()
.addSubjectAlternativeName("url-host.com") .commonName("server name")
.build() .addSubjectAlternativeName("url-host.com")
val handshakeCertificates = HandshakeCertificates.Builder() .build()
.heldCertificate(heldCertificate) val handshakeCertificates =
.addTrustedCertificate(heldCertificate.certificate) HandshakeCertificates.Builder()
.build() .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
val dns = Dns { val dns =
Dns.SYSTEM.lookup(server.hostName) Dns {
} Dns.SYSTEM.lookup(server.hostName)
}
val client = clientTestRule.newClientBuilder() val client =
.sslSocketFactory( clientTestRule.newClientBuilder()
handshakeCertificates.sslSocketFactory(), .sslSocketFactory(
handshakeCertificates.trustManager handshakeCertificates.sslSocketFactory(),
) handshakeCertificates.trustManager,
.dns(dns) )
.build() .dns(dns)
.build()
server.enqueue(MockResponse()) server.enqueue(MockResponse())
val call = client.newCall( val call =
Request( client.newCall(
url = "https://url-host.com:${server.port}/".toHttpUrl(), Request(
headers = headersOf("Host", "header-host"), url = "https://url-host.com:${server.port}/".toHttpUrl(),
headers = headersOf("Host", "header-host"),
),
) )
)
val response = call.execute() val response = call.execute()
assertThat(response.isSuccessful).isTrue() assertThat(response.isSuccessful).isTrue()
@ -150,34 +157,39 @@ class MockResponseSniTest {
* tell MockWebServer to act as a proxy. * tell MockWebServer to act as a proxy.
*/ */
private fun requestToHostnameViaProxy(hostnameOrIpAddress: String): RecordedRequest { private fun requestToHostnameViaProxy(hostnameOrIpAddress: String): RecordedRequest {
val heldCertificate = HeldCertificate.Builder() val heldCertificate =
.commonName("server name") HeldCertificate.Builder()
.addSubjectAlternativeName(hostnameOrIpAddress) .commonName("server name")
.build() .addSubjectAlternativeName(hostnameOrIpAddress)
val handshakeCertificates = HandshakeCertificates.Builder() .build()
.heldCertificate(heldCertificate) val handshakeCertificates =
.addTrustedCertificate(heldCertificate.certificate) HandshakeCertificates.Builder()
.build() .heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
val client = clientTestRule.newClientBuilder() val client =
.sslSocketFactory( clientTestRule.newClientBuilder()
handshakeCertificates.sslSocketFactory(), .sslSocketFactory(
handshakeCertificates.trustManager handshakeCertificates.sslSocketFactory(),
) handshakeCertificates.trustManager,
.proxy(server.toProxyAddress()) )
.build() .proxy(server.toProxyAddress())
.build()
server.enqueue(MockResponse(inTunnel = true)) server.enqueue(MockResponse(inTunnel = true))
server.enqueue(MockResponse()) server.enqueue(MockResponse())
val call = client.newCall( val call =
Request( client.newCall(
url = server.url("/").newBuilder() Request(
.host(hostnameOrIpAddress) url =
.build() server.url("/").newBuilder()
.host(hostnameOrIpAddress)
.build(),
),
) )
)
val response = call.execute() val response = call.execute()
assertThat(response.isSuccessful).isTrue() assertThat(response.isSuccessful).isTrue()

View File

@ -92,15 +92,16 @@ class MockWebServerTest {
@Test @Test
fun setResponseMockReason() { fun setResponseMockReason() {
val reasons = arrayOf<String?>( val reasons =
"Mock Response", arrayOf<String?>(
"Informational", "Mock Response",
"OK", "Informational",
"Redirection", "OK",
"Client Error", "Redirection",
"Server Error", "Client Error",
"Mock Response" "Server Error",
) "Mock Response",
)
for (i in 0..599) { for (i in 0..599) {
val builder = MockResponse.Builder().code(i) val builder = MockResponse.Builder().code(i)
val expectedReason = reasons[i / 100] val expectedReason = reasons[i / 100]
@ -128,30 +129,33 @@ class MockWebServerTest {
@Test @Test
fun mockResponseAddHeader() { fun mockResponseAddHeader() {
val builder = MockResponse.Builder() val builder =
.clearHeaders() MockResponse.Builder()
.addHeader("Cookie: s=square") .clearHeaders()
.addHeader("Cookie", "a=android") .addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
assertThat(headersToList(builder)).containsExactly("Cookie: s=square", "Cookie: a=android") assertThat(headersToList(builder)).containsExactly("Cookie: s=square", "Cookie: a=android")
} }
@Test @Test
fun mockResponseSetHeader() { fun mockResponseSetHeader() {
val builder = MockResponse.Builder() val builder =
.clearHeaders() MockResponse.Builder()
.addHeader("Cookie: s=square") .clearHeaders()
.addHeader("Cookie: a=android") .addHeader("Cookie: s=square")
.addHeader("Cookies: delicious") .addHeader("Cookie: a=android")
.addHeader("Cookies: delicious")
builder.setHeader("cookie", "r=robot") builder.setHeader("cookie", "r=robot")
assertThat(headersToList(builder)).containsExactly("Cookies: delicious", "cookie: r=robot") assertThat(headersToList(builder)).containsExactly("Cookies: delicious", "cookie: r=robot")
} }
@Test @Test
fun mockResponseSetHeaders() { fun mockResponseSetHeaders() {
val builder = MockResponse.Builder() val builder =
.clearHeaders() MockResponse.Builder()
.addHeader("Cookie: s=square") .clearHeaders()
.addHeader("Cookies: delicious") .addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
builder.headers(Headers.Builder().add("Cookie", "a=android").build()) builder.headers(Headers.Builder().add("Cookie", "a=android").build())
assertThat(headersToList(builder)).containsExactly("Cookie: a=android") assertThat(headersToList(builder)).containsExactly("Cookie: a=android")
} }
@ -175,14 +179,18 @@ class MockWebServerTest {
@Test @Test
fun redirect() { fun redirect() {
server.enqueue(MockResponse.Builder() server.enqueue(
.code(HttpURLConnection.HTTP_MOVED_TEMP) MockResponse.Builder()
.addHeader("Location: " + server.url("/new-path")) .code(HttpURLConnection.HTTP_MOVED_TEMP)
.body("This page has moved!") .addHeader("Location: " + server.url("/new-path"))
.build()) .body("This page has moved!")
server.enqueue(MockResponse.Builder() .build(),
.body("This is the new location!") )
.build()) server.enqueue(
MockResponse.Builder()
.body("This is the new location!")
.build(),
)
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8)) val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8))
assertThat(reader.readLine()).isEqualTo("This is the new location!") assertThat(reader.readLine()).isEqualTo("This is the new location!")
@ -203,9 +211,11 @@ class MockWebServerTest {
Thread.sleep(1000) Thread.sleep(1000)
} catch (ignored: InterruptedException) { } catch (ignored: InterruptedException) {
} }
server.enqueue(MockResponse.Builder() server.enqueue(
.body("enqueued in the background") MockResponse.Builder()
.build()) .body("enqueued in the background")
.build(),
)
}.start() }.start()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8)) val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8))
@ -214,11 +224,13 @@ class MockWebServerTest {
@Test @Test
fun nonHexadecimalChunkSize() { fun nonHexadecimalChunkSize() {
server.enqueue(MockResponse.Builder() server.enqueue(
.body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n") MockResponse.Builder()
.clearHeaders() .body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.addHeader("Transfer-encoding: chunked") .clearHeaders()
.build()) .addHeader("Transfer-encoding: chunked")
.build(),
)
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
try { try {
connection.getInputStream().read() connection.getInputStream().read()
@ -230,14 +242,18 @@ class MockWebServerTest {
@Test @Test
fun responseTimeout() { fun responseTimeout() {
server.enqueue(MockResponse.Builder() server.enqueue(
.body("ABC") MockResponse.Builder()
.clearHeaders() .body("ABC")
.addHeader("Content-Length: 4") .clearHeaders()
.build()) .addHeader("Content-Length: 4")
server.enqueue(MockResponse.Builder() .build(),
.body("DEF") )
.build()) server.enqueue(
MockResponse.Builder()
.body("DEF")
.build(),
)
val urlConnection = server.url("/").toUrl().openConnection() val urlConnection = server.url("/").toUrl().openConnection()
urlConnection!!.readTimeout = 1000 urlConnection!!.readTimeout = 1000
val inputStream = urlConnection.getInputStream() val inputStream = urlConnection.getInputStream()
@ -263,9 +279,11 @@ class MockWebServerTest {
@Disabled("Not actually failing where expected") @Disabled("Not actually failing where expected")
@Test @Test
fun disconnectAtStart() { fun disconnectAtStart() {
server.enqueue(MockResponse.Builder() server.enqueue(
.socketPolicy(DisconnectAtStart) MockResponse.Builder()
.build()) .socketPolicy(DisconnectAtStart)
.build(),
)
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard. server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse()) server.enqueue(MockResponse())
try { try {
@ -293,9 +311,11 @@ class MockWebServerTest {
@Test @Test
fun throttleRequest() { fun throttleRequest() {
assumeNotWindows() assumeNotWindows()
server.enqueue(MockResponse.Builder() server.enqueue(
.throttleBody(3, 500, TimeUnit.MILLISECONDS) MockResponse.Builder()
.build()) .throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build(),
)
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
connection.doOutput = true connection.doOutput = true
@ -314,10 +334,12 @@ class MockWebServerTest {
@Test @Test
fun throttleResponse() { fun throttleResponse() {
assumeNotWindows() assumeNotWindows()
server.enqueue(MockResponse.Builder() server.enqueue(
.body("ABCDEF") MockResponse.Builder()
.throttleBody(3, 500, TimeUnit.MILLISECONDS) .body("ABCDEF")
.build()) .throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build(),
)
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
val inputStream = connection!!.getInputStream() val inputStream = connection!!.getInputStream()
@ -337,10 +359,12 @@ class MockWebServerTest {
@Test @Test
fun delayResponse() { fun delayResponse() {
assumeNotWindows() assumeNotWindows()
server.enqueue(MockResponse.Builder() server.enqueue(
.body("ABCDEF") MockResponse.Builder()
.bodyDelay(1, TimeUnit.SECONDS) .body("ABCDEF")
.build()) .bodyDelay(1, TimeUnit.SECONDS)
.build(),
)
val startNanos = System.nanoTime() val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
val inputStream = connection!!.getInputStream() val inputStream = connection!!.getInputStream()
@ -353,9 +377,11 @@ class MockWebServerTest {
@Test @Test
fun disconnectRequestHalfway() { fun disconnectRequestHalfway() {
server.enqueue(MockResponse.Builder() server.enqueue(
.socketPolicy(DisconnectDuringRequestBody) MockResponse.Builder()
.build()) .socketPolicy(DisconnectDuringRequestBody)
.build(),
)
// Limit the size of the request body that the server holds in memory to an arbitrary // Limit the size of the request body that the server holds in memory to an arbitrary
// 3.5 MBytes so this test can pass on devices with little memory. // 3.5 MBytes so this test can pass on devices with little memory.
server.bodyLimit = 7 * 512 * 1024 server.bodyLimit = 7 * 512 * 1024
@ -386,10 +412,12 @@ class MockWebServerTest {
@Test @Test
fun disconnectResponseHalfway() { fun disconnectResponseHalfway() {
server.enqueue(MockResponse.Builder() server.enqueue(
.body("ab") MockResponse.Builder()
.socketPolicy(DisconnectDuringResponseBody) .body("ab")
.build()) .socketPolicy(DisconnectDuringResponseBody)
.build(),
)
val connection = server.url("/").toUrl().openConnection() val connection = server.url("/").toUrl().openConnection()
assertThat(connection!!.contentLength).isEqualTo(2) assertThat(connection!!.contentLength).isEqualTo(2)
val inputStream = connection.getInputStream() val inputStream = connection.getInputStream()
@ -468,9 +496,11 @@ class MockWebServerTest {
@Test @Test
fun requestUrlReconstructed() { fun requestUrlReconstructed() {
server.enqueue(MockResponse.Builder() server.enqueue(
.body("hello world") MockResponse.Builder()
.build()) .body("hello world")
.build(),
)
val url = server.url("/a/deep/path?key=foo%20bar").toUrl() val url = server.url("/a/deep/path?key=foo%20bar").toUrl()
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
val inputStream = connection.inputStream val inputStream = connection.inputStream
@ -479,7 +509,8 @@ class MockWebServerTest {
assertThat(reader.readLine()).isEqualTo("hello world") assertThat(reader.readLine()).isEqualTo("hello world")
val request = server.takeRequest() val request = server.takeRequest()
assertThat(request.requestLine).isEqualTo( assertThat(request.requestLine).isEqualTo(
"GET /a/deep/path?key=foo%20bar HTTP/1.1") "GET /a/deep/path?key=foo%20bar HTTP/1.1",
)
val requestUrl = request.requestUrl val requestUrl = request.requestUrl
assertThat(requestUrl!!.scheme).isEqualTo("http") assertThat(requestUrl!!.scheme).isEqualTo("http")
assertThat(requestUrl.host).isEqualTo(server.hostName) assertThat(requestUrl.host).isEqualTo(server.hostName)
@ -490,9 +521,11 @@ class MockWebServerTest {
@Test @Test
fun shutdownServerAfterRequest() { fun shutdownServerAfterRequest() {
server.enqueue(MockResponse.Builder() server.enqueue(
.socketPolicy(ShutdownServerAfterResponse) MockResponse.Builder()
.build()) .socketPolicy(ShutdownServerAfterResponse)
.build(),
)
val url = server.url("/").toUrl() val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
assertThat(connection.responseCode).isEqualTo(HttpURLConnection.HTTP_OK) assertThat(connection.responseCode).isEqualTo(HttpURLConnection.HTTP_OK)
@ -506,9 +539,11 @@ class MockWebServerTest {
@Test @Test
fun http100Continue() { fun http100Continue() {
server.enqueue(MockResponse.Builder() server.enqueue(
.body("response") MockResponse.Builder()
.build()) .body("response")
.build(),
)
val url = server.url("/").toUrl() val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
connection.doOutput = true connection.doOutput = true
@ -523,11 +558,13 @@ class MockWebServerTest {
@Test @Test
fun multiple1xxResponses() { fun multiple1xxResponses() {
server.enqueue(MockResponse.Builder() server.enqueue(
.add100Continue() MockResponse.Builder()
.add100Continue() .add100Continue()
.body("response") .add100Continue()
.build()) .body("response")
.build(),
)
val url = server.url("/").toUrl() val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection val connection = url.openConnection() as HttpURLConnection
connection.doOutput = true connection.doOutput = true
@ -546,8 +583,9 @@ class MockWebServerTest {
fail<Unit>() fail<Unit>()
} catch (expected: IllegalArgumentException) { } catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo( assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: " "protocols containing h2_prior_knowledge cannot use other protocols: " +
+ "[h2_prior_knowledge, http/1.1]") "[h2_prior_knowledge, http/1.1]",
)
} }
} }
@ -559,8 +597,9 @@ class MockWebServerTest {
fail<Unit>() fail<Unit>()
} catch (expected: IllegalArgumentException) { } catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo( assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: " "protocols containing h2_prior_knowledge cannot use other protocols: " +
+ "[h2_prior_knowledge, h2_prior_knowledge]") "[h2_prior_knowledge, h2_prior_knowledge]",
)
} }
} }
@ -575,9 +614,11 @@ class MockWebServerTest {
fun https() { fun https() {
val handshakeCertificates = platform.localhostHandshakeCertificates() val handshakeCertificates = platform.localhostHandshakeCertificates()
server.useHttps(handshakeCertificates.sslSocketFactory()) server.useHttps(handshakeCertificates.sslSocketFactory())
server.enqueue(MockResponse.Builder() server.enqueue(
.body("abc") MockResponse.Builder()
.build()) .body("abc")
.build(),
)
val url = server.url("/") val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.sslSocketFactory = handshakeCertificates.sslSocketFactory() connection.sslSocketFactory = handshakeCertificates.sslSocketFactory()
@ -601,32 +642,40 @@ class MockWebServerTest {
platform.assumeNotBouncyCastle() platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt() platform.assumeNotConscrypt()
val clientCa = HeldCertificate.Builder() val clientCa =
.certificateAuthority(0) HeldCertificate.Builder()
.build() .certificateAuthority(0)
val serverCa = HeldCertificate.Builder() .build()
.certificateAuthority(0) val serverCa =
.build() HeldCertificate.Builder()
val serverCertificate = HeldCertificate.Builder() .certificateAuthority(0)
.signedBy(serverCa) .build()
.addSubjectAlternativeName(server.hostName) val serverCertificate =
.build() HeldCertificate.Builder()
val serverHandshakeCertificates = HandshakeCertificates.Builder() .signedBy(serverCa)
.addTrustedCertificate(clientCa.certificate) .addSubjectAlternativeName(server.hostName)
.heldCertificate(serverCertificate) .build()
.build() val serverHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory()) server.useHttps(serverHandshakeCertificates.sslSocketFactory())
server.enqueue(MockResponse.Builder() server.enqueue(
.body("abc") MockResponse.Builder()
.build()) .body("abc")
.build(),
)
server.requestClientAuth() server.requestClientAuth()
val clientCertificate = HeldCertificate.Builder() val clientCertificate =
.signedBy(clientCa) HeldCertificate.Builder()
.build() .signedBy(clientCa)
val clientHandshakeCertificates = HandshakeCertificates.Builder() .build()
.addTrustedCertificate(serverCa.certificate) val clientHandshakeCertificates =
.heldCertificate(clientCertificate) HandshakeCertificates.Builder()
.build() .addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
val url = server.url("/") val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.sslSocketFactory = clientHandshakeCertificates.sslSocketFactory() connection.sslSocketFactory = clientHandshakeCertificates.sslSocketFactory()
@ -647,13 +696,16 @@ class MockWebServerTest {
@Test @Test
fun proxiedRequestGetsCorrectRequestUrl() { fun proxiedRequestGetsCorrectRequestUrl() {
server.enqueue(MockResponse.Builder() server.enqueue(
.body("Result") MockResponse.Builder()
.build()) .body("Result")
val proxiedClient = OkHttpClient.Builder() .build(),
.proxy(server.toProxyAddress()) )
.readTimeout(Duration.ofMillis(100)) val proxiedClient =
.build() OkHttpClient.Builder()
.proxy(server.toProxyAddress())
.readTimeout(Duration.ofMillis(100))
.build()
val request = Request.Builder().url("http://android.com/").build() val request = Request.Builder().url("http://android.com/").build()
proxiedClient.newCall(request).execute().use { response -> proxiedClient.newCall(request).execute().use { response ->
assertThat(response.body.string()).isEqualTo("Result") assertThat(response.body.string()).isEqualTo("Result")

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3 package mockwebserver3
import assertk.assertThat import assertk.assertThat
@ -32,44 +33,50 @@ class RecordedRequestTest {
private val headers: Headers = EMPTY_HEADERS private val headers: Headers = EMPTY_HEADERS
@Test fun testIPv4() { @Test fun testIPv4() {
val socket = FakeSocket( val socket =
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)), FakeSocket(
localPort = 80 localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
) localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket) val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/") assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/")
} }
@Test fun testIpv6() { @Test fun testIpv6() {
val socket = FakeSocket( val socket =
localAddress = InetAddress.getByAddress( FakeSocket(
"::1", localAddress =
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1) InetAddress.getByAddress(
), "::1",
localPort = 80 byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
) ),
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket) val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://[::1]/") assertThat(request.requestUrl.toString()).isEqualTo("http://[::1]/")
} }
@Test fun testUsesLocal() { @Test fun testUsesLocal() {
val socket = FakeSocket( val socket =
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)), FakeSocket(
localPort = 80 localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
) localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket) val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/") assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/")
} }
@Test fun testHostname() { @Test fun testHostname() {
val headers = headersOf("Host", "host-from-header.com") val headers = headersOf("Host", "host-from-header.com")
val socket = FakeSocket( val socket =
localAddress = InetAddress.getByAddress( FakeSocket(
"host-from-address.com", localAddress =
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1) InetAddress.getByAddress(
), "host-from-address.com",
localPort = 80 byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
) ),
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket) val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://host-from-header.com/") assertThat(request.requestUrl.toString()).isEqualTo("http://host-from-header.com/")
} }
@ -78,11 +85,14 @@ class RecordedRequestTest {
private val localAddress: InetAddress, private val localAddress: InetAddress,
private val localPort: Int, private val localPort: Int,
private val remoteAddress: InetAddress = localAddress, private val remoteAddress: InetAddress = localAddress,
private val remotePort: Int = 1234 private val remotePort: Int = 1234,
) : Socket() { ) : Socket() {
override fun getInetAddress() = remoteAddress override fun getInetAddress() = remoteAddress
override fun getLocalAddress() = localAddress override fun getLocalAddress() = localAddress
override fun getLocalPort() = localPort override fun getLocalPort() = localPort
override fun getPort() = remotePort override fun getPort() = remotePort
} }
} }

View File

@ -39,7 +39,7 @@ import okio.source
/** A basic HTTP/2 server that serves the contents of a local directory. */ /** A basic HTTP/2 server that serves the contents of a local directory. */
class Http2Server( class Http2Server(
private val baseDirectory: File, private val baseDirectory: File,
private val sslSocketFactory: SSLSocketFactory private val sslSocketFactory: SSLSocketFactory,
) : Http2Connection.Listener() { ) : Http2Connection.Listener() {
private fun run() { private fun run() {
val serverSocket = ServerSocket(8888) val serverSocket = ServerSocket(8888)
@ -54,10 +54,11 @@ class Http2Server(
if (protocol != Protocol.HTTP_2) { if (protocol != Protocol.HTTP_2) {
throw ProtocolException("Protocol $protocol unsupported") throw ProtocolException("Protocol $protocol unsupported")
} }
val connection = Http2Connection.Builder(false, TaskRunner.INSTANCE) val connection =
.socket(sslSocket) Http2Connection.Builder(false, TaskRunner.INSTANCE)
.listener(this) .socket(sslSocket)
.build() .listener(this)
.build()
connection.start() connection.start()
} catch (e: IOException) { } catch (e: IOException) {
logger.log(Level.INFO, "Http2Server connection failure: $e") logger.log(Level.INFO, "Http2Server connection failure: $e")
@ -70,11 +71,13 @@ class Http2Server(
} }
private fun doSsl(socket: Socket): SSLSocket { private fun doSsl(socket: Socket): SSLSocket {
val sslSocket = sslSocketFactory.createSocket( val sslSocket =
socket, socket.inetAddress.hostAddress, sslSocketFactory.createSocket(
socket.port, socket,
true socket.inetAddress.hostAddress,
) as SSLSocket socket.port,
true,
) as SSLSocket
sslSocket.useClientMode = false sslSocket.useClientMode = false
Platform.get().configureTlsExtensions(sslSocket, null, listOf(Protocol.HTTP_2)) Platform.get().configureTlsExtensions(sslSocket, null, listOf(Protocol.HTTP_2))
sslSocket.startHandshake() sslSocket.startHandshake()
@ -111,32 +114,40 @@ class Http2Server(
} }
} }
private fun send404(stream: Http2Stream, path: String) { private fun send404(
val responseHeaders = listOf( stream: Http2Stream,
Header(":status", "404"), path: String,
Header(":version", "HTTP/1.1"), ) {
Header("content-type", "text/plain") val responseHeaders =
) listOf(
Header(":status", "404"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/plain"),
)
stream.writeHeaders( stream.writeHeaders(
responseHeaders = responseHeaders, responseHeaders = responseHeaders,
outFinished = false, outFinished = false,
flushHeaders = false flushHeaders = false,
) )
val out = stream.getSink().buffer() val out = stream.getSink().buffer()
out.writeUtf8("Not found: $path") out.writeUtf8("Not found: $path")
out.close() out.close()
} }
private fun serveDirectory(stream: Http2Stream, files: Array<File>) { private fun serveDirectory(
val responseHeaders = listOf( stream: Http2Stream,
Header(":status", "200"), files: Array<File>,
Header(":version", "HTTP/1.1"), ) {
Header("content-type", "text/html; charset=UTF-8") val responseHeaders =
) listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/html; charset=UTF-8"),
)
stream.writeHeaders( stream.writeHeaders(
responseHeaders = responseHeaders, responseHeaders = responseHeaders,
outFinished = false, outFinished = false,
flushHeaders = false flushHeaders = false,
) )
val out = stream.getSink().buffer() val out = stream.getSink().buffer()
for (file in files) { for (file in files) {
@ -146,16 +157,20 @@ class Http2Server(
out.close() out.close()
} }
private fun serveFile(stream: Http2Stream, file: File) { private fun serveFile(
val responseHeaders = listOf( stream: Http2Stream,
Header(":status", "200"), file: File,
Header(":version", "HTTP/1.1"), ) {
Header("content-type", contentType(file)) val responseHeaders =
) listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", contentType(file)),
)
stream.writeHeaders( stream.writeHeaders(
responseHeaders = responseHeaders, responseHeaders = responseHeaders,
outFinished = false, outFinished = false,
flushHeaders = false flushHeaders = false,
) )
file.source().use { source -> file.source().use { source ->
stream.getSink().buffer().use { sink -> stream.getSink().buffer().use { sink ->
@ -186,10 +201,11 @@ class Http2Server(
println("Usage: Http2Server <base directory>") println("Usage: Http2Server <base directory>")
return return
} }
val server = Http2Server( val server =
File(args[0]), Http2Server(
localhost().sslContext().socketFactory File(args[0]),
) localhost().sslContext().socketFactory,
)
server.run() server.run()
} }
} }

View File

@ -15,19 +15,22 @@
*/ */
package okhttp3 package okhttp3
import java.io.OutputStream
import java.io.PrintStream
import org.junit.platform.engine.TestExecutionResult import org.junit.platform.engine.TestExecutionResult
import org.junit.platform.launcher.TestExecutionListener import org.junit.platform.launcher.TestExecutionListener
import org.junit.platform.launcher.TestIdentifier import org.junit.platform.launcher.TestIdentifier
import org.junit.platform.launcher.TestPlan import org.junit.platform.launcher.TestPlan
import java.io.OutputStream
import java.io.PrintStream
object DotListener: TestExecutionListener { object DotListener : TestExecutionListener {
private var originalSystemErr: PrintStream? = null private var originalSystemErr: PrintStream? = null
private var originalSystemOut: PrintStream? = null private var originalSystemOut: PrintStream? = null
private var testCount = 0 private var testCount = 0
override fun executionSkipped(testIdentifier: TestIdentifier, reason: String) { override fun executionSkipped(
testIdentifier: TestIdentifier,
reason: String,
) {
printStatus("-") printStatus("-")
} }
@ -40,7 +43,7 @@ object DotListener: TestExecutionListener {
override fun executionFinished( override fun executionFinished(
testIdentifier: TestIdentifier, testIdentifier: TestIdentifier,
testExecutionResult: TestExecutionResult testExecutionResult: TestExecutionResult,
) { ) {
if (!testIdentifier.isContainer) { if (!testIdentifier.isContainer) {
when (testExecutionResult.status!!) { when (testExecutionResult.status!!) {
@ -59,8 +62,8 @@ object DotListener: TestExecutionListener {
originalSystemOut = System.out originalSystemOut = System.out
originalSystemErr = System.err originalSystemErr = System.err
System.setOut(object: PrintStream(OutputStream.nullOutputStream()) {}) System.setOut(object : PrintStream(OutputStream.nullOutputStream()) {})
System.setErr(object: PrintStream(OutputStream.nullOutputStream()) {}) System.setErr(object : PrintStream(OutputStream.nullOutputStream()) {})
} }
fun uninstall() { fun uninstall() {

View File

@ -15,26 +15,27 @@
*/ */
package okhttp3 package okhttp3
import java.io.File
import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor
import org.junit.platform.engine.discovery.DiscoverySelectors import org.junit.platform.engine.discovery.DiscoverySelectors
import java.io.File
// TODO move to junit5 tags // TODO move to junit5 tags
val avoidedTests = setOf( val avoidedTests =
"okhttp3.BouncyCastleTest", setOf(
"okhttp3.ConscryptTest", "okhttp3.BouncyCastleTest",
"okhttp3.CorrettoTest", "okhttp3.ConscryptTest",
"okhttp3.OpenJSSETest", "okhttp3.CorrettoTest",
"okhttp3.internal.platform.Jdk8WithJettyBootPlatformTest", "okhttp3.OpenJSSETest",
"okhttp3.internal.platform.Jdk9PlatformTest", "okhttp3.internal.platform.Jdk8WithJettyBootPlatformTest",
"okhttp3.internal.platform.PlatformTest", "okhttp3.internal.platform.Jdk9PlatformTest",
"okhttp3.internal.platform.android.AndroidSocketAdapterTest", "okhttp3.internal.platform.PlatformTest",
"okhttp3.osgi.OsgiTest", "okhttp3.internal.platform.android.AndroidSocketAdapterTest",
// Hanging. "okhttp3.osgi.OsgiTest",
"okhttp3.CookiesTest", // Hanging.
// Hanging. "okhttp3.CookiesTest",
"okhttp3.WholeOperationTimeoutTest", // Hanging.
) "okhttp3.WholeOperationTimeoutTest",
)
/** /**
* Run periodically to refresh the known set of working tests. * Run periodically to refresh the known set of working tests.
@ -44,11 +45,12 @@ val avoidedTests = setOf(
fun main() { fun main() {
val knownTestFile = File("native-image-tests/src/main/resources/testlist.txt") val knownTestFile = File("native-image-tests/src/main/resources/testlist.txt")
val testSelector = DiscoverySelectors.selectPackage("okhttp3") val testSelector = DiscoverySelectors.selectPackage("okhttp3")
val testClasses = findTests(listOf(testSelector)) val testClasses =
.filter { it.isContainer } findTests(listOf(testSelector))
.mapNotNull { (it as? ClassBasedTestDescriptor)?.testClass?.name } .filter { it.isContainer }
.filterNot { it in avoidedTests } .mapNotNull { (it as? ClassBasedTestDescriptor)?.testClass?.name }
.sorted() .filterNot { it in avoidedTests }
.distinct() .sorted()
.distinct()
knownTestFile.writeText(testClasses.joinToString("\n")) knownTestFile.writeText(testClasses.joinToString("\n"))
} }

View File

@ -15,6 +15,9 @@
*/ */
package okhttp3 package okhttp3
import java.io.File
import java.io.PrintWriter
import kotlin.system.exitProcess
import org.junit.jupiter.engine.JupiterTestEngine import org.junit.jupiter.engine.JupiterTestEngine
import org.junit.platform.console.options.Theme import org.junit.platform.console.options.Theme
import org.junit.platform.engine.DiscoverySelector import org.junit.platform.engine.DiscoverySelector
@ -30,9 +33,6 @@ import org.junit.platform.launcher.core.LauncherConfig
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder
import org.junit.platform.launcher.core.LauncherFactory import org.junit.platform.launcher.core.LauncherFactory
import org.junit.platform.launcher.listeners.SummaryGeneratingListener import org.junit.platform.launcher.listeners.SummaryGeneratingListener
import java.io.File
import java.io.PrintWriter
import kotlin.system.exitProcess
/** /**
* Graal main method to run tests with minimal reflection and automatic settings. * Graal main method to run tests with minimal reflection and automatic settings.
@ -49,13 +49,14 @@ fun main(vararg args: String) {
val jupiterTestEngine = buildTestEngine() val jupiterTestEngine = buildTestEngine()
val config = LauncherConfig.builder() val config =
.enableTestExecutionListenerAutoRegistration(false) LauncherConfig.builder()
.enableTestEngineAutoRegistration(false) .enableTestExecutionListenerAutoRegistration(false)
.enablePostDiscoveryFilterAutoRegistration(false) .enableTestEngineAutoRegistration(false)
.addTestEngines(jupiterTestEngine) .enablePostDiscoveryFilterAutoRegistration(false)
.addTestExecutionListeners(DotListener, summaryListener, treeListener) .addTestEngines(jupiterTestEngine)
.build() .addTestExecutionListeners(DotListener, summaryListener, treeListener)
.build()
val launcher: Launcher = LauncherFactory.create(config) val launcher: Launcher = LauncherFactory.create(config)
val request: LauncherDiscoveryRequest = buildRequest(selectors) val request: LauncherDiscoveryRequest = buildRequest(selectors)
@ -89,8 +90,9 @@ fun testSelectors(inputFile: File? = null): List<DiscoverySelector> {
val lines = val lines =
inputFile?.readLines() ?: sampleTestClass.getResource("/testlist.txt").readText().lines() inputFile?.readLines() ?: sampleTestClass.getResource("/testlist.txt").readText().lines()
val flatClassnameList = lines val flatClassnameList =
.filter { it.isNotBlank() } lines
.filter { it.isNotBlank() }
return flatClassnameList return flatClassnameList
.mapNotNull { .mapNotNull {
@ -107,11 +109,12 @@ fun testSelectors(inputFile: File? = null): List<DiscoverySelector> {
* Builds a Junit Test Plan request for a fixed set of classes, or potentially a recursive package. * Builds a Junit Test Plan request for a fixed set of classes, or potentially a recursive package.
*/ */
fun buildRequest(selectors: List<DiscoverySelector>): LauncherDiscoveryRequest { fun buildRequest(selectors: List<DiscoverySelector>): LauncherDiscoveryRequest {
val request: LauncherDiscoveryRequest = LauncherDiscoveryRequestBuilder.request() val request: LauncherDiscoveryRequest =
// TODO replace junit.jupiter.extensions.autodetection.enabled with API approach. LauncherDiscoveryRequestBuilder.request()
// TODO replace junit.jupiter.extensions.autodetection.enabled with API approach.
// .enableImplicitConfigurationParameters(false) // .enableImplicitConfigurationParameters(false)
.selectors(selectors) .selectors(selectors)
.build() .build()
return request return request
} }
@ -136,11 +139,13 @@ fun findTests(selectors: List<DiscoverySelector>): List<TestDescriptor> {
* https://github.com/junit-team/junit5/issues/2469 * https://github.com/junit-team/junit5/issues/2469
*/ */
fun treeListener(): TestExecutionListener { fun treeListener(): TestExecutionListener {
val colorPalette = Class.forName("org.junit.platform.console.tasks.ColorPalette").getField("DEFAULT").apply { val colorPalette =
isAccessible = true Class.forName("org.junit.platform.console.tasks.ColorPalette").getField("DEFAULT").apply {
}.get(null) isAccessible = true
}.get(null)
return Class.forName( return Class.forName(
"org.junit.platform.console.tasks.TreePrintingListener").declaredConstructors.first() "org.junit.platform.console.tasks.TreePrintingListener",
).declaredConstructors.first()
.apply { .apply {
isAccessible = true isAccessible = true
} }

View File

@ -26,7 +26,8 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.ArgumentsSource
class SampleTest { class SampleTest {
@JvmField @RegisterExtension val clientRule = OkHttpClientTestRule() @JvmField @RegisterExtension
val clientRule = OkHttpClientTestRule()
@Test @Test
fun passingTest() { fun passingTest() {
@ -59,6 +60,6 @@ class SampleTest {
} }
} }
class SampleTestProvider: SimpleProvider() { class SampleTestProvider : SimpleProvider() {
override fun arguments() = listOf("A", "B") override fun arguments() = listOf("A", "B")
} }

View File

@ -16,11 +16,11 @@
package okhttp3 package okhttp3
import com.oracle.svm.core.annotate.AutomaticFeature import com.oracle.svm.core.annotate.AutomaticFeature
import java.io.File
import java.lang.IllegalStateException
import org.graalvm.nativeimage.hosted.Feature import org.graalvm.nativeimage.hosted.Feature
import org.graalvm.nativeimage.hosted.RuntimeClassInitialization import org.graalvm.nativeimage.hosted.RuntimeClassInitialization
import org.graalvm.nativeimage.hosted.RuntimeReflection import org.graalvm.nativeimage.hosted.RuntimeReflection
import java.io.File
import java.lang.IllegalStateException
@AutomaticFeature @AutomaticFeature
class TestRegistration : Feature { class TestRegistration : Feature {
@ -40,7 +40,10 @@ class TestRegistration : Feature {
registerParamProvider(access, "okhttp3.WebPlatformUrlTest\$TestDataParamProvider") registerParamProvider(access, "okhttp3.WebPlatformUrlTest\$TestDataParamProvider")
} }
private fun registerParamProvider(access: Feature.BeforeAnalysisAccess, provider: String) { private fun registerParamProvider(
access: Feature.BeforeAnalysisAccess,
provider: String,
) {
val providerClass = access.findClassByName(provider) val providerClass = access.findClassByName(provider)
if (providerClass != null) { if (providerClass != null) {
registerTest(access, providerClass) registerTest(access, providerClass)
@ -54,7 +57,10 @@ class TestRegistration : Feature {
registerStandardClass(access, "org.junit.platform.console.tasks.TreePrintingListener") registerStandardClass(access, "org.junit.platform.console.tasks.TreePrintingListener")
} }
private fun registerStandardClass(access: Feature.BeforeAnalysisAccess, name: String) { private fun registerStandardClass(
access: Feature.BeforeAnalysisAccess,
name: String,
) {
val clazz: Class<*> = access.findClassByName(name) ?: throw IllegalStateException("Missing class $name") val clazz: Class<*> = access.findClassByName(name) ?: throw IllegalStateException("Missing class $name")
RuntimeReflection.register(clazz) RuntimeReflection.register(clazz)
clazz.declaredConstructors.forEach { clazz.declaredConstructors.forEach {
@ -81,7 +87,10 @@ class TestRegistration : Feature {
} }
} }
private fun registerTest(access: Feature.BeforeAnalysisAccess, java: Class<*>) { private fun registerTest(
access: Feature.BeforeAnalysisAccess,
java: Class<*>,
) {
access.registerAsUsed(java) access.registerAsUsed(java)
RuntimeReflection.register(java) RuntimeReflection.register(java)
java.constructors.forEach { java.constructors.forEach {

View File

@ -40,35 +40,41 @@ import okhttp3.logging.HttpLoggingInterceptor
import okhttp3.logging.LoggingEventListener import okhttp3.logging.LoggingEventListener
class Main : CliktCommand(name = NAME, help = "A curl for the next-generation web.") { class Main : CliktCommand(name = NAME, help = "A curl for the next-generation web.") {
val method: String? by option("-X", "--request", help="Specify request command to use") val method: String? by option("-X", "--request", help = "Specify request command to use")
val data: String? by option("-d", "--data", help="HTTP POST data") val data: String? by option("-d", "--data", help = "HTTP POST data")
val headers: List<String>? by option("-H", "--header", help="Custom header to pass to server").multiple() val headers: List<String>? by option("-H", "--header", help = "Custom header to pass to server").multiple()
val userAgent: String by option("-A", "--user-agent", help="User-Agent to send to server").default(NAME + "/" + versionString()) val userAgent: String by option("-A", "--user-agent", help = "User-Agent to send to server").default(NAME + "/" + versionString())
val connectTimeout: Int by option("--connect-timeout", help="Maximum time allowed for connection (seconds)").int().default(DEFAULT_TIMEOUT) val connectTimeout: Int by option(
"--connect-timeout",
help = "Maximum time allowed for connection (seconds)",
).int().default(DEFAULT_TIMEOUT)
val readTimeout: Int by option("--read-timeout", help="Maximum time allowed for reading data (seconds)").int().default(DEFAULT_TIMEOUT) val readTimeout: Int by option("--read-timeout", help = "Maximum time allowed for reading data (seconds)").int().default(DEFAULT_TIMEOUT)
val callTimeout: Int by option("--call-timeout", help="Maximum time allowed for the entire call (seconds)").int().default(DEFAULT_TIMEOUT) val callTimeout: Int by option(
"--call-timeout",
help = "Maximum time allowed for the entire call (seconds)",
).int().default(DEFAULT_TIMEOUT)
val followRedirects: Boolean by option("-L", "--location", help="Follow redirects").flag() val followRedirects: Boolean by option("-L", "--location", help = "Follow redirects").flag()
val allowInsecure: Boolean by option("-k", "--insecure", help="Allow connections to SSL sites without certs").flag() val allowInsecure: Boolean by option("-k", "--insecure", help = "Allow connections to SSL sites without certs").flag()
val showHeaders: Boolean by option("-i", "--include", help="Include protocol headers in the output").flag() val showHeaders: Boolean by option("-i", "--include", help = "Include protocol headers in the output").flag()
val showHttp2Frames: Boolean by option("--frames", help="Log HTTP/2 frames to STDERR").flag() val showHttp2Frames: Boolean by option("--frames", help = "Log HTTP/2 frames to STDERR").flag()
val referer: String? by option("-e", "--referer", help="Referer URL") val referer: String? by option("-e", "--referer", help = "Referer URL")
val verbose: Boolean by option("-v", "--verbose", help="Makes $NAME verbose during the operation").flag() val verbose: Boolean by option("-v", "--verbose", help = "Makes $NAME verbose during the operation").flag()
val sslDebug: Boolean by option(help="Output SSL Debug").flag() val sslDebug: Boolean by option(help = "Output SSL Debug").flag()
val url: String? by argument(name = "url", help="Remote resource URL") val url: String? by argument(name = "url", help = "Remote resource URL")
var client: Call.Factory? = null var client: Call.Factory? = null
@ -123,20 +129,26 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
return prop.getProperty("version", "dev") return prop.getProperty("version", "dev")
} }
private fun createInsecureTrustManager(): X509TrustManager = object : X509TrustManager { private fun createInsecureTrustManager(): X509TrustManager =
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {} object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<X509Certificate>,
authType: String,
) {}
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {} override fun checkServerTrusted(
chain: Array<X509Certificate>,
authType: String,
) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf() override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
} }
private fun createInsecureSslSocketFactory(trustManager: TrustManager): SSLSocketFactory = private fun createInsecureSslSocketFactory(trustManager: TrustManager): SSLSocketFactory =
Platform.get().newSSLContext().apply { Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null) init(null, arrayOf(trustManager), null)
}.socketFactory }.socketFactory
private fun createInsecureHostnameVerifier(): HostnameVerifier = private fun createInsecureHostnameVerifier(): HostnameVerifier = HostnameVerifier { _, _ -> true }
HostnameVerifier { _, _ -> true }
} }
} }

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("ktlint:standard:filename") @file:Suppress("ktlint:standard:filename")
package okhttp3.curl.internal package okhttp3.curl.internal
import java.io.IOException import java.io.IOException
@ -53,15 +54,16 @@ internal fun Main.commonCreateRequest(): Request {
} }
private fun Main.mediaType(): MediaType? { private fun Main.mediaType(): MediaType? {
val mimeType = headers?.let { val mimeType =
for (header in it) { headers?.let {
val parts = header.split(':', limit = 2) for (header in it) {
if ("Content-Type".equals(parts[0], ignoreCase = true)) { val parts = header.split(':', limit = 2)
return@let parts[1].trim() if ("Content-Type".equals(parts[0], ignoreCase = true)) {
return@let parts[1].trim()
}
} }
} return@let null
return@let null } ?: "application/x-www-form-urlencoded"
} ?: "application/x-www-form-urlencoded"
return mimeType.toMediaTypeOrNull() return mimeType.toMediaTypeOrNull()
} }

View File

@ -1,32 +1,37 @@
package okhttp3.curl.logging package okhttp3.curl.logging
import okhttp3.internal.http2.Http2
import java.util.logging.ConsoleHandler import java.util.logging.ConsoleHandler
import java.util.logging.Level import java.util.logging.Level
import java.util.logging.LogManager import java.util.logging.LogManager
import java.util.logging.LogRecord import java.util.logging.LogRecord
import java.util.logging.Logger import java.util.logging.Logger
import okhttp3.internal.http2.Http2
class LoggingUtil { class LoggingUtil {
companion object { companion object {
private val activeLoggers = mutableListOf<Logger>() private val activeLoggers = mutableListOf<Logger>()
fun configureLogging(debug: Boolean, showHttp2Frames: Boolean, sslDebug: Boolean) { fun configureLogging(
debug: Boolean,
showHttp2Frames: Boolean,
sslDebug: Boolean,
) {
if (debug || showHttp2Frames || sslDebug) { if (debug || showHttp2Frames || sslDebug) {
if (sslDebug) { if (sslDebug) {
System.setProperty("javax.net.debug", "") System.setProperty("javax.net.debug", "")
} }
LogManager.getLogManager().reset() LogManager.getLogManager().reset()
val handler = object : ConsoleHandler() { val handler =
override fun publish(record: LogRecord) { object : ConsoleHandler() {
super.publish(record) override fun publish(record: LogRecord) {
super.publish(record)
val parameters = record.parameters val parameters = record.parameters
if (sslDebug && record.loggerName == "javax.net.ssl" && parameters != null) { if (sslDebug && record.loggerName == "javax.net.ssl" && parameters != null) {
System.err.println(parameters[0]) System.err.println(parameters[0])
}
} }
} }
}
if (debug) { if (debug) {
handler.level = Level.ALL handler.level = Level.ALL

View File

@ -19,16 +19,17 @@ import java.util.logging.LogRecord
* Why so much construction? * Why so much construction?
*/ */
class OneLineLogFormat : Formatter() { class OneLineLogFormat : Formatter() {
private val d = DateTimeFormatterBuilder() private val d =
.appendValue(HOUR_OF_DAY, 2) DateTimeFormatterBuilder()
.appendLiteral(':') .appendValue(HOUR_OF_DAY, 2)
.appendValue(MINUTE_OF_HOUR, 2) .appendLiteral(':')
.optionalStart() .appendValue(MINUTE_OF_HOUR, 2)
.appendLiteral(':') .optionalStart()
.appendValue(SECOND_OF_MINUTE, 2) .appendLiteral(':')
.optionalStart() .appendValue(SECOND_OF_MINUTE, 2)
.appendFraction(NANO_OF_SECOND, 3, 3, true) .optionalStart()
.toFormatter() .appendFraction(NANO_OF_SECOND, 3, 3, true)
.toFormatter()
private val offset = ZoneOffset.systemDefault() private val offset = ZoneOffset.systemDefault()

View File

@ -15,14 +15,14 @@
*/ */
package okhttp3.curl package okhttp3.curl
import java.io.IOException
import okhttp3.RequestBody
import okio.Buffer
import assertk.assertThat import assertk.assertThat
import assertk.assertions.isEqualTo import assertk.assertions.isEqualTo
import assertk.assertions.isNull import assertk.assertions.isNull
import assertk.assertions.startsWith import assertk.assertions.startsWith
import java.io.IOException
import kotlin.test.Test import kotlin.test.Test
import okhttp3.RequestBody
import okio.Buffer
class MainTest { class MainTest {
@Test @Test
@ -34,7 +34,8 @@ class MainTest {
} }
@Test @Test
@Throws(IOException::class) fun put() { @Throws(IOException::class)
fun put() {
val request = fromArgs("-X", "PUT", "-d", "foo", "http://example.com").createRequest() val request = fromArgs("-X", "PUT", "-d", "foo", "http://example.com").createRequest()
assertThat(request.method).isEqualTo("PUT") assertThat(request.method).isEqualTo("PUT")
assertThat(request.url.toString()).isEqualTo("http://example.com/") assertThat(request.url.toString()).isEqualTo("http://example.com/")
@ -48,7 +49,7 @@ class MainTest {
assertThat(request.method).isEqualTo("POST") assertThat(request.method).isEqualTo("POST")
assertThat(request.url.toString()).isEqualTo("http://example.com/") assertThat(request.url.toString()).isEqualTo("http://example.com/")
assertThat(body!!.contentType().toString()).isEqualTo( assertThat(body!!.contentType().toString()).isEqualTo(
"application/x-www-form-urlencoded; charset=utf-8" "application/x-www-form-urlencoded; charset=utf-8",
) )
assertThat(bodyAsString(body)).isEqualTo("foo") assertThat(bodyAsString(body)).isEqualTo("foo")
} }
@ -60,17 +61,21 @@ class MainTest {
assertThat(request.method).isEqualTo("PUT") assertThat(request.method).isEqualTo("PUT")
assertThat(request.url.toString()).isEqualTo("http://example.com/") assertThat(request.url.toString()).isEqualTo("http://example.com/")
assertThat(body!!.contentType().toString()).isEqualTo( assertThat(body!!.contentType().toString()).isEqualTo(
"application/x-www-form-urlencoded; charset=utf-8" "application/x-www-form-urlencoded; charset=utf-8",
) )
assertThat(bodyAsString(body)).isEqualTo("foo") assertThat(bodyAsString(body)).isEqualTo("foo")
} }
@Test @Test
fun contentTypeHeader() { fun contentTypeHeader() {
val request = fromArgs( val request =
"-d", "foo", "-H", "Content-Type: application/json", fromArgs(
"http://example.com" "-d",
).createRequest() "foo",
"-H",
"Content-Type: application/json",
"http://example.com",
).createRequest()
val body = request.body val body = request.body
assertThat(request.method).isEqualTo("POST") assertThat(request.method).isEqualTo("POST")
assertThat(request.url.toString()).isEqualTo("http://example.com/") assertThat(request.url.toString()).isEqualTo("http://example.com/")
@ -105,12 +110,14 @@ class MainTest {
@Test @Test
fun headerSplitWithDate() { fun headerSplitWithDate() {
val request = fromArgs( val request =
"-H", "If-Modified-Since: Mon, 18 Aug 2014 15:16:06 GMT", fromArgs(
"http://example.com" "-H",
).createRequest() "If-Modified-Since: Mon, 18 Aug 2014 15:16:06 GMT",
"http://example.com",
).createRequest()
assertThat(request.header("If-Modified-Since")).isEqualTo( assertThat(request.header("If-Modified-Since")).isEqualTo(
"Mon, 18 Aug 2014 15:16:06 GMT" "Mon, 18 Aug 2014 15:16:06 GMT",
) )
} }

View File

@ -50,14 +50,16 @@ import org.junit.Test
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set. * Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/ */
class AndroidAsyncDnsTest { class AndroidAsyncDnsTest {
@JvmField @Rule val serverRule = MockWebServerRule() @JvmField @Rule
val serverRule = MockWebServerRule()
private lateinit var client: OkHttpClient private lateinit var client: OkHttpClient
private val localhost: HandshakeCertificates by lazy { private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust. // Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate = HeldCertificate.Builder() val heldCertificate =
.addSubjectAlternativeName("localhost") HeldCertificate.Builder()
.build() .addSubjectAlternativeName("localhost")
.build()
return@lazy HandshakeCertificates.Builder() return@lazy HandshakeCertificates.Builder()
.addPlatformTrustedCertificates() .addPlatformTrustedCertificates()
.heldCertificate(heldCertificate) .heldCertificate(heldCertificate)
@ -69,10 +71,11 @@ class AndroidAsyncDnsTest {
fun init() { fun init() {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)
client = OkHttpClient.Builder() client =
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6)) OkHttpClient.Builder()
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager) .dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.build() .sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()
serverRule.server.useHttps(localhost.sslSocketFactory()) serverRule.server.useHttps(localhost.sslSocketFactory())
} }
@ -127,17 +130,26 @@ class AndroidAsyncDnsTest {
val latch = CountDownLatch(1) val latch = CountDownLatch(1)
// assumes an IPv4 address // assumes an IPv4 address
AndroidAsyncDns.IPv4.query(hostname, object : AsyncDns.Callback { AndroidAsyncDns.IPv4.query(
override fun onResponse(hostname: String, addresses: List<InetAddress>) { hostname,
allAddresses.addAll(addresses) object : AsyncDns.Callback {
latch.countDown() override fun onResponse(
} hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
latch.countDown()
}
override fun onFailure(hostname: String, e: IOException) { override fun onFailure(
exception = e hostname: String,
latch.countDown() e: IOException,
} ) {
}) exception = e
latch.countDown()
}
},
)
latch.await() latch.await()
@ -173,10 +185,11 @@ class AndroidAsyncDnsTest {
val network = val network =
connectivityManager.activeNetwork ?: throw AssumptionViolatedException("No active network") connectivityManager.activeNetwork ?: throw AssumptionViolatedException("No active network")
val client = OkHttpClient.Builder() val client =
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6)) OkHttpClient.Builder()
.socketFactory(network.socketFactory) .dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.build() .socketFactory(network.socketFactory)
.build()
val call = val call =
client.newCall(Request("https://google.com/robots.txt".toHttpUrl())) client.newCall(Request("https://google.com/robots.txt".toHttpUrl()))

View File

@ -43,20 +43,34 @@ class AndroidAsyncDns(
private val resolver = DnsResolver.getInstance() private val resolver = DnsResolver.getInstance()
private val executor = Executors.newSingleThreadExecutor() private val executor = Executors.newSingleThreadExecutor()
override fun query(hostname: String, callback: AsyncDns.Callback) { override fun query(
hostname: String,
callback: AsyncDns.Callback,
) {
resolver.query( resolver.query(
network, hostname, dnsClass.type, DnsResolver.FLAG_EMPTY, executor, null, network,
hostname,
dnsClass.type,
DnsResolver.FLAG_EMPTY,
executor,
null,
object : DnsResolver.Callback<List<InetAddress>> { object : DnsResolver.Callback<List<InetAddress>> {
override fun onAnswer(addresses: List<InetAddress>, rCode: Int) { override fun onAnswer(
addresses: List<InetAddress>,
rCode: Int,
) {
callback.onResponse(hostname, addresses) callback.onResponse(hostname, addresses)
} }
override fun onError(e: DnsResolver.DnsException) { override fun onError(e: DnsResolver.DnsException) {
callback.onFailure(hostname, UnknownHostException(e.message).apply { callback.onFailure(
initCause(e) hostname,
}) UnknownHostException(e.message).apply {
initCause(e)
},
)
} }
} },
) )
} }

View File

@ -43,16 +43,16 @@ import org.robolectric.annotation.Config
sdk = [30], sdk = [30],
) )
class RobolectricOkHttpClientTest { class RobolectricOkHttpClientTest {
private lateinit var context: Context private lateinit var context: Context
private lateinit var client: OkHttpClient private lateinit var client: OkHttpClient
@Before @Before
fun setUp() { fun setUp() {
context = ApplicationProvider.getApplicationContext<Application>() context = ApplicationProvider.getApplicationContext<Application>()
client = OkHttpClient.Builder() client =
.cache(Cache("/cache".toPath(), 10_000_000, FakeFileSystem())) OkHttpClient.Builder()
.build() .cache(Cache("/cache".toPath(), 10_000_000, FakeFileSystem()))
.build()
} }
@Test @Test
@ -61,8 +61,9 @@ class RobolectricOkHttpClientTest {
val request = Request("https://www.google.com/robots.txt".toHttpUrl()) val request = Request("https://www.google.com/robots.txt".toHttpUrl())
val networkRequest = request.newBuilder() val networkRequest =
.build() request.newBuilder()
.build()
val call = client.newCall(networkRequest) val call = client.newCall(networkRequest)

View File

@ -28,9 +28,10 @@ import okhttp3.brotli.internal.uncompress
object BrotliInterceptor : Interceptor { object BrotliInterceptor : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response {
return if (chain.request().header("Accept-Encoding") == null) { return if (chain.request().header("Accept-Encoding") == null) {
val request = chain.request().newBuilder() val request =
.header("Accept-Encoding", "br,gzip") chain.request().newBuilder()
.build() .header("Accept-Encoding", "br,gzip")
.build()
val response = chain.proceed(request) val response = chain.proceed(request)

View File

@ -30,13 +30,14 @@ fun uncompress(response: Response): Response {
val body = response.body val body = response.body
val encoding = response.header("Content-Encoding") ?: return response val encoding = response.header("Content-Encoding") ?: return response
val decompressedSource = when { val decompressedSource =
encoding.equals("br", ignoreCase = true) -> when {
BrotliInputStream(body.source().inputStream()).source().buffer() encoding.equals("br", ignoreCase = true) ->
encoding.equals("gzip", ignoreCase = true) -> BrotliInputStream(body.source().inputStream()).source().buffer()
GzipSource(body.source()).buffer() encoding.equals("gzip", ignoreCase = true) ->
else -> return response GzipSource(body.source()).buffer()
} else -> return response
}
return response.newBuilder() return response.newBuilder()
.removeHeader("Content-Encoding") .removeHeader("Content-Encoding")

View File

@ -38,14 +38,15 @@ class BrotliInterceptorTest {
@Test @Test
fun testUncompressBrotli() { fun testUncompressBrotli() {
val s = val s =
"1bce00009c05ceb9f028d14e416230f718960a537b0922d2f7b6adef56532c08dff44551516690131494db" + "1bce00009c05ceb9f028d14e416230f718960a537b0922d2f7b6adef56532c08dff44551516690131494db" +
"6021c7e3616c82c1bc2416abb919aaa06e8d30d82cc2981c2f5c900bfb8ee29d5c03deb1c0dacff80e" + "6021c7e3616c82c1bc2416abb919aaa06e8d30d82cc2981c2f5c900bfb8ee29d5c03deb1c0dacff80e" +
"abe82ba64ed250a497162006824684db917963ecebe041b352a3e62d629cc97b95cac24265b175171e" + "abe82ba64ed250a497162006824684db917963ecebe041b352a3e62d629cc97b95cac24265b175171e" +
"5cb384cd0912aeb5b5dd9555f2dd1a9b20688201" "5cb384cd0912aeb5b5dd9555f2dd1a9b20688201"
val response = response("https://httpbin.org/brotli", s.decodeHex()) { val response =
header("Content-Encoding", "br") response("https://httpbin.org/brotli", s.decodeHex()) {
} header("Content-Encoding", "br")
}
val uncompressed = uncompress(response) val uncompressed = uncompress(response)
@ -57,15 +58,16 @@ class BrotliInterceptorTest {
@Test @Test
fun testUncompressGzip() { fun testUncompressGzip() {
val s = val s =
"1f8b0800968f215d02ff558ec10e82301044ef7c45b3e75269d0c478e340e4a426e007086c4a636c9bb65e" + "1f8b0800968f215d02ff558ec10e82301044ef7c45b3e75269d0c478e340e4a426e007086c4a636c9bb65e" +
"24fcbb5b484c3cec61deccecee9c3106eaa39dc3114e2cfa377296d8848f117d20369324500d03ba98" + "24fcbb5b484c3cec61deccecee9c3106eaa39dc3114e2cfa377296d8848f117d20369324500d03ba98" +
"d766b0a3368a0ce83d4f55581b14696c88894f31ba5e1b61bdfa79f7803eaf149a35619f29b3db0b29" + "d766b0a3368a0ce83d4f55581b14696c88894f31ba5e1b61bdfa79f7803eaf149a35619f29b3db0b29" +
"8abcbd54b7b6b97640c965bbfec238d9f4109ceb6edb01d66ba54d6247296441531e445970f627215b" + "8abcbd54b7b6b97640c965bbfec238d9f4109ceb6edb01d66ba54d6247296441531e445970f627215b" +
"b22f1017320dd5000000" "b22f1017320dd5000000"
val response = response("https://httpbin.org/gzip", s.decodeHex()) { val response =
header("Content-Encoding", "gzip") response("https://httpbin.org/gzip", s.decodeHex()) {
} header("Content-Encoding", "gzip")
}
val uncompressed = uncompress(response) val uncompressed = uncompress(response)
@ -86,9 +88,10 @@ class BrotliInterceptorTest {
@Test @Test
fun testFailsUncompress() { fun testFailsUncompress() {
val response = response("https://httpbin.org/brotli", "bb919aaa06e8".decodeHex()) { val response =
header("Content-Encoding", "br") response("https://httpbin.org/brotli", "bb919aaa06e8".decodeHex()) {
} header("Content-Encoding", "br")
}
assertFailsWith<IOException> { assertFailsWith<IOException> {
val failingResponse = uncompress(response) val failingResponse = uncompress(response)
@ -101,11 +104,12 @@ class BrotliInterceptorTest {
@Test @Test
fun testSkipUncompressNoContentResponse() { fun testSkipUncompressNoContentResponse() {
val response = response("https://httpbin.org/brotli", EMPTY) { val response =
header("Content-Encoding", "br") response("https://httpbin.org/brotli", EMPTY) {
code(204) header("Content-Encoding", "br")
message("NO CONTENT") code(204)
} message("NO CONTENT")
}
val same = uncompress(response) val same = uncompress(response)
@ -116,15 +120,15 @@ class BrotliInterceptorTest {
private fun response( private fun response(
url: String, url: String,
bodyHex: ByteString, bodyHex: ByteString,
fn: Response.Builder.() -> Unit = {} fn: Response.Builder.() -> Unit = {},
): Response { ): Response {
return Response.Builder() return Response.Builder()
.body(bodyHex.toResponseBody("text/plain".toMediaType())) .body(bodyHex.toResponseBody("text/plain".toMediaType()))
.code(200) .code(200)
.message("OK") .message("OK")
.request(Request.Builder().url(url).build()) .request(Request.Builder().url(url).build())
.protocol(Protocol.HTTP_2) .protocol(Protocol.HTTP_2)
.apply(fn) .apply(fn)
.build() .build()
} }
} }

View File

@ -19,7 +19,8 @@ import okhttp3.OkHttpClient
import okhttp3.Request import okhttp3.Request
fun main() { fun main() {
val client = OkHttpClient.Builder() val client =
OkHttpClient.Builder()
.addInterceptor(BrotliInterceptor) .addInterceptor(BrotliInterceptor)
.build() .build()
@ -27,7 +28,10 @@ fun main() {
sendRequest("https://httpbin.org/gzip", client) sendRequest("https://httpbin.org/gzip", client)
} }
private fun sendRequest(url: String, client: OkHttpClient) { private fun sendRequest(
url: String,
client: OkHttpClient,
) {
val req = Request.Builder().url(url).build() val req = Request.Builder().url(url).build()
client.newCall(req).execute().use { client.newCall(req).execute().use {

View File

@ -17,23 +17,32 @@
package okhttp3 package okhttp3
import kotlin.coroutines.resumeWithException
import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.suspendCancellableCoroutine
import okio.IOException import okio.IOException
import kotlin.coroutines.resumeWithException
@OptIn(ExperimentalCoroutinesApi::class) @OptIn(ExperimentalCoroutinesApi::class)
suspend fun Call.executeAsync(): Response = suspendCancellableCoroutine { continuation -> suspend fun Call.executeAsync(): Response =
continuation.invokeOnCancellation { suspendCancellableCoroutine { continuation ->
this.cancel() continuation.invokeOnCancellation {
} this.cancel()
this.enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
continuation.resumeWithException(e)
} }
this.enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
continuation.resumeWithException(e)
}
override fun onResponse(call: Call, response: Response) { override fun onResponse(
continuation.resume(value = response, onCancellation = { call.cancel() }) call: Call,
} response: Response,
}) ) {
} continuation.resume(value = response, onCancellation = { call.cancel() })
}
},
)
}

View File

@ -80,7 +80,7 @@ class SuspendCallTest {
MockResponse.Builder() MockResponse.Builder()
.bodyDelay(5, TimeUnit.SECONDS) .bodyDelay(5, TimeUnit.SECONDS)
.body("abc") .body("abc")
.build() .build(),
) )
val call = client.newCall(request) val call = client.newCall(request)
@ -109,7 +109,7 @@ class SuspendCallTest {
MockResponse.Builder() MockResponse.Builder()
.bodyDelay(5, TimeUnit.SECONDS) .bodyDelay(5, TimeUnit.SECONDS)
.body("abc") .body("abc")
.build() .build(),
) )
val call = client.newCall(request) val call = client.newCall(request)
@ -137,7 +137,7 @@ class SuspendCallTest {
MockResponse( MockResponse(
body = "abc", body = "abc",
socketPolicy = DisconnectAfterRequest, socketPolicy = DisconnectAfterRequest,
) ),
) )
val call = client.newCall(request) val call = client.newCall(request)

View File

@ -26,13 +26,13 @@ import okhttp3.Dns
*/ */
internal class BootstrapDns( internal class BootstrapDns(
private val dnsHostname: String, private val dnsHostname: String,
private val dnsServers: List<InetAddress> private val dnsServers: List<InetAddress>,
) : Dns { ) : Dns {
@Throws(UnknownHostException::class) @Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> { override fun lookup(hostname: String): List<InetAddress> {
if (this.dnsHostname != hostname) { if (this.dnsHostname != hostname) {
throw UnknownHostException( throw UnknownHostException(
"BootstrapDns called for $hostname instead of $dnsHostname" "BootstrapDns called for $hostname instead of $dnsHostname",
) )
} }

View File

@ -51,7 +51,7 @@ class DnsOverHttps internal constructor(
@get:JvmName("includeIPv6") val includeIPv6: Boolean, @get:JvmName("includeIPv6") val includeIPv6: Boolean,
@get:JvmName("post") val post: Boolean, @get:JvmName("post") val post: Boolean,
@get:JvmName("resolvePrivateAddresses") val resolvePrivateAddresses: Boolean, @get:JvmName("resolvePrivateAddresses") val resolvePrivateAddresses: Boolean,
@get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean @get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean,
) : Dns { ) : Dns {
@Throws(UnknownHostException::class) @Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> { override fun lookup(hostname: String): List<InetAddress> {
@ -94,37 +94,46 @@ class DnsOverHttps internal constructor(
networkRequests: MutableList<Call>, networkRequests: MutableList<Call>,
results: MutableList<InetAddress>, results: MutableList<InetAddress>,
failures: MutableList<Exception>, failures: MutableList<Exception>,
type: Int type: Int,
) { ) {
val request = buildRequest(hostname, type) val request = buildRequest(hostname, type)
val response = getCacheOnlyResponse(request) val response = getCacheOnlyResponse(request)
response?.let { processResponse(it, hostname, results, failures) } ?: networkRequests.add( response?.let { processResponse(it, hostname, results, failures) } ?: networkRequests.add(
client.newCall(request)) client.newCall(request),
)
} }
private fun executeRequests( private fun executeRequests(
hostname: String, hostname: String,
networkRequests: List<Call>, networkRequests: List<Call>,
responses: MutableList<InetAddress>, responses: MutableList<InetAddress>,
failures: MutableList<Exception> failures: MutableList<Exception>,
) { ) {
val latch = CountDownLatch(networkRequests.size) val latch = CountDownLatch(networkRequests.size)
for (call in networkRequests) { for (call in networkRequests) {
call.enqueue(object : Callback { call.enqueue(
override fun onFailure(call: Call, e: IOException) { object : Callback {
synchronized(failures) { override fun onFailure(
failures.add(e) call: Call,
e: IOException,
) {
synchronized(failures) {
failures.add(e)
}
latch.countDown()
} }
latch.countDown()
}
override fun onResponse(call: Call, response: Response) { override fun onResponse(
processResponse(response, hostname, responses, failures) call: Call,
latch.countDown() response: Response,
} ) {
}) processResponse(response, hostname, responses, failures)
latch.countDown()
}
},
)
} }
try { try {
@ -138,7 +147,7 @@ class DnsOverHttps internal constructor(
response: Response, response: Response,
hostname: String, hostname: String,
results: MutableList<InetAddress>, results: MutableList<InetAddress>,
failures: MutableList<Exception> failures: MutableList<Exception>,
) { ) {
try { try {
val addresses = readResponse(hostname, response) val addresses = readResponse(hostname, response)
@ -153,7 +162,10 @@ class DnsOverHttps internal constructor(
} }
@Throws(UnknownHostException::class) @Throws(UnknownHostException::class)
private fun throwBestFailure(hostname: String, failures: List<Exception>): List<InetAddress> { private fun throwBestFailure(
hostname: String,
failures: List<Exception>,
): List<InetAddress> {
if (failures.isEmpty()) { if (failures.isEmpty()) {
throw UnknownHostException(hostname) throw UnknownHostException(hostname)
} }
@ -179,7 +191,8 @@ class DnsOverHttps internal constructor(
try { try {
// Use the cache without hitting the network first // Use the cache without hitting the network first
// 504 code indicates that the Cache is stale // 504 code indicates that the Cache is stale
val preferCache = CacheControl.Builder() val preferCache =
CacheControl.Builder()
.onlyIfCached() .onlyIfCached()
.build() .build()
val cacheRequest = request.newBuilder().cacheControl(preferCache).build() val cacheRequest = request.newBuilder().cacheControl(preferCache).build()
@ -199,7 +212,10 @@ class DnsOverHttps internal constructor(
} }
@Throws(Exception::class) @Throws(Exception::class)
private fun readResponse(hostname: String, response: Response): List<InetAddress> { private fun readResponse(
hostname: String,
response: Response,
): List<InetAddress> {
if (response.cacheResponse == null && response.protocol !== Protocol.HTTP_2) { if (response.cacheResponse == null && response.protocol !== Protocol.HTTP_2) {
Platform.get().log("Incorrect protocol: ${response.protocol}", Platform.WARN) Platform.get().log("Incorrect protocol: ${response.protocol}", Platform.WARN)
} }
@ -213,7 +229,7 @@ class DnsOverHttps internal constructor(
if (body.contentLength() > MAX_RESPONSE_SIZE) { if (body.contentLength() > MAX_RESPONSE_SIZE) {
throw IOException( throw IOException(
"response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes" "response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes",
) )
} }
@ -223,19 +239,22 @@ class DnsOverHttps internal constructor(
} }
} }
private fun buildRequest(hostname: String, type: Int): Request = private fun buildRequest(
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply { hostname: String,
val query = DnsRecordCodec.encodeQuery(hostname, type) type: Int,
): Request =
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply {
val query = DnsRecordCodec.encodeQuery(hostname, type)
if (post) { if (post) {
url(url).post(query.toRequestBody(DNS_MESSAGE)) url(url).post(query.toRequestBody(DNS_MESSAGE))
} else { } else {
val encoded = query.base64Url().replace("=", "") val encoded = query.base64Url().replace("=", "")
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build() val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build()
url(requestUrl) url(requestUrl)
} }
}.build() }.build()
class Builder { class Builder {
internal var client: OkHttpClient? = null internal var client: OkHttpClient? = null
@ -250,49 +269,56 @@ class DnsOverHttps internal constructor(
fun build(): DnsOverHttps { fun build(): DnsOverHttps {
val client = this.client ?: throw NullPointerException("client not set") val client = this.client ?: throw NullPointerException("client not set")
return DnsOverHttps( return DnsOverHttps(
client.newBuilder().dns(buildBootstrapClient(this)).build(), client.newBuilder().dns(buildBootstrapClient(this)).build(),
checkNotNull(url) { "url not set" }, checkNotNull(url) { "url not set" },
includeIPv6, includeIPv6,
post, post,
resolvePrivateAddresses, resolvePrivateAddresses,
resolvePublicAddresses resolvePublicAddresses,
) )
} }
fun client(client: OkHttpClient) = apply { fun client(client: OkHttpClient) =
this.client = client apply {
} this.client = client
}
fun url(url: HttpUrl) = apply { fun url(url: HttpUrl) =
this.url = url apply {
} this.url = url
}
fun includeIPv6(includeIPv6: Boolean) = apply { fun includeIPv6(includeIPv6: Boolean) =
this.includeIPv6 = includeIPv6 apply {
} this.includeIPv6 = includeIPv6
}
fun post(post: Boolean) = apply { fun post(post: Boolean) =
this.post = post apply {
} this.post = post
}
fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) = apply { fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) =
this.resolvePrivateAddresses = resolvePrivateAddresses apply {
} this.resolvePrivateAddresses = resolvePrivateAddresses
}
fun resolvePublicAddresses(resolvePublicAddresses: Boolean) = apply { fun resolvePublicAddresses(resolvePublicAddresses: Boolean) =
this.resolvePublicAddresses = resolvePublicAddresses apply {
} this.resolvePublicAddresses = resolvePublicAddresses
}
fun bootstrapDnsHosts(bootstrapDnsHosts: List<InetAddress>?) = apply { fun bootstrapDnsHosts(bootstrapDnsHosts: List<InetAddress>?) =
this.bootstrapDnsHosts = bootstrapDnsHosts apply {
} this.bootstrapDnsHosts = bootstrapDnsHosts
}
fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder = fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder = bootstrapDnsHosts(bootstrapDnsHosts.toList())
bootstrapDnsHosts(bootstrapDnsHosts.toList())
fun systemDns(systemDns: Dns) = apply { fun systemDns(systemDns: Dns) =
this.systemDns = systemDns apply {
} this.systemDns = systemDns
}
} }
companion object { companion object {

View File

@ -33,31 +33,38 @@ internal object DnsRecordCodec {
private const val TYPE_PTR = 0x000c private const val TYPE_PTR = 0x000c
private val ASCII = Charsets.US_ASCII private val ASCII = Charsets.US_ASCII
fun encodeQuery(host: String, type: Int): ByteString = Buffer().apply { fun encodeQuery(
writeShort(0) // query id host: String,
writeShort(256) // flags with recursion type: Int,
writeShort(1) // question count ): ByteString =
writeShort(0) // answerCount Buffer().apply {
writeShort(0) // authorityResourceCount writeShort(0) // query id
writeShort(0) // additional writeShort(256) // flags with recursion
writeShort(1) // question count
writeShort(0) // answerCount
writeShort(0) // authorityResourceCount
writeShort(0) // additional
val nameBuf = Buffer() val nameBuf = Buffer()
val labels = host.split('.').dropLastWhile { it.isEmpty() } val labels = host.split('.').dropLastWhile { it.isEmpty() }
for (label in labels) { for (label in labels) {
val utf8ByteCount = label.utf8Size() val utf8ByteCount = label.utf8Size()
require(utf8ByteCount == label.length.toLong()) { "non-ascii hostname: $host" } require(utf8ByteCount == label.length.toLong()) { "non-ascii hostname: $host" }
nameBuf.writeByte(utf8ByteCount.toInt()) nameBuf.writeByte(utf8ByteCount.toInt())
nameBuf.writeUtf8(label) nameBuf.writeUtf8(label)
} }
nameBuf.writeByte(0) // end nameBuf.writeByte(0) // end
nameBuf.copyTo(this, 0, nameBuf.size) nameBuf.copyTo(this, 0, nameBuf.size)
writeShort(type) writeShort(type)
writeShort(1) // CLASS_IN writeShort(1) // CLASS_IN
}.readByteString() }.readByteString()
@Throws(Exception::class) @Throws(Exception::class)
fun decodeAnswers(hostname: String, byteString: ByteString): List<InetAddress> { fun decodeAnswers(
hostname: String,
byteString: ByteString,
): List<InetAddress> {
val result = mutableListOf<InetAddress>() val result = mutableListOf<InetAddress>()
val buf = Buffer() val buf = Buffer()
@ -91,7 +98,8 @@ internal object DnsRecordCodec {
val type = buf.readShort().toInt() and 0xffff val type = buf.readShort().toInt() and 0xffff
buf.readShort() // class buf.readShort() // class
@Suppress("UNUSED_VARIABLE") val ttl = buf.readInt().toLong() and 0xffffffffL // ttl @Suppress("UNUSED_VARIABLE")
val ttl = buf.readInt().toLong() and 0xffffffffL // ttl
val length = buf.readShort().toInt() and 0xffff val length = buf.readShort().toInt() and 0xffff
if (type == TYPE_A || type == TYPE_AAAA) { if (type == TYPE_A || type == TYPE_AAAA) {

View File

@ -55,9 +55,10 @@ class DnsOverHttpsTest {
private lateinit var server: MockWebServer private lateinit var server: MockWebServer
private lateinit var dns: Dns private lateinit var dns: Dns
private val cacheFs = FakeFileSystem() private val cacheFs = FakeFileSystem()
private val bootstrapClient = OkHttpClient.Builder() private val bootstrapClient =
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)) OkHttpClient.Builder()
.build() .protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.build()
@BeforeEach @BeforeEach
fun setUp(server: MockWebServer) { fun setUp(server: MockWebServer) {
@ -72,8 +73,8 @@ class DnsOverHttpsTest {
dnsResponse( dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112" "0003b00049df00112",
) ),
) )
val result = dns.lookup("google.com") val result = dns.lookup("google.com")
assertThat(result).isEqualTo(listOf(address("157.240.1.18"))) assertThat(result).isEqualTo(listOf(address("157.240.1.18")))
@ -89,15 +90,15 @@ class DnsOverHttpsTest {
dnsResponse( dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c0005000" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c0005000" +
"100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001000" + "100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001000" +
"10000003b00049df00112" "10000003b00049df00112",
) ),
) )
server.enqueue( server.enqueue(
dnsResponse( dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c0005000" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c0005000" +
"100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c000" + "100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c000" +
"10000003b00102a032880f0290011faceb00c00000002" "10000003b00102a032880f0290011faceb00c00000002",
) ),
) )
dns = buildLocalhost(bootstrapClient, true) dns = buildLocalhost(bootstrapClient, true)
val result = dns.lookup("google.com") val result = dns.lookup("google.com")
@ -111,7 +112,7 @@ class DnsOverHttpsTest {
assertThat(listOf(request1.path, request2.path)) assertThat(listOf(request1.path, request2.path))
.containsExactlyInAnyOrder( .containsExactlyInAnyOrder(
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ" "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ",
) )
} }
@ -121,8 +122,8 @@ class DnsOverHttpsTest {
dnsResponse( dnsResponse(
"0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b00060001000" + "0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b00060001000" +
"007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e6574c01b5adb1" + "007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e6574c01b5adb1" +
"2c100000e10000003840012750000000e10" "2c100000e10000003840012750000000e10",
) ),
) )
try { try {
dns.lookup("google.com") dns.lookup("google.com")
@ -179,11 +180,11 @@ class DnsOverHttpsTest {
dnsResponse( dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112" "0003b00049df00112",
) )
.newBuilder() .newBuilder()
.setHeader("cache-control", "private, max-age=298") .setHeader("cache-control", "private, max-age=298")
.build() .build(),
) )
var result = cachedDns.lookup("google.com") var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18")) assertThat(result).containsExactly(address("157.240.1.18"))
@ -204,29 +205,29 @@ class DnsOverHttpsTest {
dnsResponse( dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112" "0003b00049df00112",
) )
.newBuilder() .newBuilder()
.setHeader("cache-control", "max-age=1") .setHeader("cache-control", "max-age=1")
.build() .build(),
) )
var result = cachedDns.lookup("google.com") var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18")) assertThat(result).containsExactly(address("157.240.1.18"))
var recordedRequest = server.takeRequest(0, TimeUnit.SECONDS) var recordedRequest = server.takeRequest(0, TimeUnit.SECONDS)
assertThat(recordedRequest!!.method).isEqualTo("GET") assertThat(recordedRequest!!.method).isEqualTo("GET")
assertThat(recordedRequest.path).isEqualTo( assertThat(recordedRequest.path).isEqualTo(
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ" "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
) )
Thread.sleep(2000) Thread.sleep(2000)
server.enqueue( server.enqueue(
dnsResponse( dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" + "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" + "0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112" "0003b00049df00112",
) )
.newBuilder() .newBuilder()
.setHeader("cache-control", "max-age=1") .setHeader("cache-control", "max-age=1")
.build() .build(),
) )
result = cachedDns.lookup("google.com") result = cachedDns.lookup("google.com")
assertThat(result).isEqualTo(listOf(address("157.240.1.18"))) assertThat(result).isEqualTo(listOf(address("157.240.1.18")))
@ -244,7 +245,10 @@ class DnsOverHttpsTest {
.build() .build()
} }
private fun buildLocalhost(bootstrapClient: OkHttpClient, includeIPv6: Boolean): DnsOverHttps { private fun buildLocalhost(
bootstrapClient: OkHttpClient,
includeIPv6: Boolean,
): DnsOverHttps {
val url = server.url("/lookup?ct") val url = server.url("/lookup?ct")
return DnsOverHttps.Builder().client(bootstrapClient) return DnsOverHttps.Builder().client(bootstrapClient)
.includeIPv6(includeIPv6) .includeIPv6(includeIPv6)

View File

@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.dnsoverhttps package okhttp3.dnsoverhttps
import assertk.assertThat import assertk.assertThat
import assertk.assertions.containsExactly import assertk.assertions.containsExactly
import assertk.assertions.isEqualTo import assertk.assertions.isEqualTo
import assertk.fail
import java.net.InetAddress import java.net.InetAddress
import java.net.UnknownHostException import java.net.UnknownHostException
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
@ -36,7 +36,10 @@ class DnsRecordCodecTest {
assertThat(encoded).isEqualTo("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ") assertThat(encoded).isEqualTo("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ")
} }
private fun encodeQuery(host: String, type: Int): String { private fun encodeQuery(
host: String,
type: Int,
): String {
return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "") return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "")
} }
@ -48,33 +51,45 @@ class DnsRecordCodecTest {
@Test @Test
fun testGoogleDotComDecodingFromCloudflare() { fun testGoogleDotComDecodingFromCloudflare() {
val encoded = decodeAnswers( val encoded =
hostname = "test.com", decodeAnswers(
byteString = ("00008180000100010000000006676f6f676c6503636f6d0000010001c00c0001000100000043" + hostname = "test.com",
"0004d83ad54e").decodeHex() byteString =
) (
"00008180000100010000000006676f6f676c6503636f6d0000010001c00c0001000100000043" +
"0004d83ad54e"
).decodeHex(),
)
assertThat(encoded).containsExactly(InetAddress.getByName("216.58.213.78")) assertThat(encoded).containsExactly(InetAddress.getByName("216.58.213.78"))
} }
@Test @Test
fun testGoogleDotComDecodingFromGoogle() { fun testGoogleDotComDecodingFromGoogle() {
val decoded = decodeAnswers( val decoded =
hostname = "test.com", decodeAnswers(
byteString = ("0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c" + hostname = "test.com",
"0005000100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001" + byteString =
"00010000003b00049df00112").decodeHex() (
) "0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c" +
"0005000100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001" +
"00010000003b00049df00112"
).decodeHex(),
)
assertThat(decoded).containsExactly(InetAddress.getByName("157.240.1.18")) assertThat(decoded).containsExactly(InetAddress.getByName("157.240.1.18"))
} }
@Test @Test
fun testGoogleDotComDecodingFromGoogleIPv6() { fun testGoogleDotComDecodingFromGoogleIPv6() {
val decoded = decodeAnswers( val decoded =
hostname = "test.com", decodeAnswers(
byteString = ("0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c" + hostname = "test.com",
"0005000100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c" + byteString =
"00010000003b00102a032880f0290011faceb00c00000002").decodeHex() (
) "0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c" +
"0005000100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c" +
"00010000003b00102a032880f0290011faceb00c00000002"
).decodeHex(),
)
assertThat(decoded) assertThat(decoded)
.containsExactly(InetAddress.getByName("2a03:2880:f029:11:face:b00c:0:2")) .containsExactly(InetAddress.getByName("2a03:2880:f029:11:face:b00c:0:2"))
} }
@ -84,9 +99,12 @@ class DnsRecordCodecTest {
assertFailsWith<UnknownHostException> { assertFailsWith<UnknownHostException> {
decodeAnswers( decodeAnswers(
hostname = "sdflkhfsdlkjdf.ee", hostname = "sdflkhfsdlkjdf.ee",
byteString = ("0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b" + byteString =
"00060001000007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e65" + (
"74c01b5adb12c100000e10000003840012750000000e10").decodeHex() "0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b" +
"00060001000007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e65" +
"74c01b5adb12c100000e10000003840012750000000e10"
).decodeHex(),
) )
}.also { expected -> }.also { expected ->
assertThat(expected.message).isEqualTo("sdflkhfsdlkjdf.ee: NXDOMAIN") assertThat(expected.message).isEqualTo("sdflkhfsdlkjdf.ee: NXDOMAIN")

View File

@ -24,7 +24,10 @@ import okhttp3.OkHttpClient
import okhttp3.dnsoverhttps.DohProviders.providers import okhttp3.dnsoverhttps.DohProviders.providers
import org.conscrypt.OpenSSLProvider import org.conscrypt.OpenSSLProvider
private fun runBatch(dnsProviders: List<DnsOverHttps>, names: List<String>) { private fun runBatch(
dnsProviders: List<DnsOverHttps>,
names: List<String>,
) {
var time = System.currentTimeMillis() var time = System.currentTimeMillis()
for (dns in dnsProviders) { for (dns in dnsProviders) {
println("Testing ${dns.url}") println("Testing ${dns.url}")
@ -54,46 +57,52 @@ fun main() {
var names = listOf("google.com", "graph.facebook.com", "sdflkhfsdlkjdf.ee") var names = listOf("google.com", "graph.facebook.com", "sdflkhfsdlkjdf.ee")
try { try {
println("uncached\n********\n") println("uncached\n********\n")
var dnsProviders = providers( var dnsProviders =
client = bootstrapClient, providers(
http2Only = false, client = bootstrapClient,
workingOnly = false, http2Only = false,
getOnly = false, workingOnly = false,
) getOnly = false,
)
runBatch(dnsProviders, names) runBatch(dnsProviders, names)
val dnsCache = Cache( val dnsCache =
directory = File("./target/TestDohMain.cache.${System.currentTimeMillis()}"), Cache(
maxSize = 10L * 1024 * 1024 directory = File("./target/TestDohMain.cache.${System.currentTimeMillis()}"),
) maxSize = 10L * 1024 * 1024,
)
println("Bad targets\n***********\n") println("Bad targets\n***********\n")
val url = "https://dns.cloudflare.com/.not-so-well-known/run-dmc-query".toHttpUrl() val url = "https://dns.cloudflare.com/.not-so-well-known/run-dmc-query".toHttpUrl()
val badProviders = listOf( val badProviders =
DnsOverHttps.Builder() listOf(
.client(bootstrapClient) DnsOverHttps.Builder()
.url(url) .client(bootstrapClient)
.post(true) .url(url)
.build() .post(true)
) .build(),
)
runBatch(badProviders, names) runBatch(badProviders, names)
println("cached first run\n****************\n") println("cached first run\n****************\n")
names = listOf("google.com", "graph.facebook.com") names = listOf("google.com", "graph.facebook.com")
bootstrapClient = bootstrapClient.newBuilder() bootstrapClient =
.cache(dnsCache) bootstrapClient.newBuilder()
.build() .cache(dnsCache)
dnsProviders = providers( .build()
client = bootstrapClient, dnsProviders =
http2Only = true, providers(
workingOnly = true, client = bootstrapClient,
getOnly = true, http2Only = true,
) workingOnly = true,
getOnly = true,
)
runBatch(dnsProviders, names) runBatch(dnsProviders, names)
println("cached second run\n*****************\n") println("cached second run\n*****************\n")
dnsProviders = providers( dnsProviders =
client = bootstrapClient, providers(
http2Only = true, client = bootstrapClient,
workingOnly = true, http2Only = true,
getOnly = true, workingOnly = true,
) getOnly = true,
)
runBatch(dnsProviders, names) runBatch(dnsProviders, names)
} finally { } finally {
bootstrapClient.connectionPool.evictAll() bootstrapClient.connectionPool.evictAll()

View File

@ -28,7 +28,7 @@ class HpackDecodeInteropTest : HpackDecodeTestBase() {
fun testGoodDecoderInterop(story: Story) { fun testGoodDecoderInterop(story: Story) {
assumeFalse( assumeFalse(
story === Story.MISSING, story === Story.MISSING,
"Test stories missing, checkout git submodule" "Test stories missing, checkout git submodule",
) )
testDecoder(story) testDecoder(story)
} }

View File

@ -36,7 +36,7 @@ open class HpackDecodeTestBase {
assertSetEquals( assertSetEquals(
"seqno=$testCase.seqno", "seqno=$testCase.seqno",
testCase.headersList, testCase.headersList,
hpackReader.getAndResetHeaderList() hpackReader.getAndResetHeaderList(),
) )
} }
} }

View File

@ -43,7 +43,7 @@ class HpackRoundTripTest : HpackDecodeTestBase() {
fun testRoundTrip(story: Story) { fun testRoundTrip(story: Story) {
assumeFalse( assumeFalse(
story === Story.MISSING, story === Story.MISSING,
"Test stories missing, checkout git submodule" "Test stories missing, checkout git submodule",
) )
val newCases = mutableListOf<Case>() val newCases = mutableListOf<Case>()

View File

@ -35,13 +35,17 @@ import okio.source
*/ */
object HpackJsonUtil { object HpackJsonUtil {
@Suppress("unused") @Suppress("unused")
private val MOSHI = Moshi.Builder() private val MOSHI =
.add(object : Any() { Moshi.Builder()
@ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex() .add(
@FromJson fun byteStringFromJson(json: String) = json.decodeHex() object : Any() {
}) @ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex()
.add(KotlinJsonAdapterFactory())
.build() @FromJson fun byteStringFromJson(json: String) = json.decodeHex()
},
)
.add(KotlinJsonAdapterFactory())
.build()
private val STORY_JSON_ADAPTER = MOSHI.adapter(Story::class.java) private val STORY_JSON_ADAPTER = MOSHI.adapter(Story::class.java)
private val fileSystem = FileSystem.SYSTEM private val fileSystem = FileSystem.SYSTEM
@ -58,8 +62,9 @@ object HpackJsonUtil {
/** Iterate through the hpack-test-case resources, only picking stories for the current draft. */ /** Iterate through the hpack-test-case resources, only picking stories for the current draft. */
fun storiesForCurrentDraft(): Array<String> { fun storiesForCurrentDraft(): Array<String> {
val resource = HpackJsonUtil::class.java.getResource("/hpack-test-case") val resource =
?: return arrayOf() HpackJsonUtil::class.java.getResource("/hpack-test-case")
?: return arrayOf()
val testCaseDirectory = File(resource.toURI()).toOkioPath() val testCaseDirectory = File(resource.toURI()).toOkioPath()
val result = mutableListOf<String>() val result = mutableListOf<String>()
@ -83,17 +88,20 @@ object HpackJsonUtil {
val result = mutableListOf<Story>() val result = mutableListOf<Story>()
var i = 0 var i = 0
while (true) { // break after last test. while (true) { // break after last test.
val storyResourceName = String.format( val storyResourceName =
"/hpack-test-case/%s/story_%02d.json", String.format(
testFolderName, "/hpack-test-case/%s/story_%02d.json",
i, testFolderName,
) i,
val storyInputStream = HpackJsonUtil::class.java.getResourceAsStream(storyResourceName) )
?: break val storyInputStream =
HpackJsonUtil::class.java.getResourceAsStream(storyResourceName)
?: break
try { try {
storyInputStream.use { storyInputStream.use {
val story = readStory(storyInputStream.source().buffer()) val story =
.copy(fileName = storyResourceName) readStory(storyInputStream.source().buffer())
.copy(fileName = storyResourceName)
result.add(story) result.add(story)
i++ i++
} }

View File

@ -24,7 +24,6 @@ data class Story(
val cases: List<Case>, val cases: List<Case>,
val fileName: String? = null, val fileName: String? = null,
) { ) {
// Used as the test name. // Used as the test name.
override fun toString() = fileName ?: "?" override fun toString() = fileName ?: "?"

View File

@ -33,9 +33,10 @@ fun main(vararg args: String) {
fun loadIdnaMappingTableData(): IdnaMappingTableData { fun loadIdnaMappingTableData(): IdnaMappingTableData {
val path = "/okhttp3/internal/idna/IdnaMappingTable.txt".toPath() val path = "/okhttp3/internal/idna/IdnaMappingTable.txt".toPath()
val table = FileSystem.RESOURCES.read(path) { val table =
readPlainTextIdnaMappingTable() FileSystem.RESOURCES.read(path) {
} readPlainTextIdnaMappingTable()
}
return buildIdnaMappingTableData(table) return buildIdnaMappingTableData(table)
} }
@ -60,18 +61,18 @@ fun generateMappingTableFile(data: IdnaMappingTableData): FileSpec {
.addModifiers(KModifier.INTERNAL) .addModifiers(KModifier.INTERNAL)
.initializer( .initializer(
""" """
|%T(⇥ |%T(⇥
|sections = "%L", |sections = "%L",
|ranges = "%L", |ranges = "%L",
|mappings = "%L", |mappings = "%L",
|⇤) |⇤)
""".trimMargin(), """.trimMargin(),
idnaMappingTable, idnaMappingTable,
data.sections.escapeDataString(), data.sections.escapeDataString(),
data.ranges.escapeDataString(), data.ranges.escapeDataString(),
data.mappings.escapeDataString(), data.mappings.escapeDataString(),
) )
.build() .build(),
) )
.build() .build()
} }
@ -89,7 +90,8 @@ fun String.escapeDataString(): String {
'$'.code, '$'.code,
'\\'.code, '\\'.code,
'·'.code, '·'.code,
127 -> append(String.format("\\u%04x", codePoint)) 127,
-> append(String.format("\\u%04x", codePoint))
else -> appendCodePoint(codePoint) else -> appendCodePoint(codePoint)
} }

View File

@ -23,22 +23,22 @@ internal sealed interface MappedRange {
data class Constant( data class Constant(
override val rangeStart: Int, override val rangeStart: Int,
val type: Int val type: Int,
) : MappedRange { ) : MappedRange {
val b1: Int val b1: Int
get() = when (type) { get() =
TYPE_IGNORED -> 119 when (type) {
TYPE_VALID -> 120 TYPE_IGNORED -> 119
TYPE_DISALLOWED -> 121 TYPE_VALID -> 120
else -> error("unexpected type: $type") TYPE_DISALLOWED -> 121
} else -> error("unexpected type: $type")
}
} }
data class Inline1( data class Inline1(
override val rangeStart: Int, override val rangeStart: Int,
private val mappedTo: ByteString private val mappedTo: ByteString,
) : MappedRange { ) : MappedRange {
val b1: Int val b1: Int
get() { get() {
val b3bit8 = mappedTo[0] and 0x80 != 0 val b3bit8 = mappedTo[0] and 0x80 != 0
@ -51,9 +51,8 @@ internal sealed interface MappedRange {
data class Inline2( data class Inline2(
override val rangeStart: Int, override val rangeStart: Int,
private val mappedTo: ByteString private val mappedTo: ByteString,
) : MappedRange { ) : MappedRange {
val b1: Int val b1: Int
get() { get() {
val b2bit8 = mappedTo[0] and 0x80 != 0 val b2bit8 = mappedTo[0] and 0x80 != 0
@ -75,17 +74,17 @@ internal sealed interface MappedRange {
data class InlineDelta( data class InlineDelta(
override val rangeStart: Int, override val rangeStart: Int,
val codepointDelta: Int val codepointDelta: Int,
) : MappedRange { ) : MappedRange {
private val absoluteDelta = abs(codepointDelta) private val absoluteDelta = abs(codepointDelta)
val b1: Int val b1: Int
get() = when { get() =
codepointDelta < 0 -> 0x40 or (absoluteDelta shr 14) when {
codepointDelta > 0 -> 0x50 or (absoluteDelta shr 14) codepointDelta < 0 -> 0x40 or (absoluteDelta shr 14)
else -> error("Unexpected codepointDelta of 0") codepointDelta > 0 -> 0x50 or (absoluteDelta shr 14)
} else -> error("Unexpected codepointDelta of 0")
}
val b2: Int val b2: Int
get() = absoluteDelta shr 7 and 0x7f get() = absoluteDelta shr 7 and 0x7f
@ -100,6 +99,6 @@ internal sealed interface MappedRange {
data class External( data class External(
override val rangeStart: Int, override val rangeStart: Int,
val mappedTo: ByteString val mappedTo: ByteString,
) : MappedRange ) : MappedRange
} }

View File

@ -135,26 +135,28 @@ internal fun sections(mappings: List<Mapping>): Map<Int, List<MappedRange>> {
val sectionList = result.getOrPut(section) { mutableListOf() } val sectionList = result.getOrPut(section) { mutableListOf() }
sectionList += when (mapping.type) { sectionList +=
TYPE_MAPPED -> run { when (mapping.type) {
val deltaMapping = inlineDeltaOrNull(mapping) TYPE_MAPPED ->
if (deltaMapping != null) { run {
return@run deltaMapping val deltaMapping = inlineDeltaOrNull(mapping)
if (deltaMapping != null) {
return@run deltaMapping
}
when (mapping.mappedTo.size) {
1 -> MappedRange.Inline1(rangeStart, mapping.mappedTo)
2 -> MappedRange.Inline2(rangeStart, mapping.mappedTo)
else -> MappedRange.External(rangeStart, mapping.mappedTo)
}
}
TYPE_IGNORED, TYPE_VALID, TYPE_DISALLOWED -> {
MappedRange.Constant(rangeStart, mapping.type)
} }
when (mapping.mappedTo.size) { else -> error("unexpected mapping type: ${mapping.type}")
1 -> MappedRange.Inline1(rangeStart, mapping.mappedTo)
2 -> MappedRange.Inline2(rangeStart, mapping.mappedTo)
else -> MappedRange.External(rangeStart, mapping.mappedTo)
}
} }
TYPE_IGNORED, TYPE_VALID, TYPE_DISALLOWED -> {
MappedRange.Constant(rangeStart, mapping.type)
}
else -> error("unexpected mapping type: ${mapping.type}")
}
} }
for (sectionList in result.values) { for (sectionList in result.values) {
@ -202,18 +204,20 @@ internal fun withoutSectionSpans(mappings: List<Mapping>): List<Mapping> {
while (true) { while (true) {
if (current.spansSections) { if (current.spansSections) {
result += Mapping( result +=
current.sourceCodePoint0, Mapping(
current.section + 0x7f, current.sourceCodePoint0,
current.type, current.section + 0x7f,
current.mappedTo, current.type,
) current.mappedTo,
current = Mapping( )
current.section + 0x80, current =
current.sourceCodePoint1, Mapping(
current.type, current.section + 0x80,
current.mappedTo, current.sourceCodePoint1,
) current.type,
current.mappedTo,
)
} else { } else {
result += current result += current
current = if (i.hasNext()) i.next() else break current = if (i.hasNext()) i.next() else break
@ -246,12 +250,13 @@ internal fun mergeAdjacentRanges(mappings: List<Mapping>): List<Mapping> {
index++ index++
} }
result += Mapping( result +=
sourceCodePoint0 = mapping.sourceCodePoint0, Mapping(
sourceCodePoint1 = unionWith.sourceCodePoint1, sourceCodePoint0 = mapping.sourceCodePoint0,
type = type, sourceCodePoint1 = unionWith.sourceCodePoint1,
mappedTo = mappedTo, type = type,
) mappedTo = mappedTo,
)
} }
return result return result
@ -262,11 +267,13 @@ internal fun canonicalizeType(type: Int): Int {
TYPE_IGNORED -> TYPE_IGNORED TYPE_IGNORED -> TYPE_IGNORED
TYPE_MAPPED, TYPE_MAPPED,
TYPE_DISALLOWED_STD3_MAPPED -> TYPE_MAPPED TYPE_DISALLOWED_STD3_MAPPED,
-> TYPE_MAPPED
TYPE_DEVIATION, TYPE_DEVIATION,
TYPE_DISALLOWED_STD3_VALID, TYPE_DISALLOWED_STD3_VALID,
TYPE_VALID -> TYPE_VALID TYPE_VALID,
-> TYPE_VALID
TYPE_DISALLOWED -> TYPE_DISALLOWED TYPE_DISALLOWED -> TYPE_DISALLOWED
@ -279,4 +286,3 @@ internal infix fun Byte.and(mask: Int): Int = toInt() and mask
internal infix fun Short.and(mask: Int): Int = toInt() and mask internal infix fun Short.and(mask: Int): Int = toInt() and mask
internal infix fun Int.and(mask: Long): Long = toLong() and mask internal infix fun Int.and(mask: Long): Long = toLong() and mask

View File

@ -42,14 +42,18 @@ class SimpleIdnaMappingTable internal constructor(
/** /**
* Returns true if the [codePoint] was applied successfully. Returns false if it was disallowed. * Returns true if the [codePoint] was applied successfully. Returns false if it was disallowed.
*/ */
fun map(codePoint: Int, sink: BufferedSink): Boolean { fun map(
val index = mappings.binarySearch { codePoint: Int,
when { sink: BufferedSink,
it.sourceCodePoint1 < codePoint -> -1 ): Boolean {
it.sourceCodePoint0 > codePoint -> 1 val index =
else -> 0 mappings.binarySearch {
when {
it.sourceCodePoint1 < codePoint -> -1
it.sourceCodePoint0 > codePoint -> 1
else -> 0
}
} }
}
// Code points must be in 0..0x10ffff. // Code points must be in 0..0x10ffff.
require(index in mappings.indices) { "unexpected code point: $codePoint" } require(index in mappings.indices) { "unexpected code point: $codePoint" }
@ -77,24 +81,25 @@ class SimpleIdnaMappingTable internal constructor(
} }
} }
private val optionsDelimiter =
Options.of(
// 0.
".".encodeUtf8(),
// 1.
" ".encodeUtf8(),
// 2.
";".encodeUtf8(),
// 3.
"#".encodeUtf8(),
// 4.
"\n".encodeUtf8(),
)
private val optionsDelimiter = Options.of( private val optionsDot =
// 0. Options.of(
".".encodeUtf8(), // 0.
// 1. ".".encodeUtf8(),
" ".encodeUtf8(), )
// 2.
";".encodeUtf8(),
// 3.
"#".encodeUtf8(),
// 4.
"\n".encodeUtf8(),
)
private val optionsDot = Options.of(
// 0.
".".encodeUtf8(),
)
private const val DELIMITER_DOT = 0 private const val DELIMITER_DOT = 0
private const val DELIMITER_SPACE = 1 private const val DELIMITER_SPACE = 1
@ -102,22 +107,23 @@ private const val DELIMITER_SEMICOLON = 2
private const val DELIMITER_HASH = 3 private const val DELIMITER_HASH = 3
private const val DELIMITER_NEWLINE = 4 private const val DELIMITER_NEWLINE = 4
private val optionsType = Options.of( private val optionsType =
// 0. Options.of(
"deviation ".encodeUtf8(), // 0.
// 1. "deviation ".encodeUtf8(),
"disallowed ".encodeUtf8(), // 1.
// 2. "disallowed ".encodeUtf8(),
"disallowed_STD3_mapped ".encodeUtf8(), // 2.
// 3. "disallowed_STD3_mapped ".encodeUtf8(),
"disallowed_STD3_valid ".encodeUtf8(), // 3.
// 4. "disallowed_STD3_valid ".encodeUtf8(),
"ignored ".encodeUtf8(), // 4.
// 5. "ignored ".encodeUtf8(),
"mapped ".encodeUtf8(), // 5.
// 6. "mapped ".encodeUtf8(),
"valid ".encodeUtf8(), // 6.
) "valid ".encodeUtf8(),
)
internal const val TYPE_DEVIATION = 0 internal const val TYPE_DEVIATION = 0
internal const val TYPE_DISALLOWED = 1 internal const val TYPE_DISALLOWED = 1
@ -182,14 +188,15 @@ fun BufferedSource.readPlainTextIdnaMappingTable(): SimpleIdnaMappingTable {
// "002F" or "0000..002C" // "002F" or "0000..002C"
val sourceCodePoint0 = readHexadecimalUnsignedLong() val sourceCodePoint0 = readHexadecimalUnsignedLong()
val sourceCodePoint1 = when (select(optionsDot)) { val sourceCodePoint1 =
DELIMITER_DOT -> { when (select(optionsDot)) {
if (readByte() != '.'.code.toByte()) throw IOException("expected '..'") DELIMITER_DOT -> {
readHexadecimalUnsignedLong() if (readByte() != '.'.code.toByte()) throw IOException("expected '..'")
} readHexadecimalUnsignedLong()
}
else -> sourceCodePoint0 else -> sourceCodePoint0
} }
skipWhitespace() skipWhitespace()
if (readByte() != ';'.code.toByte()) throw IOException("expected ';'") if (readByte() != ';'.code.toByte()) throw IOException("expected ';'")
@ -228,12 +235,13 @@ fun BufferedSource.readPlainTextIdnaMappingTable(): SimpleIdnaMappingTable {
skipRestOfLine() skipRestOfLine()
result += Mapping( result +=
sourceCodePoint0.toInt(), Mapping(
sourceCodePoint1.toInt(), sourceCodePoint0.toInt(),
type, sourceCodePoint1.toInt(),
mappedTo.readByteString(), type,
) mappedTo.readByteString(),
)
} }
return SimpleIdnaMappingTable(result) return SimpleIdnaMappingTable(result)

View File

@ -34,8 +34,8 @@ class MappingTablesTest {
Mapping(0x0234, 0x0236, TYPE_VALID, ByteString.EMPTY), Mapping(0x0234, 0x0236, TYPE_VALID, ByteString.EMPTY),
Mapping(0x0237, 0x0239, TYPE_VALID, ByteString.EMPTY), Mapping(0x0237, 0x0239, TYPE_VALID, ByteString.EMPTY),
Mapping(0x023a, 0x023a, TYPE_MAPPED, "b".encodeUtf8()), Mapping(0x023a, 0x023a, TYPE_MAPPED, "b".encodeUtf8()),
) ),
) ),
).containsExactly( ).containsExactly(
Mapping(0x0232, 0x0232, TYPE_MAPPED, "a".encodeUtf8()), Mapping(0x0232, 0x0232, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0233, 0x0239, TYPE_VALID, ByteString.EMPTY), Mapping(0x0233, 0x0239, TYPE_VALID, ByteString.EMPTY),
@ -49,8 +49,8 @@ class MappingTablesTest {
listOf( listOf(
Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()), Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()), Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()),
) ),
) ),
).containsExactly( ).containsExactly(
Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()), Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()), Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()),
@ -62,8 +62,8 @@ class MappingTablesTest {
mergeAdjacentRanges( mergeAdjacentRanges(
listOf( listOf(
Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY), Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY),
) ),
) ),
).containsExactly( ).containsExactly(
Mapping(0x0000, 0x002c, TYPE_VALID, ByteString.EMPTY), Mapping(0x0000, 0x002c, TYPE_VALID, ByteString.EMPTY),
) )
@ -74,9 +74,9 @@ class MappingTablesTest {
mergeAdjacentRanges( mergeAdjacentRanges(
listOf( listOf(
Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY), Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY),
Mapping(0x002d, 0x002e, TYPE_VALID, ByteString.EMPTY) Mapping(0x002d, 0x002e, TYPE_VALID, ByteString.EMPTY),
) ),
) ),
).containsExactly( ).containsExactly(
Mapping(0x0000, 0x002e, TYPE_VALID, ByteString.EMPTY), Mapping(0x0000, 0x002e, TYPE_VALID, ByteString.EMPTY),
) )
@ -87,8 +87,8 @@ class MappingTablesTest {
withoutSectionSpans( withoutSectionSpans(
listOf( listOf(
Mapping(0x40000, 0x40180, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40000, 0x40180, TYPE_DISALLOWED, ByteString.EMPTY),
) ),
) ),
).containsExactly( ).containsExactly(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
@ -103,8 +103,8 @@ class MappingTablesTest {
listOf( listOf(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
) ),
) ),
).containsExactly( ).containsExactly(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY), Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
@ -119,8 +119,8 @@ class MappingTablesTest {
InlineDelta(2, 5), InlineDelta(2, 5),
InlineDelta(3, 5), InlineDelta(3, 5),
MappedRange.External(4, "a".encodeUtf8()), MappedRange.External(4, "a".encodeUtf8()),
) ),
) ),
).containsExactly( ).containsExactly(
InlineDelta(1, 5), InlineDelta(1, 5),
MappedRange.External(4, "a".encodeUtf8()), MappedRange.External(4, "a".encodeUtf8()),
@ -134,8 +134,8 @@ class MappingTablesTest {
InlineDelta(1, 5), InlineDelta(1, 5),
InlineDelta(2, 5), InlineDelta(2, 5),
InlineDelta(3, 1), InlineDelta(3, 1),
) ),
) ),
).containsExactly( ).containsExactly(
InlineDelta(1, 5), InlineDelta(1, 5),
InlineDelta(3, 1), InlineDelta(3, 1),
@ -148,9 +148,9 @@ class MappingTablesTest {
mappingOf( mappingOf(
sourceCodePoint0 = 1, sourceCodePoint0 = 1,
sourceCodePoint1 = 1, sourceCodePoint1 = 1,
mappedToCodePoints = listOf(2) mappedToCodePoints = listOf(2),
) ),
) ),
).isEqualTo(InlineDelta(1, 1)) ).isEqualTo(InlineDelta(1, 1))
assertThat( assertThat(
@ -158,9 +158,9 @@ class MappingTablesTest {
mappingOf( mappingOf(
sourceCodePoint0 = 2, sourceCodePoint0 = 2,
sourceCodePoint1 = 2, sourceCodePoint1 = 2,
mappedToCodePoints = listOf(1) mappedToCodePoints = listOf(1),
) ),
) ),
).isEqualTo(InlineDelta(2, -1)) ).isEqualTo(InlineDelta(2, -1))
} }
@ -170,9 +170,9 @@ class MappingTablesTest {
mappingOf( mappingOf(
sourceCodePoint0 = 2, sourceCodePoint0 = 2,
sourceCodePoint1 = 3, sourceCodePoint1 = 3,
mappedToCodePoints = listOf(2) mappedToCodePoints = listOf(2),
) ),
) ),
).isEqualTo(null) ).isEqualTo(null)
} }
@ -182,9 +182,9 @@ class MappingTablesTest {
mappingOf( mappingOf(
sourceCodePoint0 = 1, sourceCodePoint0 = 1,
sourceCodePoint1 = 1, sourceCodePoint1 = 1,
mappedToCodePoints = listOf(2, 3) mappedToCodePoints = listOf(2, 3),
) ),
) ),
).isEqualTo(null) ).isEqualTo(null)
} }
@ -194,14 +194,14 @@ class MappingTablesTest {
mappingOf( mappingOf(
sourceCodePoint0 = 0, sourceCodePoint0 = 0,
sourceCodePoint1 = 0, sourceCodePoint1 = 0,
mappedToCodePoints = listOf((1 shl 18) - 1) mappedToCodePoints = listOf((1 shl 18) - 1),
) ),
) ),
).isEqualTo( ).isEqualTo(
InlineDelta( InlineDelta(
rangeStart = 0, rangeStart = 0,
codepointDelta = InlineDelta.MAX_VALUE codepointDelta = InlineDelta.MAX_VALUE,
) ),
) )
assertThat( assertThat(
@ -209,24 +209,26 @@ class MappingTablesTest {
mappingOf( mappingOf(
sourceCodePoint0 = 0, sourceCodePoint0 = 0,
sourceCodePoint1 = 0, sourceCodePoint1 = 0,
mappedToCodePoints = listOf(1 shl 18) mappedToCodePoints = listOf(1 shl 18),
) ),
) ),
).isEqualTo(null) ).isEqualTo(null)
} }
private fun mappingOf( private fun mappingOf(
sourceCodePoint0: Int, sourceCodePoint0: Int,
sourceCodePoint1: Int, sourceCodePoint1: Int,
mappedToCodePoints: List<Int> mappedToCodePoints: List<Int>,
): Mapping = Mapping( ): Mapping =
sourceCodePoint0 = sourceCodePoint0, Mapping(
sourceCodePoint1 = sourceCodePoint1, sourceCodePoint0 = sourceCodePoint0,
type = TYPE_MAPPED, sourceCodePoint1 = sourceCodePoint1,
mappedTo = Buffer().also { type = TYPE_MAPPED,
for (cp in mappedToCodePoints) { mappedTo =
it.writeUtf8CodePoint(cp) Buffer().also {
} for (cp in mappedToCodePoints) {
}.readByteString() it.writeUtf8CodePoint(cp)
) }
}.readByteString(),
)
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.java.net.cookiejar
package okhttp3.java.net.cookiejar
import java.io.IOException import java.io.IOException
import java.net.CookieHandler import java.net.CookieHandler
@ -32,8 +32,10 @@ import okhttp3.internal.trimSubstring
/** A cookie jar that delegates to a [java.net.CookieHandler]. */ /** A cookie jar that delegates to a [java.net.CookieHandler]. */
class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar { class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
override fun saveFromResponse(
override fun saveFromResponse(url: HttpUrl, cookies: List<Cookie>) { url: HttpUrl,
cookies: List<Cookie>,
) {
val cookieStrings = mutableListOf<String>() val cookieStrings = mutableListOf<String>()
for (cookie in cookies) { for (cookie in cookies) {
cookieStrings.add(cookieToString(cookie, true)) cookieStrings.add(cookieToString(cookie, true))
@ -47,18 +49,20 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
} }
override fun loadForRequest(url: HttpUrl): List<Cookie> { override fun loadForRequest(url: HttpUrl): List<Cookie> {
val cookieHeaders = try { val cookieHeaders =
// The RI passes all headers. We don't have 'em, so we don't pass 'em! try {
cookieHandler.get(url.toUri(), emptyMap<String, List<String>>()) // The RI passes all headers. We don't have 'em, so we don't pass 'em!
} catch (e: IOException) { cookieHandler.get(url.toUri(), emptyMap<String, List<String>>())
Platform.get().log("Loading cookies failed for " + url.resolve("/...")!!, WARN, e) } catch (e: IOException) {
return emptyList() Platform.get().log("Loading cookies failed for " + url.resolve("/...")!!, WARN, e)
} return emptyList()
}
var cookies: MutableList<Cookie>? = null var cookies: MutableList<Cookie>? = null
for ((key, value) in cookieHeaders) { for ((key, value) in cookieHeaders) {
if (("Cookie".equals(key, ignoreCase = true) || "Cookie2".equals(key, ignoreCase = true)) && if (("Cookie".equals(key, ignoreCase = true) || "Cookie2".equals(key, ignoreCase = true)) &&
value.isNotEmpty()) { value.isNotEmpty()
) {
for (header in value) { for (header in value) {
if (cookies == null) cookies = mutableListOf() if (cookies == null) cookies = mutableListOf()
cookies.addAll(decodeHeaderAsJavaNetCookies(url, header)) cookies.addAll(decodeHeaderAsJavaNetCookies(url, header))
@ -77,7 +81,10 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
* Convert a request header to OkHttp's cookies via [HttpCookie]. That extra step handles * Convert a request header to OkHttp's cookies via [HttpCookie]. That extra step handles
* multiple cookies in a single request header, which [Cookie.parse] doesn't support. * multiple cookies in a single request header, which [Cookie.parse] doesn't support.
*/ */
private fun decodeHeaderAsJavaNetCookies(url: HttpUrl, header: String): List<Cookie> { private fun decodeHeaderAsJavaNetCookies(
url: HttpUrl,
header: String,
): List<Cookie> {
val result = mutableListOf<Cookie>() val result = mutableListOf<Cookie>()
var pos = 0 var pos = 0
val limit = header.length val limit = header.length
@ -92,22 +99,25 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
} }
// We have either name=value or just a name. // We have either name=value or just a name.
var value = if (equalsSign < pairEnd) { var value =
header.trimSubstring(equalsSign + 1, pairEnd) if (equalsSign < pairEnd) {
} else { header.trimSubstring(equalsSign + 1, pairEnd)
"" } else {
} ""
}
// If the value is "quoted", drop the quotes. // If the value is "quoted", drop the quotes.
if (value.startsWith("\"") && value.endsWith("\"") && value.length >= 2) { if (value.startsWith("\"") && value.endsWith("\"") && value.length >= 2) {
value = value.substring(1, value.length - 1) value = value.substring(1, value.length - 1)
} }
result.add(Cookie.Builder() result.add(
Cookie.Builder()
.name(name) .name(name)
.value(value) .value(value)
.domain(url.host) .domain(url.host)
.build()) .build(),
)
pos = pairEnd + 1 pos = pairEnd + 1
} }
return result return result

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.logging package okhttp3.logging
import java.io.IOException import java.io.IOException
@ -38,288 +39,295 @@ import okio.GzipSource
* The format of the logs created by this class should not be considered stable and may * The format of the logs created by this class should not be considered stable and may
* change slightly between releases. If you need a stable logging format, use your own interceptor. * change slightly between releases. If you need a stable logging format, use your own interceptor.
*/ */
class HttpLoggingInterceptor @JvmOverloads constructor( class HttpLoggingInterceptor
private val logger: Logger = Logger.DEFAULT @JvmOverloads
) : Interceptor { constructor(
private val logger: Logger = Logger.DEFAULT,
) : Interceptor {
@Volatile private var headersToRedact = emptySet<String>()
@Volatile private var headersToRedact = emptySet<String>() @set:JvmName("level")
@Volatile
var level = Level.NONE
@set:JvmName("level") enum class Level {
@Volatile var level = Level.NONE /** No logs. */
NONE,
enum class Level { /**
/** No logs. */ * Logs request and response lines.
NONE, *
* Example:
* ```
* --> POST /greeting http/1.1 (3-byte body)
*
* <-- 200 OK (22ms, 6-byte body)
* ```
*/
BASIC,
/** /**
* Logs request and response lines. * Logs request and response lines and their respective headers.
* *
* Example: * Example:
* ``` * ```
* --> POST /greeting http/1.1 (3-byte body) * --> POST /greeting http/1.1
* * Host: example.com
* <-- 200 OK (22ms, 6-byte body) * Content-Type: plain/text
* ``` * Content-Length: 3
*/ * --> END POST
BASIC, *
* <-- 200 OK (22ms)
* Content-Type: plain/text
* Content-Length: 6
* <-- END HTTP
* ```
*/
HEADERS,
/** /**
* Logs request and response lines and their respective headers. * Logs request and response lines and their respective headers and bodies (if present).
* *
* Example: * Example:
* ``` * ```
* --> POST /greeting http/1.1 * --> POST /greeting http/1.1
* Host: example.com * Host: example.com
* Content-Type: plain/text * Content-Type: plain/text
* Content-Length: 3 * Content-Length: 3
* --> END POST *
* * Hi?
* <-- 200 OK (22ms) * --> END POST
* Content-Type: plain/text *
* Content-Length: 6 * <-- 200 OK (22ms)
* <-- END HTTP * Content-Type: plain/text
* ``` * Content-Length: 6
*/ *
HEADERS, * Hello!
* <-- END HTTP
* ```
*/
BODY,
}
/** fun interface Logger {
* Logs request and response lines and their respective headers and bodies (if present). fun log(message: String)
*
* Example:
* ```
* --> POST /greeting http/1.1
* Host: example.com
* Content-Type: plain/text
* Content-Length: 3
*
* Hi?
* --> END POST
*
* <-- 200 OK (22ms)
* Content-Type: plain/text
* Content-Length: 6
*
* Hello!
* <-- END HTTP
* ```
*/
BODY
}
fun interface Logger { companion object {
fun log(message: String) /** A [Logger] defaults output appropriate for the current platform. */
@JvmField
val DEFAULT: Logger = DefaultLogger()
companion object { private class DefaultLogger : Logger {
/** A [Logger] defaults output appropriate for the current platform. */ override fun log(message: String) {
@JvmField Platform.get().log(message)
val DEFAULT: Logger = DefaultLogger() }
private class DefaultLogger : Logger {
override fun log(message: String) {
Platform.get().log(message)
} }
} }
} }
}
fun redactHeader(name: String) { fun redactHeader(name: String) {
val newHeadersToRedact = TreeSet(String.CASE_INSENSITIVE_ORDER) val newHeadersToRedact = TreeSet(String.CASE_INSENSITIVE_ORDER)
newHeadersToRedact += headersToRedact newHeadersToRedact += headersToRedact
newHeadersToRedact += name newHeadersToRedact += name
headersToRedact = newHeadersToRedact headersToRedact = newHeadersToRedact
} }
/** /**
* Sets the level and returns this. * Sets the level and returns this.
* *
* This was deprecated in OkHttp 4.0 in favor of the [level] val. In OkHttp 4.3 it is * This was deprecated in OkHttp 4.0 in favor of the [level] val. In OkHttp 4.3 it is
* un-deprecated because Java callers can't chain when assigning Kotlin vals. (The getter remains * un-deprecated because Java callers can't chain when assigning Kotlin vals. (The getter remains
* deprecated). * deprecated).
*/ */
fun setLevel(level: Level) = apply { fun setLevel(level: Level) =
this.level = level apply {
} this.level = level
}
@JvmName("-deprecated_level") @JvmName("-deprecated_level")
@Deprecated( @Deprecated(
message = "moved to var", message = "moved to var",
replaceWith = ReplaceWith(expression = "level"), replaceWith = ReplaceWith(expression = "level"),
level = DeprecationLevel.ERROR level = DeprecationLevel.ERROR,
)
fun getLevel(): Level = level
@Throws(IOException::class)
override fun intercept(chain: Interceptor.Chain): Response {
val level = this.level
val request = chain.request()
if (level == Level.NONE) {
return chain.proceed(request)
}
val logBody = level == Level.BODY
val logHeaders = logBody || level == Level.HEADERS
val requestBody = request.body
val connection = chain.connection()
var requestStartMessage =
("--> ${request.method} ${request.url}${if (connection != null) " " + connection.protocol() else ""}")
if (!logHeaders && requestBody != null) {
requestStartMessage += " (${requestBody.contentLength()}-byte body)"
}
logger.log(requestStartMessage)
if (logHeaders) {
val headers = request.headers
if (requestBody != null) {
// Request body headers are only present when installed as a network interceptor. When not
// already present, force them to be included (if available) so their values are known.
requestBody.contentType()?.let {
if (headers["Content-Type"] == null) {
logger.log("Content-Type: $it")
}
}
if (requestBody.contentLength() != -1L) {
if (headers["Content-Length"] == null) {
logger.log("Content-Length: ${requestBody.contentLength()}")
}
}
}
for (i in 0 until headers.size) {
logHeader(headers, i)
}
if (!logBody || requestBody == null) {
logger.log("--> END ${request.method}")
} else if (bodyHasUnknownEncoding(request.headers)) {
logger.log("--> END ${request.method} (encoded body omitted)")
} else if (requestBody.isDuplex()) {
logger.log("--> END ${request.method} (duplex request body omitted)")
} else if (requestBody.isOneShot()) {
logger.log("--> END ${request.method} (one-shot body omitted)")
} else {
var buffer = Buffer()
requestBody.writeTo(buffer)
var gzippedLength: Long? = null
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
gzippedLength = buffer.size
GzipSource(buffer).use { gzippedResponseBody ->
buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
}
}
val charset: Charset = requestBody.contentType().charsetOrUtf8()
logger.log("")
if (!buffer.isProbablyUtf8()) {
logger.log(
"--> END ${request.method} (binary ${requestBody.contentLength()}-byte body omitted)"
)
} else if (gzippedLength != null) {
logger.log("--> END ${request.method} (${buffer.size}-byte, $gzippedLength-gzipped-byte body)")
} else {
logger.log(buffer.readString(charset))
logger.log("--> END ${request.method} (${requestBody.contentLength()}-byte body)")
}
}
}
val startNs = System.nanoTime()
val response: Response
try {
response = chain.proceed(request)
} catch (e: Exception) {
logger.log("<-- HTTP FAILED: $e")
throw e
}
val tookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
val responseBody = response.body!!
val contentLength = responseBody.contentLength()
val bodySize = if (contentLength != -1L) "$contentLength-byte" else "unknown-length"
logger.log(
buildString {
append("<-- ${response.code}")
if (response.message.isNotEmpty()) append(" ${response.message}")
append(" ${response.request.url} (${tookMs}ms")
if (!logHeaders) append(", $bodySize body")
append(")")
}
) )
fun getLevel(): Level = level
if (logHeaders) { @Throws(IOException::class)
val headers = response.headers override fun intercept(chain: Interceptor.Chain): Response {
for (i in 0 until headers.size) { val level = this.level
logHeader(headers, i)
val request = chain.request()
if (level == Level.NONE) {
return chain.proceed(request)
} }
if (!logBody || !response.promisesBody()) { val logBody = level == Level.BODY
logger.log("<-- END HTTP") val logHeaders = logBody || level == Level.HEADERS
} else if (bodyHasUnknownEncoding(response.headers)) {
logger.log("<-- END HTTP (encoded body omitted)")
} else if (bodyIsStreaming(response)) {
logger.log("<-- END HTTP (streaming)")
} else {
val source = responseBody.source()
source.request(Long.MAX_VALUE) // Buffer the entire body.
val totalMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs) val requestBody = request.body
var buffer = source.buffer val connection = chain.connection()
var requestStartMessage =
("--> ${request.method} ${request.url}${if (connection != null) " " + connection.protocol() else ""}")
if (!logHeaders && requestBody != null) {
requestStartMessage += " (${requestBody.contentLength()}-byte body)"
}
logger.log(requestStartMessage)
var gzippedLength: Long? = null if (logHeaders) {
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) { val headers = request.headers
gzippedLength = buffer.size
GzipSource(buffer.clone()).use { gzippedResponseBody -> if (requestBody != null) {
buffer = Buffer() // Request body headers are only present when installed as a network interceptor. When not
buffer.writeAll(gzippedResponseBody) // already present, force them to be included (if available) so their values are known.
requestBody.contentType()?.let {
if (headers["Content-Type"] == null) {
logger.log("Content-Type: $it")
}
}
if (requestBody.contentLength() != -1L) {
if (headers["Content-Length"] == null) {
logger.log("Content-Length: ${requestBody.contentLength()}")
}
} }
} }
val charset: Charset = responseBody.contentType().charsetOrUtf8() for (i in 0 until headers.size) {
logHeader(headers, i)
if (!buffer.isProbablyUtf8()) {
logger.log("")
logger.log("<-- END HTTP (${totalMs}ms, binary ${buffer.size}-byte body omitted)")
return response
} }
if (contentLength != 0L) { if (!logBody || requestBody == null) {
logger.log("") logger.log("--> END ${request.method}")
logger.log(buffer.clone().readString(charset)) } else if (bodyHasUnknownEncoding(request.headers)) {
} logger.log("--> END ${request.method} (encoded body omitted)")
} else if (requestBody.isDuplex()) {
logger.log("--> END ${request.method} (duplex request body omitted)")
} else if (requestBody.isOneShot()) {
logger.log("--> END ${request.method} (one-shot body omitted)")
} else {
var buffer = Buffer()
requestBody.writeTo(buffer)
logger.log( var gzippedLength: Long? = null
buildString { if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
append("<-- END HTTP (${totalMs}ms, ${buffer.size}-byte") gzippedLength = buffer.size
if (gzippedLength != null) append(", $gzippedLength-gzipped-byte") GzipSource(buffer).use { gzippedResponseBody ->
append(" body)") buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
}
} }
)
val charset: Charset = requestBody.contentType().charsetOrUtf8()
logger.log("")
if (!buffer.isProbablyUtf8()) {
logger.log(
"--> END ${request.method} (binary ${requestBody.contentLength()}-byte body omitted)",
)
} else if (gzippedLength != null) {
logger.log("--> END ${request.method} (${buffer.size}-byte, $gzippedLength-gzipped-byte body)")
} else {
logger.log(buffer.readString(charset))
logger.log("--> END ${request.method} (${requestBody.contentLength()}-byte body)")
}
}
} }
val startNs = System.nanoTime()
val response: Response
try {
response = chain.proceed(request)
} catch (e: Exception) {
logger.log("<-- HTTP FAILED: $e")
throw e
}
val tookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
val responseBody = response.body!!
val contentLength = responseBody.contentLength()
val bodySize = if (contentLength != -1L) "$contentLength-byte" else "unknown-length"
logger.log(
buildString {
append("<-- ${response.code}")
if (response.message.isNotEmpty()) append(" ${response.message}")
append(" ${response.request.url} (${tookMs}ms")
if (!logHeaders) append(", $bodySize body")
append(")")
},
)
if (logHeaders) {
val headers = response.headers
for (i in 0 until headers.size) {
logHeader(headers, i)
}
if (!logBody || !response.promisesBody()) {
logger.log("<-- END HTTP")
} else if (bodyHasUnknownEncoding(response.headers)) {
logger.log("<-- END HTTP (encoded body omitted)")
} else if (bodyIsStreaming(response)) {
logger.log("<-- END HTTP (streaming)")
} else {
val source = responseBody.source()
source.request(Long.MAX_VALUE) // Buffer the entire body.
val totalMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
var buffer = source.buffer
var gzippedLength: Long? = null
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
gzippedLength = buffer.size
GzipSource(buffer.clone()).use { gzippedResponseBody ->
buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
}
}
val charset: Charset = responseBody.contentType().charsetOrUtf8()
if (!buffer.isProbablyUtf8()) {
logger.log("")
logger.log("<-- END HTTP (${totalMs}ms, binary ${buffer.size}-byte body omitted)")
return response
}
if (contentLength != 0L) {
logger.log("")
logger.log(buffer.clone().readString(charset))
}
logger.log(
buildString {
append("<-- END HTTP (${totalMs}ms, ${buffer.size}-byte")
if (gzippedLength != null) append(", $gzippedLength-gzipped-byte")
append(" body)")
},
)
}
}
return response
} }
return response private fun logHeader(
} headers: Headers,
i: Int,
) {
val value = if (headers.name(i) in headersToRedact) "██" else headers.value(i)
logger.log(headers.name(i) + ": " + value)
}
private fun logHeader(headers: Headers, i: Int) { private fun bodyHasUnknownEncoding(headers: Headers): Boolean {
val value = if (headers.name(i) in headersToRedact) "██" else headers.value(i) val contentEncoding = headers["Content-Encoding"] ?: return false
logger.log(headers.name(i) + ": " + value) return !contentEncoding.equals("identity", ignoreCase = true) &&
}
private fun bodyHasUnknownEncoding(headers: Headers): Boolean {
val contentEncoding = headers["Content-Encoding"] ?: return false
return !contentEncoding.equals("identity", ignoreCase = true) &&
!contentEncoding.equals("gzip", ignoreCase = true) !contentEncoding.equals("gzip", ignoreCase = true)
} }
private fun bodyIsStreaming(response: Response): Boolean { private fun bodyIsStreaming(response: Response): Boolean {
val contentType = response.body.contentType() val contentType = response.body.contentType()
return contentType != null && contentType.type == "text" && contentType.subtype == "event-stream" return contentType != null && contentType.type == "text" && contentType.subtype == "event-stream"
}
} }
}

View File

@ -38,7 +38,7 @@ import okhttp3.Response
* slightly between releases. If you need a stable logging format, use your own event listener. * slightly between releases. If you need a stable logging format, use your own event listener.
*/ */
class LoggingEventListener private constructor( class LoggingEventListener private constructor(
private val logger: HttpLoggingInterceptor.Logger private val logger: HttpLoggingInterceptor.Logger,
) : EventListener() { ) : EventListener() {
private var startNs: Long = 0 private var startNs: Long = 0
@ -48,23 +48,41 @@ class LoggingEventListener private constructor(
logWithTime("callStart: ${call.request()}") logWithTime("callStart: ${call.request()}")
} }
override fun proxySelectStart(call: Call, url: HttpUrl) { override fun proxySelectStart(
call: Call,
url: HttpUrl,
) {
logWithTime("proxySelectStart: $url") logWithTime("proxySelectStart: $url")
} }
override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List<Proxy>) { override fun proxySelectEnd(
call: Call,
url: HttpUrl,
proxies: List<Proxy>,
) {
logWithTime("proxySelectEnd: $proxies") logWithTime("proxySelectEnd: $proxies")
} }
override fun dnsStart(call: Call, domainName: String) { override fun dnsStart(
call: Call,
domainName: String,
) {
logWithTime("dnsStart: $domainName") logWithTime("dnsStart: $domainName")
} }
override fun dnsEnd(call: Call, domainName: String, inetAddressList: List<InetAddress>) { override fun dnsEnd(
call: Call,
domainName: String,
inetAddressList: List<InetAddress>,
) {
logWithTime("dnsEnd: $inetAddressList") logWithTime("dnsEnd: $inetAddressList")
} }
override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) { override fun connectStart(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
) {
logWithTime("connectStart: $inetSocketAddress $proxy") logWithTime("connectStart: $inetSocketAddress $proxy")
} }
@ -72,7 +90,10 @@ class LoggingEventListener private constructor(
logWithTime("secureConnectStart") logWithTime("secureConnectStart")
} }
override fun secureConnectEnd(call: Call, handshake: Handshake?) { override fun secureConnectEnd(
call: Call,
handshake: Handshake?,
) {
logWithTime("secureConnectEnd: $handshake") logWithTime("secureConnectEnd: $handshake")
} }
@ -80,7 +101,7 @@ class LoggingEventListener private constructor(
call: Call, call: Call,
inetSocketAddress: InetSocketAddress, inetSocketAddress: InetSocketAddress,
proxy: Proxy, proxy: Proxy,
protocol: Protocol? protocol: Protocol?,
) { ) {
logWithTime("connectEnd: $protocol") logWithTime("connectEnd: $protocol")
} }
@ -90,16 +111,22 @@ class LoggingEventListener private constructor(
inetSocketAddress: InetSocketAddress, inetSocketAddress: InetSocketAddress,
proxy: Proxy, proxy: Proxy,
protocol: Protocol?, protocol: Protocol?,
ioe: IOException ioe: IOException,
) { ) {
logWithTime("connectFailed: $protocol $ioe") logWithTime("connectFailed: $protocol $ioe")
} }
override fun connectionAcquired(call: Call, connection: Connection) { override fun connectionAcquired(
call: Call,
connection: Connection,
) {
logWithTime("connectionAcquired: $connection") logWithTime("connectionAcquired: $connection")
} }
override fun connectionReleased(call: Call, connection: Connection) { override fun connectionReleased(
call: Call,
connection: Connection,
) {
logWithTime("connectionReleased") logWithTime("connectionReleased")
} }
@ -107,7 +134,10 @@ class LoggingEventListener private constructor(
logWithTime("requestHeadersStart") logWithTime("requestHeadersStart")
} }
override fun requestHeadersEnd(call: Call, request: Request) { override fun requestHeadersEnd(
call: Call,
request: Request,
) {
logWithTime("requestHeadersEnd") logWithTime("requestHeadersEnd")
} }
@ -115,11 +145,17 @@ class LoggingEventListener private constructor(
logWithTime("requestBodyStart") logWithTime("requestBodyStart")
} }
override fun requestBodyEnd(call: Call, byteCount: Long) { override fun requestBodyEnd(
call: Call,
byteCount: Long,
) {
logWithTime("requestBodyEnd: byteCount=$byteCount") logWithTime("requestBodyEnd: byteCount=$byteCount")
} }
override fun requestFailed(call: Call, ioe: IOException) { override fun requestFailed(
call: Call,
ioe: IOException,
) {
logWithTime("requestFailed: $ioe") logWithTime("requestFailed: $ioe")
} }
@ -127,7 +163,10 @@ class LoggingEventListener private constructor(
logWithTime("responseHeadersStart") logWithTime("responseHeadersStart")
} }
override fun responseHeadersEnd(call: Call, response: Response) { override fun responseHeadersEnd(
call: Call,
response: Response,
) {
logWithTime("responseHeadersEnd: $response") logWithTime("responseHeadersEnd: $response")
} }
@ -135,11 +174,17 @@ class LoggingEventListener private constructor(
logWithTime("responseBodyStart") logWithTime("responseBodyStart")
} }
override fun responseBodyEnd(call: Call, byteCount: Long) { override fun responseBodyEnd(
call: Call,
byteCount: Long,
) {
logWithTime("responseBodyEnd: byteCount=$byteCount") logWithTime("responseBodyEnd: byteCount=$byteCount")
} }
override fun responseFailed(call: Call, ioe: IOException) { override fun responseFailed(
call: Call,
ioe: IOException,
) {
logWithTime("responseFailed: $ioe") logWithTime("responseFailed: $ioe")
} }
@ -147,7 +192,10 @@ class LoggingEventListener private constructor(
logWithTime("callEnd") logWithTime("callEnd")
} }
override fun callFailed(call: Call, ioe: IOException) { override fun callFailed(
call: Call,
ioe: IOException,
) {
logWithTime("callFailed: $ioe") logWithTime("callFailed: $ioe")
} }
@ -155,11 +203,17 @@ class LoggingEventListener private constructor(
logWithTime("canceled") logWithTime("canceled")
} }
override fun satisfactionFailure(call: Call, response: Response) { override fun satisfactionFailure(
call: Call,
response: Response,
) {
logWithTime("satisfactionFailure: $response") logWithTime("satisfactionFailure: $response")
} }
override fun cacheHit(call: Call, response: Response) { override fun cacheHit(
call: Call,
response: Response,
) {
logWithTime("cacheHit: $response") logWithTime("cacheHit: $response")
} }
@ -167,7 +221,10 @@ class LoggingEventListener private constructor(
logWithTime("cacheMiss") logWithTime("cacheMiss")
} }
override fun cacheConditionalHit(call: Call, cachedResponse: Response) { override fun cacheConditionalHit(
call: Call,
cachedResponse: Response,
) {
logWithTime("cacheConditionalHit: $cachedResponse") logWithTime("cacheConditionalHit: $cachedResponse")
} }
@ -176,9 +233,11 @@ class LoggingEventListener private constructor(
logger.log("[$timeMs ms] $message") logger.log("[$timeMs ms] $message")
} }
open class Factory @JvmOverloads constructor( open class Factory
private val logger: HttpLoggingInterceptor.Logger = HttpLoggingInterceptor.Logger.DEFAULT @JvmOverloads
) : EventListener.Factory { constructor(
override fun create(call: Call): EventListener = LoggingEventListener(logger) private val logger: HttpLoggingInterceptor.Logger = HttpLoggingInterceptor.Logger.DEFAULT,
} ) : EventListener.Factory {
override fun create(call: Call): EventListener = LoggingEventListener(logger)
}
} }

View File

@ -72,21 +72,24 @@ class HttpLoggingInterceptorTest {
@BeforeEach @BeforeEach
fun setUp(server: MockWebServer) { fun setUp(server: MockWebServer) {
this.server = server this.server = server
client = OkHttpClient.Builder() client =
.addNetworkInterceptor(Interceptor { chain -> OkHttpClient.Builder()
when { .addNetworkInterceptor(
extraNetworkInterceptor != null -> extraNetworkInterceptor!!.intercept(chain) Interceptor { chain ->
else -> chain.proceed(chain.request()) when {
} extraNetworkInterceptor != null -> extraNetworkInterceptor!!.intercept(chain)
}) else -> chain.proceed(chain.request())
.addNetworkInterceptor(networkInterceptor) }
.addInterceptor(applicationInterceptor) },
.sslSocketFactory( )
handshakeCertificates.sslSocketFactory(), .addNetworkInterceptor(networkInterceptor)
handshakeCertificates.trustManager, .addInterceptor(applicationInterceptor)
) .sslSocketFactory(
.hostnameVerifier(hostnameVerifier) handshakeCertificates.sslSocketFactory(),
.build() handshakeCertificates.trustManager,
)
.hostnameVerifier(hostnameVerifier)
.build()
host = "${server.hostName}:${server.port}" host = "${server.hostName}:${server.port}"
url = server.url("/") url = server.url("/")
} }
@ -153,7 +156,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -174,7 +177,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.chunkedBody("Hello!", 2) .chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -278,13 +281,14 @@ class HttpLoggingInterceptorTest {
fun headersPostNoLength() { fun headersPostNoLength() {
setLevel(Level.HEADERS) setLevel(Level.HEADERS)
server.enqueue(MockResponse()) server.enqueue(MockResponse())
val body: RequestBody = object : RequestBody() { val body: RequestBody =
override fun contentType() = PLAIN object : RequestBody() {
override fun contentType() = PLAIN
override fun writeTo(sink: BufferedSink) { override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hi!") sink.writeUtf8("Hi!")
}
} }
}
val response = client.newCall(request().post(body).build()).execute() val response = client.newCall(request().post(body).build()).execute()
response.body.close() response.body.close()
applicationLogs applicationLogs
@ -313,20 +317,21 @@ class HttpLoggingInterceptorTest {
@Test @Test
fun headersPostWithHeaderOverrides() { fun headersPostWithHeaderOverrides() {
setLevel(Level.HEADERS) setLevel(Level.HEADERS)
extraNetworkInterceptor = Interceptor { chain: Interceptor.Chain -> extraNetworkInterceptor =
chain.proceed( Interceptor { chain: Interceptor.Chain ->
chain.request() chain.proceed(
.newBuilder() chain.request()
.header("Content-Length", "2") .newBuilder()
.header("Content-Type", "text/plain-ish") .header("Content-Length", "2")
.build() .header("Content-Type", "text/plain-ish")
) .build(),
} )
}
server.enqueue(MockResponse()) server.enqueue(MockResponse())
client.newCall( client.newCall(
request() request()
.post("Hi?".toRequestBody(PLAIN)) .post("Hi?".toRequestBody(PLAIN))
.build() .build(),
).execute() ).execute()
applicationLogs applicationLogs
.assertLogEqual("--> POST $url") .assertLogEqual("--> POST $url")
@ -359,7 +364,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -427,7 +432,7 @@ class HttpLoggingInterceptorTest {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.status("HTTP/1.1 $code No Content") .status("HTTP/1.1 $code No Content")
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -493,7 +498,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -530,7 +535,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.chunkedBody("Hello!", 2) .chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -567,14 +572,15 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.body(Buffer().writeUtf8("Uncompressed")) .body(Buffer().writeUtf8("Uncompressed"))
.build() .build(),
) )
val response = client.newCall( val response =
request() client.newCall(
.addHeader("Content-Encoding", "gzip") request()
.post("Uncompressed".toRequestBody().gzip()) .addHeader("Content-Encoding", "gzip")
.build() .post("Uncompressed".toRequestBody().gzip())
).execute() .build(),
).execute()
val responseBody = response.body val responseBody = response.body
assertThat(responseBody.string(), "Expected response body to be valid") assertThat(responseBody.string(), "Expected response body to be valid")
.isEqualTo("Uncompressed") .isEqualTo("Uncompressed")
@ -606,7 +612,7 @@ class HttpLoggingInterceptorTest {
.setHeader("Content-Encoding", "gzip") .setHeader("Content-Encoding", "gzip")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.body(Buffer().write("H4sIAAAAAAAAAPNIzcnJ11HwQKIAdyO+9hMAAAA=".decodeBase64()!!)) .body(Buffer().write("H4sIAAAAAAAAAPNIzcnJ11HwQKIAdyO+9hMAAAA=".decodeBase64()!!))
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
val responseBody = response.body val responseBody = response.body
@ -647,7 +653,7 @@ class HttpLoggingInterceptorTest {
.setHeader("Content-Encoding", "br") .setHeader("Content-Encoding", "br")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.body(Buffer().write("iwmASGVsbG8sIEhlbGxvLCBIZWxsbwoD".decodeBase64()!!)) .body(Buffer().write("iwmASGVsbG8sIEhlbGxvLCBIZWxsbwoD".decodeBase64()!!))
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -693,9 +699,10 @@ class HttpLoggingInterceptorTest {
|data: 113411 |data: 113411
| |
| |
""".trimMargin(), 8 """.trimMargin(),
8,
) )
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -728,7 +735,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.setHeader("Content-Type", "text/html; charset=0") .setHeader("Content-Type", "text/html; charset=0")
.body("Body with unknown charset") .body("Body with unknown charset")
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -774,7 +781,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder() MockResponse.Builder()
.body(buffer) .body(buffer)
.setHeader("Content-Type", "image/png; charset=utf-8") .setHeader("Content-Type", "image/png; charset=utf-8")
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
response.body.close() response.body.close()
@ -805,10 +812,11 @@ class HttpLoggingInterceptorTest {
@Test @Test
fun connectFail() { fun connectFail() {
setLevel(Level.BASIC) setLevel(Level.BASIC)
client = OkHttpClient.Builder() client =
.dns { hostname: String? -> throw UnknownHostException("reason") } OkHttpClient.Builder()
.addInterceptor(applicationInterceptor) .dns { hostname: String? -> throw UnknownHostException("reason") }
.build() .addInterceptor(applicationInterceptor)
.build()
try { try {
client.newCall(request().build()).execute() client.newCall(request().build()).execute()
fail<Any>() fail<Any>()
@ -840,31 +848,35 @@ class HttpLoggingInterceptorTest {
@Test @Test
fun headersAreRedacted() { fun headersAreRedacted() {
val networkInterceptor = HttpLoggingInterceptor(networkLogs).setLevel( val networkInterceptor =
Level.HEADERS HttpLoggingInterceptor(networkLogs).setLevel(
) Level.HEADERS,
)
networkInterceptor.redactHeader("sEnSiTiVe") networkInterceptor.redactHeader("sEnSiTiVe")
val applicationInterceptor = HttpLoggingInterceptor(applicationLogs).setLevel( val applicationInterceptor =
Level.HEADERS HttpLoggingInterceptor(applicationLogs).setLevel(
) Level.HEADERS,
)
applicationInterceptor.redactHeader("sEnSiTiVe") applicationInterceptor.redactHeader("sEnSiTiVe")
client = OkHttpClient.Builder() client =
.addNetworkInterceptor(networkInterceptor) OkHttpClient.Builder()
.addInterceptor(applicationInterceptor) .addNetworkInterceptor(networkInterceptor)
.build() .addInterceptor(applicationInterceptor)
.build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.addHeader("SeNsItIvE", "Value").addHeader("Not-Sensitive", "Value") .addHeader("SeNsItIvE", "Value").addHeader("Not-Sensitive", "Value")
.build() .build(),
) )
val response = client val response =
.newCall( client
request() .newCall(
.addHeader("SeNsItIvE", "Value") request()
.addHeader("Not-Sensitive", "Value") .addHeader("SeNsItIvE", "Value")
.build() .addHeader("Not-Sensitive", "Value")
) .build(),
.execute() )
.execute()
response.body.close() response.body.close()
applicationLogs applicationLogs
.assertLogEqual("--> GET $url") .assertLogEqual("--> GET $url")
@ -903,25 +915,27 @@ class HttpLoggingInterceptorTest {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.body("Hello response!") .body("Hello response!")
.build() .build(),
) )
val asyncRequestBody: RequestBody = object : RequestBody() { val asyncRequestBody: RequestBody =
override fun contentType(): MediaType? { object : RequestBody() {
return null override fun contentType(): MediaType? {
} return null
}
override fun writeTo(sink: BufferedSink) { override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hello request!") sink.writeUtf8("Hello request!")
sink.close() sink.close()
} }
override fun isDuplex(): Boolean { override fun isDuplex(): Boolean {
return true return true
}
} }
} val request =
val request = request() request()
.post(asyncRequestBody) .post(asyncRequestBody)
.build() .build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
Assumptions.assumeTrue(response.protocol == Protocol.HTTP_2) Assumptions.assumeTrue(response.protocol == Protocol.HTTP_2)
assertThat(response.body.string()).isEqualTo("Hello response!") assertThat(response.body.string()).isEqualTo("Hello response!")
@ -943,25 +957,27 @@ class HttpLoggingInterceptorTest {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.body("Hello response!") .body("Hello response!")
.build() .build(),
) )
val asyncRequestBody: RequestBody = object : RequestBody() { val asyncRequestBody: RequestBody =
var counter = 0 object : RequestBody() {
var counter = 0
override fun contentType() = null override fun contentType() = null
override fun writeTo(sink: BufferedSink) { override fun writeTo(sink: BufferedSink) {
counter++ counter++
assertThat(counter).isLessThanOrEqualTo(1) assertThat(counter).isLessThanOrEqualTo(1)
sink.writeUtf8("Hello request!") sink.writeUtf8("Hello request!")
sink.close() sink.close()
}
override fun isOneShot() = true
} }
val request =
override fun isOneShot() = true request()
} .post(asyncRequestBody)
val request = request() .build()
.post(asyncRequestBody)
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
assertThat(response.body.string()).isEqualTo("Hello response!") assertThat(response.body.string()).isEqualTo("Hello response!")
applicationLogs applicationLogs
@ -985,19 +1001,21 @@ class HttpLoggingInterceptorTest {
private val logs = mutableListOf<String>() private val logs = mutableListOf<String>()
private var index = 0 private var index = 0
fun assertLogEqual(expected: String) = apply { fun assertLogEqual(expected: String) =
assertThat(index, "No more messages found") apply {
.isLessThan(logs.size) assertThat(index, "No more messages found")
assertThat(logs[index++]).isEqualTo(expected) .isLessThan(logs.size)
return this assertThat(logs[index++]).isEqualTo(expected)
} return this
}
fun assertLogMatch(regex: Regex) = apply { fun assertLogMatch(regex: Regex) =
assertThat(index, "No more messages found") apply {
.isLessThan(logs.size) assertThat(index, "No more messages found")
assertThat(logs[index++]) .isLessThan(logs.size)
.matches(Regex(prefix.pattern + regex.pattern, RegexOption.DOT_MATCHES_ALL)) assertThat(logs[index++])
} .matches(Regex(prefix.pattern + regex.pattern, RegexOption.DOT_MATCHES_ALL))
}
fun assertNoMoreLogs() { fun assertNoMoreLogs() {
assertThat(logs.size, "More messages remain: ${logs.subList(index, logs.size)}") assertThat(logs.size, "More messages remain: ${logs.subList(index, logs.size)}")

View File

@ -50,9 +50,10 @@ class LoggingEventListenerTest {
val clientTestRule = OkHttpClientTestRule() val clientTestRule = OkHttpClientTestRule()
private lateinit var server: MockWebServer private lateinit var server: MockWebServer
private val handshakeCertificates = platform.localhostHandshakeCertificates() private val handshakeCertificates = platform.localhostHandshakeCertificates()
private val logRecorder = HttpLoggingInterceptorTest.LogRecorder( private val logRecorder =
prefix = Regex("""\[\d+ ms] """) HttpLoggingInterceptorTest.LogRecorder(
) prefix = Regex("""\[\d+ ms] """),
)
private val loggingEventListenerFactory = LoggingEventListener.Factory(logRecorder) private val loggingEventListenerFactory = LoggingEventListener.Factory(logRecorder)
private lateinit var client: OkHttpClient private lateinit var client: OkHttpClient
private lateinit var url: HttpUrl private lateinit var url: HttpUrl
@ -60,14 +61,15 @@ class LoggingEventListenerTest {
@BeforeEach @BeforeEach
fun setUp(server: MockWebServer) { fun setUp(server: MockWebServer) {
this.server = server this.server = server
client = clientTestRule.newClientBuilder() client =
.eventListenerFactory(loggingEventListenerFactory) clientTestRule.newClientBuilder()
.sslSocketFactory( .eventListenerFactory(loggingEventListenerFactory)
handshakeCertificates.sslSocketFactory(), .sslSocketFactory(
handshakeCertificates.trustManager handshakeCertificates.sslSocketFactory(),
) handshakeCertificates.trustManager,
.retryOnConnectionFailure(false) )
.build() .retryOnConnectionFailure(false)
.build()
url = server.url("/") url = server.url("/")
} }
@ -78,7 +80,7 @@ class LoggingEventListenerTest {
MockResponse.Builder() MockResponse.Builder()
.body("Hello!") .body("Hello!")
.setHeader("Content-Type", PLAIN) .setHeader("Content-Type", PLAIN)
.build() .build(),
) )
val response = client.newCall(request().build()).execute() val response = client.newCall(request().build()).execute()
assertThat(response.body).isNotNull() assertThat(response.body).isNotNull()
@ -91,7 +93,11 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]""")) .assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT""")) .assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""connectEnd: http/1.1""")) .assertLogMatch(Regex("""connectEnd: http/1.1"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""")) .assertLogMatch(
Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
),
)
.assertLogMatch(Regex("""requestHeadersStart""")) .assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd""")) .assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart""")) .assertLogMatch(Regex("""responseHeadersStart"""))
@ -116,7 +122,11 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]""")) .assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT""")) .assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""connectEnd: http/1.1""")) .assertLogMatch(Regex("""connectEnd: http/1.1"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""")) .assertLogMatch(
Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
),
)
.assertLogMatch(Regex("""requestHeadersStart""")) .assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd""")) .assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""requestBodyStart""")) .assertLogMatch(Regex("""requestBodyStart"""))
@ -148,9 +158,15 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]""")) .assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT""")) .assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""secureConnectStart""")) .assertLogMatch(Regex("""secureConnectStart"""))
.assertLogMatch(Regex("""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}""")) .assertLogMatch(
Regex(
"""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}""",
),
)
.assertLogMatch(Regex("""connectEnd: h2""")) .assertLogMatch(Regex("""connectEnd: h2"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}""")) .assertLogMatch(
Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""),
)
.assertLogMatch(Regex("""requestHeadersStart""")) .assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd""")) .assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart""")) .assertLogMatch(Regex("""responseHeadersStart"""))
@ -164,10 +180,11 @@ class LoggingEventListenerTest {
@Test @Test
fun dnsFail() { fun dnsFail() {
client = OkHttpClient.Builder() client =
.dns { _ -> throw UnknownHostException("reason") } OkHttpClient.Builder()
.eventListenerFactory(loggingEventListenerFactory) .dns { _ -> throw UnknownHostException("reason") }
.build() .eventListenerFactory(loggingEventListenerFactory)
.build()
try { try {
client.newCall(request().build()).execute() client.newCall(request().build()).execute()
fail<Any>() fail<Any>()
@ -190,7 +207,7 @@ class LoggingEventListenerTest {
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.socketPolicy(FailHandshake) .socketPolicy(FailHandshake)
.build() .build(),
) )
url = server.url("/") url = server.url("/")
try { try {
@ -206,8 +223,16 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]""")) .assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT""")) .assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""secureConnectStart""")) .assertLogMatch(Regex("""secureConnectStart"""))
.assertLogMatch(Regex("""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""")) .assertLogMatch(
.assertLogMatch(Regex("""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""")) Regex(
"""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
),
)
.assertLogMatch(
Regex(
"""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
),
)
.assertNoMoreLogs() .assertNoMoreLogs()
} }

View File

@ -33,6 +33,9 @@ interface EventSource {
* asynchronous process to connect the socket. Once that succeeds or fails, `listener` will be * asynchronous process to connect the socket. Once that succeeds or fails, `listener` will be
* notified. The caller must cancel the returned event source when it is no longer in use. * notified. The caller must cancel the returned event source when it is no longer in use.
*/ */
fun newEventSource(request: Request, listener: EventSourceListener): EventSource fun newEventSource(
request: Request,
listener: EventSourceListener,
): EventSource
} }
} }

View File

@ -22,13 +22,21 @@ abstract class EventSourceListener {
* Invoked when an event source has been accepted by the remote peer and may begin transmitting * Invoked when an event source has been accepted by the remote peer and may begin transmitting
* events. * events.
*/ */
open fun onOpen(eventSource: EventSource, response: Response) { open fun onOpen(
eventSource: EventSource,
response: Response,
) {
} }
/** /**
* TODO description. * TODO description.
*/ */
open fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) { open fun onEvent(
eventSource: EventSource,
id: String?,
type: String?,
data: String,
) {
} }
/** /**
@ -43,6 +51,10 @@ abstract class EventSourceListener {
* Invoked when an event source has been closed due to an error reading from or writing to the * Invoked when an event source has been closed due to an error reading from or writing to the
* network. Incoming events may have been lost. No further calls to this listener will be made. * network. Incoming events may have been lost. No further calls to this listener will be made.
*/ */
open fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) { open fun onFailure(
eventSource: EventSource,
t: Throwable?,
response: Response?,
) {
} }
} }

View File

@ -45,7 +45,10 @@ object EventSources {
} }
@JvmStatic @JvmStatic
fun processResponse(response: Response, listener: EventSourceListener) { fun processResponse(
response: Response,
listener: EventSourceListener,
) {
val eventSource = RealEventSource(response.request, listener) val eventSource = RealEventSource(response.request, listener)
eventSource.processResponse(response) eventSource.processResponse(response)
} }

View File

@ -27,18 +27,23 @@ import okhttp3.sse.EventSourceListener
internal class RealEventSource( internal class RealEventSource(
private val request: Request, private val request: Request,
private val listener: EventSourceListener private val listener: EventSourceListener,
) : EventSource, ServerSentEventReader.Callback, Callback { ) : EventSource, ServerSentEventReader.Callback, Callback {
private var call: Call? = null private var call: Call? = null
@Volatile private var canceled = false @Volatile private var canceled = false
fun connect(callFactory: Call.Factory) { fun connect(callFactory: Call.Factory) {
call = callFactory.newCall(request).apply { call =
enqueue(this@RealEventSource) callFactory.newCall(request).apply {
} enqueue(this@RealEventSource)
}
} }
override fun onResponse(call: Call, response: Response) { override fun onResponse(
call: Call,
response: Response,
) {
processResponse(response) processResponse(response)
} }
@ -52,8 +57,11 @@ internal class RealEventSource(
val body = response.body val body = response.body
if (!body.isEventStream()) { if (!body.isEventStream()) {
listener.onFailure(this, listener.onFailure(
IllegalStateException("Invalid content-type: ${body.contentType()}"), response) this,
IllegalStateException("Invalid content-type: ${body.contentType()}"),
response,
)
return return
} }
@ -71,10 +79,11 @@ internal class RealEventSource(
} }
} }
} catch (e: Exception) { } catch (e: Exception) {
val exception = when { val exception =
canceled -> IOException("canceled", e) when {
else -> e canceled -> IOException("canceled", e)
} else -> e
}
listener.onFailure(this, exception, response) listener.onFailure(this, exception, response)
return return
} }
@ -91,7 +100,10 @@ internal class RealEventSource(
return contentType.type == "text" && contentType.subtype == "event-stream" return contentType.type == "text" && contentType.subtype == "event-stream"
} }
override fun onFailure(call: Call, e: IOException) { override fun onFailure(
call: Call,
e: IOException,
) {
listener.onFailure(this, e, null) listener.onFailure(this, e, null)
} }
@ -102,7 +114,11 @@ internal class RealEventSource(
call?.cancel() call?.cancel()
} }
override fun onEvent(id: String?, type: String?, data: String) { override fun onEvent(
id: String?,
type: String?,
data: String,
) {
listener.onEvent(this, id, type, data) listener.onEvent(this, id, type, data)
} }

View File

@ -24,12 +24,17 @@ import okio.Options
class ServerSentEventReader( class ServerSentEventReader(
private val source: BufferedSource, private val source: BufferedSource,
private val callback: Callback private val callback: Callback,
) { ) {
private var lastId: String? = null private var lastId: String? = null
interface Callback { interface Callback {
fun onEvent(id: String?, type: String?, data: String) fun onEvent(
id: String?,
type: String?,
data: String,
)
fun onRetryChange(timeMs: Long) fun onRetryChange(timeMs: Long)
} }
@ -101,7 +106,11 @@ class ServerSentEventReader(
} }
@Throws(IOException::class) @Throws(IOException::class)
private fun completeEvent(id: String?, type: String?, data: Buffer) { private fun completeEvent(
id: String?,
type: String?,
data: Buffer,
) {
if (data.size != 0L) { if (data.size != 0L) {
lastId = id lastId = id
data.skip(1L) // Leading newline. data.skip(1L) // Leading newline.
@ -110,35 +119,49 @@ class ServerSentEventReader(
} }
companion object { companion object {
val options = Options.of( val options =
/* 0 */ "\r\n".encodeUtf8(), Options.of(
/* 1 */ "\r".encodeUtf8(), // 0
/* 2 */ "\n".encodeUtf8(), "\r\n".encodeUtf8(),
// 1
/* 3 */ "data: ".encodeUtf8(), "\r".encodeUtf8(),
/* 4 */ "data:".encodeUtf8(), // 2
"\n".encodeUtf8(),
/* 5 */ "data\r\n".encodeUtf8(), // 3
/* 6 */ "data\r".encodeUtf8(), "data: ".encodeUtf8(),
/* 7 */ "data\n".encodeUtf8(), // 4
"data:".encodeUtf8(),
/* 8 */ "id: ".encodeUtf8(), // 5
/* 9 */ "id:".encodeUtf8(), "data\r\n".encodeUtf8(),
// 6
/* 10 */ "id\r\n".encodeUtf8(), "data\r".encodeUtf8(),
/* 11 */ "id\r".encodeUtf8(), // 7
/* 12 */ "id\n".encodeUtf8(), "data\n".encodeUtf8(),
// 8
/* 13 */ "event: ".encodeUtf8(), "id: ".encodeUtf8(),
/* 14 */ "event:".encodeUtf8(), // 9
"id:".encodeUtf8(),
/* 15 */ "event\r\n".encodeUtf8(), // 10
/* 16 */ "event\r".encodeUtf8(), "id\r\n".encodeUtf8(),
/* 17 */ "event\n".encodeUtf8(), // 11
"id\r".encodeUtf8(),
/* 18 */ "retry: ".encodeUtf8(), // 12
/* 19 */ "retry:".encodeUtf8() "id\n".encodeUtf8(),
) // 13
"event: ".encodeUtf8(),
// 14
"event:".encodeUtf8(),
// 15
"event\r\n".encodeUtf8(),
// 16
"event\r".encodeUtf8(),
// 17
"event\n".encodeUtf8(),
// 18
"retry: ".encodeUtf8(),
// 19
"retry:".encodeUtf8(),
)
private val CRLF = "\r\n".encodeUtf8() private val CRLF = "\r\n".encodeUtf8()

View File

@ -47,9 +47,10 @@ class EventSourceHttpTest {
val clientTestRule = OkHttpClientTestRule() val clientTestRule = OkHttpClientTestRule()
private val eventListener = RecordingEventListener() private val eventListener = RecordingEventListener()
private val listener = EventSourceRecorder() private val listener = EventSourceRecorder()
private var client = clientTestRule.newClientBuilder() private var client =
.eventListenerFactory(clientTestRule.wrap(eventListener)) clientTestRule.newClientBuilder()
.build() .eventListenerFactory(clientTestRule.wrap(eventListener))
.build()
@BeforeEach @BeforeEach
fun before(server: MockWebServer) { fun before(server: MockWebServer) {
@ -70,9 +71,9 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/event-stream") ).setHeader("content-type", "text/event-stream")
.build() .build(),
) )
val source = newEventSource() val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/") assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -90,9 +91,9 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/event-stream") ).setHeader("content-type", "text/event-stream")
.build() .build(),
) )
listener.enqueueCancel() // Will cancel in onOpen(). listener.enqueueCancel() // Will cancel in onOpen().
newEventSource() newEventSource()
@ -109,9 +110,9 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/plain") ).setHeader("content-type", "text/plain")
.build() .build(),
) )
newEventSource() newEventSource()
listener.assertFailure("Invalid content-type: text/plain") listener.assertFailure("Invalid content-type: text/plain")
@ -126,11 +127,11 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
) )
.setHeader("content-type", "text/event-stream") .setHeader("content-type", "text/event-stream")
.code(401) .code(401)
.build() .build(),
) )
newEventSource() newEventSource()
listener.assertFailure(null) listener.assertFailure(null)
@ -138,15 +139,16 @@ class EventSourceHttpTest {
@Test @Test
fun fullCallTimeoutDoesNotApplyOnceConnected() { fun fullCallTimeoutDoesNotApplyOnceConnected() {
client = client.newBuilder() client =
.callTimeout(250, TimeUnit.MILLISECONDS) client.newBuilder()
.build() .callTimeout(250, TimeUnit.MILLISECONDS)
.build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.bodyDelay(500, TimeUnit.MILLISECONDS) .bodyDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream") .setHeader("content-type", "text/event-stream")
.body("data: hey\n\n") .body("data: hey\n\n")
.build() .build(),
) )
val source = newEventSource() val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/") assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -157,15 +159,16 @@ class EventSourceHttpTest {
@Test @Test
fun fullCallTimeoutAppliesToSetup() { fun fullCallTimeoutAppliesToSetup() {
client = client.newBuilder() client =
.callTimeout(250, TimeUnit.MILLISECONDS) client.newBuilder()
.build() .callTimeout(250, TimeUnit.MILLISECONDS)
.build()
server.enqueue( server.enqueue(
MockResponse.Builder() MockResponse.Builder()
.headersDelay(500, TimeUnit.MILLISECONDS) .headersDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream") .setHeader("content-type", "text/event-stream")
.body("data: hey\n\n") .body("data: hey\n\n")
.build() .build(),
) )
newEventSource() newEventSource()
listener.assertFailure("timeout") listener.assertFailure("timeout")
@ -180,10 +183,10 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
) )
.setHeader("content-type", "text/event-stream") .setHeader("content-type", "text/event-stream")
.build() .build(),
) )
newEventSource("text/plain") newEventSource("text/plain")
listener.assertOpen() listener.assertOpen()
@ -201,9 +204,9 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/event-stream") ).setHeader("content-type", "text/event-stream")
.build() .build(),
) )
newEventSource() newEventSource()
listener.assertOpen() listener.assertOpen()
@ -222,9 +225,9 @@ class EventSourceHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/event-stream") ).setHeader("content-type", "text/event-stream")
.build() .build(),
) )
val source = newEventSource() val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/") assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -247,13 +250,14 @@ class EventSourceHttpTest {
"ResponseBodyStart", "ResponseBodyStart",
"ResponseBodyEnd", "ResponseBodyEnd",
"ConnectionReleased", "ConnectionReleased",
"CallEnd" "CallEnd",
) )
} }
private fun newEventSource(accept: String? = null): EventSource { private fun newEventSource(accept: String? = null): EventSource {
val builder = Request.Builder() val builder =
.url(server.url("/")) Request.Builder()
.url(server.url("/"))
if (accept != null) { if (accept != null) {
builder.header("Accept", accept) builder.header("Accept", accept)
} }

View File

@ -35,7 +35,10 @@ class EventSourceRecorder : EventSourceListener() {
cancel = true cancel = true
} }
override fun onOpen(eventSource: EventSource, response: Response) { override fun onOpen(
eventSource: EventSource,
response: Response,
) {
get().log("[ES] onOpen", Platform.INFO, null) get().log("[ES] onOpen", Platform.INFO, null)
events.add(Open(eventSource, response)) events.add(Open(eventSource, response))
drainCancelQueue(eventSource) drainCancelQueue(eventSource)
@ -52,9 +55,7 @@ class EventSourceRecorder : EventSourceListener() {
drainCancelQueue(eventSource) drainCancelQueue(eventSource)
} }
override fun onClosed( override fun onClosed(eventSource: EventSource) {
eventSource: EventSource,
) {
get().log("[ES] onClosed", Platform.INFO, null) get().log("[ES] onClosed", Platform.INFO, null)
events.add(Closed) events.add(Closed)
drainCancelQueue(eventSource) drainCancelQueue(eventSource)
@ -70,9 +71,7 @@ class EventSourceRecorder : EventSourceListener() {
drainCancelQueue(eventSource) drainCancelQueue(eventSource)
} }
private fun drainCancelQueue( private fun drainCancelQueue(eventSource: EventSource) {
eventSource: EventSource,
) {
if (cancel) { if (cancel) {
cancel = false cancel = false
eventSource.cancel() eventSource.cancel()

View File

@ -61,13 +61,14 @@ class EventSourcesHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/event-stream") ).setHeader("content-type", "text/event-stream")
.build() .build(),
) )
val request = Request.Builder() val request =
.url(server.url("/")) Request.Builder()
.build() .url(server.url("/"))
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
processResponse(response, listener) processResponse(response, listener)
listener.assertOpen() listener.assertOpen()
@ -84,14 +85,15 @@ class EventSourcesHttpTest {
|data: hey |data: hey
| |
| |
""".trimMargin() """.trimMargin(),
).setHeader("content-type", "text/event-stream") ).setHeader("content-type", "text/event-stream")
.build() .build(),
) )
listener.enqueueCancel() // Will cancel in onOpen(). listener.enqueueCancel() // Will cancel in onOpen().
val request = Request.Builder() val request =
.url(server.url("/")) Request.Builder()
.build() .url(server.url("/"))
.build()
val response = client.newCall(request).execute() val response = client.newCall(request).execute()
processResponse(response, listener) processResponse(response, listener)
listener.assertOpen() listener.assertOpen()

View File

@ -42,7 +42,7 @@ class ServerSentEventIteratorTest {
|data: 10 |data: 10
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
} }
@ -56,7 +56,7 @@ class ServerSentEventIteratorTest {
|data: 10 |data: 10
| |
| |
""".trimMargin().replace("\n", "\r") """.trimMargin().replace("\n", "\r"),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
} }
@ -70,7 +70,7 @@ class ServerSentEventIteratorTest {
|data: 10 |data: 10
| |
| |
""".trimMargin().replace("\n", "\r\n") """.trimMargin().replace("\n", "\r\n"),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
} }
@ -89,7 +89,7 @@ class ServerSentEventIteratorTest {
|data: 113411 |data: 113411
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, "add", "73857293")) assertThat(callbacks.remove()).isEqualTo(Event(null, "add", "73857293"))
assertThat(callbacks.remove()).isEqualTo(Event(null, "remove", "2153")) assertThat(callbacks.remove()).isEqualTo(Event(null, "remove", "2153"))
@ -106,7 +106,7 @@ class ServerSentEventIteratorTest {
|id: 1 |id: 1
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
} }
@ -124,7 +124,7 @@ class ServerSentEventIteratorTest {
|data: third event |data: third event
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "second event")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "second event"))
@ -142,7 +142,7 @@ class ServerSentEventIteratorTest {
| |
|data: |data:
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, ""))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "\n")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "\n"))
@ -157,7 +157,7 @@ class ServerSentEventIteratorTest {
|data: test |data: test
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test"))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test"))
@ -170,7 +170,7 @@ class ServerSentEventIteratorTest {
|data: test |data: test
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, " test")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, " test"))
} }
@ -188,7 +188,7 @@ class ServerSentEventIteratorTest {
|data: third event |data: third event
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event"))
@ -207,7 +207,7 @@ class ServerSentEventIteratorTest {
|data: second event |data: second event
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event"))
@ -223,7 +223,7 @@ class ServerSentEventIteratorTest {
|id: 1 |id: 1
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(22L) assertThat(callbacks.remove()).isEqualTo(22L)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event")) assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
@ -237,7 +237,7 @@ class ServerSentEventIteratorTest {
| |
|retry: hey |retry: hey
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(22L) assertThat(callbacks.remove()).isEqualTo(22L)
} }
@ -253,7 +253,7 @@ class ServerSentEventIteratorTest {
|retrying |retrying
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "a")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "a"))
} }
@ -270,7 +270,7 @@ class ServerSentEventIteratorTest {
|data |data
| |
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "c\n")) assertThat(callbacks.remove()).isEqualTo(Event(null, null, "c\n"))
} }
@ -281,21 +281,26 @@ class ServerSentEventIteratorTest {
""" """
|retry |retry
| |
""".trimMargin() """.trimMargin(),
) )
assertThat(callbacks).isEmpty() assertThat(callbacks).isEmpty()
} }
private fun consumeEvents(source: String) { private fun consumeEvents(source: String) {
val callback: ServerSentEventReader.Callback = object : ServerSentEventReader.Callback { val callback: ServerSentEventReader.Callback =
override fun onEvent(id: String?, type: String?, data: String) { object : ServerSentEventReader.Callback {
callbacks.add(Event(id, type, data)) override fun onEvent(
} id: String?,
type: String?,
data: String,
) {
callbacks.add(Event(id, type, data))
}
override fun onRetryChange(timeMs: Long) { override fun onRetryChange(timeMs: Long) {
callbacks.add(timeMs) callbacks.add(timeMs)
}
} }
}
val buffer = Buffer().writeUtf8(source) val buffer = Buffer().writeUtf8(source)
val reader = ServerSentEventReader(buffer, callback) val reader = ServerSentEventReader(buffer, callback)
while (reader.processNextEvent()) { while (reader.processNextEvent()) {

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3 package okhttp3
import java.io.IOException import java.io.IOException
@ -37,40 +38,38 @@ sealed class CallEvent {
data class ProxySelectStart( data class ProxySelectStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val url: HttpUrl val url: HttpUrl,
) : CallEvent() ) : CallEvent()
data class ProxySelectEnd( data class ProxySelectEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val url: HttpUrl, val url: HttpUrl,
val proxies: List<Proxy>? val proxies: List<Proxy>?,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is ProxySelectStart && call == event.call && url == event.url
event is ProxySelectStart && call == event.call && url == event.url
} }
data class DnsStart( data class DnsStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val domainName: String val domainName: String,
) : CallEvent() ) : CallEvent()
data class DnsEnd( data class DnsEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val domainName: String, val domainName: String,
val inetAddressList: List<InetAddress> val inetAddressList: List<InetAddress>,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is DnsStart && call == event.call && domainName == event.domainName
event is DnsStart && call == event.call && domainName == event.domainName
} }
data class ConnectStart( data class ConnectStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val inetSocketAddress: InetSocketAddress, val inetSocketAddress: InetSocketAddress,
val proxy: Proxy? val proxy: Proxy?,
) : CallEvent() ) : CallEvent()
data class ConnectEnd( data class ConnectEnd(
@ -78,7 +77,7 @@ sealed class CallEvent {
override val call: Call, override val call: Call,
val inetSocketAddress: InetSocketAddress, val inetSocketAddress: InetSocketAddress,
val proxy: Proxy?, val proxy: Proxy?,
val protocol: Protocol? val protocol: Protocol?,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean =
event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy
@ -90,7 +89,7 @@ sealed class CallEvent {
val inetSocketAddress: InetSocketAddress, val inetSocketAddress: InetSocketAddress,
val proxy: Proxy, val proxy: Proxy,
val protocol: Protocol?, val protocol: Protocol?,
val ioe: IOException val ioe: IOException,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean =
event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy
@ -98,149 +97,139 @@ sealed class CallEvent {
data class SecureConnectStart( data class SecureConnectStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class SecureConnectEnd( data class SecureConnectEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val handshake: Handshake? val handshake: Handshake?,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is SecureConnectStart && call == event.call
event is SecureConnectStart && call == event.call
} }
data class ConnectionAcquired( data class ConnectionAcquired(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val connection: Connection val connection: Connection,
) : CallEvent() ) : CallEvent()
data class ConnectionReleased( data class ConnectionReleased(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val connection: Connection val connection: Connection,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is ConnectionAcquired && call == event.call && connection == event.connection
event is ConnectionAcquired && call == event.call && connection == event.connection
} }
data class CallStart( data class CallStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class CallEnd( data class CallEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is CallStart && call == event.call
event is CallStart && call == event.call
} }
data class CallFailed( data class CallFailed(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val ioe: IOException val ioe: IOException,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is CallStart && call == event.call
event is CallStart && call == event.call
} }
data class Canceled( data class Canceled(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class RequestHeadersStart( data class RequestHeadersStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class RequestHeadersEnd( data class RequestHeadersEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val headerLength: Long val headerLength: Long,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is RequestHeadersStart && call == event.call
event is RequestHeadersStart && call == event.call
} }
data class RequestBodyStart( data class RequestBodyStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class RequestBodyEnd( data class RequestBodyEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val bytesWritten: Long val bytesWritten: Long,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is RequestBodyStart && call == event.call
event is RequestBodyStart && call == event.call
} }
data class RequestFailed( data class RequestFailed(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val ioe: IOException val ioe: IOException,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = event is RequestHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean =
event is RequestHeadersStart && call == event.call
} }
data class ResponseHeadersStart( data class ResponseHeadersStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class ResponseHeadersEnd( data class ResponseHeadersEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val headerLength: Long val headerLength: Long,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is ResponseHeadersStart && call == event.call
event is ResponseHeadersStart && call == event.call
} }
data class ResponseBodyStart( data class ResponseBodyStart(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class ResponseBodyEnd( data class ResponseBodyEnd(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val bytesRead: Long val bytesRead: Long,
) : CallEvent() { ) : CallEvent() {
override fun closes(event: CallEvent): Boolean = override fun closes(event: CallEvent): Boolean = event is ResponseBodyStart && call == event.call
event is ResponseBodyStart && call == event.call
} }
data class ResponseFailed( data class ResponseFailed(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call, override val call: Call,
val ioe: IOException val ioe: IOException,
) : CallEvent() ) : CallEvent()
data class SatisfactionFailure( data class SatisfactionFailure(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class CacheHit( data class CacheHit(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class CacheMiss( data class CacheMiss(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
data class CacheConditionalHit( data class CacheConditionalHit(
override val timestampNs: Long, override val timestampNs: Long,
override val call: Call override val call: Call,
) : CallEvent() ) : CallEvent()
} }

View File

@ -23,9 +23,9 @@ import java.util.concurrent.TimeUnit
class ClientRuleEventListener( class ClientRuleEventListener(
val delegate: EventListener = NONE, val delegate: EventListener = NONE,
var logger: (String) -> Unit var logger: (String) -> Unit,
) : EventListener(), ) : EventListener(),
EventListener.Factory { EventListener.Factory {
private var startNs: Long? = null private var startNs: Long? = null
override fun create(call: Call): EventListener = this override fun create(call: Call): EventListener = this
@ -40,7 +40,7 @@ class ClientRuleEventListener(
override fun proxySelectStart( override fun proxySelectStart(
call: Call, call: Call,
url: HttpUrl url: HttpUrl,
) { ) {
logWithTime("proxySelectStart: $url") logWithTime("proxySelectStart: $url")
@ -50,7 +50,7 @@ class ClientRuleEventListener(
override fun proxySelectEnd( override fun proxySelectEnd(
call: Call, call: Call,
url: HttpUrl, url: HttpUrl,
proxies: List<Proxy> proxies: List<Proxy>,
) { ) {
logWithTime("proxySelectEnd: $proxies") logWithTime("proxySelectEnd: $proxies")
@ -59,7 +59,7 @@ class ClientRuleEventListener(
override fun dnsStart( override fun dnsStart(
call: Call, call: Call,
domainName: String domainName: String,
) { ) {
logWithTime("dnsStart: $domainName") logWithTime("dnsStart: $domainName")
@ -69,7 +69,7 @@ class ClientRuleEventListener(
override fun dnsEnd( override fun dnsEnd(
call: Call, call: Call,
domainName: String, domainName: String,
inetAddressList: List<InetAddress> inetAddressList: List<InetAddress>,
) { ) {
logWithTime("dnsEnd: $inetAddressList") logWithTime("dnsEnd: $inetAddressList")
@ -79,7 +79,7 @@ class ClientRuleEventListener(
override fun connectStart( override fun connectStart(
call: Call, call: Call,
inetSocketAddress: InetSocketAddress, inetSocketAddress: InetSocketAddress,
proxy: Proxy proxy: Proxy,
) { ) {
logWithTime("connectStart: $inetSocketAddress $proxy") logWithTime("connectStart: $inetSocketAddress $proxy")
@ -94,7 +94,7 @@ class ClientRuleEventListener(
override fun secureConnectEnd( override fun secureConnectEnd(
call: Call, call: Call,
handshake: Handshake? handshake: Handshake?,
) { ) {
logWithTime("secureConnectEnd: $handshake") logWithTime("secureConnectEnd: $handshake")
@ -105,7 +105,7 @@ class ClientRuleEventListener(
call: Call, call: Call,
inetSocketAddress: InetSocketAddress, inetSocketAddress: InetSocketAddress,
proxy: Proxy, proxy: Proxy,
protocol: Protocol? protocol: Protocol?,
) { ) {
logWithTime("connectEnd: $protocol") logWithTime("connectEnd: $protocol")
@ -117,7 +117,7 @@ class ClientRuleEventListener(
inetSocketAddress: InetSocketAddress, inetSocketAddress: InetSocketAddress,
proxy: Proxy, proxy: Proxy,
protocol: Protocol?, protocol: Protocol?,
ioe: IOException ioe: IOException,
) { ) {
logWithTime("connectFailed: $protocol $ioe") logWithTime("connectFailed: $protocol $ioe")
@ -126,7 +126,7 @@ class ClientRuleEventListener(
override fun connectionAcquired( override fun connectionAcquired(
call: Call, call: Call,
connection: Connection connection: Connection,
) { ) {
logWithTime("connectionAcquired: $connection") logWithTime("connectionAcquired: $connection")
@ -135,7 +135,7 @@ class ClientRuleEventListener(
override fun connectionReleased( override fun connectionReleased(
call: Call, call: Call,
connection: Connection connection: Connection,
) { ) {
logWithTime("connectionReleased") logWithTime("connectionReleased")
@ -150,7 +150,7 @@ class ClientRuleEventListener(
override fun requestHeadersEnd( override fun requestHeadersEnd(
call: Call, call: Call,
request: Request request: Request,
) { ) {
logWithTime("requestHeadersEnd") logWithTime("requestHeadersEnd")
@ -165,7 +165,7 @@ class ClientRuleEventListener(
override fun requestBodyEnd( override fun requestBodyEnd(
call: Call, call: Call,
byteCount: Long byteCount: Long,
) { ) {
logWithTime("requestBodyEnd: byteCount=$byteCount") logWithTime("requestBodyEnd: byteCount=$byteCount")
@ -174,7 +174,7 @@ class ClientRuleEventListener(
override fun requestFailed( override fun requestFailed(
call: Call, call: Call,
ioe: IOException ioe: IOException,
) { ) {
logWithTime("requestFailed: $ioe") logWithTime("requestFailed: $ioe")
@ -189,7 +189,7 @@ class ClientRuleEventListener(
override fun responseHeadersEnd( override fun responseHeadersEnd(
call: Call, call: Call,
response: Response response: Response,
) { ) {
logWithTime("responseHeadersEnd: $response") logWithTime("responseHeadersEnd: $response")
@ -204,7 +204,7 @@ class ClientRuleEventListener(
override fun responseBodyEnd( override fun responseBodyEnd(
call: Call, call: Call,
byteCount: Long byteCount: Long,
) { ) {
logWithTime("responseBodyEnd: byteCount=$byteCount") logWithTime("responseBodyEnd: byteCount=$byteCount")
@ -213,7 +213,7 @@ class ClientRuleEventListener(
override fun responseFailed( override fun responseFailed(
call: Call, call: Call,
ioe: IOException ioe: IOException,
) { ) {
logWithTime("responseFailed: $ioe") logWithTime("responseFailed: $ioe")
@ -228,7 +228,7 @@ class ClientRuleEventListener(
override fun callFailed( override fun callFailed(
call: Call, call: Call,
ioe: IOException ioe: IOException,
) { ) {
logWithTime("callFailed: $ioe") logWithTime("callFailed: $ioe")
@ -241,7 +241,10 @@ class ClientRuleEventListener(
delegate.canceled(call) delegate.canceled(call)
} }
override fun satisfactionFailure(call: Call, response: Response) { override fun satisfactionFailure(
call: Call,
response: Response,
) {
logWithTime("satisfactionFailure") logWithTime("satisfactionFailure")
delegate.satisfactionFailure(call, response) delegate.satisfactionFailure(call, response)
@ -253,13 +256,19 @@ class ClientRuleEventListener(
delegate.cacheMiss(call) delegate.cacheMiss(call)
} }
override fun cacheHit(call: Call, response: Response) { override fun cacheHit(
call: Call,
response: Response,
) {
logWithTime("cacheHit") logWithTime("cacheHit")
delegate.cacheHit(call, response) delegate.cacheHit(call, response)
} }
override fun cacheConditionalHit(call: Call, cachedResponse: Response) { override fun cacheConditionalHit(
call: Call,
cachedResponse: Response,
) {
logWithTime("cacheConditionalHit") logWithTime("cacheConditionalHit")
delegate.cacheConditionalHit(call, cachedResponse) delegate.cacheConditionalHit(call, cachedResponse)
@ -267,12 +276,13 @@ class ClientRuleEventListener(
private fun logWithTime(message: String) { private fun logWithTime(message: String) {
val startNs = startNs val startNs = startNs
val timeMs = if (startNs == null) { val timeMs =
// Event occurred before start, for an example an early cancel. if (startNs == null) {
0L // Event occurred before start, for an example an early cancel.
} else { 0L
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs) } else {
} TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
}
logger.invoke("[$timeMs ms] $message") logger.invoke("[$timeMs ms] $message")
} }

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3 package okhttp3
import java.io.IOException import java.io.IOException
@ -35,17 +36,16 @@ sealed class ConnectionEvent {
data class ConnectStart( data class ConnectStart(
override val timestampNs: Long, override val timestampNs: Long,
val route: Route, val route: Route,
val call: Call val call: Call,
) : ConnectionEvent() ) : ConnectionEvent()
data class ConnectFailed( data class ConnectFailed(
override val timestampNs: Long, override val timestampNs: Long,
val route: Route, val route: Route,
val call: Call, val call: Call,
val exception: IOException val exception: IOException,
) : ConnectionEvent() { ) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean = override fun closes(event: ConnectionEvent): Boolean = event is ConnectStart && call == event.call && route == event.route
event is ConnectStart && call == event.call && route == event.route
} }
data class ConnectEnd( data class ConnectEnd(
@ -54,8 +54,7 @@ sealed class ConnectionEvent {
val route: Route, val route: Route,
val call: Call, val call: Call,
) : ConnectionEvent() { ) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean = override fun closes(event: ConnectionEvent): Boolean = event is ConnectStart && call == event.call && route == event.route
event is ConnectStart && call == event.call && route == event.route
} }
data class ConnectionClosed( data class ConnectionClosed(
@ -66,15 +65,14 @@ sealed class ConnectionEvent {
data class ConnectionAcquired( data class ConnectionAcquired(
override val timestampNs: Long, override val timestampNs: Long,
override val connection: Connection, override val connection: Connection,
val call: Call val call: Call,
) : ConnectionEvent() ) : ConnectionEvent()
data class ConnectionReleased( data class ConnectionReleased(
override val timestampNs: Long, override val timestampNs: Long,
override val connection: Connection, override val connection: Connection,
val call: Call val call: Call,
) : ConnectionEvent() { ) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean = override fun closes(event: ConnectionEvent): Boolean =
event is ConnectionAcquired && connection == event.connection && call == event.call event is ConnectionAcquired && connection == event.connection && call == event.call
} }

View File

@ -24,7 +24,6 @@ import javax.security.cert.X509Certificate
/** An [SSLSession] that delegates all calls. */ /** An [SSLSession] that delegates all calls. */
abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSession { abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSession {
override fun getId(): ByteArray { override fun getId(): ByteArray {
return delegate!!.id return delegate!!.id
} }
@ -49,7 +48,10 @@ abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSe
return delegate!!.isValid return delegate!!.isValid
} }
override fun putValue(s: String, o: Any) { override fun putValue(
s: String,
o: Any,
) {
delegate!!.putValue(s, o) delegate!!.putValue(s, o)
} }

View File

@ -198,7 +198,10 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
} }
@Throws(SocketException::class) @Throws(SocketException::class)
override fun setSoLinger(on: Boolean, timeout: Int) { override fun setSoLinger(
on: Boolean,
timeout: Int,
) {
delegate!!.setSoLinger(on, timeout) delegate!!.setSoLinger(on, timeout)
} }
@ -247,7 +250,10 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun connect(remoteAddr: SocketAddress, timeout: Int) { override fun connect(
remoteAddr: SocketAddress,
timeout: Int,
) {
delegate!!.connect(remoteAddr, timeout) delegate!!.connect(remoteAddr, timeout)
} }

View File

@ -33,28 +33,40 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket(host: String, port: Int): SSLSocket { override fun createSocket(
host: String,
port: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port) as SSLSocket val sslSocket = delegate.createSocket(host, port) as SSLSocket
return configureSocket(sslSocket) return configureSocket(sslSocket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket( override fun createSocket(
host: String, port: Int, localAddress: InetAddress, localPort: Int host: String,
port: Int,
localAddress: InetAddress,
localPort: Int,
): SSLSocket { ): SSLSocket {
val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket
return configureSocket(sslSocket) return configureSocket(sslSocket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket(host: InetAddress, port: Int): SSLSocket { override fun createSocket(
host: InetAddress,
port: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port) as SSLSocket val sslSocket = delegate.createSocket(host, port) as SSLSocket
return configureSocket(sslSocket) return configureSocket(sslSocket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket( override fun createSocket(
host: InetAddress, port: Int, localAddress: InetAddress, localPort: Int host: InetAddress,
port: Int,
localAddress: InetAddress,
localPort: Int,
): SSLSocket { ): SSLSocket {
val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket
return configureSocket(sslSocket) return configureSocket(sslSocket)
@ -70,7 +82,10 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket( override fun createSocket(
socket: Socket, host: String, port: Int, autoClose: Boolean socket: Socket,
host: String,
port: Int,
autoClose: Boolean,
): SSLSocket { ): SSLSocket {
val sslSocket = delegate.createSocket(socket, host, port, autoClose) as SSLSocket val sslSocket = delegate.createSocket(socket, host, port, autoClose) as SSLSocket
return configureSocket(sslSocket) return configureSocket(sslSocket)

View File

@ -38,13 +38,20 @@ open class DelegatingServerSocketFactory(private val delegate: ServerSocketFacto
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createServerSocket(port: Int, backlog: Int): ServerSocket { override fun createServerSocket(
port: Int,
backlog: Int,
): ServerSocket {
val serverSocket = delegate.createServerSocket(port, backlog) val serverSocket = delegate.createServerSocket(port, backlog)
return configureServerSocket(serverSocket) return configureServerSocket(serverSocket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createServerSocket(port: Int, backlog: Int, ifAddress: InetAddress): ServerSocket { override fun createServerSocket(
port: Int,
backlog: Int,
ifAddress: InetAddress,
): ServerSocket {
val serverSocket = delegate.createServerSocket(port, backlog, ifAddress) val serverSocket = delegate.createServerSocket(port, backlog, ifAddress)
return configureServerSocket(serverSocket) return configureServerSocket(serverSocket)
} }

View File

@ -32,30 +32,40 @@ open class DelegatingSocketFactory(private val delegate: SocketFactory) : Socket
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket(host: String, port: Int): Socket { override fun createSocket(
host: String,
port: Int,
): Socket {
val socket = delegate.createSocket(host, port) val socket = delegate.createSocket(host, port)
return configureSocket(socket) return configureSocket(socket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket( override fun createSocket(
host: String, port: Int, localAddress: InetAddress, host: String,
localPort: Int port: Int,
localAddress: InetAddress,
localPort: Int,
): Socket { ): Socket {
val socket = delegate.createSocket(host, port, localAddress, localPort) val socket = delegate.createSocket(host, port, localAddress, localPort)
return configureSocket(socket) return configureSocket(socket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket(host: InetAddress, port: Int): Socket { override fun createSocket(
host: InetAddress,
port: Int,
): Socket {
val socket = delegate.createSocket(host, port) val socket = delegate.createSocket(host, port)
return configureSocket(socket) return configureSocket(socket)
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun createSocket( override fun createSocket(
host: InetAddress, port: Int, localAddress: InetAddress, host: InetAddress,
localPort: Int port: Int,
localAddress: InetAddress,
localPort: Int,
): Socket { ): Socket {
val socket = delegate.createSocket(host, port, localAddress, localPort) val socket = delegate.createSocket(host, port, localAddress, localPort)
return configureSocket(socket) return configureSocket(socket)

View File

@ -29,7 +29,7 @@ class FakeDns : Dns {
/** Sets the results for `hostname`. */ /** Sets the results for `hostname`. */
operator fun set( operator fun set(
hostname: String, hostname: String,
addresses: List<InetAddress> addresses: List<InetAddress>,
): FakeDns { ): FakeDns {
hostAddresses[hostname] = addresses hostAddresses[hostname] = addresses
return this return this
@ -41,9 +41,10 @@ class FakeDns : Dns {
return this return this
} }
@Throws(UnknownHostException::class) fun lookup( @Throws(UnknownHostException::class)
fun lookup(
hostname: String, hostname: String,
index: Int index: Int,
): InetAddress { ): InetAddress {
return hostAddresses[hostname]!![index] return hostAddresses[hostname]!![index]
} }
@ -66,7 +67,7 @@ class FakeDns : Dns {
return (from until nextAddress) return (from until nextAddress)
.map { .map {
return@map InetAddress.getByAddress( return@map InetAddress.getByAddress(
Buffer().writeInt(it.toInt()).readByteArray() Buffer().writeInt(it.toInt()).readByteArray(),
) )
} }
} }
@ -78,7 +79,7 @@ class FakeDns : Dns {
return (from until nextAddress) return (from until nextAddress)
.map { .map {
return@map InetAddress.getByAddress( return@map InetAddress.getByAddress(
Buffer().writeLong(0L).writeLong(it).readByteArray() Buffer().writeLong(0L).writeLong(it).readByteArray(),
) )
} }
} }

View File

@ -37,7 +37,7 @@ class FakeProxySelector : ProxySelector() {
override fun connectFailed( override fun connectFailed(
uri: URI, uri: URI,
sa: SocketAddress, sa: SocketAddress,
ioe: IOException ioe: IOException,
) { ) {
} }
} }

View File

@ -67,7 +67,7 @@ class FakeSSLSession(vararg val certificates: Certificate) : SSLSession {
} }
@Throws( @Throws(
SSLPeerUnverifiedException::class SSLPeerUnverifiedException::class,
) )
override fun getPeerCertificateChain(): Array<X509Certificate> { override fun getPeerCertificateChain(): Array<X509Certificate> {
throw UnsupportedOperationException() throw UnsupportedOperationException()
@ -96,7 +96,7 @@ class FakeSSLSession(vararg val certificates: Certificate) : SSLSession {
override fun putValue( override fun putValue(
s: String, s: String,
obj: Any obj: Any,
) { ) {
throw UnsupportedOperationException() throw UnsupportedOperationException()
} }

View File

@ -20,6 +20,7 @@ import okio.BufferedSink
open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() { open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() {
private val delegate: RequestBody private val delegate: RequestBody
fun delegate(): RequestBody { fun delegate(): RequestBody {
return delegate return delegate
} }
@ -28,7 +29,8 @@ open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() {
return delegate.contentType() return delegate.contentType()
} }
@Throws(IOException::class) override fun contentLength(): Long { @Throws(IOException::class)
override fun contentLength(): Long {
return delegate.contentLength() return delegate.contentLength()
} }

View File

@ -19,6 +19,7 @@ import okio.BufferedSource
open class ForwardingResponseBody(delegate: ResponseBody?) : ResponseBody() { open class ForwardingResponseBody(delegate: ResponseBody?) : ResponseBody() {
private val delegate: ResponseBody private val delegate: ResponseBody
fun delegate(): ResponseBody { fun delegate(): ResponseBody {
return delegate return delegate
} }

View File

@ -22,25 +22,30 @@ import java.util.logging.LogRecord
object JsseDebugLogging { object JsseDebugLogging {
data class JsseDebugMessage(val message: String, val param: String?) { data class JsseDebugMessage(val message: String, val param: String?) {
enum class Type { enum class Type {
Handshake, Plaintext, Encrypted, Setup, Unknown Handshake,
Plaintext,
Encrypted,
Setup,
Unknown,
} }
val type: Type val type: Type
get() = when { get() =
message == "adding as trusted certificates" -> Type.Setup when {
message == "Raw read" || message == "Raw write" -> Type.Encrypted message == "adding as trusted certificates" -> Type.Setup
message == "Plaintext before ENCRYPTION" || message == "Plaintext after DECRYPTION" -> Type.Plaintext message == "Raw read" || message == "Raw write" -> Type.Encrypted
message.startsWith("System property ") -> Type.Setup message == "Plaintext before ENCRYPTION" || message == "Plaintext after DECRYPTION" -> Type.Plaintext
message.startsWith("Reload ") -> Type.Setup message.startsWith("System property ") -> Type.Setup
message == "No session to resume." -> Type.Handshake message.startsWith("Reload ") -> Type.Setup
message.startsWith("Consuming ") -> Type.Handshake message == "No session to resume." -> Type.Handshake
message.startsWith("Produced ") -> Type.Handshake message.startsWith("Consuming ") -> Type.Handshake
message.startsWith("Negotiated ") -> Type.Handshake message.startsWith("Produced ") -> Type.Handshake
message.startsWith("Found resumable session") -> Type.Handshake message.startsWith("Negotiated ") -> Type.Handshake
message.startsWith("Resuming session") -> Type.Handshake message.startsWith("Found resumable session") -> Type.Handshake
message.startsWith("Using PSK to derive early secret") -> Type.Handshake message.startsWith("Resuming session") -> Type.Handshake
else -> Type.Unknown message.startsWith("Using PSK to derive early secret") -> Type.Handshake
} else -> Type.Unknown
}
override fun toString(): String { override fun toString(): String {
return if (param != null) { return if (param != null) {
@ -66,17 +71,20 @@ object JsseDebugLogging {
fun enableJsseDebugLogging(debugHandler: (JsseDebugMessage) -> Unit = this::quietDebug): Closeable { fun enableJsseDebugLogging(debugHandler: (JsseDebugMessage) -> Unit = this::quietDebug): Closeable {
System.setProperty("javax.net.debug", "") System.setProperty("javax.net.debug", "")
return OkHttpDebugLogging.enable("javax.net.ssl", object : Handler() { return OkHttpDebugLogging.enable(
override fun publish(record: LogRecord) { "javax.net.ssl",
val param = record.parameters?.firstOrNull() as? String object : Handler() {
debugHandler(JsseDebugMessage(record.message, param)) override fun publish(record: LogRecord) {
} val param = record.parameters?.firstOrNull() as? String
debugHandler(JsseDebugMessage(record.message, param))
}
override fun flush() { override fun flush() {
} }
override fun close() { override fun close() {
} }
}) },
)
} }
} }

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") @file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3 package okhttp3
import android.annotation.SuppressLint import android.annotation.SuppressLint
@ -62,57 +63,60 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
var recordFrames = false var recordFrames = false
var recordSslDebug = false var recordSslDebug = false
private val sslExcludeFilter = Regex( private val sslExcludeFilter =
buildString { Regex(
append("^(?:") buildString {
append( append("^(?:")
listOf( append(
"Inaccessible trust store", listOf(
"trustStore is", "Inaccessible trust store",
"Reload the trust store", "trustStore is",
"Reload trust certs", "Reload the trust store",
"Reloaded", "Reload trust certs",
"adding as trusted certificates", "Reloaded",
"Ignore disabled cipher suite", "adding as trusted certificates",
"Ignore unsupported cipher suite", "Ignore disabled cipher suite",
).joinToString(separator = "|") "Ignore unsupported cipher suite",
) ).joinToString(separator = "|"),
append(").*") )
} append(").*")
) },
)
private val testLogHandler = object : Handler() { private val testLogHandler =
override fun publish(record: LogRecord) { object : Handler() {
val recorded = when (record.loggerName) { override fun publish(record: LogRecord) {
TaskRunner::class.java.name -> recordTaskRunner val recorded =
Http2::class.java.name -> recordFrames when (record.loggerName) {
"javax.net.ssl" -> recordSslDebug && !sslExcludeFilter.matches(record.message) TaskRunner::class.java.name -> recordTaskRunner
else -> false Http2::class.java.name -> recordFrames
} "javax.net.ssl" -> recordSslDebug && !sslExcludeFilter.matches(record.message)
else -> false
}
if (recorded) { if (recorded) {
synchronized(clientEventsList) { synchronized(clientEventsList) {
clientEventsList.add(record.message) clientEventsList.add(record.message)
if (record.loggerName == "javax.net.ssl") { if (record.loggerName == "javax.net.ssl") {
val parameters = record.parameters val parameters = record.parameters
if (parameters != null) { if (parameters != null) {
clientEventsList.add(parameters.first().toString()) clientEventsList.add(parameters.first().toString())
}
} }
} }
} }
} }
}
override fun flush() { override fun flush() {
} }
override fun close() { override fun close() {
}
}.apply {
level = Level.FINEST
} }
}.apply {
level = Level.FINEST
}
private fun applyLogger(fn: Logger.() -> Unit) { private fun applyLogger(fn: Logger.() -> Unit) {
Logger.getLogger(OkHttpClient::class.java.`package`.name).fn() Logger.getLogger(OkHttpClient::class.java.`package`.name).fn()
@ -122,8 +126,7 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
Logger.getLogger("javax.net.ssl").fn() Logger.getLogger("javax.net.ssl").fn()
} }
fun wrap(eventListener: EventListener) = fun wrap(eventListener: EventListener) = EventListener.Factory { ClientRuleEventListener(eventListener, ::addEvent) }
EventListener.Factory { ClientRuleEventListener(eventListener, ::addEvent) }
fun wrap(eventListenerFactory: EventListener.Factory) = fun wrap(eventListenerFactory: EventListener.Factory) =
EventListener.Factory { call -> ClientRuleEventListener(eventListenerFactory.create(call), ::addEvent) } EventListener.Factory { call -> ClientRuleEventListener(eventListenerFactory.create(call), ::addEvent) }
@ -140,10 +143,11 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
fun newClient(): OkHttpClient { fun newClient(): OkHttpClient {
var client = testClient var client = testClient
if (client == null) { if (client == null) {
client = initialClientBuilder() client =
.dns(SINGLE_INET_ADDRESS_DNS) // Prevent unexpected fallback addresses. initialClientBuilder()
.eventListenerFactory { ClientRuleEventListener(logger = ::addEvent) } .dns(SINGLE_INET_ADDRESS_DNS) // Prevent unexpected fallback addresses.
.build() .eventListenerFactory { ClientRuleEventListener(logger = ::addEvent) }
.build()
connectionListener.forbidLock(RealConnectionPool.get(client.connectionPool)) connectionListener.forbidLock(RealConnectionPool.get(client.connectionPool))
connectionListener.forbidLock(client.dispatcher) connectionListener.forbidLock(client.dispatcher)
testClient = client testClient = client
@ -151,23 +155,24 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
return client return client
} }
private fun initialClientBuilder(): OkHttpClient.Builder = if (isLoom()) { private fun initialClientBuilder(): OkHttpClient.Builder =
val backend = TaskRunner.RealBackend(loomThreadFactory()) if (isLoom()) {
val taskRunner = TaskRunner(backend) val backend = TaskRunner.RealBackend(loomThreadFactory())
val taskRunner = TaskRunner(backend)
OkHttpClient.Builder() OkHttpClient.Builder()
.connectionPool( .connectionPool(
buildConnectionPool( buildConnectionPool(
connectionListener = connectionListener, connectionListener = connectionListener,
taskRunner = taskRunner, taskRunner = taskRunner,
),
) )
) .dispatcher(Dispatcher(backend.executor))
.dispatcher(Dispatcher(backend.executor)) .taskRunnerInternal(taskRunner)
.taskRunnerInternal(taskRunner) } else {
} else { OkHttpClient.Builder()
OkHttpClient.Builder() .connectionPool(ConnectionPool(connectionListener = connectionListener))
.connectionPool(ConnectionPool(connectionListener = connectionListener)) }
}
private fun loomThreadFactory(): ThreadFactory { private fun loomThreadFactory(): ThreadFactory {
val ofVirtual = Thread::class.java.getMethod("ofVirtual").invoke(null) val ofVirtual = Thread::class.java.getMethod("ofVirtual").invoke(null)
@ -322,10 +327,11 @@ class OkHttpClientTestRule : BeforeEachCallback, AfterEachCallback {
* A network that resolves only one IP address per host. Use this when testing route selection * A network that resolves only one IP address per host. Use this when testing route selection
* fallbacks to prevent the host machine's various IP addresses from interfering. * fallbacks to prevent the host machine's various IP addresses from interfering.
*/ */
private val SINGLE_INET_ADDRESS_DNS = Dns { hostname -> private val SINGLE_INET_ADDRESS_DNS =
val addresses = Dns.SYSTEM.lookup(hostname) Dns { hostname ->
listOf(addresses[0]) val addresses = Dns.SYSTEM.lookup(hostname)
} listOf(addresses[0])
}
private operator fun Throwable?.plus(throwable: Throwable): Throwable { private operator fun Throwable?.plus(throwable: Throwable): Throwable {
if (this != null) { if (this != null) {

Some files were not shown because too many files have changed in this diff Show More