diff --git a/quic/api/QuicTransportBase.cpp b/quic/api/QuicTransportBase.cpp index b5e9bab25..387b05ba4 100644 --- a/quic/api/QuicTransportBase.cpp +++ b/quic/api/QuicTransportBase.cpp @@ -390,8 +390,11 @@ void QuicTransportBase::closeImpl( if (noError) { connCallback_->onConnectionEnd(); } else { + std::string closeStr = exceptionCloseWhat_ + ? std::move(exceptionCloseWhat_.value()) + : cancelCode.second.str(); connCallback_->onConnectionError( - std::make_pair(cancelCode.first, cancelCode.second.str())); + std::make_pair(cancelCode.first, closeStr)); } } @@ -1279,16 +1282,19 @@ folly::Expected, LocalErrorCode> QuicTransportBase::read( return folly::makeExpected(std::move(result)); } catch (const QuicTransportException& ex) { VLOG(4) << "read() error " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("read() error"))); return folly::makeUnexpected(LocalErrorCode::TRANSPORT_ERROR); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("read() error"))); return folly::makeUnexpected(ex.errorCode()); } catch (const std::exception& ex) { VLOG(4) << "read() error " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("read() error"))); @@ -1405,17 +1411,20 @@ folly:: return folly::makeExpected(folly::Unit()); } catch (const QuicTransportException& ex) { VLOG(4) << "consume() error " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("consume() error"))); return folly::makeUnexpected( ConsumeError{LocalErrorCode::TRANSPORT_ERROR, readOffset}); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("consume() error"))); return folly::makeUnexpected(ConsumeError{ex.errorCode(), readOffset}); } catch (const std::exception& ex) { VLOG(4) << "consume() error " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("consume() error"))); @@ -1635,18 +1644,22 @@ void QuicTransportBase::onNetworkData( } } catch (const QuicTransportException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); return closeImpl( std::make_pair(QuicErrorCode(ex.errorCode()), std::string(ex.what()))); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); return closeImpl( std::make_pair(QuicErrorCode(ex.errorCode()), std::string(ex.what()))); } catch (const QuicApplicationException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); return closeImpl( std::make_pair(QuicErrorCode(ex.errorCode()), std::string(ex.what()))); } catch (const std::exception& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); return closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("error onNetworkData()"))); @@ -1868,18 +1881,21 @@ QuicSocket::WriteResult QuicTransportBase::writeChain( } catch (const QuicTransportException& ex) { VLOG(4) << __func__ << " streamId=" << id << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("writeChain() error"))); return folly::makeUnexpected(LocalErrorCode::TRANSPORT_ERROR); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << " streamId=" << id << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("writeChain() error"))); return folly::makeUnexpected(ex.errorCode()); } catch (const std::exception& ex) { VLOG(4) << __func__ << " streamId=" << id << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("writeChain() error"))); @@ -1994,18 +2010,21 @@ folly::Expected QuicTransportBase::resetStream( } catch (const QuicTransportException& ex) { VLOG(4) << __func__ << " streamId=" << id << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("resetStream() error"))); return folly::makeUnexpected(LocalErrorCode::TRANSPORT_ERROR); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << " streamId=" << id << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("resetStream() error"))); return folly::makeUnexpected(ex.errorCode()); } catch (const std::exception& ex) { VLOG(4) << __func__ << " streamId=" << id << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("resetStream() error"))); @@ -2103,16 +2122,19 @@ void QuicTransportBase::lossTimeoutExpired() noexcept { pacedWriteDataToSocket(false); } catch (const QuicTransportException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("lossTimeoutExpired() error"))); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("lossTimeoutExpired() error"))); } catch (const std::exception& ex) { VLOG(4) << __func__ << " " << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("lossTimeoutExpired() error"))); @@ -2376,16 +2398,19 @@ void QuicTransportBase::writeSocketDataAndCatch() { writeSocketData(); } catch (const QuicTransportException& ex) { VLOG(4) << __func__ << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("writeSocketDataAndCatch() error"))); } catch (const QuicInternalException& ex) { VLOG(4) << __func__ << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(ex.errorCode()), std::string("writeSocketDataAndCatch() error"))); } catch (const std::exception& ex) { VLOG(4) << __func__ << " error=" << ex.what() << " " << *this; + exceptionCloseWhat_ = ex.what(); closeImpl(std::make_pair( QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), std::string("writeSocketDataAndCatch() error"))); diff --git a/quic/api/QuicTransportBase.h b/quic/api/QuicTransportBase.h index 1b355b542..37c211d62 100644 --- a/quic/api/QuicTransportBase.h +++ b/quic/api/QuicTransportBase.h @@ -627,6 +627,8 @@ class QuicTransportBase : public QuicSocket { folly::SocketAddress localFallbackAddress; // CongestionController factory std::shared_ptr ccFactory_{nullptr}; + + folly::Optional exceptionCloseWhat_; }; std::ostream& operator<<(std::ostream& os, const QuicTransportBase& qt); diff --git a/quic/api/test/QuicTransportBaseTest.cpp b/quic/api/test/QuicTransportBaseTest.cpp index 914ef9c7f..c74c3fd75 100644 --- a/quic/api/test/QuicTransportBaseTest.cpp +++ b/quic/api/test/QuicTransportBaseTest.cpp @@ -1030,6 +1030,30 @@ TEST_F(QuicTransportImplTest, ConnectionErrorOnWrite) { QuicErrorCode(LocalErrorCode::CONNECTION_ABANDONED)); } +TEST_F(QuicTransportImplTest, ConnectionErrorUnhandledException) { + transport->transportConn->oneRttWriteCipher = test::createNoOpAead(); + auto stream = transport->createBidirectionalStream().value(); + EXPECT_CALL( + connCallback, + onConnectionError(std::make_pair( + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR), + std::string("Well there's your problem")))); + EXPECT_CALL(*socketPtr, write(_, _)).WillOnce(Invoke([](auto&, auto&) { + throw std::runtime_error("Well there's your problem"); + return 0; + })); + QuicSocket::WriteResult result = transport->writeChain( + stream, folly::IOBuf::copyBuffer("Hey"), true, false, nullptr); + transport->addDataToStream( + stream, StreamBuffer(folly::IOBuf::copyBuffer("Data"), 0)); + evb->loopOnce(); + + EXPECT_TRUE(transport->isClosed()); + EXPECT_EQ( + transport->getConnectionError(), + QuicErrorCode(TransportErrorCode::INTERNAL_ERROR)); +} + TEST_F(QuicTransportImplTest, LossTimeoutNoLessThanTickInterval) { auto tickInterval = evb->timer().getTickInterval(); transport->scheduleLossTimeout(tickInterval - 1ms);