1
0
mirror of https://github.com/square/okhttp.git synced 2026-01-17 08:42:25 +03:00

Merge pull request #2285 from square/jwilson_0127_extract_trust_manager

Teach OkHttp to lookup the X509TrustManagerFactory.
This commit is contained in:
Jake Wharton
2016-01-27 01:00:35 -05:00
3 changed files with 123 additions and 42 deletions

View File

@@ -34,51 +34,43 @@ public class DelegatingSSLSocketFactory extends SSLSocketFactory {
this.delegate = delegate;
}
@Override
public SSLSocket createSocket() throws IOException {
@Override public SSLSocket createSocket() throws IOException {
SSLSocket sslSocket = (SSLSocket) delegate.createSocket();
return configureSocket(sslSocket);
}
@Override
public SSLSocket createSocket(String host, int port) throws IOException, UnknownHostException {
@Override public SSLSocket createSocket(String host, int port) throws IOException {
SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port);
return configureSocket(sslSocket);
}
@Override
public SSLSocket createSocket(String host, int port, InetAddress localAddress, int localPort)
throws IOException, UnknownHostException {
@Override public SSLSocket createSocket(
String host, int port, InetAddress localAddress, int localPort) throws IOException {
SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localAddress, localPort);
return configureSocket(sslSocket);
}
@Override
public SSLSocket createSocket(InetAddress host, int port) throws IOException {
@Override public SSLSocket createSocket(InetAddress host, int port) throws IOException {
SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port);
return configureSocket(sslSocket);
}
@Override
public SSLSocket createSocket(InetAddress host, int port, InetAddress localAddress, int localPort)
throws IOException {
@Override public SSLSocket createSocket(
InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException {
SSLSocket sslSocket = (SSLSocket) delegate.createSocket(host, port, localAddress, localPort);
return configureSocket(sslSocket);
}
@Override
public String[] getDefaultCipherSuites() {
@Override public String[] getDefaultCipherSuites() {
return delegate.getDefaultCipherSuites();
}
@Override
public String[] getSupportedCipherSuites() {
@Override public String[] getSupportedCipherSuites() {
return delegate.getSupportedCipherSuites();
}
@Override
public SSLSocket createSocket(Socket socket, String host, int port, boolean autoClose)
throws IOException {
@Override public SSLSocket createSocket(
Socket socket, String host, int port, boolean autoClose) throws IOException {
SSLSocket sslSocket = (SSLSocket) delegate.createSocket(socket, host, port, autoClose);
return configureSocket(sslSocket);
}

View File

@@ -28,8 +28,10 @@ import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.X509TrustManager;
import okhttp3.internal.Internal;
import okhttp3.internal.InternalCache;
import okhttp3.internal.Platform;
import okhttp3.internal.RouteDatabase;
import okhttp3.internal.Util;
import okhttp3.internal.http.StreamAllocation;
@@ -130,6 +132,7 @@ public final class OkHttpClient implements Cloneable, Call.Factory {
final InternalCache internalCache;
final SocketFactory socketFactory;
final SSLSocketFactory sslSocketFactory;
final X509TrustManager trustManager;
final HostnameVerifier hostnameVerifier;
final CertificatePinner certificatePinner;
final Authenticator proxyAuthenticator;
@@ -160,7 +163,7 @@ public final class OkHttpClient implements Cloneable, Call.Factory {
this.internalCache = builder.internalCache;
this.socketFactory = builder.socketFactory;
boolean isTLS = true;
boolean isTLS = false;
for (ConnectionSpec spec : connectionSpecs) {
isTLS = isTLS || spec.isTls();
}
@@ -176,6 +179,16 @@ public final class OkHttpClient implements Cloneable, Call.Factory {
throw new AssertionError(); // The system has no TLS. Just give up.
}
}
if (this.sslSocketFactory != null) {
this.trustManager = Platform.get().trustManager(sslSocketFactory);
if (trustManager == null) {
throw new IllegalStateException("Unable to extract the trust manager on " + Platform.get()
+ ", sslSocketFactory is " + sslSocketFactory.getClass());
}
} else {
this.trustManager = null;
}
this.hostnameVerifier = builder.hostnameVerifier;
this.certificatePinner = builder.certificatePinner;
this.proxyAuthenticator = builder.proxyAuthenticator;

View File

@@ -18,6 +18,7 @@ package okhttp3.internal;
import android.util.Log;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
@@ -29,6 +30,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.X509TrustManager;
import okhttp3.Protocol;
import okio.Buffer;
@@ -55,6 +58,11 @@ import static okhttp3.internal.Internal.logger;
* unstable.
*
* Supported on OpenJDK 7 and 8 (via the JettyALPN-boot library).
*
* <h3>Trust Manager Extraction</h3>
*
* <p>Supported on Android 2.3+ and OpenJDK 7+. There are no public APIs to recover the trust
* manager that was used to create an {@link SSLSocketFactory}.
*/
public class Platform {
private static final Platform PLATFORM = findPlatform();
@@ -78,6 +86,10 @@ public class Platform {
public void untagSocket(Socket socket) throws SocketException {
}
public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) {
return null;
}
/**
* Configure TLS extensions on {@code sslSocket} for {@code route}.
*
@@ -112,11 +124,13 @@ public class Platform {
private static Platform findPlatform() {
// Attempt to find Android 2.3+ APIs.
try {
Class<?> sslParametersClass;
try {
Class.forName("com.android.org.conscrypt.OpenSSLSocketImpl");
sslParametersClass = Class.forName("com.android.org.conscrypt.SSLParametersImpl");
} catch (ClassNotFoundException e) {
// Older platform before being unbundled.
Class.forName("org.apache.harmony.xnet.provider.jsse.OpenSSLSocketImpl");
sslParametersClass = Class.forName(
"org.apache.harmony.xnet.provider.jsse.SSLParametersImpl");
}
OptionalMethod<Socket> setUseSessionTickets
@@ -144,25 +158,34 @@ public class Platform {
} catch (ClassNotFoundException | NoSuchMethodException ignored) {
}
return new Android(setUseSessionTickets, setHostname, trafficStatsTagSocket,
trafficStatsUntagSocket, getAlpnSelectedProtocol, setAlpnProtocols);
return new Android(sslParametersClass, setUseSessionTickets, setHostname,
trafficStatsTagSocket, trafficStatsUntagSocket, getAlpnSelectedProtocol,
setAlpnProtocols);
} catch (ClassNotFoundException ignored) {
// This isn't an Android runtime.
}
// Find Jetty's ALPN extension for OpenJDK.
// Find an Oracle JDK.
try {
String negoClassName = "org.eclipse.jetty.alpn.ALPN";
Class<?> negoClass = Class.forName(negoClassName);
Class<?> providerClass = Class.forName(negoClassName + "$Provider");
Class<?> clientProviderClass = Class.forName(negoClassName + "$ClientProvider");
Class<?> serverProviderClass = Class.forName(negoClassName + "$ServerProvider");
Method putMethod = negoClass.getMethod("put", SSLSocket.class, providerClass);
Method getMethod = negoClass.getMethod("get", SSLSocket.class);
Method removeMethod = negoClass.getMethod("remove", SSLSocket.class);
return new JdkWithJettyBootPlatform(
putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass);
} catch (ClassNotFoundException | NoSuchMethodException ignored) {
Class<?> sslContextClass = Class.forName("sun.security.ssl.SSLContextImpl");
// Find Jetty's ALPN extension for OpenJDK.
try {
String negoClassName = "org.eclipse.jetty.alpn.ALPN";
Class<?> negoClass = Class.forName(negoClassName);
Class<?> providerClass = Class.forName(negoClassName + "$Provider");
Class<?> clientProviderClass = Class.forName(negoClassName + "$ClientProvider");
Class<?> serverProviderClass = Class.forName(negoClassName + "$ServerProvider");
Method putMethod = negoClass.getMethod("put", SSLSocket.class, providerClass);
Method getMethod = negoClass.getMethod("get", SSLSocket.class);
Method removeMethod = negoClass.getMethod("remove", SSLSocket.class);
return new JdkWithJettyBootPlatform(sslContextClass,
putMethod, getMethod, removeMethod, clientProviderClass, serverProviderClass);
} catch (ClassNotFoundException | NoSuchMethodException ignored) {
}
return new JdkPlatform(sslContextClass);
} catch (ClassNotFoundException ignored) {
}
return new Platform();
@@ -172,6 +195,7 @@ public class Platform {
private static class Android extends Platform {
private static final int MAX_LOG_LENGTH = 4000;
private final Class<?> sslParametersClass;
private final OptionalMethod<Socket> setUseSessionTickets;
private final OptionalMethod<Socket> setHostname;
@@ -183,9 +207,11 @@ public class Platform {
private final OptionalMethod<Socket> getAlpnSelectedProtocol;
private final OptionalMethod<Socket> setAlpnProtocols;
public Android(OptionalMethod<Socket> setUseSessionTickets, OptionalMethod<Socket> setHostname,
Method trafficStatsTagSocket, Method trafficStatsUntagSocket,
OptionalMethod<Socket> getAlpnSelectedProtocol, OptionalMethod<Socket> setAlpnProtocols) {
public Android(Class<?> sslParametersClass, OptionalMethod<Socket> setUseSessionTickets,
OptionalMethod<Socket> setHostname, Method trafficStatsTagSocket,
Method trafficStatsUntagSocket, OptionalMethod<Socket> getAlpnSelectedProtocol,
OptionalMethod<Socket> setAlpnProtocols) {
this.sslParametersClass = sslParametersClass;
this.setUseSessionTickets = setUseSessionTickets;
this.setHostname = setHostname;
this.trafficStatsTagSocket = trafficStatsTagSocket;
@@ -210,6 +236,17 @@ public class Platform {
}
}
@Override public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) {
Object context = readFieldOrNull(sslSocketFactory, sslParametersClass, "sslParameters");
if (context == null) return null;
X509TrustManager x509TrustManager = readFieldOrNull(
context, X509TrustManager.class, "x509TrustManager");
if (x509TrustManager != null) return x509TrustManager;
return readFieldOrNull(context, X509TrustManager.class, "trustManager");
}
@Override public void configureTlsExtensions(
SSLSocket sslSocket, String hostname, List<Protocol> protocols) {
// Enable SNI and session tickets.
@@ -271,18 +308,34 @@ public class Platform {
}
}
/** JDK 1.7 or better. */
private static class JdkPlatform extends Platform {
private final Class<?> sslContextClass;
public JdkPlatform(Class<?> sslContextClass) {
this.sslContextClass = sslContextClass;
}
@Override public X509TrustManager trustManager(SSLSocketFactory sslSocketFactory) {
Object context = readFieldOrNull(sslSocketFactory, sslContextClass, "context");
if (context == null) return null;
return readFieldOrNull(context, X509TrustManager.class, "trustManager");
}
}
/**
* OpenJDK 7+ with {@code org.mortbay.jetty.alpn/alpn-boot} in the boot class path.
*/
private static class JdkWithJettyBootPlatform extends Platform {
private static class JdkWithJettyBootPlatform extends JdkPlatform {
private final Method putMethod;
private final Method getMethod;
private final Method removeMethod;
private final Class<?> clientProviderClass;
private final Class<?> serverProviderClass;
public JdkWithJettyBootPlatform(Method putMethod, Method getMethod, Method removeMethod,
Class<?> clientProviderClass, Class<?> serverProviderClass) {
public JdkWithJettyBootPlatform(Class<?> sslContextClass, Method putMethod, Method getMethod,
Method removeMethod, Class<?> clientProviderClass, Class<?> serverProviderClass) {
super(sslContextClass);
this.putMethod = putMethod;
this.getMethod = getMethod;
this.removeMethod = removeMethod;
@@ -394,4 +447,27 @@ public class Platform {
}
return result.readByteArray();
}
static <T> T readFieldOrNull(Object instance, Class<T> fieldType, String fieldName) {
for (Class<?> c = instance.getClass(); c != Object.class; c = c.getSuperclass()) {
try {
Field field = c.getDeclaredField(fieldName);
field.setAccessible(true);
Object value = field.get(instance);
if (value == null || !fieldType.isInstance(value)) return null;
return fieldType.cast(value);
} catch (NoSuchFieldException ignored) {
} catch (IllegalAccessException e) {
throw new AssertionError();
}
}
// Didn't find the field we wanted. As a last gasp attempt, try to find the value on a delegate.
if (!fieldName.equals("delegate")) {
Object delegate = readFieldOrNull(instance, Object.class, "delegate");
if (delegate != null) return readFieldOrNull(delegate, fieldType, fieldName);
}
return null;
}
}