1
0
mirror of https://github.com/square/okhttp.git synced 2025-11-24 18:41:06 +03:00

Adopt idiomatic Kotlin in MockWebServer

This commit is contained in:
Jesse Wilson
2019-06-01 12:34:32 -04:00
parent 327763a642
commit 080a0a38ea
3 changed files with 136 additions and 83 deletions

View File

@@ -276,16 +276,16 @@ class MockResponse : Cloneable {
unit.convert(headersDelayAmount, headersDelayUnit)
/**
* When [protocols][MockWebServer.setProtocols] include [HTTP_2][okhttp3.Protocol],
* this attaches a pushed stream to this response.
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this attaches a
* pushed stream to this response.
*/
fun withPush(promise: PushPromise) = apply {
promises.add(promise)
}
/**
* When [protocols][MockWebServer.setProtocols] include [HTTP_2][okhttp3.Protocol],
* this pushes [settings] before writing the response.
* When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this pushes
* [settings] before writing the response.
*/
fun withSettings(settings: Settings) = apply {
this.settings = settings

View File

@@ -104,29 +104,52 @@ class MockWebServer : ExternalResource(), Closeable {
Collections.newSetFromMap(ConcurrentHashMap<Socket, Boolean>())
private val openConnections =
Collections.newSetFromMap(ConcurrentHashMap<Http2Connection, Boolean>())
private val requestCount = AtomicInteger()
private var bodyLimit = Long.MAX_VALUE
private var serverSocketFactory: ServerSocketFactory? = null
private val atomicRequestCount = AtomicInteger()
/**
* The number of HTTP requests received thus far by this server. This may exceed the number of
* HTTP connections when connection reuse is in practice.
*/
val requestCount: Int
get() = atomicRequestCount.get()
/** The number of bytes of the POST body to keep in memory to the given limit. */
var bodyLimit = Long.MAX_VALUE
var serverSocketFactory: ServerSocketFactory? = null
get() {
if (field == null && started) {
field = ServerSocketFactory.getDefault() // Build the default value lazily.
}
return field
}
set(value) {
check(!started) { "serverSocketFactory must not be set after start()" }
field = value
}
private var serverSocket: ServerSocket? = null
private var sslSocketFactory: SSLSocketFactory? = null
private var executor: ExecutorService? = null
private var tunnelProxy: Boolean = false
private var clientAuth = CLIENT_AUTH_NONE
/**
* Returns the dispatcher used to respond to HTTP requests.
* The default dispatcher is a [QueueDispatcher] but other dispatchers can be configured.
* The dispatcher used to respond to HTTP requests. The default dispatcher is a [QueueDispatcher],
* which serves a fixed sequence of responses from a [queue][enqueue].
*
* Sets the dispatcher used to match incoming requests to mock responses. The default dispatcher
* simply serves a fixed sequence of responses from a [queue][enqueue]; custom
* dispatchers can vary the response based on timing or the content of the request.
* Other dispatchers can be configured. They can vary the response based on timing or the content
* of the request.
*/
var dispatcher: Dispatcher = QueueDispatcher()
private var port = -1
private var inetSocketAddress: InetSocketAddress? = null
private var protocolNegotiationEnabled = true
private var protocols = immutableListOf(Protocol.HTTP_2, Protocol.HTTP_1_1)
private var started: Boolean = false
private var portField: Int = -1
val port: Int
get() {
before()
return portField
}
val hostName: String
get() {
@@ -134,6 +157,37 @@ class MockWebServer : ExternalResource(), Closeable {
return inetSocketAddress!!.address.canonicalHostName
}
private var inetSocketAddress: InetSocketAddress? = null
/**
* True if ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1 or
* HTTP/2. This is true by default; set to false to disable negotiation and restrict connections
* to HTTP/1.1.
*/
var protocolNegotiationEnabled = true
/**
* The protocols supported by ALPN on incoming HTTPS connections in order of preference. The list
* must contain [Protocol.HTTP_1_1]. It must not contain null.
*
* This list is ignored when [negotiation is disabled][protocolNegotiationEnabled].
*/
@get:JvmName("protocols") var protocols: List<Protocol> =
immutableListOf(Protocol.HTTP_2, Protocol.HTTP_1_1)
set(value) {
val protocolList = value.toImmutableList()
require(Protocol.H2_PRIOR_KNOWLEDGE !in protocolList || protocolList.size == 1) {
"protocols containing h2_prior_knowledge cannot use other protocols: $protocolList"
}
require(Protocol.HTTP_1_1 in protocolList || Protocol.H2_PRIOR_KNOWLEDGE in protocolList) {
"protocols doesn't contain http/1.1: $protocolList"
}
require(null !in protocolList as List<Protocol?>) { "protocols must not contain null" }
field = protocolList
}
private var started: Boolean = false
@Synchronized override fun before() {
if (started) return
try {
@@ -143,20 +197,27 @@ class MockWebServer : ExternalResource(), Closeable {
}
}
fun getPort(): Int {
before()
return port
}
@JvmName("-deprecated_port")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "port"),
level = DeprecationLevel.WARNING)
fun getPort(): Int = port
fun toProxyAddress(): Proxy {
before()
val address = InetSocketAddress(inetSocketAddress!!.address
.canonicalHostName, port)
val address = InetSocketAddress(inetSocketAddress!!.address.canonicalHostName, port)
return Proxy(Proxy.Type.HTTP, address)
}
fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) {
check(executor == null) { "setServerSocketFactory() must be called before start()" }
@JvmName("-deprecated_serverSocketFactory")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.serverSocketFactory = serverSocketFactory }"
),
level = DeprecationLevel.WARNING)
fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) = run {
this.serverSocketFactory = serverSocketFactory
}
@@ -174,40 +235,38 @@ class MockWebServer : ExternalResource(), Closeable {
.resolve(path)!!
}
/**
* Sets the number of bytes of the POST body to keep in memory to the given limit.
*/
fun setBodyLimit(maxBodyLength: Long) {
this.bodyLimit = maxBodyLength
}
@JvmName("-deprecated_bodyLimit")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.bodyLimit = bodyLimit }"
),
level = DeprecationLevel.WARNING)
fun setBodyLimit(bodyLimit: Long) = run { this.bodyLimit = bodyLimit }
/**
* Sets whether ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1
* or HTTP/2. Call this method to disable negotiation and restrict connections to HTTP/1.1.
*/
fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) {
@JvmName("-deprecated_protocolNegotiationEnabled")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(
expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }"
),
level = DeprecationLevel.WARNING)
fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) = run {
this.protocolNegotiationEnabled = protocolNegotiationEnabled
}
/**
* Indicates the protocols supported by ALPN on incoming HTTPS connections. This list is ignored
* when [negotiation is disabled][setProtocolNegotiationEnabled].
*
* @param protocols the protocols to use, in order of preference. The list must contain
* [Protocol.HTTP_1_1]. It must not contain null.
*/
fun setProtocols(protocols: List<Protocol>) {
val protocolList = protocols.toImmutableList()
require(Protocol.H2_PRIOR_KNOWLEDGE !in protocolList || protocolList.size == 1) {
"protocols containing h2_prior_knowledge cannot use other protocols: $protocolList"
}
require(Protocol.HTTP_1_1 in protocolList || Protocol.H2_PRIOR_KNOWLEDGE in protocolList) {
"protocols doesn't contain http/1.1: $protocolList"
}
require(null !in protocolList as List<Protocol?>) { "protocols must not contain null" }
this.protocols = protocolList
}
@JvmName("-deprecated_protocols")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"),
level = DeprecationLevel.WARNING)
fun setProtocols(protocols: List<Protocol>) = run { this.protocols = protocols }
@JvmName("-deprecated_protocols")
@Deprecated(
message = "moved to var",
replaceWith = ReplaceWith(expression = "protocols"),
level = DeprecationLevel.WARNING)
fun protocols(): List<Protocol> = protocols
/**
@@ -273,11 +332,12 @@ class MockWebServer : ExternalResource(), Closeable {
fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? =
requestQueue.poll(timeout, unit)
/**
* Returns the number of HTTP requests received thus far by this server. This may exceed the
* number of HTTP connections when connection reuse is in practice.
*/
fun getRequestCount(): Int = requestCount.get()
@JvmName("-deprecated_requestCount")
@Deprecated(
message = "moved to val",
replaceWith = ReplaceWith(expression = "requestCount"),
level = DeprecationLevel.WARNING)
fun getRequestCount(): Int = requestCount
/**
* Scripts [response] to be returned to a request made in sequence. The first request is
@@ -322,17 +382,14 @@ class MockWebServer : ExternalResource(), Closeable {
executor = Executors.newCachedThreadPool(threadFactory("MockWebServer", false))
this.inetSocketAddress = inetSocketAddress
if (serverSocketFactory == null) {
serverSocketFactory = ServerSocketFactory.getDefault()
}
serverSocket = serverSocketFactory!!.createServerSocket()
// Reuse if the user specified a port
serverSocket!!.reuseAddress = inetSocketAddress.port != 0
serverSocket!!.bind(inetSocketAddress, 50)
port = serverSocket!!.localPort
executor!!.execute("MockWebServer $port") {
portField = serverSocket!!.localPort
executor!!.execute("MockWebServer $portField") {
try {
logger.info("${this@MockWebServer} starting to accept connections")
acceptConnections()
@@ -421,7 +478,7 @@ class MockWebServer : ExternalResource(), Closeable {
}
}
inner class SocketHandler(private val raw: Socket) {
internal inner class SocketHandler(private val raw: Socket) {
private var sequenceNumber = 0
@Throws(Exception::class)
@@ -531,7 +588,7 @@ class MockWebServer : ExternalResource(), Closeable {
): Boolean {
val request = readRequest(socket, source, sink, sequenceNumber) ?: return false
requestCount.incrementAndGet()
atomicRequestCount.incrementAndGet()
requestQueue.add(request)
val response = dispatcher.dispatch(request)
@@ -571,7 +628,8 @@ class MockWebServer : ExternalResource(), Closeable {
SHUTDOWN_INPUT_AT_END -> socket.shutdownInput()
SHUTDOWN_OUTPUT_AT_END -> socket.shutdownOutput()
SHUTDOWN_SERVER_AFTER_RESPONSE -> shutdown()
else -> {}
else -> {
}
}
sequenceNumber++
return reuseSocket
@@ -597,7 +655,7 @@ class MockWebServer : ExternalResource(), Closeable {
private fun dispatchBookkeepingRequest(sequenceNumber: Int, socket: Socket) {
val request = RecordedRequest(
"", headersOf(), emptyList(), 0L, Buffer(), sequenceNumber, socket)
requestCount.incrementAndGet()
atomicRequestCount.incrementAndGet()
requestQueue.add(request)
dispatcher.dispatch(request)
}
@@ -832,7 +890,7 @@ class MockWebServer : ExternalResource(), Closeable {
check(line.isEmpty()) { "Expected empty but was: $line" }
}
override fun toString(): String = "MockWebServer[$port]"
override fun toString(): String = "MockWebServer[$portField]"
@Throws(IOException::class)
override fun close() = shutdown()
@@ -886,7 +944,7 @@ class MockWebServer : ExternalResource(), Closeable {
}
val request = readRequest(stream)
requestCount.incrementAndGet()
atomicRequestCount.incrementAndGet()
requestQueue.add(request)
val response: MockResponse = dispatcher.dispatch(request)

View File

@@ -684,7 +684,6 @@ class KotlinSourceModernTest {
var mockResponse: MockResponse = MockResponse()
var status: String = mockResponse.status
status = mockResponse.status
mockResponse = mockResponse.apply { mockResponse.status = "" }
mockResponse.status = ""
mockResponse = mockResponse.setResponseCode(0)
var headers: Headers = mockResponse.getHeaders()
@@ -707,11 +706,9 @@ class KotlinSourceModernTest {
mockResponse = mockResponse.setChunkedBody("", 0)
var socketPolicy: SocketPolicy = mockResponse.socketPolicy
socketPolicy = mockResponse.socketPolicy
mockResponse = mockResponse.apply { mockResponse.socketPolicy = SocketPolicy.KEEP_OPEN }
mockResponse.socketPolicy = SocketPolicy.KEEP_OPEN
var http2ErrorCode: Int = mockResponse.http2ErrorCode
http2ErrorCode = mockResponse.http2ErrorCode
mockResponse = mockResponse.apply { mockResponse.http2ErrorCode = 0 }
mockResponse.http2ErrorCode = 0
mockResponse = mockResponse.throttleBody(0L, 0L, TimeUnit.SECONDS)
var throttleBytesPerPeriod: Long = mockResponse.throttleBytesPerPeriod
@@ -736,25 +733,23 @@ class KotlinSourceModernTest {
@Test @Ignore
fun mockWebServer() {
val mockWebServer: MockWebServer = MockWebServer()
var port: Int = mockWebServer.getPort()
port = mockWebServer.getPort()
var port: Int = mockWebServer.port
var hostName: String = mockWebServer.hostName
hostName = mockWebServer.hostName
val toProxyAddress: Proxy = mockWebServer.toProxyAddress()
mockWebServer.setServerSocketFactory(ServerSocketFactory.getDefault())
mockWebServer.serverSocketFactory = ServerSocketFactory.getDefault()
val url: HttpUrl = mockWebServer.url("")
mockWebServer.setBodyLimit(0L)
mockWebServer.setProtocolNegotiationEnabled(false)
mockWebServer.setProtocols(listOf())
val protocols: List<Protocol> = mockWebServer.protocols()
mockWebServer.bodyLimit = 0L
mockWebServer.protocolNegotiationEnabled = false
mockWebServer.protocols = listOf()
val protocols: List<Protocol> = mockWebServer.protocols
mockWebServer.useHttps(SSLSocketFactory.getDefault() as SSLSocketFactory, false)
mockWebServer.noClientAuth()
mockWebServer.requestClientAuth()
mockWebServer.requireClientAuth()
var request: RecordedRequest? = mockWebServer.takeRequest()
request = mockWebServer.takeRequest(0L, TimeUnit.SECONDS)
var requestCount: Int = mockWebServer.getRequestCount()
requestCount = mockWebServer.getRequestCount()
var requestCount: Int = mockWebServer.requestCount
mockWebServer.enqueue(MockResponse())
mockWebServer.start()
mockWebServer.start(0)