diff --git a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Deadline.java b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Deadline.java index ed8572a4f..88188433d 100644 --- a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Deadline.java +++ b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Deadline.java @@ -15,6 +15,7 @@ */ package com.squareup.okhttp.internal.bytes; +import java.io.IOException; import java.util.concurrent.TimeUnit; /** @@ -45,4 +46,9 @@ public class Deadline { public boolean reached() { return System.nanoTime() - deadlineNanos >= 0; // Subtract to avoid overflow! } + + public void throwIfReached() throws IOException { + // TODO: a more catchable exception type? + if (reached()) throw new IOException("Deadline reached"); + } } diff --git a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffer.java b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffer.java index f07c1bc84..8a646337b 100644 --- a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffer.java +++ b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffer.java @@ -40,8 +40,8 @@ public final class OkBuffer implements Source, Sink { private static final char[] HEX_DIGITS = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' }; - private Segment head; - private long byteCount; + Segment head; + long byteCount; public OkBuffer() { } @@ -62,10 +62,7 @@ public final class OkBuffer implements Source, Sink { } private byte[] readBytes(int byteCount) { - if (byteCount > this.byteCount) { - throw new IllegalArgumentException( - String.format("requested %s > available %s", byteCount, this.byteCount)); - } + checkByteCount(byteCount); int offset = 0; byte[] result = new byte[byteCount]; @@ -101,15 +98,7 @@ public final class OkBuffer implements Source, Sink { private void write(byte[] data) { int offset = 0; while (offset < data.length) { - if (head == null) { - head = SegmentPool.INSTANCE.take(); // Acquire a first segment. - head.next = head.prev = head; - } - - Segment tail = head.prev; - if (tail.limit == Segment.SIZE) { - tail = tail.push(SegmentPool.INSTANCE.take()); // Append a new empty segment to fill up. - } + Segment tail = writableSegment(); int toCopy = Math.min(data.length - offset, Segment.SIZE - tail.limit); System.arraycopy(data, offset, tail.data, tail.limit, toCopy); @@ -121,6 +110,20 @@ public final class OkBuffer implements Source, Sink { this.byteCount += data.length; } + /** Returns a tail segment that we can write bytes to, creating it if necessary. */ + Segment writableSegment() { + if (head == null) { + head = SegmentPool.INSTANCE.take(); // Acquire a first segment. + return head.next = head.prev = head; + } + + Segment tail = head.prev; + if (tail.limit == Segment.SIZE) { + tail = tail.push(SegmentPool.INSTANCE.take()); // Append a new empty segment to fill up. + } + return tail; + } + @Override public void write(OkBuffer source, long byteCount, Deadline deadline) { // Move bytes from the head of the source buffer to the tail of this buffer // while balancing two conflicting goals: don't waste CPU and don't waste @@ -173,10 +176,7 @@ public final class OkBuffer implements Source, Sink { // yielding sink [51%, 91%, 30%] and source [62%, 82%]. if (source == this) throw new IllegalArgumentException("source == this"); - if (byteCount > source.byteCount) { - throw new IllegalArgumentException( - String.format("requested %s > available %s", byteCount, this.byteCount)); - } + source.checkByteCount(byteCount); while (byteCount > 0) { // Is a prefix of the source's head segment all that we need to move? @@ -214,14 +214,17 @@ public final class OkBuffer implements Source, Sink { } @Override public long read(OkBuffer sink, long byteCount, Deadline deadline) throws IOException { - if (byteCount < 0) throw new IllegalArgumentException("byteCount < 0: " + byteCount); if (this.byteCount == 0) return -1L; if (byteCount > this.byteCount) byteCount = this.byteCount; sink.write(this, byteCount, deadline); return byteCount; } - @Override public long indexOf(byte b, Deadline deadline) throws IOException { + /** + * Returns the index of {@code b} in this, or -1 if this buffer does not + * contain {@code b}. + */ + public long indexOf(byte b) throws IOException { Segment s = head; if (s == null) return -1L; long offset = 0L; @@ -272,4 +275,15 @@ public final class OkBuffer implements Source, Sink { } return new String(result); } + + /** Throws if this has fewer bytes than {@code requested}. */ + void checkByteCount(long requested) { + if (requested < 0) { + throw new IllegalArgumentException("requested < 0: " + requested); + } + if (requested > this.byteCount) { + throw new IllegalArgumentException( + String.format("requested %s > available %s", requested, this.byteCount)); + } + } } diff --git a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffers.java b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffers.java new file mode 100644 index 000000000..0f8583953 --- /dev/null +++ b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/OkBuffers.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.squareup.okhttp.internal.bytes; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +public final class OkBuffers { + private OkBuffers() { + } + + /** Returns a sink that writes to {@code out}. */ + public static Sink sink(final OutputStream out) { + return new Sink() { + @Override public void write(OkBuffer source, long byteCount, Deadline deadline) + throws IOException { + source.checkByteCount(byteCount); + while (byteCount > 0) { + deadline.throwIfReached(); + Segment head = source.head; + int toCopy = (int) Math.min(byteCount, head.limit - head.pos); + out.write(head.data, head.pos, toCopy); + + head.pos += toCopy; + byteCount -= toCopy; + source.byteCount -= toCopy; + + if (head.pos == head.limit) { + source.head = head.pop(); + SegmentPool.INSTANCE.recycle(head); + } + } + } + + @Override public void flush(Deadline deadline) throws IOException { + out.flush(); + } + + @Override public void close(Deadline deadline) throws IOException { + out.close(); + } + + @Override public String toString() { + return "sink(" + out + ")"; + } + }; + } + + /** Returns a source that reads from {@code in}. */ + public static Source source(final InputStream in) { + return new Source() { + @Override public long read( + OkBuffer sink, long byteCount, Deadline deadline) throws IOException { + if (byteCount < 0) throw new IllegalArgumentException("byteCount < 0: " + byteCount); + deadline.throwIfReached(); + Segment tail = sink.writableSegment(); + int maxToCopy = (int) Math.min(byteCount, Segment.SIZE - tail.limit); + int bytesRead = in.read(tail.data, tail.limit, maxToCopy); + if (bytesRead == -1) return -1; + tail.limit += bytesRead; + sink.byteCount += bytesRead; + return bytesRead; + } + + @Override public void close(Deadline deadline) throws IOException { + in.close(); + } + + @Override public String toString() { + return "source(" + in + ")"; + } + }; + } +} diff --git a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Source.java b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Source.java index 233cd5875..5b9a87a00 100644 --- a/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Source.java +++ b/okhttp-protocols/src/main/java/com/squareup/okhttp/internal/bytes/Source.java @@ -22,17 +22,12 @@ import java.io.IOException; */ public interface Source { /** - * Removes {@code byteCount} bytes from this and appends them to {@code sink}. - * Returns the number of bytes actually written. + * Removes at least 1, and up to {@code byteCount} bytes from this and appends + * them to {@code sink}. Returns the number of bytes read, or -1 if this + * source is exhausted. */ long read(OkBuffer sink, long byteCount, Deadline deadline) throws IOException; - /** - * Returns the index of {@code b} in this, or -1 if this source is exhausted - * first. This may cause this source to buffer a large number of bytes. - */ - long indexOf(byte b, Deadline deadline) throws IOException; - /** * Closes this source and releases the resources held by this source. It is an * error to read a closed source. It is safe to close a source more than once. diff --git a/okhttp-protocols/src/test/java/com/squareup/okhttp/internal/bytes/OkBufferTest.java b/okhttp-protocols/src/test/java/com/squareup/okhttp/internal/bytes/OkBufferTest.java index fbb1e97e3..f91c180d0 100644 --- a/okhttp-protocols/src/test/java/com/squareup/okhttp/internal/bytes/OkBufferTest.java +++ b/okhttp-protocols/src/test/java/com/squareup/okhttp/internal/bytes/OkBufferTest.java @@ -15,10 +15,14 @@ */ package com.squareup.okhttp.internal.bytes; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; import java.util.Arrays; import java.util.List; import org.junit.Test; +import static com.squareup.okhttp.internal.Util.UTF_8; import static java.util.Arrays.asList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -236,42 +240,80 @@ public final class OkBufferTest { OkBuffer buffer = new OkBuffer(); // The segment is empty. - assertEquals(-1, buffer.indexOf((byte) 'a', Deadline.NONE)); + assertEquals(-1, buffer.indexOf((byte) 'a')); // The segment has one value. buffer.writeUtf8("a"); // a - assertEquals(0, buffer.indexOf((byte) 'a', Deadline.NONE)); - assertEquals(-1, buffer.indexOf((byte) 'b', Deadline.NONE)); + assertEquals(0, buffer.indexOf((byte) 'a')); + assertEquals(-1, buffer.indexOf((byte) 'b')); // The segment has lots of data. buffer.writeUtf8(repeat('b', Segment.SIZE - 2)); // ab...b - assertEquals(0, buffer.indexOf((byte) 'a', Deadline.NONE)); - assertEquals(1, buffer.indexOf((byte) 'b', Deadline.NONE)); - assertEquals(-1, buffer.indexOf((byte) 'c', Deadline.NONE)); + assertEquals(0, buffer.indexOf((byte) 'a')); + assertEquals(1, buffer.indexOf((byte) 'b')); + assertEquals(-1, buffer.indexOf((byte) 'c')); // The segment doesn't start at 0, it starts at 2. buffer.readUtf8(2); // b...b - assertEquals(-1, buffer.indexOf((byte) 'a', Deadline.NONE)); - assertEquals(0, buffer.indexOf((byte) 'b', Deadline.NONE)); - assertEquals(-1, buffer.indexOf((byte) 'c', Deadline.NONE)); + assertEquals(-1, buffer.indexOf((byte) 'a')); + assertEquals(0, buffer.indexOf((byte) 'b')); + assertEquals(-1, buffer.indexOf((byte) 'c')); // The segment is full. buffer.writeUtf8("c"); // b...bc - assertEquals(-1, buffer.indexOf((byte) 'a', Deadline.NONE)); - assertEquals(0, buffer.indexOf((byte) 'b', Deadline.NONE)); - assertEquals(Segment.SIZE - 3, buffer.indexOf((byte) 'c', Deadline.NONE)); + assertEquals(-1, buffer.indexOf((byte) 'a')); + assertEquals(0, buffer.indexOf((byte) 'b')); + assertEquals(Segment.SIZE - 3, buffer.indexOf((byte) 'c')); // The segment doesn't start at 2, it starts at 4. buffer.readUtf8(2); // b...bc - assertEquals(-1, buffer.indexOf((byte) 'a', Deadline.NONE)); - assertEquals(0, buffer.indexOf((byte) 'b', Deadline.NONE)); - assertEquals(Segment.SIZE - 5, buffer.indexOf((byte) 'c', Deadline.NONE)); + assertEquals(-1, buffer.indexOf((byte) 'a')); + assertEquals(0, buffer.indexOf((byte) 'b')); + assertEquals(Segment.SIZE - 5, buffer.indexOf((byte) 'c')); // Two segments. buffer.writeUtf8("d"); // b...bcd, d is in the 2nd segment. assertEquals(asList(Segment.SIZE - 4, 1), buffer.segmentSizes()); - assertEquals(Segment.SIZE - 4, buffer.indexOf((byte) 'd', Deadline.NONE)); - assertEquals(-1, buffer.indexOf((byte) 'e', Deadline.NONE)); + assertEquals(Segment.SIZE - 4, buffer.indexOf((byte) 'd')); + assertEquals(-1, buffer.indexOf((byte) 'e')); + } + + @Test public void sinkFromOutputStream() throws Exception { + OkBuffer data = new OkBuffer(); + data.writeUtf8("a"); + data.writeUtf8(repeat('b', 9998)); + data.writeUtf8("c"); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + Sink sink = OkBuffers.sink(out); + sink.write(data, 3, Deadline.NONE); + assertEquals("abb", out.toString("UTF-8")); + sink.write(data, data.byteCount(), Deadline.NONE); + assertEquals("a" + repeat('b', 9998) + "c", out.toString("UTF-8")); + } + + @Test public void sourceFromInputStream() throws Exception { + InputStream in = new ByteArrayInputStream( + ("a" + repeat('b', Segment.SIZE * 2) + "c").getBytes(UTF_8)); + + // Source: ab...bc + Source source = OkBuffers.source(in); + OkBuffer sink = new OkBuffer(); + + // Source: b...bc. Sink: abb. + assertEquals(3, source.read(sink, 3, Deadline.NONE)); + assertEquals("abb", sink.readUtf8(3)); + + // Source: b...bc. Sink: b...b. + assertEquals(Segment.SIZE, source.read(sink, 20000, Deadline.NONE)); + assertEquals(repeat('b', Segment.SIZE), sink.readUtf8((int) sink.byteCount())); + + // Source: b...bc. Sink: b...bc. + assertEquals(Segment.SIZE - 1, source.read(sink, 20000, Deadline.NONE)); + assertEquals(repeat('b', Segment.SIZE - 2) + "c", sink.readUtf8((int) sink.byteCount())); + + // Source and sink are empty. + assertEquals(-1, source.read(sink, 1, Deadline.NONE)); } private String repeat(char c, int count) {