From 080a0a38ea18faa2b8313a2bf822846dc1854938 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Sat, 1 Jun 2019 12:34:32 -0400 Subject: [PATCH] Adopt idiomatic Kotlin in MockWebServer --- .../okhttp3/mockwebserver/MockResponse.kt | 8 +- .../okhttp3/mockwebserver/MockWebServer.kt | 192 ++++++++++++------ .../java/okhttp3/KotlinSourceModernTest.kt | 19 +- 3 files changed, 136 insertions(+), 83 deletions(-) diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockResponse.kt b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockResponse.kt index ae8e3f63a..6a3851d4e 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockResponse.kt +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockResponse.kt @@ -276,16 +276,16 @@ class MockResponse : Cloneable { unit.convert(headersDelayAmount, headersDelayUnit) /** - * When [protocols][MockWebServer.setProtocols] include [HTTP_2][okhttp3.Protocol], - * this attaches a pushed stream to this response. + * When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this attaches a + * pushed stream to this response. */ fun withPush(promise: PushPromise) = apply { promises.add(promise) } /** - * When [protocols][MockWebServer.setProtocols] include [HTTP_2][okhttp3.Protocol], - * this pushes [settings] before writing the response. + * When [protocols][MockWebServer.protocols] include [HTTP_2][okhttp3.Protocol], this pushes + * [settings] before writing the response. */ fun withSettings(settings: Settings) = apply { this.settings = settings diff --git a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt index 29a5a3f2d..2e498b2f0 100644 --- a/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt +++ b/mockwebserver/src/main/java/okhttp3/mockwebserver/MockWebServer.kt @@ -104,29 +104,52 @@ class MockWebServer : ExternalResource(), Closeable { Collections.newSetFromMap(ConcurrentHashMap()) private val openConnections = Collections.newSetFromMap(ConcurrentHashMap()) - private val requestCount = AtomicInteger() - private var bodyLimit = Long.MAX_VALUE - private var serverSocketFactory: ServerSocketFactory? = null + + private val atomicRequestCount = AtomicInteger() + + /** + * The number of HTTP requests received thus far by this server. This may exceed the number of + * HTTP connections when connection reuse is in practice. + */ + val requestCount: Int + get() = atomicRequestCount.get() + + /** The number of bytes of the POST body to keep in memory to the given limit. */ + var bodyLimit = Long.MAX_VALUE + + var serverSocketFactory: ServerSocketFactory? = null + get() { + if (field == null && started) { + field = ServerSocketFactory.getDefault() // Build the default value lazily. + } + return field + } + set(value) { + check(!started) { "serverSocketFactory must not be set after start()" } + field = value + } + private var serverSocket: ServerSocket? = null private var sslSocketFactory: SSLSocketFactory? = null private var executor: ExecutorService? = null private var tunnelProxy: Boolean = false private var clientAuth = CLIENT_AUTH_NONE + /** - * Returns the dispatcher used to respond to HTTP requests. - * The default dispatcher is a [QueueDispatcher] but other dispatchers can be configured. + * The dispatcher used to respond to HTTP requests. The default dispatcher is a [QueueDispatcher], + * which serves a fixed sequence of responses from a [queue][enqueue]. * - * Sets the dispatcher used to match incoming requests to mock responses. The default dispatcher - * simply serves a fixed sequence of responses from a [queue][enqueue]; custom - * dispatchers can vary the response based on timing or the content of the request. + * Other dispatchers can be configured. They can vary the response based on timing or the content + * of the request. */ var dispatcher: Dispatcher = QueueDispatcher() - private var port = -1 - private var inetSocketAddress: InetSocketAddress? = null - private var protocolNegotiationEnabled = true - private var protocols = immutableListOf(Protocol.HTTP_2, Protocol.HTTP_1_1) - private var started: Boolean = false + private var portField: Int = -1 + val port: Int + get() { + before() + return portField + } val hostName: String get() { @@ -134,6 +157,37 @@ class MockWebServer : ExternalResource(), Closeable { return inetSocketAddress!!.address.canonicalHostName } + private var inetSocketAddress: InetSocketAddress? = null + + /** + * True if ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1 or + * HTTP/2. This is true by default; set to false to disable negotiation and restrict connections + * to HTTP/1.1. + */ + var protocolNegotiationEnabled = true + + /** + * The protocols supported by ALPN on incoming HTTPS connections in order of preference. The list + * must contain [Protocol.HTTP_1_1]. It must not contain null. + * + * This list is ignored when [negotiation is disabled][protocolNegotiationEnabled]. + */ + @get:JvmName("protocols") var protocols: List = + immutableListOf(Protocol.HTTP_2, Protocol.HTTP_1_1) + set(value) { + val protocolList = value.toImmutableList() + require(Protocol.H2_PRIOR_KNOWLEDGE !in protocolList || protocolList.size == 1) { + "protocols containing h2_prior_knowledge cannot use other protocols: $protocolList" + } + require(Protocol.HTTP_1_1 in protocolList || Protocol.H2_PRIOR_KNOWLEDGE in protocolList) { + "protocols doesn't contain http/1.1: $protocolList" + } + require(null !in protocolList as List) { "protocols must not contain null" } + field = protocolList + } + + private var started: Boolean = false + @Synchronized override fun before() { if (started) return try { @@ -143,20 +197,27 @@ class MockWebServer : ExternalResource(), Closeable { } } - fun getPort(): Int { - before() - return port - } + @JvmName("-deprecated_port") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "port"), + level = DeprecationLevel.WARNING) + fun getPort(): Int = port fun toProxyAddress(): Proxy { before() - val address = InetSocketAddress(inetSocketAddress!!.address - .canonicalHostName, port) + val address = InetSocketAddress(inetSocketAddress!!.address.canonicalHostName, port) return Proxy(Proxy.Type.HTTP, address) } - fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) { - check(executor == null) { "setServerSocketFactory() must be called before start()" } + @JvmName("-deprecated_serverSocketFactory") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith( + expression = "run { this.serverSocketFactory = serverSocketFactory }" + ), + level = DeprecationLevel.WARNING) + fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) = run { this.serverSocketFactory = serverSocketFactory } @@ -174,40 +235,38 @@ class MockWebServer : ExternalResource(), Closeable { .resolve(path)!! } - /** - * Sets the number of bytes of the POST body to keep in memory to the given limit. - */ - fun setBodyLimit(maxBodyLength: Long) { - this.bodyLimit = maxBodyLength - } + @JvmName("-deprecated_bodyLimit") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith( + expression = "run { this.bodyLimit = bodyLimit }" + ), + level = DeprecationLevel.WARNING) + fun setBodyLimit(bodyLimit: Long) = run { this.bodyLimit = bodyLimit } - /** - * Sets whether ALPN is used on incoming HTTPS connections to negotiate a protocol like HTTP/1.1 - * or HTTP/2. Call this method to disable negotiation and restrict connections to HTTP/1.1. - */ - fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) { + @JvmName("-deprecated_protocolNegotiationEnabled") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith( + expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }" + ), + level = DeprecationLevel.WARNING) + fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) = run { this.protocolNegotiationEnabled = protocolNegotiationEnabled } - /** - * Indicates the protocols supported by ALPN on incoming HTTPS connections. This list is ignored - * when [negotiation is disabled][setProtocolNegotiationEnabled]. - * - * @param protocols the protocols to use, in order of preference. The list must contain - * [Protocol.HTTP_1_1]. It must not contain null. - */ - fun setProtocols(protocols: List) { - val protocolList = protocols.toImmutableList() - require(Protocol.H2_PRIOR_KNOWLEDGE !in protocolList || protocolList.size == 1) { - "protocols containing h2_prior_knowledge cannot use other protocols: $protocolList" - } - require(Protocol.HTTP_1_1 in protocolList || Protocol.H2_PRIOR_KNOWLEDGE in protocolList) { - "protocols doesn't contain http/1.1: $protocolList" - } - require(null !in protocolList as List) { "protocols must not contain null" } - this.protocols = protocolList - } + @JvmName("-deprecated_protocols") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"), + level = DeprecationLevel.WARNING) + fun setProtocols(protocols: List) = run { this.protocols = protocols } + @JvmName("-deprecated_protocols") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "protocols"), + level = DeprecationLevel.WARNING) fun protocols(): List = protocols /** @@ -273,11 +332,12 @@ class MockWebServer : ExternalResource(), Closeable { fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? = requestQueue.poll(timeout, unit) - /** - * Returns the number of HTTP requests received thus far by this server. This may exceed the - * number of HTTP connections when connection reuse is in practice. - */ - fun getRequestCount(): Int = requestCount.get() + @JvmName("-deprecated_requestCount") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "requestCount"), + level = DeprecationLevel.WARNING) + fun getRequestCount(): Int = requestCount /** * Scripts [response] to be returned to a request made in sequence. The first request is @@ -322,17 +382,14 @@ class MockWebServer : ExternalResource(), Closeable { executor = Executors.newCachedThreadPool(threadFactory("MockWebServer", false)) this.inetSocketAddress = inetSocketAddress - if (serverSocketFactory == null) { - serverSocketFactory = ServerSocketFactory.getDefault() - } serverSocket = serverSocketFactory!!.createServerSocket() // Reuse if the user specified a port serverSocket!!.reuseAddress = inetSocketAddress.port != 0 serverSocket!!.bind(inetSocketAddress, 50) - port = serverSocket!!.localPort - executor!!.execute("MockWebServer $port") { + portField = serverSocket!!.localPort + executor!!.execute("MockWebServer $portField") { try { logger.info("${this@MockWebServer} starting to accept connections") acceptConnections() @@ -421,7 +478,7 @@ class MockWebServer : ExternalResource(), Closeable { } } - inner class SocketHandler(private val raw: Socket) { + internal inner class SocketHandler(private val raw: Socket) { private var sequenceNumber = 0 @Throws(Exception::class) @@ -497,7 +554,7 @@ class MockWebServer : ExternalResource(), Closeable { if (sequenceNumber == 0) { logger.warning( - "${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request") + "${this@MockWebServer} connection from ${raw.inetAddress} didn't make a request") } socket.close() @@ -531,7 +588,7 @@ class MockWebServer : ExternalResource(), Closeable { ): Boolean { val request = readRequest(socket, source, sink, sequenceNumber) ?: return false - requestCount.incrementAndGet() + atomicRequestCount.incrementAndGet() requestQueue.add(request) val response = dispatcher.dispatch(request) @@ -571,7 +628,8 @@ class MockWebServer : ExternalResource(), Closeable { SHUTDOWN_INPUT_AT_END -> socket.shutdownInput() SHUTDOWN_OUTPUT_AT_END -> socket.shutdownOutput() SHUTDOWN_SERVER_AFTER_RESPONSE -> shutdown() - else -> {} + else -> { + } } sequenceNumber++ return reuseSocket @@ -597,7 +655,7 @@ class MockWebServer : ExternalResource(), Closeable { private fun dispatchBookkeepingRequest(sequenceNumber: Int, socket: Socket) { val request = RecordedRequest( "", headersOf(), emptyList(), 0L, Buffer(), sequenceNumber, socket) - requestCount.incrementAndGet() + atomicRequestCount.incrementAndGet() requestQueue.add(request) dispatcher.dispatch(request) } @@ -832,7 +890,7 @@ class MockWebServer : ExternalResource(), Closeable { check(line.isEmpty()) { "Expected empty but was: $line" } } - override fun toString(): String = "MockWebServer[$port]" + override fun toString(): String = "MockWebServer[$portField]" @Throws(IOException::class) override fun close() = shutdown() @@ -886,7 +944,7 @@ class MockWebServer : ExternalResource(), Closeable { } val request = readRequest(stream) - requestCount.incrementAndGet() + atomicRequestCount.incrementAndGet() requestQueue.add(request) val response: MockResponse = dispatcher.dispatch(request) diff --git a/okhttp/src/test/java/okhttp3/KotlinSourceModernTest.kt b/okhttp/src/test/java/okhttp3/KotlinSourceModernTest.kt index 3ab35067a..6fff7623d 100644 --- a/okhttp/src/test/java/okhttp3/KotlinSourceModernTest.kt +++ b/okhttp/src/test/java/okhttp3/KotlinSourceModernTest.kt @@ -684,7 +684,6 @@ class KotlinSourceModernTest { var mockResponse: MockResponse = MockResponse() var status: String = mockResponse.status status = mockResponse.status - mockResponse = mockResponse.apply { mockResponse.status = "" } mockResponse.status = "" mockResponse = mockResponse.setResponseCode(0) var headers: Headers = mockResponse.getHeaders() @@ -707,11 +706,9 @@ class KotlinSourceModernTest { mockResponse = mockResponse.setChunkedBody("", 0) var socketPolicy: SocketPolicy = mockResponse.socketPolicy socketPolicy = mockResponse.socketPolicy - mockResponse = mockResponse.apply { mockResponse.socketPolicy = SocketPolicy.KEEP_OPEN } mockResponse.socketPolicy = SocketPolicy.KEEP_OPEN var http2ErrorCode: Int = mockResponse.http2ErrorCode http2ErrorCode = mockResponse.http2ErrorCode - mockResponse = mockResponse.apply { mockResponse.http2ErrorCode = 0 } mockResponse.http2ErrorCode = 0 mockResponse = mockResponse.throttleBody(0L, 0L, TimeUnit.SECONDS) var throttleBytesPerPeriod: Long = mockResponse.throttleBytesPerPeriod @@ -736,25 +733,23 @@ class KotlinSourceModernTest { @Test @Ignore fun mockWebServer() { val mockWebServer: MockWebServer = MockWebServer() - var port: Int = mockWebServer.getPort() - port = mockWebServer.getPort() + var port: Int = mockWebServer.port var hostName: String = mockWebServer.hostName hostName = mockWebServer.hostName val toProxyAddress: Proxy = mockWebServer.toProxyAddress() - mockWebServer.setServerSocketFactory(ServerSocketFactory.getDefault()) + mockWebServer.serverSocketFactory = ServerSocketFactory.getDefault() val url: HttpUrl = mockWebServer.url("") - mockWebServer.setBodyLimit(0L) - mockWebServer.setProtocolNegotiationEnabled(false) - mockWebServer.setProtocols(listOf()) - val protocols: List = mockWebServer.protocols() + mockWebServer.bodyLimit = 0L + mockWebServer.protocolNegotiationEnabled = false + mockWebServer.protocols = listOf() + val protocols: List = mockWebServer.protocols mockWebServer.useHttps(SSLSocketFactory.getDefault() as SSLSocketFactory, false) mockWebServer.noClientAuth() mockWebServer.requestClientAuth() mockWebServer.requireClientAuth() var request: RecordedRequest? = mockWebServer.takeRequest() request = mockWebServer.takeRequest(0L, TimeUnit.SECONDS) - var requestCount: Int = mockWebServer.getRequestCount() - requestCount = mockWebServer.getRequestCount() + var requestCount: Int = mockWebServer.requestCount mockWebServer.enqueue(MockResponse()) mockWebServer.start() mockWebServer.start(0)