diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/FaultRecoveringOutputStream.java b/okhttp/src/main/java/com/squareup/okhttp/internal/FaultRecoveringOutputStream.java new file mode 100644 index 000000000..d4e7507ab --- /dev/null +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/FaultRecoveringOutputStream.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2013 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.squareup.okhttp.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import static com.squareup.okhttp.internal.Util.checkOffsetAndCount; + +/** + * An output stream wrapper that recovers from failures in the underlying stream + * by replacing it with another stream. This class buffers a fixed amount of + * data under the assumption that failures occur early in a stream's life. + * If a failure occurs after the buffer has been exhausted, no recovery is + * attempted. + * + *
Subclasses must override {@link #replacementStream} which will request a
+ * replacement stream each time an {@link IOException} is encountered on the
+ * current stream.
+ */
+public abstract class FaultRecoveringOutputStream extends OutputStream {
+ private final int maxReplayBufferLength;
+
+ /** Bytes to transmit on the replacement stream, or null if no recovery is possible. */
+ private ByteArrayOutputStream replayBuffer;
+ private OutputStream out;
+ private boolean closed;
+
+ /**
+ * @param maxReplayBufferLength the maximum number of successfully written
+ * bytes to buffer so they can be replayed in the event of an error.
+ * Failure recoveries are not possible once this limit has been exceeded.
+ */
+ public FaultRecoveringOutputStream(int maxReplayBufferLength, OutputStream out) {
+ if (maxReplayBufferLength < 0) throw new IllegalArgumentException();
+ this.maxReplayBufferLength = maxReplayBufferLength;
+ this.replayBuffer = new ByteArrayOutputStream(maxReplayBufferLength);
+ this.out = out;
+ }
+
+ @Override public final void write(int data) throws IOException {
+ write(new byte[] { (byte) data });
+ }
+
+ @Override public final void write(byte[] buffer, int offset, int count) throws IOException {
+ if (closed) throw new IOException("stream closed");
+ checkOffsetAndCount(buffer.length, offset, count);
+
+ while (true) {
+ try {
+ out.write(buffer, offset, count);
+
+ if (replayBuffer != null) {
+ if (count + replayBuffer.size() > maxReplayBufferLength) {
+ // Failure recovery is no longer possible once we overflow the replay buffer.
+ replayBuffer = null;
+ } else {
+ // Remember the written bytes to the replay buffer.
+ replayBuffer.write(buffer, offset, count);
+ }
+ }
+ return;
+ } catch (IOException e) {
+ if (!recover(e)) throw e;
+ }
+ }
+ }
+
+ @Override public final void flush() throws IOException {
+ if (closed) {
+ return; // don't throw; this stream might have been closed on the caller's behalf
+ }
+ out.flush();
+ }
+
+ @Override public final void close() throws IOException {
+ if (closed) {
+ return;
+ }
+ out.close();
+ closed = true;
+ }
+
+ /**
+ * Attempt to replace {@code out} with another equivalent stream. Returns true
+ * if a suitable replacement stream was found.
+ */
+ private boolean recover(IOException e) {
+ if (replayBuffer == null) {
+ return false; // Can't recover because we've dropped data that we would need to replay.
+ }
+
+ while (true) {
+ OutputStream replacementStream = replacementStream(e);
+ if (replacementStream == null) {
+ return false;
+ }
+ try {
+ replayBuffer.writeTo(replacementStream);
+ // We've found a replacement that works!
+ Util.closeQuietly(out);
+ out = replacementStream;
+ return true;
+ } catch (IOException replacementStreamFailure) {
+ // The replacement was also broken. Loop to ask for another replacement.
+ Util.closeQuietly(replacementStream);
+ e = replacementStreamFailure;
+ }
+ }
+ }
+
+ /**
+ * Returns a replacement output stream to recover from {@code e} thrown by the
+ * previous stream. Returns a new OutputStream if recovery was successful, in
+ * which case all previously-written data will be replayed. Returns null if
+ * the failure cannot be recovered.
+ */
+ protected abstract OutputStream replacementStream(IOException e);
+}
diff --git a/okhttp/src/test/java/com/squareup/okhttp/internal/FaultRecoveringOutputStreamTest.java b/okhttp/src/test/java/com/squareup/okhttp/internal/FaultRecoveringOutputStreamTest.java
new file mode 100644
index 000000000..5b3a621e4
--- /dev/null
+++ b/okhttp/src/test/java/com/squareup/okhttp/internal/FaultRecoveringOutputStreamTest.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright (C) 2013 Square, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.squareup.okhttp.internal;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.List;
+import org.junit.Test;
+
+import static com.squareup.okhttp.internal.Util.UTF_8;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public final class FaultRecoveringOutputStreamTest {
+ @Test public void noRecoveryWithoutReplacement() throws Exception {
+ FaultingOutputStream faulting = new FaultingOutputStream();
+ TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting);
+
+ recovering.write('a');
+ faulting.nextFault = "system on fire";
+ try {
+ recovering.write('b');
+ fail();
+ } catch (IOException e) {
+ assertEquals(Arrays.asList("system on fire"), recovering.exceptionMessages);
+ assertEquals("ab", faulting.receivedUtf8);
+ assertFalse(faulting.closed);
+ }
+ }
+
+ @Test public void successfulRecovery() throws Exception {
+ FaultingOutputStream faulting1 = new FaultingOutputStream();
+ FaultingOutputStream faulting2 = new FaultingOutputStream();
+ TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1);
+ recovering.replacements.addLast(faulting2);
+
+ recovering.write('a');
+ assertEquals("a", faulting1.receivedUtf8);
+ assertEquals("", faulting2.receivedUtf8);
+ faulting1.nextFault = "system under water";
+ recovering.write('b');
+ assertEquals(Arrays.asList("system under water"), recovering.exceptionMessages);
+ assertEquals("ab", faulting1.receivedUtf8);
+ assertEquals("ab", faulting2.receivedUtf8);
+ assertTrue(faulting1.closed);
+ assertFalse(faulting2.closed);
+
+ // Confirm that new data goes to the new stream.
+ recovering.write('c');
+ assertEquals("ab", faulting1.receivedUtf8);
+ assertEquals("abc", faulting2.receivedUtf8);
+ }
+
+ @Test public void replacementStreamFaultsImmediately() throws Exception {
+ FaultingOutputStream faulting1 = new FaultingOutputStream();
+ FaultingOutputStream faulting2 = new FaultingOutputStream();
+ FaultingOutputStream faulting3 = new FaultingOutputStream();
+ TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1);
+ recovering.replacements.addLast(faulting2);
+ recovering.replacements.addLast(faulting3);
+
+ recovering.write('a');
+ assertEquals("a", faulting1.receivedUtf8);
+ assertEquals("", faulting2.receivedUtf8);
+ assertEquals("", faulting3.receivedUtf8);
+ faulting1.nextFault = "offline";
+ faulting2.nextFault = "slow";
+ recovering.write('b');
+ assertEquals(Arrays.asList("offline", "slow"), recovering.exceptionMessages);
+ assertEquals("ab", faulting1.receivedUtf8);
+ assertEquals("a", faulting2.receivedUtf8);
+ assertEquals("ab", faulting3.receivedUtf8);
+ assertTrue(faulting1.closed);
+ assertTrue(faulting2.closed);
+ assertFalse(faulting3.closed);
+
+ // Confirm that new data goes to the new stream.
+ recovering.write('c');
+ assertEquals("ab", faulting1.receivedUtf8);
+ assertEquals("a", faulting2.receivedUtf8);
+ assertEquals("abc", faulting3.receivedUtf8);
+ }
+
+ @Test public void recoverWithFullBuffer() throws Exception {
+ FaultingOutputStream faulting1 = new FaultingOutputStream();
+ FaultingOutputStream faulting2 = new FaultingOutputStream();
+ TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1);
+ recovering.replacements.addLast(faulting2);
+
+ recovering.write("abcdefghij".getBytes(UTF_8)); // 10 bytes.
+ faulting1.nextFault = "unlucky";
+ recovering.write('k');
+ assertEquals("abcdefghijk", faulting1.receivedUtf8);
+ assertEquals("abcdefghijk", faulting2.receivedUtf8);
+ assertEquals(Arrays.asList("unlucky"), recovering.exceptionMessages);
+ assertTrue(faulting1.closed);
+ assertFalse(faulting2.closed);
+
+ // Confirm that new data goes to the new stream.
+ recovering.write('l');
+ assertEquals("abcdefghijk", faulting1.receivedUtf8);
+ assertEquals("abcdefghijkl", faulting2.receivedUtf8);
+ }
+
+ @Test public void noRecoveryWithOverfullBuffer() throws Exception {
+ FaultingOutputStream faulting1 = new FaultingOutputStream();
+ FaultingOutputStream faulting2 = new FaultingOutputStream();
+ TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1);
+ recovering.replacements.addLast(faulting2);
+
+ recovering.write("abcdefghijk".getBytes(UTF_8)); // 11 bytes.
+ faulting1.nextFault = "out to lunch";
+ try {
+ recovering.write('l');
+ fail();
+ } catch (IOException expected) {
+ assertEquals("out to lunch", expected.getMessage());
+ }
+
+ assertEquals(Arrays.