diff --git a/quic/server/QuicServerTransport.h b/quic/server/QuicServerTransport.h index ae57a05f9..e06e472f6 100644 --- a/quic/server/QuicServerTransport.h +++ b/quic/server/QuicServerTransport.h @@ -107,12 +107,6 @@ class QuicServerTransport } virtual void accept(); - void setShedConnection() { - shedConnection_ = true; - } - bool shouldShedConnection() { - return shedConnection_; - } protected: // From ServerHandshake::HandshakeCallback @@ -131,7 +125,6 @@ class QuicServerTransport bool notifiedRouting_{false}; bool notifiedConnIdBound_{false}; bool newSessionTicketWritten_{false}; - bool shedConnection_{false}; bool connectionIdsIssued_{false}; QuicServerConnectionState* serverConn_; }; diff --git a/quic/server/QuicServerWorker.cpp b/quic/server/QuicServerWorker.cpp index cce87e707..35121488a 100644 --- a/quic/server/QuicServerWorker.cpp +++ b/quic/server/QuicServerWorker.cpp @@ -384,44 +384,50 @@ void QuicServerWorker::dispatchPacketData( auto sock = makeSocket(getEventBase()); auto trans = transportFactory_->make( getEventBase(), std::move(sock), client, ctx_); - trans->setPacingTimer(pacingTimer_); - trans->setRoutingCallback(this); - trans->setSupportedVersions(supportedVersions_); - trans->setOriginalPeerAddress(client); - trans->setCongestionControllerFactory(ccFactory_); - if (transportSettingsOverrideFn_) { - folly::Optional overridenTransportSettings = - transportSettingsOverrideFn_( - transportSettings_, client.getIPAddress()); - if (overridenTransportSettings) { - trans->setTransportSettings(*overridenTransportSettings); + if (!trans) { + LOG(ERROR) << "Transport factory failed to make new transport"; + dropPacket = true; + } else { + CHECK(trans); + trans->setPacingTimer(pacingTimer_); + trans->setRoutingCallback(this); + trans->setSupportedVersions(supportedVersions_); + trans->setOriginalPeerAddress(client); + trans->setCongestionControllerFactory(ccFactory_); + if (transportSettingsOverrideFn_) { + folly::Optional overridenTransportSettings = + transportSettingsOverrideFn_( + transportSettings_, client.getIPAddress()); + if (overridenTransportSettings) { + trans->setTransportSettings(*overridenTransportSettings); + } else { + trans->setTransportSettings(transportSettings_); + } } else { trans->setTransportSettings(transportSettings_); } - } else { - trans->setTransportSettings(transportSettings_); + trans->setConnectionIdAlgo(connIdAlgo_.get()); + if (routingData.sourceConnId) { + trans->setClientConnectionId(*routingData.sourceConnId); + } + trans->setClientChosenDestConnectionId(routingData.destinationConnId); + // parameters to create server chosen connection id + ServerConnectionIdParams serverConnIdParams( + hostId_, static_cast(processId_), workerId_); + trans->setServerConnectionIdParams(std::move(serverConnIdParams)); + if (infoCallback_) { + trans->setTransportInfoCallback(infoCallback_.get()); + } + trans->accept(); + auto result = sourceAddressMap_.emplace(std::make_pair( + std::make_pair(client, routingData.destinationConnId), trans)); + if (!result.second) { + LOG(ERROR) << "Routing entry already exists for client=" << client + << ", dest CID=" << routingData.destinationConnId.hex(); + dropPacket = true; + } + transport = trans; } - trans->setConnectionIdAlgo(connIdAlgo_.get()); - if (routingData.sourceConnId) { - trans->setClientConnectionId(*routingData.sourceConnId); - } - trans->setClientChosenDestConnectionId(routingData.destinationConnId); - // parameters to create server chosen connection id - ServerConnectionIdParams serverConnIdParams( - hostId_, static_cast(processId_), workerId_); - trans->setServerConnectionIdParams(std::move(serverConnIdParams)); - if (infoCallback_) { - trans->setTransportInfoCallback(infoCallback_.get()); - } - trans->accept(); - auto result = sourceAddressMap_.emplace(std::make_pair( - std::make_pair(client, routingData.destinationConnId), trans)); - if (!result.second) { - LOG(ERROR) << "Routing entry already exists for client=" << client - << ", dest CID=" << routingData.destinationConnId.hex(); - dropPacket = true; - } - transport = trans; } } else { transport = sit->second; @@ -674,12 +680,6 @@ void QuicServerWorker::onConnectionIdBound( LOG(ERROR) << "Transport not match, client=" << *transport; } else { sourceAddressMap_.erase(source); - if (transport->shouldShedConnection()) { - VLOG_EVERY_N(1, 100) << "Shedding connection"; - transport->closeNow(std::make_pair( - QuicErrorCode(TransportErrorCode::SERVER_BUSY), - std::string("shedding under load"))); - } } } diff --git a/quic/server/test/QuicServerTest.cpp b/quic/server/test/QuicServerTest.cpp index 822629774..2293206f2 100644 --- a/quic/server/test/QuicServerTest.cpp +++ b/quic/server/test/QuicServerTest.cpp @@ -167,6 +167,11 @@ class QuicServerWorkerTest : public Test { ShortHeader shortHeader, QuicTransportStatsCallback::PacketDropReason dropReason); + void expectConnCreateRefused(); + void createQuicConnectionDuringShedding( + const folly::SocketAddress& addr, + ConnectionId connId); + protected: folly::SocketAddress fakeAddress_; std::unique_ptr worker_; @@ -206,6 +211,38 @@ void QuicServerWorkerTest::expectConnectionCreation( EXPECT_CALL(*transport, setTransportInfoCallback(transportInfoCb_)); } +void QuicServerWorkerTest::expectConnCreateRefused() { + MockQuicTransport::Ptr transport = transport_; + EXPECT_CALL(*factory_, _make(_, _, _, _)).WillOnce(Return(nullptr)); + EXPECT_CALL(*transport, setSupportedVersions(_)).Times(0); + EXPECT_CALL(*transport, setOriginalPeerAddress(_)).Times(0); + EXPECT_CALL(*transport, setRoutingCallback(worker_.get())).Times(0); + EXPECT_CALL(*transport, setConnectionIdAlgo(_)).Times(0); + EXPECT_CALL(*transport, setServerConnectionIdParams(_)).Times(0); + EXPECT_CALL(*transport, setTransportSettings(_)).Times(0); + EXPECT_CALL(*transport, accept()).Times(0); + EXPECT_CALL(*transport, setTransportInfoCallback(transportInfoCb_)).Times(0); + EXPECT_CALL(*transport, onNetworkData(_, _)).Times(0); +} + +void QuicServerWorkerTest::createQuicConnectionDuringShedding( + const folly::SocketAddress& addr, + ConnectionId connId) { + PacketNum num = 1; + QuicVersion version = QuicVersion::MVFST; + LongHeader header(LongHeader::Types::Initial, connId, connId, num, version); + RoutingData routingData(HeaderForm::Long, true, true, connId, connId); + + auto data = createData(kMinInitialPacketSize + 10); + expectConnCreateRefused(); + worker_->dispatchPacketData( + addr, std::move(routingData), NetworkData(data->clone(), Clock::now())); + + const auto& addrMap = worker_->getSrcToTransportMap(); + EXPECT_EQ(0, addrMap.count(std::make_pair(addr, connId))); + eventbase_.loop(); +} + void QuicServerWorkerTest::createQuicConnection( const folly::SocketAddress& addr, ConnectionId connId, @@ -552,23 +589,7 @@ TEST_F(QuicServerWorkerTest, InitialPacketTooSmall) { TEST_F(QuicServerWorkerTest, QuicShedTest) { auto connId = getTestConnectionId(hostId_); - createQuicConnection(kClientAddr, connId); - - worker_->onConnectionIdAvailable(transport_, getTestConnectionId(hostId_)); - EXPECT_CALL(*transport_, getClientChosenDestConnectionId()) - .WillRepeatedly(Return(connId)); - transport_->setShedConnection(); - EXPECT_CALL( - *transport_, - closeNow(Eq(std::make_pair( - QuicErrorCode(TransportErrorCode::SERVER_BUSY), - std::string("shedding under load"))))); - worker_->onConnectionIdBound(transport_); - EXPECT_CALL(*transport_, setRoutingCallback(nullptr)); - worker_->onConnectionUnbound( - transport_.get(), - std::make_pair(kClientAddr, connId), - std::vector{ConnectionIdData{connId, 0}}); + createQuicConnectionDuringShedding(kClientAddr, connId); } TEST_F(QuicServerWorkerTest, ZeroLengthConnectionId) {