mirror of
https://github.com/square/okhttp.git
synced 2025-07-31 05:04:26 +03:00
Make separate IPv4 and IPv6 requests for DNS over HTTPS (#4234)
Make separate requests because DNS in practice does not support multiple questions A + AAAA in a single message.
This commit is contained in:
@ -18,11 +18,14 @@ package okhttp3.dnsoverhttps;
|
||||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.net.UnknownHostException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import javax.annotation.Nullable;
|
||||
import okhttp3.CacheControl;
|
||||
import okhttp3.Call;
|
||||
import okhttp3.Callback;
|
||||
import okhttp3.Dns;
|
||||
import okhttp3.HttpUrl;
|
||||
import okhttp3.MediaType;
|
||||
@ -32,6 +35,7 @@ import okhttp3.Request;
|
||||
import okhttp3.RequestBody;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.internal.Util;
|
||||
import okhttp3.internal.platform.Platform;
|
||||
import okhttp3.internal.publicsuffix.PublicSuffixDatabase;
|
||||
import okio.ByteString;
|
||||
@ -113,8 +117,6 @@ public class DnsOverHttps implements Dns {
|
||||
}
|
||||
|
||||
@Override public List<InetAddress> lookup(String hostname) throws UnknownHostException {
|
||||
UnknownHostException firstUhe = null;
|
||||
|
||||
if (!resolvePrivateAddresses || !resolvePublicAddresses) {
|
||||
boolean privateHost = isPrivateHost(hostname);
|
||||
|
||||
@ -131,37 +133,117 @@ public class DnsOverHttps implements Dns {
|
||||
}
|
||||
|
||||
private List<InetAddress> lookupHttps(String hostname) throws UnknownHostException {
|
||||
try {
|
||||
ByteString query = DnsRecordCodec.encodeQuery(hostname, includeIPv6);
|
||||
List<Call> networkRequests = new ArrayList<>(2);
|
||||
List<Exception> failures = new ArrayList<>(2);
|
||||
List<InetAddress> results = new ArrayList<>(5);
|
||||
|
||||
Request request = buildRequest(query);
|
||||
Response response = executeRequest(request);
|
||||
buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_A);
|
||||
|
||||
return readResponse(hostname, response);
|
||||
} catch (UnknownHostException uhe) {
|
||||
throw uhe;
|
||||
} catch (Exception e) {
|
||||
UnknownHostException unknownHostException = new UnknownHostException(hostname);
|
||||
unknownHostException.initCause(e);
|
||||
throw unknownHostException;
|
||||
if (includeIPv6) {
|
||||
buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_AAAA);
|
||||
}
|
||||
|
||||
executeRequests(hostname, networkRequests, results, failures);
|
||||
|
||||
if (!results.isEmpty()) {
|
||||
return results;
|
||||
}
|
||||
|
||||
return throwBestFailure(hostname, failures);
|
||||
}
|
||||
|
||||
private void buildRequest(String hostname, List<Call> networkRequests, List<InetAddress> results,
|
||||
List<Exception> failures, int type) {
|
||||
Request request = buildRequest(hostname, type);
|
||||
Response response = getCacheOnlyResponse(request);
|
||||
|
||||
if (response != null) {
|
||||
processResponse(response, hostname, results, failures);
|
||||
} else {
|
||||
networkRequests.add(client.newCall(request));
|
||||
}
|
||||
}
|
||||
|
||||
private Response executeRequest(Request request) throws IOException {
|
||||
// cached request
|
||||
private void executeRequests(final String hostname, List<Call> networkRequests,
|
||||
final List<InetAddress> responses, final List<Exception> failures) {
|
||||
final CountDownLatch latch = new CountDownLatch(networkRequests.size());
|
||||
|
||||
for (Call call : networkRequests) {
|
||||
call.enqueue(new Callback() {
|
||||
@Override public void onFailure(Call call, IOException e) {
|
||||
synchronized (failures) {
|
||||
failures.add(e);
|
||||
}
|
||||
latch.countDown();
|
||||
}
|
||||
|
||||
@Override public void onResponse(Call call, Response response) {
|
||||
processResponse(response, hostname, responses, failures);
|
||||
latch.countDown();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
latch.await();
|
||||
} catch (InterruptedException e) {
|
||||
failures.add(e);
|
||||
}
|
||||
}
|
||||
|
||||
private void processResponse(Response response, String hostname, List<InetAddress> results,
|
||||
List<Exception> failures) {
|
||||
try {
|
||||
List<InetAddress> addresses = readResponse(hostname, response);
|
||||
synchronized (results) {
|
||||
results.addAll(addresses);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
synchronized (failures) {
|
||||
failures.add(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<InetAddress> throwBestFailure(String hostname, List<Exception> failures)
|
||||
throws UnknownHostException {
|
||||
if (failures.size() == 0) {
|
||||
throw new UnknownHostException(hostname);
|
||||
}
|
||||
|
||||
Exception failure = failures.get(0);
|
||||
|
||||
if (failure instanceof UnknownHostException) {
|
||||
throw (UnknownHostException) failure;
|
||||
}
|
||||
|
||||
UnknownHostException unknownHostException = new UnknownHostException(hostname);
|
||||
unknownHostException.initCause(failure);
|
||||
|
||||
for (int i = 1; i < failures.size(); i++) {
|
||||
Util.addSuppressedIfPossible(unknownHostException, failures.get(i));
|
||||
}
|
||||
|
||||
throw unknownHostException;
|
||||
}
|
||||
|
||||
private @Nullable Response getCacheOnlyResponse(Request request) {
|
||||
if (!post && client.cache() != null) {
|
||||
CacheControl cacheControl =
|
||||
new CacheControl.Builder().maxStale(Integer.MAX_VALUE, TimeUnit.SECONDS).build();
|
||||
Request cacheRequest = request.newBuilder().cacheControl(cacheControl).build();
|
||||
try {
|
||||
Request cacheRequest = request.newBuilder().cacheControl(CacheControl.FORCE_CACHE).build();
|
||||
|
||||
Response response = client.newCall(cacheRequest).execute();
|
||||
Response cacheResponse = client.newCall(cacheRequest).execute();
|
||||
|
||||
if (response.isSuccessful()) {
|
||||
return response;
|
||||
if (cacheResponse.code() != 504) {
|
||||
return cacheResponse;
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
// Failures are ignored as we can fallback to the network
|
||||
// and hopefully repopulate the cache.
|
||||
}
|
||||
}
|
||||
|
||||
return client.newCall(request).execute();
|
||||
return null;
|
||||
}
|
||||
|
||||
private List<InetAddress> readResponse(String hostname, Response response) throws Exception {
|
||||
@ -192,9 +274,11 @@ public class DnsOverHttps implements Dns {
|
||||
}
|
||||
}
|
||||
|
||||
private Request buildRequest(ByteString query) {
|
||||
private Request buildRequest(String hostname, int type) {
|
||||
Request.Builder requestBuilder = new Request.Builder().header("Accept", DNS_MESSAGE.toString());
|
||||
|
||||
ByteString query = DnsRecordCodec.encodeQuery(hostname, type);
|
||||
|
||||
if (post) {
|
||||
requestBuilder = requestBuilder.url(url).post(RequestBody.create(DNS_MESSAGE, query));
|
||||
} else {
|
||||
@ -216,12 +300,14 @@ public class DnsOverHttps implements Dns {
|
||||
@Nullable HttpUrl url = null;
|
||||
boolean includeIPv6 = true;
|
||||
boolean post = false;
|
||||
MediaType contentType = DNS_MESSAGE;
|
||||
Dns systemDns = Dns.SYSTEM;
|
||||
@Nullable List<InetAddress> bootstrapDnsHosts = null;
|
||||
boolean resolvePrivateAddresses = false;
|
||||
boolean resolvePublicAddresses = true;
|
||||
|
||||
public Builder() {
|
||||
}
|
||||
|
||||
public DnsOverHttps build() {
|
||||
return new DnsOverHttps(this);
|
||||
}
|
||||
@ -246,11 +332,6 @@ public class DnsOverHttps implements Dns {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder contentType(MediaType contentType) {
|
||||
this.contentType = contentType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder resolvePrivateAddresses(boolean resolvePrivateAddresses) {
|
||||
this.resolvePrivateAddresses = resolvePrivateAddresses;
|
||||
return this;
|
||||
|
@ -31,20 +31,20 @@ import okio.Utf8;
|
||||
class DnsRecordCodec {
|
||||
private static final byte SERVFAIL = 2;
|
||||
private static final byte NXDOMAIN = 3;
|
||||
private static final int TYPE_A = 0x0001;
|
||||
private static final int TYPE_AAAA = 0x001c;
|
||||
public static final int TYPE_A = 0x0001;
|
||||
public static final int TYPE_AAAA = 0x001c;
|
||||
private static final int TYPE_PTR = 0x000c;
|
||||
private static final Charset ASCII = Charset.forName("ASCII");
|
||||
|
||||
private DnsRecordCodec() {
|
||||
}
|
||||
|
||||
public static ByteString encodeQuery(String host, boolean includeIPv6) {
|
||||
public static ByteString encodeQuery(String host, int type) {
|
||||
Buffer buf = new Buffer();
|
||||
|
||||
buf.writeShort(0); // query id
|
||||
buf.writeShort(256); // flags with recursion
|
||||
buf.writeShort(includeIPv6 ? 2 : 1); // question count
|
||||
buf.writeShort(1); // question count
|
||||
buf.writeShort(0); // answerCount
|
||||
buf.writeShort(0); // authorityResourceCount
|
||||
buf.writeShort(0); // additional
|
||||
@ -62,15 +62,9 @@ class DnsRecordCodec {
|
||||
nameBuf.writeByte(0); // end
|
||||
|
||||
nameBuf.copyTo(buf, 0, nameBuf.size());
|
||||
buf.writeShort(TYPE_A);
|
||||
buf.writeShort(type);
|
||||
buf.writeShort(1); // CLASS_IN
|
||||
|
||||
if (includeIPv6) {
|
||||
nameBuf.copyTo(buf, 0, nameBuf.size());
|
||||
buf.writeShort(TYPE_AAAA);
|
||||
buf.writeShort(1); // CLASS_IN
|
||||
}
|
||||
|
||||
return buf.readByteString();
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,8 @@ import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.net.UnknownHostException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import okhttp3.Cache;
|
||||
import okhttp3.Dns;
|
||||
@ -35,6 +37,7 @@ import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
@ -44,9 +47,8 @@ public class DnsOverHttpsTest {
|
||||
@Rule public final MockWebServer server = new MockWebServer();
|
||||
|
||||
private final OkHttpClient bootstrapClient =
|
||||
new OkHttpClient.Builder().protocols(Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1))
|
||||
.build();
|
||||
private final Dns dns = buildLocalhost(bootstrapClient);
|
||||
new OkHttpClient.Builder().protocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)).build();
|
||||
private Dns dns = buildLocalhost(bootstrapClient, false);
|
||||
|
||||
@Before public void setUp() {
|
||||
server.setProtocols(bootstrapClient.protocols());
|
||||
@ -64,24 +66,38 @@ public class DnsOverHttpsTest {
|
||||
|
||||
RecordedRequest recordedRequest = server.takeRequest();
|
||||
assertEquals("GET", recordedRequest.getMethod());
|
||||
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
|
||||
+ "AAAcAAE", recordedRequest.getPath());
|
||||
assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
|
||||
recordedRequest.getPath());
|
||||
}
|
||||
|
||||
@Test public void getIpv6() throws Exception {
|
||||
server.enqueue(dnsResponse(
|
||||
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001"
|
||||
+ "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010"
|
||||
+ "0010000003b00049df00112"));
|
||||
server.enqueue(dnsResponse(
|
||||
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c00050001"
|
||||
+ "00000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c0"
|
||||
+ "0010000003b00102a032880f0290011faceb00c00000002"));
|
||||
|
||||
dns = buildLocalhost(bootstrapClient, true);
|
||||
|
||||
List<InetAddress> result = dns.lookup("google.com");
|
||||
|
||||
assertEquals(singletonList(address("2a03:2880:f029:11:face:b00c:0:2")), result);
|
||||
assertEquals(2, result.size());
|
||||
assertTrue(result.contains(address("157.240.1.18")));
|
||||
assertTrue(result.contains(address("2a03:2880:f029:11:face:b00c:0:2")));
|
||||
|
||||
RecordedRequest recordedRequest = server.takeRequest();
|
||||
assertEquals("GET", recordedRequest.getMethod());
|
||||
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
|
||||
+ "AAAcAAE", recordedRequest.getPath());
|
||||
RecordedRequest request1 = server.takeRequest();
|
||||
assertEquals("GET", request1.getMethod());
|
||||
|
||||
RecordedRequest request2 = server.takeRequest();
|
||||
assertEquals("GET", request2.getMethod());
|
||||
|
||||
assertEquals(new HashSet<>(
|
||||
Arrays.asList("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
|
||||
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ")),
|
||||
new LinkedHashSet<>(Arrays.asList(request1.getPath(), request2.getPath())));
|
||||
}
|
||||
|
||||
@Test public void failure() throws Exception {
|
||||
@ -100,8 +116,8 @@ public class DnsOverHttpsTest {
|
||||
|
||||
RecordedRequest recordedRequest = server.takeRequest();
|
||||
assertEquals("GET", recordedRequest.getMethod());
|
||||
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
|
||||
+ "AAAcAAE", recordedRequest.getPath());
|
||||
assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
|
||||
recordedRequest.getPath());
|
||||
}
|
||||
|
||||
@Test public void failOnExcessiveResponse() {
|
||||
@ -145,13 +161,12 @@ public class DnsOverHttpsTest {
|
||||
@Test public void usesCache() throws Exception {
|
||||
Cache cache = new Cache(new File("./target/DnsOverHttpsTest.cache"), 100 * 1024);
|
||||
OkHttpClient cachedClient = bootstrapClient.newBuilder().cache(cache).build();
|
||||
DnsOverHttps cachedDns = buildLocalhost(cachedClient);
|
||||
DnsOverHttps cachedDns = buildLocalhost(cachedClient, false);
|
||||
|
||||
server.enqueue(dnsResponse(
|
||||
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001"
|
||||
+ "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010"
|
||||
+ "0010000003b00049df00112")
|
||||
.setHeader("cache-control", "private, max-age=298"));
|
||||
+ "0010000003b00049df00112").setHeader("cache-control", "private, max-age=298"));
|
||||
|
||||
List<InetAddress> result = cachedDns.lookup("google.com");
|
||||
|
||||
@ -159,23 +174,26 @@ public class DnsOverHttpsTest {
|
||||
|
||||
RecordedRequest recordedRequest = server.takeRequest();
|
||||
assertEquals("GET", recordedRequest.getMethod());
|
||||
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
|
||||
+ "AAAcAAE", recordedRequest.getPath());
|
||||
assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
|
||||
recordedRequest.getPath());
|
||||
|
||||
result = cachedDns.lookup("google.com");
|
||||
assertEquals(singletonList(address("157.240.1.18")), result);
|
||||
}
|
||||
|
||||
private MockResponse dnsResponse(String s) {
|
||||
return new MockResponse()
|
||||
.setBody(new Buffer().write(ByteString.decodeHex(s)))
|
||||
return new MockResponse().setBody(new Buffer().write(ByteString.decodeHex(s)))
|
||||
.addHeader("content-type", "application/dns-message")
|
||||
.addHeader("content-length", s.length() / 2);
|
||||
}
|
||||
|
||||
private DnsOverHttps buildLocalhost(OkHttpClient bootstrapClient) {
|
||||
private DnsOverHttps buildLocalhost(OkHttpClient bootstrapClient, boolean includeIPv6) {
|
||||
HttpUrl url = server.url("/lookup?ct");
|
||||
return new DnsOverHttps.Builder().client(bootstrapClient).resolvePrivateAddresses(true).url(url).build();
|
||||
return new DnsOverHttps.Builder().client(bootstrapClient)
|
||||
.includeIPv6(includeIPv6)
|
||||
.resolvePrivateAddresses(true)
|
||||
.url(url)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static InetAddress address(String host) {
|
||||
|
@ -22,24 +22,26 @@ import java.util.List;
|
||||
import okio.ByteString;
|
||||
import org.junit.Test;
|
||||
|
||||
import static okhttp3.dnsoverhttps.DnsRecordCodec.TYPE_A;
|
||||
import static okhttp3.dnsoverhttps.DnsRecordCodec.TYPE_AAAA;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
public class DnsRecordCodecTest {
|
||||
@Test public void testGoogleDotComEncoding() {
|
||||
String encoded = encodeQuery("google.com", false);
|
||||
String encoded = encodeQuery("google.com", TYPE_A);
|
||||
|
||||
assertEquals("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", encoded);
|
||||
}
|
||||
|
||||
private String encodeQuery(String host, boolean includeIpv6) {
|
||||
return DnsRecordCodec.encodeQuery(host, includeIpv6).base64Url().replace("=", "");
|
||||
private String encodeQuery(String host, int type) {
|
||||
return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "");
|
||||
}
|
||||
|
||||
@Test public void testGoogleDotComEncodingWithIPv6() {
|
||||
String encoded = encodeQuery("google.com", true);
|
||||
String encoded = encodeQuery("google.com", TYPE_AAAA);
|
||||
|
||||
assertEquals("AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29tAAAcAAE", encoded);
|
||||
assertEquals("AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ", encoded);
|
||||
}
|
||||
|
||||
@Test public void testGoogleDotComDecodingFromCloudflare() throws Exception {
|
||||
|
Reference in New Issue
Block a user