/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include namespace quic { class MockQuicSocket : public QuicSocket { public: using SharedBuf = std::shared_ptr; MockQuicSocket() = default; MockQuicSocket( folly::EventBase* /*eventBase*/, ConnectionSetupCallback* setupCb, ConnectionCallback* connCb) : setupCb_(setupCb), connCb_(connCb) {} MOCK_METHOD(bool, good, (), (const)); MOCK_METHOD(bool, replaySafe, (), (const)); MOCK_METHOD(bool, error, (), (const)); MOCK_METHOD(void, close, (Optional)); MOCK_METHOD(void, closeGracefully, ()); MOCK_METHOD(void, closeNow, (Optional)); MOCK_METHOD(Optional, getClientConnectionId, (), (const)); MOCK_METHOD(const TransportSettings&, getTransportSettings, (), (const)); MOCK_METHOD(Optional, getServerConnectionId, (), (const)); MOCK_METHOD( Optional, getClientChosenDestConnectionId, (), (const)); MOCK_METHOD(const folly::SocketAddress&, getPeerAddress, (), (const)); MOCK_METHOD(const folly::SocketAddress&, getOriginalPeerAddress, (), (const)); MOCK_METHOD(const folly::SocketAddress&, getLocalAddress, (), (const)); MOCK_METHOD( Optional>, getPeerTransportParams, (), (const)); MOCK_METHOD( (Optional>), getExportedKeyingMaterial, (const std::string&, const Optional&, uint16_t), (const)); MOCK_METHOD(std::shared_ptr, getEventBase, (), (const)); MOCK_METHOD( (quic::Expected), getStreamReadOffset, (StreamId), (const)); MOCK_METHOD( (quic::Expected), getStreamWriteOffset, (StreamId), (const)); MOCK_METHOD( (quic::Expected), getStreamWriteBufferedBytes, (StreamId), (const)); MOCK_METHOD(QuicSocket::TransportInfo, getTransportInfo, (), (const)); MOCK_METHOD( (quic::Expected), getStreamTransportInfo, (StreamId), (const)); MOCK_METHOD(Optional, getAppProtocol, (), (const)); MOCK_METHOD(void, setReceiveWindow, (StreamId, size_t)); MOCK_METHOD(void, setSendBuffer, (StreamId, size_t, size_t)); MOCK_METHOD(uint64_t, getConnectionBufferAvailable, (), (const)); MOCK_METHOD( (quic::Expected), getConnectionFlowControl, (), (const)); MOCK_METHOD( (quic::Expected), getStreamFlowControl, (StreamId), (const)); MOCK_METHOD( (quic::Expected), getMaxWritableOnStream, (StreamId), (const)); MOCK_METHOD(void, unsetAllReadCallbacks, ()); MOCK_METHOD(void, unsetAllPeekCallbacks, ()); MOCK_METHOD(void, unsetAllDeliveryCallbacks, ()); MOCK_METHOD(void, cancelDeliveryCallbacksForStream, (StreamId)); MOCK_METHOD( void, cancelDeliveryCallbacksForStream, (StreamId, uint64_t offset)); MOCK_METHOD( (quic::Expected), setConnectionFlowControlWindow, (uint64_t)); MOCK_METHOD( (quic::Expected), setStreamFlowControlWindow, (StreamId, uint64_t)); MOCK_METHOD(void, setTransportSettings, (TransportSettings)); MOCK_METHOD( (quic::Expected), setMaxPacingRate, (uint64_t)); quic::Expected setKnob(uint64_t knobSpace, uint64_t knobId, BufPtr knobBlob) override { SharedBuf sharedBlob(knobBlob.release()); return setKnob(knobSpace, knobId, sharedBlob); } MOCK_METHOD( (quic::Expected), setKnob, (uint64_t, uint64_t, SharedBuf)); MOCK_METHOD(bool, isKnobSupported, (), (const)); MOCK_METHOD( (quic::Expected), setStreamPriority, (StreamId, PriorityQueue::Priority)); MOCK_METHOD( (quic::Expected), setPriorityQueue, (std::unique_ptr queue)); MOCK_METHOD( (quic::Expected), getStreamPriority, (StreamId)); MOCK_METHOD( (quic::Expected), setReadCallback, (StreamId, ReadCallback*, Optional err)); MOCK_METHOD( void, setConnectionSetupCallback, (folly::MaybeManagedPtr)); MOCK_METHOD( void, setConnectionCallback, (folly::MaybeManagedPtr)); void setEarlyDataAppParamsFunctions( std::function&, const BufPtr&)> validator, std::function getter) override { earlyDataAppParamsValidator_ = std::move(validator); earlyDataAppParamsGetter_ = std::move(getter); } MOCK_METHOD((quic::Expected), pauseRead, (StreamId)); MOCK_METHOD((quic::Expected), resumeRead, (StreamId)); MOCK_METHOD( (quic::Expected), stopSending, (StreamId, ApplicationErrorCode)); quic::Expected, LocalErrorCode> read( StreamId id, size_t maxRead) override { auto res = readNaked(id, maxRead); if (res.hasError()) { return quic::make_unexpected(res.error()); } else { return std::pair( BufPtr(res.value().first), res.value().second); } } using ReadResult = quic::Expected, LocalErrorCode>; MOCK_METHOD(ReadResult, readNaked, (StreamId, size_t)); MOCK_METHOD( (quic::Expected), createBidirectionalStream, (bool)); MOCK_METHOD( (quic::Expected), createUnidirectionalStream, (bool)); MOCK_METHOD(uint64_t, getNumOpenableBidirectionalStreams, (), (const)); MOCK_METHOD(uint64_t, getNumOpenableUnidirectionalStreams, (), (const)); MOCK_METHOD((bool), isClientStream, (StreamId), (noexcept)); MOCK_METHOD((bool), isServerStream, (StreamId), (noexcept)); MOCK_METHOD((StreamInitiator), getStreamInitiator, (StreamId), (noexcept)); MOCK_METHOD((bool), isBidirectionalStream, (StreamId), (noexcept)); MOCK_METHOD((bool), isUnidirectionalStream, (StreamId), (noexcept)); MOCK_METHOD( (StreamDirectionality), getStreamDirectionality, (StreamId), (noexcept)); MOCK_METHOD( (quic::Expected), notifyPendingWriteOnConnection, (ConnectionWriteCallback*)); MOCK_METHOD( (quic::Expected), notifyPendingWriteOnStream, (StreamId, StreamWriteCallback*)); MOCK_METHOD( (quic::Expected), unregisterStreamWriteCallback, (StreamId)); MOCK_METHOD( (quic::Expected), registerTxCallback, (const StreamId, const uint64_t, ByteEventCallback*)); MOCK_METHOD( (quic::Expected), registerByteEventCallback, (const ByteEvent::Type, const StreamId, const uint64_t, ByteEventCallback*)); MOCK_METHOD( void, cancelByteEventCallbacksForStream, (const StreamId id, const Optional& offset)); MOCK_METHOD( void, cancelByteEventCallbacksForStream, (const ByteEvent::Type, const StreamId id, const Optional& offset)); MOCK_METHOD(void, cancelAllByteEventCallbacks, ()); MOCK_METHOD(void, cancelByteEventCallbacks, (const ByteEvent::Type)); MOCK_METHOD( size_t, getNumByteEventCallbacksForStream, (const StreamId id), (const)); MOCK_METHOD( size_t, getNumByteEventCallbacksForStream, (const ByteEvent::Type, const StreamId), (const)); quic::Expected writeChain( StreamId id, BufPtr data, bool eof, ByteEventCallback* cb) override { SharedBuf sharedData(data.release()); return writeChain(id, sharedData, eof, cb); } MOCK_METHOD( WriteResult, writeChain, (StreamId, SharedBuf, bool, ByteEventCallback*)); MOCK_METHOD( WriteResult, writeBufMeta, (StreamId, const BufferMeta&, bool, ByteEventCallback*)); MOCK_METHOD( WriteResult, setDSRPacketizationRequestSender, (StreamId, std::unique_ptr)); MOCK_METHOD( (quic::Expected), registerDeliveryCallback, (StreamId, uint64_t, ByteEventCallback*)); MOCK_METHOD(Optional, shutdownWrite, (StreamId)); MOCK_METHOD( (quic::Expected), resetStream, (StreamId, ApplicationErrorCode)); MOCK_METHOD( (quic::Expected), updateReliableDeliveryCheckpoint, (StreamId)); MOCK_METHOD( (quic::Expected), resetStreamReliably, (StreamId, ApplicationErrorCode)); MOCK_METHOD( (quic::Expected), maybeResetStreamFromReadError, (StreamId, QuicErrorCode)); MOCK_METHOD( (quic::Expected), setPingCallback, (PingCallback*)); MOCK_METHOD(void, sendPing, (std::chrono::milliseconds)); MOCK_METHOD(const QuicConnectionStateBase*, getState, (), (const)); MOCK_METHOD(bool, isDetachable, ()); MOCK_METHOD(void, attachEventBase, (std::shared_ptr)); MOCK_METHOD(void, detachEventBase, ()); MOCK_METHOD(Optional, setControlStream, (StreamId)); MOCK_METHOD( (quic::Expected), setPeekCallback, (StreamId, PeekCallback*)); MOCK_METHOD((quic::Expected), pausePeek, (StreamId)); MOCK_METHOD((quic::Expected), resumePeek, (StreamId)); MOCK_METHOD( (quic::Expected), peek, (StreamId, const std::function< void(StreamId, const folly::Range&)>&)); MOCK_METHOD( (quic::Expected>>), consume, (StreamId, uint64_t, size_t)); MOCK_METHOD( (quic::Expected), consume, (StreamId, size_t)); MOCK_METHOD(void, setCongestionControl, (CongestionControlType)); MOCK_METHOD(void, addPacketProcessor, (std::shared_ptr)); MOCK_METHOD( void, setThrottlingSignalProvider, (std::shared_ptr)); ConnectionSetupCallback* setupCb_{nullptr}; ConnectionCallback* connCb_{nullptr}; std::function&, const BufPtr&)> earlyDataAppParamsValidator_; std::function earlyDataAppParamsGetter_; MOCK_METHOD( void, resetNonControlStreams, (ApplicationErrorCode, folly::StringPiece)); MOCK_METHOD(QuicConnectionStats, getConnectionsStats, (), (const)); MOCK_METHOD( (quic::Expected), setDatagramCallback, (DatagramCallback*)); MOCK_METHOD(uint16_t, getDatagramSizeLimit, (), (const)); quic::Expected writeDatagram(BufPtr data) override { SharedBuf sharedData(data.release()); return writeDatagram(sharedData); } MOCK_METHOD(WriteResult, writeDatagram, (SharedBuf)); MOCK_METHOD( (quic::Expected, LocalErrorCode>), readDatagrams, (size_t)); MOCK_METHOD( (quic::Expected, LocalErrorCode>), readDatagramBufs, (size_t)); MOCK_METHOD( SocketObserverContainer*, getSocketObserverContainer, (), (const)); MOCK_METHOD( (quic::Expected), createBidirectionalStreamGroup, ()); MOCK_METHOD( (quic::Expected), createUnidirectionalStreamGroup, ()); MOCK_METHOD( (quic::Expected), createBidirectionalStreamInGroup, (StreamGroupId)); MOCK_METHOD( (quic::Expected), createUnidirectionalStreamInGroup, (StreamGroupId)); MOCK_METHOD( (quic::Expected), setStreamGroupRetransmissionPolicy, (StreamGroupId, std::optional), (noexcept)); MOCK_METHOD( (const std::shared_ptr), getPeerCertificate, (), (const)); MOCK_METHOD((uint64_t), maxWritableOnConn, (), (const)); }; } // namespace quic