From a31a5192d1604131d5239fbf2f5280f1ac5eeeaa Mon Sep 17 00:00:00 2001 From: jwilson Date: Sat, 4 Jan 2014 20:31:28 -0500 Subject: [PATCH] Implement simple limits in the dispatcher. This adds Dispatcher to the public API so that application code can tweak the policy. --- .../java/com/squareup/okhttp/Dispatcher.java | 156 +++++++++++++--- .../main/java/com/squareup/okhttp/Job.java | 28 ++- .../com/squareup/okhttp/OkHttpClient.java | 22 ++- .../com/squareup/okhttp/DispatcherTest.java | 167 ++++++++++++++++++ 4 files changed, 336 insertions(+), 37 deletions(-) create mode 100644 okhttp/src/test/java/com/squareup/okhttp/DispatcherTest.java diff --git a/okhttp/src/main/java/com/squareup/okhttp/Dispatcher.java b/okhttp/src/main/java/com/squareup/okhttp/Dispatcher.java index 915bb8757..9868b165b 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/Dispatcher.java +++ b/okhttp/src/main/java/com/squareup/okhttp/Dispatcher.java @@ -15,42 +15,148 @@ */ package com.squareup.okhttp; -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import com.squareup.okhttp.internal.Util; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -final class Dispatcher { - // TODO: thread pool size should be configurable; possibly configurable per host. - private final ThreadPoolExecutor executorService = new ThreadPoolExecutor( - 8, 8, 60, TimeUnit.SECONDS, new LinkedBlockingQueue()); - private final Map> enqueuedJobs = new LinkedHashMap>(); +/** + * Policy on when async requests are executed. + * + *

Each dispatcher uses an {@link Executor} to run jobs internally. If you + * supply your own executor, it should be able to run {@link #getMaxRequests the + * configured maximum} number of jobs concurrently. + */ +public final class Dispatcher { + private int maxRequests = 64; + private int maxRequestsPerHost = 5; - public synchronized void enqueue( - OkHttpClient client, Request request, Response.Receiver responseReceiver) { - Job job = new Job(this, client, request, responseReceiver); - List jobsForTag = enqueuedJobs.get(request.tag()); - if (jobsForTag == null) { - jobsForTag = new ArrayList(2); - enqueuedJobs.put(request.tag(), jobsForTag); - } - jobsForTag.add(job); - executorService.execute(job); + /** Executes jobs. Created lazily. */ + private Executor executor; + + /** Ready jobs in the order they'll be run. */ + private final Deque readyJobs = new ArrayDeque(); + + /** Running jobs. Includes canceled jobs that haven't finished yet. */ + private final Deque runningJobs = new ArrayDeque(); + + public Dispatcher(Executor executor) { + this.executor = executor; } + public Dispatcher() { + } + + public synchronized Executor getExecutor() { + if (executor == null) { + // TODO: name these threads, either here or in the job. + executor = new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60, TimeUnit.SECONDS, + new LinkedBlockingQueue()); + } + return executor; + } + + /** + * Set the maximum number of requests to execute concurrently. Above this + * requests queue in memory, waiting for the running jobs to complete. + * + *

If more than {@code maxRequests} requests are in flight when this is + * invoked, those requests will remain in flight. + */ + public synchronized void setMaxRequests(int maxRequests) { + if (maxRequests < 1) { + throw new IllegalArgumentException("max < 1: " + maxRequests); + } + this.maxRequests = maxRequests; + promoteJobs(); + } + + public synchronized int getMaxRequests() { + return maxRequests; + } + + /** + * Set the maximum number of requests for each host to execute concurrently. + * This limits requests by the URL's host name. Note that concurrent requests + * to a single IP address may still exceed this limit: multiple hostnames may + * share an IP address or be routed through the same HTTP proxy. + * + *

If more than {@code maxRequestsPerHost} requests are in flight when this + * is invoked, those requests will remain in flight. + */ + public synchronized void setMaxRequestsPerHost(int maxRequestsPerHost) { + if (maxRequestsPerHost < 1) { + throw new IllegalArgumentException("max < 1: " + maxRequestsPerHost); + } + this.maxRequestsPerHost = maxRequestsPerHost; + promoteJobs(); + } + + public synchronized int getMaxRequestsPerHost() { + return maxRequestsPerHost; + } + + synchronized void enqueue(OkHttpClient client, Request request, Response.Receiver receiver) { + // Copy the client. Otherwise changes (socket factory, redirect policy, + // etc.) may incorrectly be reflected in the request when it is executed. + client = client.copyWithDefaults(); + Job job = new Job(this, client, request, receiver); + + if (runningJobs.size() < maxRequests && runningJobsForHost(job) < maxRequestsPerHost) { + runningJobs.add(job); + getExecutor().execute(job); + } else { + readyJobs.add(job); + } + } + + /** + * Cancel all jobs with the tag {@code tag}. If a canceled job is running it + * may continue running until it reaches a safe point to finish. + */ public synchronized void cancel(Object tag) { - List jobs = enqueuedJobs.remove(tag); - if (jobs == null) return; - for (Job job : jobs) { - executorService.remove(job); + for (Iterator i = readyJobs.iterator(); i.hasNext(); ) { + if (Util.equal(tag, i.next().tag())) i.remove(); + } + + for (Job job : runningJobs) { + if (Util.equal(tag, job.tag())) job.canceled = true; } } + /** Used by {@code Job#run} to signal completion. */ synchronized void finished(Job job) { - List jobs = enqueuedJobs.get(job.tag()); - if (jobs != null) jobs.remove(job); + if (!runningJobs.remove(job)) throw new AssertionError("Job wasn't running!"); + promoteJobs(); + } + + private void promoteJobs() { + if (runningJobs.size() >= maxRequests) return; // Already running max capacity. + if (readyJobs.isEmpty()) return; // No ready jobs to promote. + + for (Iterator i = readyJobs.iterator(); i.hasNext(); ) { + Job job = i.next(); + + if (runningJobsForHost(job) < maxRequestsPerHost) { + i.remove(); + runningJobs.add(job); + getExecutor().execute(job); + } + + if (runningJobs.size() >= maxRequests) return; // Reached max capacity. + } + } + + /** Returns the number of running jobs that share a host with {@code job}. */ + private int runningJobsForHost(Job job) { + int result = 0; + for (Job j : runningJobs) { + if (j.host().equals(job.host())) result++; + } + return result; } } diff --git a/okhttp/src/main/java/com/squareup/okhttp/Job.java b/okhttp/src/main/java/com/squareup/okhttp/Job.java index a4f05c14b..7f4401a05 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/Job.java +++ b/okhttp/src/main/java/com/squareup/okhttp/Job.java @@ -38,6 +38,8 @@ final class Job implements Runnable { private final OkHttpClient client; private final Response.Receiver responseReceiver; + volatile boolean canceled; + /** The request; possibly a consequence of redirects or auth headers. */ private Request request; @@ -49,6 +51,14 @@ final class Job implements Runnable { this.responseReceiver = responseReceiver; } + String host() { + return request.url().getHost(); + } + + Request request() { + return request; + } + Object tag() { return request.tag(); } @@ -56,7 +66,9 @@ final class Job implements Runnable { @Override public void run() { try { Response response = execute(); - responseReceiver.onResponse(response); + if (response != null && !canceled) { + responseReceiver.onResponse(response); + } } catch (IOException e) { responseReceiver.onFailure(new Failure.Builder() .request(request) @@ -64,16 +76,22 @@ final class Job implements Runnable { .build()); } finally { // TODO: close the response body - // TODO: release the HTTP engine (potentially multiple!) + // TODO: release the HTTP engine dispatcher.finished(this); } } + /** + * Performs the request and returns the response. May return null if this job + * was canceled. + */ private Response execute() throws IOException { Connection connection = null; Response redirectedBy = null; while (true) { + if (canceled) return null; + Request.Body body = request.body(); if (body != null) { MediaType contentType = body.contentType(); @@ -94,7 +112,7 @@ final class Job implements Runnable { request = requestBuilder.build(); } - HttpEngine engine = newEngine(connection); + HttpEngine engine = new HttpEngine(client, request, false, connection, null); engine.sendRequest(); if (body != null) { @@ -124,10 +142,6 @@ final class Job implements Runnable { } } - HttpEngine newEngine(Connection connection) throws IOException { - return new HttpEngine(client, request, false, connection, null); - } - /** * Figures out the HTTP request to make in response to receiving {@code * response}. This will either add authentication headers or follow diff --git a/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java b/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java index 37f83f7dc..764761bf8 100644 --- a/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java +++ b/okhttp/src/main/java/com/squareup/okhttp/OkHttpClient.java @@ -42,7 +42,7 @@ public final class OkHttpClient implements URLStreamHandlerFactory, Cloneable { = Util.immutableList(Arrays.asList("spdy/3", "http/1.1")); private final RouteDatabase routeDatabase; - private final Dispatcher dispatcher; + private Dispatcher dispatcher; private Proxy proxy; private List transports; private ProxySelector proxySelector; @@ -281,6 +281,20 @@ public final class OkHttpClient implements URLStreamHandlerFactory, Cloneable { return routeDatabase; } + /** + * Sets the dispatcher used to set policy and execute asynchronous requests. + * Must not be null. + */ + public OkHttpClient setDispatcher(Dispatcher dispatcher) { + if (dispatcher == null) throw new IllegalArgumentException("dispatcher == null"); + this.dispatcher = dispatcher; + return this; + } + + public Dispatcher getDispatcher() { + return dispatcher; + } + /** * Configure the transports used by this client to communicate with remote * servers. By default this client will prefer the most efficient transport @@ -334,9 +348,7 @@ public final class OkHttpClient implements URLStreamHandlerFactory, Cloneable { * This method is in beta. APIs are subject to change! */ public void enqueue(Request request, Response.Receiver responseReceiver) { - // Copy this client. Otherwise changes (socket factory, redirect policy, - // etc.) may incorrectly be reflected in the request when it is dispatched. - dispatcher.enqueue(copyWithDefaults(), request, responseReceiver); + dispatcher.enqueue(this, request, responseReceiver); } /** @@ -368,7 +380,7 @@ public final class OkHttpClient implements URLStreamHandlerFactory, Cloneable { * Returns a shallow copy of this OkHttpClient that uses the system-wide * default for each field that hasn't been explicitly configured. */ - private OkHttpClient copyWithDefaults() { + OkHttpClient copyWithDefaults() { OkHttpClient result = clone(); if (result.proxySelector == null) { result.proxySelector = ProxySelector.getDefault(); diff --git a/okhttp/src/test/java/com/squareup/okhttp/DispatcherTest.java b/okhttp/src/test/java/com/squareup/okhttp/DispatcherTest.java new file mode 100644 index 000000000..034ed84b6 --- /dev/null +++ b/okhttp/src/test/java/com/squareup/okhttp/DispatcherTest.java @@ -0,0 +1,167 @@ +package com.squareup.okhttp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.Executor; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public final class DispatcherTest { + RecordingExecutor executor = new RecordingExecutor(); + RecordingReceiver receiver = new RecordingReceiver(); + Dispatcher dispatcher = new Dispatcher(executor); + OkHttpClient client = new OkHttpClient().setDispatcher(dispatcher); + + @Before public void setUp() throws Exception { + dispatcher.setMaxRequests(20); + dispatcher.setMaxRequestsPerHost(10); + } + + @Test public void maxRequestsZero() throws Exception { + try { + dispatcher.setMaxRequests(0); + fail(); + } catch (IllegalArgumentException expected) { + } + } + + @Test public void maxPerHostZero() throws Exception { + try { + dispatcher.setMaxRequestsPerHost(0); + fail(); + } catch (IllegalArgumentException expected) { + } + } + + @Test public void enqueuedJobsRunImmediately() throws Exception { + client.enqueue(newRequest("http://a/1"), receiver); + executor.assertJobs("http://a/1"); + } + + @Test public void maxRequestsEnforced() throws Exception { + dispatcher.setMaxRequests(3); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + client.enqueue(newRequest("http://b/1"), receiver); + client.enqueue(newRequest("http://b/2"), receiver); + executor.assertJobs("http://a/1", "http://a/2", "http://b/1"); + } + + @Test public void maxPerHostEnforced() throws Exception { + dispatcher.setMaxRequestsPerHost(2); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + client.enqueue(newRequest("http://a/3"), receiver); + executor.assertJobs("http://a/1", "http://a/2"); + } + + @Test public void increasingMaxRequestsPromotesJobsImmediately() throws Exception { + dispatcher.setMaxRequests(2); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://b/1"), receiver); + client.enqueue(newRequest("http://c/1"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + client.enqueue(newRequest("http://b/2"), receiver); + dispatcher.setMaxRequests(4); + executor.assertJobs("http://a/1", "http://b/1", "http://c/1", "http://a/2"); + } + + @Test public void increasingMaxPerHostPromotesJobsImmediately() throws Exception { + dispatcher.setMaxRequestsPerHost(2); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + client.enqueue(newRequest("http://a/3"), receiver); + client.enqueue(newRequest("http://a/4"), receiver); + client.enqueue(newRequest("http://a/5"), receiver); + dispatcher.setMaxRequestsPerHost(4); + executor.assertJobs("http://a/1", "http://a/2", "http://a/3", "http://a/4"); + } + + @Test public void oldJobFinishesNewJobCanRunDifferentHost() throws Exception { + dispatcher.setMaxRequests(1); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://b/1"), receiver); + executor.finishJob("http://a/1"); + executor.assertJobs("http://b/1"); + } + + @Test public void oldJobFinishesNewJobWithSameHostStarts() throws Exception { + dispatcher.setMaxRequests(2); + dispatcher.setMaxRequestsPerHost(1); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://b/1"), receiver); + client.enqueue(newRequest("http://b/2"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + executor.finishJob("http://a/1"); + executor.assertJobs("http://b/1", "http://a/2"); + } + + @Test public void oldJobFinishesNewJobCantRunDueToHostLimit() throws Exception { + dispatcher.setMaxRequestsPerHost(1); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://b/1"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + executor.finishJob("http://b/1"); + executor.assertJobs("http://a/1"); + } + + @Test public void cancelingReadyJobPreventsItFromStarting() throws Exception { + dispatcher.setMaxRequestsPerHost(1); + client.enqueue(newRequest("http://a/1"), receiver); + client.enqueue(newRequest("http://a/2", "tag1"), receiver); + dispatcher.cancel("tag1"); + executor.finishJob("http://a/1"); + executor.assertJobs(); + } + + @Test public void cancelingRunningJobTakesNoEffectUntilJobFinishes() throws Exception { + dispatcher.setMaxRequests(1); + client.enqueue(newRequest("http://a/1", "tag1"), receiver); + client.enqueue(newRequest("http://a/2"), receiver); + dispatcher.cancel("tag1"); + executor.assertJobs("http://a/1"); + executor.finishJob("http://a/1"); + executor.assertJobs("http://a/2"); + } + + class RecordingExecutor implements Executor { + private List jobs = new ArrayList(); + + @Override public void execute(Runnable command) { + jobs.add((Job) command); + } + + public void assertJobs(String... expectedUrls) { + List actualUrls = new ArrayList(); + for (Job job : jobs) { + actualUrls.add(job.request().urlString()); + } + assertEquals(Arrays.asList(expectedUrls), actualUrls); + } + + public void finishJob(String url) { + for (Iterator i = jobs.iterator(); i.hasNext(); ) { + Job job = i.next(); + if (job.request().urlString().equals(url)) { + i.remove(); + dispatcher.finished(job); + return; + } + } + throw new AssertionError("No such job: " + url); + } + } + + private Request newRequest(String url) { + return new Request.Builder().url(url).build(); + } + + private Request newRequest(String url, String tag) { + return new Request.Builder().url(url).tag(tag).build(); + } +}