diff --git a/mockwebserverwrapper/build.gradle b/mockwebserverwrapper/build.gradle new file mode 100644 index 000000000..90e371ed2 --- /dev/null +++ b/mockwebserverwrapper/build.gradle @@ -0,0 +1,22 @@ +jar { + manifest { + attributes('Automatic-Module-Name': 'okhttp3.mockwebserverwrapper') + } +} + +dependencies { + api project(':okhttp') + api project(':mockwebserver') + api deps.junit + + testImplementation project(':okhttp-testing-support') + testImplementation project(':okhttp-tls') + testImplementation deps.assertj +} + +afterEvaluate { project -> + project.tasks.dokka { + outputDirectory = "$rootDir/docs/4.x" + outputFormat = 'gfm' + } +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/Dispatcher.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/Dispatcher.kt new file mode 100644 index 000000000..6a5ac1cc0 --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/Dispatcher.kt @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +abstract class Dispatcher { + @Throws(InterruptedException::class) + abstract fun dispatch(request: RecordedRequest): MockResponse + + open fun peek(): MockResponse { + return MockResponse().apply { this.socketPolicy = SocketPolicy.KEEP_OPEN } + } + + open fun shutdown() {} +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/MockResponse.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/MockResponse.kt new file mode 100644 index 000000000..38cd94d93 --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/MockResponse.kt @@ -0,0 +1,248 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +import okhttp3.Headers +import okhttp3.WebSocketListener +import okhttp3.internal.addHeaderLenient +import okhttp3.internal.http2.Settings +import okio.Buffer +import java.util.concurrent.TimeUnit + +class MockResponse : Cloneable { + @set:JvmName("status") + var status: String = "" + + private var headersBuilder = Headers.Builder() + private var trailersBuilder = Headers.Builder() + + @set:JvmName("headers") + var headers: Headers + get() = headersBuilder.build() + set(value) { + this.headersBuilder = value.newBuilder() + } + + @set:JvmName("trailers") + var trailers: Headers + get() = trailersBuilder.build() + set(value) { + this.trailersBuilder = value.newBuilder() + } + + private var body: Buffer? = null + + var throttleBytesPerPeriod = Long.MAX_VALUE + private set + private var throttlePeriodAmount = 1L + private var throttlePeriodUnit = TimeUnit.SECONDS + + @set:JvmName("socketPolicy") + var socketPolicy = SocketPolicy.KEEP_OPEN + + @set:JvmName("http2ErrorCode") + var http2ErrorCode = -1 + + private var bodyDelayAmount = 0L + private var bodyDelayUnit = TimeUnit.MILLISECONDS + + private var headersDelayAmount = 0L + private var headersDelayUnit = TimeUnit.MILLISECONDS + + private var promises = mutableListOf() + var settings: Settings = Settings() + private set + var webSocketListener: WebSocketListener? = null + private set + + val pushPromises: List + get() = promises + + init { + setResponseCode(200) + setHeader("Content-Length", 0L) + } + + public override fun clone(): MockResponse { + val result = super.clone() as MockResponse + result.headersBuilder = headersBuilder.build().newBuilder() + result.promises = promises.toMutableList() + return result + } + + @JvmName("-deprecated_getStatus") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "status"), + level = DeprecationLevel.ERROR) + fun getStatus(): String = status + + fun setStatus(status: String) = apply { + this.status = status + } + + fun setResponseCode(code: Int): MockResponse { + val reason = when (code) { + in 100..199 -> "Informational" + in 200..299 -> "OK" + in 300..399 -> "Redirection" + in 400..499 -> "Client Error" + in 500..599 -> "Server Error" + else -> "Mock Response" + } + return apply { status = "HTTP/1.1 $code $reason" } + } + + fun clearHeaders() = apply { + headersBuilder = Headers.Builder() + } + + fun addHeader(header: String) = apply { + headersBuilder.add(header) + } + + fun addHeader(name: String, value: Any) = apply { + headersBuilder.add(name, value.toString()) + } + + fun addHeaderLenient(name: String, value: Any) = apply { + addHeaderLenient(headersBuilder, name, value.toString()) + } + + fun setHeader(name: String, value: Any) = apply { + removeHeader(name) + addHeader(name, value) + } + + fun removeHeader(name: String) = apply { + headersBuilder.removeAll(name) + } + + fun getBody(): Buffer? = body?.clone() + + fun setBody(body: Buffer) = apply { + setHeader("Content-Length", body.size) + this.body = body.clone() // Defensive copy. + } + + fun setBody(body: String): MockResponse = setBody(Buffer().writeUtf8(body)) + + fun setChunkedBody(body: Buffer, maxChunkSize: Int) = apply { + removeHeader("Content-Length") + headersBuilder.add(CHUNKED_BODY_HEADER) + + val bytesOut = Buffer() + while (!body.exhausted()) { + val chunkSize = minOf(body.size, maxChunkSize.toLong()) + bytesOut.writeHexadecimalUnsignedLong(chunkSize) + bytesOut.writeUtf8("\r\n") + bytesOut.write(body, chunkSize) + bytesOut.writeUtf8("\r\n") + } + bytesOut.writeUtf8("0\r\n") // Last chunk. Trailers follow! + this.body = bytesOut + } + + fun setChunkedBody(body: String, maxChunkSize: Int): MockResponse = + setChunkedBody(Buffer().writeUtf8(body), maxChunkSize) + + @JvmName("-deprecated_getHeaders") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "headers"), + level = DeprecationLevel.ERROR) + fun getHeaders(): Headers = headers + + fun setHeaders(headers: Headers) = apply { this.headers = headers } + + @JvmName("-deprecated_getTrailers") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "trailers"), + level = DeprecationLevel.ERROR) + fun getTrailers(): Headers = trailers + + fun setTrailers(trailers: Headers) = apply { this.trailers = trailers } + + @JvmName("-deprecated_getSocketPolicy") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "socketPolicy"), + level = DeprecationLevel.ERROR) + fun getSocketPolicy() = socketPolicy + + fun setSocketPolicy(socketPolicy: SocketPolicy) = apply { + this.socketPolicy = socketPolicy + } + + @JvmName("-deprecated_getHttp2ErrorCode") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "http2ErrorCode"), + level = DeprecationLevel.ERROR) + fun getHttp2ErrorCode() = http2ErrorCode + + fun setHttp2ErrorCode(http2ErrorCode: Int) = apply { + this.http2ErrorCode = http2ErrorCode + } + + fun throttleBody(bytesPerPeriod: Long, period: Long, unit: TimeUnit) = apply { + throttleBytesPerPeriod = bytesPerPeriod + throttlePeriodAmount = period + throttlePeriodUnit = unit + } + + fun getThrottlePeriod(unit: TimeUnit): Long = + unit.convert(throttlePeriodAmount, throttlePeriodUnit) + + fun setBodyDelay(delay: Long, unit: TimeUnit) = apply { + bodyDelayAmount = delay + bodyDelayUnit = unit + } + + fun getBodyDelay(unit: TimeUnit): Long = + unit.convert(bodyDelayAmount, bodyDelayUnit) + + fun setHeadersDelay(delay: Long, unit: TimeUnit) = apply { + headersDelayAmount = delay + headersDelayUnit = unit + } + + fun getHeadersDelay(unit: TimeUnit): Long = + unit.convert(headersDelayAmount, headersDelayUnit) + + fun withPush(promise: PushPromise) = apply { + promises.add(promise) + } + + fun withSettings(settings: Settings) = apply { + this.settings = settings + } + + fun withWebSocketUpgrade(listener: WebSocketListener) = apply { + status = "HTTP/1.1 101 Switching Protocols" + setHeader("Connection", "Upgrade") + setHeader("Upgrade", "websocket") + body = null + webSocketListener = listener + } + + override fun toString() = status + + companion object { + private const val CHUNKED_BODY_HEADER = "Transfer-encoding: chunked" + } +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/MockWebServer.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/MockWebServer.kt new file mode 100644 index 000000000..b174d0aea --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/MockWebServer.kt @@ -0,0 +1,218 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +import okhttp3.HttpUrl +import okhttp3.Protocol +import org.junit.rules.ExternalResource +import java.io.Closeable +import java.io.IOException +import java.net.InetAddress +import java.net.Proxy +import java.util.concurrent.TimeUnit +import java.util.logging.Level +import java.util.logging.Logger +import javax.net.ServerSocketFactory +import javax.net.ssl.SSLSocketFactory + +class MockWebServer : ExternalResource(), Closeable { + val delegate = okhttp3.mockwebserver.MockWebServer() + + val requestCount: Int by delegate::requestCount + + var bodyLimit: Long by delegate::bodyLimit + + var serverSocketFactory: ServerSocketFactory? by delegate::serverSocketFactory + + var dispatcher: Dispatcher = QueueDispatcher() + set(value) { + field = value + delegate.dispatcher = value.wrap() + } + + val port: Int + get() { + before() // This implicitly starts the delegate. + return delegate.port + } + + val hostName: String + get() { + before() // This implicitly starts the delegate. + return delegate.hostName + } + + var protocolNegotiationEnabled: Boolean by delegate::protocolNegotiationEnabled + + @get:JvmName("protocols") var protocols: List + get() = delegate.protocols + set(value) { + delegate.protocols = value + } + + init { + delegate.dispatcher = dispatcher.wrap() + } + + private var started: Boolean = false + + @Synchronized override fun before() { + if (started) return + try { + start() + } catch (e: IOException) { + throw RuntimeException(e) + } + } + + @JvmName("-deprecated_port") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "port"), + level = DeprecationLevel.ERROR) + fun getPort(): Int = port + + fun toProxyAddress(): Proxy { + before() // This implicitly starts the delegate. + return delegate.toProxyAddress() + } + + @JvmName("-deprecated_serverSocketFactory") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith( + expression = "run { this.serverSocketFactory = serverSocketFactory }" + ), + level = DeprecationLevel.ERROR) + fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) { + delegate.serverSocketFactory = serverSocketFactory + } + + fun url(path: String): HttpUrl { + before() // This implicitly starts the delegate. + return delegate.url(path) + } + + @JvmName("-deprecated_bodyLimit") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith( + expression = "run { this.bodyLimit = bodyLimit }" + ), + level = DeprecationLevel.ERROR) + fun setBodyLimit(bodyLimit: Long) { + delegate.bodyLimit = bodyLimit + } + + @JvmName("-deprecated_protocolNegotiationEnabled") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith( + expression = "run { this.protocolNegotiationEnabled = protocolNegotiationEnabled }" + ), + level = DeprecationLevel.ERROR) + fun setProtocolNegotiationEnabled(protocolNegotiationEnabled: Boolean) { + delegate.protocolNegotiationEnabled = protocolNegotiationEnabled + } + + @JvmName("-deprecated_protocols") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "run { this.protocols = protocols }"), + level = DeprecationLevel.ERROR) + fun setProtocols(protocols: List) { + delegate.protocols = protocols + } + + @JvmName("-deprecated_protocols") + @Deprecated( + message = "moved to var", + replaceWith = ReplaceWith(expression = "protocols"), + level = DeprecationLevel.ERROR) + fun protocols(): List = delegate.protocols + + fun useHttps(sslSocketFactory: SSLSocketFactory, tunnelProxy: Boolean) { + delegate.useHttps(sslSocketFactory, tunnelProxy) + } + + fun noClientAuth() { + delegate.noClientAuth() + } + + fun requestClientAuth() { + delegate.requestClientAuth() + } + + fun requireClientAuth() { + delegate.requireClientAuth() + } + + @Throws(InterruptedException::class) + fun takeRequest(): RecordedRequest { + return delegate.takeRequest().unwrap() + } + + @Throws(InterruptedException::class) + fun takeRequest(timeout: Long, unit: TimeUnit): RecordedRequest? { + return delegate.takeRequest(timeout, unit)?.unwrap() + } + + @JvmName("-deprecated_requestCount") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "requestCount"), + level = DeprecationLevel.ERROR) + fun getRequestCount(): Int = delegate.requestCount + + fun enqueue(response: MockResponse) { + delegate.enqueue(response.wrap()) + } + + @Throws(IOException::class) + @JvmOverloads fun start(port: Int = 0) { + started = true + delegate.start(port) + } + + @Throws(IOException::class) + fun start(inetAddress: InetAddress, port: Int) { + started = true + delegate.start(inetAddress, port) + } + + @Synchronized + @Throws(IOException::class) + fun shutdown() { + delegate.shutdown() + } + + @Synchronized override fun after() { + try { + shutdown() + } catch (e: IOException) { + logger.log(Level.WARNING, "MockWebServer shutdown failed", e) + } + } + + override fun toString(): String = delegate.toString() + + @Throws(IOException::class) + override fun close() = delegate.close() + + companion object { + private val logger = Logger.getLogger(MockWebServer::class.java.name) + } +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/PushPromise.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/PushPromise.kt new file mode 100644 index 000000000..818439dad --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/PushPromise.kt @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +import okhttp3.Headers + +class PushPromise( + @get:JvmName("method") val method: String, + @get:JvmName("path") val path: String, + @get:JvmName("headers") val headers: Headers, + @get:JvmName("response") val response: MockResponse +) { + + @JvmName("-deprecated_method") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "method"), + level = DeprecationLevel.ERROR) + fun method() = method + + @JvmName("-deprecated_path") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "path"), + level = DeprecationLevel.ERROR) + fun path() = path + + @JvmName("-deprecated_headers") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "headers"), + level = DeprecationLevel.ERROR) + fun headers() = headers + + @JvmName("-deprecated_response") + @Deprecated( + message = "moved to val", + replaceWith = ReplaceWith(expression = "response"), + level = DeprecationLevel.ERROR) + fun response() = response +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/QueueDispatcher.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/QueueDispatcher.kt new file mode 100644 index 000000000..23baf8ee2 --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/QueueDispatcher.kt @@ -0,0 +1,45 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +class QueueDispatcher : Dispatcher() { + internal val delegate = okhttp3.mockwebserver.QueueDispatcher() + + @Throws(InterruptedException::class) + override fun dispatch(request: RecordedRequest): MockResponse { + throw UnsupportedOperationException("unexpected call") + } + + override fun peek(): MockResponse { + throw UnsupportedOperationException("unexpected call") + } + + fun enqueueResponse(response: MockResponse) { + delegate.enqueueResponse(response.wrap()) + } + + override fun shutdown() { + delegate.shutdown() + } + + fun setFailFast(failFast: Boolean) { + delegate.setFailFast(failFast) + } + + fun setFailFast(failFastResponse: MockResponse?) { + delegate.setFailFast(failFastResponse?.wrap()) + } +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/RecordedRequest.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/RecordedRequest.kt new file mode 100644 index 000000000..c25e0bae7 --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/RecordedRequest.kt @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +import okhttp3.Handshake +import okhttp3.Handshake.Companion.handshake +import okhttp3.Headers +import okhttp3.HttpUrl +import okhttp3.HttpUrl.Companion.toHttpUrlOrNull +import okhttp3.TlsVersion +import okio.Buffer +import java.io.IOException +import java.net.Inet6Address +import java.net.Socket +import javax.net.ssl.SSLSocket + +class RecordedRequest { + val requestLine: String + val headers: Headers + val chunkSizes: List + val bodySize: Long + val body: Buffer + val sequenceNumber: Int + val failure: IOException? + val method: String? + val path: String? + val handshake: Handshake? + val requestUrl: HttpUrl? + + @get:JvmName("-deprecated_utf8Body") + @Deprecated( + message = "Use body.readUtf8()", + replaceWith = ReplaceWith("body.readUtf8()"), + level = DeprecationLevel.ERROR) + val utf8Body: String + get() = body.readUtf8() + + val tlsVersion: TlsVersion? + get() = handshake?.tlsVersion + + internal constructor( + requestLine: String, + headers: Headers, + chunkSizes: List, + bodySize: Long, + body: Buffer, + sequenceNumber: Int, + failure: IOException?, + method: String?, + path: String?, + handshake: Handshake?, + requestUrl: HttpUrl? + ) { + this.requestLine = requestLine + this.headers = headers + this.chunkSizes = chunkSizes + this.bodySize = bodySize + this.body = body + this.sequenceNumber = sequenceNumber + this.failure = failure + this.method = method + this.path = path + this.handshake = handshake + this.requestUrl = requestUrl + } + + @JvmOverloads + constructor( + requestLine: String, + headers: Headers, + chunkSizes: List, + bodySize: Long, + body: Buffer, + sequenceNumber: Int, + socket: Socket, + failure: IOException? = null + ) { + this.requestLine = requestLine; + this.headers = headers + this.chunkSizes = chunkSizes + this.bodySize = bodySize + this.body = body + this.sequenceNumber = sequenceNumber + this.failure = failure + + if (socket is SSLSocket) { + try { + this.handshake = socket.session.handshake() + } catch (e: IOException) { + throw IllegalArgumentException(e) + } + } else { + this.handshake = null + } + + if (requestLine.isNotEmpty()) { + val methodEnd = requestLine.indexOf(' ') + val pathEnd = requestLine.indexOf(' ', methodEnd + 1) + this.method = requestLine.substring(0, methodEnd) + var path = requestLine.substring(methodEnd + 1, pathEnd) + if (!path.startsWith("/")) { + path = "/" + } + this.path = path + + val scheme = if (socket is SSLSocket) "https" else "http" + val inetAddress = socket.localAddress + + var hostname = inetAddress.hostName + if (inetAddress is Inet6Address && hostname.contains(':')) { + // hostname is likely some form representing the IPv6 bytes + // 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + // 2001:db8:85a3::8a2e:370:7334 + // ::1 + hostname = "[$hostname]" + } + + val localPort = socket.localPort + // Allow null in failure case to allow for testing bad requests + this.requestUrl = "$scheme://$hostname:$localPort$path".toHttpUrlOrNull() + } else { + this.requestUrl = null + this.method = null + this.path = null + } + } + + @Deprecated( + message = "Use body.readUtf8()", + replaceWith = ReplaceWith("body.readUtf8()"), + level = DeprecationLevel.WARNING) + fun getUtf8Body(): String = body.readUtf8() + + fun getHeader(name: String): String? = headers.values(name).firstOrNull() + + override fun toString(): String = requestLine +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/SocketPolicy.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/SocketPolicy.kt new file mode 100644 index 000000000..f45dd2fb8 --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/SocketPolicy.kt @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +enum class SocketPolicy { + SHUTDOWN_SERVER_AFTER_RESPONSE, + KEEP_OPEN, + DISCONNECT_AT_END, + UPGRADE_TO_SSL_AT_END, + DISCONNECT_AT_START, + DISCONNECT_AFTER_REQUEST, + DISCONNECT_DURING_REQUEST_BODY, + DISCONNECT_DURING_RESPONSE_BODY, + DO_NOT_READ_REQUEST_BODY, + FAIL_HANDSHAKE, + SHUTDOWN_INPUT_AT_END, + SHUTDOWN_OUTPUT_AT_END, + STALL_SOCKET_AT_START, + NO_RESPONSE, + RESET_STREAM_AT_START, + EXPECT_CONTINUE, + CONTINUE_ALWAYS +} diff --git a/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/bridge.kt b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/bridge.kt new file mode 100644 index 000000000..779090f6e --- /dev/null +++ b/mockwebserverwrapper/src/main/kotlin/okhttp3/mockwebserverwrapper/bridge.kt @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper + +import java.util.concurrent.TimeUnit.MILLISECONDS + +internal fun Dispatcher.wrap(): okhttp3.mockwebserver.Dispatcher { + if (this is QueueDispatcher) return this.delegate + + val delegate = this + return object : okhttp3.mockwebserver.Dispatcher() { + override fun dispatch( + request: okhttp3.mockwebserver.RecordedRequest + ): okhttp3.mockwebserver.MockResponse { + return delegate.dispatch(request.unwrap()).wrap() + } + + override fun peek(): okhttp3.mockwebserver.MockResponse { + return delegate.peek().wrap() + } + + override fun shutdown() { + delegate.shutdown() + } + } +} + +internal fun MockResponse.wrap(): okhttp3.mockwebserver.MockResponse { + val result = okhttp3.mockwebserver.MockResponse() + val copyFromWebSocketListener = webSocketListener + if (copyFromWebSocketListener != null) { + result.withWebSocketUpgrade(copyFromWebSocketListener) + } + + val body = getBody() + if (body != null) result.setBody(body) + + for (pushPromise in pushPromises) { + result.withPush(pushPromise.wrap()) + } + + result.withSettings(settings) + result.status = status + result.headers = headers + result.trailers = trailers + result.socketPolicy = socketPolicy.wrap() + result.http2ErrorCode = http2ErrorCode + result.throttleBody(throttleBytesPerPeriod, getThrottlePeriod(MILLISECONDS), MILLISECONDS) + result.setBodyDelay(getBodyDelay(MILLISECONDS), MILLISECONDS) + result.setHeadersDelay(getHeadersDelay(MILLISECONDS), MILLISECONDS) + return result +} + +private fun PushPromise.wrap(): okhttp3.mockwebserver.PushPromise { + return okhttp3.mockwebserver.PushPromise( + method = method, + path = path, + headers = headers, + response = response.wrap() + ) +} + +internal fun okhttp3.mockwebserver.RecordedRequest.unwrap(): RecordedRequest { + return RecordedRequest( + requestLine = requestLine, + headers = headers, + chunkSizes = chunkSizes, + bodySize = bodySize, + body = body, + sequenceNumber = sequenceNumber, + failure = failure, + method = method, + path = path, + handshake = handshake, + requestUrl = requestUrl + ) +} + +private fun SocketPolicy.wrap(): okhttp3.mockwebserver.SocketPolicy { + return okhttp3.mockwebserver.SocketPolicy.valueOf(name) +} diff --git a/mockwebserverwrapper/src/test/java/okhttp3/mockwebserverwrapper/MockWebServerTest.java b/mockwebserverwrapper/src/test/java/okhttp3/mockwebserverwrapper/MockWebServerTest.java new file mode 100644 index 000000000..001f8416e --- /dev/null +++ b/mockwebserverwrapper/src/test/java/okhttp3/mockwebserverwrapper/MockWebServerTest.java @@ -0,0 +1,629 @@ +/* + * Copyright (C) 2011 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.mockwebserverwrapper; + +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.ConnectException; +import java.net.HttpURLConnection; +import java.net.ProtocolException; +import java.net.SocketTimeoutException; +import java.net.URL; +import java.net.URLConnection; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.net.ssl.HttpsURLConnection; +import okhttp3.Handshake; +import okhttp3.Headers; +import okhttp3.HttpUrl; +import okhttp3.Protocol; +import okhttp3.RecordingHostnameVerifier; +import okhttp3.TestUtil; +import okhttp3.testing.PlatformRule; +import okhttp3.tls.HandshakeCertificates; +import okhttp3.tls.HeldCertificate; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.asList; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static okhttp3.tls.internal.TlsUtil.localhost; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.data.Offset.offset; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; + +@SuppressWarnings({"ArraysAsListWithZeroOrOneArgument", "deprecation"}) +public final class MockWebServerTest { + @Rule public PlatformRule platform = new PlatformRule(); + + @Rule public final MockWebServer server = new MockWebServer(); + + @Rule public Timeout globalTimeout = Timeout.seconds(30); + + @Before public void checkPlatforms() { + platform.assumeNotBouncyCastle(); + } + + @Test public void defaultMockResponse() { + MockResponse response = new MockResponse(); + assertThat(headersToList(response)).containsExactly("Content-Length: 0"); + assertThat(response.getStatus()).isEqualTo("HTTP/1.1 200 OK"); + } + + @Test public void setResponseMockReason() { + String[] reasons = { + "Mock Response", + "Informational", + "OK", + "Redirection", + "Client Error", + "Server Error", + "Mock Response" + }; + for (int i = 0; i < 600; i++) { + MockResponse response = new MockResponse().setResponseCode(i); + String expectedReason = reasons[i / 100]; + assertThat(response.getStatus()).isEqualTo(("HTTP/1.1 " + i + " " + expectedReason)); + assertThat(headersToList(response)).containsExactly("Content-Length: 0"); + } + } + + @Test public void setStatusControlsWholeStatusLine() { + MockResponse response = new MockResponse().setStatus("HTTP/1.1 202 That'll do pig"); + assertThat(headersToList(response)).containsExactly("Content-Length: 0"); + assertThat(response.getStatus()).isEqualTo("HTTP/1.1 202 That'll do pig"); + } + + @Test public void setBodyAdjustsHeaders() throws IOException { + MockResponse response = new MockResponse().setBody("ABC"); + assertThat(headersToList(response)).containsExactly("Content-Length: 3"); + assertThat(response.getBody().readUtf8()).isEqualTo("ABC"); + } + + @Test public void mockResponseAddHeader() { + MockResponse response = new MockResponse() + .clearHeaders() + .addHeader("Cookie: s=square") + .addHeader("Cookie", "a=android"); + assertThat(headersToList(response)).containsExactly("Cookie: s=square", "Cookie: a=android"); + } + + @Test public void mockResponseSetHeader() { + MockResponse response = new MockResponse() + .clearHeaders() + .addHeader("Cookie: s=square") + .addHeader("Cookie: a=android") + .addHeader("Cookies: delicious"); + response.setHeader("cookie", "r=robot"); + assertThat(headersToList(response)).containsExactly("Cookies: delicious", "cookie: r=robot"); + } + + @Test public void mockResponseSetHeaders() { + MockResponse response = new MockResponse() + .clearHeaders() + .addHeader("Cookie: s=square") + .addHeader("Cookies: delicious"); + + response.setHeaders(new Headers.Builder().add("Cookie", "a=android").build()); + + assertThat(headersToList(response)).containsExactly("Cookie: a=android"); + } + + @Test public void regularResponse() throws Exception { + server.enqueue(new MockResponse().setBody("hello world")); + + URL url = server.url("/").url(); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestProperty("Accept-Language", "en-US"); + InputStream in = connection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, UTF_8)); + assertThat(connection.getResponseCode()).isEqualTo(HttpURLConnection.HTTP_OK); + assertThat(reader.readLine()).isEqualTo("hello world"); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getRequestLine()).isEqualTo("GET / HTTP/1.1"); + assertThat(request.getHeader("Accept-Language")).isEqualTo("en-US"); + + // Server has no more requests. + assertThat(server.takeRequest(100, MILLISECONDS)).isNull(); + } + + @Test public void redirect() throws Exception { + server.enqueue(new MockResponse() + .setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP) + .addHeader("Location: " + server.url("/new-path")) + .setBody("This page has moved!")); + server.enqueue(new MockResponse().setBody("This is the new location!")); + + URLConnection connection = server.url("/").url().openConnection(); + InputStream in = connection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, UTF_8)); + assertThat(reader.readLine()).isEqualTo("This is the new location!"); + + RecordedRequest first = server.takeRequest(); + assertThat(first.getRequestLine()).isEqualTo("GET / HTTP/1.1"); + RecordedRequest redirect = server.takeRequest(); + assertThat(redirect.getRequestLine()).isEqualTo("GET /new-path HTTP/1.1"); + } + + /** + * Test that MockWebServer blocks for a call to enqueue() if a request is made before a mock + * response is ready. + */ + @Test public void dispatchBlocksWaitingForEnqueue() throws Exception { + new Thread(() -> { + try { + Thread.sleep(1000); + } catch (InterruptedException ignored) { + } + server.enqueue(new MockResponse().setBody("enqueued in the background")); + }).start(); + + URLConnection connection = server.url("/").url().openConnection(); + InputStream in = connection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, UTF_8)); + assertThat(reader.readLine()).isEqualTo("enqueued in the background"); + } + + @Test public void nonHexadecimalChunkSize() throws Exception { + server.enqueue(new MockResponse() + .setBody("G\r\nxxxxxxxxxxxxxxxx\r\n0\r\n\r\n") + .clearHeaders() + .addHeader("Transfer-encoding: chunked")); + + URLConnection connection = server.url("/").url().openConnection(); + InputStream in = connection.getInputStream(); + try { + in.read(); + fail(); + } catch (IOException expected) { + } + } + + @Test public void responseTimeout() throws Exception { + server.enqueue(new MockResponse() + .setBody("ABC") + .clearHeaders() + .addHeader("Content-Length: 4")); + server.enqueue(new MockResponse().setBody("DEF")); + + URLConnection urlConnection = server.url("/").url().openConnection(); + urlConnection.setReadTimeout(1000); + InputStream in = urlConnection.getInputStream(); + assertThat(in.read()).isEqualTo('A'); + assertThat(in.read()).isEqualTo('B'); + assertThat(in.read()).isEqualTo('C'); + try { + in.read(); // if Content-Length was accurate, this would return -1 immediately + fail(); + } catch (SocketTimeoutException expected) { + } + + URLConnection urlConnection2 = server.url("/").url().openConnection(); + InputStream in2 = urlConnection2.getInputStream(); + assertThat(in2.read()).isEqualTo('D'); + assertThat(in2.read()).isEqualTo('E'); + assertThat(in2.read()).isEqualTo('F'); + assertThat(in2.read()).isEqualTo(-1); + + assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); + assertThat(server.takeRequest().getSequenceNumber()).isEqualTo(0); + } + + @Ignore("Not actually failing where expected") + @Test public void disconnectAtStart() throws Exception { + server.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START)); + server.enqueue(new MockResponse()); // The jdk's HttpUrlConnection is a bastard. + server.enqueue(new MockResponse()); + try { + server.url("/a").url().openConnection().getInputStream(); + fail(); + } catch (IOException expected) { + } + server.url("/b").url().openConnection().getInputStream(); // Should succeed. + } + + /** + * Throttle the request body by sleeping 500ms after every 3 bytes. With a 6-byte request, this + * should yield one sleep for a total delay of 500ms. + */ + @Test public void throttleRequest() throws Exception { + TestUtil.assumeNotWindows(); + + server.enqueue(new MockResponse() + .throttleBody(3, 500, TimeUnit.MILLISECONDS)); + + long startNanos = System.nanoTime(); + URLConnection connection = server.url("/").url().openConnection(); + connection.setDoOutput(true); + connection.getOutputStream().write("ABCDEF".getBytes(UTF_8)); + InputStream in = connection.getInputStream(); + assertThat(in.read()).isEqualTo(-1); + long elapsedNanos = System.nanoTime() - startNanos; + long elapsedMillis = NANOSECONDS.toMillis(elapsedNanos); + assertThat(elapsedMillis).isBetween(500L, 1000L); + } + + /** + * Throttle the response body by sleeping 500ms after every 3 bytes. With a 6-byte response, this + * should yield one sleep for a total delay of 500ms. + */ + @Test public void throttleResponse() throws Exception { + TestUtil.assumeNotWindows(); + + server.enqueue(new MockResponse() + .setBody("ABCDEF") + .throttleBody(3, 500, TimeUnit.MILLISECONDS)); + + long startNanos = System.nanoTime(); + URLConnection connection = server.url("/").url().openConnection(); + InputStream in = connection.getInputStream(); + assertThat(in.read()).isEqualTo('A'); + assertThat(in.read()).isEqualTo('B'); + assertThat(in.read()).isEqualTo('C'); + assertThat(in.read()).isEqualTo('D'); + assertThat(in.read()).isEqualTo('E'); + assertThat(in.read()).isEqualTo('F'); + assertThat(in.read()).isEqualTo(-1); + long elapsedNanos = System.nanoTime() - startNanos; + long elapsedMillis = NANOSECONDS.toMillis(elapsedNanos); + assertThat(elapsedMillis).isBetween(500L, 1000L); + } + + /** Delay the response body by sleeping 1s. */ + @Test public void delayResponse() throws IOException { + TestUtil.assumeNotWindows(); + + server.enqueue(new MockResponse() + .setBody("ABCDEF") + .setBodyDelay(1, SECONDS)); + + long startNanos = System.nanoTime(); + URLConnection connection = server.url("/").url().openConnection(); + InputStream in = connection.getInputStream(); + assertThat(in.read()).isEqualTo('A'); + long elapsedNanos = System.nanoTime() - startNanos; + long elapsedMillis = NANOSECONDS.toMillis(elapsedNanos); + assertThat(elapsedMillis).isGreaterThanOrEqualTo(1000L); + + in.close(); + } + + @Test public void disconnectRequestHalfway() throws Exception { + server.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_DURING_REQUEST_BODY)); + // Limit the size of the request body that the server holds in memory to an arbitrary + // 3.5 MBytes so this test can pass on devices with little memory. + server.setBodyLimit(7 * 512 * 1024); + + HttpURLConnection connection = (HttpURLConnection) server.url("/").url().openConnection(); + connection.setRequestMethod("POST"); + connection.setDoOutput(true); + connection.setFixedLengthStreamingMode(1024 * 1024 * 1024); // 1 GB + connection.connect(); + OutputStream out = connection.getOutputStream(); + + byte[] data = new byte[1024 * 1024]; + int i; + for (i = 0; i < 1024; i++) { + try { + out.write(data); + out.flush(); + if (i == 513) { + // pause slightly after half way to make result more predictable + Thread.sleep(100); + } + } catch (IOException e) { + break; + } + } + // Halfway +/- 0.5% + assertThat((float) i).isCloseTo(512f, offset(5f)); + } + + @Test public void disconnectResponseHalfway() throws IOException { + server.enqueue(new MockResponse() + .setBody("ab") + .setSocketPolicy(SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY)); + + URLConnection connection = server.url("/").url().openConnection(); + assertThat(connection.getContentLength()).isEqualTo(2); + InputStream in = connection.getInputStream(); + assertThat(in.read()).isEqualTo('a'); + try { + int byteRead = in.read(); + // OpenJDK behavior: end of stream. + assertThat(byteRead).isEqualTo(-1); + } catch (ProtocolException e) { + // On Android, HttpURLConnection is implemented by OkHttp v2. OkHttp + // treats an incomplete response body as a ProtocolException. + } + } + + private List headersToList(MockResponse response) { + Headers headers = response.getHeaders(); + int size = headers.size(); + List headerList = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + headerList.add(headers.name(i) + ": " + headers.value(i)); + } + return headerList; + } + + @Test public void shutdownWithoutStart() throws IOException { + MockWebServer server = new MockWebServer(); + server.shutdown(); + } + + @Test public void closeViaClosable() throws IOException { + Closeable server = new MockWebServer(); + server.close(); + } + + @Test public void shutdownWithoutEnqueue() throws IOException { + MockWebServer server = new MockWebServer(); + server.start(); + server.shutdown(); + } + + @Test public void portImplicitlyStarts() { + assertThat(server.getPort()).isGreaterThan(0); + } + + @Test public void hostnameImplicitlyStarts() { + assertThat(server.getHostName()).isNotNull(); + } + + @Test public void toProxyAddressImplicitlyStarts() { + assertThat(server.toProxyAddress()).isNotNull(); + } + + @Test public void differentInstancesGetDifferentPorts() throws IOException { + MockWebServer other = new MockWebServer(); + assertThat(other.getPort()).isNotEqualTo(server.getPort()); + other.shutdown(); + } + + @Test public void statementStartsAndStops() throws Throwable { + final AtomicBoolean called = new AtomicBoolean(); + Statement statement = server.apply(new Statement() { + @Override public void evaluate() throws Throwable { + called.set(true); + server.url("/").url().openConnection().connect(); + } + }, Description.EMPTY); + + statement.evaluate(); + + assertThat(called.get()).isTrue(); + try { + server.url("/").url().openConnection().connect(); + fail(); + } catch (ConnectException expected) { + } + } + + @Test public void shutdownWhileBlockedDispatching() throws Exception { + // Enqueue a request that'll cause MockWebServer to hang on QueueDispatcher.dispatch(). + HttpURLConnection connection = (HttpURLConnection) server.url("/").url().openConnection(); + connection.setReadTimeout(500); + try { + connection.getResponseCode(); + fail(); + } catch (SocketTimeoutException expected) { + } + + // Shutting down the server should unblock the dispatcher. + server.shutdown(); + } + + @Test public void requestUrlReconstructed() throws Exception { + server.enqueue(new MockResponse().setBody("hello world")); + + URL url = server.url("/a/deep/path?key=foo%20bar").url(); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + InputStream in = connection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, UTF_8)); + assertThat(connection.getResponseCode()).isEqualTo(HttpURLConnection.HTTP_OK); + assertThat(reader.readLine()).isEqualTo("hello world"); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getRequestLine()).isEqualTo( + "GET /a/deep/path?key=foo%20bar HTTP/1.1"); + + HttpUrl requestUrl = request.getRequestUrl(); + assertThat(requestUrl.scheme()).isEqualTo("http"); + assertThat(requestUrl.host()).isEqualTo(server.getHostName()); + assertThat(requestUrl.port()).isEqualTo(server.getPort()); + assertThat(requestUrl.encodedPath()).isEqualTo("/a/deep/path"); + assertThat(requestUrl.queryParameter("key")).isEqualTo("foo bar"); + } + + @Test public void shutdownServerAfterRequest() throws Exception { + server.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.SHUTDOWN_SERVER_AFTER_RESPONSE)); + + URL url = server.url("/").url(); + + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + assertThat(connection.getResponseCode()).isEqualTo(HttpURLConnection.HTTP_OK); + + HttpURLConnection refusedConnection = (HttpURLConnection) url.openConnection(); + + try { + refusedConnection.getResponseCode(); + fail("Second connection should be refused"); + } catch (ConnectException e) { + assertThat(e.getMessage()).contains("refused"); + } + } + + @Test public void http100Continue() throws Exception { + server.enqueue(new MockResponse().setBody("response")); + + URL url = server.url("/").url(); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setDoOutput(true); + connection.setRequestProperty("Expect", "100-Continue"); + connection.getOutputStream().write("request".getBytes(UTF_8)); + + InputStream in = connection.getInputStream(); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, UTF_8)); + assertThat(reader.readLine()).isEqualTo("response"); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getBody().readUtf8()).isEqualTo("request"); + } + + @Test public void testH2PriorKnowledgeServerFallback() { + try { + server.setProtocols(asList(Protocol.H2_PRIOR_KNOWLEDGE, Protocol.HTTP_1_1)); + fail(); + } catch (IllegalArgumentException expected) { + assertThat(expected.getMessage()).isEqualTo( + ("protocols containing h2_prior_knowledge cannot use other protocols: " + + "[h2_prior_knowledge, http/1.1]")); + } + } + + @Test public void testH2PriorKnowledgeServerDuplicates() { + try { + // Treating this use case as user error + server.setProtocols(asList(Protocol.H2_PRIOR_KNOWLEDGE, Protocol.H2_PRIOR_KNOWLEDGE)); + fail(); + } catch (IllegalArgumentException expected) { + assertThat(expected.getMessage()).isEqualTo( + ("protocols containing h2_prior_knowledge cannot use other protocols: " + + "[h2_prior_knowledge, h2_prior_knowledge]")); + } + } + + @Test public void testMockWebServerH2PriorKnowledgeProtocol() { + server.setProtocols(asList(Protocol.H2_PRIOR_KNOWLEDGE)); + + assertThat(server.protocols().size()).isEqualTo(1); + assertThat(server.protocols().get(0)).isEqualTo(Protocol.H2_PRIOR_KNOWLEDGE); + } + + @Test public void https() throws Exception { + HandshakeCertificates handshakeCertificates = localhost(); + server.useHttps(handshakeCertificates.sslSocketFactory(), false); + server.enqueue(new MockResponse().setBody("abc")); + + HttpUrl url = server.url("/"); + HttpsURLConnection connection = (HttpsURLConnection) url.url().openConnection(); + connection.setSSLSocketFactory(handshakeCertificates.sslSocketFactory()); + connection.setHostnameVerifier(new RecordingHostnameVerifier()); + + assertThat(connection.getResponseCode()).isEqualTo(HttpURLConnection.HTTP_OK); + BufferedReader reader = + new BufferedReader(new InputStreamReader(connection.getInputStream(), UTF_8)); + assertThat(reader.readLine()).isEqualTo("abc"); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getRequestUrl().scheme()).isEqualTo("https"); + Handshake handshake = request.getHandshake(); + assertThat(handshake.tlsVersion()).isNotNull(); + assertThat(handshake.cipherSuite()).isNotNull(); + assertThat(handshake.localPrincipal()).isNotNull(); + assertThat(handshake.localCertificates().size()).isEqualTo(1); + assertThat(handshake.peerPrincipal()).isNull(); + assertThat(handshake.peerCertificates().size()).isEqualTo(0); + } + + @Test public void httpsWithClientAuth() throws Exception { + assumeFalse(getPlatform().equals("conscrypt")); + + HeldCertificate clientCa = new HeldCertificate.Builder() + .certificateAuthority(0) + .build(); + HeldCertificate serverCa = new HeldCertificate.Builder() + .certificateAuthority(0) + .build(); + HeldCertificate serverCertificate = new HeldCertificate.Builder() + .signedBy(serverCa) + .addSubjectAlternativeName(server.getHostName()) + .build(); + HandshakeCertificates serverHandshakeCertificates = new HandshakeCertificates.Builder() + .addTrustedCertificate(clientCa.certificate()) + .heldCertificate(serverCertificate) + .build(); + + server.useHttps(serverHandshakeCertificates.sslSocketFactory(), false); + server.enqueue(new MockResponse().setBody("abc")); + server.requestClientAuth(); + + HeldCertificate clientCertificate = new HeldCertificate.Builder() + .signedBy(clientCa) + .build(); + HandshakeCertificates clientHandshakeCertificates = new HandshakeCertificates.Builder() + .addTrustedCertificate(serverCa.certificate()) + .heldCertificate(clientCertificate) + .build(); + + HttpUrl url = server.url("/"); + HttpsURLConnection connection = (HttpsURLConnection) url.url().openConnection(); + connection.setSSLSocketFactory(clientHandshakeCertificates.sslSocketFactory()); + connection.setHostnameVerifier(new RecordingHostnameVerifier()); + + assertThat(connection.getResponseCode()).isEqualTo(HttpURLConnection.HTTP_OK); + BufferedReader reader = + new BufferedReader(new InputStreamReader(connection.getInputStream(), UTF_8)); + assertThat(reader.readLine()).isEqualTo("abc"); + + RecordedRequest request = server.takeRequest(); + assertThat(request.getRequestUrl().scheme()).isEqualTo("https"); + Handshake handshake = request.getHandshake(); + assertThat(handshake.tlsVersion()).isNotNull(); + assertThat(handshake.cipherSuite()).isNotNull(); + assertThat(handshake.localPrincipal()).isNotNull(); + assertThat(handshake.localCertificates().size()).isEqualTo(1); + assertThat(handshake.peerPrincipal()).isNotNull(); + assertThat(handshake.peerCertificates().size()).isEqualTo(1); + } + + @Test + public void shutdownTwice() throws IOException { + MockWebServer server2 = new MockWebServer(); + + server2.start(); + server2.shutdown(); + try { + server2.start(); + fail(); + } catch (IllegalArgumentException iae) { + // expected + } + server2.shutdown(); + } + + public static String getPlatform() { + return System.getProperty("okhttp.platform", "jdk8"); + } +} diff --git a/settings.gradle b/settings.gradle index a90c9b8bc..1c8275806 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,6 +1,7 @@ rootProject.name = 'okhttp-parent' include ':mockwebserver' +include ':mockwebserverwrapper' if (properties.containsKey('android.injected.invoked.from.ide') || System.getenv('ANDROID_SDK_ROOT') != null) {