diff --git a/okhttp/src/main/java/okhttp3/Handshake.kt b/okhttp/src/main/java/okhttp3/Handshake.kt index 8fb6a8a47..2f4505742 100644 --- a/okhttp/src/main/java/okhttp3/Handshake.kt +++ b/okhttp/src/main/java/okhttp3/Handshake.kt @@ -42,8 +42,7 @@ class Handshake internal constructor( @get:JvmName("cipherSuite") val cipherSuite: CipherSuite, /** Returns a possibly-empty list of certificates that identify this peer. */ - @get:JvmName( - "localCertificates") val localCertificates: List, + @get:JvmName("localCertificates") val localCertificates: List, // Delayed provider of peerCertificates, to allow lazy cleaning. peerCertificatesFn: () -> List @@ -141,10 +140,12 @@ class Handshake internal constructor( @JvmName("get") fun SSLSession.handshake(): Handshake { val cipherSuiteString = checkNotNull(cipherSuite) { "cipherSuite == null" } - if ("SSL_NULL_WITH_NULL_NULL" == cipherSuiteString) { - throw IOException("cipherSuite == SSL_NULL_WITH_NULL_NULL") + val cipherSuite = when (cipherSuiteString) { + "TLS_NULL_WITH_NULL_NULL", "SSL_NULL_WITH_NULL_NULL" -> { + throw IOException("cipherSuite == $cipherSuiteString") + } + else -> CipherSuite.forJavaName(cipherSuiteString) } - val cipherSuite = CipherSuite.forJavaName(cipherSuiteString) val tlsVersionString = checkNotNull(protocol) { "tlsVersion == null" } if ("NONE" == tlsVersionString) throw IOException("tlsVersion == NONE") @@ -184,8 +185,9 @@ class Handshake internal constructor( localCertificates: List ): Handshake { val peerCertificatesCopy = peerCertificates.toImmutableList() - return Handshake(tlsVersion, cipherSuite, localCertificates.toImmutableList() - ) { peerCertificatesCopy } + return Handshake(tlsVersion, cipherSuite, localCertificates.toImmutableList()) { + peerCertificatesCopy + } } } -} \ No newline at end of file +} diff --git a/okhttp/src/test/java/okhttp3/DelegatingSSLSession.java b/okhttp/src/test/java/okhttp3/DelegatingSSLSession.java new file mode 100644 index 000000000..0e15e825b --- /dev/null +++ b/okhttp/src/test/java/okhttp3/DelegatingSSLSession.java @@ -0,0 +1,116 @@ +/* + * Copyright 2019 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3; + +import java.security.Principal; +import java.security.cert.Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; +import javax.security.cert.X509Certificate; + +/** An {@link SSLSession} that delegates all calls. */ +public abstract class DelegatingSSLSession implements SSLSession { + protected final SSLSession delegate; + + public DelegatingSSLSession(SSLSession delegate) { + this.delegate = delegate; + } + + @Override public byte[] getId() { + return delegate.getId(); + } + + @Override public SSLSessionContext getSessionContext() { + return delegate.getSessionContext(); + } + + @Override public long getCreationTime() { + return delegate.getCreationTime(); + } + + @Override public long getLastAccessedTime() { + return delegate.getLastAccessedTime(); + } + + @Override public void invalidate() { + delegate.invalidate(); + } + + @Override public boolean isValid() { + return delegate.isValid(); + } + + @Override public void putValue(String s, Object o) { + delegate.putValue(s, o); + } + + @Override public Object getValue(String s) { + return delegate.getValue(s); + } + + @Override public void removeValue(String s) { + delegate.removeValue(s); + } + + @Override public String[] getValueNames() { + return delegate.getValueNames(); + } + + @Override public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return delegate.getPeerCertificates(); + } + + @Override public Certificate[] getLocalCertificates() { + return delegate.getLocalCertificates(); + } + + @Override public X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException { + return delegate.getPeerCertificateChain(); + } + + @Override public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return delegate.getPeerPrincipal(); + } + + @Override public Principal getLocalPrincipal() { + return delegate.getLocalPrincipal(); + } + + @Override public String getCipherSuite() { + return delegate.getCipherSuite(); + } + + @Override public String getProtocol() { + return delegate.getProtocol(); + } + + @Override public String getPeerHost() { + return delegate.getPeerHost(); + } + + @Override public int getPeerPort() { + return delegate.getPeerPort(); + } + + @Override public int getPacketBufferSize() { + return delegate.getPacketBufferSize(); + } + + @Override public int getApplicationBufferSize() { + return delegate.getApplicationBufferSize(); + } +} diff --git a/okhttp/src/test/java/okhttp3/DelegatingSSLSocket.java b/okhttp/src/test/java/okhttp3/DelegatingSSLSocket.java index 6647dfc4d..8a772e7a1 100644 --- a/okhttp/src/test/java/okhttp3/DelegatingSSLSocket.java +++ b/okhttp/src/test/java/okhttp3/DelegatingSSLSocket.java @@ -31,9 +31,7 @@ import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; -/** - * An {@link javax.net.ssl.SSLSocket} that delegates all calls. - */ +/** An {@link SSLSocket} that delegates all calls. */ public abstract class DelegatingSSLSocket extends SSLSocket { protected final SSLSocket delegate; diff --git a/okhttp/src/test/java/okhttp3/HandshakeTest.kt b/okhttp/src/test/java/okhttp3/HandshakeTest.kt new file mode 100644 index 000000000..80e6bcc03 --- /dev/null +++ b/okhttp/src/test/java/okhttp3/HandshakeTest.kt @@ -0,0 +1,126 @@ +/* + * Copyright (C) 2019 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import okhttp3.Handshake.Companion.handshake +import okhttp3.tls.HeldCertificate +import org.assertj.core.api.Assertions.assertThat +import org.junit.Assert.fail +import org.junit.Test +import java.io.IOException +import java.security.cert.Certificate + +class HandshakeTest { + val serverRoot = HeldCertificate.Builder() + .certificateAuthority(1) + .build() + val serverIntermediate = HeldCertificate.Builder() + .certificateAuthority(0) + .signedBy(serverRoot) + .build() + val serverCertificate = HeldCertificate.Builder() + .signedBy(serverIntermediate) + .build() + + @Test + fun createFromParts() { + val handshake = Handshake.get( + tlsVersion = TlsVersion.TLS_1_3, + cipherSuite = CipherSuite.TLS_AES_128_GCM_SHA256, + peerCertificates = listOf(serverCertificate.certificate, serverIntermediate.certificate), + localCertificates = listOf() + ) + + assertThat(handshake.tlsVersion).isEqualTo(TlsVersion.TLS_1_3) + assertThat(handshake.cipherSuite).isEqualTo(CipherSuite.TLS_AES_128_GCM_SHA256) + assertThat(handshake.peerCertificates).containsExactly( + serverCertificate.certificate, serverIntermediate.certificate) + assertThat(handshake.localPrincipal).isNull() + assertThat(handshake.peerPrincipal) + .isEqualTo(serverCertificate.certificate.subjectX500Principal) + assertThat(handshake.localCertificates).isEmpty() + } + + @Test + fun createFromSslSession() { + val sslSession = FakeSSLSession( + "TLSv1.3", + "TLS_AES_128_GCM_SHA256", + arrayOf(serverCertificate.certificate, serverIntermediate.certificate), + null + ) + + val handshake = sslSession.handshake() + + assertThat(handshake.tlsVersion).isEqualTo(TlsVersion.TLS_1_3) + assertThat(handshake.cipherSuite).isEqualTo(CipherSuite.TLS_AES_128_GCM_SHA256) + assertThat(handshake.peerCertificates).containsExactly( + serverCertificate.certificate, serverIntermediate.certificate) + assertThat(handshake.localPrincipal).isNull() + assertThat(handshake.peerPrincipal) + .isEqualTo(serverCertificate.certificate.subjectX500Principal) + assertThat(handshake.localCertificates).isEmpty() + } + + @Test + fun sslWithNullNullNull() { + val sslSession = FakeSSLSession( + "TLSv1.3", + "SSL_NULL_WITH_NULL_NULL", + arrayOf(serverCertificate.certificate, serverIntermediate.certificate), + null + ) + + try { + sslSession.handshake() + fail() + } catch (expected: IOException) { + assertThat(expected).hasMessage("cipherSuite == SSL_NULL_WITH_NULL_NULL") + } + } + + @Test + fun tlsWithNullNullNull() { + val sslSession = FakeSSLSession( + "TLSv1.3", + "TLS_NULL_WITH_NULL_NULL", + arrayOf(serverCertificate.certificate, serverIntermediate.certificate), + null + ) + + try { + sslSession.handshake() + fail() + } catch (expected: IOException) { + assertThat(expected).hasMessage("cipherSuite == TLS_NULL_WITH_NULL_NULL") + } + } + + class FakeSSLSession( + private val protocol: String, + private val cipherSuite: String, + private val peerCertificates: Array?, + private val localCertificates: Array? + ) : DelegatingSSLSession(null) { + override fun getProtocol() = protocol + + override fun getCipherSuite() = cipherSuite + + override fun getPeerCertificates() = peerCertificates + + override fun getLocalCertificates() = localCertificates + } +}