1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-11-09 10:00:57 +03:00

Implement handshake done and cipher dropping.

Summary: This implements the handshake done signal and also cipher dropping.

Reviewed By: yangchi

Differential Revision: D19584922

fbshipit-source-id: a98bec8f1076393b051ff65a2d8aae7d572b42f5
This commit is contained in:
Matt Joras
2020-02-27 12:16:23 -08:00
committed by Facebook Github Bot
parent 49d262c84c
commit 472e40a902
26 changed files with 405 additions and 311 deletions

View File

@@ -62,7 +62,6 @@ constexpr uint64_t kDefaultBufferSpaceAvailable =
constexpr std::chrono::microseconds kDefaultMinRtt = constexpr std::chrono::microseconds kDefaultMinRtt =
std::chrono::microseconds::max(); std::chrono::microseconds::max();
// Frames types with values defines in Quic Draft 15+
enum class FrameType : uint8_t { enum class FrameType : uint8_t {
PADDING = 0x00, PADDING = 0x00,
PING = 0x01, PING = 0x01,
@@ -97,8 +96,9 @@ enum class FrameType : uint8_t {
CONNECTION_CLOSE = 0x1C, CONNECTION_CLOSE = 0x1C,
// CONNECTION_CLOSE_APP_ERR frametype is use to indicate application errors // CONNECTION_CLOSE_APP_ERR frametype is use to indicate application errors
CONNECTION_CLOSE_APP_ERR = 0x1D, CONNECTION_CLOSE_APP_ERR = 0x1D,
MIN_STREAM_DATA = 0xFE, // subject to change (https://fburl.com/qpr) HANDSHAKE_DONE = 0x1E,
EXPIRED_STREAM_DATA = 0xFF, // subject to change (https://fburl.com/qpr) MIN_STREAM_DATA = 0xFE, // subject to change
EXPIRED_STREAM_DATA = 0xFF, // subject to change
}; };
inline constexpr uint16_t toFrameError(FrameType frame) { inline constexpr uint16_t toFrameError(FrameType frame) {
@@ -187,7 +187,7 @@ enum class QuicVersion : uint32_t {
MVFST = 0xfaceb001, MVFST = 0xfaceb001,
QUIC_DRAFT_22 = 0xFF000016, QUIC_DRAFT_22 = 0xFF000016,
QUIC_DRAFT_23 = 0xFF000017, QUIC_DRAFT_23 = 0xFF000017,
QUIC_DRAFT = 0xFF000018, // Draft-24 QUIC_DRAFT = 0xFF000019, // Draft-25
MVFST_INVALID = 0xfaceb00f, MVFST_INVALID = 0xfaceb00f,
}; };
@@ -384,10 +384,6 @@ constexpr std::chrono::milliseconds kHappyEyeballsConnAttemptDelayWithCache =
constexpr size_t kMaxNumTokenSourceAddresses = 3; constexpr size_t kMaxNumTokenSourceAddresses = 3;
// Amount of time to retain initial keys until they are dropped after handshake
// completion.
constexpr std::chrono::seconds kTimeToRetainInitialKeys = 20s;
// Amount of time to retain zero rtt keys until they are dropped after handshake // Amount of time to retain zero rtt keys until they are dropped after handshake
// completion. // completion.
constexpr std::chrono::seconds kTimeToRetainZeroRttKeys = 20s; constexpr std::chrono::seconds kTimeToRetainZeroRttKeys = 20s;

View File

@@ -306,13 +306,13 @@ void QuicClientTransport::processPacketData(
// If we received an ack for data that we sent in 1-rtt from // If we received an ack for data that we sent in 1-rtt from
// the server, we can assume that the server had successfully // the server, we can assume that the server had successfully
// derived the 1-rtt keys and hence received the client // derived the 1-rtt keys and hence received the client
// finished message. Thus we don't need to retransmit any of // finished message. Thus we can drop the ciphers and cancel
// the crypto data any longer. // the handshake stream.
// DCHECK(conn_->oneRttWriteCipher);
// This will not cancel oneRttStream. DCHECK(conn_->oneRttWriteHeaderCipher);
// if (conn_->handshakeWriteCipher) {
// TODO: replace this with a better solution later. handshakeConfirmed(*conn_);
cancelHandshakeCryptoStreamRetransmissions(*conn_->cryptoState); }
} }
switch (packetFrame.type()) { switch (packetFrame.type()) {
case QuicWriteFrame::Type::WriteAckFrame_E: { case QuicWriteFrame::Type::WriteAckFrame_E: {
@@ -516,43 +516,54 @@ void QuicClientTransport::processPacketData(
handshakeLayer->doHandshake(std::move(cryptoData), encryptionLevel); handshakeLayer->doHandshake(std::move(cryptoData), encryptionLevel);
auto handshakeWriteCipher = handshakeLayer->getHandshakeWriteCipher(); auto handshakeWriteCipher = handshakeLayer->getHandshakeWriteCipher();
auto handshakeReadCipher = handshakeLayer->getHandshakeReadCipher(); auto handshakeReadCipher = handshakeLayer->getHandshakeReadCipher();
auto handshakeReadHeaderCipher =
handshakeLayer->getHandshakeReadHeaderCipher();
auto handshakeWriteHeaderCipher = auto handshakeWriteHeaderCipher =
handshakeLayer->getHandshakeWriteHeaderCipher(); handshakeLayer->getHandshakeWriteHeaderCipher();
auto handshakeReadHeaderCipher =
handshakeLayer->getHandshakeReadHeaderCipher();
if (handshakeWriteCipher) { if (handshakeWriteCipher) {
CHECK(handshakeWriteHeaderCipher);
conn_->handshakeWriteCipher = std::move(handshakeWriteCipher); conn_->handshakeWriteCipher = std::move(handshakeWriteCipher);
}
if (handshakeWriteHeaderCipher) {
conn_->handshakeWriteHeaderCipher = std::move(handshakeWriteHeaderCipher); conn_->handshakeWriteHeaderCipher = std::move(handshakeWriteHeaderCipher);
} }
if (handshakeReadCipher) { if (handshakeReadCipher) {
CHECK(handshakeReadHeaderCipher);
conn_->readCodec->setHandshakeReadCipher(std::move(handshakeReadCipher)); conn_->readCodec->setHandshakeReadCipher(std::move(handshakeReadCipher));
}
if (handshakeReadHeaderCipher) {
conn_->readCodec->setHandshakeHeaderCipher( conn_->readCodec->setHandshakeHeaderCipher(
std::move(handshakeReadHeaderCipher)); std::move(handshakeReadHeaderCipher));
} }
if (conn_->handshakeWriteCipher &&
conn_->readCodec->getHandshakeReadCipher()) {
// We can now drop the initial ciphers.
conn_->initialWriteCipher.reset();
conn_->initialHeaderCipher.reset();
conn_->readCodec->setInitialReadCipher(nullptr);
conn_->readCodec->setInitialHeaderCipher(nullptr);
cancelCryptoStream(conn_->cryptoState->initialStream);
}
auto oneRttWriteCipher = handshakeLayer->getOneRttWriteCipher(); auto oneRttWriteCipher = handshakeLayer->getOneRttWriteCipher();
auto oneRttReadCipher = handshakeLayer->getOneRttReadCipher(); auto oneRttReadCipher = handshakeLayer->getOneRttReadCipher();
auto oneRttReadHeaderCipher = handshakeLayer->getOneRttReadHeaderCipher(); auto oneRttReadHeaderCipher = handshakeLayer->getOneRttReadHeaderCipher();
auto oneRttWriteHeaderCipher = handshakeLayer->getOneRttWriteHeaderCipher(); auto oneRttWriteHeaderCipher = handshakeLayer->getOneRttWriteHeaderCipher();
bool oneRttKeyDerivationTriggered = false; bool oneRttKeyDerivationTriggered = false;
if (oneRttWriteCipher) { if (oneRttWriteCipher) {
CHECK(oneRttWriteHeaderCipher);
conn_->oneRttWriteCipher = std::move(oneRttWriteCipher); conn_->oneRttWriteCipher = std::move(oneRttWriteCipher);
conn_->oneRttWriteHeaderCipher = std::move(oneRttWriteHeaderCipher);
oneRttKeyDerivationTriggered = true; oneRttKeyDerivationTriggered = true;
updatePacingOnKeyEstablished(*conn_); updatePacingOnKeyEstablished(*conn_);
} }
if (oneRttWriteHeaderCipher) {
conn_->oneRttWriteHeaderCipher = std::move(oneRttWriteHeaderCipher);
}
if (oneRttReadCipher) { if (oneRttReadCipher) {
CHECK(oneRttReadHeaderCipher);
conn_->readCodec->setOneRttReadCipher(std::move(oneRttReadCipher)); conn_->readCodec->setOneRttReadCipher(std::move(oneRttReadCipher));
}
if (oneRttReadHeaderCipher) {
conn_->readCodec->setOneRttHeaderCipher( conn_->readCodec->setOneRttHeaderCipher(
std::move(oneRttReadHeaderCipher)); std::move(oneRttReadHeaderCipher));
} }
if (oneRttWriteCipher && oneRttReadCipher) {
conn_->zeroRttWriteCipher.reset();
conn_->zeroRttWriteHeaderCipher.reset();
conn_->readCodec->setZeroRttReadCipher(nullptr);
conn_->readCodec->setZeroRttHeaderCipher(nullptr);
}
bool zeroRttRejected = handshakeLayer->getZeroRttRejected().value_or(false); bool zeroRttRejected = handshakeLayer->getZeroRttRejected().value_or(false);
if (zeroRttRejected) { if (zeroRttRejected) {
if (conn_->qLogger) { if (conn_->qLogger) {
@@ -643,12 +654,6 @@ void QuicClientTransport::processPacketData(
markZeroRttPacketsLost(*conn_, markPacketLoss); markZeroRttPacketsLost(*conn_, markPacketLoss);
} }
} }
if (protectionLevel == ProtectionType::KeyPhaseZero ||
protectionLevel == ProtectionType::KeyPhaseOne) {
DCHECK(conn_->oneRttWriteCipher);
clientConn_->clientHandshakeLayer->onRecvOneRttProtectedData();
conn_->readCodec->onHandshakeDone(receiveTimePoint);
}
updateAckSendStateOnRecvPacket( updateAckSendStateOnRecvPacket(
*conn_, *conn_,
ackState, ackState,
@@ -696,7 +701,6 @@ void QuicClientTransport::writeData() {
// TODO: replace with write in state machine. // TODO: replace with write in state machine.
// TODO: change to draining when we move the client to have a draining state // TODO: change to draining when we move the client to have a draining state
// as well. // as well.
auto phase = clientConn_->clientHandshakeLayer->getPhase();
QuicVersion version = conn_->version.value_or(*conn_->originalVersion); QuicVersion version = conn_->version.value_or(*conn_->originalVersion);
const ConnectionId& srcConnId = *conn_->clientConnectionId; const ConnectionId& srcConnId = *conn_->clientConnectionId;
const ConnectionId* destConnId = const ConnectionId* destConnId =
@@ -705,29 +709,39 @@ void QuicClientTransport::writeData() {
destConnId = &(*conn_->serverConnectionId); destConnId = &(*conn_->serverConnectionId);
} }
if (closeState_ == CloseState::CLOSED) { if (closeState_ == CloseState::CLOSED) {
// TODO: get rid of phase if (conn_->initialWriteCipher) {
if (phase == ClientHandshake::Phase::Established &&
conn_->oneRttWriteCipher) {
CHECK(conn_->oneRttWriteHeaderCipher);
writeShortClose(
*socket_,
*conn_,
*destConnId /* dst */,
conn_->localConnectionError,
*conn_->oneRttWriteCipher,
*conn_->oneRttWriteHeaderCipher);
} else if (conn_->initialWriteCipher) {
CHECK(conn_->initialHeaderCipher); CHECK(conn_->initialHeaderCipher);
writeLongClose( writeLongClose(
*socket_, *socket_,
*conn_, *conn_,
srcConnId /* src */, srcConnId,
*destConnId /* dst */, *destConnId,
LongHeader::Types::Initial, LongHeader::Types::Initial,
conn_->localConnectionError, conn_->localConnectionError,
*conn_->initialWriteCipher, *conn_->initialWriteCipher,
*conn_->initialHeaderCipher, *conn_->initialHeaderCipher,
version); version);
} else if (conn_->handshakeWriteCipher) {
CHECK(conn_->handshakeWriteHeaderCipher);
writeLongClose(
*socket_,
*conn_,
srcConnId,
*destConnId,
LongHeader::Types::Handshake,
conn_->localConnectionError,
*conn_->handshakeWriteCipher,
*conn_->handshakeWriteHeaderCipher,
version);
} else if (conn_->oneRttWriteCipher) {
CHECK(conn_->oneRttWriteHeaderCipher);
writeShortClose(
*socket_,
*conn_,
*destConnId,
conn_->localConnectionError,
*conn_->oneRttWriteCipher,
*conn_->oneRttWriteHeaderCipher);
} }
return; return;
} }
@@ -736,15 +750,14 @@ void QuicClientTransport::writeData() {
(isConnectionPaced(*conn_) (isConnectionPaced(*conn_)
? conn_->pacer->updateAndGetWriteBatchSize(Clock::now()) ? conn_->pacer->updateAndGetWriteBatchSize(Clock::now())
: conn_->transportSettings.writeConnectionDataPacketsLimit); : conn_->transportSettings.writeConnectionDataPacketsLimit);
if (conn_->initialWriteCipher) {
CryptoStreamScheduler initialScheduler( CryptoStreamScheduler initialScheduler(
*conn_, *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial));
CryptoStreamScheduler handshakeScheduler(
*conn_, *conn_,
*getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake)); *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial));
if (initialScheduler.hasData() || if (initialScheduler.hasData() ||
(conn_->ackStates.initialAckState.needsToSendAckImmediately && (conn_->ackStates.initialAckState.needsToSendAckImmediately &&
hasAcksToSchedule(conn_->ackStates.initialAckState))) { hasAcksToSchedule(conn_->ackStates.initialAckState))) {
CHECK(conn_->initialWriteCipher);
CHECK(conn_->initialHeaderCipher); CHECK(conn_->initialHeaderCipher);
packetLimit -= writeCryptoAndAckDataToSocket( packetLimit -= writeCryptoAndAckDataToSocket(
*socket_, *socket_,
@@ -761,10 +774,14 @@ void QuicClientTransport::writeData() {
if (!packetLimit) { if (!packetLimit) {
return; return;
} }
}
if (conn_->handshakeWriteCipher) {
CryptoStreamScheduler handshakeScheduler(
*conn_,
*getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake));
if (handshakeScheduler.hasData() || if (handshakeScheduler.hasData() ||
(conn_->ackStates.handshakeAckState.needsToSendAckImmediately && (conn_->ackStates.handshakeAckState.needsToSendAckImmediately &&
hasAcksToSchedule(conn_->ackStates.handshakeAckState))) { hasAcksToSchedule(conn_->ackStates.handshakeAckState))) {
CHECK(conn_->handshakeWriteCipher);
CHECK(conn_->handshakeWriteHeaderCipher); CHECK(conn_->handshakeWriteHeaderCipher);
packetLimit -= writeCryptoAndAckDataToSocket( packetLimit -= writeCryptoAndAckDataToSocket(
*socket_, *socket_,
@@ -780,6 +797,7 @@ void QuicClientTransport::writeData() {
if (!packetLimit) { if (!packetLimit) {
return; return;
} }
}
if (clientConn_->zeroRttWriteCipher && !conn_->oneRttWriteCipher) { if (clientConn_->zeroRttWriteCipher && !conn_->oneRttWriteCipher) {
CHECK(clientConn_->zeroRttWriteHeaderCipher); CHECK(clientConn_->zeroRttWriteHeaderCipher);
packetLimit -= writeZeroRttDataToSocket( packetLimit -= writeZeroRttDataToSocket(

View File

@@ -120,15 +120,9 @@ ClientHandshake::getZeroRttWriteHeaderCipher() {
return std::move(zeroRttWriteHeaderCipher_); return std::move(zeroRttWriteHeaderCipher_);
} }
/** void ClientHandshake::handshakeConfirmed() {
* Notify the crypto layer that we received one rtt protected data.
* This allows us to know that the peer has implicitly acked the 1-rtt keys.
*/
void ClientHandshake::onRecvOneRttProtectedData() {
if (phase_ != Phase::Established) {
phase_ = Phase::Established; phase_ = Phase::Established;
} }
}
ClientHandshake::Phase ClientHandshake::getPhase() const { ClientHandshake::Phase ClientHandshake::getPhase() const {
return phase_; return phase_;

View File

@@ -122,10 +122,9 @@ class ClientHandshake : public Handshake {
virtual const CryptoFactory& getCryptoFactory() const = 0; virtual const CryptoFactory& getCryptoFactory() const = 0;
/** /**
* Notify the crypto layer that we received one rtt protected data. * Triggered when we have received a handshake done frame from the server.
* This allows us to know that the peer has implicitly acked the 1-rtt keys.
*/ */
void onRecvOneRttProtectedData(); void handshakeConfirmed() override;
Phase getPhase() const; Phase getPhase() const;

View File

@@ -294,7 +294,7 @@ TEST_F(ClientHandshakeTest, TestHandshakeSuccess) {
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived); EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
handshake->onRecvOneRttProtectedData(); handshake->handshakeConfirmed();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established); EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established);
EXPECT_FALSE(zeroRttRejected.has_value()); EXPECT_FALSE(zeroRttRejected.has_value());
EXPECT_TRUE(handshakeSuccess); EXPECT_TRUE(handshakeSuccess);
@@ -497,7 +497,7 @@ TEST_F(ClientHandshakeZeroRttTest, TestZeroRttSuccess) {
EXPECT_FALSE(zeroRttRejected.has_value()); EXPECT_FALSE(zeroRttRejected.has_value());
expectZeroRttCipher(true, true); expectZeroRttCipher(true, true);
clientServerRound(); clientServerRound();
handshake->onRecvOneRttProtectedData(); handshake->handshakeConfirmed();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established); EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established);
EXPECT_EQ(handshake->getApplicationProtocol(), "h1q-fb"); EXPECT_EQ(handshake->getApplicationProtocol(), "h1q-fb");
} }
@@ -521,7 +521,7 @@ TEST_F(ClientHandshakeZeroRttReject, TestZeroRttRejection) {
// We will still keep the zero rtt key lying around. // We will still keep the zero rtt key lying around.
expectZeroRttCipher(true, true); expectZeroRttCipher(true, true);
clientServerRound(); clientServerRound();
handshake->onRecvOneRttProtectedData(); handshake->handshakeConfirmed();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established); EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established);
} }

View File

@@ -372,18 +372,14 @@ QuicClientTransportIntegrationTest::sendRequestAndResponse(
auto streamData = new StreamData(streamId); auto streamData = new StreamData(streamId);
auto dataCopy = std::shared_ptr<folly::IOBuf>(std::move(data)); auto dataCopy = std::shared_ptr<folly::IOBuf>(std::move(data));
EXPECT_CALL(*readCallback, readAvailable(streamId)) EXPECT_CALL(*readCallback, readAvailable(streamId))
.WillRepeatedly(Invoke([c = client.get(), .WillRepeatedly(
id = streamId, Invoke([c = client.get(), id = streamId, streamData, dataCopy](
streamData, auto) mutable {
dataCopy](auto) mutable {
EXPECT_EQ(
dynamic_cast<ClientHandshake*>(c->getConn().handshakeLayer.get())
->getPhase(),
ClientHandshake::Phase::Established);
auto readData = c->read(id, 1000); auto readData = c->read(id, 1000);
auto copy = readData->first->clone(); auto copy = readData->first->clone();
LOG(INFO) << "Client received data=" LOG(INFO) << "Client received data="
<< copy->moveToFbString().toStdString() << " on stream=" << id << copy->moveToFbString().toStdString()
<< " on stream=" << id
<< " read=" << readData->first->computeChainDataLength() << " read=" << readData->first->computeChainDataLength()
<< " sent=" << dataCopy->computeChainDataLength(); << " sent=" << dataCopy->computeChainDataLength();
streamData->append(std::move(readData->first), readData->second); streamData->append(std::move(readData->first), readData->second);
@@ -428,6 +424,16 @@ TEST_P(QuicClientTransportIntegrationTest, NetworkTest) {
auto expected = std::shared_ptr<IOBuf>(IOBuf::copyBuffer("echo ")); auto expected = std::shared_ptr<IOBuf>(IOBuf::copyBuffer("echo "));
expected->prependChain(data->clone()); expected->prependChain(data->clone());
sendRequestAndResponseAndWait(*expected, data->clone(), streamId, &readCb); sendRequestAndResponseAndWait(*expected, data->clone(), streamId, &readCb);
if (getVersion() == QuicVersion::QUIC_DRAFT) {
EXPECT_EQ(client->getConn().initialWriteCipher, nullptr);
EXPECT_EQ(client->getConn().initialHeaderCipher, nullptr);
EXPECT_EQ(client->getConn().handshakeWriteCipher, nullptr);
EXPECT_EQ(client->getConn().handshakeWriteHeaderCipher, nullptr);
EXPECT_EQ(client->getConn().readCodec->getInitialCipher(), nullptr);
EXPECT_EQ(client->getConn().readCodec->getInitialHeaderCipher(), nullptr);
EXPECT_EQ(client->getConn().readCodec->getHandshakeReadCipher(), nullptr);
EXPECT_EQ(client->getConn().readCodec->getHandshakeHeaderCipher(), nullptr);
}
} }
TEST_P(QuicClientTransportIntegrationTest, FlowControlLimitedTest) { TEST_P(QuicClientTransportIntegrationTest, FlowControlLimitedTest) {
@@ -1223,6 +1229,10 @@ class FakeOneRttHandshakeLayer : public ClientHandshake {
void doHandshake(std::unique_ptr<folly::IOBuf>, EncryptionLevel) override { void doHandshake(std::unique_ptr<folly::IOBuf>, EncryptionLevel) override {
EXPECT_EQ(writeBuf.get(), nullptr); EXPECT_EQ(writeBuf.get(), nullptr);
if (getPhase() == Phase::Initial) { if (getPhase() == Phase::Initial) {
handshakeWriteCipher_ = test::createNoOpAead();
handshakeWriteHeaderCipher_ = test::createNoOpHeaderCipher();
handshakeReadCipher_ = test::createNoOpAead();
handshakeReadHeaderCipher_ = test::createNoOpHeaderCipher();
writeDataToQuicStream( writeDataToQuicStream(
conn_->cryptoState->handshakeStream, conn_->cryptoState->handshakeStream,
IOBuf::copyBuffer("ClientFinished")); IOBuf::copyBuffer("ClientFinished"));
@@ -1371,17 +1381,13 @@ class QuicClientTransportTest : public Test {
virtual void setFakeHandshakeCiphers() { virtual void setFakeHandshakeCiphers() {
auto readAead = test::createNoOpAead(); auto readAead = test::createNoOpAead();
auto writeAead = test::createNoOpAead(); auto writeAead = test::createNoOpAead();
auto handshakeReadAead = test::createNoOpAead(); mockClientHandshake->setHandshakeReadCipher(nullptr);
auto handshakeWriteAead = test::createNoOpAead(); mockClientHandshake->setHandshakeWriteCipher(nullptr);
mockClientHandshake->setHandshakeReadCipher(std::move(handshakeReadAead));
mockClientHandshake->setHandshakeWriteCipher(std::move(handshakeWriteAead));
mockClientHandshake->setOneRttReadCipher(std::move(readAead)); mockClientHandshake->setOneRttReadCipher(std::move(readAead));
mockClientHandshake->setOneRttWriteCipher(std::move(writeAead)); mockClientHandshake->setOneRttWriteCipher(std::move(writeAead));
mockClientHandshake->setHandshakeReadHeaderCipher( mockClientHandshake->setHandshakeReadHeaderCipher(nullptr);
test::createNoOpHeaderCipher()); mockClientHandshake->setHandshakeWriteHeaderCipher(nullptr);
mockClientHandshake->setHandshakeWriteHeaderCipher(
test::createNoOpHeaderCipher());
mockClientHandshake->setOneRttWriteHeaderCipher( mockClientHandshake->setOneRttWriteHeaderCipher(
test::createNoOpHeaderCipher()); test::createNoOpHeaderCipher());
mockClientHandshake->setOneRttReadHeaderCipher( mockClientHandshake->setOneRttReadHeaderCipher(
@@ -1512,12 +1518,16 @@ class QuicClientTransportTest : public Test {
void verifyCiphers() { void verifyCiphers() {
EXPECT_NE(client->getConn().oneRttWriteCipher, nullptr); EXPECT_NE(client->getConn().oneRttWriteCipher, nullptr);
EXPECT_NE(client->getConn().oneRttWriteHeaderCipher, nullptr);
EXPECT_NE(client->getConn().handshakeWriteCipher, nullptr); EXPECT_NE(client->getConn().handshakeWriteCipher, nullptr);
EXPECT_NE(client->getConn().handshakeWriteHeaderCipher, nullptr); EXPECT_NE(client->getConn().handshakeWriteHeaderCipher, nullptr);
EXPECT_NE(client->getConn().oneRttWriteHeaderCipher, nullptr); EXPECT_EQ(client->getConn().initialWriteCipher, nullptr);
EXPECT_EQ(client->getConn().initialHeaderCipher, nullptr);
EXPECT_NE(client->getConn().readCodec->getHandshakeHeaderCipher(), nullptr); EXPECT_NE(client->getConn().readCodec->getOneRttReadCipher(), nullptr);
EXPECT_NE(client->getConn().readCodec->getOneRttHeaderCipher(), nullptr); EXPECT_NE(client->getConn().readCodec->getOneRttHeaderCipher(), nullptr);
EXPECT_NE(client->getConn().readCodec->getHandshakeReadCipher(), nullptr);
EXPECT_NE(client->getConn().readCodec->getHandshakeHeaderCipher(), nullptr);
} }
void deliverDataWithoutErrorCheck( void deliverDataWithoutErrorCheck(
@@ -1587,12 +1597,12 @@ class QuicClientTransportTest : public Test {
} }
} }
if (shortHeader) { if (shortHeader) {
EXPECT_GT(numShort, 0); ASSERT_GT(numShort, 0);
} }
if (longHeader) { if (longHeader) {
EXPECT_GT(numLong, 0); CHECK_GT(numLong, 0);
} }
EXPECT_EQ(numOthers, 0); ASSERT_EQ(numOthers, 0);
} }
RegularQuicPacket* parseRegularQuicPacket(CodecResult& codecResult) { RegularQuicPacket* parseRegularQuicPacket(CodecResult& codecResult) {
@@ -1787,6 +1797,8 @@ TEST_F(QuicClientTransportTest, AddNewPeerAddressSetsPacketSize) {
TEST_F(QuicClientTransportTest, onNetworkSwitchNoReplace) { TEST_F(QuicClientTransportTest, onNetworkSwitchNoReplace) {
client->getNonConstConn().oneRttWriteCipher = test::createNoOpAead(); client->getNonConstConn().oneRttWriteCipher = test::createNoOpAead();
client->getNonConstConn().oneRttWriteHeaderCipher =
test::createNoOpHeaderCipher();
auto mockQLogger = std::make_shared<MockQLogger>(VantagePoint::Client); auto mockQLogger = std::make_shared<MockQLogger>(VantagePoint::Client);
client->setQLogger(mockQLogger); client->setQLogger(mockQLogger);
@@ -1797,6 +1809,8 @@ TEST_F(QuicClientTransportTest, onNetworkSwitchNoReplace) {
TEST_F(QuicClientTransportTest, onNetworkSwitchReplaceAfterHandshake) { TEST_F(QuicClientTransportTest, onNetworkSwitchReplaceAfterHandshake) {
client->getNonConstConn().oneRttWriteCipher = test::createNoOpAead(); client->getNonConstConn().oneRttWriteCipher = test::createNoOpAead();
client->getNonConstConn().oneRttWriteHeaderCipher =
test::createNoOpHeaderCipher();
auto mockQLogger = std::make_shared<MockQLogger>(VantagePoint::Client); auto mockQLogger = std::make_shared<MockQLogger>(VantagePoint::Client);
client->setQLogger(mockQLogger); client->setQLogger(mockQLogger);
@@ -3126,6 +3140,7 @@ TEST_P(QuicClientTransportAfterStartTest, ReadStreamCoalesced) {
auto garbage = IOBuf::copyBuffer("garbage"); auto garbage = IOBuf::copyBuffer("garbage");
auto initialCipher = cryptoFactory.getServerInitialCipher( auto initialCipher = cryptoFactory.getServerInitialCipher(
*serverChosenConnId, QuicVersion::MVFST); *serverChosenConnId, QuicVersion::MVFST);
auto initialHeaderCipher = test::createNoOpHeaderCipher();
auto firstPacketNum = appDataPacketNum++; auto firstPacketNum = appDataPacketNum++;
auto packet1 = packetToBufCleartext( auto packet1 = packetToBufCleartext(
createStreamPacket( createStreamPacket(
@@ -3138,7 +3153,7 @@ TEST_P(QuicClientTransportAfterStartTest, ReadStreamCoalesced) {
0 /* largestAcked */, 0 /* largestAcked */,
std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)), std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)),
*initialCipher, *initialCipher,
getInitialHeaderCipher(), *initialHeaderCipher,
firstPacketNum); firstPacketNum);
packet1->coalesce(); packet1->coalesce();
auto packet2 = packetToBuf(createStreamPacket( auto packet2 = packetToBuf(createStreamPacket(
@@ -3177,6 +3192,7 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStreamCoalescedMany) {
auto garbage = IOBuf::copyBuffer("garbage"); auto garbage = IOBuf::copyBuffer("garbage");
auto initialCipher = cryptoFactory.getServerInitialCipher( auto initialCipher = cryptoFactory.getServerInitialCipher(
*serverChosenConnId, QuicVersion::MVFST); *serverChosenConnId, QuicVersion::MVFST);
auto initialHeaderCipher = test::createNoOpHeaderCipher();
auto packetNum = appDataPacketNum++; auto packetNum = appDataPacketNum++;
auto packet1 = packetToBufCleartext( auto packet1 = packetToBufCleartext(
createStreamPacket( createStreamPacket(
@@ -3189,7 +3205,7 @@ TEST_F(QuicClientTransportAfterStartTest, ReadStreamCoalescedMany) {
0 /* largestAcked */, 0 /* largestAcked */,
std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)), std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)),
*initialCipher, *initialCipher,
getInitialHeaderCipher(), *initialHeaderCipher,
packetNum); packetNum);
packets.append(std::move(packet1)); packets.append(std::move(packet1));
} }
@@ -3261,6 +3277,32 @@ TEST_F(QuicClientTransportAfterStartTest, RecvPathChallengeAvailablePeerId) {
EXPECT_EQ(pathResponse.pathData, pathChallenge.pathData); EXPECT_EQ(pathResponse.pathData, pathChallenge.pathData);
} }
TEST_F(QuicClientTransportAfterStartTest, HandshakeDoneDrop) {
auto& conn = client->getNonConstConn();
conn.handshakeWriteCipher = test::createNoOpAead();
conn.handshakeWriteHeaderCipher = test::createNoOpHeaderCipher();
conn.readCodec->setHandshakeReadCipher(test::createNoOpAead());
conn.readCodec->setHandshakeHeaderCipher(test::createNoOpHeaderCipher());
conn.cryptoState->handshakeStream.writeBuffer.append(
folly::IOBuf::copyBuffer("blah"));
ShortHeader header(ProtectionType::KeyPhaseZero, *conn.clientConnectionId, 1);
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen, std::move(header), 0 /* largestAcked */);
ASSERT_TRUE(builder.canBuildPacket());
writeSimpleFrame(QuicSimpleFrame(HandshakeDoneFrame()), builder);
auto packet = std::move(builder).buildPacket();
auto data = packetToBuf(packet);
deliverData(data->coalesce(), false);
EXPECT_EQ(conn.handshakeWriteCipher, nullptr);
EXPECT_EQ(conn.handshakeWriteHeaderCipher, nullptr);
EXPECT_EQ(conn.readCodec->getHandshakeReadCipher(), nullptr);
EXPECT_EQ(conn.readCodec->getHandshakeHeaderCipher(), nullptr);
EXPECT_EQ(conn.cryptoState->handshakeStream.writeBuffer.chainLength(), 0);
}
bool verifyFramePresent( bool verifyFramePresent(
std::vector<std::unique_ptr<folly::IOBuf>>& socketWrites, std::vector<std::unique_ptr<folly::IOBuf>>& socketWrites,
QuicReadCodec& readCodec, QuicReadCodec& readCodec,
@@ -3493,31 +3535,9 @@ TEST_F(QuicClientTransportAfterStartTest, RecvRetransmittedHandshakeData) {
TEST_F(QuicClientTransportAfterStartTest, RecvAckOfCryptoStream) { TEST_F(QuicClientTransportAfterStartTest, RecvAckOfCryptoStream) {
// Simulate ack from server // Simulate ack from server
auto& cryptoState = client->getConn().cryptoState; auto& cryptoState = client->getConn().cryptoState;
EXPECT_GT(cryptoState->initialStream.retransmissionBuffer.size(), 0);
EXPECT_GT(cryptoState->handshakeStream.retransmissionBuffer.size(), 0); EXPECT_GT(cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(cryptoState->oneRttStream.retransmissionBuffer.size(), 0); EXPECT_EQ(cryptoState->oneRttStream.retransmissionBuffer.size(), 0);
auto& aead = getInitialCipher();
auto& headerCipher = getInitialHeaderCipher();
// initial
{
AckBlocks acks;
auto start = getFirstOutstandingPacket(
client->getNonConstConn(), PacketNumberSpace::Initial)
->packet.header.getPacketSequenceNum();
auto end = getLastOutstandingPacket(
client->getNonConstConn(), PacketNumberSpace::Initial)
->packet.header.getPacketSequenceNum();
acks.insert(start, end);
auto pn = initialPacketNum++;
auto ackPkt = createAckPacket(
client->getNonConstConn(), pn, acks, PacketNumberSpace::Initial, &aead);
deliverData(
packetToBufCleartext(ackPkt, aead, headerCipher, pn)->coalesce());
EXPECT_EQ(cryptoState->initialStream.retransmissionBuffer.size(), 0);
EXPECT_GT(cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(cryptoState->oneRttStream.retransmissionBuffer.size(), 0);
}
// handshake // handshake
{ {
AckBlocks acks; AckBlocks acks;
@@ -3539,9 +3559,6 @@ TEST_F(QuicClientTransportAfterStartTest, RecvAckOfCryptoStream) {
} }
TEST_F(QuicClientTransportAfterStartTest, RecvOneRttAck) { TEST_F(QuicClientTransportAfterStartTest, RecvOneRttAck) {
EXPECT_GT(
client->getConn().cryptoState->initialStream.retransmissionBuffer.size(),
0);
EXPECT_GT( EXPECT_GT(
client->getConn() client->getConn()
.cryptoState->handshakeStream.retransmissionBuffer.size(), .cryptoState->handshakeStream.retransmissionBuffer.size(),
@@ -3570,9 +3587,6 @@ TEST_F(QuicClientTransportAfterStartTest, RecvOneRttAck) {
deliverData(ackPacket->coalesce()); deliverData(ackPacket->coalesce());
// Should have canceled retransmissions // Should have canceled retransmissions
EXPECT_EQ(
client->getConn().cryptoState->initialStream.retransmissionBuffer.size(),
0);
EXPECT_EQ( EXPECT_EQ(
client->getConn() client->getConn()
.cryptoState->handshakeStream.retransmissionBuffer.size(), .cryptoState->handshakeStream.retransmissionBuffer.size(),
@@ -3604,39 +3618,17 @@ TEST_P(QuicClientTransportAfterStartTestClose, CloseConnectionWithError) {
std::string("stopping"))); std::string("stopping")));
EXPECT_TRUE(verifyFramePresent( EXPECT_TRUE(verifyFramePresent(
socketWrites, socketWrites,
*makeEncryptedCodec(), *makeHandshakeCodec(),
QuicFrame::Type::ConnectionCloseFrame_E)); QuicFrame::Type::ConnectionCloseFrame_E));
} else { } else {
client->close(folly::none); client->close(folly::none);
EXPECT_TRUE(verifyFramePresent( EXPECT_TRUE(verifyFramePresent(
socketWrites, socketWrites,
*makeEncryptedCodec(), *makeHandshakeCodec(),
QuicFrame::Type::ConnectionCloseFrame_E)); QuicFrame::Type::ConnectionCloseFrame_E));
} }
} }
TEST_F(
QuicClientTransportAfterStartTest,
HandshakeCipherTimeoutAfterFirstData) {
StreamId streamId = client->createBidirectionalStream().value();
EXPECT_NE(client->getConn().readCodec->getInitialCipher(), nullptr);
auto expected = IOBuf::copyBuffer("hello");
auto packet = packetToBuf(createStreamPacket(
*serverChosenConnId /* src */,
*originalConnId /* dest */,
appDataPacketNum++,
streamId,
*expected,
0 /* cipherOverhead */,
0 /* largestAcked */,
folly::none,
true));
deliverData(packet->coalesce());
EXPECT_NE(client->getConn().readCodec->getInitialCipher(), nullptr);
EXPECT_TRUE(client->getConn().readCodec->getHandshakeDoneTime().has_value());
}
TEST_F(QuicClientTransportAfterStartTest, IdleTimerResetOnRecvNewData) { TEST_F(QuicClientTransportAfterStartTest, IdleTimerResetOnRecvNewData) {
// spend some time looping the evb // spend some time looping the evb
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
@@ -3860,6 +3852,7 @@ TEST_F(QuicClientTransportAfterStartTest, WrongCleartextCipher) {
auto initialCipher = cryptoFactory.getServerInitialCipher( auto initialCipher = cryptoFactory.getServerInitialCipher(
*serverChosenConnId, QuicVersion::MVFST); *serverChosenConnId, QuicVersion::MVFST);
auto initialHeaderCipher = test::createNoOpHeaderCipher();
auto packet = packetToBufCleartext( auto packet = packetToBufCleartext(
createStreamPacket( createStreamPacket(
*serverChosenConnId /* src */, *serverChosenConnId /* src */,
@@ -3871,7 +3864,7 @@ TEST_F(QuicClientTransportAfterStartTest, WrongCleartextCipher) {
0 /* largestAcked */, 0 /* largestAcked */,
std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)), std::make_pair(LongHeader::Types::Initial, QuicVersion::MVFST)),
*initialCipher, *initialCipher,
getInitialHeaderCipher(), *initialHeaderCipher,
nextPacketNum); nextPacketNum);
deliverData(packet->coalesce()); deliverData(packet->coalesce());
} }

View File

@@ -678,6 +678,10 @@ ExpiredStreamDataFrame decodeExpiredStreamDataFrame(folly::io::Cursor& cursor) {
folly::to<StreamId>(streamId->first), minimumStreamOffset->first); folly::to<StreamId>(streamId->first), minimumStreamOffset->first);
} }
HandshakeDoneFrame decodeHandshakeDoneFrame(folly::io::Cursor& /*cursor*/) {
return HandshakeDoneFrame();
}
QuicFrame parseFrame( QuicFrame parseFrame(
BufQueue& queue, BufQueue& queue,
const PacketHeader& header, const PacketHeader& header,
@@ -766,6 +770,8 @@ QuicFrame parseFrame(
return QuicFrame(decodeMinStreamDataFrame(cursor)); return QuicFrame(decodeMinStreamDataFrame(cursor));
case FrameType::EXPIRED_STREAM_DATA: case FrameType::EXPIRED_STREAM_DATA:
return QuicFrame(decodeExpiredStreamDataFrame(cursor)); return QuicFrame(decodeExpiredStreamDataFrame(cursor));
case FrameType::HANDSHAKE_DONE:
return QuicFrame(decodeHandshakeDoneFrame(cursor));
} }
} catch (const std::exception&) { } catch (const std::exception&) {
error = true; error = true;

View File

@@ -139,6 +139,8 @@ ReadCryptoFrame decodeCryptoFrame(folly::io::Cursor& cursor);
ReadNewTokenFrame decodeNewTokenFrame(folly::io::Cursor& cursor); ReadNewTokenFrame decodeNewTokenFrame(folly::io::Cursor& cursor);
HandshakeDoneFrame decodeHandshakeDoneFrame(folly::io::Cursor& cursor);
/** /**
* Parse the Invariant fields in Long Header. * Parse the Invariant fields in Long Header.
* *

View File

@@ -125,15 +125,12 @@ CodecResult QuicReadCodec::parseLongHeaderPacket(
auto protectionType = longHeader.getProtectionType(); auto protectionType = longHeader.getProtectionType();
switch (protectionType) { switch (protectionType) {
case ProtectionType::Initial: case ProtectionType::Initial:
if (handshakeDoneTime_) { if (!initialHeaderCipher_) {
auto timeBetween = Clock::now() - *handshakeDoneTime_;
if (timeBetween > kTimeToRetainZeroRttKeys) {
VLOG(4) << nodeToString(nodeType_) VLOG(4) << nodeToString(nodeType_)
<< " dropping initial packet for exceeding key timeout" << " dropping initial packet after initial keys dropped"
<< connIdToHex(); << connIdToHex();
return CodecResult(Nothing()); return CodecResult(Nothing());
} }
}
headerCipher = initialHeaderCipher_.get(); headerCipher = initialHeaderCipher_.get();
cipher = initialReadCipher_.get(); cipher = initialReadCipher_.get();
break; break;
@@ -143,6 +140,7 @@ CodecResult QuicReadCodec::parseLongHeaderPacket(
break; break;
case ProtectionType::ZeroRtt: case ProtectionType::ZeroRtt:
if (handshakeDoneTime_) { if (handshakeDoneTime_) {
// TODO actually drop the 0-rtt keys in addition to dropping packets.
auto timeBetween = Clock::now() - *handshakeDoneTime_; auto timeBetween = Clock::now() - *handshakeDoneTime_;
if (timeBetween > kTimeToRetainZeroRttKeys) { if (timeBetween > kTimeToRetainZeroRttKeys) {
VLOG(4) << nodeToString(nodeType_) VLOG(4) << nodeToString(nodeType_)

View File

@@ -494,6 +494,18 @@ size_t writeSimpleFrame(
// no space left in packet // no space left in packet
return size_t(0); return size_t(0);
} }
case QuicSimpleFrame::Type::HandshakeDoneFrame_E: {
const HandshakeDoneFrame& handshakeDoneFrame =
*frame.asHandshakeDoneFrame();
QuicInteger intFrameType(static_cast<uint8_t>(FrameType::HANDSHAKE_DONE));
if (packetSpaceCheck(spaceLeft, intFrameType.getSize())) {
builder.write(intFrameType);
builder.appendFrame(QuicSimpleFrame(handshakeDoneFrame));
return intFrameType.getSize();
}
// no space left in packet
return size_t(0);
}
} }
folly::assume_unreachable(); folly::assume_unreachable();
} }

View File

@@ -396,6 +396,8 @@ std::string toString(FrameType frame) {
return "MIN_STREAM_DATA"; return "MIN_STREAM_DATA";
case FrameType::EXPIRED_STREAM_DATA: case FrameType::EXPIRED_STREAM_DATA:
return "EXPIRED_STREAM_DATA"; return "EXPIRED_STREAM_DATA";
case FrameType::HANDSHAKE_DONE:
return "HANDSHAKE_DONE";
} }
LOG(WARNING) << "toString has unhandled frame type"; LOG(WARNING) << "toString has unhandled frame type";
return "UNKNOWN"; return "UNKNOWN";

View File

@@ -549,6 +549,12 @@ struct ConnectionCloseFrame {
} }
}; };
struct HandshakeDoneFrame {
bool operator==(const HandshakeDoneFrame& /*rhs*/) const {
return true;
}
};
// Frame to represent ones we skip // Frame to represent ones we skip
struct NoopFrame { struct NoopFrame {
bool operator==(const NoopFrame&) const { bool operator==(const NoopFrame&) const {
@@ -572,7 +578,8 @@ struct StatelessReset {
F(NewConnectionIdFrame, __VA_ARGS__) \ F(NewConnectionIdFrame, __VA_ARGS__) \
F(MaxStreamsFrame, __VA_ARGS__) \ F(MaxStreamsFrame, __VA_ARGS__) \
F(RetireConnectionIdFrame, __VA_ARGS__) \ F(RetireConnectionIdFrame, __VA_ARGS__) \
F(PingFrame, __VA_ARGS__) F(PingFrame, __VA_ARGS__) \
F(HandshakeDoneFrame, __VA_ARGS__)
DECLARE_VARIANT_TYPE(QuicSimpleFrame, QUIC_SIMPLE_FRAME) DECLARE_VARIANT_TYPE(QuicSimpleFrame, QUIC_SIMPLE_FRAME)

View File

@@ -534,7 +534,7 @@ TEST_F(QuicReadCodecTest, TestHandshakeDone) {
auto packetQueue = auto packetQueue =
bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum)); bufToQueue(packetToBufCleartext(packet, *aead, *headerCipher, packetNum));
EXPECT_TRUE(parseSuccess(codec->parsePacket(packetQueue, ackStates))); EXPECT_TRUE(parseSuccess(codec->parsePacket(packetQueue, ackStates)));
codec->onHandshakeDone(Clock::now() - kTimeToRetainInitialKeys * 2); codec->onHandshakeDone(Clock::now());
EXPECT_FALSE(parseSuccess(codec->parsePacket(packetQueue, ackStates))); EXPECT_FALSE(parseSuccess(codec->parsePacket(packetQueue, ackStates)));
} }

View File

@@ -24,6 +24,10 @@ class Handshake {
virtual const folly::Optional<std::string>& getApplicationProtocol() virtual const folly::Optional<std::string>& getApplicationProtocol()
const = 0; const = 0;
virtual void handshakeConfirmed() {
LOG(FATAL) << "Not implemented";
}
}; };
constexpr folly::StringPiece kQuicDraft17Salt = constexpr folly::StringPiece kQuicDraft17Salt =

View File

@@ -65,6 +65,10 @@ void addQuicSimpleFrameToEvent(
frame.sequenceNumber)); frame.sequenceNumber));
break; break;
} }
case quic::QuicSimpleFrame::Type::HandshakeDoneFrame_E: {
event->frames.push_back(std::make_unique<quic::HandshakeDoneFrameLog>());
break;
}
} }
} }
} // namespace } // namespace

View File

@@ -219,6 +219,12 @@ folly::dynamic ReadNewTokenFrameLog::toDynamic() const {
return d; return d;
} }
folly::dynamic HandshakeDoneFrameLog::toDynamic() const {
folly::dynamic d = folly::dynamic::object();
d["frame_type"] = toString(FrameType::HANDSHAKE_DONE);
return d;
}
folly::dynamic VersionNegotiationLog::toDynamic() const { folly::dynamic VersionNegotiationLog::toDynamic() const {
folly::dynamic d = folly::dynamic::object(); folly::dynamic d = folly::dynamic::object();
d = folly::dynamic::array(); d = folly::dynamic::array();

View File

@@ -267,8 +267,7 @@ class RetireConnectionIdFrameLog : public QLogFrame {
public: public:
uint64_t sequence; uint64_t sequence;
RetireConnectionIdFrameLog(uint64_t sequenceIn) RetireConnectionIdFrameLog(uint64_t sequenceIn) : sequence(sequenceIn) {}
: sequence(sequenceIn) {}
~RetireConnectionIdFrameLog() override = default; ~RetireConnectionIdFrameLog() override = default;
folly::dynamic toDynamic() const override; folly::dynamic toDynamic() const override;
@@ -281,6 +280,13 @@ class ReadNewTokenFrameLog : public QLogFrame {
folly::dynamic toDynamic() const override; folly::dynamic toDynamic() const override;
}; };
class HandshakeDoneFrameLog : public QLogFrame {
public:
HandshakeDoneFrameLog() = default;
~HandshakeDoneFrameLog() override = default;
folly::dynamic toDynamic() const override;
};
class VersionNegotiationLog { class VersionNegotiationLog {
public: public:
std::vector<QuicVersion> versions; std::vector<QuicVersion> versions;

View File

@@ -542,7 +542,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkCryptoLostAfterCancelRetransmission) {
EXPECT_GT(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0); EXPECT_GT(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
auto& packet = conn->outstandingPackets.front().packet; auto& packet = conn->outstandingPackets.front().packet;
auto packetNum = packet.header.getPacketSequenceNum(); auto packetNum = packet.header.getPacketSequenceNum();
cancelHandshakeCryptoStreamRetransmissions(*conn->cryptoState); cancelCryptoStream(conn->cryptoState->handshakeStream);
markPacketLoss(*conn, packet, false, packetNum); markPacketLoss(*conn, packet, false, packetNum);
EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0); EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(conn->cryptoState->handshakeStream.lossBuffer.size(), 0); EXPECT_EQ(conn->cryptoState->handshakeStream.lossBuffer.size(), 0);
@@ -579,7 +579,7 @@ TEST_F(QuicLossFunctionsTest, TestMarkCryptoLostCancel) {
markPacketLoss(*conn, packet, false, packetNum); markPacketLoss(*conn, packet, false, packetNum);
EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0); EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(conn->cryptoState->handshakeStream.lossBuffer.size(), 1); EXPECT_EQ(conn->cryptoState->handshakeStream.lossBuffer.size(), 1);
cancelHandshakeCryptoStreamRetransmissions(*conn->cryptoState); cancelCryptoStream(conn->cryptoState->handshakeStream);
EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0); EXPECT_EQ(conn->cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(conn->cryptoState->handshakeStream.lossBuffer.size(), 0); EXPECT_EQ(conn->cryptoState->handshakeStream.lossBuffer.size(), 0);
} }

View File

@@ -157,51 +157,50 @@ void QuicServerTransport::writeData() {
return; return;
} }
updateLargestReceivedPacketsAtLastCloseSent(*conn_); updateLargestReceivedPacketsAtLastCloseSent(*conn_);
if (conn_->oneRttWriteCipher && conn_->readCodec->getOneRttReadCipher()) { if (conn_->initialWriteCipher) {
CHECK(conn_->oneRttWriteHeaderCipher);
// We do not process handshake data after we are closed. It is
// possible that we closed the transport while handshake data was
// pending in which case we would not derive the 1-RTT keys. We
// shouldn't send a long header at this point, because the client may
// have already dropped its handshake keys.
writeShortClose(
*socket_,
*conn_,
destConnId /* dst */,
conn_->localConnectionError,
*conn_->oneRttWriteCipher,
*conn_->oneRttWriteHeaderCipher);
} else if (conn_->initialWriteCipher) {
CHECK(conn_->initialHeaderCipher); CHECK(conn_->initialHeaderCipher);
writeLongClose( writeLongClose(
*socket_, *socket_,
*conn_, *conn_,
srcConnId /* src */, srcConnId,
destConnId /* dst */, destConnId,
LongHeader::Types::Initial, LongHeader::Types::Initial,
conn_->localConnectionError, conn_->localConnectionError,
*conn_->initialWriteCipher, *conn_->initialWriteCipher,
*conn_->initialHeaderCipher, *conn_->initialHeaderCipher,
version); version);
} else if (conn_->handshakeWriteCipher) {
CHECK(conn_->handshakeWriteHeaderCipher);
writeLongClose(
*socket_,
*conn_,
srcConnId,
destConnId,
LongHeader::Types::Initial,
conn_->localConnectionError,
*conn_->handshakeWriteCipher,
*conn_->handshakeWriteHeaderCipher,
version);
} else if (conn_->oneRttWriteCipher) {
CHECK(conn_->oneRttWriteHeaderCipher);
writeShortClose(
*socket_,
*conn_,
destConnId,
conn_->localConnectionError,
*conn_->oneRttWriteCipher,
*conn_->oneRttWriteHeaderCipher);
} }
return; return;
} }
if (!conn_->initialWriteCipher) {
// This would be possible if we read a packet from the network which
// could not be parsed later.
return;
}
uint64_t packetLimit = uint64_t packetLimit =
(isConnectionPaced(*conn_) (isConnectionPaced(*conn_)
? conn_->pacer->updateAndGetWriteBatchSize(Clock::now()) ? conn_->pacer->updateAndGetWriteBatchSize(Clock::now())
: conn_->transportSettings.writeConnectionDataPacketsLimit); : conn_->transportSettings.writeConnectionDataPacketsLimit);
if (conn_->initialWriteCipher) {
CryptoStreamScheduler initialScheduler( CryptoStreamScheduler initialScheduler(
*conn_, *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial));
CryptoStreamScheduler handshakeScheduler(
*conn_, *conn_,
*getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake)); *getCryptoStream(*conn_->cryptoState, EncryptionLevel::Initial));
if (initialScheduler.hasData() || if (initialScheduler.hasData() ||
(conn_->ackStates.initialAckState.needsToSendAckImmediately && (conn_->ackStates.initialAckState.needsToSendAckImmediately &&
hasAcksToSchedule(conn_->ackStates.initialAckState))) { hasAcksToSchedule(conn_->ackStates.initialAckState))) {
@@ -221,6 +220,11 @@ void QuicServerTransport::writeData() {
if (!packetLimit) { if (!packetLimit) {
return; return;
} }
}
if (conn_->handshakeWriteCipher) {
CryptoStreamScheduler handshakeScheduler(
*conn_,
*getCryptoStream(*conn_->cryptoState, EncryptionLevel::Handshake));
if (handshakeScheduler.hasData() || if (handshakeScheduler.hasData() ||
(conn_->ackStates.handshakeAckState.needsToSendAckImmediately && (conn_->ackStates.handshakeAckState.needsToSendAckImmediately &&
hasAcksToSchedule(conn_->ackStates.handshakeAckState))) { hasAcksToSchedule(conn_->ackStates.handshakeAckState))) {
@@ -240,6 +244,7 @@ void QuicServerTransport::writeData() {
if (!packetLimit) { if (!packetLimit) {
return; return;
} }
}
if (conn_->oneRttWriteCipher) { if (conn_->oneRttWriteCipher) {
CHECK(conn_->oneRttWriteHeaderCipher); CHECK(conn_->oneRttWriteHeaderCipher);
writeQuicDataToSocket( writeQuicDataToSocket(

View File

@@ -242,25 +242,28 @@ void updateHandshakeState(QuicServerConnectionState& conn) {
} }
auto handshakeWriteCipher = handshakeLayer->getHandshakeWriteCipher(); auto handshakeWriteCipher = handshakeLayer->getHandshakeWriteCipher();
auto handshakeReadCipher = handshakeLayer->getHandshakeReadCipher(); auto handshakeReadCipher = handshakeLayer->getHandshakeReadCipher();
if (handshakeWriteCipher) {
conn.handshakeWriteCipher = std::move(handshakeWriteCipher);
}
if (handshakeReadCipher) {
conn.readCodec->setHandshakeReadCipher(std::move(handshakeReadCipher));
}
auto handshakeWriteHeaderCipher = auto handshakeWriteHeaderCipher =
handshakeLayer->getHandshakeWriteHeaderCipher(); handshakeLayer->getHandshakeWriteHeaderCipher();
auto handshakeReadHeaderCipher = auto handshakeReadHeaderCipher =
handshakeLayer->getHandshakeReadHeaderCipher(); handshakeLayer->getHandshakeReadHeaderCipher();
if (handshakeWriteHeaderCipher) { if (handshakeWriteCipher) {
CHECK(
handshakeReadCipher && handshakeWriteHeaderCipher &&
handshakeReadHeaderCipher);
conn.handshakeWriteCipher = std::move(handshakeWriteCipher);
conn.handshakeWriteHeaderCipher = std::move(handshakeWriteHeaderCipher); conn.handshakeWriteHeaderCipher = std::move(handshakeWriteHeaderCipher);
} conn.readCodec->setHandshakeReadCipher(std::move(handshakeReadCipher));
if (handshakeReadHeaderCipher) {
conn.readCodec->setHandshakeHeaderCipher( conn.readCodec->setHandshakeHeaderCipher(
std::move(handshakeReadHeaderCipher)); std::move(handshakeReadHeaderCipher));
} }
if (handshakeLayer->isHandshakeDone()) { if (handshakeLayer->isHandshakeDone()) {
conn.readCodec->onHandshakeDone(Clock::now()); CHECK(conn.oneRttWriteCipher);
if (conn.handshakeWriteCipher) {
handshakeConfirmed(conn);
if (conn.version == QuicVersion::QUIC_DRAFT) {
sendSimpleFrame(conn, HandshakeDoneFrame());
}
}
} }
} }
@@ -969,6 +972,14 @@ void onServerReadDataFromOpen(
} }
} }
} }
// If we've processed a handshake packet, we can dicard the initial cipher.
if (encryptionLevel == EncryptionLevel::Handshake) {
conn.initialWriteCipher.reset();
conn.initialHeaderCipher.reset();
conn.readCodec->setInitialReadCipher(nullptr);
conn.readCodec->setInitialHeaderCipher(nullptr);
cancelCryptoStream(conn.cryptoState->initialStream);
}
// Update writable limit before processing the handshake data. This is so // Update writable limit before processing the handshake data. This is so
// that if we haven't decided whether or not to validate the peer, we won't // that if we haven't decided whether or not to validate the peer, we won't

View File

@@ -541,36 +541,13 @@ class QuicServerTransportTest : public Test {
EXPECT_TRUE(getCryptoStream( EXPECT_TRUE(getCryptoStream(
*server->getConn().cryptoState, EncryptionLevel::Initial) *server->getConn().cryptoState, EncryptionLevel::Initial)
->readBuffer.empty()); ->readBuffer.empty());
EXPECT_NE(server->getConn().initialWriteCipher, nullptr);
EXPECT_FALSE(server->getConn().localConnectionError.has_value()); EXPECT_FALSE(server->getConn().localConnectionError.has_value());
verifyTransportParameters(kDefaultIdleTimeout); verifyTransportParameters(kDefaultIdleTimeout);
serverWrites.clear(); serverWrites.clear();
// Simulate ack from client
auto& cryptoState = server->getConn().cryptoState; auto& cryptoState = server->getConn().cryptoState;
EXPECT_GT(cryptoState->initialStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(cryptoState->handshakeStream.retransmissionBuffer.size(), 0); EXPECT_EQ(cryptoState->handshakeStream.retransmissionBuffer.size(), 0);
EXPECT_EQ(cryptoState->oneRttStream.retransmissionBuffer.size(), 0); EXPECT_EQ(cryptoState->oneRttStream.retransmissionBuffer.size(), 0);
auto aead = getInitialCipher();
auto headerCipher = getInitialHeaderCipher();
AckBlocks acks;
auto start = getFirstOutstandingPacket(
server->getNonConstConn(), PacketNumberSpace::Initial)
->packet.header.getPacketSequenceNum();
auto end = getLastOutstandingPacket(
server->getNonConstConn(), PacketNumberSpace::Initial)
->packet.header.getPacketSequenceNum();
acks.insert(start, end);
auto pn = clientNextInitialPacketNum++;
auto ackPkt = createAckPacket(
server->getNonConstConn(),
pn,
acks,
PacketNumberSpace::Initial,
aead.get());
deliverData(packetToBufCleartext(ackPkt, *aead, *headerCipher, pn));
EXPECT_EQ(cryptoState->initialStream.retransmissionBuffer.size(), 0);
} }
void verifyTransportParameters(std::chrono::milliseconds idleTimeout) { void verifyTransportParameters(std::chrono::milliseconds idleTimeout) {
@@ -829,6 +806,7 @@ TEST_F(QuicServerTransportTest, IdleTimerNotResetOnDuplicatePacket) {
TEST_F(QuicServerTransportTest, IdleTimerNotResetWhenDataOutstanding) { TEST_F(QuicServerTransportTest, IdleTimerNotResetWhenDataOutstanding) {
// Clear the receivedNewPacketBeforeWrite flag, since we may reveice from // Clear the receivedNewPacketBeforeWrite flag, since we may reveice from
// client during the SetUp of the test case. // client during the SetUp of the test case.
server->getNonConstConn().outstandingPackets.clear();
server->getNonConstConn().receivedNewPacketBeforeWrite = false; server->getNonConstConn().receivedNewPacketBeforeWrite = false;
StreamId streamId = server->createBidirectionalStream().value(); StreamId streamId = server->createBidirectionalStream().value();
@@ -841,18 +819,18 @@ TEST_F(QuicServerTransportTest, IdleTimerNotResetWhenDataOutstanding) {
false); false);
loopForWrites(); loopForWrites();
// It was the first packet // It was the first packet
ASSERT_TRUE(server->idleTimeout().isScheduled()); EXPECT_TRUE(server->idleTimeout().isScheduled());
// cancel it and write something else. This time idle timer shouldn't set. // cancel it and write something else. This time idle timer shouldn't set.
server->idleTimeout().cancelTimeout(); server->idleTimeout().cancelTimeout();
ASSERT_FALSE(server->idleTimeout().isScheduled()); EXPECT_FALSE(server->idleTimeout().isScheduled());
server->writeChain( server->writeChain(
streamId, streamId,
IOBuf::copyBuffer("And if the daylight feels like it's a long way off"), IOBuf::copyBuffer("And if the daylight feels like it's a long way off"),
false, false,
false); false);
loopForWrites(); loopForWrites();
ASSERT_FALSE(server->idleTimeout().isScheduled()); EXPECT_FALSE(server->idleTimeout().isScheduled());
} }
TEST_F(QuicServerTransportTest, TimeoutsNotSetAfterClose) { TEST_F(QuicServerTransportTest, TimeoutsNotSetAfterClose) {
@@ -2991,6 +2969,22 @@ TEST_F(
TEST_F(QuicServerTransportTest, ClientPortChangeNATRebinding) { TEST_F(QuicServerTransportTest, ClientPortChangeNATRebinding) {
server->getNonConstConn().transportSettings.disableMigration = false; server->getNonConstConn().transportSettings.disableMigration = false;
StreamId streamId = server->createBidirectionalStream().value();
auto data1 = IOBuf::copyBuffer("Aloha");
server->writeChain(streamId, data1->clone(), false, false);
loopForWrites();
PacketNum packetNum1 =
getFirstOutstandingPacket(
server->getNonConstConn(), PacketNumberSpace::AppData)
->packet.header.getPacketSequenceNum();
AckBlocks acks = {{packetNum1, packetNum1}};
auto packet1 = createAckPacket(
server->getNonConstConn(),
++clientNextAppDataPacketNum,
acks,
PacketNumberSpace::AppData);
deliverData(packetToBuf(packet1));
auto data = IOBuf::copyBuffer("bad data"); auto data = IOBuf::copyBuffer("bad data");
auto packetData = packetToBuf(createStreamPacket( auto packetData = packetToBuf(createStreamPacket(
*clientConnectionId, *clientConnectionId,
@@ -3022,6 +3016,21 @@ TEST_F(QuicServerTransportTest, ClientPortChangeNATRebinding) {
TEST_F(QuicServerTransportTest, ClientAddressChangeNATRebinding) { TEST_F(QuicServerTransportTest, ClientAddressChangeNATRebinding) {
server->getNonConstConn().transportSettings.disableMigration = false; server->getNonConstConn().transportSettings.disableMigration = false;
StreamId streamId = server->createBidirectionalStream().value();
auto data1 = IOBuf::copyBuffer("Aloha");
server->writeChain(streamId, data1->clone(), false, false);
loopForWrites();
PacketNum packetNum1 =
getFirstOutstandingPacket(
server->getNonConstConn(), PacketNumberSpace::AppData)
->packet.header.getPacketSequenceNum();
AckBlocks acks = {{packetNum1, packetNum1}};
auto packet1 = createAckPacket(
server->getNonConstConn(),
++clientNextAppDataPacketNum,
acks,
PacketNumberSpace::AppData);
deliverData(packetToBuf(packet1));
auto data = IOBuf::copyBuffer("bad data"); auto data = IOBuf::copyBuffer("bad data");
auto packetData = packetToBuf(createStreamPacket( auto packetData = packetToBuf(createStreamPacket(

View File

@@ -7,6 +7,7 @@
*/ */
#include <quic/state/QuicStateFunctions.h> #include <quic/state/QuicStateFunctions.h>
#include <quic/state/QuicStreamFunctions.h>
#include <quic/common/TimeUtil.h> #include <quic/common/TimeUtil.h>
#include <quic/logging/QuicLogger.h> #include <quic/logging/QuicLogger.h>
@@ -311,4 +312,16 @@ std::pair<folly::Optional<TimePoint>, PacketNumberSpace> earliestTimeAndSpace(
return res; return res;
} }
void handshakeConfirmed(QuicConnectionStateBase& conn) {
if (conn.nodeType == QuicNodeType::Client) {
conn.handshakeLayer->handshakeConfirmed();
}
conn.readCodec->onHandshakeDone(Clock::now());
conn.handshakeWriteCipher.reset();
conn.handshakeWriteHeaderCipher.reset();
conn.readCodec->setHandshakeReadCipher(nullptr);
conn.readCodec->setHandshakeHeaderCipher(nullptr);
cancelCryptoStream(conn.cryptoState->handshakeStream);
}
} // namespace quic } // namespace quic

View File

@@ -113,4 +113,7 @@ std::pair<folly::Optional<TimePoint>, PacketNumberSpace> earliestLossTimer(
std::pair<folly::Optional<TimePoint>, PacketNumberSpace> earliestTimeAndSpace( std::pair<folly::Optional<TimePoint>, PacketNumberSpace> earliestTimeAndSpace(
const EnumArray<PacketNumberSpace, folly::Optional<TimePoint>>& times, const EnumArray<PacketNumberSpace, folly::Optional<TimePoint>>& times,
bool considerAppData) noexcept; bool considerAppData) noexcept;
void handshakeConfirmed(QuicConnectionStateBase& conn);
} // namespace quic } // namespace quic

View File

@@ -398,14 +398,10 @@ uint64_t getStreamNextOffsetToDeliver(const QuicStreamState& stream) {
return minOffsetToDeliver; return minOffsetToDeliver;
} }
void cancelHandshakeCryptoStreamRetransmissions(QuicCryptoState& cryptoState) { void cancelCryptoStream(QuicCryptoStream& cryptoStream) {
// Cancel any retransmissions we might want to do for the crypto stream. cryptoStream.retransmissionBuffer.clear();
// This does not include data that is already deemed as lost, or data that cryptoStream.lossBuffer.clear();
// is pending in the write buffer. cryptoStream.writeBuffer.move();
cryptoState.initialStream.retransmissionBuffer.clear();
cryptoState.initialStream.lossBuffer.clear();
cryptoState.handshakeStream.retransmissionBuffer.clear();
cryptoState.handshakeStream.lossBuffer.clear();
} }
QuicCryptoStream* getCryptoStream( QuicCryptoStream* getCryptoStream(

View File

@@ -118,11 +118,9 @@ std::pair<Buf, bool> readDataInOrderFromReadBuffer(
bool sinkData = false); bool sinkData = false);
/** /**
* Cancel the retransmissions of the crypto stream data. * Cancel retransmissions and writes for a crypto stream.
* TODO: remove this when we can deal with cleartext data after handshake done
* correctly.
*/ */
void cancelHandshakeCryptoStreamRetransmissions(QuicCryptoState& cryptoStream); void cancelCryptoStream(QuicCryptoStream& cryptoStream);
/** /**
* Returns the appropriate crypto stream for the protection type of the packet. * Returns the appropriate crypto stream for the protection type of the packet.

View File

@@ -21,7 +21,6 @@ void sendSimpleFrame(QuicConnectionStateBase& conn, QuicSimpleFrame frame) {
void updateSimpleFrameOnAck( void updateSimpleFrameOnAck(
QuicConnectionStateBase& conn, QuicConnectionStateBase& conn,
const QuicSimpleFrame& frame) { const QuicSimpleFrame& frame) {
// TODO implement.
switch (frame.type()) { switch (frame.type()) {
case QuicSimpleFrame::Type::PingFrame_E: { case QuicSimpleFrame::Type::PingFrame_E: {
conn.pendingEvents.cancelPingTimeout = true; conn.pendingEvents.cancelPingTimeout = true;
@@ -72,6 +71,8 @@ folly::Optional<QuicSimpleFrame> updateSimpleFrameOnPacketClone(
case QuicSimpleFrame::Type::RetireConnectionIdFrame_E: case QuicSimpleFrame::Type::RetireConnectionIdFrame_E:
// TODO junqiw // TODO junqiw
return QuicSimpleFrame(frame); return QuicSimpleFrame(frame);
case QuicSimpleFrame::Type::HandshakeDoneFrame_E:
return QuicSimpleFrame(frame);
} }
folly::assume_unreachable(); folly::assume_unreachable();
} }
@@ -144,6 +145,7 @@ void updateSimpleFrameOnPacketLoss(
case QuicSimpleFrame::Type::NewConnectionIdFrame_E: case QuicSimpleFrame::Type::NewConnectionIdFrame_E:
case QuicSimpleFrame::Type::MaxStreamsFrame_E: case QuicSimpleFrame::Type::MaxStreamsFrame_E:
case QuicSimpleFrame::Type::RetireConnectionIdFrame_E: case QuicSimpleFrame::Type::RetireConnectionIdFrame_E:
case QuicSimpleFrame::Type::HandshakeDoneFrame_E:
conn.pendingEvents.frames.push_back(frame); conn.pendingEvents.frames.push_back(frame);
break; break;
} }
@@ -289,6 +291,16 @@ bool updateSimpleFrameOnPacketReceived(
// TODO junqiw // TODO junqiw
return false; return false;
} }
case QuicSimpleFrame::Type::HandshakeDoneFrame_E: {
if (conn.nodeType == QuicNodeType::Server) {
throw QuicTransportException(
"Received HANDSHAKE_DONE from client.",
TransportErrorCode::PROTOCOL_VIOLATION,
FrameType::HANDSHAKE_DONE);
}
handshakeConfirmed(conn);
return true;
}
} }
folly::assume_unreachable(); folly::assume_unreachable();
} }