1
0
mirror of https://github.com/square/okhttp.git synced 2026-01-14 07:22:20 +03:00

Merge pull request #5871 from square/jwilson.0315.websockets

Turn on web socket compression
This commit is contained in:
Jesse Wilson
2020-03-17 21:19:16 -04:00
committed by GitHub
6 changed files with 166 additions and 33 deletions

View File

@@ -248,8 +248,6 @@ open class OkHttpClient internal constructor(
/** Uses [request] to connect a new web socket. */
override fun newWebSocket(request: Request, listener: WebSocketListener): WebSocket {
// Compress messages 1 KiB or larger.
val minimumDeflateSize = 1024L
val webSocket = RealWebSocket(
taskRunner = TaskRunner.INSTANCE,
originalRequest = request,
@@ -257,7 +255,7 @@ open class OkHttpClient internal constructor(
random = Random(),
pingIntervalMillis = pingIntervalMillis.toLong(),
extensions = null, // Always null for clients.
minimumDeflateSize = minimumDeflateSize
minimumDeflateSize = RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE
)
webSocket.connect(this)
return webSocket

View File

@@ -35,6 +35,14 @@ class MessageInflater(
fun inflate(buffer: Buffer) {
require(deflatedBytes.size == 0L)
// Handle the empty message special case. The compressed empty message is one byte, '0x00'. We
// can't use the normal flow here because inflaterSource.read() throws EOFException if the
// deflated stream isn't complete but there's no bytes to return.
if (buffer.size == 1L && buffer[0L] == 0.toByte()) {
buffer.skip(1L)
return
}
if (noContextTakeover) {
inflater.reset()
}

View File

@@ -144,6 +144,12 @@ class RealWebSocket(
}
fun connect(client: OkHttpClient) {
if (originalRequest.header("Sec-WebSocket-Extensions") != null) {
failWebSocket(ProtocolException(
"Request header not permitted: 'Sec-WebSocket-Extensions'"), null)
return
}
val webSocketClient = client.newBuilder()
.eventListener(EventListener.NONE)
.protocols(ONLY_HTTP1)
@@ -153,6 +159,7 @@ class RealWebSocket(
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", key)
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Extensions", "permessage-deflate")
.build()
call = RealCall(webSocketClient, request, forWebSocket = true)
call!!.enqueue(object : Callback {
@@ -162,8 +169,6 @@ class RealWebSocket(
try {
checkUpgradeSuccess(response, exchange)
streams = exchange!!.newWebSocketStreams()
// TODO(jwilson): use request & response headers to negotiate extensions.
extensions = WebSocketExtensions()
} catch (e: IOException) {
exchange?.webSocketUpgradeFailed()
failWebSocket(e, response)
@@ -171,6 +176,17 @@ class RealWebSocket(
return
}
// Apply the extensions. If they're unacceptable initiate a graceful shut down.
// TODO(jwilson): Listeners should get onFailure() instead of onClosing() + onClosed(1010).
val extensions = WebSocketExtensions.parse(response.headers)
this@RealWebSocket.extensions = extensions
if (!extensions.isValid()) {
synchronized(this@RealWebSocket) {
messageAndCloseQueue.clear() // Don't transmit any messages.
close(1010, "unexpected Sec-WebSocket-Extensions in response header")
}
}
// Process all web socket messages.
try {
val name = "$okHttpName WebSocket ${request.url.redact()}"
@@ -188,6 +204,20 @@ class RealWebSocket(
})
}
private fun WebSocketExtensions.isValid(): Boolean {
// If the server returned parameters we don't understand, fail the web socket.
if (unknownValues) return false
// If the server returned a value for client_max_window_bits, fail the web socket.
if (clientMaxWindowBits != null) return false
// If the server returned an illegal server_max_window_bits, fail the web socket.
if (serverMaxWindowBits != null && serverMaxWindowBits !in 8..15) return false
// Success.
return true
}
@Throws(IOException::class)
internal fun checkUpgradeSuccess(response: Response, exchange: Exchange?) {
if (response.code != 101) {
@@ -607,5 +637,15 @@ class RealWebSocket(
* the server doesn't respond the web socket will be canceled.
*/
private const val CANCEL_AFTER_CLOSE_MILLIS = 60L * 1000
/**
* The smallest message that will be compressed. We use 1024 because smaller messages already
* fit comfortably within a single ethernet packet (1500 bytes) even with framing overhead.
*
* For tests this must be big enough to realize real compression on test messages like
* 'aaaaaaaaaa...'. Our tests check if compression was applied just by looking at the size if
* the inbound buffer.
*/
const val DEFAULT_MINIMUM_DEFLATE_SIZE = 1024L
}
}

View File

@@ -50,6 +50,19 @@ internal class MessageDeflaterInflaterTest {
assertThat(inflated).isEqualTo(goldenValue)
}
@Test fun `inflate deflate empty message`() {
val deflater = MessageDeflater(false)
val inflater = MessageInflater(false)
val goldenValue = "".encodeUtf8()
val deflated = deflater.deflate(goldenValue)
assertThat(deflated).isEqualTo("00".decodeHex())
val inflated = inflater.inflate(deflated)
assertThat(inflated).isEqualTo(goldenValue)
}
@Test fun `inflate deflate with context takeover`() {
val deflater = MessageDeflater(false)
val inflater = MessageInflater(false)

View File

@@ -35,6 +35,7 @@ import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import static okhttp3.internal.ws.RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Offset.offset;
import static org.junit.Assert.fail;
@@ -44,16 +45,9 @@ public final class RealWebSocketTest {
// zero effect on the behavior of the WebSocket API which is why tests are only written once
// from the perspective of a single peer.
/**
* Compress messages of length 10 bytes or longer. This should be big enough to realize real
* compression on a message like 'aaaaaaaaaa...'. We check if compression was applied just by
* looking at the size if the inbound buffer.
*/
private static final int MINIMUM_DEFLATE_SIZE = 10;
private final Random random = new Random(0);
private final Pipe client2Server = new Pipe(1024L);
private final Pipe server2client = new Pipe(1024L);
private final Pipe client2Server = new Pipe(8192L);
private final Pipe server2client = new Pipe(8192L);
private TestStreams client = new TestStreams(true, server2client, client2Server);
private TestStreams server = new TestStreams(false, client2Server, server2client);
@@ -367,7 +361,7 @@ public final class RealWebSocketTest {
}
@Test public void messagesNotCompressedWhenNotConfigured() throws IOException {
String message = TestUtil.repeat('a', MINIMUM_DEFLATE_SIZE);
String message = TestUtil.repeat('a', (int) DEFAULT_MINIMUM_DEFLATE_SIZE);
server.webSocket.send(message);
assertThat(client.clientSourceBufferSize()).isGreaterThan(message.length()); // Not compressed.
@@ -380,7 +374,7 @@ public final class RealWebSocketTest {
client.initWebSocket(random, 0, headers);
server.initWebSocket(random, 0, headers);
String message = TestUtil.repeat('a', MINIMUM_DEFLATE_SIZE);
String message = TestUtil.repeat('a', (int) DEFAULT_MINIMUM_DEFLATE_SIZE);
server.webSocket.send(message);
assertThat(client.clientSourceBufferSize()).isLessThan(message.length()); // Compressed!
@@ -393,7 +387,7 @@ public final class RealWebSocketTest {
client.initWebSocket(random, 0, headers);
server.initWebSocket(random, 0, headers);
String message = TestUtil.repeat('a', MINIMUM_DEFLATE_SIZE - 1);
String message = TestUtil.repeat('a', (int) DEFAULT_MINIMUM_DEFLATE_SIZE - 1);
server.webSocket.send(message);
assertThat(client.clientSourceBufferSize()).isGreaterThan(message.length()); // Not compressed.
@@ -431,7 +425,7 @@ public final class RealWebSocketTest {
.build();
webSocket = new RealWebSocket(TaskRunner.INSTANCE, response.request(), listener, random,
pingIntervalMillis, WebSocketExtensions.Companion.parse(responseHeaders),
MINIMUM_DEFLATE_SIZE);
DEFAULT_MINIMUM_DEFLATE_SIZE);
webSocket.initReaderAndWriter(name, this);
}
@@ -441,7 +435,7 @@ public final class RealWebSocketTest {
*/
public long clientSourceBufferSize() throws IOException {
getSource().request(1L);
return getSource().buffer().size();
return getSource().getBuffer().size();
}
public boolean processNextFrame() throws IOException {

View File

@@ -370,6 +370,16 @@ public final class WebSocketHttpTest {
"Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");
}
@Test public void clientIncludesForbiddenHeader() throws IOException {
newWebSocket(new Request.Builder()
.url(webServer.url("/"))
.header("Sec-WebSocket-Extensions", "permessage-deflate")
.build());
clientListener.assertFailure(ProtocolException.class,
"Request header not permitted: 'Sec-WebSocket-Extensions'");
}
@Test public void webSocketAndApplicationInterceptors() {
final AtomicInteger interceptedCount = new AtomicInteger();
@@ -780,31 +790,101 @@ public final class WebSocketHttpTest {
webSocket.close(1000, null);
}
@Ignore
@Test public void compressedMessages() {
client = client.newBuilder()
.addInterceptor(chain -> {
assertThat(chain.request().header("Sec-WebSocket-Extensions"))
.isEqualTo("permessage-deflate");
return chain.proceed(chain.request());
})
.build();
@Test public void compressedMessages() throws Exception {
successfulExtensions("permessage-deflate");
}
@Test public void compressedMessagesNoClientContextTakeover() throws Exception {
successfulExtensions("permessage-deflate; client_no_context_takeover");
}
@Test public void compressedMessagesNoServerContextTakeover() throws Exception {
successfulExtensions("permessage-deflate; server_no_context_takeover");
}
@Test public void unexpectedExtensionParameter() throws Exception {
extensionNegotiationFailure("permessage-deflate; unknown_parameter=15");
}
@Test public void clientMaxWindowBitsIncluded() throws Exception {
extensionNegotiationFailure("permessage-deflate; client_max_window_bits=15");
}
@Test public void serverMaxWindowBitsTooLow() throws Exception {
extensionNegotiationFailure("permessage-deflate; server_max_window_bits=7");
}
@Test public void serverMaxWindowBitsTooHigh() throws Exception {
extensionNegotiationFailure("permessage-deflate; server_max_window_bits=16");
}
@Test public void serverMaxWindowBitsJustRight() throws Exception {
successfulExtensions("permessage-deflate; server_max_window_bits=15");
}
private void successfulExtensions(String extensionsHeader) throws Exception {
webServer.enqueue(new MockResponse()
.addHeader("Sec-WebSocket-Extensions", "permessage-deflate")
.addHeader("Sec-WebSocket-Extensions", extensionsHeader)
.withWebSocketUpgrade(serverListener));
WebSocket client = newWebSocket();
clientListener.assertOpen();
WebSocket server = serverListener.assertOpen();
server.send("Hello this is a compressed message from the server!");
clientListener.assertTextMessage("Hello this is a compressed message from the server!");
// Server to client message big enough to be compressed.
String message1 = TestUtil.repeat('a', (int) RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE);
server.send(message1);
clientListener.assertTextMessage(message1);
client.send("Hello this is a compressed message from the client!");
serverListener.assertTextMessage("Hello this is a compressed message from the client!");
// Client to server message big enough to be compressed.
String message2 = TestUtil.repeat('b', (int) RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE);
client.send(message2);
serverListener.assertTextMessage(message2);
// Empty server to client message.
String message3 = "";
server.send(message3);
clientListener.assertTextMessage(message3);
// Empty client to server message.
String message4 = "";
client.send(message4);
serverListener.assertTextMessage(message4);
// Server to client message that shares context with message1.
String message5 = message1 + message1;
server.send(message5);
clientListener.assertTextMessage(message5);
// Client to server message that shares context with message2.
String message6 = message2 + message2;
client.send(message6);
serverListener.assertTextMessage(message6);
closeWebSockets(client, server);
RecordedRequest upgradeRequest = webServer.takeRequest();
assertThat(upgradeRequest.getHeader("Sec-WebSocket-Extensions"))
.isEqualTo("permessage-deflate");
}
private void extensionNegotiationFailure(String extensionsHeader) throws Exception {
webServer.enqueue(new MockResponse()
.addHeader("Sec-WebSocket-Extensions", extensionsHeader)
.withWebSocketUpgrade(serverListener));
newWebSocket();
clientListener.assertOpen();
WebSocket server = serverListener.assertOpen();
String clientReason = "unexpected Sec-WebSocket-Extensions in response header";
serverListener.assertClosing(1010, clientReason);
server.close(1010, "");
clientListener.assertClosing(1010, "");
clientListener.assertClosed(1010, "");
serverListener.assertClosed(1010, clientReason);
clientListener.assertExhausted();
serverListener.assertExhausted();
}
private MockResponse upgradeResponse(RecordedRequest request) {