1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-11-24 04:01:07 +03:00

Reset congestion controller after setting factory

Summary:
Before the change, there's no good way to recreate Cubic CC instance with custom CC factory, because Cubic is created by default.

On client side this requires calling setCongestionControl() or setTransportSettings() after calling setCongestionControllerFactory(), which is normally the case.

Reviewed By: yangchi

Differential Revision: D26401996

fbshipit-source-id: dfda39be835c67b9db42f726b3ac64c7b3d37c2f
This commit is contained in:
Andrii Vasylevskyi
2021-02-19 17:53:31 -08:00
committed by Facebook GitHub Bot
parent 1303281a51
commit 10a6feed49
7 changed files with 71 additions and 25 deletions

View File

@@ -71,7 +71,9 @@ void QuicTransportBase::setPacingTimer(
void QuicTransportBase::setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> ccFactory) {
CHECK(ccFactory);
ccFactory_ = ccFactory;
CHECK(conn_);
conn_->congestionControllerFactory = ccFactory;
conn_->congestionController.reset();
}
folly::EventBase* QuicTransportBase::getEventBase() const {
@@ -3004,10 +3006,11 @@ void QuicTransportBase::setCongestionControl(CongestionControlType type) {
DCHECK(conn_);
if (!conn_->congestionController ||
type != conn_->congestionController->type()) {
CHECK(ccFactory_);
CHECK(conn_->congestionControllerFactory);
validateCongestionAndPacing(type);
conn_->congestionController =
ccFactory_->makeCongestionController(*conn_, type);
conn_->congestionControllerFactory->makeCongestionController(
*conn_, type);
}
}

View File

@@ -259,7 +259,9 @@ class QuicTransportBase : public QuicSocket {
/**
* Set factory to create specific congestion controller instances
* for a given connection
* for a given connection.
* Deletes current congestion controller instance, to create new controller
* call setCongestionControl() or setTransportSettings().
*/
virtual void setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> factory);
@@ -826,8 +828,6 @@ class QuicTransportBase : public QuicSocket {
// TODO: This is silly. We need a better solution.
// Uninitialied local address as a fallback answer when socket isn't bound.
folly::SocketAddress localFallbackAddress;
// CongestionController factory
std::shared_ptr<CongestionControllerFactory> ccFactory_{nullptr};
folly::Optional<std::string> exceptionCloseWhat_;

View File

@@ -11,6 +11,7 @@
#include <folly/io/async/AsyncSocketException.h>
#include <quic/client/handshake/CachedServerTransportParameters.h>
#include <quic/common/TimeUtil.h>
#include <quic/congestion_control/CongestionControllerFactory.h>
#include <quic/congestion_control/QuicCubic.h>
#include <quic/flowcontrol/QuicFlowController.h>
#include <quic/handshake/TransportParameters.h>
@@ -56,6 +57,16 @@ std::unique_ptr<QuicClientConnectionState> undoAllClientStateForRetry(
std::move(conn->earlyDataAppParamsValidator);
newConn->earlyDataAppParamsGetter = std::move(conn->earlyDataAppParamsGetter);
newConn->happyEyeballsState = std::move(conn->happyEyeballsState);
if (conn->congestionControllerFactory) {
newConn->congestionControllerFactory = conn->congestionControllerFactory;
if (conn->congestionController) {
// we have to recreate congestion controler
// because it holds referencs to the old state
newConn->congestionController =
newConn->congestionControllerFactory->makeCongestionController(
*newConn, conn->congestionController->type());
}
}
return newConn;
}

View File

@@ -2811,8 +2811,6 @@ class QuicClientTransportAfterStartTestBase : public QuicClientTransportTest {
void SetUpChild() override {
client->addNewPeerAddress(serverAddr);
client->setHostname(hostname_);
client->setCongestionControllerFactory(
std::make_shared<DefaultCongestionControllerFactory>());
ON_CALL(*sock, write(_, _))
.WillByDefault(Invoke([&](const SocketAddress&,
const std::unique_ptr<folly::IOBuf>& buf) {
@@ -4523,6 +4521,10 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) {
client->getNonConstConn().initialDestinationConnectionId = initialDstConnId;
client->getNonConstConn().originalDestinationConnectionId = initialDstConnId;
client->setCongestionControllerFactory(
std::make_shared<DefaultCongestionControllerFactory>());
client->setCongestionControl(CongestionControlType::NewReno);
StreamId streamId = *client->createBidirectionalStream();
auto write = IOBuf::copyBuffer("ice cream");
client->writeChain(streamId, write->clone(), true, nullptr);
@@ -4540,6 +4542,12 @@ TEST_F(QuicClientTransportVersionAndRetryTest, RetryPacket) {
auto serverCid = recvServerRetry(serverAddr);
ASSERT_TRUE(bytesWrittenToNetwork);
// Check CC is kept after retry recreates QuicClientConnectionState
EXPECT_TRUE(client->getConn().congestionControllerFactory);
EXPECT_EQ(
client->getConn().congestionController->type(),
CongestionControlType::NewReno);
// Check to see that the server receives an initial packet with the following
// properties:
// 1. The token in the initial packet matches the token sent in the retry
@@ -5169,11 +5177,49 @@ TEST_F(QuicClientTransportAfterStartTest, DestroyEvbWhileLossTimeoutActive) {
eventbase_.reset();
}
class TestCCFactory : public CongestionControllerFactory {
public:
std::unique_ptr<CongestionController> makeCongestionController(
QuicConnectionStateBase& conn,
CongestionControlType type) override {
EXPECT_EQ(type, CongestionControlType::Cubic);
createdControllers++;
return std::make_unique<Cubic>(conn);
}
int createdControllers{0};
};
TEST_F(
QuicClientTransportAfterStartTest,
CongestionControlRecreatedWithNewFactory) {
// Default: Cubic
auto cc = client->getConn().congestionController.get();
EXPECT_EQ(CongestionControlType::Cubic, cc->type());
// Check Cubic CC instance is recreated with new CC factory
auto factory = std::make_shared<TestCCFactory>();
client->setCongestionControllerFactory(factory);
client->setCongestionControl(CongestionControlType::Cubic);
auto newCC = client->getConn().congestionController.get();
EXPECT_NE(cc, newCC);
EXPECT_EQ(factory->createdControllers, 1);
}
TEST_F(QuicClientTransportAfterStartTest, SetCongestionControl) {
// Default: Cubic
auto cc = client->getConn().congestionController.get();
EXPECT_EQ(CongestionControlType::Cubic, cc->type());
// Setting CC factory resets CC controller
client->setCongestionControllerFactory(
std::make_shared<DefaultCongestionControllerFactory>());
EXPECT_FALSE(client->getConn().congestionController);
// Set to Cubic explicitly this time
client->setCongestionControl(CongestionControlType::Cubic);
cc = client->getConn().congestionController.get();
EXPECT_EQ(CongestionControlType::Cubic, cc->type());
// Change to Reno
client->setCongestionControl(CongestionControlType::NewReno);
cc = client->getConn().congestionController.get();
@@ -5191,6 +5237,8 @@ TEST_F(QuicClientTransportAfterStartTest, SetCongestionControlBbr) {
EXPECT_EQ(CongestionControlType::Cubic, cc->type());
// Change to BBR, which requires enable pacing first
client->setCongestionControllerFactory(
std::make_shared<DefaultCongestionControllerFactory>());
client->setPacingTimer(TimerHighRes::newTimer(eventbase_.get(), 1ms));
client->getNonConstConn().transportSettings.pacingEnabled = true;
client->setCongestionControl(CongestionControlType::BBR);

View File

@@ -105,15 +105,6 @@ void QuicServerTransport::setServerConnectionIdRejector(
}
}
void QuicServerTransport::setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> ccFactory) {
CHECK(ccFactory);
ccFactory_ = ccFactory;
if (conn_) {
conn_->congestionControllerFactory = ccFactory_;
}
}
void QuicServerTransport::onReadData(
const folly::SocketAddress& peer,
NetworkDataSingle&& networkData) {

View File

@@ -87,14 +87,6 @@ class QuicServerTransport
void setServerConnectionIdRejector(
ServerConnectionIdRejector* connIdRejector) noexcept;
/**
* Set factory to create specific congestion controller instances
* for a given connection
* This must be set before the server is started.
*/
void setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> factory) override;
virtual void setClientConnectionId(const ConnectionId& clientConnectionId);
void setClientChosenDestConnectionId(const ConnectionId& serverCid);

View File

@@ -322,6 +322,7 @@ class QuicServerTransportTest : public Test {
server = std::make_shared<TestingQuicServerTransport>(
&evb, std::move(sock), connCallback, serverCtx);
server->setCongestionControllerFactory(ccFactory_);
server->setCongestionControl(CongestionControlType::Cubic);
server->setRoutingCallback(&routingCallback);
server->setSupportedVersions(supportedVersions);
server->setOriginalPeerAddress(clientAddr);