1
0
mirror of https://github.com/square/okhttp.git synced 2025-04-19 07:42:15 +03:00

Reformat with Spotless (#8180)

* Enable spotless

* Run spotlessApply

* Fixup trimMargin

* Re-run spotlessApply
This commit is contained in:
Jesse Wilson 2024-01-07 20:13:22 -05:00 committed by GitHub
parent 0e312d7804
commit a228fd64cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
442 changed files with 24992 additions and 18542 deletions

View File

@ -2,9 +2,10 @@ root = true
[*]
indent_size = 2
ij_continuation_indent_size = 2
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.{kt, kts}]
kotlin_imports_layout = ascii
ij_kotlin_imports_layout = *

View File

@ -24,7 +24,6 @@ import org.junit.Test
* Run with "./gradlew :android-test-app:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/
class PublicSuffixDatabaseTest {
@Test
fun testTopLevelDomain() {
assertThat("https://www.google.com/robots.txt".toHttpUrl().topPrivateDomain()).isEqualTo("google.com")

View File

@ -19,20 +19,13 @@ import android.os.Bundle
import androidx.activity.ComponentActivity
import okhttp3.Call
import okhttp3.Callback
import okhttp3.CookieJar
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.internal.publicsuffix.PublicSuffixDatabase
import okio.FileSystem
import okio.IOException
import okio.Path.Companion.toPath
import java.net.CookieHandler
import java.net.URI
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
@ -41,15 +34,23 @@ class MainActivity : ComponentActivity() {
val url = "https://github.com/square/okhttp".toHttpUrl()
println(url.topPrivateDomain())
client.newCall(Request(url)).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
println("failed: $e")
}
client.newCall(Request(url)).enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
println("failed: $e")
}
override fun onResponse(call: Call, response: Response) {
println("response: ${response.code}")
response.close()
}
})
override fun onResponse(
call: Call,
response: Response,
) {
println("response: ${response.code}")
response.close()
}
},
)
}
}

View File

@ -17,5 +17,4 @@ package okhttp.android.testapp
import android.app.Application
class TestApplication: Application() {
}
class TestApplication : Application()

View File

@ -21,6 +21,25 @@ import com.google.android.gms.common.GooglePlayServicesNotAvailableException
import com.google.android.gms.security.ProviderInstaller
import com.squareup.moshi.Moshi
import com.squareup.moshi.kotlin.reflect.KotlinJsonAdapterFactory
import java.io.IOException
import java.net.InetAddress
import java.net.UnknownHostException
import java.security.KeyStore
import java.security.SecureRandom
import java.security.Security
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Handler
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import mockwebserver3.MockResponse
import mockwebserver3.MockWebServer
import mockwebserver3.junit5.internal.MockWebServerExtension
@ -31,6 +50,7 @@ import okhttp3.Connection
import okhttp3.DelegatingSSLSocket
import okhttp3.DelegatingSSLSocketFactory
import okhttp3.EventListener
import okhttp3.Headers
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.OkHttpClient
import okhttp3.OkHttpClientTestRule
@ -59,33 +79,13 @@ import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Assertions.fail
import org.junit.jupiter.api.Assumptions.assumeTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Tag
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.junit.jupiter.api.extension.RegisterExtension
import org.opentest4j.TestAbortedException
import java.io.IOException
import java.net.InetAddress
import java.net.UnknownHostException
import java.security.KeyStore
import java.security.SecureRandom
import java.security.Security
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Handler
import java.util.logging.Level
import java.util.logging.LogRecord
import java.util.logging.Logger
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import okhttp3.Headers
import org.junit.jupiter.api.BeforeEach
/**
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
@ -93,22 +93,25 @@ import org.junit.jupiter.api.BeforeEach
@ExtendWith(MockWebServerExtension::class)
@Tag("Slow")
class OkHttpTest {
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension
public val platform = PlatformRule()
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension public val platform = PlatformRule()
@Suppress("RedundantVisibilityModifier")
@JvmField
@RegisterExtension public val clientTestRule = OkHttpClientTestRule().apply {
logger = Logger.getLogger(OkHttpTest::class.java.name)
}
@RegisterExtension
public val clientTestRule =
OkHttpClientTestRule().apply {
logger = Logger.getLogger(OkHttpTest::class.java.name)
}
private var client: OkHttpClient = clientTestRule.newClient()
private val moshi = Moshi.Builder()
.add(KotlinJsonAdapterFactory())
.build()
private val moshi =
Moshi.Builder()
.add(KotlinJsonAdapterFactory())
.build()
private val handshakeCertificates = localhost()
@ -136,18 +139,20 @@ class OkHttpTest {
val request = Request.Builder().url("https://api.twitter.com/robots.txt").build()
val clientCertificates = HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.apply {
if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName)
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.apply {
if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName)
}
}
}
.build()
.build()
client = client.newBuilder()
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
client =
client.newBuilder()
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute()
@ -184,22 +189,29 @@ class OkHttpTest {
var socketClass: String? = null
val clientCertificates = HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.addInsecureHost(server.hostName)
.build()
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.addInsecureHost(server.hostName)
.build()
// Need fresh client to reset sslSocketFactoryOrNull
client = OkHttpClient.Builder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
socketClass = connection.socket().javaClass.name
}
})
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
client =
OkHttpClient.Builder()
.eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute()
@ -240,26 +252,33 @@ class OkHttpTest {
throw TestAbortedException("Google Play Services not available", gpsnae)
}
val clientCertificates = HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.addInsecureHost(server.hostName)
.build()
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.addInsecureHost(server.hostName)
.build()
val request = Request.Builder().url("https://facebook.com/robots.txt").build()
var socketClass: String? = null
// Need fresh client to reset sslSocketFactoryOrNull
client = OkHttpClient.Builder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
socketClass = connection.socket().javaClass.name
}
})
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
client =
OkHttpClient.Builder()
.eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute()
@ -298,24 +317,31 @@ class OkHttpTest {
var socketClass: String? = null
val clientCertificates = HandshakeCertificates.Builder()
.addPlatformTrustedCertificates().apply {
if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName)
}
}
.build()
client = client.newBuilder()
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
socketClass = connection.socket().javaClass.name
val clientCertificates =
HandshakeCertificates.Builder()
.addPlatformTrustedCertificates().apply {
if (Build.VERSION.SDK_INT >= 24) {
addInsecureHost(server.hostName)
}
})
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
}
.build()
client =
client.newBuilder()
.eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
},
),
)
.sslSocketFactory(clientCertificates.sslSocketFactory(), clientCertificates.trustManager)
.build()
val response = client.newCall(request).execute()
@ -372,7 +398,7 @@ class OkHttpTest {
val tls_version: String,
val able_to_detect_n_minus_one_splitting: Boolean,
val insecure_cipher_suites: Map<String, List<String>>,
val given_cipher_suites: List<String>?
val given_cipher_suites: List<String>?,
)
@Test
@ -384,9 +410,10 @@ class OkHttpTest {
val response = client.newCall(request).execute()
val results = response.use {
moshi.adapter(HowsMySslResults::class.java).fromJson(response.body.string())!!
}
val results =
response.use {
moshi.adapter(HowsMySslResults::class.java).fromJson(response.body.string())!!
}
Platform.get().log("results $results", Platform.WARN)
@ -417,7 +444,7 @@ class OkHttpTest {
assertTrue(tlsVersion == TlsVersion.TLS_1_2 || tlsVersion == TlsVersion.TLS_1_3)
assertEquals(
"CN=localhost",
(response.handshake!!.peerCertificates.first() as X509Certificate).subjectDN.name
(response.handshake!!.peerCertificates.first() as X509Certificate).subjectDN.name,
)
}
}
@ -426,9 +453,10 @@ class OkHttpTest {
fun testCertificatePinningFailure() {
enableTls()
val certificatePinner = CertificatePinner.Builder()
.add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")
.build()
val certificatePinner =
CertificatePinner.Builder()
.add(server.hostName, "sha256/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=")
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build()
server.enqueue(MockResponse(body = "abc"))
@ -446,12 +474,13 @@ class OkHttpTest {
fun testCertificatePinningSuccess() {
enableTls()
val certificatePinner = CertificatePinner.Builder()
.add(
server.hostName,
CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0])
)
.build()
val certificatePinner =
CertificatePinner.Builder()
.add(
server.hostName,
CertificatePinner.pin(handshakeCertificates.trustManager.acceptedIssuers[0]),
)
.build()
client = client.newBuilder().certificatePinner(certificatePinner).build()
server.enqueue(MockResponse(body = "abc"))
@ -488,9 +517,9 @@ class OkHttpTest {
"ConnectStart", "SecureConnectStart", "SecureConnectEnd", "ConnectEnd",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased",
"CallEnd"
"CallEnd",
),
eventListener.recordedEventTypes()
eventListener.recordedEventTypes(),
)
eventListener.clearAllEvents()
@ -504,9 +533,9 @@ class OkHttpTest {
"CallStart",
"ConnectionAcquired", "RequestHeadersStart", "RequestHeadersEnd", "ResponseHeadersStart",
"ResponseHeadersEnd", "ResponseBodyStart", "ResponseBodyEnd", "ConnectionReleased",
"CallEnd"
"CallEnd",
),
eventListener.recordedEventTypes()
eventListener.recordedEventTypes(),
)
}
@ -516,15 +545,21 @@ class OkHttpTest {
enableTls()
client = client.newBuilder().eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
val sslSocket = connection.socket() as SSLSocket
client =
client.newBuilder().eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
sessionIds.add(sslSocket.session.id.toByteString().hex())
}
})
).build()
sessionIds.add(sslSocket.session.id.toByteString().hex())
}
},
),
).build()
server.enqueue(MockResponse(body = "abc1"))
server.enqueue(MockResponse(body = "abc2"))
@ -550,9 +585,10 @@ class OkHttpTest {
fun testDnsOverHttps() {
assumeNetwork()
client = client.newBuilder()
.eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory()))
.build()
client =
client.newBuilder()
.eventListenerFactory(clientTestRule.wrap(LoggingEventListener.Factory()))
.build()
val dohDns = buildCloudflareIp(client)
val dohEnabledClient =
@ -566,25 +602,34 @@ class OkHttpTest {
fun testCustomTrustManager() {
assumeNetwork()
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
val trustManager =
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
val sslContext = Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
val sslContext =
Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
val sslSocketFactory = sslContext.socketFactory
val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier)
.build()
client =
client.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier)
.build()
client.get("https://www.facebook.com/robots.txt")
}
@ -598,19 +643,21 @@ class OkHttpTest {
val sslSocketFactory = client.sslSocketFactory
val trustManager = client.x509TrustManager!!
val delegatingSocketFactory = object : DelegatingSSLSocketFactory(sslSocketFactory) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
return object : DelegatingSSLSocket(sslSocket) {
override fun getApplicationProtocol(): String {
throw UnsupportedOperationException()
val delegatingSocketFactory =
object : DelegatingSSLSocketFactory(sslSocketFactory) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
return object : DelegatingSSLSocket(sslSocket) {
override fun getApplicationProtocol(): String {
throw UnsupportedOperationException()
}
}
}
}
}
client = client.newBuilder()
.sslSocketFactory(delegatingSocketFactory, trustManager)
.build()
client =
client.newBuilder()
.sslSocketFactory(delegatingSocketFactory, trustManager)
.build()
val request = Request.Builder().url(server.url("/").toString()).build()
val response = client.newCall(request).execute()
@ -626,34 +673,47 @@ class OkHttpTest {
var withHostCalled = false
var withoutHostCalled = false
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
val trustManager =
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
withoutHostCalled = true
override fun checkServerTrusted(
chain: Array<out X509Certificate>?,
authType: String?,
) {
withoutHostCalled = true
}
// called by Android via reflection in X509TrustManagerExtensions
@Suppress("unused", "UNUSED_PARAMETER")
fun checkServerTrusted(
chain: Array<out X509Certificate>,
authType: String,
hostname: String,
): List<X509Certificate> {
withHostCalled = true
return chain.toList()
}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
@Suppress("unused", "UNUSED_PARAMETER")
// called by Android via reflection in X509TrustManagerExtensions
fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String, hostname: String): List<X509Certificate> {
withHostCalled = true
return chain.toList()
val sslContext =
Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
val sslContext = Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}
val sslSocketFactory = sslContext.socketFactory
val hostnameVerifier = HostnameVerifier { _, _ -> true }
client = client.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier)
.build()
client =
client.newBuilder()
.sslSocketFactory(sslSocketFactory, trustManager)
.hostnameVerifier(hostnameVerifier)
.build()
client.get("https://www.facebook.com/robots.txt")
@ -702,25 +762,33 @@ class OkHttpTest {
var socketClass: String? = null
val trustManager = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(null as KeyStore?)
}.trustManagers.first() as X509TrustManager
val trustManager =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(null as KeyStore?)
}.trustManagers.first() as X509TrustManager
val sslContext = Platform.get().newSSLContext().apply {
// TODO remove most of this code after https://github.com/bcgit/bc-java/issues/686
init(null, arrayOf(trustManager), SecureRandom())
}
val sslContext =
Platform.get().newSSLContext().apply {
// TODO remove most of this code after https://github.com/bcgit/bc-java/issues/686
init(null, arrayOf(trustManager), SecureRandom())
}
client = client.newBuilder()
.sslSocketFactory(sslContext.socketFactory, trustManager)
.eventListenerFactory(
clientTestRule.wrap(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
socketClass = connection.socket().javaClass.name
}
})
)
.build()
client =
client.newBuilder()
.sslSocketFactory(sslContext.socketFactory, trustManager)
.eventListenerFactory(
clientTestRule.wrap(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
socketClass = connection.socket().javaClass.name
}
},
),
)
.build()
val request = Request.Builder().url("https://facebook.com/robots.txt").build()
@ -743,22 +811,23 @@ class OkHttpTest {
fun testLoggingLevels() {
enableTls()
val testHandler = object : Handler() {
val calls = mutableMapOf<String, AtomicInteger>()
val testHandler =
object : Handler() {
val calls = mutableMapOf<String, AtomicInteger>()
override fun publish(record: LogRecord) {
calls.getOrPut(record.loggerName) { AtomicInteger(0) }
.incrementAndGet()
}
override fun publish(record: LogRecord) {
calls.getOrPut(record.loggerName) { AtomicInteger(0) }
.incrementAndGet()
}
override fun flush() {
}
override fun flush() {
}
override fun close() {
override fun close() {
}
}.apply {
level = Level.FINEST
}
}.apply {
level = Level.FINEST
}
Logger.getLogger("")
.addHandler(testHandler)
@ -773,12 +842,14 @@ class OkHttpTest {
server.enqueue(MockResponse(body = "abc"))
val request = Request.Builder()
.url(server.url("/"))
.build()
val request =
Request.Builder()
.url(server.url("/"))
.build()
val response = client.newCall(request)
.execute()
val response =
client.newCall(request)
.execute()
response.use {
assertEquals(200, response.code)
@ -802,13 +873,15 @@ class OkHttpTest {
val cache = Cache(ctxt.cacheDir.resolve("testCache"), cacheSize)
try {
client = client.newBuilder()
.cache(cache)
.build()
client =
client.newBuilder()
.cache(cache)
.build()
val request = Request.Builder()
.url(server.url("/"))
.build()
val request =
Request.Builder()
.url(server.url("/"))
.build()
client.newCall(request)
.execute()
@ -853,11 +926,12 @@ class OkHttpTest {
}
private fun enableTls() {
client = client.newBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager
)
.build()
client =
client.newBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(), handshakeCertificates.trustManager,
)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory())
}

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp.android.test.alpn;
package okhttp.android.test.alpn
import android.os.Build
import android.util.Log
@ -36,9 +36,8 @@ import org.junit.jupiter.api.Test
*/
@Tag("Remote")
class AlpnOverrideTest {
class CustomSSLSocketFactory(
delegate: SSLSocketFactory
delegate: SSLSocketFactory,
) : DelegatingSSLSocketFactory(delegate) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
@ -56,25 +55,34 @@ class AlpnOverrideTest {
@Test
fun getWithCustomSocketFactory() {
client = client.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.connectionSpecs(listOf(
ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
.supportsTlsExtensions(false)
.build()
))
.eventListener(object : EventListener() {
override fun connectionAcquired(call: Call, connection: Connection) {
val sslSocket = connection.socket() as SSLSocket
println("Requested " + sslSocket.sslParameters.applicationProtocols.joinToString())
println("Negotiated " + sslSocket.applicationProtocol)
}
})
.build()
client =
client.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.connectionSpecs(
listOf(
ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
.supportsTlsExtensions(false)
.build(),
),
)
.eventListener(
object : EventListener() {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
val sslSocket = connection.socket() as SSLSocket
println("Requested " + sslSocket.sslParameters.applicationProtocols.joinToString())
println("Negotiated " + sslSocket.applicationProtocol)
}
},
)
.build()
val request = Request.Builder()
.url("https://www.google.com")
.build()
val request =
Request.Builder()
.url("https://www.google.com")
.build()
client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(200)
}

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp.android.test.letsencrypt;
package okhttp.android.test.letsencrypt
import android.os.Build
import assertk.assertThat
@ -41,56 +41,61 @@ class LetsEncryptClientTest {
val clientBuilder = OkHttpClient.Builder()
if (androidMorEarlier) {
val cert: X509Certificate = """
-----BEGIN CERTIFICATE-----
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
-----END CERTIFICATE-----
""".trimIndent().decodeCertificatePem()
val cert: X509Certificate =
"""
-----BEGIN CERTIFICATE-----
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
-----END CERTIFICATE-----
""".trimIndent().decodeCertificatePem()
val handshakeCertificates = HandshakeCertificates.Builder()
// TODO reenable in official answers
val handshakeCertificates =
HandshakeCertificates.Builder()
// TODO reenable in official answers
// .addPlatformTrustedCertificates()
.addTrustedCertificate(cert)
.build()
.addTrustedCertificate(cert)
.build()
clientBuilder
.sslSocketFactory(handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
}
val client = clientBuilder.build()
val request = Request.Builder()
.url("https://valid-isrgrootx1.letsencrypt.org/robots.txt")
.build()
val request =
Request.Builder()
.url("https://valid-isrgrootx1.letsencrypt.org/robots.txt")
.build()
client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(404)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp.android.test.sni;
package okhttp.android.test.sni
import android.os.Build
import android.util.Log
@ -39,15 +39,16 @@ import org.junit.jupiter.api.Test
*/
@Tag("Remote")
class SniOverrideTest {
var client = OkHttpClient.Builder()
.build()
var client =
OkHttpClient.Builder()
.build()
@Test
fun getWithCustomSocketFactory() {
assumeTrue(Build.VERSION.SDK_INT >= 24)
class CustomSSLSocketFactory(
delegate: SSLSocketFactory
delegate: SSLSocketFactory,
) : DelegatingSSLSocketFactory(delegate) {
override fun configureSocket(sslSocket: SSLSocket): SSLSocket {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
@ -62,49 +63,52 @@ class SniOverrideTest {
}
}
client = client.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.hostnameVerifier { hostname, session ->
val s = "hostname: $hostname peerHost:${session.peerHost}"
Log.d("SniOverrideTest", s)
try {
val cert = session.peerCertificates[0] as X509Certificate
for (name in cert.subjectAlternativeNames) {
if (name[0] as Int == 2) {
Log.d("SniOverrideTest", "cert: " + name[1])
client =
client.newBuilder()
.sslSocketFactory(CustomSSLSocketFactory(client.sslSocketFactory), client.x509TrustManager!!)
.hostnameVerifier { hostname, session ->
val s = "hostname: $hostname peerHost:${session.peerHost}"
Log.d("SniOverrideTest", s)
try {
val cert = session.peerCertificates[0] as X509Certificate
for (name in cert.subjectAlternativeNames) {
if (name[0] as Int == 2) {
Log.d("SniOverrideTest", "cert: " + name[1])
}
}
true
} catch (e: Exception) {
false
}
true
} catch (e: Exception) {
false
}
}
.build()
.build()
val request = Request.Builder()
.url("https://sni.cloudflaressl.com/cdn-cgi/trace")
.header("Host", "cloudflare-dns.com")
.build()
val request =
Request.Builder()
.url("https://sni.cloudflaressl.com/cdn-cgi/trace")
.header("Host", "cloudflare-dns.com")
.build()
client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(200)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)
assertThat(response.body.string()).contains("h=cloudflare-dns.com")
}
}
@Test
fun getWithDns() {
client = client.newBuilder()
.dns {
Dns.SYSTEM.lookup("sni.cloudflaressl.com")
}
.build()
client =
client.newBuilder()
.dns {
Dns.SYSTEM.lookup("sni.cloudflaressl.com")
}
.build()
val request = Request.Builder()
.url("https://cloudflare-dns.com/cdn-cgi/trace")
.build()
val request =
Request.Builder()
.url("https://cloudflare-dns.com/cdn-cgi/trace")
.build()
client.newCall(request).execute().use { response ->
assertThat(response.code).isEqualTo(200)
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)

View File

@ -1,14 +1,14 @@
@file:Suppress("UnstableApiUsage")
import com.diffplug.gradle.spotless.SpotlessExtension
import com.vanniktech.maven.publish.MavenPublishBaseExtension
import com.vanniktech.maven.publish.SonatypeHost
import java.net.URL
import java.net.URI
import kotlinx.validation.ApiValidationExtension
import org.gradle.api.tasks.testing.logging.TestExceptionFormat
import org.jetbrains.dokka.gradle.DokkaTaskPartial
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import ru.vyarus.gradle.plugin.animalsniffer.AnimalSnifferExtension
import java.net.URI
buildscript {
dependencies {
@ -35,6 +35,14 @@ buildscript {
}
apply(plugin = "org.jetbrains.dokka")
apply(plugin = "com.diffplug.spotless")
configure<SpotlessExtension> {
kotlin {
target("**/*.kt")
ktlint()
}
}
allprojects {
group = "com.squareup.okhttp3"

View File

@ -34,21 +34,23 @@ fun Project.applyOsgi(vararg bndProperties: String) {
private fun Project.applyOsgi(
jarTaskName: String,
osgiApiConfigurationName: String,
bndProperties: Array<out String>
bndProperties: Array<out String>,
) {
val osgi = project.sourceSets.create("osgi")
val osgiApi = project.configurations.getByName(osgiApiConfigurationName)
val kotlinOsgi = extensions.getByType(VersionCatalogsExtension::class.java).named("libs")
.findLibrary("kotlin.stdlib.osgi").get().get()
val kotlinOsgi =
extensions.getByType(VersionCatalogsExtension::class.java).named("libs")
.findLibrary("kotlin.stdlib.osgi").get().get()
project.dependencies {
osgiApi(kotlinOsgi)
}
val jarTask = tasks.getByName<Jar>(jarTaskName)
val bundleExtension = jarTask.extensions.findByType() ?: jarTask.extensions.create(
BundleTaskExtension.NAME, BundleTaskExtension::class.java, jarTask
)
val bundleExtension =
jarTask.extensions.findByType() ?: jarTask.extensions.create(
BundleTaskExtension.NAME, BundleTaskExtension::class.java, jarTask,
)
bundleExtension.run {
setClasspath(osgi.compileClasspath + sourceSets["main"].compileClasspath)
bnd(*bndProperties)

View File

@ -36,9 +36,7 @@ internal fun Dispatcher.wrap(): mockwebserver3.Dispatcher {
val delegate = this
return object : mockwebserver3.Dispatcher() {
override fun dispatch(
request: mockwebserver3.RecordedRequest
): mockwebserver3.MockResponse {
override fun dispatch(request: mockwebserver3.RecordedRequest): mockwebserver3.MockResponse {
return delegate.dispatch(request.unwrap()).wrap()
}
@ -70,17 +68,18 @@ internal fun MockResponse.wrap(): mockwebserver3.MockResponse {
result.status = status
result.headers(headers)
result.trailers(trailers)
result.socketPolicy = when (socketPolicy) {
SocketPolicy.EXPECT_CONTINUE, SocketPolicy.CONTINUE_ALWAYS -> {
result.add100Continue()
KeepOpen
result.socketPolicy =
when (socketPolicy) {
SocketPolicy.EXPECT_CONTINUE, SocketPolicy.CONTINUE_ALWAYS -> {
result.add100Continue()
KeepOpen
}
SocketPolicy.UPGRADE_TO_SSL_AT_END -> {
result.inTunnel()
KeepOpen
}
else -> wrapSocketPolicy()
}
SocketPolicy.UPGRADE_TO_SSL_AT_END -> {
result.inTunnel()
KeepOpen
}
else -> wrapSocketPolicy()
}
result.throttleBody(throttleBytesPerPeriod, getThrottlePeriod(MILLISECONDS), MILLISECONDS)
result.bodyDelay(getBodyDelay(MILLISECONDS), MILLISECONDS)
result.headersDelay(getHeadersDelay(MILLISECONDS), MILLISECONDS)
@ -92,7 +91,7 @@ private fun PushPromise.wrap(): mockwebserver3.PushPromise {
method = method,
path = path,
headers = headers,
response = response.wrap()
response = response.wrap(),
)
}
@ -108,7 +107,7 @@ internal fun mockwebserver3.RecordedRequest.unwrap(): RecordedRequest {
method = method,
path = path,
handshake = handshake,
requestUrl = requestUrl
requestUrl = requestUrl,
)
}

View File

@ -14,14 +14,15 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.mockwebserver
import java.util.concurrent.TimeUnit
import okhttp3.Headers
import okhttp3.WebSocketListener
import okhttp3.internal.addHeaderLenient
import okhttp3.internal.http2.Settings
import okio.Buffer
import java.util.concurrent.TimeUnit
class MockResponse : Cloneable {
@set:JvmName("status")
@ -86,62 +87,81 @@ class MockResponse : Cloneable {
@JvmName("-deprecated_getStatus")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "status"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "status"),
level = DeprecationLevel.ERROR,
)
fun getStatus(): String = status
fun setStatus(status: String) = apply {
this.status = status
}
fun setStatus(status: String) =
apply {
this.status = status
}
fun setResponseCode(code: Int): MockResponse {
val reason = when (code) {
in 100..199 -> "Informational"
in 200..299 -> "OK"
in 300..399 -> "Redirection"
in 400..499 -> "Client Error"
in 500..599 -> "Server Error"
else -> "Mock Response"
}
val reason =
when (code) {
in 100..199 -> "Informational"
in 200..299 -> "OK"
in 300..399 -> "Redirection"
in 400..499 -> "Client Error"
in 500..599 -> "Server Error"
else -> "Mock Response"
}
return apply { status = "HTTP/1.1 $code $reason" }
}
fun clearHeaders() = apply {
headersBuilder = Headers.Builder()
}
fun clearHeaders() =
apply {
headersBuilder = Headers.Builder()
}
fun addHeader(header: String) = apply {
headersBuilder.add(header)
}
fun addHeader(header: String) =
apply {
headersBuilder.add(header)
}
fun addHeader(name: String, value: Any) = apply {
fun addHeader(
name: String,
value: Any,
) = apply {
headersBuilder.add(name, value.toString())
}
fun addHeaderLenient(name: String, value: Any) = apply {
fun addHeaderLenient(
name: String,
value: Any,
) = apply {
addHeaderLenient(headersBuilder, name, value.toString())
}
fun setHeader(name: String, value: Any) = apply {
fun setHeader(
name: String,
value: Any,
) = apply {
removeHeader(name)
addHeader(name, value)
}
fun removeHeader(name: String) = apply {
headersBuilder.removeAll(name)
}
fun removeHeader(name: String) =
apply {
headersBuilder.removeAll(name)
}
fun getBody(): Buffer? = body?.clone()
fun setBody(body: Buffer) = apply {
setHeader("Content-Length", body.size)
this.body = body.clone() // Defensive copy.
}
fun setBody(body: Buffer) =
apply {
setHeader("Content-Length", body.size)
this.body = body.clone() // Defensive copy.
}
fun setBody(body: String): MockResponse = setBody(Buffer().writeUtf8(body))
fun setChunkedBody(body: Buffer, maxChunkSize: Int) = apply {
fun setChunkedBody(
body: Buffer,
maxChunkSize: Int,
) = apply {
removeHeader("Content-Length")
headersBuilder.add(CHUNKED_BODY_HEADER)
@ -157,89 +177,107 @@ class MockResponse : Cloneable {
this.body = bytesOut
}
fun setChunkedBody(body: String, maxChunkSize: Int): MockResponse =
setChunkedBody(Buffer().writeUtf8(body), maxChunkSize)
fun setChunkedBody(
body: String,
maxChunkSize: Int,
): MockResponse = setChunkedBody(Buffer().writeUtf8(body), maxChunkSize)
@JvmName("-deprecated_getHeaders")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR,
)
fun getHeaders(): Headers = headers
fun setHeaders(headers: Headers) = apply { this.headers = headers }
@JvmName("-deprecated_getTrailers")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "trailers"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "trailers"),
level = DeprecationLevel.ERROR,
)
fun getTrailers(): Headers = trailers
fun setTrailers(trailers: Headers) = apply { this.trailers = trailers }
@JvmName("-deprecated_getSocketPolicy")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "socketPolicy"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "socketPolicy"),
level = DeprecationLevel.ERROR,
)
fun getSocketPolicy(): SocketPolicy = socketPolicy
fun setSocketPolicy(socketPolicy: SocketPolicy) = apply {
this.socketPolicy = socketPolicy
}
fun setSocketPolicy(socketPolicy: SocketPolicy) =
apply {
this.socketPolicy = socketPolicy
}
@JvmName("-deprecated_getHttp2ErrorCode")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "http2ErrorCode"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "http2ErrorCode"),
level = DeprecationLevel.ERROR,
)
fun getHttp2ErrorCode(): Int = http2ErrorCode
fun setHttp2ErrorCode(http2ErrorCode: Int) = apply {
this.http2ErrorCode = http2ErrorCode
}
fun setHttp2ErrorCode(http2ErrorCode: Int) =
apply {
this.http2ErrorCode = http2ErrorCode
}
fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply {
fun throttleBody(
bytesPerPeriod: Long,
period: Long,
unit: TimeUnit,
) = apply {
throttleBytesPerPeriod = bytesPerPeriod
throttlePeriodAmount = period
throttlePeriodUnit = unit
}
fun getThrottlePeriod(unit: TimeUnit): Long =
unit.convert(throttlePeriodAmount, throttlePeriodUnit)
fun getThrottlePeriod(unit: TimeUnit): Long = unit.convert(throttlePeriodAmount, throttlePeriodUnit)
fun setBodyDelay(delay: Long, unit: TimeUnit) = apply {
fun setBodyDelay(
delay: Long,
unit: TimeUnit,
) = apply {
bodyDelayAmount = delay
bodyDelayUnit = unit
}
fun getBodyDelay(unit: TimeUnit): Long =
unit.convert(bodyDelayAmount, bodyDelayUnit)
fun getBodyDelay(unit: TimeUnit): Long = unit.convert(bodyDelayAmount, bodyDelayUnit)
fun setHeadersDelay(delay: Long, unit: TimeUnit) = apply {
fun setHeadersDelay(
delay: Long,
unit: TimeUnit,
) = apply {
headersDelayAmount = delay
headersDelayUnit = unit
}
fun getHeadersDelay(unit: TimeUnit): Long =
unit.convert(headersDelayAmount, headersDelayUnit)
fun getHeadersDelay(unit: TimeUnit): Long = unit.convert(headersDelayAmount, headersDelayUnit)
fun withPush(promise: PushPromise) = apply {
promises.add(promise)
}
fun withPush(promise: PushPromise) =
apply {
promises.add(promise)
}
fun withSettings(settings: Settings) = apply {
this.settings = settings
}
fun withSettings(settings: Settings) =
apply {
this.settings = settings
}
fun withWebSocketUpgrade(listener: WebSocketListener) = apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")
body = null
webSocketListener = listener
}
fun withWebSocketUpgrade(listener: WebSocketListener) =
apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")
body = null
webSocketListener = listener
}
override fun toString(): String = status

View File

@ -15,9 +15,6 @@
*/
package okhttp3.mockwebserver
import okhttp3.HttpUrl
import okhttp3.Protocol
import org.junit.rules.ExternalResource
import java.io.Closeable
import java.io.IOException
import java.net.InetAddress
@ -27,6 +24,9 @@ import java.util.logging.Level
import java.util.logging.Logger
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.HttpUrl
import okhttp3.Protocol
import org.junit.rules.ExternalResource
class MockWebServer : ExternalResource(), Closeable {
val delegate = mockwebserver3.MockWebServer()
@ -57,7 +57,8 @@ class MockWebServer : ExternalResource(), Closeable {
var protocolNegotiationEnabled: Boolean by delegate::protocolNegotiationEnabled
@get:JvmName("protocols") var protocols: List<Protocol>
@get:JvmName("protocols")
var protocols: List<Protocol>
get() = delegate.protocols
set(value) {
delegate.protocols = value
@ -80,9 +81,10 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_port")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "port"),
level = DeprecationLevel.ERROR)
message = "moved to val",
replaceWith = ReplaceWith(expression = "port"),
level = DeprecationLevel.ERROR,
)
fun getPort(): Int = port
fun toProxyAddress(): Proxy {
@ -92,11 +94,13 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_serverSocketFactory")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.serverSocketFactory = serverSocketFactory }"
message = "moved to var",
replaceWith =
ReplaceWith(
expression = "run { this.serverSocketFactory = serverSocketFactory }",
),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) {
delegate.serverSocketFactory = serverSocketFactory
}
@ -108,43 +112,52 @@ class MockWebServer : ExternalResource(), Closeable {
@JvmName("-deprecated_bodyLimit")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.bodyLimit = bodyLimit }"
message = "moved to var",
replaceWith =
ReplaceWith(
expression = "run { this.bodyLimit = bodyLimit }",
),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setBodyLimit(bodyLimit: Long) {
delegate.bodyLimit = bodyLimit
}
@JvmName("-deprecated_protocolNegotiationEnabled")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }"
message = "moved to var",
replaceWith =
ReplaceWith(
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }",
),
level = DeprecationLevel.ERROR)
level = DeprecationLevel.ERROR,
)
fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) {
delegate.protocolNegotiationEnabled = protocolNegotiationEnabled
}
@JvmName("-deprecated_protocols")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"),
level = DeprecationLevel.ERROR,
)
fun setProtocols(protocols: List<Protocol>) {
delegate.protocols = protocols
}
@JvmName("-deprecated_protocols")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "protocols"),
level = DeprecationLevel.ERROR)
message = "moved to var",
replaceWith = ReplaceWith(expression = "protocols"),
level = DeprecationLevel.ERROR,
)
fun protocols(): List<Protocol> = delegate.protocols
fun useHttps(sslSocketFactory: SSLSocketFactory, tunnelProxy: Boolean) {
fun useHttps(
sslSocketFactory: SSLSocketFactory,
tunnelProxy: Boolean,
) {
delegate.useHttps(sslSocketFactory)
}
@ -166,15 +179,19 @@ class MockWebServer : ExternalResource(), Closeable {
}
@Throws(InterruptedException::class)
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? {
fun takeRequest(
timeout: Long,
unit: TimeUnit,
): RecordedRequest? {
return delegate.takeRequest(timeout, unit)?.unwrap()
}
@JvmName("-deprecated_requestCount")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "requestCount"),
level = DeprecationLevel.ERROR)
message = "moved to val",
replaceWith = ReplaceWith(expression = "requestCount"),
level = DeprecationLevel.ERROR,
)
fun getRequestCount(): Int = delegate.requestCount
fun enqueue(response: MockResponse) {
@ -182,13 +199,17 @@ class MockWebServer : ExternalResource(), Closeable {
}
@Throws(IOException::class)
@JvmOverloads fun start(port: Int = 0) {
@JvmOverloads
fun start(port: Int = 0) {
started = true
delegate.start(port)
}
@Throws(IOException::class)
fun start(inetAddress: InetAddress, port: Int) {
fun start(
inetAddress: InetAddress,
port: Int,
) {
started = true
delegate.start(inetAddress, port)
}

View File

@ -21,34 +21,37 @@ class PushPromise(
@get:JvmName("method") val method: String,
@get:JvmName("path") val path: String,
@get:JvmName("headers") val headers: Headers,
@get:JvmName("response") val response: MockResponse
@get:JvmName("response") val response: MockResponse,
) {
@JvmName("-deprecated_method")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "method"),
level = DeprecationLevel.ERROR)
message = "moved to val",
replaceWith = ReplaceWith(expression = "method"),
level = DeprecationLevel.ERROR,
)
fun method(): String = method
@JvmName("-deprecated_path")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "path"),
level = DeprecationLevel.ERROR)
message = "moved to val",
replaceWith = ReplaceWith(expression = "path"),
level = DeprecationLevel.ERROR,
)
fun path(): String = path
@JvmName("-deprecated_headers")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR)
message = "moved to val",
replaceWith = ReplaceWith(expression = "headers"),
level = DeprecationLevel.ERROR,
)
fun headers(): Headers = headers
@JvmName("-deprecated_response")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "response"),
level = DeprecationLevel.ERROR)
message = "moved to val",
replaceWith = ReplaceWith(expression = "response"),
level = DeprecationLevel.ERROR,
)
fun response(): MockResponse = response
}

View File

@ -15,6 +15,10 @@
*/
package okhttp3.mockwebserver
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
import okhttp3.Handshake
import okhttp3.Handshake.Companion.handshake
import okhttp3.Headers
@ -22,10 +26,6 @@ import okhttp3.HttpUrl
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
import okhttp3.TlsVersion
import okio.Buffer
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
class RecordedRequest {
val requestLine: String
@ -42,9 +42,10 @@ class RecordedRequest {
@get:JvmName("-deprecated_utf8Body")
@Deprecated(
message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.ERROR)
message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.ERROR,
)
val utf8Body: String
get() = body.readUtf8()
@ -62,7 +63,7 @@ class RecordedRequest {
method: String?,
path: String?,
handshake: Handshake?,
requestUrl: HttpUrl?
requestUrl: HttpUrl?,
) {
this.requestLine = requestLine
this.headers = headers
@ -86,7 +87,7 @@ class RecordedRequest {
body: Buffer,
sequenceNumber: Int,
socket: Socket,
failure: IOException? = null
failure: IOException? = null,
) {
this.requestLine = requestLine
this.headers = headers
@ -139,9 +140,10 @@ class RecordedRequest {
}
@Deprecated(
message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.WARNING)
message = "Use body.readUtf8()",
replaceWith = ReplaceWith("body.readUtf8()"),
level = DeprecationLevel.WARNING,
)
fun getUtf8Body(): String = body.readUtf8()
fun getHeader(name: String): String? = headers.values(name).firstOrNull()

View File

@ -32,5 +32,5 @@ enum class SocketPolicy {
NO_RESPONSE,
RESET_STREAM_AT_START,
EXPECT_CONTINUE,
CONTINUE_ALWAYS
CONTINUE_ALWAYS,
}

View File

@ -15,6 +15,12 @@
*/
package okhttp3.mockwebserver
import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.Handshake
import okhttp3.Headers
import okhttp3.Headers.Companion.headersOf
@ -26,35 +32,32 @@ import okhttp3.internal.http2.Settings
import okio.Buffer
import org.junit.Ignore
import org.junit.Test
import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
/**
* Access every type, function, and property from Kotlin to defend against unexpected regressions in
* modern 4.0.x kotlin source-compatibility.
*/
@Suppress(
"ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE",
"UNUSED_ANONYMOUS_PARAMETER",
"UNUSED_VALUE",
"UNUSED_VARIABLE",
"VARIABLE_WITH_REDUNDANT_INITIALIZER",
"RedundantLambdaArrow",
"RedundantExplicitType",
"IMPLICIT_NOTHING_AS_TYPE_PARAMETER"
"ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE",
"UNUSED_ANONYMOUS_PARAMETER",
"UNUSED_VALUE",
"UNUSED_VARIABLE",
"VARIABLE_WITH_REDUNDANT_INITIALIZER",
"RedundantLambdaArrow",
"RedundantExplicitType",
"IMPLICIT_NOTHING_AS_TYPE_PARAMETER",
)
class KotlinSourceModernTest {
@Test @Ignore
fun dispatcherFromMockWebServer() {
val dispatcher = object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse = TODO()
override fun peek(): MockResponse = TODO()
override fun shutdown() = TODO()
}
val dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse = TODO()
override fun peek(): MockResponse = TODO()
override fun shutdown() = TODO()
}
}
@Test @Ignore
@ -96,8 +99,11 @@ class KotlinSourceModernTest {
mockResponse = mockResponse.withSettings(Settings())
var settings: Settings = mockResponse.settings
settings = mockResponse.settings
mockResponse = mockResponse.withWebSocketUpgrade(object : WebSocketListener() {
})
mockResponse =
mockResponse.withWebSocketUpgrade(
object : WebSocketListener() {
},
)
var webSocketListener: WebSocketListener? = mockResponse.webSocketListener
webSocketListener = mockResponse.webSocketListener
}
@ -146,8 +152,10 @@ class KotlinSourceModernTest {
@Test @Ignore
fun queueDispatcher() {
val queueDispatcher: QueueDispatcher = QueueDispatcher()
var mockResponse: MockResponse = queueDispatcher.dispatch(
RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket()))
var mockResponse: MockResponse =
queueDispatcher.dispatch(
RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket()),
)
mockResponse = queueDispatcher.peek()
queueDispatcher.enqueueResponse(MockResponse())
queueDispatcher.shutdown()
@ -157,8 +165,16 @@ class KotlinSourceModernTest {
@Test @Ignore
fun recordedRequest() {
var recordedRequest: RecordedRequest = RecordedRequest(
"", headersOf(), listOf(), 0L, Buffer(), 0, Socket())
var recordedRequest: RecordedRequest =
RecordedRequest(
"",
headersOf(),
listOf(),
0L,
Buffer(),
0,
Socket(),
)
recordedRequest = RecordedRequest("", headersOf(), listOf(), 0L, Buffer(), 0, Socket())
var requestUrl: HttpUrl? = recordedRequest.requestUrl
var requestLine: String = recordedRequest.requestLine

View File

@ -86,15 +86,16 @@ class MockWebServerTest {
@Test
fun setResponseMockReason() {
val reasons = arrayOf(
"Mock Response",
"Informational",
"OK",
"Redirection",
"Client Error",
"Server Error",
"Mock Response"
)
val reasons =
arrayOf(
"Mock Response",
"Informational",
"OK",
"Redirection",
"Client Error",
"Server Error",
"Mock Response",
)
for (i in 0..599) {
val response = MockResponse().setResponseCode(i)
val expectedReason = reasons[i / 100]
@ -119,21 +120,23 @@ class MockWebServerTest {
@Test
fun mockResponseAddHeader() {
val response = MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
val response =
MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
assertThat(headersToList(response))
.containsExactly("Cookie: s=square", "Cookie: a=android")
}
@Test
fun mockResponseSetHeader() {
val response = MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie: a=android")
.addHeader("Cookies: delicious")
val response =
MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie: a=android")
.addHeader("Cookies: delicious")
response.setHeader("cookie", "r=robot")
assertThat(headersToList(response))
.containsExactly("Cookies: delicious", "cookie: r=robot")
@ -141,10 +144,11 @@ class MockWebServerTest {
@Test
fun mockResponseSetHeaders() {
val response = MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
val response =
MockResponse()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
response.setHeaders(Headers.Builder().add("Cookie", "a=android").build())
assertThat(headersToList(response)).containsExactly("Cookie: a=android")
}
@ -173,7 +177,7 @@ class MockWebServerTest {
MockResponse()
.setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path"))
.setBody("This page has moved!")
.setBody("This page has moved!"),
)
server.enqueue(MockResponse().setBody("This is the new location!"))
val connection = server.url("/").toUrl().openConnection()
@ -211,7 +215,7 @@ class MockWebServerTest {
MockResponse()
.setBody("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders()
.addHeader("Transfer-encoding: chunked")
.addHeader("Transfer-encoding: chunked"),
)
val connection = server.url("/").toUrl().openConnection()
val inputStream = connection.getInputStream()
@ -228,7 +232,7 @@ class MockWebServerTest {
MockResponse()
.setBody("ABC")
.clearHeaders()
.addHeader("Content-Length: 4")
.addHeader("Content-Length: 4"),
)
server.enqueue(MockResponse().setBody("DEF"))
val urlConnection = server.url("/").toUrl().openConnection()
@ -257,7 +261,7 @@ class MockWebServerTest {
fun disconnectAtStart() {
server.enqueue(
MockResponse()
.setSocketPolicy(SocketPolicy.DISCONNECT_AT_START)
.setSocketPolicy(SocketPolicy.DISCONNECT_AT_START),
)
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse())
@ -278,7 +282,7 @@ class MockWebServerTest {
assumeNotWindows()
server.enqueue(
MockResponse()
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.throttleBody(3, 500, TimeUnit.MILLISECONDS),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
@ -301,7 +305,7 @@ class MockWebServerTest {
server.enqueue(
MockResponse()
.setBody("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.throttleBody(3, 500, TimeUnit.MILLISECONDS),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
@ -325,7 +329,7 @@ class MockWebServerTest {
server.enqueue(
MockResponse()
.setBody("ABCDEF")
.setBodyDelay(1, TimeUnit.SECONDS)
.setBodyDelay(1, TimeUnit.SECONDS),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
@ -373,7 +377,7 @@ class MockWebServerTest {
server.enqueue(
MockResponse()
.setBody("ab")
.setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY)
.setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY),
)
val connection = server.url("/").toUrl().openConnection()
assertThat(connection.getContentLength()).isEqualTo(2)
@ -443,12 +447,16 @@ class MockWebServerTest {
@Test
fun statementStartsAndStops() {
val called = AtomicBoolean()
val statement = server.apply(object : Statement() {
override fun evaluate() {
called.set(true)
server.url("/").toUrl().openConnection().connect()
}
}, Description.EMPTY)
val statement =
server.apply(
object : Statement() {
override fun evaluate() {
called.set(true)
server.url("/").toUrl().openConnection().connect()
}
},
Description.EMPTY,
)
statement.evaluate()
assertThat(called.get()).isTrue()
try {
@ -484,7 +492,7 @@ class MockWebServerTest {
assertThat(reader.readLine()).isEqualTo("hello world")
val request = server.takeRequest()
assertThat(request.requestLine).isEqualTo(
"GET /a/deep/path?key=foo%20bar HTTP/1.1"
"GET /a/deep/path?key=foo%20bar HTTP/1.1",
)
val requestUrl = request.requestUrl
assertThat(requestUrl!!.scheme).isEqualTo("http")
@ -530,8 +538,8 @@ class MockWebServerTest {
fail<Any>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, http/1.1]"
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, http/1.1]",
)
}
}
@ -544,8 +552,8 @@ class MockWebServerTest {
fail<Any>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, h2_prior_knowledge]"
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, h2_prior_knowledge]",
)
}
}
@ -584,30 +592,36 @@ class MockWebServerTest {
fun httpsWithClientAuth() {
platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt()
val clientCa = HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCa = HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCertificate = HeldCertificate.Builder()
.signedBy(serverCa)
.addSubjectAlternativeName(server.hostName)
.build()
val serverHandshakeCertificates = HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
val clientCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCertificate =
HeldCertificate.Builder()
.signedBy(serverCa)
.addSubjectAlternativeName(server.hostName)
.build()
val serverHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory(), false)
server.enqueue(MockResponse().setBody("abc"))
server.requestClientAuth()
val clientCertificate = HeldCertificate.Builder()
.signedBy(clientCa)
.build()
val clientHandshakeCertificates = HandshakeCertificates.Builder()
.addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
val clientCertificate =
HeldCertificate.Builder()
.signedBy(clientCa)
.build()
val clientHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.setSSLSocketFactory(clientHandshakeCertificates.sslSocketFactory())

View File

@ -15,11 +15,11 @@
*/
package mockwebserver3.junit4
import mockwebserver3.MockWebServer
import org.junit.rules.ExternalResource
import java.io.IOException
import java.util.logging.Level
import java.util.logging.Logger
import mockwebserver3.MockWebServer
import org.junit.rules.ExternalResource
/**
* Runs MockWebServer for the duration of a single test method.

View File

@ -28,12 +28,16 @@ class MockWebServerRuleTest {
@Test fun statementStartsAndStops() {
val rule = MockWebServerRule()
val called = AtomicBoolean()
val statement: Statement = rule.apply(object : Statement() {
override fun evaluate() {
called.set(true)
rule.server.url("/").toUrl().openConnection().connect()
}
}, Description.EMPTY)
val statement: Statement =
rule.apply(
object : Statement() {
override fun evaluate() {
called.set(true)
rule.server.url("/").toUrl().openConnection().connect()
}
},
Description.EMPTY,
)
statement.evaluate()
assertThat(called.get()).isTrue()
try {

View File

@ -38,12 +38,13 @@ import org.junit.jupiter.api.extension.ParameterResolver
* - The test lifecycle default (passed into test method, plus @BeforeEach, @AfterEach)
* - named instances with @MockWebServerInstance.
*/
class MockWebServerExtension
: BeforeEachCallback, AfterEachCallback, ParameterResolver {
class MockWebServerExtension :
BeforeEachCallback, AfterEachCallback, ParameterResolver {
private val ExtensionContext.resource: ServersForTest
get() = getStore(namespace).getOrComputeIfAbsent(this.uniqueId) {
ServersForTest()
} as ServersForTest
get() =
getStore(namespace).getOrComputeIfAbsent(this.uniqueId) {
ServersForTest()
} as ServersForTest
private class ServersForTest {
private val servers = mutableMapOf<String, MockWebServer>()
@ -80,24 +81,25 @@ class MockWebServerExtension
@IgnoreJRERequirement
override fun supportsParameter(
parameterContext: ParameterContext,
extensionContext: ExtensionContext
extensionContext: ExtensionContext,
): Boolean {
// Not supported on constructors, or static contexts
return parameterContext.parameter.type === MockWebServer::class.java
&& extensionContext.testMethod.isPresent
return parameterContext.parameter.type === MockWebServer::class.java &&
extensionContext.testMethod.isPresent
}
@Suppress("NewApi")
override fun resolveParameter(
parameterContext: ParameterContext,
extensionContext: ExtensionContext
extensionContext: ExtensionContext,
): Any {
val nameAnnotation = parameterContext.findAnnotation(MockWebServerInstance::class.java)
val name = if (nameAnnotation.isPresent) {
nameAnnotation.get().name
} else {
defaultName
}
val name =
if (nameAnnotation.isPresent) {
nameAnnotation.get().name
} else {
defaultName
}
return extensionContext.resource.server(name)
}

View File

@ -16,5 +16,5 @@
package mockwebserver3.junit5.internal
annotation class MockWebServerInstance(
val name: String
val name: String,
)

View File

@ -35,7 +35,7 @@ class ExtensionMultipleInstancesTest {
fun setup(
defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer
@MockWebServerInstance("B") instanceB: MockWebServer,
) {
defaultInstancePort = defaultInstance.port
instanceAPort = instanceA.port
@ -51,7 +51,7 @@ class ExtensionMultipleInstancesTest {
fun tearDown(
defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer
@MockWebServerInstance("B") instanceB: MockWebServer,
) {
assertThat(defaultInstance.port).isEqualTo(defaultInstancePort)
assertThat(instanceA.port).isEqualTo(instanceAPort)
@ -62,7 +62,7 @@ class ExtensionMultipleInstancesTest {
fun testClient(
defaultInstance: MockWebServer,
@MockWebServerInstance("A") instanceA: MockWebServer,
@MockWebServerInstance("B") instanceB: MockWebServer
@MockWebServerInstance("B") instanceB: MockWebServer,
) {
assertThat(defaultInstance.port).isEqualTo(defaultInstancePort)
assertThat(instanceA.port).isEqualTo(instanceAPort)

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3
import java.util.concurrent.TimeUnit
@ -28,7 +29,6 @@ import okio.Buffer
/** A scripted response to be replayed by the mock web server. */
class MockResponse {
/** Returns the HTTP response line, such as "HTTP/1.1 200 OK". */
val status: String
@ -76,14 +76,15 @@ class MockResponse {
body: String = "",
inTunnel: Boolean = false,
socketPolicy: SocketPolicy = KeepOpen,
) : this(Builder()
.apply {
this.code = code
this.headers.addAll(headers)
if (inTunnel) inTunnel()
this.body(body)
this.socketPolicy = socketPolicy
}
) : this(
Builder()
.apply {
this.code = code
this.headers.addAll(headers)
if (inTunnel) inTunnel()
this.body(body)
this.socketPolicy = socketPolicy
},
)
private constructor(builder: Builder) {
@ -101,9 +102,10 @@ class MockResponse {
this.bodyDelayNanos = builder.bodyDelayNanos
this.headersDelayNanos = builder.headersDelayNanos
this.pushPromises = builder.pushPromises.toList()
this.settings = Settings().apply {
merge(builder.settings)
}
this.settings =
Settings().apply {
merge(builder.settings)
}
}
fun newBuilder(): Builder = Builder(this)
@ -125,14 +127,15 @@ class MockResponse {
return statusParts[1].toInt()
}
set(value) {
val reason = when (value) {
in 100..199 -> "Informational"
in 200..299 -> "OK"
in 300..399 -> "Redirection"
in 400..499 -> "Client Error"
in 500..599 -> "Server Error"
else -> "Mock Response"
}
val reason =
when (value) {
in 100..199 -> "Informational"
in 200..299 -> "OK"
in 300..399 -> "Redirection"
in 400..499 -> "Client Error"
in 500..599 -> "Server Error"
else -> "Mock Response"
}
status = "HTTP/1.1 $value $reason"
}
@ -189,8 +192,9 @@ class MockResponse {
this.bodyVar = null
this.streamHandlerVar = null
this.webSocketListenerVar = null
this.headers = Headers.Builder()
.add("Content-Length", "0")
this.headers =
Headers.Builder()
.add("Content-Length", "0")
this.trailers = Headers.Builder()
this.throttleBytesPerPeriod = Long.MAX_VALUE
this.throttlePeriodNanos = 0L
@ -216,41 +220,49 @@ class MockResponse {
this.bodyDelayNanos = mockResponse.bodyDelayNanos
this.headersDelayNanos = mockResponse.headersDelayNanos
this.pushPromises = mockResponse.pushPromises.toMutableList()
this.settings = Settings().apply {
merge(mockResponse.settings)
}
this.settings =
Settings().apply {
merge(mockResponse.settings)
}
}
fun code(code: Int) = apply {
this.code = code
}
fun code(code: Int) =
apply {
this.code = code
}
/** Sets the status and returns this. */
fun status(status: String) = apply {
this.status = status
}
fun status(status: String) =
apply {
this.status = status
}
/**
* Removes all HTTP headers including any "Content-Length" and "Transfer-encoding" headers that
* were added by default.
*/
fun clearHeaders() = apply {
headers = Headers.Builder()
}
fun clearHeaders() =
apply {
headers = Headers.Builder()
}
/**
* Adds [header] as an HTTP header. For well-formed HTTP [header] should contain a name followed
* by a colon and a value.
*/
fun addHeader(header: String) = apply {
headers.add(header)
}
fun addHeader(header: String) =
apply {
headers.add(header)
}
/**
* Adds a new header with the name and value. This may be used to add multiple headers with the
* same name.
*/
fun addHeader(name: String, value: Any) = apply {
fun addHeader(
name: String,
value: Any,
) = apply {
headers.add(name, value.toString())
}
@ -259,39 +271,51 @@ class MockResponse {
* same name. Unlike [addHeader] this does not validate the name and
* value.
*/
fun addHeaderLenient(name: String, value: Any) = apply {
fun addHeaderLenient(
name: String,
value: Any,
) = apply {
addHeaderLenient(headers, name, value.toString())
}
/** Removes all headers named [name], then adds a new header with the name and value. */
fun setHeader(name: String, value: Any) = apply {
fun setHeader(
name: String,
value: Any,
) = apply {
removeHeader(name)
addHeader(name, value)
}
/** Removes all headers named [name]. */
fun removeHeader(name: String) = apply {
headers.removeAll(name)
}
fun removeHeader(name: String) =
apply {
headers.removeAll(name)
}
fun body(body: Buffer) = body(body.toMockResponseBody())
fun body(body: MockResponseBody) = apply {
setHeader("Content-Length", body.contentLength)
this.body = body
}
fun body(body: MockResponseBody) =
apply {
setHeader("Content-Length", body.contentLength)
this.body = body
}
/** Sets the response body to the UTF-8 encoded bytes of [body]. */
fun body(body: String): Builder = body(Buffer().writeUtf8(body))
fun streamHandler(streamHandler: StreamHandler) = apply {
this.streamHandler = streamHandler
}
fun streamHandler(streamHandler: StreamHandler) =
apply {
this.streamHandler = streamHandler
}
/**
* Sets the response body to [body], chunked every [maxChunkSize] bytes.
*/
fun chunkedBody(body: Buffer, maxChunkSize: Int) = apply {
fun chunkedBody(
body: Buffer,
maxChunkSize: Int,
) = apply {
removeHeader("Content-Length")
headers.add(CHUNKED_BODY_HEADER)
@ -311,29 +335,38 @@ class MockResponse {
* Sets the response body to the UTF-8 encoded bytes of [body],
* chunked every [maxChunkSize] bytes.
*/
fun chunkedBody(body: String, maxChunkSize: Int): Builder =
chunkedBody(Buffer().writeUtf8(body), maxChunkSize)
fun chunkedBody(
body: String,
maxChunkSize: Int,
): Builder = chunkedBody(Buffer().writeUtf8(body), maxChunkSize)
/** Sets the headers and returns this. */
fun headers(headers: Headers) = apply {
this.headers = headers.newBuilder()
}
fun headers(headers: Headers) =
apply {
this.headers = headers.newBuilder()
}
/** Sets the trailers and returns this. */
fun trailers(trailers: Headers) = apply {
this.trailers = trailers.newBuilder()
}
fun trailers(trailers: Headers) =
apply {
this.trailers = trailers.newBuilder()
}
/** Sets the socket policy and returns this. */
fun socketPolicy(socketPolicy: SocketPolicy) = apply {
this.socketPolicy = socketPolicy
}
fun socketPolicy(socketPolicy: SocketPolicy) =
apply {
this.socketPolicy = socketPolicy
}
/**
* Throttles the request reader and response writer to sleep for the given period after each
* series of [bytesPerPeriod] bytes are transferred. Use this to simulate network behavior.
*/
fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply {
fun throttleBody(
bytesPerPeriod: Long,
period: Long,
unit: TimeUnit,
) = apply {
throttleBytesPerPeriod = bytesPerPeriod
throttlePeriodNanos = unit.toNanos(period)
}
@ -342,11 +375,17 @@ class MockResponse {
* Set the delayed time of the response body to [delay]. This applies to the response body
* only; response headers are not affected.
*/
fun bodyDelay(delay: Long, unit: TimeUnit) = apply {
fun bodyDelay(
delay: Long,
unit: TimeUnit,
) = apply {
bodyDelayNanos = unit.toNanos(delay)
}
fun headersDelay(delay: Long, unit: TimeUnit) = apply {
fun headersDelay(
delay: Long,
unit: TimeUnit,
) = apply {
headersDelayNanos = unit.toNanos(delay)
}
@ -354,29 +393,32 @@ class MockResponse {
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this attaches a
* pushed stream to this response.
*/
fun addPush(promise: PushPromise) = apply {
this.pushPromises += promise
}
fun addPush(promise: PushPromise) =
apply {
this.pushPromises += promise
}
/**
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this pushes
* [settings] before writing the response.
*/
fun settings(settings: Settings) = apply {
this.settings.clear()
this.settings.merge(settings)
}
fun settings(settings: Settings) =
apply {
this.settings.clear()
this.settings.merge(settings)
}
/**
* Attempts to perform a web socket upgrade on the connection.
* This will overwrite any previously set status, body, or streamHandler.
*/
fun webSocketUpgrade(listener: WebSocketListener) = apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")
webSocketListener = listener
}
fun webSocketUpgrade(listener: WebSocketListener) =
apply {
status = "HTTP/1.1 101 Switching Protocols"
setHeader("Connection", "Upgrade")
setHeader("Upgrade", "websocket")
webSocketListener = listener
}
/**
* Configures this response to be served as a response to an HTTP CONNECT request, either for
@ -385,23 +427,26 @@ class MockResponse {
* When a new connection is received, all in-tunnel responses are served before the connection is
* upgraded to HTTPS or HTTP/2.
*/
fun inTunnel() = apply {
removeHeader("Content-Length")
inTunnel = true
}
fun inTunnel() =
apply {
removeHeader("Content-Length")
inTunnel = true
}
/**
* Adds an HTTP 1xx response to precede this response. Note that this response's
* [headers delay][headersDelay] applies after this response is transmitted. Set a
* headers delay on that response to delay its transmission.
*/
fun addInformationalResponse(response: MockResponse) = apply {
informationalResponses += response
}
fun addInformationalResponse(response: MockResponse) =
apply {
informationalResponses += response
}
fun add100Continue() = apply {
addInformationalResponse(MockResponse(code = 100))
}
fun add100Continue() =
apply {
addInformationalResponse(MockResponse(code = 100))
}
public override fun clone(): Builder = build().newBuilder()

View File

@ -15,6 +15,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3
import java.io.Closeable
@ -97,9 +98,10 @@ import okio.source
* in sequence.
*/
class MockWebServer : Closeable {
private val taskRunnerBackend = TaskRunner.RealBackend(
threadFactory("MockWebServer TaskRunner", daemon = false)
)
private val taskRunnerBackend =
TaskRunner.RealBackend(
threadFactory("MockWebServer TaskRunner", daemon = false),
)
private val taskRunner = TaskRunner(taskRunnerBackend)
private val requestQueue = LinkedBlockingQueue<RecordedRequest>()
private val openClientSockets =
@ -126,6 +128,7 @@ class MockWebServer : Closeable {
}
return field
}
@Synchronized set(value) {
check(!started) { "serverSocketFactory must not be set after start()" }
field = value
@ -280,8 +283,10 @@ class MockWebServer : Closeable {
* @return the head of the request queue
*/
@Throws(InterruptedException::class)
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? =
requestQueue.poll(timeout, unit)
fun takeRequest(
timeout: Long,
unit: TimeUnit,
): RecordedRequest? = requestQueue.poll(timeout, unit)
/**
* Scripts [response] to be returned to a request made in sequence. The first request is
@ -291,8 +296,7 @@ class MockWebServer : Closeable {
* @throws ClassCastException if the default dispatcher has been
* replaced with [setDispatcher][dispatcher].
*/
fun enqueue(response: MockResponse) =
(dispatcher as QueueDispatcher).enqueueResponse(response)
fun enqueue(response: MockResponse) = (dispatcher as QueueDispatcher).enqueueResponse(response)
/**
* Starts the server on the loopback interface for the given port.
@ -301,7 +305,8 @@ class MockWebServer : Closeable {
* use port 0 to avoid flakiness when a specific port is unavailable.
*/
@Throws(IOException::class)
@JvmOverloads fun start(port: Int = 0) = start(InetAddress.getByName("localhost"), port)
@JvmOverloads
fun start(port: Int = 0) = start(InetAddress.getByName("localhost"), port)
/**
* Starts the server on the given address and port.
@ -311,14 +316,18 @@ class MockWebServer : Closeable {
* use port 0 to avoid flakiness when a specific port is unavailable.
*/
@Throws(IOException::class)
fun start(inetAddress: InetAddress, port: Int) = start(InetSocketAddress(inetAddress, port))
fun start(
inetAddress: InetAddress,
port: Int,
) = start(InetSocketAddress(inetAddress, port))
/**
* Starts the server and binds to the given socket address.
*
* @param inetSocketAddress the socket address to bind the server on
*/
@Synchronized @Throws(IOException::class)
@Synchronized
@Throws(IOException::class)
private fun start(inetSocketAddress: InetSocketAddress) {
check(!shutdown) { "shutdown() already called" }
if (started) return
@ -432,9 +441,10 @@ class MockWebServer : Closeable {
processHandshakeFailure(raw)
return
}
socket = sslSocketFactory!!.createSocket(
raw, raw.inetAddress.hostAddress, raw.port, true
)
socket =
sslSocketFactory!!.createSocket(
raw, raw.inetAddress.hostAddress, raw.port, true,
)
val sslSocket = socket as SSLSocket
sslSocket.useClientMode = false
if (clientAuth == CLIENT_AUTH_REQUIRED) {
@ -452,10 +462,11 @@ class MockWebServer : Closeable {
if (protocolNegotiationEnabled) {
val protocolString = Platform.get().getSelectedProtocol(sslSocket)
protocol = when {
protocolString != null -> Protocol.get(protocolString)
else -> Protocol.HTTP_1_1
}
protocol =
when {
protocolString != null -> Protocol.get(protocolString)
else -> Protocol.HTTP_1_1
}
Platform.get().afterHandshake(sslSocket)
} else {
protocol = Protocol.HTTP_1_1
@ -463,10 +474,11 @@ class MockWebServer : Closeable {
openClientSockets.remove(raw)
}
else -> {
protocol = when {
Protocol.H2_PRIOR_KNOWLEDGE in protocols -> Protocol.H2_PRIOR_KNOWLEDGE
else -> Protocol.HTTP_1_1
}
protocol =
when {
Protocol.H2_PRIOR_KNOWLEDGE in protocols -> Protocol.H2_PRIOR_KNOWLEDGE
else -> Protocol.HTTP_1_1
}
socket = raw
}
}
@ -478,10 +490,11 @@ class MockWebServer : Closeable {
if (protocol === Protocol.HTTP_2 || protocol === Protocol.H2_PRIOR_KNOWLEDGE) {
val http2SocketHandler = Http2SocketHandler(socket, protocol)
val connection = Http2Connection.Builder(false, taskRunner)
.socket(socket)
.listener(http2SocketHandler)
.build()
val connection =
Http2Connection.Builder(false, taskRunner)
.socket(socket)
.listener(http2SocketHandler)
.build()
connection.start()
openConnections.add(connection)
openClientSockets.remove(socket)
@ -498,7 +511,7 @@ class MockWebServer : Closeable {
if (sequenceNumber == 0) {
logger.warning(
"${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request"
"${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request",
)
}
@ -538,7 +551,7 @@ class MockWebServer : Closeable {
private fun processOneRequest(
socket: Socket,
source: BufferedSource,
sink: BufferedSink
sink: BufferedSink,
): Boolean {
if (source.exhausted()) {
return false // No more requests on this socket.
@ -581,7 +594,7 @@ class MockWebServer : Closeable {
if (logger.isLoggable(Level.FINE)) {
logger.fine(
"${this@MockWebServer} received request: $request and responded: $response"
"${this@MockWebServer} received request: $request and responded: $response",
)
}
@ -607,9 +620,13 @@ class MockWebServer : Closeable {
val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf<TrustManager>(UNTRUSTED_TRUST_MANAGER), SecureRandom())
val sslSocketFactory = context.socketFactory
val socket = sslSocketFactory.createSocket(
raw, raw.inetAddress.hostAddress, raw.port, true
) as SSLSocket
val socket =
sslSocketFactory.createSocket(
raw,
raw.inetAddress.hostAddress,
raw.port,
true,
) as SSLSocket
try {
socket.startHandshake() // we're testing a handshake failure
throw AssertionError()
@ -619,10 +636,20 @@ class MockWebServer : Closeable {
}
@Throws(InterruptedException::class)
private fun dispatchBookkeepingRequest(sequenceNumber: Int, socket: Socket) {
val request = RecordedRequest(
"", headersOf(), emptyList(), 0L, Buffer(), sequenceNumber, socket
)
private fun dispatchBookkeepingRequest(
sequenceNumber: Int,
socket: Socket,
) {
val request =
RecordedRequest(
"",
headersOf(),
emptyList(),
0L,
Buffer(),
sequenceNumber,
socket,
)
atomicRequestCount.incrementAndGet()
requestQueue.add(request)
dispatcher.dispatch(request)
@ -634,7 +661,7 @@ class MockWebServer : Closeable {
socket: Socket,
source: BufferedSource,
sink: BufferedSink,
sequenceNumber: Int
sequenceNumber: Int,
): RecordedRequest {
var request = ""
val headers = Headers.Builder()
@ -674,12 +701,13 @@ class MockWebServer : Closeable {
var hasBody = false
val policy = dispatcher.peek()
val requestBodySink = requestBody.withThrottlingAndSocketPolicy(
policy = policy,
disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody,
expectedByteCount = contentLength,
socket = socket,
).buffer()
val requestBodySink =
requestBody.withThrottlingAndSocketPolicy(
policy = policy,
disconnectHalfway = policy.socketPolicy == DisconnectDuringRequestBody,
expectedByteCount = contentLength,
socket = socket,
).buffer()
requestBodySink.use {
when {
policy.socketPolicy is DoNotReadRequestBody -> {
@ -725,7 +753,7 @@ class MockWebServer : Closeable {
body = requestBody.buffer,
sequenceNumber = sequenceNumber,
socket = socket,
failure = failure
failure = failure,
)
}
@ -735,47 +763,52 @@ class MockWebServer : Closeable {
source: BufferedSource,
sink: BufferedSink,
request: RecordedRequest,
response: MockResponse
response: MockResponse,
) {
val key = request.headers["Sec-WebSocket-Key"]
val webSocketResponse = response.newBuilder()
.setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!))
.build()
val webSocketResponse =
response.newBuilder()
.setHeader("Sec-WebSocket-Accept", WebSocketProtocol.acceptHeader(key!!))
.build()
writeHttpResponse(socket, sink, webSocketResponse)
// Adapt the request and response into our Request and Response domain model.
val scheme = if (request.handshake != null) "https" else "http"
val authority = request.headers["Host"] // Has host and port.
val fancyRequest = Request.Builder()
.url("$scheme://$authority/")
.headers(request.headers)
.build()
val fancyResponse = Response.Builder()
.code(webSocketResponse.code)
.message(webSocketResponse.message)
.headers(webSocketResponse.headers)
.request(fancyRequest)
.protocol(Protocol.HTTP_1_1)
.build()
val fancyRequest =
Request.Builder()
.url("$scheme://$authority/")
.headers(request.headers)
.build()
val fancyResponse =
Response.Builder()
.code(webSocketResponse.code)
.message(webSocketResponse.message)
.headers(webSocketResponse.headers)
.request(fancyRequest)
.protocol(Protocol.HTTP_1_1)
.build()
val connectionClose = CountDownLatch(1)
val streams = object : RealWebSocket.Streams(false, source, sink) {
override fun close() = connectionClose.countDown()
val streams =
object : RealWebSocket.Streams(false, source, sink) {
override fun close() = connectionClose.countDown()
override fun cancel() {
socket.closeQuietly()
override fun cancel() {
socket.closeQuietly()
}
}
}
val webSocket = RealWebSocket(
taskRunner = taskRunner,
originalRequest = fancyRequest,
listener = webSocketResponse.webSocketListener!!,
random = SecureRandom(),
pingIntervalMillis = 0,
extensions = WebSocketExtensions.parse(webSocketResponse.headers),
// Compress all messages if compression is enabled.
minimumDeflateSize = 0L,
)
val webSocket =
RealWebSocket(
taskRunner = taskRunner,
originalRequest = fancyRequest,
listener = webSocketResponse.webSocketListener!!,
random = SecureRandom(),
pingIntervalMillis = 0,
extensions = WebSocketExtensions.parse(webSocketResponse.headers),
// Compress all messages if compression is enabled.
minimumDeflateSize = 0L,
)
val name = "MockWebServer WebSocket ${request.path!!}"
webSocket.initReaderAndWriter(name, streams)
try {
@ -789,7 +822,11 @@ class MockWebServer : Closeable {
}
@Throws(IOException::class)
private fun writeHttpResponse(socket: Socket, sink: BufferedSink, response: MockResponse) {
private fun writeHttpResponse(
socket: Socket,
sink: BufferedSink,
response: MockResponse,
) {
sleepNanos(response.headersDelayNanos)
sink.writeUtf8(response.status)
sink.writeUtf8("\r\n")
@ -798,12 +835,13 @@ class MockWebServer : Closeable {
val body = response.body ?: return
sleepNanos(response.bodyDelayNanos)
val responseBodySink = sink.withThrottlingAndSocketPolicy(
policy = response,
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
expectedByteCount = body.contentLength,
socket = socket,
).buffer()
val responseBodySink =
sink.withThrottlingAndSocketPolicy(
policy = response,
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
expectedByteCount = body.contentLength,
socket = socket,
).buffer()
body.writeTo(responseBodySink)
responseBodySink.emit()
@ -813,7 +851,10 @@ class MockWebServer : Closeable {
}
@Throws(IOException::class)
private fun writeHeaders(sink: BufferedSink, headers: Headers) {
private fun writeHeaders(
sink: BufferedSink,
headers: Headers,
) {
for ((name, value) in headers) {
sink.writeUtf8(name)
sink.writeUtf8(": ")
@ -834,25 +875,28 @@ class MockWebServer : Closeable {
var result: Sink = this
if (policy.throttlePeriodNanos > 0L) {
result = ThrottledSink(
delegate = result,
bytesPerPeriod = policy.throttleBytesPerPeriod,
periodDelayNanos = policy.throttlePeriodNanos,
)
result =
ThrottledSink(
delegate = result,
bytesPerPeriod = policy.throttleBytesPerPeriod,
periodDelayNanos = policy.throttlePeriodNanos,
)
}
if (disconnectHalfway) {
val halfwayByteCount = when {
expectedByteCount != -1L -> expectedByteCount / 2
else -> 0L
}
result = TriggerSink(
delegate = result,
triggerByteCount = halfwayByteCount,
) {
result.flush()
socket.close()
}
val halfwayByteCount =
when {
expectedByteCount != -1L -> expectedByteCount / 2
else -> 0L
}
result =
TriggerSink(
delegate = result,
triggerByteCount = halfwayByteCount,
) {
result.flush()
socket.close()
}
}
return result
@ -871,13 +915,16 @@ class MockWebServer : Closeable {
/** A buffer wrapper that drops data after [bodyLimit] bytes. */
private class TruncatingBuffer(
private var remainingByteCount: Long
private var remainingByteCount: Long,
) : Sink {
val buffer = Buffer()
var receivedByteCount = 0L
@Throws(IOException::class)
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
val toRead = minOf(remainingByteCount, byteCount)
if (toRead > 0L) {
source.read(buffer, toRead)
@ -904,7 +951,7 @@ class MockWebServer : Closeable {
/** Processes HTTP requests layered over HTTP/2. */
private inner class Http2SocketHandler constructor(
private val socket: Socket,
private val protocol: Protocol
private val protocol: Protocol,
) : Http2Connection.Listener() {
private val sequenceNumber = AtomicInteger()
@ -935,7 +982,7 @@ class MockWebServer : Closeable {
if (logger.isLoggable(Level.FINE)) {
logger.fine(
"${this@MockWebServer} received request: $request " +
"and responded: $response protocol is $protocol"
"and responded: $response protocol is $protocol",
)
}
@ -990,12 +1037,13 @@ class MockWebServer : Closeable {
if (readBody && peek.streamHandler == null && peek.socketPolicy !is DoNotReadRequestBody) {
try {
val contentLengthString = headers["content-length"]
val requestBodySink = body.withThrottlingAndSocketPolicy(
policy = peek,
disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody,
expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE,
socket = socket,
).buffer()
val requestBodySink =
body.withThrottlingAndSocketPolicy(
policy = peek,
disconnectHalfway = peek.socketPolicy == DisconnectDuringRequestBody,
expectedByteCount = contentLengthString?.toLong() ?: Long.MAX_VALUE,
socket = socket,
).buffer()
requestBodySink.use {
it.writeAll(stream.getSource())
}
@ -1012,7 +1060,7 @@ class MockWebServer : Closeable {
body = body,
sequenceNumber = sequenceNumber.getAndIncrement(),
socket = socket,
failure = exception
failure = exception,
)
}
@ -1029,7 +1077,7 @@ class MockWebServer : Closeable {
private fun writeResponse(
stream: Http2Stream,
request: RecordedRequest,
response: MockResponse
response: MockResponse,
) {
val settings = response.settings
stream.connection.setSettings(settings)
@ -1042,9 +1090,11 @@ class MockWebServer : Closeable {
val trailers = response.trailers
val body = response.body
val streamHandler = response.streamHandler
val outFinished = (body == null &&
response.pushPromises.isEmpty() &&
streamHandler == null)
val outFinished = (
body == null &&
response.pushPromises.isEmpty() &&
streamHandler == null
)
val flushHeaders = body == null || bodyDelayNanos != 0L
require(!outFinished || trailers.size == 0) {
"unsupported: no body and non-empty trailers $trailers"
@ -1059,12 +1109,13 @@ class MockWebServer : Closeable {
pushPromises(stream, request, response.pushPromises)
if (body != null) {
sleepNanos(bodyDelayNanos)
val responseBodySink = stream.getSink().withThrottlingAndSocketPolicy(
policy = response,
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
expectedByteCount = body.contentLength,
socket = socket
).buffer()
val responseBodySink =
stream.getSink().withThrottlingAndSocketPolicy(
policy = response,
disconnectHalfway = response.socketPolicy == DisconnectDuringResponseBody,
expectedByteCount = body.contentLength,
socket = socket,
).buffer()
responseBodySink.use {
body.writeTo(responseBodySink)
}
@ -1079,7 +1130,7 @@ class MockWebServer : Closeable {
private fun pushPromises(
stream: Http2Stream,
request: RecordedRequest,
promises: List<PushPromise>
promises: List<PushPromise>,
) {
for (pushPromise in promises) {
val pushedHeaders = mutableListOf<Header>()
@ -1100,8 +1151,8 @@ class MockWebServer : Closeable {
bodySize = 0,
body = Buffer(),
sequenceNumber = sequenceNumber.getAndIncrement(),
socket = socket
)
socket = socket,
),
)
val hasBody = pushPromise.response.body != null
val pushedStream = stream.connection.pushStream(stream.id, pushedHeaders, hasBody)
@ -1115,20 +1166,21 @@ class MockWebServer : Closeable {
private const val CLIENT_AUTH_REQUESTED = 1
private const val CLIENT_AUTH_REQUIRED = 2
private val UNTRUSTED_TRUST_MANAGER = object : X509TrustManager {
@Throws(CertificateException::class)
override fun checkClientTrusted(
chain: Array<X509Certificate>,
authType: String
) = throw CertificateException()
private val UNTRUSTED_TRUST_MANAGER =
object : X509TrustManager {
@Throws(CertificateException::class)
override fun checkClientTrusted(
chain: Array<X509Certificate>,
authType: String,
) = throw CertificateException()
override fun checkServerTrusted(
chain: Array<X509Certificate>,
authType: String
) = throw AssertionError()
override fun checkServerTrusted(
chain: Array<X509Certificate>,
authType: String,
) = throw AssertionError()
override fun getAcceptedIssuers(): Array<X509Certificate> = throw AssertionError()
}
override fun getAcceptedIssuers(): Array<X509Certificate> = throw AssertionError()
}
private val logger = Logger.getLogger(MockWebServer::class.java.name)
}

View File

@ -68,11 +68,12 @@ open class QueueDispatcher : Dispatcher() {
}
open fun setFailFast(failFast: Boolean) {
val failFastResponse = if (failFast) {
MockResponse(code = HttpURLConnection.HTTP_NOT_FOUND)
} else {
null
}
val failFastResponse =
if (failFast) {
MockResponse(code = HttpURLConnection.HTTP_NOT_FOUND)
} else {
null
}
setFailFast(failFastResponse)
}

View File

@ -31,34 +31,28 @@ import okio.Buffer
/** An HTTP request that came into the mock web server. */
class RecordedRequest(
val requestLine: String,
/** All headers. */
val headers: Headers,
/**
* The sizes of the chunks of this request's body, or an empty list if the request's body
* was empty or unchunked.
*/
val chunkSizes: List<Int>,
/** The total size of the body of this POST request (before truncation).*/
val bodySize: Long,
/** The body of this POST request. This may be truncated. */
val body: Buffer,
/**
* The index of this request on its HTTP connection. Since a single HTTP connection may serve
* multiple requests, each request is assigned its own sequence number.
*/
val sequenceNumber: Int,
socket: Socket,
/**
* The failure MockWebServer recorded when attempting to decode this request. If, for example,
* the inbound request was truncated, this exception will be non-null.
*/
val failure: IOException? = null
val failure: IOException? = null,
) {
val method: String?
val path: String?
@ -102,12 +96,13 @@ class RecordedRequest(
val scheme = if (socket is SSLSocket) "https" else "http"
val localPort = socket.localPort
val hostAndPort = headers[":authority"]
?: headers["Host"]
?: when (val inetAddress = socket.localAddress) {
is Inet6Address -> "[${inetAddress.hostAddress}]:$localPort"
else -> "${inetAddress.hostAddress}:$localPort"
}
val hostAndPort =
headers[":authority"]
?: headers["Host"]
?: when (val inetAddress = socket.localAddress) {
is Inet6Address -> "[${inetAddress.hostAddress}]:$localPort"
else -> "${inetAddress.hostAddress}:$localPort"
}
// Allow null in failure case to allow for testing bad requests
this.requestUrl = "$scheme://$hostAndPort$path".toHttpUrlOrNull()

View File

@ -30,7 +30,6 @@ package mockwebserver3
* server has closed the socket before follow up requests are made.
*/
sealed interface SocketPolicy {
/**
* Shutdown [MockWebServer] after writing response.
*/

View File

@ -29,7 +29,11 @@ internal class ThrottledSink(
private val periodDelayNanos: Long,
) : Sink by delegate {
private var bytesWrittenSinceLastDelay = 0L
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
var bytesLeft = byteCount
while (bytesLeft > 0) {

View File

@ -30,7 +30,10 @@ internal class TriggerSink(
) : Sink by delegate {
private var bytesWritten = 0L
override fun write(source: Buffer, byteCount: Long) {
override fun write(
source: Buffer,
byteCount: Long,
) {
if (byteCount == 0L) return // Avoid double-triggering.
if (bytesWritten == triggerByteCount) {

View File

@ -34,37 +34,41 @@ class MockStreamHandler : StreamHandler {
private val actions = LinkedBlockingQueue<Action>()
private val results = LinkedBlockingQueue<FutureTask<Void>>()
fun receiveRequest(expected: String) = apply {
actions += { stream ->
val actual = stream.requestBody.readUtf8(expected.utf8Size())
if (actual != expected) throw AssertionError("$actual != $expected")
}
}
fun exhaustRequest() = apply {
actions += { stream ->
if (!stream.requestBody.exhausted()) throw AssertionError("expected exhausted")
}
}
fun cancelStream() = apply {
actions += { stream -> stream.cancel() }
}
fun requestIOException() = apply {
actions += { stream ->
try {
stream.requestBody.exhausted()
throw AssertionError("expected IOException")
} catch (expected: IOException) {
fun receiveRequest(expected: String) =
apply {
actions += { stream ->
val actual = stream.requestBody.readUtf8(expected.utf8Size())
if (actual != expected) throw AssertionError("$actual != $expected")
}
}
fun exhaustRequest() =
apply {
actions += { stream ->
if (!stream.requestBody.exhausted()) throw AssertionError("expected exhausted")
}
}
fun cancelStream() =
apply {
actions += { stream -> stream.cancel() }
}
fun requestIOException() =
apply {
actions += { stream ->
try {
stream.requestBody.exhausted()
throw AssertionError("expected IOException")
} catch (expected: IOException) {
}
}
}
}
@JvmOverloads
fun sendResponse(
s: String,
responseSent: CountDownLatch = CountDownLatch(0)
responseSent: CountDownLatch = CountDownLatch(0),
) = apply {
actions += { stream ->
stream.responseBody.writeUtf8(s)
@ -73,11 +77,15 @@ class MockStreamHandler : StreamHandler {
}
}
fun exhaustResponse() = apply {
actions += { stream -> stream.responseBody.close() }
}
fun exhaustResponse() =
apply {
actions += { stream -> stream.responseBody.close() }
}
fun sleep(duration: Long, unit: TimeUnit) = apply {
fun sleep(
duration: Long,
unit: TimeUnit,
) = apply {
actions += { Thread.sleep(unit.toMillis(duration)) }
}
@ -88,9 +96,7 @@ class MockStreamHandler : StreamHandler {
}
/** Returns a task that processes both request and response from [stream]. */
private fun serviceStreamTask(
stream: Stream,
): FutureTask<Void> {
private fun serviceStreamTask(stream: Stream): FutureTask<Void> {
return FutureTask<Void> {
stream.requestBody.use {
stream.responseBody.use {
@ -106,8 +112,9 @@ class MockStreamHandler : StreamHandler {
/** Returns once all stream actions complete successfully. */
fun awaitSuccess() {
val futureTask = results.poll(5, TimeUnit.SECONDS)
?: throw AssertionError("no onRequest call received")
val futureTask =
results.poll(5, TimeUnit.SECONDS)
?: throw AssertionError("no onRequest call received")
futureTask.get(5, TimeUnit.SECONDS)
}
}

View File

@ -37,12 +37,13 @@ class CustomDispatcherTest {
@Test
fun simpleDispatch() {
val requestsMade = mutableListOf<RecordedRequest>()
val dispatcher: Dispatcher = object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
requestsMade.add(request)
return MockResponse()
val dispatcher: Dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
requestsMade.add(request)
return MockResponse()
}
}
}
assertThat(requestsMade.size).isEqualTo(0)
mockWebServer.dispatcher = dispatcher
val url = mockWebServer.url("/").toUrl()
@ -59,14 +60,15 @@ class CustomDispatcherTest {
val secondRequest = "/bar"
val firstRequest = "/foo"
val latch = CountDownLatch(1)
val dispatcher: Dispatcher = object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
if (request.path == firstRequest) {
latch.await()
val dispatcher: Dispatcher =
object : Dispatcher() {
override fun dispatch(request: RecordedRequest): MockResponse {
if (request.path == firstRequest) {
latch.await()
}
return MockResponse()
}
return MockResponse()
}
}
mockWebServer.dispatcher = dispatcher
val startsFirst = buildRequestThread(firstRequest, firstResponseCode)
startsFirst.start()
@ -85,7 +87,10 @@ class CustomDispatcherTest {
assertThat(secondResponseCode.get()).isEqualTo(200)
}
private fun buildRequestThread(path: String, responseCode: AtomicInteger): Thread {
private fun buildRequestThread(
path: String,
responseCode: AtomicInteger,
): Thread {
return Thread {
val url = mockWebServer.url(path).toUrl()
val conn: HttpURLConnection

View File

@ -55,17 +55,19 @@ class MockResponseSniTest {
val handshakeCertificates = localhost()
server.useHttps(handshakeCertificates.sslSocketFactory())
val dns = Dns {
Dns.SYSTEM.lookup(server.hostName)
}
val dns =
Dns {
Dns.SYSTEM.lookup(server.hostName)
}
val client = clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
)
.dns(dns)
.build()
val client =
clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
.dns(dns)
.build()
server.enqueue(MockResponse())
@ -84,36 +86,41 @@ class MockResponseSniTest {
*/
@Test
fun domainFronting() {
val heldCertificate = HeldCertificate.Builder()
.commonName("server name")
.addSubjectAlternativeName("url-host.com")
.build()
val handshakeCertificates = HandshakeCertificates.Builder()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
val heldCertificate =
HeldCertificate.Builder()
.commonName("server name")
.addSubjectAlternativeName("url-host.com")
.build()
val handshakeCertificates =
HandshakeCertificates.Builder()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory())
val dns = Dns {
Dns.SYSTEM.lookup(server.hostName)
}
val dns =
Dns {
Dns.SYSTEM.lookup(server.hostName)
}
val client = clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
)
.dns(dns)
.build()
val client =
clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
.dns(dns)
.build()
server.enqueue(MockResponse())
val call = client.newCall(
Request(
url = "https://url-host.com:${server.port}/".toHttpUrl(),
headers = headersOf("Host", "header-host"),
val call =
client.newCall(
Request(
url = "https://url-host.com:${server.port}/".toHttpUrl(),
headers = headersOf("Host", "header-host"),
),
)
)
val response = call.execute()
assertThat(response.isSuccessful).isTrue()
@ -150,34 +157,39 @@ class MockResponseSniTest {
* tell MockWebServer to act as a proxy.
*/
private fun requestToHostnameViaProxy(hostnameOrIpAddress: String): RecordedRequest {
val heldCertificate = HeldCertificate.Builder()
.commonName("server name")
.addSubjectAlternativeName(hostnameOrIpAddress)
.build()
val handshakeCertificates = HandshakeCertificates.Builder()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
val heldCertificate =
HeldCertificate.Builder()
.commonName("server name")
.addSubjectAlternativeName(hostnameOrIpAddress)
.build()
val handshakeCertificates =
HandshakeCertificates.Builder()
.heldCertificate(heldCertificate)
.addTrustedCertificate(heldCertificate.certificate)
.build()
server.useHttps(handshakeCertificates.sslSocketFactory())
val client = clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
)
.proxy(server.toProxyAddress())
.build()
val client =
clientTestRule.newClientBuilder()
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
.proxy(server.toProxyAddress())
.build()
server.enqueue(MockResponse(inTunnel = true))
server.enqueue(MockResponse())
val call = client.newCall(
Request(
url = server.url("/").newBuilder()
.host(hostnameOrIpAddress)
.build()
val call =
client.newCall(
Request(
url =
server.url("/").newBuilder()
.host(hostnameOrIpAddress)
.build(),
),
)
)
val response = call.execute()
assertThat(response.isSuccessful).isTrue()

View File

@ -92,15 +92,16 @@ class MockWebServerTest {
@Test
fun setResponseMockReason() {
val reasons = arrayOf<String?>(
"Mock Response",
"Informational",
"OK",
"Redirection",
"Client Error",
"Server Error",
"Mock Response"
)
val reasons =
arrayOf<String?>(
"Mock Response",
"Informational",
"OK",
"Redirection",
"Client Error",
"Server Error",
"Mock Response",
)
for (i in 0..599) {
val builder = MockResponse.Builder().code(i)
val expectedReason = reasons[i / 100]
@ -128,30 +129,33 @@ class MockWebServerTest {
@Test
fun mockResponseAddHeader() {
val builder = MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
val builder =
MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie", "a=android")
assertThat(headersToList(builder)).containsExactly("Cookie: s=square", "Cookie: a=android")
}
@Test
fun mockResponseSetHeader() {
val builder = MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie: a=android")
.addHeader("Cookies: delicious")
val builder =
MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookie: a=android")
.addHeader("Cookies: delicious")
builder.setHeader("cookie", "r=robot")
assertThat(headersToList(builder)).containsExactly("Cookies: delicious", "cookie: r=robot")
}
@Test
fun mockResponseSetHeaders() {
val builder = MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
val builder =
MockResponse.Builder()
.clearHeaders()
.addHeader("Cookie: s=square")
.addHeader("Cookies: delicious")
builder.headers(Headers.Builder().add("Cookie", "a=android").build())
assertThat(headersToList(builder)).containsExactly("Cookie: a=android")
}
@ -175,14 +179,18 @@ class MockWebServerTest {
@Test
fun redirect() {
server.enqueue(MockResponse.Builder()
.code(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path"))
.body("This page has moved!")
.build())
server.enqueue(MockResponse.Builder()
.body("This is the new location!")
.build())
server.enqueue(
MockResponse.Builder()
.code(HttpURLConnection.HTTP_MOVED_TEMP)
.addHeader("Location: " + server.url("/new-path"))
.body("This page has moved!")
.build(),
)
server.enqueue(
MockResponse.Builder()
.body("This is the new location!")
.build(),
)
val connection = server.url("/").toUrl().openConnection()
val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8))
assertThat(reader.readLine()).isEqualTo("This is the new location!")
@ -203,9 +211,11 @@ class MockWebServerTest {
Thread.sleep(1000)
} catch (ignored: InterruptedException) {
}
server.enqueue(MockResponse.Builder()
.body("enqueued in the background")
.build())
server.enqueue(
MockResponse.Builder()
.body("enqueued in the background")
.build(),
)
}.start()
val connection = server.url("/").toUrl().openConnection()
val reader = BufferedReader(InputStreamReader(connection!!.getInputStream(), UTF_8))
@ -214,11 +224,13 @@ class MockWebServerTest {
@Test
fun nonHexadecimalChunkSize() {
server.enqueue(MockResponse.Builder()
.body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders()
.addHeader("Transfer-encoding: chunked")
.build())
server.enqueue(
MockResponse.Builder()
.body("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n")
.clearHeaders()
.addHeader("Transfer-encoding: chunked")
.build(),
)
val connection = server.url("/").toUrl().openConnection()
try {
connection.getInputStream().read()
@ -230,14 +242,18 @@ class MockWebServerTest {
@Test
fun responseTimeout() {
server.enqueue(MockResponse.Builder()
.body("ABC")
.clearHeaders()
.addHeader("Content-Length: 4")
.build())
server.enqueue(MockResponse.Builder()
.body("DEF")
.build())
server.enqueue(
MockResponse.Builder()
.body("ABC")
.clearHeaders()
.addHeader("Content-Length: 4")
.build(),
)
server.enqueue(
MockResponse.Builder()
.body("DEF")
.build(),
)
val urlConnection = server.url("/").toUrl().openConnection()
urlConnection!!.readTimeout = 1000
val inputStream = urlConnection.getInputStream()
@ -263,9 +279,11 @@ class MockWebServerTest {
@Disabled("Not actually failing where expected")
@Test
fun disconnectAtStart() {
server.enqueue(MockResponse.Builder()
.socketPolicy(DisconnectAtStart)
.build())
server.enqueue(
MockResponse.Builder()
.socketPolicy(DisconnectAtStart)
.build(),
)
server.enqueue(MockResponse()) // The jdk's HttpUrlConnection is a bastard.
server.enqueue(MockResponse())
try {
@ -293,9 +311,11 @@ class MockWebServerTest {
@Test
fun throttleRequest() {
assumeNotWindows()
server.enqueue(MockResponse.Builder()
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build())
server.enqueue(
MockResponse.Builder()
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build(),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
connection.doOutput = true
@ -314,10 +334,12 @@ class MockWebServerTest {
@Test
fun throttleResponse() {
assumeNotWindows()
server.enqueue(MockResponse.Builder()
.body("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build())
server.enqueue(
MockResponse.Builder()
.body("ABCDEF")
.throttleBody(3, 500, TimeUnit.MILLISECONDS)
.build(),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
val inputStream = connection!!.getInputStream()
@ -337,10 +359,12 @@ class MockWebServerTest {
@Test
fun delayResponse() {
assumeNotWindows()
server.enqueue(MockResponse.Builder()
.body("ABCDEF")
.bodyDelay(1, TimeUnit.SECONDS)
.build())
server.enqueue(
MockResponse.Builder()
.body("ABCDEF")
.bodyDelay(1, TimeUnit.SECONDS)
.build(),
)
val startNanos = System.nanoTime()
val connection = server.url("/").toUrl().openConnection()
val inputStream = connection!!.getInputStream()
@ -353,9 +377,11 @@ class MockWebServerTest {
@Test
fun disconnectRequestHalfway() {
server.enqueue(MockResponse.Builder()
.socketPolicy(DisconnectDuringRequestBody)
.build())
server.enqueue(
MockResponse.Builder()
.socketPolicy(DisconnectDuringRequestBody)
.build(),
)
// Limit the size of the request body that the server holds in memory to an arbitrary
// 3.5 MBytes so this test can pass on devices with little memory.
server.bodyLimit = 7 * 512 * 1024
@ -386,10 +412,12 @@ class MockWebServerTest {
@Test
fun disconnectResponseHalfway() {
server.enqueue(MockResponse.Builder()
.body("ab")
.socketPolicy(DisconnectDuringResponseBody)
.build())
server.enqueue(
MockResponse.Builder()
.body("ab")
.socketPolicy(DisconnectDuringResponseBody)
.build(),
)
val connection = server.url("/").toUrl().openConnection()
assertThat(connection!!.contentLength).isEqualTo(2)
val inputStream = connection.getInputStream()
@ -468,9 +496,11 @@ class MockWebServerTest {
@Test
fun requestUrlReconstructed() {
server.enqueue(MockResponse.Builder()
.body("hello world")
.build())
server.enqueue(
MockResponse.Builder()
.body("hello world")
.build(),
)
val url = server.url("/a/deep/path?key=foo%20bar").toUrl()
val connection = url.openConnection() as HttpURLConnection
val inputStream = connection.inputStream
@ -479,7 +509,8 @@ class MockWebServerTest {
assertThat(reader.readLine()).isEqualTo("hello world")
val request = server.takeRequest()
assertThat(request.requestLine).isEqualTo(
"GET /a/deep/path?key=foo%20bar HTTP/1.1")
"GET /a/deep/path?key=foo%20bar HTTP/1.1",
)
val requestUrl = request.requestUrl
assertThat(requestUrl!!.scheme).isEqualTo("http")
assertThat(requestUrl.host).isEqualTo(server.hostName)
@ -490,9 +521,11 @@ class MockWebServerTest {
@Test
fun shutdownServerAfterRequest() {
server.enqueue(MockResponse.Builder()
.socketPolicy(ShutdownServerAfterResponse)
.build())
server.enqueue(
MockResponse.Builder()
.socketPolicy(ShutdownServerAfterResponse)
.build(),
)
val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection
assertThat(connection.responseCode).isEqualTo(HttpURLConnection.HTTP_OK)
@ -506,9 +539,11 @@ class MockWebServerTest {
@Test
fun http100Continue() {
server.enqueue(MockResponse.Builder()
.body("response")
.build())
server.enqueue(
MockResponse.Builder()
.body("response")
.build(),
)
val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection
connection.doOutput = true
@ -523,11 +558,13 @@ class MockWebServerTest {
@Test
fun multiple1xxResponses() {
server.enqueue(MockResponse.Builder()
.add100Continue()
.add100Continue()
.body("response")
.build())
server.enqueue(
MockResponse.Builder()
.add100Continue()
.add100Continue()
.body("response")
.build(),
)
val url = server.url("/").toUrl()
val connection = url.openConnection() as HttpURLConnection
connection.doOutput = true
@ -546,8 +583,9 @@ class MockWebServerTest {
fail<Unit>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, http/1.1]")
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, http/1.1]",
)
}
}
@ -559,8 +597,9 @@ class MockWebServerTest {
fail<Unit>()
} catch (expected: IllegalArgumentException) {
assertThat(expected.message).isEqualTo(
"protocols containing h2_prior_knowledge cannot use other protocols: "
+ "[h2_prior_knowledge, h2_prior_knowledge]")
"protocols containing h2_prior_knowledge cannot use other protocols: " +
"[h2_prior_knowledge, h2_prior_knowledge]",
)
}
}
@ -575,9 +614,11 @@ class MockWebServerTest {
fun https() {
val handshakeCertificates = platform.localhostHandshakeCertificates()
server.useHttps(handshakeCertificates.sslSocketFactory())
server.enqueue(MockResponse.Builder()
.body("abc")
.build())
server.enqueue(
MockResponse.Builder()
.body("abc")
.build(),
)
val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.sslSocketFactory = handshakeCertificates.sslSocketFactory()
@ -601,32 +642,40 @@ class MockWebServerTest {
platform.assumeNotBouncyCastle()
platform.assumeNotConscrypt()
val clientCa = HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCa = HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCertificate = HeldCertificate.Builder()
.signedBy(serverCa)
.addSubjectAlternativeName(server.hostName)
.build()
val serverHandshakeCertificates = HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
val clientCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCa =
HeldCertificate.Builder()
.certificateAuthority(0)
.build()
val serverCertificate =
HeldCertificate.Builder()
.signedBy(serverCa)
.addSubjectAlternativeName(server.hostName)
.build()
val serverHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(clientCa.certificate)
.heldCertificate(serverCertificate)
.build()
server.useHttps(serverHandshakeCertificates.sslSocketFactory())
server.enqueue(MockResponse.Builder()
.body("abc")
.build())
server.enqueue(
MockResponse.Builder()
.body("abc")
.build(),
)
server.requestClientAuth()
val clientCertificate = HeldCertificate.Builder()
.signedBy(clientCa)
.build()
val clientHandshakeCertificates = HandshakeCertificates.Builder()
.addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
val clientCertificate =
HeldCertificate.Builder()
.signedBy(clientCa)
.build()
val clientHandshakeCertificates =
HandshakeCertificates.Builder()
.addTrustedCertificate(serverCa.certificate)
.heldCertificate(clientCertificate)
.build()
val url = server.url("/")
val connection = url.toUrl().openConnection() as HttpsURLConnection
connection.sslSocketFactory = clientHandshakeCertificates.sslSocketFactory()
@ -647,13 +696,16 @@ class MockWebServerTest {
@Test
fun proxiedRequestGetsCorrectRequestUrl() {
server.enqueue(MockResponse.Builder()
.body("Result")
.build())
val proxiedClient = OkHttpClient.Builder()
.proxy(server.toProxyAddress())
.readTimeout(Duration.ofMillis(100))
.build()
server.enqueue(
MockResponse.Builder()
.body("Result")
.build(),
)
val proxiedClient =
OkHttpClient.Builder()
.proxy(server.toProxyAddress())
.readTimeout(Duration.ofMillis(100))
.build()
val request = Request.Builder().url("http://android.com/").build()
proxiedClient.newCall(request).execute().use { response ->
assertThat(response.body.string()).isEqualTo("Result")

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package mockwebserver3
import assertk.assertThat
@ -32,44 +33,50 @@ class RecordedRequestTest {
private val headers: Headers = EMPTY_HEADERS
@Test fun testIPv4() {
val socket = FakeSocket(
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
localPort = 80
)
val socket =
FakeSocket(
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/")
}
@Test fun testIpv6() {
val socket = FakeSocket(
localAddress = InetAddress.getByAddress(
"::1",
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)
),
localPort = 80
)
val socket =
FakeSocket(
localAddress =
InetAddress.getByAddress(
"::1",
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
),
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://[::1]/")
}
@Test fun testUsesLocal() {
val socket = FakeSocket(
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
localPort = 80
)
val socket =
FakeSocket(
localAddress = InetAddress.getByAddress("127.0.0.1", byteArrayOf(127, 0, 0, 1)),
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://127.0.0.1/")
}
@Test fun testHostname() {
val headers = headersOf("Host", "host-from-header.com")
val socket = FakeSocket(
localAddress = InetAddress.getByAddress(
"host-from-address.com",
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)
),
localPort = 80
)
val socket =
FakeSocket(
localAddress =
InetAddress.getByAddress(
"host-from-address.com",
byteArrayOf(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
),
localPort = 80,
)
val request = RecordedRequest("GET / HTTP/1.1", headers, emptyList(), 0, Buffer(), 0, socket)
assertThat(request.requestUrl.toString()).isEqualTo("http://host-from-header.com/")
}
@ -78,11 +85,14 @@ class RecordedRequestTest {
private val localAddress: InetAddress,
private val localPort: Int,
private val remoteAddress: InetAddress = localAddress,
private val remotePort: Int = 1234
private val remotePort: Int = 1234,
) : Socket() {
override fun getInetAddress() = remoteAddress
override fun getLocalAddress() = localAddress
override fun getLocalPort() = localPort
override fun getPort() = remotePort
}
}

View File

@ -39,7 +39,7 @@ import okio.source
/** A basic HTTP/2 server that serves the contents of a local directory. */
class Http2Server(
private val baseDirectory: File,
private val sslSocketFactory: SSLSocketFactory
private val sslSocketFactory: SSLSocketFactory,
) : Http2Connection.Listener() {
private fun run() {
val serverSocket = ServerSocket(8888)
@ -54,10 +54,11 @@ class Http2Server(
if (protocol != Protocol.HTTP_2) {
throw ProtocolException("Protocol $protocol unsupported")
}
val connection = Http2Connection.Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket)
.listener(this)
.build()
val connection =
Http2Connection.Builder(false, TaskRunner.INSTANCE)
.socket(sslSocket)
.listener(this)
.build()
connection.start()
} catch (e: IOException) {
logger.log(Level.INFO, "Http2Server connection failure: $e")
@ -70,11 +71,13 @@ class Http2Server(
}
private fun doSsl(socket: Socket): SSLSocket {
val sslSocket = sslSocketFactory.createSocket(
socket, socket.inetAddress.hostAddress,
socket.port,
true
) as SSLSocket
val sslSocket =
sslSocketFactory.createSocket(
socket,
socket.inetAddress.hostAddress,
socket.port,
true,
) as SSLSocket
sslSocket.useClientMode = false
Platform.get().configureTlsExtensions(sslSocket, null, listOf(Protocol.HTTP_2))
sslSocket.startHandshake()
@ -111,32 +114,40 @@ class Http2Server(
}
}
private fun send404(stream: Http2Stream, path: String) {
val responseHeaders = listOf(
Header(":status", "404"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/plain")
)
private fun send404(
stream: Http2Stream,
path: String,
) {
val responseHeaders =
listOf(
Header(":status", "404"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/plain"),
)
stream.writeHeaders(
responseHeaders = responseHeaders,
outFinished = false,
flushHeaders = false
flushHeaders = false,
)
val out = stream.getSink().buffer()
out.writeUtf8("Not found: $path")
out.close()
}
private fun serveDirectory(stream: Http2Stream, files: Array<File>) {
val responseHeaders = listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/html; charset=UTF-8")
)
private fun serveDirectory(
stream: Http2Stream,
files: Array<File>,
) {
val responseHeaders =
listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", "text/html; charset=UTF-8"),
)
stream.writeHeaders(
responseHeaders = responseHeaders,
outFinished = false,
flushHeaders = false
flushHeaders = false,
)
val out = stream.getSink().buffer()
for (file in files) {
@ -146,16 +157,20 @@ class Http2Server(
out.close()
}
private fun serveFile(stream: Http2Stream, file: File) {
val responseHeaders = listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", contentType(file))
)
private fun serveFile(
stream: Http2Stream,
file: File,
) {
val responseHeaders =
listOf(
Header(":status", "200"),
Header(":version", "HTTP/1.1"),
Header("content-type", contentType(file)),
)
stream.writeHeaders(
responseHeaders = responseHeaders,
outFinished = false,
flushHeaders = false
flushHeaders = false,
)
file.source().use { source ->
stream.getSink().buffer().use { sink ->
@ -186,10 +201,11 @@ class Http2Server(
println("Usage: Http2Server <base directory>")
return
}
val server = Http2Server(
File(args[0]),
localhost().sslContext().socketFactory
)
val server =
Http2Server(
File(args[0]),
localhost().sslContext().socketFactory,
)
server.run()
}
}

View File

@ -15,19 +15,22 @@
*/
package okhttp3
import java.io.OutputStream
import java.io.PrintStream
import org.junit.platform.engine.TestExecutionResult
import org.junit.platform.launcher.TestExecutionListener
import org.junit.platform.launcher.TestIdentifier
import org.junit.platform.launcher.TestPlan
import java.io.OutputStream
import java.io.PrintStream
object DotListener: TestExecutionListener {
object DotListener : TestExecutionListener {
private var originalSystemErr: PrintStream? = null
private var originalSystemOut: PrintStream? = null
private var testCount = 0
override fun executionSkipped(testIdentifier: TestIdentifier, reason: String) {
override fun executionSkipped(
testIdentifier: TestIdentifier,
reason: String,
) {
printStatus("-")
}
@ -40,7 +43,7 @@ object DotListener: TestExecutionListener {
override fun executionFinished(
testIdentifier: TestIdentifier,
testExecutionResult: TestExecutionResult
testExecutionResult: TestExecutionResult,
) {
if (!testIdentifier.isContainer) {
when (testExecutionResult.status!!) {
@ -59,8 +62,8 @@ object DotListener: TestExecutionListener {
originalSystemOut = System.out
originalSystemErr = System.err
System.setOut(object: PrintStream(OutputStream.nullOutputStream()) {})
System.setErr(object: PrintStream(OutputStream.nullOutputStream()) {})
System.setOut(object : PrintStream(OutputStream.nullOutputStream()) {})
System.setErr(object : PrintStream(OutputStream.nullOutputStream()) {})
}
fun uninstall() {
@ -71,4 +74,4 @@ object DotListener: TestExecutionListener {
System.setErr(it)
}
}
}
}

View File

@ -15,26 +15,27 @@
*/
package okhttp3
import java.io.File
import org.junit.jupiter.engine.descriptor.ClassBasedTestDescriptor
import org.junit.platform.engine.discovery.DiscoverySelectors
import java.io.File
// TODO move to junit5 tags
val avoidedTests = setOf(
"okhttp3.BouncyCastleTest",
"okhttp3.ConscryptTest",
"okhttp3.CorrettoTest",
"okhttp3.OpenJSSETest",
"okhttp3.internal.platform.Jdk8WithJettyBootPlatformTest",
"okhttp3.internal.platform.Jdk9PlatformTest",
"okhttp3.internal.platform.PlatformTest",
"okhttp3.internal.platform.android.AndroidSocketAdapterTest",
"okhttp3.osgi.OsgiTest",
// Hanging.
"okhttp3.CookiesTest",
// Hanging.
"okhttp3.WholeOperationTimeoutTest",
)
val avoidedTests =
setOf(
"okhttp3.BouncyCastleTest",
"okhttp3.ConscryptTest",
"okhttp3.CorrettoTest",
"okhttp3.OpenJSSETest",
"okhttp3.internal.platform.Jdk8WithJettyBootPlatformTest",
"okhttp3.internal.platform.Jdk9PlatformTest",
"okhttp3.internal.platform.PlatformTest",
"okhttp3.internal.platform.android.AndroidSocketAdapterTest",
"okhttp3.osgi.OsgiTest",
// Hanging.
"okhttp3.CookiesTest",
// Hanging.
"okhttp3.WholeOperationTimeoutTest",
)
/**
* Run periodically to refresh the known set of working tests.
@ -44,11 +45,12 @@ val avoidedTests = setOf(
fun main() {
val knownTestFile = File("native-image-tests/src/main/resources/testlist.txt")
val testSelector = DiscoverySelectors.selectPackage("okhttp3")
val testClasses = findTests(listOf(testSelector))
.filter { it.isContainer }
.mapNotNull { (it as? ClassBasedTestDescriptor)?.testClass?.name }
.filterNot { it in avoidedTests }
.sorted()
.distinct()
val testClasses =
findTests(listOf(testSelector))
.filter { it.isContainer }
.mapNotNull { (it as? ClassBasedTestDescriptor)?.testClass?.name }
.filterNot { it in avoidedTests }
.sorted()
.distinct()
knownTestFile.writeText(testClasses.joinToString("\n"))
}

View File

@ -15,6 +15,9 @@
*/
package okhttp3
import java.io.File
import java.io.PrintWriter
import kotlin.system.exitProcess
import org.junit.jupiter.engine.JupiterTestEngine
import org.junit.platform.console.options.Theme
import org.junit.platform.engine.DiscoverySelector
@ -30,9 +33,6 @@ import org.junit.platform.launcher.core.LauncherConfig
import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder
import org.junit.platform.launcher.core.LauncherFactory
import org.junit.platform.launcher.listeners.SummaryGeneratingListener
import java.io.File
import java.io.PrintWriter
import kotlin.system.exitProcess
/**
* Graal main method to run tests with minimal reflection and automatic settings.
@ -49,13 +49,14 @@ fun main(vararg args: String) {
val jupiterTestEngine = buildTestEngine()
val config = LauncherConfig.builder()
.enableTestExecutionListenerAutoRegistration(false)
.enableTestEngineAutoRegistration(false)
.enablePostDiscoveryFilterAutoRegistration(false)
.addTestEngines(jupiterTestEngine)
.addTestExecutionListeners(DotListener, summaryListener, treeListener)
.build()
val config =
LauncherConfig.builder()
.enableTestExecutionListenerAutoRegistration(false)
.enableTestEngineAutoRegistration(false)
.enablePostDiscoveryFilterAutoRegistration(false)
.addTestEngines(jupiterTestEngine)
.addTestExecutionListeners(DotListener, summaryListener, treeListener)
.build()
val launcher: Launcher = LauncherFactory.create(config)
val request: LauncherDiscoveryRequest = buildRequest(selectors)
@ -89,8 +90,9 @@ fun testSelectors(inputFile: File? = null): List<DiscoverySelector> {
val lines =
inputFile?.readLines() ?: sampleTestClass.getResource("/testlist.txt").readText().lines()
val flatClassnameList = lines
.filter { it.isNotBlank() }
val flatClassnameList =
lines
.filter { it.isNotBlank() }
return flatClassnameList
.mapNotNull {
@ -107,11 +109,12 @@ fun testSelectors(inputFile: File? = null): List<DiscoverySelector> {
* Builds a Junit Test Plan request for a fixed set of classes, or potentially a recursive package.
*/
fun buildRequest(selectors: List<DiscoverySelector>): LauncherDiscoveryRequest {
val request: LauncherDiscoveryRequest = LauncherDiscoveryRequestBuilder.request()
// TODO replace junit.jupiter.extensions.autodetection.enabled with API approach.
val request: LauncherDiscoveryRequest =
LauncherDiscoveryRequestBuilder.request()
// TODO replace junit.jupiter.extensions.autodetection.enabled with API approach.
// .enableImplicitConfigurationParameters(false)
.selectors(selectors)
.build()
.selectors(selectors)
.build()
return request
}
@ -136,11 +139,13 @@ fun findTests(selectors: List<DiscoverySelector>): List<TestDescriptor> {
* https://github.com/junit-team/junit5/issues/2469
*/
fun treeListener(): TestExecutionListener {
val colorPalette = Class.forName("org.junit.platform.console.tasks.ColorPalette").getField("DEFAULT").apply {
isAccessible = true
}.get(null)
val colorPalette =
Class.forName("org.junit.platform.console.tasks.ColorPalette").getField("DEFAULT").apply {
isAccessible = true
}.get(null)
return Class.forName(
"org.junit.platform.console.tasks.TreePrintingListener").declaredConstructors.first()
"org.junit.platform.console.tasks.TreePrintingListener",
).declaredConstructors.first()
.apply {
isAccessible = true
}

View File

@ -26,7 +26,8 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
class SampleTest {
@JvmField @RegisterExtension val clientRule = OkHttpClientTestRule()
@JvmField @RegisterExtension
val clientRule = OkHttpClientTestRule()
@Test
fun passingTest() {
@ -59,6 +60,6 @@ class SampleTest {
}
}
class SampleTestProvider: SimpleProvider() {
class SampleTestProvider : SimpleProvider() {
override fun arguments() = listOf("A", "B")
}

View File

@ -16,11 +16,11 @@
package okhttp3
import com.oracle.svm.core.annotate.AutomaticFeature
import java.io.File
import java.lang.IllegalStateException
import org.graalvm.nativeimage.hosted.Feature
import org.graalvm.nativeimage.hosted.RuntimeClassInitialization
import org.graalvm.nativeimage.hosted.RuntimeReflection
import java.io.File
import java.lang.IllegalStateException
@AutomaticFeature
class TestRegistration : Feature {
@ -40,7 +40,10 @@ class TestRegistration : Feature {
registerParamProvider(access, "okhttp3.WebPlatformUrlTest\$TestDataParamProvider")
}
private fun registerParamProvider(access: Feature.BeforeAnalysisAccess, provider: String) {
private fun registerParamProvider(
access: Feature.BeforeAnalysisAccess,
provider: String,
) {
val providerClass = access.findClassByName(provider)
if (providerClass != null) {
registerTest(access, providerClass)
@ -54,7 +57,10 @@ class TestRegistration : Feature {
registerStandardClass(access, "org.junit.platform.console.tasks.TreePrintingListener")
}
private fun registerStandardClass(access: Feature.BeforeAnalysisAccess, name: String) {
private fun registerStandardClass(
access: Feature.BeforeAnalysisAccess,
name: String,
) {
val clazz: Class<*> = access.findClassByName(name) ?: throw IllegalStateException("Missing class $name")
RuntimeReflection.register(clazz)
clazz.declaredConstructors.forEach {
@ -81,7 +87,10 @@ class TestRegistration : Feature {
}
}
private fun registerTest(access: Feature.BeforeAnalysisAccess, java: Class<*>) {
private fun registerTest(
access: Feature.BeforeAnalysisAccess,
java: Class<*>,
) {
access.registerAsUsed(java)
RuntimeReflection.register(java)
java.constructors.forEach {

View File

@ -55,4 +55,4 @@ class NativeImageTestsTest {
assertNotNull(listener)
}
}
}

View File

@ -40,35 +40,41 @@ import okhttp3.logging.HttpLoggingInterceptor
import okhttp3.logging.LoggingEventListener
class Main : CliktCommand(name = NAME, help = "A curl for the next-generation web.") {
val method: String? by option("-X", "--request", help="Specify request command to use")
val method: String? by option("-X", "--request", help = "Specify request command to use")
val data: String? by option("-d", "--data", help="HTTP POST data")
val data: String? by option("-d", "--data", help = "HTTP POST data")
val headers: List<String>? by option("-H", "--header", help="Custom header to pass to server").multiple()
val headers: List<String>? by option("-H", "--header", help = "Custom header to pass to server").multiple()
val userAgent: String by option("-A", "--user-agent", help="User-Agent to send to server").default(NAME + "/" + versionString())
val userAgent: String by option("-A", "--user-agent", help = "User-Agent to send to server").default(NAME + "/" + versionString())
val connectTimeout: Int by option("--connect-timeout", help="Maximum time allowed for connection (seconds)").int().default(DEFAULT_TIMEOUT)
val connectTimeout: Int by option(
"--connect-timeout",
help = "Maximum time allowed for connection (seconds)",
).int().default(DEFAULT_TIMEOUT)
val readTimeout: Int by option("--read-timeout", help="Maximum time allowed for reading data (seconds)").int().default(DEFAULT_TIMEOUT)
val readTimeout: Int by option("--read-timeout", help = "Maximum time allowed for reading data (seconds)").int().default(DEFAULT_TIMEOUT)
val callTimeout: Int by option("--call-timeout", help="Maximum time allowed for the entire call (seconds)").int().default(DEFAULT_TIMEOUT)
val callTimeout: Int by option(
"--call-timeout",
help = "Maximum time allowed for the entire call (seconds)",
).int().default(DEFAULT_TIMEOUT)
val followRedirects: Boolean by option("-L", "--location", help="Follow redirects").flag()
val followRedirects: Boolean by option("-L", "--location", help = "Follow redirects").flag()
val allowInsecure: Boolean by option("-k", "--insecure", help="Allow connections to SSL sites without certs").flag()
val allowInsecure: Boolean by option("-k", "--insecure", help = "Allow connections to SSL sites without certs").flag()
val showHeaders: Boolean by option("-i", "--include", help="Include protocol headers in the output").flag()
val showHeaders: Boolean by option("-i", "--include", help = "Include protocol headers in the output").flag()
val showHttp2Frames: Boolean by option("--frames", help="Log HTTP/2 frames to STDERR").flag()
val showHttp2Frames: Boolean by option("--frames", help = "Log HTTP/2 frames to STDERR").flag()
val referer: String? by option("-e", "--referer", help="Referer URL")
val referer: String? by option("-e", "--referer", help = "Referer URL")
val verbose: Boolean by option("-v", "--verbose", help="Makes $NAME verbose during the operation").flag()
val verbose: Boolean by option("-v", "--verbose", help = "Makes $NAME verbose during the operation").flag()
val sslDebug: Boolean by option(help="Output SSL Debug").flag()
val sslDebug: Boolean by option(help = "Output SSL Debug").flag()
val url: String? by argument(name = "url", help="Remote resource URL")
val url: String? by argument(name = "url", help = "Remote resource URL")
var client: Call.Factory? = null
@ -123,20 +129,26 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
return prop.getProperty("version", "dev")
}
private fun createInsecureTrustManager(): X509TrustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {}
private fun createInsecureTrustManager(): X509TrustManager =
object : X509TrustManager {
override fun checkClientTrusted(
chain: Array<X509Certificate>,
authType: String,
) {}
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {}
override fun checkServerTrusted(
chain: Array<X509Certificate>,
authType: String,
) {}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
private fun createInsecureSslSocketFactory(trustManager: TrustManager): SSLSocketFactory =
Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}.socketFactory
Platform.get().newSSLContext().apply {
init(null, arrayOf(trustManager), null)
}.socketFactory
private fun createInsecureHostnameVerifier(): HostnameVerifier =
HostnameVerifier { _, _ -> true }
private fun createInsecureHostnameVerifier(): HostnameVerifier = HostnameVerifier { _, _ -> true }
}
}

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("ktlint:standard:filename")
package okhttp3.curl.internal
import java.io.IOException
@ -53,15 +54,16 @@ internal fun Main.commonCreateRequest(): Request {
}
private fun Main.mediaType(): MediaType? {
val mimeType = headers?.let {
for (header in it) {
val parts = header.split(':', limit = 2)
if ("Content-Type".equals(parts[0], ignoreCase = true)) {
return@let parts[1].trim()
val mimeType =
headers?.let {
for (header in it) {
val parts = header.split(':', limit = 2)
if ("Content-Type".equals(parts[0], ignoreCase = true)) {
return@let parts[1].trim()
}
}
}
return@let null
} ?: "application/x-www-form-urlencoded"
return@let null
} ?: "application/x-www-form-urlencoded"
return mimeType.toMediaTypeOrNull()
}

View File

@ -1,32 +1,37 @@
package okhttp3.curl.logging
import okhttp3.internal.http2.Http2
import java.util.logging.ConsoleHandler
import java.util.logging.Level
import java.util.logging.LogManager
import java.util.logging.LogRecord
import java.util.logging.Logger
import okhttp3.internal.http2.Http2
class LoggingUtil {
companion object {
private val activeLoggers = mutableListOf<Logger>()
fun configureLogging(debug: Boolean, showHttp2Frames: Boolean, sslDebug: Boolean) {
fun configureLogging(
debug: Boolean,
showHttp2Frames: Boolean,
sslDebug: Boolean,
) {
if (debug || showHttp2Frames || sslDebug) {
if (sslDebug) {
System.setProperty("javax.net.debug", "")
}
LogManager.getLogManager().reset()
val handler = object : ConsoleHandler() {
override fun publish(record: LogRecord) {
super.publish(record)
val handler =
object : ConsoleHandler() {
override fun publish(record: LogRecord) {
super.publish(record)
val parameters = record.parameters
if (sslDebug && record.loggerName == "javax.net.ssl" && parameters != null) {
System.err.println(parameters[0])
val parameters = record.parameters
if (sslDebug && record.loggerName == "javax.net.ssl" && parameters != null) {
System.err.println(parameters[0])
}
}
}
}
if (debug) {
handler.level = Level.ALL

View File

@ -19,16 +19,17 @@ import java.util.logging.LogRecord
* Why so much construction?
*/
class OneLineLogFormat : Formatter() {
private val d = DateTimeFormatterBuilder()
.appendValue(HOUR_OF_DAY, 2)
.appendLiteral(':')
.appendValue(MINUTE_OF_HOUR, 2)
.optionalStart()
.appendLiteral(':')
.appendValue(SECOND_OF_MINUTE, 2)
.optionalStart()
.appendFraction(NANO_OF_SECOND, 3, 3, true)
.toFormatter()
private val d =
DateTimeFormatterBuilder()
.appendValue(HOUR_OF_DAY, 2)
.appendLiteral(':')
.appendValue(MINUTE_OF_HOUR, 2)
.optionalStart()
.appendLiteral(':')
.appendValue(SECOND_OF_MINUTE, 2)
.optionalStart()
.appendFraction(NANO_OF_SECOND, 3, 3, true)
.toFormatter()
private val offset = ZoneOffset.systemDefault()

View File

@ -15,14 +15,14 @@
*/
package okhttp3.curl
import java.io.IOException
import okhttp3.RequestBody
import okio.Buffer
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNull
import assertk.assertions.startsWith
import java.io.IOException
import kotlin.test.Test
import okhttp3.RequestBody
import okio.Buffer
class MainTest {
@Test
@ -34,7 +34,8 @@ class MainTest {
}
@Test
@Throws(IOException::class) fun put() {
@Throws(IOException::class)
fun put() {
val request = fromArgs("-X", "PUT", "-d", "foo", "http://example.com").createRequest()
assertThat(request.method).isEqualTo("PUT")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
@ -48,7 +49,7 @@ class MainTest {
assertThat(request.method).isEqualTo("POST")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
assertThat(body!!.contentType().toString()).isEqualTo(
"application/x-www-form-urlencoded; charset=utf-8"
"application/x-www-form-urlencoded; charset=utf-8",
)
assertThat(bodyAsString(body)).isEqualTo("foo")
}
@ -60,17 +61,21 @@ class MainTest {
assertThat(request.method).isEqualTo("PUT")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
assertThat(body!!.contentType().toString()).isEqualTo(
"application/x-www-form-urlencoded; charset=utf-8"
"application/x-www-form-urlencoded; charset=utf-8",
)
assertThat(bodyAsString(body)).isEqualTo("foo")
}
@Test
fun contentTypeHeader() {
val request = fromArgs(
"-d", "foo", "-H", "Content-Type: application/json",
"http://example.com"
).createRequest()
val request =
fromArgs(
"-d",
"foo",
"-H",
"Content-Type: application/json",
"http://example.com",
).createRequest()
val body = request.body
assertThat(request.method).isEqualTo("POST")
assertThat(request.url.toString()).isEqualTo("http://example.com/")
@ -105,12 +110,14 @@ class MainTest {
@Test
fun headerSplitWithDate() {
val request = fromArgs(
"-H", "If-Modified-Since: Mon, 18 Aug 2014 15:16:06 GMT",
"http://example.com"
).createRequest()
val request =
fromArgs(
"-H",
"If-Modified-Since: Mon, 18 Aug 2014 15:16:06 GMT",
"http://example.com",
).createRequest()
assertThat(request.header("If-Modified-Since")).isEqualTo(
"Mon, 18 Aug 2014 15:16:06 GMT"
"Mon, 18 Aug 2014 15:16:06 GMT",
)
}

View File

@ -50,14 +50,16 @@ import org.junit.Test
* Run with "./gradlew :android-test:connectedCheck -PandroidBuild=true" and make sure ANDROID_SDK_ROOT is set.
*/
class AndroidAsyncDnsTest {
@JvmField @Rule val serverRule = MockWebServerRule()
@JvmField @Rule
val serverRule = MockWebServerRule()
private lateinit var client: OkHttpClient
private val localhost: HandshakeCertificates by lazy {
// Generate a self-signed cert for the server to serve and the client to trust.
val heldCertificate = HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build()
val heldCertificate =
HeldCertificate.Builder()
.addSubjectAlternativeName("localhost")
.build()
return@lazy HandshakeCertificates.Builder()
.addPlatformTrustedCertificates()
.heldCertificate(heldCertificate)
@ -69,10 +71,11 @@ class AndroidAsyncDnsTest {
fun init() {
assumeTrue("Supported on API 29+", Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q)
client = OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()
client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.sslSocketFactory(localhost.sslSocketFactory(), localhost.trustManager)
.build()
serverRule.server.useHttps(localhost.sslSocketFactory())
}
@ -127,17 +130,26 @@ class AndroidAsyncDnsTest {
val latch = CountDownLatch(1)
// assumes an IPv4 address
AndroidAsyncDns.IPv4.query(hostname, object : AsyncDns.Callback {
override fun onResponse(hostname: String, addresses: List<InetAddress>) {
allAddresses.addAll(addresses)
latch.countDown()
}
AndroidAsyncDns.IPv4.query(
hostname,
object : AsyncDns.Callback {
override fun onResponse(
hostname: String,
addresses: List<InetAddress>,
) {
allAddresses.addAll(addresses)
latch.countDown()
}
override fun onFailure(hostname: String, e: IOException) {
exception = e
latch.countDown()
}
})
override fun onFailure(
hostname: String,
e: IOException,
) {
exception = e
latch.countDown()
}
},
)
latch.await()
@ -173,10 +185,11 @@ class AndroidAsyncDnsTest {
val network =
connectivityManager.activeNetwork ?: throw AssumptionViolatedException("No active network")
val client = OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.socketFactory(network.socketFactory)
.build()
val client =
OkHttpClient.Builder()
.dns(AsyncDns.toDns(AndroidAsyncDns.IPv4, AndroidAsyncDns.IPv6))
.socketFactory(network.socketFactory)
.build()
val call =
client.newCall(Request("https://google.com/robots.txt".toHttpUrl()))

View File

@ -43,20 +43,34 @@ class AndroidAsyncDns(
private val resolver = DnsResolver.getInstance()
private val executor = Executors.newSingleThreadExecutor()
override fun query(hostname: String, callback: AsyncDns.Callback) {
override fun query(
hostname: String,
callback: AsyncDns.Callback,
) {
resolver.query(
network, hostname, dnsClass.type, DnsResolver.FLAG_EMPTY, executor, null,
network,
hostname,
dnsClass.type,
DnsResolver.FLAG_EMPTY,
executor,
null,
object : DnsResolver.Callback<List<InetAddress>> {
override fun onAnswer(addresses: List<InetAddress>, rCode: Int) {
override fun onAnswer(
addresses: List<InetAddress>,
rCode: Int,
) {
callback.onResponse(hostname, addresses)
}
override fun onError(e: DnsResolver.DnsException) {
callback.onFailure(hostname, UnknownHostException(e.message).apply {
initCause(e)
})
callback.onFailure(
hostname,
UnknownHostException(e.message).apply {
initCause(e)
},
)
}
}
},
)
}

View File

@ -43,16 +43,16 @@ import org.robolectric.annotation.Config
sdk = [30],
)
class RobolectricOkHttpClientTest {
private lateinit var context: Context
private lateinit var client: OkHttpClient
@Before
fun setUp() {
context = ApplicationProvider.getApplicationContext<Application>()
client = OkHttpClient.Builder()
.cache(Cache("/cache".toPath(), 10_000_000, FakeFileSystem()))
.build()
client =
OkHttpClient.Builder()
.cache(Cache("/cache".toPath(), 10_000_000, FakeFileSystem()))
.build()
}
@Test
@ -61,8 +61,9 @@ class RobolectricOkHttpClientTest {
val request = Request("https://www.google.com/robots.txt".toHttpUrl())
val networkRequest = request.newBuilder()
.build()
val networkRequest =
request.newBuilder()
.build()
val call = client.newCall(networkRequest)

View File

@ -28,9 +28,10 @@ import okhttp3.brotli.internal.uncompress
object BrotliInterceptor : Interceptor {
override fun intercept(chain: Interceptor.Chain): Response {
return if (chain.request().header("Accept-Encoding") == null) {
val request = chain.request().newBuilder()
.header("Accept-Encoding", "br,gzip")
.build()
val request =
chain.request().newBuilder()
.header("Accept-Encoding", "br,gzip")
.build()
val response = chain.proceed(request)

View File

@ -30,13 +30,14 @@ fun uncompress(response: Response): Response {
val body = response.body
val encoding = response.header("Content-Encoding") ?: return response
val decompressedSource = when {
encoding.equals("br", ignoreCase = true) ->
BrotliInputStream(body.source().inputStream()).source().buffer()
encoding.equals("gzip", ignoreCase = true) ->
GzipSource(body.source()).buffer()
else -> return response
}
val decompressedSource =
when {
encoding.equals("br", ignoreCase = true) ->
BrotliInputStream(body.source().inputStream()).source().buffer()
encoding.equals("gzip", ignoreCase = true) ->
GzipSource(body.source()).buffer()
else -> return response
}
return response.newBuilder()
.removeHeader("Content-Encoding")

View File

@ -38,14 +38,15 @@ class BrotliInterceptorTest {
@Test
fun testUncompressBrotli() {
val s =
"1bce00009c05ceb9f028d14e416230f718960a537b0922d2f7b6adef56532c08dff44551516690131494db" +
"6021c7e3616c82c1bc2416abb919aaa06e8d30d82cc2981c2f5c900bfb8ee29d5c03deb1c0dacff80e" +
"abe82ba64ed250a497162006824684db917963ecebe041b352a3e62d629cc97b95cac24265b175171e" +
"5cb384cd0912aeb5b5dd9555f2dd1a9b20688201"
"1bce00009c05ceb9f028d14e416230f718960a537b0922d2f7b6adef56532c08dff44551516690131494db" +
"6021c7e3616c82c1bc2416abb919aaa06e8d30d82cc2981c2f5c900bfb8ee29d5c03deb1c0dacff80e" +
"abe82ba64ed250a497162006824684db917963ecebe041b352a3e62d629cc97b95cac24265b175171e" +
"5cb384cd0912aeb5b5dd9555f2dd1a9b20688201"
val response = response("https://httpbin.org/brotli", s.decodeHex()) {
header("Content-Encoding", "br")
}
val response =
response("https://httpbin.org/brotli", s.decodeHex()) {
header("Content-Encoding", "br")
}
val uncompressed = uncompress(response)
@ -57,15 +58,16 @@ class BrotliInterceptorTest {
@Test
fun testUncompressGzip() {
val s =
"1f8b0800968f215d02ff558ec10e82301044ef7c45b3e75269d0c478e340e4a426e007086c4a636c9bb65e" +
"24fcbb5b484c3cec61deccecee9c3106eaa39dc3114e2cfa377296d8848f117d20369324500d03ba98" +
"d766b0a3368a0ce83d4f55581b14696c88894f31ba5e1b61bdfa79f7803eaf149a35619f29b3db0b29" +
"8abcbd54b7b6b97640c965bbfec238d9f4109ceb6edb01d66ba54d6247296441531e445970f627215b" +
"b22f1017320dd5000000"
"1f8b0800968f215d02ff558ec10e82301044ef7c45b3e75269d0c478e340e4a426e007086c4a636c9bb65e" +
"24fcbb5b484c3cec61deccecee9c3106eaa39dc3114e2cfa377296d8848f117d20369324500d03ba98" +
"d766b0a3368a0ce83d4f55581b14696c88894f31ba5e1b61bdfa79f7803eaf149a35619f29b3db0b29" +
"8abcbd54b7b6b97640c965bbfec238d9f4109ceb6edb01d66ba54d6247296441531e445970f627215b" +
"b22f1017320dd5000000"
val response = response("https://httpbin.org/gzip", s.decodeHex()) {
header("Content-Encoding", "gzip")
}
val response =
response("https://httpbin.org/gzip", s.decodeHex()) {
header("Content-Encoding", "gzip")
}
val uncompressed = uncompress(response)
@ -86,9 +88,10 @@ class BrotliInterceptorTest {
@Test
fun testFailsUncompress() {
val response = response("https://httpbin.org/brotli", "bb919aaa06e8".decodeHex()) {
header("Content-Encoding", "br")
}
val response =
response("https://httpbin.org/brotli", "bb919aaa06e8".decodeHex()) {
header("Content-Encoding", "br")
}
assertFailsWith<IOException> {
val failingResponse = uncompress(response)
@ -101,11 +104,12 @@ class BrotliInterceptorTest {
@Test
fun testSkipUncompressNoContentResponse() {
val response = response("https://httpbin.org/brotli", EMPTY) {
header("Content-Encoding", "br")
code(204)
message("NO CONTENT")
}
val response =
response("https://httpbin.org/brotli", EMPTY) {
header("Content-Encoding", "br")
code(204)
message("NO CONTENT")
}
val same = uncompress(response)
@ -116,15 +120,15 @@ class BrotliInterceptorTest {
private fun response(
url: String,
bodyHex: ByteString,
fn: Response.Builder.() -> Unit = {}
fn: Response.Builder.() -> Unit = {},
): Response {
return Response.Builder()
.body(bodyHex.toResponseBody("text/plain".toMediaType()))
.code(200)
.message("OK")
.request(Request.Builder().url(url).build())
.protocol(Protocol.HTTP_2)
.apply(fn)
.build()
.body(bodyHex.toResponseBody("text/plain".toMediaType()))
.code(200)
.message("OK")
.request(Request.Builder().url(url).build())
.protocol(Protocol.HTTP_2)
.apply(fn)
.build()
}
}

View File

@ -19,7 +19,8 @@ import okhttp3.OkHttpClient
import okhttp3.Request
fun main() {
val client = OkHttpClient.Builder()
val client =
OkHttpClient.Builder()
.addInterceptor(BrotliInterceptor)
.build()
@ -27,7 +28,10 @@ fun main() {
sendRequest("https://httpbin.org/gzip", client)
}
private fun sendRequest(url: String, client: OkHttpClient) {
private fun sendRequest(
url: String,
client: OkHttpClient,
) {
val req = Request.Builder().url(url).build()
client.newCall(req).execute().use {

View File

@ -17,23 +17,32 @@
package okhttp3
import kotlin.coroutines.resumeWithException
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.suspendCancellableCoroutine
import okio.IOException
import kotlin.coroutines.resumeWithException
@OptIn(ExperimentalCoroutinesApi::class)
suspend fun Call.executeAsync(): Response = suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
this.cancel()
}
this.enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
continuation.resumeWithException(e)
suspend fun Call.executeAsync(): Response =
suspendCancellableCoroutine { continuation ->
continuation.invokeOnCancellation {
this.cancel()
}
this.enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
continuation.resumeWithException(e)
}
override fun onResponse(call: Call, response: Response) {
continuation.resume(value = response, onCancellation = { call.cancel() })
}
})
}
override fun onResponse(
call: Call,
response: Response,
) {
continuation.resume(value = response, onCancellation = { call.cancel() })
}
},
)
}

View File

@ -80,7 +80,7 @@ class SuspendCallTest {
MockResponse.Builder()
.bodyDelay(5, TimeUnit.SECONDS)
.body("abc")
.build()
.build(),
)
val call = client.newCall(request)
@ -109,7 +109,7 @@ class SuspendCallTest {
MockResponse.Builder()
.bodyDelay(5, TimeUnit.SECONDS)
.body("abc")
.build()
.build(),
)
val call = client.newCall(request)
@ -137,7 +137,7 @@ class SuspendCallTest {
MockResponse(
body = "abc",
socketPolicy = DisconnectAfterRequest,
)
),
)
val call = client.newCall(request)

View File

@ -26,13 +26,13 @@ import okhttp3.Dns
*/
internal class BootstrapDns(
private val dnsHostname: String,
private val dnsServers: List<InetAddress>
private val dnsServers: List<InetAddress>,
) : Dns {
@Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> {
if (this.dnsHostname != hostname) {
throw UnknownHostException(
"BootstrapDns called for $hostname instead of $dnsHostname"
"BootstrapDns called for $hostname instead of $dnsHostname",
)
}

View File

@ -51,7 +51,7 @@ class DnsOverHttps internal constructor(
@get:JvmName("includeIPv6") val includeIPv6: Boolean,
@get:JvmName("post") val post: Boolean,
@get:JvmName("resolvePrivateAddresses") val resolvePrivateAddresses: Boolean,
@get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean
@get:JvmName("resolvePublicAddresses") val resolvePublicAddresses: Boolean,
) : Dns {
@Throws(UnknownHostException::class)
override fun lookup(hostname: String): List<InetAddress> {
@ -94,37 +94,46 @@ class DnsOverHttps internal constructor(
networkRequests: MutableList<Call>,
results: MutableList<InetAddress>,
failures: MutableList<Exception>,
type: Int
type: Int,
) {
val request = buildRequest(hostname, type)
val response = getCacheOnlyResponse(request)
response?.let { processResponse(it, hostname, results, failures) } ?: networkRequests.add(
client.newCall(request))
client.newCall(request),
)
}
private fun executeRequests(
hostname: String,
networkRequests: List<Call>,
responses: MutableList<InetAddress>,
failures: MutableList<Exception>
failures: MutableList<Exception>,
) {
val latch = CountDownLatch(networkRequests.size)
for (call in networkRequests) {
call.enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
synchronized(failures) {
failures.add(e)
call.enqueue(
object : Callback {
override fun onFailure(
call: Call,
e: IOException,
) {
synchronized(failures) {
failures.add(e)
}
latch.countDown()
}
latch.countDown()
}
override fun onResponse(call: Call, response: Response) {
processResponse(response, hostname, responses, failures)
latch.countDown()
}
})
override fun onResponse(
call: Call,
response: Response,
) {
processResponse(response, hostname, responses, failures)
latch.countDown()
}
},
)
}
try {
@ -138,7 +147,7 @@ class DnsOverHttps internal constructor(
response: Response,
hostname: String,
results: MutableList<InetAddress>,
failures: MutableList<Exception>
failures: MutableList<Exception>,
) {
try {
val addresses = readResponse(hostname, response)
@ -153,7 +162,10 @@ class DnsOverHttps internal constructor(
}
@Throws(UnknownHostException::class)
private fun throwBestFailure(hostname: String, failures: List<Exception>): List<InetAddress> {
private fun throwBestFailure(
hostname: String,
failures: List<Exception>,
): List<InetAddress> {
if (failures.isEmpty()) {
throw UnknownHostException(hostname)
}
@ -179,7 +191,8 @@ class DnsOverHttps internal constructor(
try {
// Use the cache without hitting the network first
// 504 code indicates that the Cache is stale
val preferCache = CacheControl.Builder()
val preferCache =
CacheControl.Builder()
.onlyIfCached()
.build()
val cacheRequest = request.newBuilder().cacheControl(preferCache).build()
@ -199,7 +212,10 @@ class DnsOverHttps internal constructor(
}
@Throws(Exception::class)
private fun readResponse(hostname: String, response: Response): List<InetAddress> {
private fun readResponse(
hostname: String,
response: Response,
): List<InetAddress> {
if (response.cacheResponse == null && response.protocol !== Protocol.HTTP_2) {
Platform.get().log("Incorrect protocol: ${response.protocol}", Platform.WARN)
}
@ -213,7 +229,7 @@ class DnsOverHttps internal constructor(
if (body.contentLength() > MAX_RESPONSE_SIZE) {
throw IOException(
"response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes"
"response size exceeds limit ($MAX_RESPONSE_SIZE bytes): ${body.contentLength()} bytes",
)
}
@ -223,19 +239,22 @@ class DnsOverHttps internal constructor(
}
}
private fun buildRequest(hostname: String, type: Int): Request =
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply {
val query = DnsRecordCodec.encodeQuery(hostname, type)
private fun buildRequest(
hostname: String,
type: Int,
): Request =
Request.Builder().header("Accept", DNS_MESSAGE.toString()).apply {
val query = DnsRecordCodec.encodeQuery(hostname, type)
if (post) {
url(url).post(query.toRequestBody(DNS_MESSAGE))
} else {
val encoded = query.base64Url().replace("=", "")
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build()
if (post) {
url(url).post(query.toRequestBody(DNS_MESSAGE))
} else {
val encoded = query.base64Url().replace("=", "")
val requestUrl = url.newBuilder().addQueryParameter("dns", encoded).build()
url(requestUrl)
}
}.build()
url(requestUrl)
}
}.build()
class Builder {
internal var client: OkHttpClient? = null
@ -250,49 +269,56 @@ class DnsOverHttps internal constructor(
fun build(): DnsOverHttps {
val client = this.client ?: throw NullPointerException("client not set")
return DnsOverHttps(
client.newBuilder().dns(buildBootstrapClient(this)).build(),
checkNotNull(url) { "url not set" },
includeIPv6,
post,
resolvePrivateAddresses,
resolvePublicAddresses
client.newBuilder().dns(buildBootstrapClient(this)).build(),
checkNotNull(url) { "url not set" },
includeIPv6,
post,
resolvePrivateAddresses,
resolvePublicAddresses,
)
}
fun client(client: OkHttpClient) = apply {
this.client = client
}
fun client(client: OkHttpClient) =
apply {
this.client = client
}
fun url(url: HttpUrl) = apply {
this.url = url
}
fun url(url: HttpUrl) =
apply {
this.url = url
}
fun includeIPv6(includeIPv6: Boolean) = apply {
this.includeIPv6 = includeIPv6
}
fun includeIPv6(includeIPv6: Boolean) =
apply {
this.includeIPv6 = includeIPv6
}
fun post(post: Boolean) = apply {
this.post = post
}
fun post(post: Boolean) =
apply {
this.post = post
}
fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) = apply {
this.resolvePrivateAddresses = resolvePrivateAddresses
}
fun resolvePrivateAddresses(resolvePrivateAddresses: Boolean) =
apply {
this.resolvePrivateAddresses = resolvePrivateAddresses
}
fun resolvePublicAddresses(resolvePublicAddresses: Boolean) = apply {
this.resolvePublicAddresses = resolvePublicAddresses
}
fun resolvePublicAddresses(resolvePublicAddresses: Boolean) =
apply {
this.resolvePublicAddresses = resolvePublicAddresses
}
fun bootstrapDnsHosts(bootstrapDnsHosts: List<InetAddress>?) = apply {
this.bootstrapDnsHosts = bootstrapDnsHosts
}
fun bootstrapDnsHosts(bootstrapDnsHosts: List<InetAddress>?) =
apply {
this.bootstrapDnsHosts = bootstrapDnsHosts
}
fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder =
bootstrapDnsHosts(bootstrapDnsHosts.toList())
fun bootstrapDnsHosts(vararg bootstrapDnsHosts: InetAddress): Builder = bootstrapDnsHosts(bootstrapDnsHosts.toList())
fun systemDns(systemDns: Dns) = apply {
this.systemDns = systemDns
}
fun systemDns(systemDns: Dns) =
apply {
this.systemDns = systemDns
}
}
companion object {

View File

@ -33,31 +33,38 @@ internal object DnsRecordCodec {
private const val TYPE_PTR = 0x000c
private val ASCII = Charsets.US_ASCII
fun encodeQuery(host: String, type: Int): ByteString = Buffer().apply {
writeShort(0) // query id
writeShort(256) // flags with recursion
writeShort(1) // question count
writeShort(0) // answerCount
writeShort(0) // authorityResourceCount
writeShort(0) // additional
fun encodeQuery(
host: String,
type: Int,
): ByteString =
Buffer().apply {
writeShort(0) // query id
writeShort(256) // flags with recursion
writeShort(1) // question count
writeShort(0) // answerCount
writeShort(0) // authorityResourceCount
writeShort(0) // additional
val nameBuf = Buffer()
val labels = host.split('.').dropLastWhile { it.isEmpty() }
for (label in labels) {
val utf8ByteCount = label.utf8Size()
require(utf8ByteCount == label.length.toLong()) { "non-ascii hostname: $host" }
nameBuf.writeByte(utf8ByteCount.toInt())
nameBuf.writeUtf8(label)
}
nameBuf.writeByte(0) // end
val nameBuf = Buffer()
val labels = host.split('.').dropLastWhile { it.isEmpty() }
for (label in labels) {
val utf8ByteCount = label.utf8Size()
require(utf8ByteCount == label.length.toLong()) { "non-ascii hostname: $host" }
nameBuf.writeByte(utf8ByteCount.toInt())
nameBuf.writeUtf8(label)
}
nameBuf.writeByte(0) // end
nameBuf.copyTo(this, 0, nameBuf.size)
writeShort(type)
writeShort(1) // CLASS_IN
}.readByteString()
nameBuf.copyTo(this, 0, nameBuf.size)
writeShort(type)
writeShort(1) // CLASS_IN
}.readByteString()
@Throws(Exception::class)
fun decodeAnswers(hostname: String, byteString: ByteString): List<InetAddress> {
fun decodeAnswers(
hostname: String,
byteString: ByteString,
): List<InetAddress> {
val result = mutableListOf<InetAddress>()
val buf = Buffer()
@ -91,7 +98,8 @@ internal object DnsRecordCodec {
val type = buf.readShort().toInt() and 0xffff
buf.readShort() // class
@Suppress("UNUSED_VARIABLE") val ttl = buf.readInt().toLong() and 0xffffffffL // ttl
@Suppress("UNUSED_VARIABLE")
val ttl = buf.readInt().toLong() and 0xffffffffL // ttl
val length = buf.readShort().toInt() and 0xffff
if (type == TYPE_A || type == TYPE_AAAA) {

View File

@ -55,9 +55,10 @@ class DnsOverHttpsTest {
private lateinit var server: MockWebServer
private lateinit var dns: Dns
private val cacheFs = FakeFileSystem()
private val bootstrapClient = OkHttpClient.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.build()
private val bootstrapClient =
OkHttpClient.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.build()
@BeforeEach
fun setUp(server: MockWebServer) {
@ -72,8 +73,8 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
)
"0003b00049df00112",
),
)
val result = dns.lookup("google.com")
assertThat(result).isEqualTo(listOf(address("157.240.1.18")))
@ -89,15 +90,15 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c0005000" +
"100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001000" +
"10000003b00049df00112"
)
"10000003b00049df00112",
),
)
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c0005000" +
"100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c000" +
"10000003b00102a032880f0290011faceb00c00000002"
)
"10000003b00102a032880f0290011faceb00c00000002",
),
)
dns = buildLocalhost(bootstrapClient, true)
val result = dns.lookup("google.com")
@ -111,7 +112,7 @@ class DnsOverHttpsTest {
assertThat(listOf(request1.path, request2.path))
.containsExactlyInAnyOrder(
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ"
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ",
)
}
@ -121,8 +122,8 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b00060001000" +
"007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e6574c01b5adb1" +
"2c100000e10000003840012750000000e10"
)
"2c100000e10000003840012750000000e10",
),
)
try {
dns.lookup("google.com")
@ -179,11 +180,11 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "private, max-age=298")
.build()
.build(),
)
var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
@ -204,29 +205,29 @@ class DnsOverHttpsTest {
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "max-age=1")
.build()
.build(),
)
var result = cachedDns.lookup("google.com")
assertThat(result).containsExactly(address("157.240.1.18"))
var recordedRequest = server.takeRequest(0, TimeUnit.SECONDS)
assertThat(recordedRequest!!.method).isEqualTo("GET")
assertThat(recordedRequest.path).isEqualTo(
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ"
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
)
Thread.sleep(2000)
server.enqueue(
dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c000500010" +
"0000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010001000" +
"0003b00049df00112"
"0003b00049df00112",
)
.newBuilder()
.setHeader("cache-control", "max-age=1")
.build()
.build(),
)
result = cachedDns.lookup("google.com")
assertThat(result).isEqualTo(listOf(address("157.240.1.18")))
@ -244,7 +245,10 @@ class DnsOverHttpsTest {
.build()
}
private fun buildLocalhost(bootstrapClient: OkHttpClient, includeIPv6: Boolean): DnsOverHttps {
private fun buildLocalhost(
bootstrapClient: OkHttpClient,
includeIPv6: Boolean,
): DnsOverHttps {
val url = server.url("/lookup?ct")
return DnsOverHttps.Builder().client(bootstrapClient)
.includeIPv6(includeIPv6)

View File

@ -14,12 +14,12 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.dnsoverhttps
import assertk.assertThat
import assertk.assertions.containsExactly
import assertk.assertions.isEqualTo
import assertk.fail
import java.net.InetAddress
import java.net.UnknownHostException
import kotlin.test.assertFailsWith
@ -36,7 +36,10 @@ class DnsRecordCodecTest {
assertThat(encoded).isEqualTo("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ")
}
private fun encodeQuery(host: String, type: Int): String {
private fun encodeQuery(
host: String,
type: Int,
): String {
return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "")
}
@ -48,33 +51,45 @@ class DnsRecordCodecTest {
@Test
fun testGoogleDotComDecodingFromCloudflare() {
val encoded = decodeAnswers(
hostname = "test.com",
byteString = ("00008180000100010000000006676f6f676c6503636f6d0000010001c00c0001000100000043" +
"0004d83ad54e").decodeHex()
)
val encoded =
decodeAnswers(
hostname = "test.com",
byteString =
(
"00008180000100010000000006676f6f676c6503636f6d0000010001c00c0001000100000043" +
"0004d83ad54e"
).decodeHex(),
)
assertThat(encoded).containsExactly(InetAddress.getByName("216.58.213.78"))
}
@Test
fun testGoogleDotComDecodingFromGoogle() {
val decoded = decodeAnswers(
hostname = "test.com",
byteString = ("0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c" +
"0005000100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001" +
"00010000003b00049df00112").decodeHex()
)
val decoded =
decodeAnswers(
hostname = "test.com",
byteString =
(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c" +
"0005000100000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c0420001" +
"00010000003b00049df00112"
).decodeHex(),
)
assertThat(decoded).containsExactly(InetAddress.getByName("157.240.1.18"))
}
@Test
fun testGoogleDotComDecodingFromGoogleIPv6() {
val decoded = decodeAnswers(
hostname = "test.com",
byteString = ("0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c" +
"0005000100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c" +
"00010000003b00102a032880f0290011faceb00c00000002").decodeHex()
)
val decoded =
decodeAnswers(
hostname = "test.com",
byteString =
(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c" +
"0005000100000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c" +
"00010000003b00102a032880f0290011faceb00c00000002"
).decodeHex(),
)
assertThat(decoded)
.containsExactly(InetAddress.getByName("2a03:2880:f029:11:face:b00c:0:2"))
}
@ -84,9 +99,12 @@ class DnsRecordCodecTest {
assertFailsWith<UnknownHostException> {
decodeAnswers(
hostname = "sdflkhfsdlkjdf.ee",
byteString = ("0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b" +
"00060001000007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e65" +
"74c01b5adb12c100000e10000003840012750000000e10").decodeHex()
byteString =
(
"0000818300010000000100000e7364666c6b686673646c6b6a64660265650000010001c01b" +
"00060001000007070038026e7303746c64c01b0a686f73746d61737465720d6565737469696e7465726e65" +
"74c01b5adb12c100000e10000003840012750000000e10"
).decodeHex(),
)
}.also { expected ->
assertThat(expected.message).isEqualTo("sdflkhfsdlkjdf.ee: NXDOMAIN")

View File

@ -24,7 +24,10 @@ import okhttp3.OkHttpClient
import okhttp3.dnsoverhttps.DohProviders.providers
import org.conscrypt.OpenSSLProvider
private fun runBatch(dnsProviders: List<DnsOverHttps>, names: List<String>) {
private fun runBatch(
dnsProviders: List<DnsOverHttps>,
names: List<String>,
) {
var time = System.currentTimeMillis()
for (dns in dnsProviders) {
println("Testing ${dns.url}")
@ -54,46 +57,52 @@ fun main() {
var names = listOf("google.com", "graph.facebook.com", "sdflkhfsdlkjdf.ee")
try {
println("uncached\n********\n")
var dnsProviders = providers(
client = bootstrapClient,
http2Only = false,
workingOnly = false,
getOnly = false,
)
var dnsProviders =
providers(
client = bootstrapClient,
http2Only = false,
workingOnly = false,
getOnly = false,
)
runBatch(dnsProviders, names)
val dnsCache = Cache(
directory = File("./target/TestDohMain.cache.${System.currentTimeMillis()}"),
maxSize = 10L * 1024 * 1024
)
val dnsCache =
Cache(
directory = File("./target/TestDohMain.cache.${System.currentTimeMillis()}"),
maxSize = 10L * 1024 * 1024,
)
println("Bad targets\n***********\n")
val url = "https://dns.cloudflare.com/.not-so-well-known/run-dmc-query".toHttpUrl()
val badProviders = listOf(
DnsOverHttps.Builder()
.client(bootstrapClient)
.url(url)
.post(true)
.build()
)
val badProviders =
listOf(
DnsOverHttps.Builder()
.client(bootstrapClient)
.url(url)
.post(true)
.build(),
)
runBatch(badProviders, names)
println("cached first run\n****************\n")
names = listOf("google.com", "graph.facebook.com")
bootstrapClient = bootstrapClient.newBuilder()
.cache(dnsCache)
.build()
dnsProviders = providers(
client = bootstrapClient,
http2Only = true,
workingOnly = true,
getOnly = true,
)
bootstrapClient =
bootstrapClient.newBuilder()
.cache(dnsCache)
.build()
dnsProviders =
providers(
client = bootstrapClient,
http2Only = true,
workingOnly = true,
getOnly = true,
)
runBatch(dnsProviders, names)
println("cached second run\n*****************\n")
dnsProviders = providers(
client = bootstrapClient,
http2Only = true,
workingOnly = true,
getOnly = true,
)
dnsProviders =
providers(
client = bootstrapClient,
http2Only = true,
workingOnly = true,
getOnly = true,
)
runBatch(dnsProviders, names)
} finally {
bootstrapClient.connectionPool.evictAll()

View File

@ -28,7 +28,7 @@ class HpackDecodeInteropTest : HpackDecodeTestBase() {
fun testGoodDecoderInterop(story: Story) {
assumeFalse(
story === Story.MISSING,
"Test stories missing, checkout git submodule"
"Test stories missing, checkout git submodule",
)
testDecoder(story)
}

View File

@ -36,7 +36,7 @@ open class HpackDecodeTestBase {
assertSetEquals(
"seqno=$testCase.seqno",
testCase.headersList,
hpackReader.getAndResetHeaderList()
hpackReader.getAndResetHeaderList(),
)
}
}

View File

@ -43,7 +43,7 @@ class HpackRoundTripTest : HpackDecodeTestBase() {
fun testRoundTrip(story: Story) {
assumeFalse(
story === Story.MISSING,
"Test stories missing, checkout git submodule"
"Test stories missing, checkout git submodule",
)
val newCases = mutableListOf<Case>()

View File

@ -35,13 +35,17 @@ import okio.source
*/
object HpackJsonUtil {
@Suppress("unused")
private val MOSHI = Moshi.Builder()
.add(object : Any() {
@ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex()
@FromJson fun byteStringFromJson(json: String) = json.decodeHex()
})
.add(KotlinJsonAdapterFactory())
.build()
private val MOSHI =
Moshi.Builder()
.add(
object : Any() {
@ToJson fun byteStringToJson(byteString: ByteString) = byteString.hex()
@FromJson fun byteStringFromJson(json: String) = json.decodeHex()
},
)
.add(KotlinJsonAdapterFactory())
.build()
private val STORY_JSON_ADAPTER = MOSHI.adapter(Story::class.java)
private val fileSystem = FileSystem.SYSTEM
@ -58,8 +62,9 @@ object HpackJsonUtil {
/** Iterate through the hpack-test-case resources, only picking stories for the current draft. */
fun storiesForCurrentDraft(): Array<String> {
val resource = HpackJsonUtil::class.java.getResource("/hpack-test-case")
?: return arrayOf()
val resource =
HpackJsonUtil::class.java.getResource("/hpack-test-case")
?: return arrayOf()
val testCaseDirectory = File(resource.toURI()).toOkioPath()
val result = mutableListOf<String>()
@ -83,17 +88,20 @@ object HpackJsonUtil {
val result = mutableListOf<Story>()
var i = 0
while (true) { // break after last test.
val storyResourceName = String.format(
"/hpack-test-case/%s/story_%02d.json",
testFolderName,
i,
)
val storyInputStream = HpackJsonUtil::class.java.getResourceAsStream(storyResourceName)
?: break
val storyResourceName =
String.format(
"/hpack-test-case/%s/story_%02d.json",
testFolderName,
i,
)
val storyInputStream =
HpackJsonUtil::class.java.getResourceAsStream(storyResourceName)
?: break
try {
storyInputStream.use {
val story = readStory(storyInputStream.source().buffer())
.copy(fileName = storyResourceName)
val story =
readStory(storyInputStream.source().buffer())
.copy(fileName = storyResourceName)
result.add(story)
i++
}

View File

@ -24,7 +24,6 @@ data class Story(
val cases: List<Case>,
val fileName: String? = null,
) {
// Used as the test name.
override fun toString() = fileName ?: "?"

View File

@ -33,9 +33,10 @@ fun main(vararg args: String) {
fun loadIdnaMappingTableData(): IdnaMappingTableData {
val path = "/okhttp3/internal/idna/IdnaMappingTable.txt".toPath()
val table = FileSystem.RESOURCES.read(path) {
readPlainTextIdnaMappingTable()
}
val table =
FileSystem.RESOURCES.read(path) {
readPlainTextIdnaMappingTable()
}
return buildIdnaMappingTableData(table)
}
@ -60,18 +61,18 @@ fun generateMappingTableFile(data: IdnaMappingTableData): FileSpec {
.addModifiers(KModifier.INTERNAL)
.initializer(
"""
|%T(
|sections = "%L",
|ranges = "%L",
|mappings = "%L",
|)
""".trimMargin(),
|%T(
|sections = "%L",
|ranges = "%L",
|mappings = "%L",
|)
""".trimMargin(),
idnaMappingTable,
data.sections.escapeDataString(),
data.ranges.escapeDataString(),
data.mappings.escapeDataString(),
)
.build()
.build(),
)
.build()
}
@ -89,7 +90,8 @@ fun String.escapeDataString(): String {
'$'.code,
'\\'.code,
'·'.code,
127 -> append(String.format("\\u%04x", codePoint))
127,
-> append(String.format("\\u%04x", codePoint))
else -> appendCodePoint(codePoint)
}

View File

@ -23,22 +23,22 @@ internal sealed interface MappedRange {
data class Constant(
override val rangeStart: Int,
val type: Int
val type: Int,
) : MappedRange {
val b1: Int
get() = when (type) {
TYPE_IGNORED -> 119
TYPE_VALID -> 120
TYPE_DISALLOWED -> 121
else -> error("unexpected type: $type")
}
get() =
when (type) {
TYPE_IGNORED -> 119
TYPE_VALID -> 120
TYPE_DISALLOWED -> 121
else -> error("unexpected type: $type")
}
}
data class Inline1(
override val rangeStart: Int,
private val mappedTo: ByteString
private val mappedTo: ByteString,
) : MappedRange {
val b1: Int
get() {
val b3bit8 = mappedTo[0] and 0x80 != 0
@ -51,9 +51,8 @@ internal sealed interface MappedRange {
data class Inline2(
override val rangeStart: Int,
private val mappedTo: ByteString
private val mappedTo: ByteString,
) : MappedRange {
val b1: Int
get() {
val b2bit8 = mappedTo[0] and 0x80 != 0
@ -75,17 +74,17 @@ internal sealed interface MappedRange {
data class InlineDelta(
override val rangeStart: Int,
val codepointDelta: Int
val codepointDelta: Int,
) : MappedRange {
private val absoluteDelta = abs(codepointDelta)
val b1: Int
get() = when {
codepointDelta < 0 -> 0x40 or (absoluteDelta shr 14)
codepointDelta > 0 -> 0x50 or (absoluteDelta shr 14)
else -> error("Unexpected codepointDelta of 0")
}
get() =
when {
codepointDelta < 0 -> 0x40 or (absoluteDelta shr 14)
codepointDelta > 0 -> 0x50 or (absoluteDelta shr 14)
else -> error("Unexpected codepointDelta of 0")
}
val b2: Int
get() = absoluteDelta shr 7 and 0x7f
@ -100,6 +99,6 @@ internal sealed interface MappedRange {
data class External(
override val rangeStart: Int,
val mappedTo: ByteString
val mappedTo: ByteString,
) : MappedRange
}

View File

@ -135,26 +135,28 @@ internal fun sections(mappings: List<Mapping>): Map<Int, List<MappedRange>> {
val sectionList = result.getOrPut(section) { mutableListOf() }
sectionList += when (mapping.type) {
TYPE_MAPPED -> run {
val deltaMapping = inlineDeltaOrNull(mapping)
if (deltaMapping != null) {
return@run deltaMapping
sectionList +=
when (mapping.type) {
TYPE_MAPPED ->
run {
val deltaMapping = inlineDeltaOrNull(mapping)
if (deltaMapping != null) {
return@run deltaMapping
}
when (mapping.mappedTo.size) {
1 -> MappedRange.Inline1(rangeStart, mapping.mappedTo)
2 -> MappedRange.Inline2(rangeStart, mapping.mappedTo)
else -> MappedRange.External(rangeStart, mapping.mappedTo)
}
}
TYPE_IGNORED, TYPE_VALID, TYPE_DISALLOWED -> {
MappedRange.Constant(rangeStart, mapping.type)
}
when (mapping.mappedTo.size) {
1 -> MappedRange.Inline1(rangeStart, mapping.mappedTo)
2 -> MappedRange.Inline2(rangeStart, mapping.mappedTo)
else -> MappedRange.External(rangeStart, mapping.mappedTo)
}
else -> error("unexpected mapping type: ${mapping.type}")
}
TYPE_IGNORED, TYPE_VALID, TYPE_DISALLOWED -> {
MappedRange.Constant(rangeStart, mapping.type)
}
else -> error("unexpected mapping type: ${mapping.type}")
}
}
for (sectionList in result.values) {
@ -202,18 +204,20 @@ internal fun withoutSectionSpans(mappings: List<Mapping>): List<Mapping> {
while (true) {
if (current.spansSections) {
result += Mapping(
current.sourceCodePoint0,
current.section + 0x7f,
current.type,
current.mappedTo,
)
current = Mapping(
current.section + 0x80,
current.sourceCodePoint1,
current.type,
current.mappedTo,
)
result +=
Mapping(
current.sourceCodePoint0,
current.section + 0x7f,
current.type,
current.mappedTo,
)
current =
Mapping(
current.section + 0x80,
current.sourceCodePoint1,
current.type,
current.mappedTo,
)
} else {
result += current
current = if (i.hasNext()) i.next() else break
@ -246,12 +250,13 @@ internal fun mergeAdjacentRanges(mappings: List<Mapping>): List<Mapping> {
index++
}
result += Mapping(
sourceCodePoint0 = mapping.sourceCodePoint0,
sourceCodePoint1 = unionWith.sourceCodePoint1,
type = type,
mappedTo = mappedTo,
)
result +=
Mapping(
sourceCodePoint0 = mapping.sourceCodePoint0,
sourceCodePoint1 = unionWith.sourceCodePoint1,
type = type,
mappedTo = mappedTo,
)
}
return result
@ -262,11 +267,13 @@ internal fun canonicalizeType(type: Int): Int {
TYPE_IGNORED -> TYPE_IGNORED
TYPE_MAPPED,
TYPE_DISALLOWED_STD3_MAPPED -> TYPE_MAPPED
TYPE_DISALLOWED_STD3_MAPPED,
-> TYPE_MAPPED
TYPE_DEVIATION,
TYPE_DISALLOWED_STD3_VALID,
TYPE_VALID -> TYPE_VALID
TYPE_VALID,
-> TYPE_VALID
TYPE_DISALLOWED -> TYPE_DISALLOWED
@ -279,4 +286,3 @@ internal infix fun Byte.and(mask: Int): Int = toInt() and mask
internal infix fun Short.and(mask: Int): Int = toInt() and mask
internal infix fun Int.and(mask: Long): Long = toLong() and mask

View File

@ -42,14 +42,18 @@ class SimpleIdnaMappingTable internal constructor(
/**
* Returns true if the [codePoint] was applied successfully. Returns false if it was disallowed.
*/
fun map(codePoint: Int, sink: BufferedSink): Boolean {
val index = mappings.binarySearch {
when {
it.sourceCodePoint1 < codePoint -> -1
it.sourceCodePoint0 > codePoint -> 1
else -> 0
fun map(
codePoint: Int,
sink: BufferedSink,
): Boolean {
val index =
mappings.binarySearch {
when {
it.sourceCodePoint1 < codePoint -> -1
it.sourceCodePoint0 > codePoint -> 1
else -> 0
}
}
}
// Code points must be in 0..0x10ffff.
require(index in mappings.indices) { "unexpected code point: $codePoint" }
@ -77,24 +81,25 @@ class SimpleIdnaMappingTable internal constructor(
}
}
private val optionsDelimiter =
Options.of(
// 0.
".".encodeUtf8(),
// 1.
" ".encodeUtf8(),
// 2.
";".encodeUtf8(),
// 3.
"#".encodeUtf8(),
// 4.
"\n".encodeUtf8(),
)
private val optionsDelimiter = Options.of(
// 0.
".".encodeUtf8(),
// 1.
" ".encodeUtf8(),
// 2.
";".encodeUtf8(),
// 3.
"#".encodeUtf8(),
// 4.
"\n".encodeUtf8(),
)
private val optionsDot = Options.of(
// 0.
".".encodeUtf8(),
)
private val optionsDot =
Options.of(
// 0.
".".encodeUtf8(),
)
private const val DELIMITER_DOT = 0
private const val DELIMITER_SPACE = 1
@ -102,22 +107,23 @@ private const val DELIMITER_SEMICOLON = 2
private const val DELIMITER_HASH = 3
private const val DELIMITER_NEWLINE = 4
private val optionsType = Options.of(
// 0.
"deviation ".encodeUtf8(),
// 1.
"disallowed ".encodeUtf8(),
// 2.
"disallowed_STD3_mapped ".encodeUtf8(),
// 3.
"disallowed_STD3_valid ".encodeUtf8(),
// 4.
"ignored ".encodeUtf8(),
// 5.
"mapped ".encodeUtf8(),
// 6.
"valid ".encodeUtf8(),
)
private val optionsType =
Options.of(
// 0.
"deviation ".encodeUtf8(),
// 1.
"disallowed ".encodeUtf8(),
// 2.
"disallowed_STD3_mapped ".encodeUtf8(),
// 3.
"disallowed_STD3_valid ".encodeUtf8(),
// 4.
"ignored ".encodeUtf8(),
// 5.
"mapped ".encodeUtf8(),
// 6.
"valid ".encodeUtf8(),
)
internal const val TYPE_DEVIATION = 0
internal const val TYPE_DISALLOWED = 1
@ -182,14 +188,15 @@ fun BufferedSource.readPlainTextIdnaMappingTable(): SimpleIdnaMappingTable {
// "002F" or "0000..002C"
val sourceCodePoint0 = readHexadecimalUnsignedLong()
val sourceCodePoint1 = when (select(optionsDot)) {
DELIMITER_DOT -> {
if (readByte() != '.'.code.toByte()) throw IOException("expected '..'")
readHexadecimalUnsignedLong()
}
val sourceCodePoint1 =
when (select(optionsDot)) {
DELIMITER_DOT -> {
if (readByte() != '.'.code.toByte()) throw IOException("expected '..'")
readHexadecimalUnsignedLong()
}
else -> sourceCodePoint0
}
else -> sourceCodePoint0
}
skipWhitespace()
if (readByte() != ';'.code.toByte()) throw IOException("expected ';'")
@ -228,12 +235,13 @@ fun BufferedSource.readPlainTextIdnaMappingTable(): SimpleIdnaMappingTable {
skipRestOfLine()
result += Mapping(
sourceCodePoint0.toInt(),
sourceCodePoint1.toInt(),
type,
mappedTo.readByteString(),
)
result +=
Mapping(
sourceCodePoint0.toInt(),
sourceCodePoint1.toInt(),
type,
mappedTo.readByteString(),
)
}
return SimpleIdnaMappingTable(result)

View File

@ -34,8 +34,8 @@ class MappingTablesTest {
Mapping(0x0234, 0x0236, TYPE_VALID, ByteString.EMPTY),
Mapping(0x0237, 0x0239, TYPE_VALID, ByteString.EMPTY),
Mapping(0x023a, 0x023a, TYPE_MAPPED, "b".encodeUtf8()),
)
)
),
),
).containsExactly(
Mapping(0x0232, 0x0232, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0233, 0x0239, TYPE_VALID, ByteString.EMPTY),
@ -49,8 +49,8 @@ class MappingTablesTest {
listOf(
Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()),
)
)
),
),
).containsExactly(
Mapping(0x0041, 0x0041, TYPE_MAPPED, "a".encodeUtf8()),
Mapping(0x0042, 0x0042, TYPE_MAPPED, "b".encodeUtf8()),
@ -62,8 +62,8 @@ class MappingTablesTest {
mergeAdjacentRanges(
listOf(
Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY),
)
)
),
),
).containsExactly(
Mapping(0x0000, 0x002c, TYPE_VALID, ByteString.EMPTY),
)
@ -74,9 +74,9 @@ class MappingTablesTest {
mergeAdjacentRanges(
listOf(
Mapping(0x0000, 0x002c, TYPE_DISALLOWED_STD3_VALID, ByteString.EMPTY),
Mapping(0x002d, 0x002e, TYPE_VALID, ByteString.EMPTY)
)
)
Mapping(0x002d, 0x002e, TYPE_VALID, ByteString.EMPTY),
),
),
).containsExactly(
Mapping(0x0000, 0x002e, TYPE_VALID, ByteString.EMPTY),
)
@ -87,8 +87,8 @@ class MappingTablesTest {
withoutSectionSpans(
listOf(
Mapping(0x40000, 0x40180, TYPE_DISALLOWED, ByteString.EMPTY),
)
)
),
),
).containsExactly(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
@ -103,8 +103,8 @@ class MappingTablesTest {
listOf(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
)
)
),
),
).containsExactly(
Mapping(0x40000, 0x4007f, TYPE_DISALLOWED, ByteString.EMPTY),
Mapping(0x40080, 0x400ff, TYPE_DISALLOWED, ByteString.EMPTY),
@ -119,8 +119,8 @@ class MappingTablesTest {
InlineDelta(2, 5),
InlineDelta(3, 5),
MappedRange.External(4, "a".encodeUtf8()),
)
)
),
),
).containsExactly(
InlineDelta(1, 5),
MappedRange.External(4, "a".encodeUtf8()),
@ -134,8 +134,8 @@ class MappingTablesTest {
InlineDelta(1, 5),
InlineDelta(2, 5),
InlineDelta(3, 1),
)
)
),
),
).containsExactly(
InlineDelta(1, 5),
InlineDelta(3, 1),
@ -148,9 +148,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 1,
sourceCodePoint1 = 1,
mappedToCodePoints = listOf(2)
)
)
mappedToCodePoints = listOf(2),
),
),
).isEqualTo(InlineDelta(1, 1))
assertThat(
@ -158,9 +158,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 2,
sourceCodePoint1 = 2,
mappedToCodePoints = listOf(1)
)
)
mappedToCodePoints = listOf(1),
),
),
).isEqualTo(InlineDelta(2, -1))
}
@ -170,9 +170,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 2,
sourceCodePoint1 = 3,
mappedToCodePoints = listOf(2)
)
)
mappedToCodePoints = listOf(2),
),
),
).isEqualTo(null)
}
@ -182,9 +182,9 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 1,
sourceCodePoint1 = 1,
mappedToCodePoints = listOf(2, 3)
)
)
mappedToCodePoints = listOf(2, 3),
),
),
).isEqualTo(null)
}
@ -194,14 +194,14 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 0,
sourceCodePoint1 = 0,
mappedToCodePoints = listOf((1 shl 18) - 1)
)
)
mappedToCodePoints = listOf((1 shl 18) - 1),
),
),
).isEqualTo(
InlineDelta(
rangeStart = 0,
codepointDelta = InlineDelta.MAX_VALUE
)
codepointDelta = InlineDelta.MAX_VALUE,
),
)
assertThat(
@ -209,24 +209,26 @@ class MappingTablesTest {
mappingOf(
sourceCodePoint0 = 0,
sourceCodePoint1 = 0,
mappedToCodePoints = listOf(1 shl 18)
)
)
mappedToCodePoints = listOf(1 shl 18),
),
),
).isEqualTo(null)
}
private fun mappingOf(
sourceCodePoint0: Int,
sourceCodePoint1: Int,
mappedToCodePoints: List<Int>
): Mapping = Mapping(
sourceCodePoint0 = sourceCodePoint0,
sourceCodePoint1 = sourceCodePoint1,
type = TYPE_MAPPED,
mappedTo = Buffer().also {
for (cp in mappedToCodePoints) {
it.writeUtf8CodePoint(cp)
}
}.readByteString()
)
mappedToCodePoints: List<Int>,
): Mapping =
Mapping(
sourceCodePoint0 = sourceCodePoint0,
sourceCodePoint1 = sourceCodePoint1,
type = TYPE_MAPPED,
mappedTo =
Buffer().also {
for (cp in mappedToCodePoints) {
it.writeUtf8CodePoint(cp)
}
}.readByteString(),
)
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.java.net.cookiejar
package okhttp3.java.net.cookiejar
import java.io.IOException
import java.net.CookieHandler
@ -32,8 +32,10 @@ import okhttp3.internal.trimSubstring
/** A cookie jar that delegates to a [java.net.CookieHandler]. */
class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
override fun saveFromResponse(url: HttpUrl, cookies: List<Cookie>) {
override fun saveFromResponse(
url: HttpUrl,
cookies: List<Cookie>,
) {
val cookieStrings = mutableListOf<String>()
for (cookie in cookies) {
cookieStrings.add(cookieToString(cookie, true))
@ -47,18 +49,20 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
}
override fun loadForRequest(url: HttpUrl): List<Cookie> {
val cookieHeaders = try {
// The RI passes all headers. We don't have 'em, so we don't pass 'em!
cookieHandler.get(url.toUri(), emptyMap<String, List<String>>())
} catch (e: IOException) {
Platform.get().log("Loading cookies failed for " + url.resolve("/...")!!, WARN, e)
return emptyList()
}
val cookieHeaders =
try {
// The RI passes all headers. We don't have 'em, so we don't pass 'em!
cookieHandler.get(url.toUri(), emptyMap<String, List<String>>())
} catch (e: IOException) {
Platform.get().log("Loading cookies failed for " + url.resolve("/...")!!, WARN, e)
return emptyList()
}
var cookies: MutableList<Cookie>? = null
for ((key, value) in cookieHeaders) {
if (("Cookie".equals(key, ignoreCase = true) || "Cookie2".equals(key, ignoreCase = true)) &&
value.isNotEmpty()) {
value.isNotEmpty()
) {
for (header in value) {
if (cookies == null) cookies = mutableListOf()
cookies.addAll(decodeHeaderAsJavaNetCookies(url, header))
@ -77,7 +81,10 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
* Convert a request header to OkHttp's cookies via [HttpCookie]. That extra step handles
* multiple cookies in a single request header, which [Cookie.parse] doesn't support.
*/
private fun decodeHeaderAsJavaNetCookies(url: HttpUrl, header: String): List<Cookie> {
private fun decodeHeaderAsJavaNetCookies(
url: HttpUrl,
header: String,
): List<Cookie> {
val result = mutableListOf<Cookie>()
var pos = 0
val limit = header.length
@ -92,22 +99,25 @@ class JavaNetCookieJar(private val cookieHandler: CookieHandler) : CookieJar {
}
// We have either name=value or just a name.
var value = if (equalsSign < pairEnd) {
header.trimSubstring(equalsSign + 1, pairEnd)
} else {
""
}
var value =
if (equalsSign < pairEnd) {
header.trimSubstring(equalsSign + 1, pairEnd)
} else {
""
}
// If the value is "quoted", drop the quotes.
if (value.startsWith("\"") && value.endsWith("\"") && value.length >= 2) {
value = value.substring(1, value.length - 1)
}
result.add(Cookie.Builder()
result.add(
Cookie.Builder()
.name(name)
.value(value)
.domain(url.host)
.build())
.build(),
)
pos = pairEnd + 1
}
return result

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3.logging
import java.io.IOException
@ -38,288 +39,295 @@ import okio.GzipSource
* The format of the logs created by this class should not be considered stable and may
* change slightly between releases. If you need a stable logging format, use your own interceptor.
*/
class HttpLoggingInterceptor @JvmOverloads constructor(
private val logger: Logger = Logger.DEFAULT
) : Interceptor {
class HttpLoggingInterceptor
@JvmOverloads
constructor(
private val logger: Logger = Logger.DEFAULT,
) : Interceptor {
@Volatile private var headersToRedact = emptySet<String>()
@Volatile private var headersToRedact = emptySet<String>()
@set:JvmName("level")
@Volatile
var level = Level.NONE
@set:JvmName("level")
@Volatile var level = Level.NONE
enum class Level {
/** No logs. */
NONE,
enum class Level {
/** No logs. */
NONE,
/**
* Logs request and response lines.
*
* Example:
* ```
* --> POST /greeting http/1.1 (3-byte body)
*
* <-- 200 OK (22ms, 6-byte body)
* ```
*/
BASIC,
/**
* Logs request and response lines.
*
* Example:
* ```
* --> POST /greeting http/1.1 (3-byte body)
*
* <-- 200 OK (22ms, 6-byte body)
* ```
*/
BASIC,
/**
* Logs request and response lines and their respective headers.
*
* Example:
* ```
* --> POST /greeting http/1.1
* Host: example.com
* Content-Type: plain/text
* Content-Length: 3
* --> END POST
*
* <-- 200 OK (22ms)
* Content-Type: plain/text
* Content-Length: 6
* <-- END HTTP
* ```
*/
HEADERS,
/**
* Logs request and response lines and their respective headers.
*
* Example:
* ```
* --> POST /greeting http/1.1
* Host: example.com
* Content-Type: plain/text
* Content-Length: 3
* --> END POST
*
* <-- 200 OK (22ms)
* Content-Type: plain/text
* Content-Length: 6
* <-- END HTTP
* ```
*/
HEADERS,
/**
* Logs request and response lines and their respective headers and bodies (if present).
*
* Example:
* ```
* --> POST /greeting http/1.1
* Host: example.com
* Content-Type: plain/text
* Content-Length: 3
*
* Hi?
* --> END POST
*
* <-- 200 OK (22ms)
* Content-Type: plain/text
* Content-Length: 6
*
* Hello!
* <-- END HTTP
* ```
*/
BODY,
}
/**
* Logs request and response lines and their respective headers and bodies (if present).
*
* Example:
* ```
* --> POST /greeting http/1.1
* Host: example.com
* Content-Type: plain/text
* Content-Length: 3
*
* Hi?
* --> END POST
*
* <-- 200 OK (22ms)
* Content-Type: plain/text
* Content-Length: 6
*
* Hello!
* <-- END HTTP
* ```
*/
BODY
}
fun interface Logger {
fun log(message: String)
fun interface Logger {
fun log(message: String)
companion object {
/** A [Logger] defaults output appropriate for the current platform. */
@JvmField
val DEFAULT: Logger = DefaultLogger()
companion object {
/** A [Logger] defaults output appropriate for the current platform. */
@JvmField
val DEFAULT: Logger = DefaultLogger()
private class DefaultLogger : Logger {
override fun log(message: String) {
Platform.get().log(message)
private class DefaultLogger : Logger {
override fun log(message: String) {
Platform.get().log(message)
}
}
}
}
}
fun redactHeader(name: String) {
val newHeadersToRedact = TreeSet(String.CASE_INSENSITIVE_ORDER)
newHeadersToRedact += headersToRedact
newHeadersToRedact += name
headersToRedact = newHeadersToRedact
}
fun redactHeader(name: String) {
val newHeadersToRedact = TreeSet(String.CASE_INSENSITIVE_ORDER)
newHeadersToRedact += headersToRedact
newHeadersToRedact += name
headersToRedact = newHeadersToRedact
}
/**
* Sets the level and returns this.
*
* This was deprecated in OkHttp 4.0 in favor of the [level] val. In OkHttp 4.3 it is
* un-deprecated because Java callers can't chain when assigning Kotlin vals. (The getter remains
* deprecated).
*/
fun setLevel(level: Level) = apply {
this.level = level
}
/**
* Sets the level and returns this.
*
* This was deprecated in OkHttp 4.0 in favor of the [level] val. In OkHttp 4.3 it is
* un-deprecated because Java callers can't chain when assigning Kotlin vals. (The getter remains
* deprecated).
*/
fun setLevel(level: Level) =
apply {
this.level = level
}
@JvmName("-deprecated_level")
@Deprecated(
@JvmName("-deprecated_level")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "level"),
level = DeprecationLevel.ERROR
)
fun getLevel(): Level = level
@Throws(IOException::class)
override fun intercept(chain: Interceptor.Chain): Response {
val level = this.level
val request = chain.request()
if (level == Level.NONE) {
return chain.proceed(request)
}
val logBody = level == Level.BODY
val logHeaders = logBody || level == Level.HEADERS
val requestBody = request.body
val connection = chain.connection()
var requestStartMessage =
("--> ${request.method} ${request.url}${if (connection != null) " " + connection.protocol() else ""}")
if (!logHeaders && requestBody != null) {
requestStartMessage += " (${requestBody.contentLength()}-byte body)"
}
logger.log(requestStartMessage)
if (logHeaders) {
val headers = request.headers
if (requestBody != null) {
// Request body headers are only present when installed as a network interceptor. When not
// already present, force them to be included (if available) so their values are known.
requestBody.contentType()?.let {
if (headers["Content-Type"] == null) {
logger.log("Content-Type: $it")
}
}
if (requestBody.contentLength() != -1L) {
if (headers["Content-Length"] == null) {
logger.log("Content-Length: ${requestBody.contentLength()}")
}
}
}
for (i in 0 until headers.size) {
logHeader(headers, i)
}
if (!logBody || requestBody == null) {
logger.log("--> END ${request.method}")
} else if (bodyHasUnknownEncoding(request.headers)) {
logger.log("--> END ${request.method} (encoded body omitted)")
} else if (requestBody.isDuplex()) {
logger.log("--> END ${request.method} (duplex request body omitted)")
} else if (requestBody.isOneShot()) {
logger.log("--> END ${request.method} (one-shot body omitted)")
} else {
var buffer = Buffer()
requestBody.writeTo(buffer)
var gzippedLength: Long? = null
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
gzippedLength = buffer.size
GzipSource(buffer).use { gzippedResponseBody ->
buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
}
}
val charset: Charset = requestBody.contentType().charsetOrUtf8()
logger.log("")
if (!buffer.isProbablyUtf8()) {
logger.log(
"--> END ${request.method} (binary ${requestBody.contentLength()}-byte body omitted)"
)
} else if (gzippedLength != null) {
logger.log("--> END ${request.method} (${buffer.size}-byte, $gzippedLength-gzipped-byte body)")
} else {
logger.log(buffer.readString(charset))
logger.log("--> END ${request.method} (${requestBody.contentLength()}-byte body)")
}
}
}
val startNs = System.nanoTime()
val response: Response
try {
response = chain.proceed(request)
} catch (e: Exception) {
logger.log("<-- HTTP FAILED: $e")
throw e
}
val tookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
val responseBody = response.body!!
val contentLength = responseBody.contentLength()
val bodySize = if (contentLength != -1L) "$contentLength-byte" else "unknown-length"
logger.log(
buildString {
append("<-- ${response.code}")
if (response.message.isNotEmpty()) append(" ${response.message}")
append(" ${response.request.url} (${tookMs}ms")
if (!logHeaders) append(", $bodySize body")
append(")")
}
level = DeprecationLevel.ERROR,
)
fun getLevel(): Level = level
if (logHeaders) {
val headers = response.headers
for (i in 0 until headers.size) {
logHeader(headers, i)
@Throws(IOException::class)
override fun intercept(chain: Interceptor.Chain): Response {
val level = this.level
val request = chain.request()
if (level == Level.NONE) {
return chain.proceed(request)
}
if (!logBody || !response.promisesBody()) {
logger.log("<-- END HTTP")
} else if (bodyHasUnknownEncoding(response.headers)) {
logger.log("<-- END HTTP (encoded body omitted)")
} else if (bodyIsStreaming(response)) {
logger.log("<-- END HTTP (streaming)")
} else {
val source = responseBody.source()
source.request(Long.MAX_VALUE) // Buffer the entire body.
val logBody = level == Level.BODY
val logHeaders = logBody || level == Level.HEADERS
val totalMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
val requestBody = request.body
var buffer = source.buffer
val connection = chain.connection()
var requestStartMessage =
("--> ${request.method} ${request.url}${if (connection != null) " " + connection.protocol() else ""}")
if (!logHeaders && requestBody != null) {
requestStartMessage += " (${requestBody.contentLength()}-byte body)"
}
logger.log(requestStartMessage)
var gzippedLength: Long? = null
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
gzippedLength = buffer.size
GzipSource(buffer.clone()).use { gzippedResponseBody ->
buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
if (logHeaders) {
val headers = request.headers
if (requestBody != null) {
// Request body headers are only present when installed as a network interceptor. When not
// already present, force them to be included (if available) so their values are known.
requestBody.contentType()?.let {
if (headers["Content-Type"] == null) {
logger.log("Content-Type: $it")
}
}
if (requestBody.contentLength() != -1L) {
if (headers["Content-Length"] == null) {
logger.log("Content-Length: ${requestBody.contentLength()}")
}
}
}
val charset: Charset = responseBody.contentType().charsetOrUtf8()
if (!buffer.isProbablyUtf8()) {
logger.log("")
logger.log("<-- END HTTP (${totalMs}ms, binary ${buffer.size}-byte body omitted)")
return response
for (i in 0 until headers.size) {
logHeader(headers, i)
}
if (contentLength != 0L) {
logger.log("")
logger.log(buffer.clone().readString(charset))
}
if (!logBody || requestBody == null) {
logger.log("--> END ${request.method}")
} else if (bodyHasUnknownEncoding(request.headers)) {
logger.log("--> END ${request.method} (encoded body omitted)")
} else if (requestBody.isDuplex()) {
logger.log("--> END ${request.method} (duplex request body omitted)")
} else if (requestBody.isOneShot()) {
logger.log("--> END ${request.method} (one-shot body omitted)")
} else {
var buffer = Buffer()
requestBody.writeTo(buffer)
logger.log(
buildString {
append("<-- END HTTP (${totalMs}ms, ${buffer.size}-byte")
if (gzippedLength != null) append(", $gzippedLength-gzipped-byte")
append(" body)")
var gzippedLength: Long? = null
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
gzippedLength = buffer.size
GzipSource(buffer).use { gzippedResponseBody ->
buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
}
}
)
val charset: Charset = requestBody.contentType().charsetOrUtf8()
logger.log("")
if (!buffer.isProbablyUtf8()) {
logger.log(
"--> END ${request.method} (binary ${requestBody.contentLength()}-byte body omitted)",
)
} else if (gzippedLength != null) {
logger.log("--> END ${request.method} (${buffer.size}-byte, $gzippedLength-gzipped-byte body)")
} else {
logger.log(buffer.readString(charset))
logger.log("--> END ${request.method} (${requestBody.contentLength()}-byte body)")
}
}
}
val startNs = System.nanoTime()
val response: Response
try {
response = chain.proceed(request)
} catch (e: Exception) {
logger.log("<-- HTTP FAILED: $e")
throw e
}
val tookMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
val responseBody = response.body!!
val contentLength = responseBody.contentLength()
val bodySize = if (contentLength != -1L) "$contentLength-byte" else "unknown-length"
logger.log(
buildString {
append("<-- ${response.code}")
if (response.message.isNotEmpty()) append(" ${response.message}")
append(" ${response.request.url} (${tookMs}ms")
if (!logHeaders) append(", $bodySize body")
append(")")
},
)
if (logHeaders) {
val headers = response.headers
for (i in 0 until headers.size) {
logHeader(headers, i)
}
if (!logBody || !response.promisesBody()) {
logger.log("<-- END HTTP")
} else if (bodyHasUnknownEncoding(response.headers)) {
logger.log("<-- END HTTP (encoded body omitted)")
} else if (bodyIsStreaming(response)) {
logger.log("<-- END HTTP (streaming)")
} else {
val source = responseBody.source()
source.request(Long.MAX_VALUE) // Buffer the entire body.
val totalMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
var buffer = source.buffer
var gzippedLength: Long? = null
if ("gzip".equals(headers["Content-Encoding"], ignoreCase = true)) {
gzippedLength = buffer.size
GzipSource(buffer.clone()).use { gzippedResponseBody ->
buffer = Buffer()
buffer.writeAll(gzippedResponseBody)
}
}
val charset: Charset = responseBody.contentType().charsetOrUtf8()
if (!buffer.isProbablyUtf8()) {
logger.log("")
logger.log("<-- END HTTP (${totalMs}ms, binary ${buffer.size}-byte body omitted)")
return response
}
if (contentLength != 0L) {
logger.log("")
logger.log(buffer.clone().readString(charset))
}
logger.log(
buildString {
append("<-- END HTTP (${totalMs}ms, ${buffer.size}-byte")
if (gzippedLength != null) append(", $gzippedLength-gzipped-byte")
append(" body)")
},
)
}
}
return response
}
return response
}
private fun logHeader(
headers: Headers,
i: Int,
) {
val value = if (headers.name(i) in headersToRedact) "██" else headers.value(i)
logger.log(headers.name(i) + ": " + value)
}
private fun logHeader(headers: Headers, i: Int) {
val value = if (headers.name(i) in headersToRedact) "██" else headers.value(i)
logger.log(headers.name(i) + ": " + value)
}
private fun bodyHasUnknownEncoding(headers: Headers): Boolean {
val contentEncoding = headers["Content-Encoding"] ?: return false
return !contentEncoding.equals("identity", ignoreCase = true) &&
private fun bodyHasUnknownEncoding(headers: Headers): Boolean {
val contentEncoding = headers["Content-Encoding"] ?: return false
return !contentEncoding.equals("identity", ignoreCase = true) &&
!contentEncoding.equals("gzip", ignoreCase = true)
}
}
private fun bodyIsStreaming(response: Response): Boolean {
val contentType = response.body.contentType()
return contentType != null && contentType.type == "text" && contentType.subtype == "event-stream"
private fun bodyIsStreaming(response: Response): Boolean {
val contentType = response.body.contentType()
return contentType != null && contentType.type == "text" && contentType.subtype == "event-stream"
}
}
}

View File

@ -38,7 +38,7 @@ import okhttp3.Response
* slightly between releases. If you need a stable logging format, use your own event listener.
*/
class LoggingEventListener private constructor(
private val logger: HttpLoggingInterceptor.Logger
private val logger: HttpLoggingInterceptor.Logger,
) : EventListener() {
private var startNs: Long = 0
@ -48,23 +48,41 @@ class LoggingEventListener private constructor(
logWithTime("callStart: ${call.request()}")
}
override fun proxySelectStart(call: Call, url: HttpUrl) {
override fun proxySelectStart(
call: Call,
url: HttpUrl,
) {
logWithTime("proxySelectStart: $url")
}
override fun proxySelectEnd(call: Call, url: HttpUrl, proxies: List<Proxy>) {
override fun proxySelectEnd(
call: Call,
url: HttpUrl,
proxies: List<Proxy>,
) {
logWithTime("proxySelectEnd: $proxies")
}
override fun dnsStart(call: Call, domainName: String) {
override fun dnsStart(
call: Call,
domainName: String,
) {
logWithTime("dnsStart: $domainName")
}
override fun dnsEnd(call: Call, domainName: String, inetAddressList: List<InetAddress>) {
override fun dnsEnd(
call: Call,
domainName: String,
inetAddressList: List<InetAddress>,
) {
logWithTime("dnsEnd: $inetAddressList")
}
override fun connectStart(call: Call, inetSocketAddress: InetSocketAddress, proxy: Proxy) {
override fun connectStart(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
) {
logWithTime("connectStart: $inetSocketAddress $proxy")
}
@ -72,7 +90,10 @@ class LoggingEventListener private constructor(
logWithTime("secureConnectStart")
}
override fun secureConnectEnd(call: Call, handshake: Handshake?) {
override fun secureConnectEnd(
call: Call,
handshake: Handshake?,
) {
logWithTime("secureConnectEnd: $handshake")
}
@ -80,7 +101,7 @@ class LoggingEventListener private constructor(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?
protocol: Protocol?,
) {
logWithTime("connectEnd: $protocol")
}
@ -90,16 +111,22 @@ class LoggingEventListener private constructor(
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?,
ioe: IOException
ioe: IOException,
) {
logWithTime("connectFailed: $protocol $ioe")
}
override fun connectionAcquired(call: Call, connection: Connection) {
override fun connectionAcquired(
call: Call,
connection: Connection,
) {
logWithTime("connectionAcquired: $connection")
}
override fun connectionReleased(call: Call, connection: Connection) {
override fun connectionReleased(
call: Call,
connection: Connection,
) {
logWithTime("connectionReleased")
}
@ -107,7 +134,10 @@ class LoggingEventListener private constructor(
logWithTime("requestHeadersStart")
}
override fun requestHeadersEnd(call: Call, request: Request) {
override fun requestHeadersEnd(
call: Call,
request: Request,
) {
logWithTime("requestHeadersEnd")
}
@ -115,11 +145,17 @@ class LoggingEventListener private constructor(
logWithTime("requestBodyStart")
}
override fun requestBodyEnd(call: Call, byteCount: Long) {
override fun requestBodyEnd(
call: Call,
byteCount: Long,
) {
logWithTime("requestBodyEnd: byteCount=$byteCount")
}
override fun requestFailed(call: Call, ioe: IOException) {
override fun requestFailed(
call: Call,
ioe: IOException,
) {
logWithTime("requestFailed: $ioe")
}
@ -127,7 +163,10 @@ class LoggingEventListener private constructor(
logWithTime("responseHeadersStart")
}
override fun responseHeadersEnd(call: Call, response: Response) {
override fun responseHeadersEnd(
call: Call,
response: Response,
) {
logWithTime("responseHeadersEnd: $response")
}
@ -135,11 +174,17 @@ class LoggingEventListener private constructor(
logWithTime("responseBodyStart")
}
override fun responseBodyEnd(call: Call, byteCount: Long) {
override fun responseBodyEnd(
call: Call,
byteCount: Long,
) {
logWithTime("responseBodyEnd: byteCount=$byteCount")
}
override fun responseFailed(call: Call, ioe: IOException) {
override fun responseFailed(
call: Call,
ioe: IOException,
) {
logWithTime("responseFailed: $ioe")
}
@ -147,7 +192,10 @@ class LoggingEventListener private constructor(
logWithTime("callEnd")
}
override fun callFailed(call: Call, ioe: IOException) {
override fun callFailed(
call: Call,
ioe: IOException,
) {
logWithTime("callFailed: $ioe")
}
@ -155,11 +203,17 @@ class LoggingEventListener private constructor(
logWithTime("canceled")
}
override fun satisfactionFailure(call: Call, response: Response) {
override fun satisfactionFailure(
call: Call,
response: Response,
) {
logWithTime("satisfactionFailure: $response")
}
override fun cacheHit(call: Call, response: Response) {
override fun cacheHit(
call: Call,
response: Response,
) {
logWithTime("cacheHit: $response")
}
@ -167,7 +221,10 @@ class LoggingEventListener private constructor(
logWithTime("cacheMiss")
}
override fun cacheConditionalHit(call: Call, cachedResponse: Response) {
override fun cacheConditionalHit(
call: Call,
cachedResponse: Response,
) {
logWithTime("cacheConditionalHit: $cachedResponse")
}
@ -176,9 +233,11 @@ class LoggingEventListener private constructor(
logger.log("[$timeMs ms] $message")
}
open class Factory @JvmOverloads constructor(
private val logger: HttpLoggingInterceptor.Logger = HttpLoggingInterceptor.Logger.DEFAULT
) : EventListener.Factory {
override fun create(call: Call): EventListener = LoggingEventListener(logger)
}
open class Factory
@JvmOverloads
constructor(
private val logger: HttpLoggingInterceptor.Logger = HttpLoggingInterceptor.Logger.DEFAULT,
) : EventListener.Factory {
override fun create(call: Call): EventListener = LoggingEventListener(logger)
}
}

View File

@ -72,21 +72,24 @@ class HttpLoggingInterceptorTest {
@BeforeEach
fun setUp(server: MockWebServer) {
this.server = server
client = OkHttpClient.Builder()
.addNetworkInterceptor(Interceptor { chain ->
when {
extraNetworkInterceptor != null -> extraNetworkInterceptor!!.intercept(chain)
else -> chain.proceed(chain.request())
}
})
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
.hostnameVerifier(hostnameVerifier)
.build()
client =
OkHttpClient.Builder()
.addNetworkInterceptor(
Interceptor { chain ->
when {
extraNetworkInterceptor != null -> extraNetworkInterceptor!!.intercept(chain)
else -> chain.proceed(chain.request())
}
},
)
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
.hostnameVerifier(hostnameVerifier)
.build()
host = "${server.hostName}:${server.port}"
url = server.url("/")
}
@ -153,7 +156,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -174,7 +177,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -278,13 +281,14 @@ class HttpLoggingInterceptorTest {
fun headersPostNoLength() {
setLevel(Level.HEADERS)
server.enqueue(MockResponse())
val body: RequestBody = object : RequestBody() {
override fun contentType() = PLAIN
val body: RequestBody =
object : RequestBody() {
override fun contentType() = PLAIN
override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hi!")
override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hi!")
}
}
}
val response = client.newCall(request().post(body).build()).execute()
response.body.close()
applicationLogs
@ -313,20 +317,21 @@ class HttpLoggingInterceptorTest {
@Test
fun headersPostWithHeaderOverrides() {
setLevel(Level.HEADERS)
extraNetworkInterceptor = Interceptor { chain: Interceptor.Chain ->
chain.proceed(
chain.request()
.newBuilder()
.header("Content-Length", "2")
.header("Content-Type", "text/plain-ish")
.build()
)
}
extraNetworkInterceptor =
Interceptor { chain: Interceptor.Chain ->
chain.proceed(
chain.request()
.newBuilder()
.header("Content-Length", "2")
.header("Content-Type", "text/plain-ish")
.build(),
)
}
server.enqueue(MockResponse())
client.newCall(
request()
.post("Hi?".toRequestBody(PLAIN))
.build()
.build(),
).execute()
applicationLogs
.assertLogEqual("--> POST $url")
@ -359,7 +364,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -427,7 +432,7 @@ class HttpLoggingInterceptorTest {
server.enqueue(
MockResponse.Builder()
.status("HTTP/1.1 $code No Content")
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -493,7 +498,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -530,7 +535,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.chunkedBody("Hello!", 2)
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -567,14 +572,15 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.setHeader("Content-Type", PLAIN)
.body(Buffer().writeUtf8("Uncompressed"))
.build()
.build(),
)
val response = client.newCall(
request()
.addHeader("Content-Encoding", "gzip")
.post("Uncompressed".toRequestBody().gzip())
.build()
).execute()
val response =
client.newCall(
request()
.addHeader("Content-Encoding", "gzip")
.post("Uncompressed".toRequestBody().gzip())
.build(),
).execute()
val responseBody = response.body
assertThat(responseBody.string(), "Expected response body to be valid")
.isEqualTo("Uncompressed")
@ -606,7 +612,7 @@ class HttpLoggingInterceptorTest {
.setHeader("Content-Encoding", "gzip")
.setHeader("Content-Type", PLAIN)
.body(Buffer().write("H4sIAAAAAAAAAPNIzcnJ11HwQKIAdyO+9hMAAAA=".decodeBase64()!!))
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
val responseBody = response.body
@ -647,7 +653,7 @@ class HttpLoggingInterceptorTest {
.setHeader("Content-Encoding", "br")
.setHeader("Content-Type", PLAIN)
.body(Buffer().write("iwmASGVsbG8sIEhlbGxvLCBIZWxsbwoD".decodeBase64()!!))
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -693,9 +699,10 @@ class HttpLoggingInterceptorTest {
|data: 113411
|
|
""".trimMargin(), 8
""".trimMargin(),
8,
)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -728,7 +735,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.setHeader("Content-Type", "text/html; charset=0")
.body("Body with unknown charset")
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -774,7 +781,7 @@ class HttpLoggingInterceptorTest {
MockResponse.Builder()
.body(buffer)
.setHeader("Content-Type", "image/png; charset=utf-8")
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
response.body.close()
@ -805,10 +812,11 @@ class HttpLoggingInterceptorTest {
@Test
fun connectFail() {
setLevel(Level.BASIC)
client = OkHttpClient.Builder()
.dns { hostname: String? -> throw UnknownHostException("reason") }
.addInterceptor(applicationInterceptor)
.build()
client =
OkHttpClient.Builder()
.dns { hostname: String? -> throw UnknownHostException("reason") }
.addInterceptor(applicationInterceptor)
.build()
try {
client.newCall(request().build()).execute()
fail<Any>()
@ -840,31 +848,35 @@ class HttpLoggingInterceptorTest {
@Test
fun headersAreRedacted() {
val networkInterceptor = HttpLoggingInterceptor(networkLogs).setLevel(
Level.HEADERS
)
val networkInterceptor =
HttpLoggingInterceptor(networkLogs).setLevel(
Level.HEADERS,
)
networkInterceptor.redactHeader("sEnSiTiVe")
val applicationInterceptor = HttpLoggingInterceptor(applicationLogs).setLevel(
Level.HEADERS
)
val applicationInterceptor =
HttpLoggingInterceptor(applicationLogs).setLevel(
Level.HEADERS,
)
applicationInterceptor.redactHeader("sEnSiTiVe")
client = OkHttpClient.Builder()
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor)
.build()
client =
OkHttpClient.Builder()
.addNetworkInterceptor(networkInterceptor)
.addInterceptor(applicationInterceptor)
.build()
server.enqueue(
MockResponse.Builder()
.addHeader("SeNsItIvE", "Value").addHeader("Not-Sensitive", "Value")
.build()
.build(),
)
val response = client
.newCall(
request()
.addHeader("SeNsItIvE", "Value")
.addHeader("Not-Sensitive", "Value")
.build()
)
.execute()
val response =
client
.newCall(
request()
.addHeader("SeNsItIvE", "Value")
.addHeader("Not-Sensitive", "Value")
.build(),
)
.execute()
response.body.close()
applicationLogs
.assertLogEqual("--> GET $url")
@ -903,25 +915,27 @@ class HttpLoggingInterceptorTest {
server.enqueue(
MockResponse.Builder()
.body("Hello response!")
.build()
.build(),
)
val asyncRequestBody: RequestBody = object : RequestBody() {
override fun contentType(): MediaType? {
return null
}
val asyncRequestBody: RequestBody =
object : RequestBody() {
override fun contentType(): MediaType? {
return null
}
override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hello request!")
sink.close()
}
override fun writeTo(sink: BufferedSink) {
sink.writeUtf8("Hello request!")
sink.close()
}
override fun isDuplex(): Boolean {
return true
override fun isDuplex(): Boolean {
return true
}
}
}
val request = request()
.post(asyncRequestBody)
.build()
val request =
request()
.post(asyncRequestBody)
.build()
val response = client.newCall(request).execute()
Assumptions.assumeTrue(response.protocol == Protocol.HTTP_2)
assertThat(response.body.string()).isEqualTo("Hello response!")
@ -943,25 +957,27 @@ class HttpLoggingInterceptorTest {
server.enqueue(
MockResponse.Builder()
.body("Hello response!")
.build()
.build(),
)
val asyncRequestBody: RequestBody = object : RequestBody() {
var counter = 0
val asyncRequestBody: RequestBody =
object : RequestBody() {
var counter = 0
override fun contentType() = null
override fun contentType() = null
override fun writeTo(sink: BufferedSink) {
counter++
assertThat(counter).isLessThanOrEqualTo(1)
sink.writeUtf8("Hello request!")
sink.close()
override fun writeTo(sink: BufferedSink) {
counter++
assertThat(counter).isLessThanOrEqualTo(1)
sink.writeUtf8("Hello request!")
sink.close()
}
override fun isOneShot() = true
}
override fun isOneShot() = true
}
val request = request()
.post(asyncRequestBody)
.build()
val request =
request()
.post(asyncRequestBody)
.build()
val response = client.newCall(request).execute()
assertThat(response.body.string()).isEqualTo("Hello response!")
applicationLogs
@ -985,19 +1001,21 @@ class HttpLoggingInterceptorTest {
private val logs = mutableListOf<String>()
private var index = 0
fun assertLogEqual(expected: String) = apply {
assertThat(index, "No more messages found")
.isLessThan(logs.size)
assertThat(logs[index++]).isEqualTo(expected)
return this
}
fun assertLogEqual(expected: String) =
apply {
assertThat(index, "No more messages found")
.isLessThan(logs.size)
assertThat(logs[index++]).isEqualTo(expected)
return this
}
fun assertLogMatch(regex: Regex) = apply {
assertThat(index, "No more messages found")
.isLessThan(logs.size)
assertThat(logs[index++])
.matches(Regex(prefix.pattern + regex.pattern, RegexOption.DOT_MATCHES_ALL))
}
fun assertLogMatch(regex: Regex) =
apply {
assertThat(index, "No more messages found")
.isLessThan(logs.size)
assertThat(logs[index++])
.matches(Regex(prefix.pattern + regex.pattern, RegexOption.DOT_MATCHES_ALL))
}
fun assertNoMoreLogs() {
assertThat(logs.size, "More messages remain: ${logs.subList(index, logs.size)}")

View File

@ -50,9 +50,10 @@ class LoggingEventListenerTest {
val clientTestRule = OkHttpClientTestRule()
private lateinit var server: MockWebServer
private val handshakeCertificates = platform.localhostHandshakeCertificates()
private val logRecorder = HttpLoggingInterceptorTest.LogRecorder(
prefix = Regex("""\[\d+ ms] """)
)
private val logRecorder =
HttpLoggingInterceptorTest.LogRecorder(
prefix = Regex("""\[\d+ ms] """),
)
private val loggingEventListenerFactory = LoggingEventListener.Factory(logRecorder)
private lateinit var client: OkHttpClient
private lateinit var url: HttpUrl
@ -60,14 +61,15 @@ class LoggingEventListenerTest {
@BeforeEach
fun setUp(server: MockWebServer) {
this.server = server
client = clientTestRule.newClientBuilder()
.eventListenerFactory(loggingEventListenerFactory)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager
)
.retryOnConnectionFailure(false)
.build()
client =
clientTestRule.newClientBuilder()
.eventListenerFactory(loggingEventListenerFactory)
.sslSocketFactory(
handshakeCertificates.sslSocketFactory(),
handshakeCertificates.trustManager,
)
.retryOnConnectionFailure(false)
.build()
url = server.url("/")
}
@ -78,7 +80,7 @@ class LoggingEventListenerTest {
MockResponse.Builder()
.body("Hello!")
.setHeader("Content-Type", PLAIN)
.build()
.build(),
)
val response = client.newCall(request().build()).execute()
assertThat(response.body).isNotNull()
@ -91,7 +93,11 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""connectEnd: http/1.1"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}"""))
.assertLogMatch(
Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
),
)
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart"""))
@ -116,7 +122,11 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""connectEnd: http/1.1"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}"""))
.assertLogMatch(
Regex(
"""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=none protocol=http/1\.1\}""",
),
)
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""requestBodyStart"""))
@ -148,9 +158,15 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""secureConnectStart"""))
.assertLogMatch(Regex("""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}"""))
.assertLogMatch(
Regex(
"""secureConnectEnd: Handshake\{tlsVersion=TLS_1_[23] cipherSuite=TLS_.* peerCertificates=\[CN=localhost] localCertificates=\[]\}""",
),
)
.assertLogMatch(Regex("""connectEnd: h2"""))
.assertLogMatch(Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""))
.assertLogMatch(
Regex("""connectionAcquired: Connection\{${url.host}:\d+, proxy=DIRECT hostAddress=${url.host}/.+ cipherSuite=.+ protocol=h2\}"""),
)
.assertLogMatch(Regex("""requestHeadersStart"""))
.assertLogMatch(Regex("""requestHeadersEnd"""))
.assertLogMatch(Regex("""responseHeadersStart"""))
@ -164,10 +180,11 @@ class LoggingEventListenerTest {
@Test
fun dnsFail() {
client = OkHttpClient.Builder()
.dns { _ -> throw UnknownHostException("reason") }
.eventListenerFactory(loggingEventListenerFactory)
.build()
client =
OkHttpClient.Builder()
.dns { _ -> throw UnknownHostException("reason") }
.eventListenerFactory(loggingEventListenerFactory)
.build()
try {
client.newCall(request().build()).execute()
fail<Any>()
@ -190,7 +207,7 @@ class LoggingEventListenerTest {
server.enqueue(
MockResponse.Builder()
.socketPolicy(FailHandshake)
.build()
.build(),
)
url = server.url("/")
try {
@ -206,8 +223,16 @@ class LoggingEventListenerTest {
.assertLogMatch(Regex("""dnsEnd: \[.+]"""))
.assertLogMatch(Regex("""connectStart: ${url.host}/.+ DIRECT"""))
.assertLogMatch(Regex("""secureConnectStart"""))
.assertLogMatch(Regex("""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*"""))
.assertLogMatch(Regex("""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*"""))
.assertLogMatch(
Regex(
"""connectFailed: null \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
),
)
.assertLogMatch(
Regex(
"""callFailed: \S+(?:SSLProtocolException|SSLHandshakeException|TlsFatalAlert): (?:Unexpected handshake message: client_hello|Handshake message sequence violation, 1|Read error|Handshake failed|unexpected_message\(10\)).*""",
),
)
.assertNoMoreLogs()
}

View File

@ -33,6 +33,9 @@ interface EventSource {
* asynchronous process to connect the socket. Once that succeeds or fails, `listener` will be
* notified. The caller must cancel the returned event source when it is no longer in use.
*/
fun newEventSource(request: Request, listener: EventSourceListener): EventSource
fun newEventSource(
request: Request,
listener: EventSourceListener,
): EventSource
}
}

View File

@ -22,13 +22,21 @@ abstract class EventSourceListener {
* Invoked when an event source has been accepted by the remote peer and may begin transmitting
* events.
*/
open fun onOpen(eventSource: EventSource, response: Response) {
open fun onOpen(
eventSource: EventSource,
response: Response,
) {
}
/**
* TODO description.
*/
open fun onEvent(eventSource: EventSource, id: String?, type: String?, data: String) {
open fun onEvent(
eventSource: EventSource,
id: String?,
type: String?,
data: String,
) {
}
/**
@ -43,6 +51,10 @@ abstract class EventSourceListener {
* Invoked when an event source has been closed due to an error reading from or writing to the
* network. Incoming events may have been lost. No further calls to this listener will be made.
*/
open fun onFailure(eventSource: EventSource, t: Throwable?, response: Response?) {
open fun onFailure(
eventSource: EventSource,
t: Throwable?,
response: Response?,
) {
}
}

View File

@ -45,7 +45,10 @@ object EventSources {
}
@JvmStatic
fun processResponse(response: Response, listener: EventSourceListener) {
fun processResponse(
response: Response,
listener: EventSourceListener,
) {
val eventSource = RealEventSource(response.request, listener)
eventSource.processResponse(response)
}

View File

@ -27,18 +27,23 @@ import okhttp3.sse.EventSourceListener
internal class RealEventSource(
private val request: Request,
private val listener: EventSourceListener
private val listener: EventSourceListener,
) : EventSource, ServerSentEventReader.Callback, Callback {
private var call: Call? = null
@Volatile private var canceled = false
fun connect(callFactory: Call.Factory) {
call = callFactory.newCall(request).apply {
enqueue(this@RealEventSource)
}
call =
callFactory.newCall(request).apply {
enqueue(this@RealEventSource)
}
}
override fun onResponse(call: Call, response: Response) {
override fun onResponse(
call: Call,
response: Response,
) {
processResponse(response)
}
@ -52,8 +57,11 @@ internal class RealEventSource(
val body = response.body
if (!body.isEventStream()) {
listener.onFailure(this,
IllegalStateException("Invalid content-type: ${body.contentType()}"), response)
listener.onFailure(
this,
IllegalStateException("Invalid content-type: ${body.contentType()}"),
response,
)
return
}
@ -71,10 +79,11 @@ internal class RealEventSource(
}
}
} catch (e: Exception) {
val exception = when {
canceled -> IOException("canceled", e)
else -> e
}
val exception =
when {
canceled -> IOException("canceled", e)
else -> e
}
listener.onFailure(this, exception, response)
return
}
@ -91,7 +100,10 @@ internal class RealEventSource(
return contentType.type == "text" && contentType.subtype == "event-stream"
}
override fun onFailure(call: Call, e: IOException) {
override fun onFailure(
call: Call,
e: IOException,
) {
listener.onFailure(this, e, null)
}
@ -102,7 +114,11 @@ internal class RealEventSource(
call?.cancel()
}
override fun onEvent(id: String?, type: String?, data: String) {
override fun onEvent(
id: String?,
type: String?,
data: String,
) {
listener.onEvent(this, id, type, data)
}

View File

@ -24,12 +24,17 @@ import okio.Options
class ServerSentEventReader(
private val source: BufferedSource,
private val callback: Callback
private val callback: Callback,
) {
private var lastId: String? = null
interface Callback {
fun onEvent(id: String?, type: String?, data: String)
fun onEvent(
id: String?,
type: String?,
data: String,
)
fun onRetryChange(timeMs: Long)
}
@ -101,7 +106,11 @@ class ServerSentEventReader(
}
@Throws(IOException::class)
private fun completeEvent(id: String?, type: String?, data: Buffer) {
private fun completeEvent(
id: String?,
type: String?,
data: Buffer,
) {
if (data.size != 0L) {
lastId = id
data.skip(1L) // Leading newline.
@ -110,35 +119,49 @@ class ServerSentEventReader(
}
companion object {
val options = Options.of(
/* 0 */ "\r\n".encodeUtf8(),
/* 1 */ "\r".encodeUtf8(),
/* 2 */ "\n".encodeUtf8(),
/* 3 */ "data: ".encodeUtf8(),
/* 4 */ "data:".encodeUtf8(),
/* 5 */ "data\r\n".encodeUtf8(),
/* 6 */ "data\r".encodeUtf8(),
/* 7 */ "data\n".encodeUtf8(),
/* 8 */ "id: ".encodeUtf8(),
/* 9 */ "id:".encodeUtf8(),
/* 10 */ "id\r\n".encodeUtf8(),
/* 11 */ "id\r".encodeUtf8(),
/* 12 */ "id\n".encodeUtf8(),
/* 13 */ "event: ".encodeUtf8(),
/* 14 */ "event:".encodeUtf8(),
/* 15 */ "event\r\n".encodeUtf8(),
/* 16 */ "event\r".encodeUtf8(),
/* 17 */ "event\n".encodeUtf8(),
/* 18 */ "retry: ".encodeUtf8(),
/* 19 */ "retry:".encodeUtf8()
)
val options =
Options.of(
// 0
"\r\n".encodeUtf8(),
// 1
"\r".encodeUtf8(),
// 2
"\n".encodeUtf8(),
// 3
"data: ".encodeUtf8(),
// 4
"data:".encodeUtf8(),
// 5
"data\r\n".encodeUtf8(),
// 6
"data\r".encodeUtf8(),
// 7
"data\n".encodeUtf8(),
// 8
"id: ".encodeUtf8(),
// 9
"id:".encodeUtf8(),
// 10
"id\r\n".encodeUtf8(),
// 11
"id\r".encodeUtf8(),
// 12
"id\n".encodeUtf8(),
// 13
"event: ".encodeUtf8(),
// 14
"event:".encodeUtf8(),
// 15
"event\r\n".encodeUtf8(),
// 16
"event\r".encodeUtf8(),
// 17
"event\n".encodeUtf8(),
// 18
"retry: ".encodeUtf8(),
// 19
"retry:".encodeUtf8(),
)
private val CRLF = "\r\n".encodeUtf8()

View File

@ -47,9 +47,10 @@ class EventSourceHttpTest {
val clientTestRule = OkHttpClientTestRule()
private val eventListener = RecordingEventListener()
private val listener = EventSourceRecorder()
private var client = clientTestRule.newClientBuilder()
.eventListenerFactory(clientTestRule.wrap(eventListener))
.build()
private var client =
clientTestRule.newClientBuilder()
.eventListenerFactory(clientTestRule.wrap(eventListener))
.build()
@BeforeEach
fun before(server: MockWebServer) {
@ -70,9 +71,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -90,9 +91,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
listener.enqueueCancel() // Will cancel in onOpen().
newEventSource()
@ -109,9 +110,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/plain")
.build()
.build(),
)
newEventSource()
listener.assertFailure("Invalid content-type: text/plain")
@ -126,11 +127,11 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
)
.setHeader("content-type", "text/event-stream")
.code(401)
.build()
.build(),
)
newEventSource()
listener.assertFailure(null)
@ -138,15 +139,16 @@ class EventSourceHttpTest {
@Test
fun fullCallTimeoutDoesNotApplyOnceConnected() {
client = client.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS)
.build()
client =
client.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS)
.build()
server.enqueue(
MockResponse.Builder()
.bodyDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream")
.body("data: hey\n\n")
.build()
.build(),
)
val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -157,15 +159,16 @@ class EventSourceHttpTest {
@Test
fun fullCallTimeoutAppliesToSetup() {
client = client.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS)
.build()
client =
client.newBuilder()
.callTimeout(250, TimeUnit.MILLISECONDS)
.build()
server.enqueue(
MockResponse.Builder()
.headersDelay(500, TimeUnit.MILLISECONDS)
.setHeader("content-type", "text/event-stream")
.body("data: hey\n\n")
.build()
.build(),
)
newEventSource()
listener.assertFailure("timeout")
@ -180,10 +183,10 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
)
.setHeader("content-type", "text/event-stream")
.build()
.build(),
)
newEventSource("text/plain")
listener.assertOpen()
@ -201,9 +204,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
newEventSource()
listener.assertOpen()
@ -222,9 +225,9 @@ class EventSourceHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
val source = newEventSource()
assertThat(source.request().url.encodedPath).isEqualTo("/")
@ -247,13 +250,14 @@ class EventSourceHttpTest {
"ResponseBodyStart",
"ResponseBodyEnd",
"ConnectionReleased",
"CallEnd"
"CallEnd",
)
}
private fun newEventSource(accept: String? = null): EventSource {
val builder = Request.Builder()
.url(server.url("/"))
val builder =
Request.Builder()
.url(server.url("/"))
if (accept != null) {
builder.header("Accept", accept)
}

View File

@ -35,7 +35,10 @@ class EventSourceRecorder : EventSourceListener() {
cancel = true
}
override fun onOpen(eventSource: EventSource, response: Response) {
override fun onOpen(
eventSource: EventSource,
response: Response,
) {
get().log("[ES] onOpen", Platform.INFO, null)
events.add(Open(eventSource, response))
drainCancelQueue(eventSource)
@ -52,9 +55,7 @@ class EventSourceRecorder : EventSourceListener() {
drainCancelQueue(eventSource)
}
override fun onClosed(
eventSource: EventSource,
) {
override fun onClosed(eventSource: EventSource) {
get().log("[ES] onClosed", Platform.INFO, null)
events.add(Closed)
drainCancelQueue(eventSource)
@ -70,9 +71,7 @@ class EventSourceRecorder : EventSourceListener() {
drainCancelQueue(eventSource)
}
private fun drainCancelQueue(
eventSource: EventSource,
) {
private fun drainCancelQueue(eventSource: EventSource) {
if (cancel) {
cancel = false
eventSource.cancel()

View File

@ -61,13 +61,14 @@ class EventSourcesHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
val request = Request.Builder()
.url(server.url("/"))
.build()
val request =
Request.Builder()
.url(server.url("/"))
.build()
val response = client.newCall(request).execute()
processResponse(response, listener)
listener.assertOpen()
@ -84,14 +85,15 @@ class EventSourcesHttpTest {
|data: hey
|
|
""".trimMargin()
""".trimMargin(),
).setHeader("content-type", "text/event-stream")
.build()
.build(),
)
listener.enqueueCancel() // Will cancel in onOpen().
val request = Request.Builder()
.url(server.url("/"))
.build()
val request =
Request.Builder()
.url(server.url("/"))
.build()
val response = client.newCall(request).execute()
processResponse(response, listener)
listener.assertOpen()

View File

@ -42,7 +42,7 @@ class ServerSentEventIteratorTest {
|data: 10
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
}
@ -56,7 +56,7 @@ class ServerSentEventIteratorTest {
|data: 10
|
|
""".trimMargin().replace("\n", "\r")
""".trimMargin().replace("\n", "\r"),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
}
@ -70,7 +70,7 @@ class ServerSentEventIteratorTest {
|data: 10
|
|
""".trimMargin().replace("\n", "\r\n")
""".trimMargin().replace("\n", "\r\n"),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "YHOO\n+2\n10"))
}
@ -89,7 +89,7 @@ class ServerSentEventIteratorTest {
|data: 113411
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, "add", "73857293"))
assertThat(callbacks.remove()).isEqualTo(Event(null, "remove", "2153"))
@ -106,7 +106,7 @@ class ServerSentEventIteratorTest {
|id: 1
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
}
@ -124,7 +124,7 @@ class ServerSentEventIteratorTest {
|data: third event
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "second event"))
@ -142,7 +142,7 @@ class ServerSentEventIteratorTest {
|
|data:
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, ""))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "\n"))
@ -157,7 +157,7 @@ class ServerSentEventIteratorTest {
|data: test
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test"))
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "test"))
@ -170,7 +170,7 @@ class ServerSentEventIteratorTest {
|data: test
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, " test"))
}
@ -188,7 +188,7 @@ class ServerSentEventIteratorTest {
|data: third event
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event"))
@ -207,7 +207,7 @@ class ServerSentEventIteratorTest {
|data: second event
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "second event"))
@ -223,7 +223,7 @@ class ServerSentEventIteratorTest {
|id: 1
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(22L)
assertThat(callbacks.remove()).isEqualTo(Event("1", null, "first event"))
@ -237,7 +237,7 @@ class ServerSentEventIteratorTest {
|
|retry: hey
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(22L)
}
@ -253,7 +253,7 @@ class ServerSentEventIteratorTest {
|retrying
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "a"))
}
@ -270,7 +270,7 @@ class ServerSentEventIteratorTest {
|data
|
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks.remove()).isEqualTo(Event(null, null, "c\n"))
}
@ -281,21 +281,26 @@ class ServerSentEventIteratorTest {
"""
|retry
|
""".trimMargin()
""".trimMargin(),
)
assertThat(callbacks).isEmpty()
}
private fun consumeEvents(source: String) {
val callback: ServerSentEventReader.Callback = object : ServerSentEventReader.Callback {
override fun onEvent(id: String?, type: String?, data: String) {
callbacks.add(Event(id, type, data))
}
val callback: ServerSentEventReader.Callback =
object : ServerSentEventReader.Callback {
override fun onEvent(
id: String?,
type: String?,
data: String,
) {
callbacks.add(Event(id, type, data))
}
override fun onRetryChange(timeMs: Long) {
callbacks.add(timeMs)
override fun onRetryChange(timeMs: Long) {
callbacks.add(timeMs)
}
}
}
val buffer = Buffer().writeUtf8(source)
val reader = ServerSentEventReader(buffer, callback)
while (reader.processNextEvent()) {

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3
import java.io.IOException
@ -37,40 +38,38 @@ sealed class CallEvent {
data class ProxySelectStart(
override val timestampNs: Long,
override val call: Call,
val url: HttpUrl
val url: HttpUrl,
) : CallEvent()
data class ProxySelectEnd(
override val timestampNs: Long,
override val call: Call,
val url: HttpUrl,
val proxies: List<Proxy>?
val proxies: List<Proxy>?,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ProxySelectStart && call == event.call && url == event.url
override fun closes(event: CallEvent): Boolean = event is ProxySelectStart && call == event.call && url == event.url
}
data class DnsStart(
override val timestampNs: Long,
override val call: Call,
val domainName: String
val domainName: String,
) : CallEvent()
data class DnsEnd(
override val timestampNs: Long,
override val call: Call,
val domainName: String,
val inetAddressList: List<InetAddress>
val inetAddressList: List<InetAddress>,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is DnsStart && call == event.call && domainName == event.domainName
override fun closes(event: CallEvent): Boolean = event is DnsStart && call == event.call && domainName == event.domainName
}
data class ConnectStart(
override val timestampNs: Long,
override val call: Call,
val inetSocketAddress: InetSocketAddress,
val proxy: Proxy?
val proxy: Proxy?,
) : CallEvent()
data class ConnectEnd(
@ -78,7 +77,7 @@ sealed class CallEvent {
override val call: Call,
val inetSocketAddress: InetSocketAddress,
val proxy: Proxy?,
val protocol: Protocol?
val protocol: Protocol?,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy
@ -90,7 +89,7 @@ sealed class CallEvent {
val inetSocketAddress: InetSocketAddress,
val proxy: Proxy,
val protocol: Protocol?,
val ioe: IOException
val ioe: IOException,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ConnectStart && call == event.call && inetSocketAddress == event.inetSocketAddress && proxy == event.proxy
@ -98,149 +97,139 @@ sealed class CallEvent {
data class SecureConnectStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class SecureConnectEnd(
override val timestampNs: Long,
override val call: Call,
val handshake: Handshake?
val handshake: Handshake?,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is SecureConnectStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is SecureConnectStart && call == event.call
}
data class ConnectionAcquired(
override val timestampNs: Long,
override val call: Call,
val connection: Connection
val connection: Connection,
) : CallEvent()
data class ConnectionReleased(
override val timestampNs: Long,
override val call: Call,
val connection: Connection
val connection: Connection,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ConnectionAcquired && call == event.call && connection == event.connection
override fun closes(event: CallEvent): Boolean = event is ConnectionAcquired && call == event.call && connection == event.connection
}
data class CallStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CallEnd(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is CallStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is CallStart && call == event.call
}
data class CallFailed(
override val timestampNs: Long,
override val call: Call,
val ioe: IOException
val ioe: IOException,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is CallStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is CallStart && call == event.call
}
data class Canceled(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class RequestHeadersStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class RequestHeadersEnd(
override val timestampNs: Long,
override val call: Call,
val headerLength: Long
val headerLength: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is RequestHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is RequestHeadersStart && call == event.call
}
data class RequestBodyStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class RequestBodyEnd(
override val timestampNs: Long,
override val call: Call,
val bytesWritten: Long
val bytesWritten: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is RequestBodyStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is RequestBodyStart && call == event.call
}
data class RequestFailed(
override val timestampNs: Long,
override val call: Call,
val ioe: IOException
val ioe: IOException,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is RequestHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is RequestHeadersStart && call == event.call
}
data class ResponseHeadersStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class ResponseHeadersEnd(
override val timestampNs: Long,
override val call: Call,
val headerLength: Long
val headerLength: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ResponseHeadersStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is ResponseHeadersStart && call == event.call
}
data class ResponseBodyStart(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class ResponseBodyEnd(
override val timestampNs: Long,
override val call: Call,
val bytesRead: Long
val bytesRead: Long,
) : CallEvent() {
override fun closes(event: CallEvent): Boolean =
event is ResponseBodyStart && call == event.call
override fun closes(event: CallEvent): Boolean = event is ResponseBodyStart && call == event.call
}
data class ResponseFailed(
override val timestampNs: Long,
override val call: Call,
val ioe: IOException
val ioe: IOException,
) : CallEvent()
data class SatisfactionFailure(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CacheHit(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CacheMiss(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
data class CacheConditionalHit(
override val timestampNs: Long,
override val call: Call
override val call: Call,
) : CallEvent()
}

View File

@ -23,9 +23,9 @@ import java.util.concurrent.TimeUnit
class ClientRuleEventListener(
val delegate: EventListener = NONE,
var logger: (String) -> Unit
var logger: (String) -> Unit,
) : EventListener(),
EventListener.Factory {
EventListener.Factory {
private var startNs: Long? = null
override fun create(call: Call): EventListener = this
@ -40,7 +40,7 @@ class ClientRuleEventListener(
override fun proxySelectStart(
call: Call,
url: HttpUrl
url: HttpUrl,
) {
logWithTime("proxySelectStart: $url")
@ -50,7 +50,7 @@ class ClientRuleEventListener(
override fun proxySelectEnd(
call: Call,
url: HttpUrl,
proxies: List<Proxy>
proxies: List<Proxy>,
) {
logWithTime("proxySelectEnd: $proxies")
@ -59,7 +59,7 @@ class ClientRuleEventListener(
override fun dnsStart(
call: Call,
domainName: String
domainName: String,
) {
logWithTime("dnsStart: $domainName")
@ -69,7 +69,7 @@ class ClientRuleEventListener(
override fun dnsEnd(
call: Call,
domainName: String,
inetAddressList: List<InetAddress>
inetAddressList: List<InetAddress>,
) {
logWithTime("dnsEnd: $inetAddressList")
@ -79,7 +79,7 @@ class ClientRuleEventListener(
override fun connectStart(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy
proxy: Proxy,
) {
logWithTime("connectStart: $inetSocketAddress $proxy")
@ -94,7 +94,7 @@ class ClientRuleEventListener(
override fun secureConnectEnd(
call: Call,
handshake: Handshake?
handshake: Handshake?,
) {
logWithTime("secureConnectEnd: $handshake")
@ -105,7 +105,7 @@ class ClientRuleEventListener(
call: Call,
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?
protocol: Protocol?,
) {
logWithTime("connectEnd: $protocol")
@ -117,7 +117,7 @@ class ClientRuleEventListener(
inetSocketAddress: InetSocketAddress,
proxy: Proxy,
protocol: Protocol?,
ioe: IOException
ioe: IOException,
) {
logWithTime("connectFailed: $protocol $ioe")
@ -126,7 +126,7 @@ class ClientRuleEventListener(
override fun connectionAcquired(
call: Call,
connection: Connection
connection: Connection,
) {
logWithTime("connectionAcquired: $connection")
@ -135,7 +135,7 @@ class ClientRuleEventListener(
override fun connectionReleased(
call: Call,
connection: Connection
connection: Connection,
) {
logWithTime("connectionReleased")
@ -150,7 +150,7 @@ class ClientRuleEventListener(
override fun requestHeadersEnd(
call: Call,
request: Request
request: Request,
) {
logWithTime("requestHeadersEnd")
@ -165,7 +165,7 @@ class ClientRuleEventListener(
override fun requestBodyEnd(
call: Call,
byteCount: Long
byteCount: Long,
) {
logWithTime("requestBodyEnd: byteCount=$byteCount")
@ -174,7 +174,7 @@ class ClientRuleEventListener(
override fun requestFailed(
call: Call,
ioe: IOException
ioe: IOException,
) {
logWithTime("requestFailed: $ioe")
@ -189,7 +189,7 @@ class ClientRuleEventListener(
override fun responseHeadersEnd(
call: Call,
response: Response
response: Response,
) {
logWithTime("responseHeadersEnd: $response")
@ -204,7 +204,7 @@ class ClientRuleEventListener(
override fun responseBodyEnd(
call: Call,
byteCount: Long
byteCount: Long,
) {
logWithTime("responseBodyEnd: byteCount=$byteCount")
@ -213,7 +213,7 @@ class ClientRuleEventListener(
override fun responseFailed(
call: Call,
ioe: IOException
ioe: IOException,
) {
logWithTime("responseFailed: $ioe")
@ -228,7 +228,7 @@ class ClientRuleEventListener(
override fun callFailed(
call: Call,
ioe: IOException
ioe: IOException,
) {
logWithTime("callFailed: $ioe")
@ -241,7 +241,10 @@ class ClientRuleEventListener(
delegate.canceled(call)
}
override fun satisfactionFailure(call: Call, response: Response) {
override fun satisfactionFailure(
call: Call,
response: Response,
) {
logWithTime("satisfactionFailure")
delegate.satisfactionFailure(call, response)
@ -253,13 +256,19 @@ class ClientRuleEventListener(
delegate.cacheMiss(call)
}
override fun cacheHit(call: Call, response: Response) {
override fun cacheHit(
call: Call,
response: Response,
) {
logWithTime("cacheHit")
delegate.cacheHit(call, response)
}
override fun cacheConditionalHit(call: Call, cachedResponse: Response) {
override fun cacheConditionalHit(
call: Call,
cachedResponse: Response,
) {
logWithTime("cacheConditionalHit")
delegate.cacheConditionalHit(call, cachedResponse)
@ -267,12 +276,13 @@ class ClientRuleEventListener(
private fun logWithTime(message: String) {
val startNs = startNs
val timeMs = if (startNs == null) {
// Event occurred before start, for an example an early cancel.
0L
} else {
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
}
val timeMs =
if (startNs == null) {
// Event occurred before start, for an example an early cancel.
0L
} else {
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNs)
}
logger.invoke("[$timeMs ms] $message")
}

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
@file:Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE")
package okhttp3
import java.io.IOException
@ -35,17 +36,16 @@ sealed class ConnectionEvent {
data class ConnectStart(
override val timestampNs: Long,
val route: Route,
val call: Call
val call: Call,
) : ConnectionEvent()
data class ConnectFailed(
override val timestampNs: Long,
val route: Route,
val call: Call,
val exception: IOException
val exception: IOException,
) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean =
event is ConnectStart && call == event.call && route == event.route
override fun closes(event: ConnectionEvent): Boolean = event is ConnectStart && call == event.call && route == event.route
}
data class ConnectEnd(
@ -54,8 +54,7 @@ sealed class ConnectionEvent {
val route: Route,
val call: Call,
) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean =
event is ConnectStart && call == event.call && route == event.route
override fun closes(event: ConnectionEvent): Boolean = event is ConnectStart && call == event.call && route == event.route
}
data class ConnectionClosed(
@ -66,15 +65,14 @@ sealed class ConnectionEvent {
data class ConnectionAcquired(
override val timestampNs: Long,
override val connection: Connection,
val call: Call
val call: Call,
) : ConnectionEvent()
data class ConnectionReleased(
override val timestampNs: Long,
override val connection: Connection,
val call: Call
val call: Call,
) : ConnectionEvent() {
override fun closes(event: ConnectionEvent): Boolean =
event is ConnectionAcquired && connection == event.connection && call == event.call
}

View File

@ -24,7 +24,6 @@ import javax.security.cert.X509Certificate
/** An [SSLSession] that delegates all calls. */
abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSession {
override fun getId(): ByteArray {
return delegate!!.id
}
@ -49,7 +48,10 @@ abstract class DelegatingSSLSession(protected val delegate: SSLSession?) : SSLSe
return delegate!!.isValid
}
override fun putValue(s: String, o: Any) {
override fun putValue(
s: String,
o: Any,
) {
delegate!!.putValue(s, o)
}

View File

@ -198,7 +198,10 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
}
@Throws(SocketException::class)
override fun setSoLinger(on: Boolean, timeout: Int) {
override fun setSoLinger(
on: Boolean,
timeout: Int,
) {
delegate!!.setSoLinger(on, timeout)
}
@ -247,7 +250,10 @@ abstract class DelegatingSSLSocket(protected val delegate: SSLSocket?) : SSLSock
}
@Throws(IOException::class)
override fun connect(remoteAddr: SocketAddress, timeout: Int) {
override fun connect(
remoteAddr: SocketAddress,
timeout: Int,
) {
delegate!!.connect(remoteAddr, timeout)
}

View File

@ -33,28 +33,40 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
}
@Throws(IOException::class)
override fun createSocket(host: String, port: Int): SSLSocket {
override fun createSocket(
host: String,
port: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port) as SSLSocket
return configureSocket(sslSocket)
}
@Throws(IOException::class)
override fun createSocket(
host: String, port: Int, localAddress: InetAddress, localPort: Int
host: String,
port: Int,
localAddress: InetAddress,
localPort: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket
return configureSocket(sslSocket)
}
@Throws(IOException::class)
override fun createSocket(host: InetAddress, port: Int): SSLSocket {
override fun createSocket(
host: InetAddress,
port: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port) as SSLSocket
return configureSocket(sslSocket)
}
@Throws(IOException::class)
override fun createSocket(
host: InetAddress, port: Int, localAddress: InetAddress, localPort: Int
host: InetAddress,
port: Int,
localAddress: InetAddress,
localPort: Int,
): SSLSocket {
val sslSocket = delegate.createSocket(host, port, localAddress, localPort) as SSLSocket
return configureSocket(sslSocket)
@ -70,7 +82,10 @@ open class DelegatingSSLSocketFactory(private val delegate: SSLSocketFactory) :
@Throws(IOException::class)
override fun createSocket(
socket: Socket, host: String, port: Int, autoClose: Boolean
socket: Socket,
host: String,
port: Int,
autoClose: Boolean,
): SSLSocket {
val sslSocket = delegate.createSocket(socket, host, port, autoClose) as SSLSocket
return configureSocket(sslSocket)

View File

@ -38,13 +38,20 @@ open class DelegatingServerSocketFactory(private val delegate: ServerSocketFacto
}
@Throws(IOException::class)
override fun createServerSocket(port: Int, backlog: Int): ServerSocket {
override fun createServerSocket(
port: Int,
backlog: Int,
): ServerSocket {
val serverSocket = delegate.createServerSocket(port, backlog)
return configureServerSocket(serverSocket)
}
@Throws(IOException::class)
override fun createServerSocket(port: Int, backlog: Int, ifAddress: InetAddress): ServerSocket {
override fun createServerSocket(
port: Int,
backlog: Int,
ifAddress: InetAddress,
): ServerSocket {
val serverSocket = delegate.createServerSocket(port, backlog, ifAddress)
return configureServerSocket(serverSocket)
}

View File

@ -32,30 +32,40 @@ open class DelegatingSocketFactory(private val delegate: SocketFactory) : Socket
}
@Throws(IOException::class)
override fun createSocket(host: String, port: Int): Socket {
override fun createSocket(
host: String,
port: Int,
): Socket {
val socket = delegate.createSocket(host, port)
return configureSocket(socket)
}
@Throws(IOException::class)
override fun createSocket(
host: String, port: Int, localAddress: InetAddress,
localPort: Int
host: String,
port: Int,
localAddress: InetAddress,
localPort: Int,
): Socket {
val socket = delegate.createSocket(host, port, localAddress, localPort)
return configureSocket(socket)
}
@Throws(IOException::class)
override fun createSocket(host: InetAddress, port: Int): Socket {
override fun createSocket(
host: InetAddress,
port: Int,
): Socket {
val socket = delegate.createSocket(host, port)
return configureSocket(socket)
}
@Throws(IOException::class)
override fun createSocket(
host: InetAddress, port: Int, localAddress: InetAddress,
localPort: Int
host: InetAddress,
port: Int,
localAddress: InetAddress,
localPort: Int,
): Socket {
val socket = delegate.createSocket(host, port, localAddress, localPort)
return configureSocket(socket)

View File

@ -29,7 +29,7 @@ class FakeDns : Dns {
/** Sets the results for `hostname`. */
operator fun set(
hostname: String,
addresses: List<InetAddress>
addresses: List<InetAddress>,
): FakeDns {
hostAddresses[hostname] = addresses
return this
@ -41,9 +41,10 @@ class FakeDns : Dns {
return this
}
@Throws(UnknownHostException::class) fun lookup(
@Throws(UnknownHostException::class)
fun lookup(
hostname: String,
index: Int
index: Int,
): InetAddress {
return hostAddresses[hostname]!![index]
}
@ -66,7 +67,7 @@ class FakeDns : Dns {
return (from until nextAddress)
.map {
return@map InetAddress.getByAddress(
Buffer().writeInt(it.toInt()).readByteArray()
Buffer().writeInt(it.toInt()).readByteArray(),
)
}
}
@ -78,7 +79,7 @@ class FakeDns : Dns {
return (from until nextAddress)
.map {
return@map InetAddress.getByAddress(
Buffer().writeLong(0L).writeLong(it).readByteArray()
Buffer().writeLong(0L).writeLong(it).readByteArray(),
)
}
}

View File

@ -37,7 +37,7 @@ class FakeProxySelector : ProxySelector() {
override fun connectFailed(
uri: URI,
sa: SocketAddress,
ioe: IOException
ioe: IOException,
) {
}
}

View File

@ -67,7 +67,7 @@ class FakeSSLSession(vararg val certificates: Certificate) : SSLSession {
}
@Throws(
SSLPeerUnverifiedException::class
SSLPeerUnverifiedException::class,
)
override fun getPeerCertificateChain(): Array<X509Certificate> {
throw UnsupportedOperationException()
@ -96,7 +96,7 @@ class FakeSSLSession(vararg val certificates: Certificate) : SSLSession {
override fun putValue(
s: String,
obj: Any
obj: Any,
) {
throw UnsupportedOperationException()
}

View File

@ -20,6 +20,7 @@ import okio.BufferedSink
open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() {
private val delegate: RequestBody
fun delegate(): RequestBody {
return delegate
}
@ -28,7 +29,8 @@ open class ForwardingRequestBody(delegate: RequestBody?) : RequestBody() {
return delegate.contentType()
}
@Throws(IOException::class) override fun contentLength(): Long {
@Throws(IOException::class)
override fun contentLength(): Long {
return delegate.contentLength()
}

View File

@ -19,6 +19,7 @@ import okio.BufferedSource
open class ForwardingResponseBody(delegate: ResponseBody?) : ResponseBody() {
private val delegate: ResponseBody
fun delegate(): ResponseBody {
return delegate
}

View File

@ -22,25 +22,30 @@ import java.util.logging.LogRecord
object JsseDebugLogging {
data class JsseDebugMessage(val message: String, val param: String?) {
enum class Type {
Handshake, Plaintext, Encrypted, Setup, Unknown
Handshake,
Plaintext,
Encrypted,
Setup,
Unknown,
}
val type: Type
get() = when {
message == "adding as trusted certificates" -> Type.Setup
message == "Raw read" || message == "Raw write" -> Type.Encrypted
message == "Plaintext before ENCRYPTION" || message == "Plaintext after DECRYPTION" -> Type.Plaintext
message.startsWith("System property ") -> Type.Setup
message.startsWith("Reload ") -> Type.Setup
message == "No session to resume." -> Type.Handshake
message.startsWith("Consuming ") -> Type.Handshake
message.startsWith("Produced ") -> Type.Handshake
message.startsWith("Negotiated ") -> Type.Handshake
message.startsWith("Found resumable session") -> Type.Handshake
message.startsWith("Resuming session") -> Type.Handshake
message.startsWith("Using PSK to derive early secret") -> Type.Handshake
else -> Type.Unknown
}
get() =
when {
message == "adding as trusted certificates" -> Type.Setup
message == "Raw read" || message == "Raw write" -> Type.Encrypted
message == "Plaintext before ENCRYPTION" || message == "Plaintext after DECRYPTION" -> Type.Plaintext
message.startsWith("System property ") -> Type.Setup
message.startsWith("Reload ") -> Type.Setup
message == "No session to resume." -> Type.Handshake
message.startsWith("Consuming ") -> Type.Handshake
message.startsWith("Produced ") -> Type.Handshake
message.startsWith("Negotiated ") -> Type.Handshake
message.startsWith("Found resumable session") -> Type.Handshake
message.startsWith("Resuming session") -> Type.Handshake
message.startsWith("Using PSK to derive early secret") -> Type.Handshake
else -> Type.Unknown
}
override fun toString(): String {
return if (param != null) {
@ -66,17 +71,20 @@ object JsseDebugLogging {
fun enableJsseDebugLogging(debugHandler: (JsseDebugMessage) -> Unit = this::quietDebug): Closeable {
System.setProperty("javax.net.debug", "")
return OkHttpDebugLogging.enable("javax.net.ssl", object : Handler() {
override fun publish(record: LogRecord) {
val param = record.parameters?.firstOrNull() as? String
debugHandler(JsseDebugMessage(record.message, param))
}
return OkHttpDebugLogging.enable(
"javax.net.ssl",
object : Handler() {
override fun publish(record: LogRecord) {
val param = record.parameters?.firstOrNull() as? String
debugHandler(JsseDebugMessage(record.message, param))
}
override fun flush() {
}
override fun flush() {
}
override fun close() {
}
})
override fun close() {
}
},
)
}
}
}

Some files were not shown because too many files have changed in this diff Show More