diff --git a/okhttp/src/main/java/com/squareup/okhttp/internal/FaultRecoveringOutputStream.java b/okhttp/src/main/java/com/squareup/okhttp/internal/FaultRecoveringOutputStream.java new file mode 100644 index 000000000..d4e7507ab --- /dev/null +++ b/okhttp/src/main/java/com/squareup/okhttp/internal/FaultRecoveringOutputStream.java @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2013 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.squareup.okhttp.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import static com.squareup.okhttp.internal.Util.checkOffsetAndCount; + +/** + * An output stream wrapper that recovers from failures in the underlying stream + * by replacing it with another stream. This class buffers a fixed amount of + * data under the assumption that failures occur early in a stream's life. + * If a failure occurs after the buffer has been exhausted, no recovery is + * attempted. + * + *

Subclasses must override {@link #replacementStream} which will request a + * replacement stream each time an {@link IOException} is encountered on the + * current stream. + */ +public abstract class FaultRecoveringOutputStream extends OutputStream { + private final int maxReplayBufferLength; + + /** Bytes to transmit on the replacement stream, or null if no recovery is possible. */ + private ByteArrayOutputStream replayBuffer; + private OutputStream out; + private boolean closed; + + /** + * @param maxReplayBufferLength the maximum number of successfully written + * bytes to buffer so they can be replayed in the event of an error. + * Failure recoveries are not possible once this limit has been exceeded. + */ + public FaultRecoveringOutputStream(int maxReplayBufferLength, OutputStream out) { + if (maxReplayBufferLength < 0) throw new IllegalArgumentException(); + this.maxReplayBufferLength = maxReplayBufferLength; + this.replayBuffer = new ByteArrayOutputStream(maxReplayBufferLength); + this.out = out; + } + + @Override public final void write(int data) throws IOException { + write(new byte[] { (byte) data }); + } + + @Override public final void write(byte[] buffer, int offset, int count) throws IOException { + if (closed) throw new IOException("stream closed"); + checkOffsetAndCount(buffer.length, offset, count); + + while (true) { + try { + out.write(buffer, offset, count); + + if (replayBuffer != null) { + if (count + replayBuffer.size() > maxReplayBufferLength) { + // Failure recovery is no longer possible once we overflow the replay buffer. + replayBuffer = null; + } else { + // Remember the written bytes to the replay buffer. + replayBuffer.write(buffer, offset, count); + } + } + return; + } catch (IOException e) { + if (!recover(e)) throw e; + } + } + } + + @Override public final void flush() throws IOException { + if (closed) { + return; // don't throw; this stream might have been closed on the caller's behalf + } + out.flush(); + } + + @Override public final void close() throws IOException { + if (closed) { + return; + } + out.close(); + closed = true; + } + + /** + * Attempt to replace {@code out} with another equivalent stream. Returns true + * if a suitable replacement stream was found. + */ + private boolean recover(IOException e) { + if (replayBuffer == null) { + return false; // Can't recover because we've dropped data that we would need to replay. + } + + while (true) { + OutputStream replacementStream = replacementStream(e); + if (replacementStream == null) { + return false; + } + try { + replayBuffer.writeTo(replacementStream); + // We've found a replacement that works! + Util.closeQuietly(out); + out = replacementStream; + return true; + } catch (IOException replacementStreamFailure) { + // The replacement was also broken. Loop to ask for another replacement. + Util.closeQuietly(replacementStream); + e = replacementStreamFailure; + } + } + } + + /** + * Returns a replacement output stream to recover from {@code e} thrown by the + * previous stream. Returns a new OutputStream if recovery was successful, in + * which case all previously-written data will be replayed. Returns null if + * the failure cannot be recovered. + */ + protected abstract OutputStream replacementStream(IOException e); +} diff --git a/okhttp/src/test/java/com/squareup/okhttp/internal/FaultRecoveringOutputStreamTest.java b/okhttp/src/test/java/com/squareup/okhttp/internal/FaultRecoveringOutputStreamTest.java new file mode 100644 index 000000000..5b3a621e4 --- /dev/null +++ b/okhttp/src/test/java/com/squareup/okhttp/internal/FaultRecoveringOutputStreamTest.java @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2013 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.squareup.okhttp.internal; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import org.junit.Test; + +import static com.squareup.okhttp.internal.Util.UTF_8; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public final class FaultRecoveringOutputStreamTest { + @Test public void noRecoveryWithoutReplacement() throws Exception { + FaultingOutputStream faulting = new FaultingOutputStream(); + TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting); + + recovering.write('a'); + faulting.nextFault = "system on fire"; + try { + recovering.write('b'); + fail(); + } catch (IOException e) { + assertEquals(Arrays.asList("system on fire"), recovering.exceptionMessages); + assertEquals("ab", faulting.receivedUtf8); + assertFalse(faulting.closed); + } + } + + @Test public void successfulRecovery() throws Exception { + FaultingOutputStream faulting1 = new FaultingOutputStream(); + FaultingOutputStream faulting2 = new FaultingOutputStream(); + TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1); + recovering.replacements.addLast(faulting2); + + recovering.write('a'); + assertEquals("a", faulting1.receivedUtf8); + assertEquals("", faulting2.receivedUtf8); + faulting1.nextFault = "system under water"; + recovering.write('b'); + assertEquals(Arrays.asList("system under water"), recovering.exceptionMessages); + assertEquals("ab", faulting1.receivedUtf8); + assertEquals("ab", faulting2.receivedUtf8); + assertTrue(faulting1.closed); + assertFalse(faulting2.closed); + + // Confirm that new data goes to the new stream. + recovering.write('c'); + assertEquals("ab", faulting1.receivedUtf8); + assertEquals("abc", faulting2.receivedUtf8); + } + + @Test public void replacementStreamFaultsImmediately() throws Exception { + FaultingOutputStream faulting1 = new FaultingOutputStream(); + FaultingOutputStream faulting2 = new FaultingOutputStream(); + FaultingOutputStream faulting3 = new FaultingOutputStream(); + TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1); + recovering.replacements.addLast(faulting2); + recovering.replacements.addLast(faulting3); + + recovering.write('a'); + assertEquals("a", faulting1.receivedUtf8); + assertEquals("", faulting2.receivedUtf8); + assertEquals("", faulting3.receivedUtf8); + faulting1.nextFault = "offline"; + faulting2.nextFault = "slow"; + recovering.write('b'); + assertEquals(Arrays.asList("offline", "slow"), recovering.exceptionMessages); + assertEquals("ab", faulting1.receivedUtf8); + assertEquals("a", faulting2.receivedUtf8); + assertEquals("ab", faulting3.receivedUtf8); + assertTrue(faulting1.closed); + assertTrue(faulting2.closed); + assertFalse(faulting3.closed); + + // Confirm that new data goes to the new stream. + recovering.write('c'); + assertEquals("ab", faulting1.receivedUtf8); + assertEquals("a", faulting2.receivedUtf8); + assertEquals("abc", faulting3.receivedUtf8); + } + + @Test public void recoverWithFullBuffer() throws Exception { + FaultingOutputStream faulting1 = new FaultingOutputStream(); + FaultingOutputStream faulting2 = new FaultingOutputStream(); + TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1); + recovering.replacements.addLast(faulting2); + + recovering.write("abcdefghij".getBytes(UTF_8)); // 10 bytes. + faulting1.nextFault = "unlucky"; + recovering.write('k'); + assertEquals("abcdefghijk", faulting1.receivedUtf8); + assertEquals("abcdefghijk", faulting2.receivedUtf8); + assertEquals(Arrays.asList("unlucky"), recovering.exceptionMessages); + assertTrue(faulting1.closed); + assertFalse(faulting2.closed); + + // Confirm that new data goes to the new stream. + recovering.write('l'); + assertEquals("abcdefghijk", faulting1.receivedUtf8); + assertEquals("abcdefghijkl", faulting2.receivedUtf8); + } + + @Test public void noRecoveryWithOverfullBuffer() throws Exception { + FaultingOutputStream faulting1 = new FaultingOutputStream(); + FaultingOutputStream faulting2 = new FaultingOutputStream(); + TestFaultRecoveringOutputStream recovering = new TestFaultRecoveringOutputStream(10, faulting1); + recovering.replacements.addLast(faulting2); + + recovering.write("abcdefghijk".getBytes(UTF_8)); // 11 bytes. + faulting1.nextFault = "out to lunch"; + try { + recovering.write('l'); + fail(); + } catch (IOException expected) { + assertEquals("out to lunch", expected.getMessage()); + } + + assertEquals(Arrays.asList(), recovering.exceptionMessages); + assertEquals("abcdefghijkl", faulting1.receivedUtf8); + assertEquals("", faulting2.receivedUtf8); + assertFalse(faulting1.closed); + assertFalse(faulting2.closed); + } + + static class FaultingOutputStream extends OutputStream { + String receivedUtf8 = ""; + String nextFault; + boolean closed; + + @Override public final void write(int data) throws IOException { + write(new byte[] { (byte) data }); + } + + @Override public void write(byte[] buffer, int offset, int count) throws IOException { + receivedUtf8 += new String(buffer, offset, count, UTF_8); + if (nextFault != null) throw new IOException(nextFault); + } + + @Override public void close() throws IOException { + closed = true; + } + } + + static class TestFaultRecoveringOutputStream extends FaultRecoveringOutputStream { + final List exceptionMessages = new ArrayList(); + final Deque replacements = new ArrayDeque(); + + TestFaultRecoveringOutputStream(int maxReplayBufferLength, OutputStream first) { + super(maxReplayBufferLength, first); + } + + @Override protected OutputStream replacementStream(IOException e) { + exceptionMessages.add(e.getMessage()); + return replacements.poll(); + } + } +}