1
0
mirror of https://github.com/square/okhttp.git synced 2025-07-31 05:04:26 +03:00

Make separate IPv4 and IPv6 requests for DNS over HTTPS (#4234)

Make separate requests because DNS in practice does not support multiple questions A + AAAA in a single message.
This commit is contained in:
Yuri Schimke
2018-09-02 19:27:32 +01:00
committed by GitHub
parent 077281796e
commit 0b2486f7f3
4 changed files with 162 additions and 67 deletions

View File

@ -18,11 +18,14 @@ package okhttp3.dnsoverhttps;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.CountDownLatch;
import javax.annotation.Nullable;
import okhttp3.CacheControl;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Dns;
import okhttp3.HttpUrl;
import okhttp3.MediaType;
@ -32,6 +35,7 @@ import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.internal.Util;
import okhttp3.internal.platform.Platform;
import okhttp3.internal.publicsuffix.PublicSuffixDatabase;
import okio.ByteString;
@ -113,8 +117,6 @@ public class DnsOverHttps implements Dns {
}
@Override public List<InetAddress> lookup(String hostname) throws UnknownHostException {
UnknownHostException firstUhe = null;
if (!resolvePrivateAddresses || !resolvePublicAddresses) {
boolean privateHost = isPrivateHost(hostname);
@ -131,37 +133,117 @@ public class DnsOverHttps implements Dns {
}
private List<InetAddress> lookupHttps(String hostname) throws UnknownHostException {
try {
ByteString query = DnsRecordCodec.encodeQuery(hostname, includeIPv6);
List<Call> networkRequests = new ArrayList<>(2);
List<Exception> failures = new ArrayList<>(2);
List<InetAddress> results = new ArrayList<>(5);
Request request = buildRequest(query);
Response response = executeRequest(request);
buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_A);
return readResponse(hostname, response);
} catch (UnknownHostException uhe) {
throw uhe;
} catch (Exception e) {
UnknownHostException unknownHostException = new UnknownHostException(hostname);
unknownHostException.initCause(e);
throw unknownHostException;
if (includeIPv6) {
buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_AAAA);
}
executeRequests(hostname, networkRequests, results, failures);
if (!results.isEmpty()) {
return results;
}
return throwBestFailure(hostname, failures);
}
private void buildRequest(String hostname, List<Call> networkRequests, List<InetAddress> results,
List<Exception> failures, int type) {
Request request = buildRequest(hostname, type);
Response response = getCacheOnlyResponse(request);
if (response != null) {
processResponse(response, hostname, results, failures);
} else {
networkRequests.add(client.newCall(request));
}
}
private Response executeRequest(Request request) throws IOException {
// cached request
private void executeRequests(final String hostname, List<Call> networkRequests,
final List<InetAddress> responses, final List<Exception> failures) {
final CountDownLatch latch = new CountDownLatch(networkRequests.size());
for (Call call : networkRequests) {
call.enqueue(new Callback() {
@Override public void onFailure(Call call, IOException e) {
synchronized (failures) {
failures.add(e);
}
latch.countDown();
}
@Override public void onResponse(Call call, Response response) {
processResponse(response, hostname, responses, failures);
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
failures.add(e);
}
}
private void processResponse(Response response, String hostname, List<InetAddress> results,
List<Exception> failures) {
try {
List<InetAddress> addresses = readResponse(hostname, response);
synchronized (results) {
results.addAll(addresses);
}
} catch (Exception e) {
synchronized (failures) {
failures.add(e);
}
}
}
private List<InetAddress> throwBestFailure(String hostname, List<Exception> failures)
throws UnknownHostException {
if (failures.size() == 0) {
throw new UnknownHostException(hostname);
}
Exception failure = failures.get(0);
if (failure instanceof UnknownHostException) {
throw (UnknownHostException) failure;
}
UnknownHostException unknownHostException = new UnknownHostException(hostname);
unknownHostException.initCause(failure);
for (int i = 1; i < failures.size(); i++) {
Util.addSuppressedIfPossible(unknownHostException, failures.get(i));
}
throw unknownHostException;
}
private @Nullable Response getCacheOnlyResponse(Request request) {
if (!post && client.cache() != null) {
CacheControl cacheControl =
new CacheControl.Builder().maxStale(Integer.MAX_VALUE, TimeUnit.SECONDS).build();
Request cacheRequest = request.newBuilder().cacheControl(cacheControl).build();
try {
Request cacheRequest = request.newBuilder().cacheControl(CacheControl.FORCE_CACHE).build();
Response response = client.newCall(cacheRequest).execute();
Response cacheResponse = client.newCall(cacheRequest).execute();
if (response.isSuccessful()) {
return response;
if (cacheResponse.code() != 504) {
return cacheResponse;
}
} catch (IOException ioe) {
// Failures are ignored as we can fallback to the network
// and hopefully repopulate the cache.
}
}
return client.newCall(request).execute();
return null;
}
private List<InetAddress> readResponse(String hostname, Response response) throws Exception {
@ -192,9 +274,11 @@ public class DnsOverHttps implements Dns {
}
}
private Request buildRequest(ByteString query) {
private Request buildRequest(String hostname, int type) {
Request.Builder requestBuilder = new Request.Builder().header("Accept", DNS_MESSAGE.toString());
ByteString query = DnsRecordCodec.encodeQuery(hostname, type);
if (post) {
requestBuilder = requestBuilder.url(url).post(RequestBody.create(DNS_MESSAGE, query));
} else {
@ -216,12 +300,14 @@ public class DnsOverHttps implements Dns {
@Nullable HttpUrl url = null;
boolean includeIPv6 = true;
boolean post = false;
MediaType contentType = DNS_MESSAGE;
Dns systemDns = Dns.SYSTEM;
@Nullable List<InetAddress> bootstrapDnsHosts = null;
boolean resolvePrivateAddresses = false;
boolean resolvePublicAddresses = true;
public Builder() {
}
public DnsOverHttps build() {
return new DnsOverHttps(this);
}
@ -246,11 +332,6 @@ public class DnsOverHttps implements Dns {
return this;
}
public Builder contentType(MediaType contentType) {
this.contentType = contentType;
return this;
}
public Builder resolvePrivateAddresses(boolean resolvePrivateAddresses) {
this.resolvePrivateAddresses = resolvePrivateAddresses;
return this;

View File

@ -31,20 +31,20 @@ import okio.Utf8;
class DnsRecordCodec {
private static final byte SERVFAIL = 2;
private static final byte NXDOMAIN = 3;
private static final int TYPE_A = 0x0001;
private static final int TYPE_AAAA = 0x001c;
public static final int TYPE_A = 0x0001;
public static final int TYPE_AAAA = 0x001c;
private static final int TYPE_PTR = 0x000c;
private static final Charset ASCII = Charset.forName("ASCII");
private DnsRecordCodec() {
}
public static ByteString encodeQuery(String host, boolean includeIPv6) {
public static ByteString encodeQuery(String host, int type) {
Buffer buf = new Buffer();
buf.writeShort(0); // query id
buf.writeShort(256); // flags with recursion
buf.writeShort(includeIPv6 ? 2 : 1); // question count
buf.writeShort(1); // question count
buf.writeShort(0); // answerCount
buf.writeShort(0); // authorityResourceCount
buf.writeShort(0); // additional
@ -62,15 +62,9 @@ class DnsRecordCodec {
nameBuf.writeByte(0); // end
nameBuf.copyTo(buf, 0, nameBuf.size());
buf.writeShort(TYPE_A);
buf.writeShort(type);
buf.writeShort(1); // CLASS_IN
if (includeIPv6) {
nameBuf.copyTo(buf, 0, nameBuf.size());
buf.writeShort(TYPE_AAAA);
buf.writeShort(1); // CLASS_IN
}
return buf.readByteString();
}

View File

@ -20,6 +20,8 @@ import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import okhttp3.Cache;
import okhttp3.Dns;
@ -35,6 +37,7 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@ -44,9 +47,8 @@ public class DnsOverHttpsTest {
@Rule public final MockWebServer server = new MockWebServer();
private final OkHttpClient bootstrapClient =
new OkHttpClient.Builder().protocols(Arrays.asList(Protocol.HTTP_2, Protocol.HTTP_1_1))
.build();
private final Dns dns = buildLocalhost(bootstrapClient);
new OkHttpClient.Builder().protocols(asList(Protocol.HTTP_2, Protocol.HTTP_1_1)).build();
private Dns dns = buildLocalhost(bootstrapClient, false);
@Before public void setUp() {
server.setProtocols(bootstrapClient.protocols());
@ -64,24 +66,38 @@ public class DnsOverHttpsTest {
RecordedRequest recordedRequest = server.takeRequest();
assertEquals("GET", recordedRequest.getMethod());
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
+ "AAAcAAE", recordedRequest.getPath());
assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
recordedRequest.getPath());
}
@Test public void getIpv6() throws Exception {
server.enqueue(dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001"
+ "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010"
+ "0010000003b00049df00112"));
server.enqueue(dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d00001c0001c00c00050001"
+ "00000a1b000603617069c012c0300005000100000b1f000c04737461720463313072c012c042001c0"
+ "0010000003b00102a032880f0290011faceb00c00000002"));
dns = buildLocalhost(bootstrapClient, true);
List<InetAddress> result = dns.lookup("google.com");
assertEquals(singletonList(address("2a03:2880:f029:11:face:b00c:0:2")), result);
assertEquals(2, result.size());
assertTrue(result.contains(address("157.240.1.18")));
assertTrue(result.contains(address("2a03:2880:f029:11:face:b00c:0:2")));
RecordedRequest recordedRequest = server.takeRequest();
assertEquals("GET", recordedRequest.getMethod());
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
+ "AAAcAAE", recordedRequest.getPath());
RecordedRequest request1 = server.takeRequest();
assertEquals("GET", request1.getMethod());
RecordedRequest request2 = server.takeRequest();
assertEquals("GET", request2.getMethod());
assertEquals(new HashSet<>(
Arrays.asList("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
"/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ")),
new LinkedHashSet<>(Arrays.asList(request1.getPath(), request2.getPath())));
}
@Test public void failure() throws Exception {
@ -100,8 +116,8 @@ public class DnsOverHttpsTest {
RecordedRequest recordedRequest = server.takeRequest();
assertEquals("GET", recordedRequest.getMethod());
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
+ "AAAcAAE", recordedRequest.getPath());
assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
recordedRequest.getPath());
}
@Test public void failOnExcessiveResponse() {
@ -145,13 +161,12 @@ public class DnsOverHttpsTest {
@Test public void usesCache() throws Exception {
Cache cache = new Cache(new File("./target/DnsOverHttpsTest.cache"), 100 * 1024);
OkHttpClient cachedClient = bootstrapClient.newBuilder().cache(cache).build();
DnsOverHttps cachedDns = buildLocalhost(cachedClient);
DnsOverHttps cachedDns = buildLocalhost(cachedClient, false);
server.enqueue(dnsResponse(
"0000818000010003000000000567726170680866616365626f6f6b03636f6d0000010001c00c00050001"
+ "00000a6d000603617069c012c0300005000100000cde000c04737461720463313072c012c04200010"
+ "0010000003b00049df00112")
.setHeader("cache-control", "private, max-age=298"));
+ "0010000003b00049df00112").setHeader("cache-control", "private, max-age=298"));
List<InetAddress> result = cachedDns.lookup("google.com");
@ -159,23 +174,26 @@ public class DnsOverHttpsTest {
RecordedRequest recordedRequest = server.takeRequest();
assertEquals("GET", recordedRequest.getMethod());
assertEquals("/lookup?ct&dns=AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29t"
+ "AAAcAAE", recordedRequest.getPath());
assertEquals("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ",
recordedRequest.getPath());
result = cachedDns.lookup("google.com");
assertEquals(singletonList(address("157.240.1.18")), result);
}
private MockResponse dnsResponse(String s) {
return new MockResponse()
.setBody(new Buffer().write(ByteString.decodeHex(s)))
return new MockResponse().setBody(new Buffer().write(ByteString.decodeHex(s)))
.addHeader("content-type", "application/dns-message")
.addHeader("content-length", s.length() / 2);
}
private DnsOverHttps buildLocalhost(OkHttpClient bootstrapClient) {
private DnsOverHttps buildLocalhost(OkHttpClient bootstrapClient, boolean includeIPv6) {
HttpUrl url = server.url("/lookup?ct");
return new DnsOverHttps.Builder().client(bootstrapClient).resolvePrivateAddresses(true).url(url).build();
return new DnsOverHttps.Builder().client(bootstrapClient)
.includeIPv6(includeIPv6)
.resolvePrivateAddresses(true)
.url(url)
.build();
}
private static InetAddress address(String host) {

View File

@ -22,24 +22,26 @@ import java.util.List;
import okio.ByteString;
import org.junit.Test;
import static okhttp3.dnsoverhttps.DnsRecordCodec.TYPE_A;
import static okhttp3.dnsoverhttps.DnsRecordCodec.TYPE_AAAA;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class DnsRecordCodecTest {
@Test public void testGoogleDotComEncoding() {
String encoded = encodeQuery("google.com", false);
String encoded = encodeQuery("google.com", TYPE_A);
assertEquals("AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", encoded);
}
private String encodeQuery(String host, boolean includeIpv6) {
return DnsRecordCodec.encodeQuery(host, includeIpv6).base64Url().replace("=", "");
private String encodeQuery(String host, int type) {
return DnsRecordCodec.encodeQuery(host, type).base64Url().replace("=", "");
}
@Test public void testGoogleDotComEncodingWithIPv6() {
String encoded = encodeQuery("google.com", true);
String encoded = encodeQuery("google.com", TYPE_AAAA);
assertEquals("AAABAAACAAAAAAAABmdvb2dsZQNjb20AAAEAAQZnb29nbGUDY29tAAAcAAE", encoded);
assertEquals("AAABAAABAAAAAAAABmdvb2dsZQNjb20AABwAAQ", encoded);
}
@Test public void testGoogleDotComDecodingFromCloudflare() throws Exception {