1
0
mirror of https://github.com/square/okhttp.git synced 2026-01-14 07:22:20 +03:00

Hook up compression in WebSocketReader and WebSocketWriter

It isn't yet enabled in the calling RealWebSocket, so this is only
visible to tests right now.
This commit is contained in:
Jesse Wilson
2020-03-14 18:39:50 -04:00
parent 97a5e7a9e0
commit 8703126227
10 changed files with 337 additions and 71 deletions

View File

@@ -55,7 +55,8 @@ interface WebSocket {
/**
* Returns the size in bytes of all messages enqueued to be transmitted to the server. This
* doesn't include framing overhead. It also doesn't include any bytes buffered by the operating
* doesn't include framing overhead. If compression is enabled, uncompressed messages size
* is used to calculate this value. It also doesn't include any bytes buffered by the operating
* system or network intermediaries. This method returns 0 if no messages are waiting in the
* queue. If may return a nonzero value after the web socket has been canceled; this indicates
* that enqueued messages were not transmitted.

View File

@@ -215,7 +215,13 @@ class RealWebSocket(
synchronized(this) {
this.name = name
this.streams = streams
this.writer = WebSocketWriter(streams.client, streams.sink, random)
this.writer = WebSocketWriter(
isClient = streams.client,
sink = streams.sink,
random = random,
messageDeflater = null,
minimumDeflateSize = Long.MAX_VALUE
)
this.writerTask = WriterTask()
if (pingIntervalMillis != 0L) {
val pingIntervalNanos = MILLISECONDS.toNanos(pingIntervalMillis)
@@ -229,7 +235,12 @@ class RealWebSocket(
}
}
reader = WebSocketReader(streams.client, streams.source, this)
reader = WebSocketReader(
isClient = streams.client,
source = streams.source,
frameCallback = this,
messageInflater = null
)
}
/** Receive frames until there are no more. Invoked only by the reader thread. */
@@ -304,6 +315,8 @@ class RealWebSocket(
require(code != -1)
var toClose: Streams? = null
var readerToClose: WebSocketReader? = null
var writerToClose: WebSocketWriter? = null
synchronized(this) {
check(receivedCloseCode == -1) { "already closed" }
receivedCloseCode = code
@@ -311,6 +324,10 @@ class RealWebSocket(
if (enqueuedClose && messageAndCloseQueue.isEmpty()) {
toClose = this.streams
this.streams = null
readerToClose = this.reader
this.reader = null
writerToClose = this.writer
this.writer = null
this.taskQueue.shutdown()
}
}
@@ -323,6 +340,8 @@ class RealWebSocket(
}
} finally {
toClose?.closeQuietly()
readerToClose?.closeQuietly()
writerToClose?.closeQuietly()
}
}
@@ -422,6 +441,8 @@ class RealWebSocket(
var receivedCloseCode = -1
var receivedCloseReason: String? = null
var streamsToClose: Streams? = null
var readerToClose: WebSocketReader? = null
var writerToClose: WebSocketWriter? = null
synchronized(this@RealWebSocket) {
if (failed) {
@@ -438,6 +459,10 @@ class RealWebSocket(
if (receivedCloseCode != -1) {
streamsToClose = this.streams
this.streams = null
readerToClose = this.reader
this.reader = null
writerToClose = this.writer
this.writer = null
this.taskQueue.shutdown()
} else {
// When we request a graceful close also schedule a cancel of the web socket.
@@ -476,6 +501,8 @@ class RealWebSocket(
return true
} finally {
streamsToClose?.closeQuietly()
readerToClose?.closeQuietly()
writerToClose?.closeQuietly()
}
}
@@ -505,11 +532,17 @@ class RealWebSocket(
fun failWebSocket(e: Exception, response: Response?) {
val streamsToClose: Streams?
val readerToClose: WebSocketReader?
val writerToClose: WebSocketWriter?
synchronized(this) {
if (failed) return // Already failed.
failed = true
streamsToClose = this.streams
this.streams = null
readerToClose = this.reader
this.reader = null
writerToClose = this.writer
this.writer = null
taskQueue.shutdown()
}
@@ -517,6 +550,8 @@ class RealWebSocket(
listener.onFailure(this, e, response)
} finally {
streamsToClose?.closeQuietly()
readerToClose?.closeQuietly()
writerToClose?.closeQuietly()
}
}

View File

@@ -16,7 +16,7 @@
package okhttp3.internal.ws
import java.io.IOException
import okhttp3.Response
import okhttp3.Headers
import okhttp3.internal.delimiterOffset
import okhttp3.internal.trimSubstring
@@ -82,7 +82,7 @@ data class WebSocketExtensions(
private const val HEADER_WEB_SOCKET_EXTENSION = "Sec-WebSocket-Extensions"
@Throws(IOException::class)
fun parse(response: Response): WebSocketExtensions {
fun parse(responseHeaders: Headers): WebSocketExtensions {
// Note that this code does case-insensitive comparisons, even though the spec doesn't specify
// whether extension tokens and parameters are case-insensitive or not.
@@ -94,11 +94,11 @@ data class WebSocketExtensions(
var unexpectedValues = false
// Parse each header.
for (i in 0 until response.headers.size) {
if (!response.headers.name(i).equals(HEADER_WEB_SOCKET_EXTENSION, ignoreCase = true)) {
for (i in 0 until responseHeaders.size) {
if (!responseHeaders.name(i).equals(HEADER_WEB_SOCKET_EXTENSION, ignoreCase = true)) {
continue // Not a header we're interested in.
}
val header = response.headers.value(i)
val header = responseHeaders.value(i)
// Parse each extension.
var pos = 0

View File

@@ -15,6 +15,7 @@
*/
package okhttp3.internal.ws
import java.io.Closeable
import java.io.IOException
import java.net.ProtocolException
import java.util.concurrent.TimeUnit
@@ -50,19 +51,20 @@ import okio.ByteString
*
* [rfc_6455]: http://tools.ietf.org/html/rfc6455
*/
internal class WebSocketReader(
class WebSocketReader(
private val isClient: Boolean,
val source: BufferedSource,
private val frameCallback: FrameCallback
) {
var closed = false
private val frameCallback: FrameCallback,
private val messageInflater: MessageInflater?
) : Closeable {
private var closed = false
// Stateful data about the current frame.
private var opcode = 0
private var frameLength = 0L
private var isFinalFrame = false
private var isControlFrame = false
private var readingCompressedMessage = false
private val controlFrameBuffer = Buffer()
private val messageFrameBuffer = Buffer()
@@ -125,13 +127,26 @@ internal class WebSocketReader(
}
val reservedFlag1 = b0 and B0_FLAG_RSV1 != 0
val reservedFlag2 = b0 and B0_FLAG_RSV2 != 0
val reservedFlag3 = b0 and B0_FLAG_RSV3 != 0
if (reservedFlag1 || reservedFlag2 || reservedFlag3) {
// Reserved flags are for extensions which we currently do not support.
throw ProtocolException("Reserved flags are unsupported.")
when (opcode) {
OPCODE_TEXT, OPCODE_BINARY -> {
if (reservedFlag1) {
if (messageInflater == null) throw ProtocolException("Unexpected rsv1 flag")
readingCompressedMessage = true
} else {
readingCompressedMessage = false
}
}
else -> {
if (reservedFlag1) throw ProtocolException("Unexpected rsv1 flag")
}
}
val reservedFlag2 = b0 and B0_FLAG_RSV2 != 0
if (reservedFlag2) throw ProtocolException("Unexpected rsv2 flag")
val reservedFlag3 = b0 and B0_FLAG_RSV3 != 0
if (reservedFlag3) throw ProtocolException("Unexpected rsv3 flag")
val b1 = source.readByte() and 0xff
val isMasked = b1 and B1_FLAG_MASK != 0
@@ -216,6 +231,10 @@ internal class WebSocketReader(
readMessage()
if (readingCompressedMessage) {
messageInflater!!.inflate(messageFrameBuffer)
}
if (opcode == OPCODE_TEXT) {
frameCallback.onReadMessage(messageFrameBuffer.readUtf8())
} else {
@@ -264,4 +283,9 @@ internal class WebSocketReader(
}
}
}
@Throws(IOException::class)
override fun close() {
messageInflater?.close()
}
}

View File

@@ -15,9 +15,11 @@
*/
package okhttp3.internal.ws
import java.io.Closeable
import java.io.IOException
import java.util.Random
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_FIN
import okhttp3.internal.ws.WebSocketProtocol.B0_FLAG_RSV1
import okhttp3.internal.ws.WebSocketProtocol.B1_FLAG_MASK
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_CLOSE
import okhttp3.internal.ws.WebSocketProtocol.OPCODE_CONTROL_PING
@@ -39,11 +41,15 @@ import okio.ByteString
*
* [rfc_6455]: http://tools.ietf.org/html/rfc6455
*/
internal class WebSocketWriter(
class WebSocketWriter(
private val isClient: Boolean,
val sink: BufferedSink,
val random: Random
) {
val random: Random,
private val messageDeflater: MessageDeflater?,
private val minimumDeflateSize: Long
) : Closeable {
/** This buffer is holds outbound data for compression and masking. */
private val messageBuffer = Buffer()
/** The [Buffer] of [sink]. Write to this and then flush/emit [sink]. */
private val sinkBuffer: Buffer = sink.buffer
@@ -136,7 +142,15 @@ internal class WebSocketWriter(
fun writeMessageFrame(formatOpcode: Int, data: ByteString) {
if (writerClosed) throw IOException("closed")
val b0 = formatOpcode or B0_FLAG_FIN
messageBuffer.write(data)
var b0 = formatOpcode or B0_FLAG_FIN
val messageDeflater = this.messageDeflater
if (messageDeflater != null && data.size >= minimumDeflateSize) {
messageDeflater.deflate(messageBuffer)
b0 = b0 or B0_FLAG_RSV1
}
val dataSize = messageBuffer.size
sinkBuffer.writeByte(b0)
var b1 = 0
@@ -144,19 +158,19 @@ internal class WebSocketWriter(
b1 = b1 or B1_FLAG_MASK
}
when {
data.size <= PAYLOAD_BYTE_MAX -> {
b1 = b1 or data.size
dataSize <= PAYLOAD_BYTE_MAX -> {
b1 = b1 or dataSize.toInt()
sinkBuffer.writeByte(b1)
}
data.size <= PAYLOAD_SHORT_MAX -> {
dataSize <= PAYLOAD_SHORT_MAX -> {
b1 = b1 or PAYLOAD_SHORT
sinkBuffer.writeByte(b1)
sinkBuffer.writeShort(data.size)
sinkBuffer.writeShort(dataSize.toInt())
}
else -> {
b1 = b1 or PAYLOAD_LONG
sinkBuffer.writeByte(b1)
sinkBuffer.writeLong(data.size.toLong())
sinkBuffer.writeLong(dataSize)
}
}
@@ -164,19 +178,19 @@ internal class WebSocketWriter(
random.nextBytes(maskKey!!)
sinkBuffer.write(maskKey)
if (data.size > 0L) {
val bufferStart = sinkBuffer.size
sinkBuffer.write(data)
sinkBuffer.readAndWriteUnsafe(maskCursor!!)
maskCursor.seek(bufferStart)
if (dataSize > 0L) {
messageBuffer.readAndWriteUnsafe(maskCursor!!)
maskCursor.seek(0L)
toggleMask(maskCursor, maskKey)
maskCursor.close()
}
} else {
sinkBuffer.write(data)
}
sinkBuffer.write(messageBuffer, dataSize)
sink.emit()
}
override fun close() {
messageDeflater?.close()
}
}

View File

@@ -21,10 +21,12 @@ import java.net.ProtocolException;
import java.net.SocketTimeoutException;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import okhttp3.Headers;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.internal.concurrent.TaskRunner;
import okio.Buffer;
import okio.ByteString;
import okio.Okio;
import okio.Pipe;
@@ -357,6 +359,48 @@ public final class RealWebSocketTest {
250d));
}
@Test public void writesUncompressedMessageIfCompressionDisabled() throws IOException {
server.initWebSocket(random, 0);
server.webSocket.send(ByteString.encodeUtf8("Hello"));
Buffer buffer = new Buffer();
server2client.source().read(buffer, Integer.MAX_VALUE);
assertThat(buffer.readByteString())
.isEqualTo(ByteString.decodeHex("820548656c6c6f")); // Uncompressed Hello
}
@Test public void writesUncompressedMessageIfMessageTooSmall()
throws IOException {
server.initWebSocket(random, 0,
Headers.of("Sec-WebSocket-Extensions", "permessage-deflate"));
// Length 5 is less than 10, our minimum compressed size.
server.webSocket.send(ByteString.encodeUtf8("Hello"));
Buffer buffer = new Buffer();
server2client.source().read(buffer, Integer.MAX_VALUE);
assertThat(buffer.readByteString())
.isEqualTo(ByteString.decodeHex("820548656c6c6f")); // Uncompressed
}
@Ignore
@Test public void writesCompressedMessage() throws IOException {
server.initWebSocket(random, 0,
Headers.of("Sec-WebSocket-Extensions", "permessage-deflate"));
// Length 35 is greater than 10, our minimum compressed size.
server.webSocket.send(ByteString.encodeUtf8("Hello Hello Hello Hello Hello Hello"));
Buffer buffer = new Buffer();
server2client.source().read(buffer, Integer.MAX_VALUE);
assertThat(buffer.readByteString())
.isEqualTo(ByteString.decodeHex("c20bf248cdc9c957f0c0470200")); // Compressed
}
/** One peer's streams, listener, and web socket in the test. */
private static class TestStreams extends RealWebSocket.Streams {
private final String name;
@@ -372,11 +416,17 @@ public final class RealWebSocketTest {
}
public void initWebSocket(Random random, int pingIntervalMillis) throws IOException {
initWebSocket(random, pingIntervalMillis, Headers.of());
}
public void initWebSocket(
Random random, int pingIntervalMillis, Headers responseHeaders) throws IOException {
String url = "http://example.com/websocket";
Response response = new Response.Builder()
.code(101)
.message("OK")
.request(new Request.Builder().url(url).build())
.headers(responseHeaders)
.protocol(Protocol.HTTP_1_1)
.build();
webSocket = new RealWebSocket(

View File

@@ -15,31 +15,19 @@
*/
package okhttp3.internal.ws
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okhttp3.Headers.Companion.headersOf
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class WebSocketExtensionsTest {
private val minimalResponse = Response.Builder()
.protocol(Protocol.HTTP_1_1)
.code(200)
.message("OK")
.request(
Request.Builder()
.url("https://example.com/")
.build()
)
.build()
@Test
fun emptyHeader() {
assertThat(parse("")).isEqualTo(WebSocketExtensions())
}
@Test fun noExtensionHeader() {
assertThat(WebSocketExtensions.parse(minimalResponse))
@Test
fun noExtensionHeader() {
assertThat(WebSocketExtensions.parse(headersOf()))
.isEqualTo(WebSocketExtensions())
}
@@ -79,11 +67,10 @@ class WebSocketExtensionsTest {
@Test
fun multiplePerMessageDeflateHeaders() {
val response = minimalResponse.newBuilder()
.header("Sec-WebSocket-Extensions", "")
.header("Sec-WebSocket-Extensions", "permessage-deflate")
.build()
val extensions = WebSocketExtensions.parse(response)
val extensions = WebSocketExtensions.parse(headersOf(
"Sec-WebSocket-Extensions", "",
"Sec-WebSocket-Extensions", "permessage-deflate"
))
assertThat(extensions)
.isEqualTo(WebSocketExtensions(
perMessageDeflate = true
@@ -232,10 +219,6 @@ class WebSocketExtensionsTest {
))
}
private fun parse(extension: String): WebSocketExtensions {
val response = minimalResponse.newBuilder()
.header("Sec-WebSocket-Extensions", extension)
.build()
return WebSocketExtensions.parse(response)
}
private fun parse(extension: String) =
WebSocketExtensions.parse(headersOf("Sec-WebSocket-Extensions", extension))
}

View File

@@ -780,6 +780,33 @@ public final class WebSocketHttpTest {
webSocket.close(1000, null);
}
@Ignore
@Test public void compressedMessages() {
client = client.newBuilder()
.addInterceptor(chain -> {
assertThat(chain.request().header("Sec-WebSocket-Extensions"))
.isEqualTo("permessage-deflate");
return chain.proceed(chain.request());
})
.build();
webServer.enqueue(new MockResponse()
.addHeader("Sec-WebSocket-Extensions", "permessage-deflate")
.withWebSocketUpgrade(serverListener));
WebSocket client = newWebSocket();
clientListener.assertOpen();
WebSocket server = serverListener.assertOpen();
server.send("Hello this is a compressed message from the server!");
clientListener.assertTextMessage("Hello this is a compressed message from the server!");
client.send("Hello this is a compressed message from the client!");
serverListener.assertTextMessage("Hello this is a compressed message from the client!");
closeWebSockets(client, server);
}
private MockResponse upgradeResponse(RecordedRequest request) {
String key = request.getHeader("Sec-WebSocket-Key");
return new MockResponse()

View File

@@ -34,8 +34,14 @@ public final class WebSocketReaderTest {
private final Random random = new Random(0);
// Mutually exclusive. Use the one corresponding to the peer whose behavior you wish to test.
final WebSocketReader serverReader = new WebSocketReader(false, data, callback.asFrameCallback());
final WebSocketReader clientReader = new WebSocketReader(true, data, callback.asFrameCallback());
final WebSocketReader serverReader =
new WebSocketReader(false, data, callback.asFrameCallback(), null);
final WebSocketReader serverReaderWithCompression =
new WebSocketReader(false, data, callback.asFrameCallback(), new MessageInflater(false));
final WebSocketReader clientReader =
new WebSocketReader(true, data, callback.asFrameCallback(), null);
final WebSocketReader clientReaderWithCompression =
new WebSocketReader(true, data, callback.asFrameCallback(), new MessageInflater(false));
@After public void tearDown() {
callback.assertExhausted();
@@ -51,21 +57,43 @@ public final class WebSocketReaderTest {
}
}
@Test public void reservedFlagsAreUnsupported() throws IOException {
@Test public void reservedFlag1IsUnsupportedWithNoCompression() throws IOException {
data.write(ByteString.decodeHex("ca00")); // Empty pong, flag 1 set.
try {
clientReader.processNextFrame();
fail();
} catch (ProtocolException e) {
assertThat(e.getMessage()).isEqualTo("Reserved flags are unsupported.");
assertThat(e.getMessage()).isEqualTo("Unexpected rsv1 flag");
}
data.clear();
}
@Test public void reservedFlag1IsUnsupportedForControlFrames() throws IOException {
data.write(ByteString.decodeHex("ca00")); // Empty pong, flag 1 set.
try {
clientReaderWithCompression.processNextFrame();
fail();
} catch (ProtocolException e) {
assertThat(e.getMessage()).isEqualTo("Unexpected rsv1 flag");
}
}
@Test public void reservedFlag1IsUnsupportedForContinuationFrames() throws IOException {
data.write(ByteString.decodeHex("c000")); // Empty continuation, flag 1 set.
try {
clientReaderWithCompression.processNextFrame();
fail();
} catch (ProtocolException e) {
assertThat(e.getMessage()).isEqualTo("Unexpected rsv1 flag");
}
}
@Test public void reservedFlags2and3AreUnsupported() throws IOException {
data.write(ByteString.decodeHex("aa00")); // Empty pong, flag 2 set.
try {
clientReader.processNextFrame();
fail();
} catch (ProtocolException e) {
assertThat(e.getMessage()).isEqualTo("Reserved flags are unsupported.");
assertThat(e.getMessage()).isEqualTo("Unexpected rsv2 flag");
}
data.clear();
data.write(ByteString.decodeHex("9a00")); // Empty pong, flag 3 set.
@@ -73,7 +101,7 @@ public final class WebSocketReaderTest {
clientReader.processNextFrame();
fail();
} catch (ProtocolException e) {
assertThat(e.getMessage()).isEqualTo("Reserved flags are unsupported.");
assertThat(e.getMessage()).isEqualTo("Unexpected rsv3 flag");
}
}
@@ -113,12 +141,36 @@ public final class WebSocketReaderTest {
callback.assertTextMessage("Hello");
}
@Test public void clientWithCompressionSimpleUncompressedHello() throws IOException {
data.write(ByteString.decodeHex("810548656c6c6f")); // Hello
clientReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void clientWithCompressionSimpleCompressedHello() throws IOException {
data.write(ByteString.decodeHex("c107f248cdc9c90700")); // Hello
clientReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void serverSimpleHello() throws IOException {
data.write(ByteString.decodeHex("818537fa213d7f9f4d5158")); // Hello
serverReader.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void serverWithCompressionSimpleUncompressedHello() throws IOException {
data.write(ByteString.decodeHex("818537fa213d7f9f4d5158")); // Hello
serverReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void serverWithCompressionSimpleCompressedHello() throws IOException {
data.write(ByteString.decodeHex("c18760b420bb92fced72a9b320")); // Hello
serverReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void clientFramePayloadShort() throws IOException {
data.write(ByteString.decodeHex("817E000548656c6c6f")); // Hello
clientReader.processNextFrame();
@@ -151,6 +203,23 @@ public final class WebSocketReaderTest {
callback.assertTextMessage("Hello");
}
@Test public void serverWithCompressionHelloTwoChunks() throws IOException {
data.write(ByteString.decodeHex("818537fa213d7f9f4d")); // Hel
data.write(ByteString.decodeHex("5158")); // lo
serverReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void serverWithCompressionCompressedHelloTwoChunks() throws IOException {
data.write(ByteString.decodeHex("418460b420bb92fced72")); // first 4 bytes of compressed 'Hello'
data.write(ByteString.decodeHex("80833851d9d4f156d9")); // last 3 bytes of compressed 'Hello'
serverReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void clientTwoFrameHello() throws IOException {
data.write(ByteString.decodeHex("010348656c")); // Hel
data.write(ByteString.decodeHex("80026c6f")); // lo
@@ -158,6 +227,20 @@ public final class WebSocketReaderTest {
callback.assertTextMessage("Hello");
}
@Test public void clientWithCompressionTwoFrameHello() throws IOException {
data.write(ByteString.decodeHex("010348656c")); // Hel
data.write(ByteString.decodeHex("80026c6f")); // lo
clientReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void clientWithCompressionTwoFrameCompressedHello() throws IOException {
data.write(ByteString.decodeHex("4104f248cdc9")); // first 4 bytes of compressed 'Hello'
data.write(ByteString.decodeHex("8003c90700")); // last 3 bytes of compressed 'Hello'
clientReaderWithCompression.processNextFrame();
callback.assertTextMessage("Hello");
}
@Test public void clientTwoFrameHelloWithPongs() throws IOException {
data.write(ByteString.decodeHex("010348656c")); // Hel
data.write(ByteString.decodeHex("8a00")); // Pong
@@ -173,6 +256,21 @@ public final class WebSocketReaderTest {
callback.assertTextMessage("Hello");
}
@Test public void clientTwoFrameCompressedHelloWithPongs() throws IOException {
data.write(ByteString.decodeHex("4104f248cdc9")); // first 4 bytes of compressed 'Hello'
data.write(ByteString.decodeHex("8a00")); // Pong
data.write(ByteString.decodeHex("8a00")); // Pong
data.write(ByteString.decodeHex("8a00")); // Pong
data.write(ByteString.decodeHex("8a00")); // Pong
data.write(ByteString.decodeHex("8003c90700")); // last 3 bytes of compressed 'Hello'
clientReaderWithCompression.processNextFrame();
callback.assertPong(ByteString.EMPTY);
callback.assertPong(ByteString.EMPTY);
callback.assertPong(ByteString.EMPTY);
callback.assertPong(ByteString.EMPTY);
callback.assertTextMessage("Hello");
}
@Test public void clientIncompleteMessageBodyThrows() throws IOException {
data.write(ByteString.decodeHex("810548656c")); // Length = 5, "Hel"
try {
@@ -182,6 +280,15 @@ public final class WebSocketReaderTest {
}
}
@Test public void clientUncompressedMessageWithCompressedFlagThrows() throws IOException {
data.write(ByteString.decodeHex("c10548656c6c6f")); // Uncompressed 'Hello', flag 1 set
try {
clientReaderWithCompression.processNextFrame();
fail();
} catch (IOException ignored) {
}
}
@Test public void clientIncompleteControlFrameBodyThrows() throws IOException {
data.write(ByteString.decodeHex("8a0548656c")); // Length = 5, "Hel"
try {
@@ -315,6 +422,17 @@ public final class WebSocketReaderTest {
assertThat(count).isEqualTo(1988);
}
@Test public void clientWithCompressionCannotBeUsedAfterClose() throws IOException {
data.write(ByteString.decodeHex("c107f248cdc9c90700")); // Hello
clientReaderWithCompression.close();
try {
clientReaderWithCompression.processNextFrame();
fail();
} catch (Exception e) {
assertThat(e.getMessage()).contains("closed");
}
}
private byte[] binaryData(int length) {
byte[] junk = new byte[length];
random.nextBytes(junk);

View File

@@ -52,14 +52,21 @@ public final class WebSocketWriterTest {
};
// Mutually exclusive. Use the one corresponding to the peer whose behavior you wish to test.
private final WebSocketWriter serverWriter = new WebSocketWriter(false, data, random);
private final WebSocketWriter clientWriter = new WebSocketWriter(true, data, random);
private final WebSocketWriter serverWriter = new WebSocketWriter(false, data, random, null, 0L);
private final WebSocketWriter clientWriter = new WebSocketWriter(true, data, random, null, 0L);
@Test public void serverTextMessage() throws IOException {
serverWriter.writeMessageFrame(OPCODE_TEXT, ByteString.encodeUtf8("Hello"));
assertData("810548656c6c6f");
}
@Test public void serverCompressedTextMessage() throws IOException {
WebSocketWriter serverWriter = new WebSocketWriter(
false, data, random, new MessageDeflater(false), 0L);
serverWriter.writeMessageFrame(OPCODE_TEXT, ByteString.encodeUtf8("Hello"));
assertData("c107f248cdc9c90700");
}
@Test public void serverSmallBufferedPayloadWrittenAsOneFrame() throws IOException {
int length = 5;
ByteString payload = ByteString.of(binaryData(length));
@@ -85,6 +92,13 @@ public final class WebSocketWriterTest {
assertData("818560b420bb28d14cd70f");
}
@Test public void clientCompressedTextMessage() throws IOException {
WebSocketWriter clientWriter = new WebSocketWriter(
false, data, random, new MessageDeflater(false), 0L);
clientWriter.writeMessageFrame(OPCODE_TEXT, ByteString.encodeUtf8("Hello"));
assertData("c107f248cdc9c90700");
}
@Test public void serverBinaryMessage() throws IOException {
ByteString payload = ByteString.decodeHex(""
+ "60b420bb3851d9d47acb933dbe70399bf6c92da33af01d4fb7"