1
0
mirror of https://github.com/facebookincubator/mvfst.git synced 2025-04-18 17:24:03 +03:00

Initial commit of mvfst

This commit is contained in:
udippant 2019-04-22 23:34:59 -07:00
commit 50d4939e9e
287 changed files with 71535 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
_build/

99
CMakeLists.txt Normal file
View File

@ -0,0 +1,99 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
cmake_minimum_required(VERSION 3.10)
project(
mvfst
)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
# QUIC_FBCODE_ROOT is where the top level quic/ directory resides, so
# an #include <quic/path/to/file> will resolve to
# $QUIC_FBCODE_ROOT/quic/path/to/file on disk
set(QUIC_FBCODE_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
# Dependencies
find_package(Boost 1.58
REQUIRED COMPONENTS
iostreams
system
thread
filesystem
regex
context
)
find_package(folly REQUIRED)
find_package(Fizz REQUIRED)
find_package(Glog REQUIRED)
find_package(Threads)
SET(GFLAG_DEPENDENCIES "")
SET(QUIC_EXTRA_LINK_LIBRARIES "")
SET(QUIC_EXTRA_INCLUDE_DIRECTORIES "")
# find_package(gflags COMPONENTS static)
find_package(gflags CONFIG QUIET)
if (gflags_FOUND)
message("module path: ${CMAKE_MODULE_PATH}")
message(STATUS "Found gflags from package config")
if (TARGET gflags-shared)
list(APPEND GFLAG_DEPENDENCIES gflags-shared)
elseif (TARGET gflags)
list(APPEND GFLAG_DEPENDENCIES gflags)
else()
message(FATAL_ERROR "Unable to determine the target name for the GFlags package.")
endif()
list(APPEND CMAKE_REQUIRED_LIBRARIES ${GFLAGS_LIBRARIES})
list(APPEND CMAKE_REQUIRED_INCLUDES ${GFLAGS_INCLUDE_DIR})
else()
find_package(Gflags REQUIRED MODULE)
list(APPEND QUIC_EXTRA_LINK_LIBRARIES ${LIBGFLAGS_LIBRARY})
list(APPEND QUIC_EXTRA_INCLUDE_DIRECTORIES ${LIBGFLAGS_INCLUDE_DIR})
list(APPEND CMAKE_REQUIRED_LIBRARIES ${LIBGFLAGS_LIBRARY})
list(APPEND CMAKE_REQUIRED_INCLUDES ${LIBGFLAGS_INCLUDE_DIR})
endif()
list(APPEND
_QUIC_COMMON_COMPILE_OPTIONS
-std=c++14
-Wall
-Wextra
# more strict options
-Werror=sign-compare
-Werror=bool-compare
-Werror=unused-variable
-Woverloaded-virtual
-Wnon-virtual-dtor
)
include(QuicTest)
add_subdirectory(quic)
if(BUILD_TESTS)
enable_testing()
find_package(GTest REQUIRED)
find_package(GMock REQUIRED)
endif()
install(
EXPORT mvfst-exports
FILE mvfst-targets.cmake
NAMESPACE mvfst::
DESTINATION lib/cmake/mvfst/
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
cmake/mvfst-config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/mvfst-config.cmake
INSTALL_DESTINATION lib/cmake/mvfst/
)
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/mvfst-config.cmake
DESTINATION lib/cmake/mvfst/
)

3
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,3 @@
# Code of Conduct
Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.facebook.com/codeofconduct) so that you can understand what actions will and will not be tolerated.

44
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,44 @@
# Contributing to MVFST
Here's a quick rundown of how to contribute to this project.
## Code of Conduct
The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md)
## Our Development Process
We develop on a private branch internally at Facebook. We regularly update
this github project with the changes from the internal repo. External pull
requests are cherry-picked into our repo and then pushed back out.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `master`.
1. If you've added code that should be tested, add tests
1. If you've changed APIs, update the documentation.
1. Ensure the test suite passes.
1. Make sure your code lints.
1. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You
only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description
is clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for
the safe disclosure of security bugs. In those cases, please go through
the process outlined on that page and do not file a public issue.
## Coding Style
We use clang-format tool for coding style.
Please run `clang-format -i <filename>` on file, where changes has been made
before you commit them.
## License
By contributing to MVFST, you agree that your contributions will be
licensed under its BSD license.

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) Facebook, Inc. and its affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

33
README.md Normal file
View File

@ -0,0 +1,33 @@
## Building
### Ubuntu 16+
To begin, you should install the dependencies we need for build. This largely
consists of dependencies from [folly](https://github.com/facebook/folly) as well as
[fizz](https://github.com/facebookincubator/fizz).
```
sudo apt-get install \
g++ \
cmake \
libboost-all-dev \
libevent-dev \
libdouble-conversion-dev \
libgoogle-glog-dev \
libgflags-dev \
libiberty-dev \
liblz4-dev \
liblzma-dev \
libsnappy-dev \
make \
zlib1g-dev \
binutils-dev \
libjemalloc-dev \
libssl-dev \
pkg-config \
libsodium-dev
```
Then, build and install folly and fizz

93
cmake/CheckAtomic.cmake Normal file
View File

@ -0,0 +1,93 @@
# University of Illinois/NCSA
# Open Source License
#
# Copyright (c) 2003-2017 University of Illinois at Urbana-Champaign.
# All rights reserved.
#
# Developed by:
#
# LLVM Team
#
# University of Illinois at Urbana-Champaign
#
# http://llvm.org
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal with
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimers.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimers in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the names of the LLVM Team, University of Illinois at
# Urbana-Champaign, nor the names of its contributors may be used to
# endorse or promote products derived from this Software without specific
# prior written permission.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE
# SOFTWARE.
include(CheckCXXSourceCompiles)
# Sometimes linking against libatomic is required for atomic ops, if
# the platform doesn't support lock-free atomics.
function(check_working_cxx_atomics varname)
set(OLD_CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS})
get_directory_property(compile_options COMPILE_OPTIONS)
set(CMAKE_REQUIRED_FLAGS ${compile_options})
CHECK_CXX_SOURCE_COMPILES("
#include <atomic>
int main() {
struct Test { int val; };
std::atomic<Test> s;
s.is_lock_free();
}" ${varname})
set(CMAKE_REQUIRED_FLAGS ${OLD_CMAKE_REQUIRED_FLAGS})
endfunction(check_working_cxx_atomics)
if(NOT DEFINED PROXYGEN_COMPILER_IS_GCC_COMPATIBLE)
if(CMAKE_COMPILER_IS_GNUCXX)
set(PROXYGEN_COMPILER_IS_GCC_COMPATIBLE ON)
elseif(MSVC)
set(PROXYGEN_COMPILER_IS_GCC_COMPATIBLE OFF)
elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
set(PROXYGEN_COMPILER_IS_GCC_COMPATIBLE ON)
elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Intel")
set(PROXYGEN_COMPILER_IS_GCC_COMPATIBLE ON)
endif()
endif()
# This isn't necessary on MSVC, so avoid command-line switch annoyance
# by only running on GCC-like hosts.
if(PROXYGEN_COMPILER_IS_GCC_COMPATIBLE)
# First check if atomics work without the library.
check_working_cxx_atomics(HAVE_CXX_ATOMICS_WITHOUT_LIB)
# If not, check if the library exists, and atomics work with it.
if(NOT HAVE_CXX_ATOMICS_WITHOUT_LIB)
check_library_exists(atomic __atomic_is_lock_free "" HAVE_LIBATOMIC)
if(HAVE_LIBATOMIC)
list(APPEND CMAKE_REQUIRED_LIBRARIES "atomic")
check_working_cxx_atomics(HAVE_CXX_ATOMICS_WITH_LIB)
if (NOT HAVE_CXX_ATOMICS_WITH_LIB)
message(FATAL_ERROR "Host compiler must support std::atomic!")
endif()
list(APPEND CMAKE_CXX_STANDARD_LIBRARIES -latomic)
else()
message(FATAL_ERROR "Host compiler appears to require libatomic, but cannot find it.")
endif()
endif()
endif()

View File

@ -0,0 +1,25 @@
# - Try to find double-conversion
# Once done, this will define
#
# DOUBLE_CONVERSION_FOUND - system has double-conversion
# DOUBLE_CONVERSION_INCLUDE_DIRS - the double-conversion include directories
# DOUBLE_CONVERSION_LIBRARIES - link these to use double-conversion
include(FindPackageHandleStandardArgs)
find_library(DOUBLE_CONVERSION_LIBRARY double-conversion
PATHS ${DOUBLE_CONVERSION_LIBRARYDIR})
find_path(DOUBLE_CONVERSION_INCLUDE_DIR double-conversion/double-conversion.h
PATHS ${DOUBLE_CONVERSION_INCLUDEDIR})
find_package_handle_standard_args(double_conversion DEFAULT_MSG
DOUBLE_CONVERSION_LIBRARY
DOUBLE_CONVERSION_INCLUDE_DIR)
mark_as_advanced(
DOUBLE_CONVERSION_LIBRARY
DOUBLE_CONVERSION_INCLUDE_DIR)
set(DOUBLE_CONVERSION_LIBRARIES ${DOUBLE_CONVERSION_LIBRARY})
set(DOUBLE_CONVERSION_INCLUDE_DIRS ${DOUBLE_CONVERSION_INCLUDE_DIR})

80
cmake/FindGMock.cmake Normal file
View File

@ -0,0 +1,80 @@
#
# Find libgmock
#
# LIBGMOCK_DEFINES - List of defines when using libgmock.
# LIBGMOCK_INCLUDE_DIR - where to find gmock/gmock.h, etc.
# LIBGMOCK_LIBRARIES - List of libraries when using libgmock.
# LIBGMOCK_FOUND - True if libgmock found.
IF (LIBGMOCK_INCLUDE_DIR)
# Already in cache, be silent
SET(LIBGMOCK_FIND_QUIETLY TRUE)
ENDIF ()
find_package(GTest CONFIG QUIET)
if (TARGET GTest::gmock)
get_target_property(LIBGMOCK_DEFINES GTest::gtest INTERFACE_COMPILE_DEFINITIONS)
if (NOT ${LIBGMOCK_DEFINES})
# Explicitly set to empty string if not found to avoid it being
# set to NOTFOUND and breaking compilation
set(LIBGMOCK_DEFINES "")
endif()
get_target_property(LIBGMOCK_INCLUDE_DIR GTest::gtest INTERFACE_INCLUDE_DIRECTORIES)
set(LIBGMOCK_LIBRARIES GTest::gmock_main GTest::gmock GTest::gtest)
set(LIBGMOCK_FOUND ON)
message(STATUS "Found gmock via config, defines=${LIBGMOCK_DEFINES}, include=${LIBGMOCK_INCLUDE_DIR}, libs=${LIBGMOCK_LIBRARIES}")
else()
FIND_PATH(LIBGMOCK_INCLUDE_DIR gmock/gmock.h)
FIND_LIBRARY(LIBGMOCK_MAIN_LIBRARY_DEBUG NAMES gmock_maind)
FIND_LIBRARY(LIBGMOCK_MAIN_LIBRARY_RELEASE NAMES gmock_main)
FIND_LIBRARY(LIBGMOCK_LIBRARY_DEBUG NAMES gmockd)
FIND_LIBRARY(LIBGMOCK_LIBRARY_RELEASE NAMES gmock)
FIND_LIBRARY(LIBGTEST_LIBRARY_DEBUG NAMES gtestd)
FIND_LIBRARY(LIBGTEST_LIBRARY_RELEASE NAMES gtest)
find_package(Threads REQUIRED)
INCLUDE(SelectLibraryConfigurations)
SELECT_LIBRARY_CONFIGURATIONS(LIBGMOCK_MAIN)
SELECT_LIBRARY_CONFIGURATIONS(LIBGMOCK)
SELECT_LIBRARY_CONFIGURATIONS(LIBGTEST)
set(LIBGMOCK_LIBRARIES
${LIBGMOCK_MAIN_LIBRARY}
${LIBGMOCK_LIBRARY}
${LIBGTEST_LIBRARY}
Threads::Threads
)
if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
# The GTEST_LINKED_AS_SHARED_LIBRARY macro must be set properly on Windows.
#
# There isn't currently an easy way to determine if a library was compiled as
# a shared library on Windows, so just assume we've been built against a
# shared build of gmock for now.
SET(LIBGMOCK_DEFINES "GTEST_LINKED_AS_SHARED_LIBRARY=1" CACHE STRING "")
endif()
# handle the QUIETLY and REQUIRED arguments and set LIBGMOCK_FOUND to TRUE if
# all listed variables are TRUE
INCLUDE(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(
GMock
DEFAULT_MSG
LIBGMOCK_MAIN_LIBRARY
LIBGMOCK_LIBRARY
LIBGTEST_LIBRARY
LIBGMOCK_LIBRARIES
LIBGMOCK_INCLUDE_DIR
)
MARK_AS_ADVANCED(
LIBGMOCK_DEFINES
LIBGMOCK_MAIN_LIBRARY
LIBGMOCK_LIBRARY
LIBGTEST_LIBRARY
LIBGMOCK_LIBRARIES
LIBGMOCK_INCLUDE_DIR
)
endif()

32
cmake/FindGlog.cmake Normal file
View File

@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# - Try to find Glog
# Once done, this will define
#
# GLOG_FOUND - system has Glog
# GLOG_INCLUDE_DIRS - the Glog include directories
# GLOG_LIBRARIES - link these to use Glog
include(FindPackageHandleStandardArgs)
find_library(GLOG_LIBRARY glog
PATHS ${GLOG_LIBRARYDIR})
find_path(GLOG_INCLUDE_DIR glog/logging.h
PATHS ${GLOG_INCLUDEDIR})
find_package_handle_standard_args(glog DEFAULT_MSG
GLOG_LIBRARY
GLOG_INCLUDE_DIR)
mark_as_advanced(
GLOG_LIBRARY
GLOG_INCLUDE_DIR)
set(GLOG_LIBRARIES ${GLOG_LIBRARY})
set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR})
if (NOT TARGET glog::glog)
add_library(glog::glog UNKNOWN IMPORTED)
set_target_properties(glog::glog PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${GLOG_INCLUDE_DIRS}")
set_target_properties(glog::glog PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${GLOG_LIBRARIES}")
endif()

25
cmake/FindLibevent.cmake Normal file
View File

@ -0,0 +1,25 @@
# - Try to find Libevent
# Once done, this will define
#
# LIBEVENT_FOUND - system has Libevent
# LIBEVENT_INCLUDE_DIRS - the Libevent include directories
# LIBEVENT_LIBRARIES - link these to use Libevent
include(FindPackageHandleStandardArgs)
find_library(LIBEVENT_LIBRARY event
PATHS ${LIBEVENT_LIBRARYDIR})
find_path(LIBEVENT_INCLUDE_DIR event.h
PATHS ${LIBEVENT_INCLUDEDIR})
find_package_handle_standard_args(libevent DEFAULT_MSG
LIBEVENT_LIBRARY
LIBEVENT_INCLUDE_DIR)
mark_as_advanced(
LIBEVENT_LIBRARY
LIBEVENT_INCLUDE_DIR)
set(LIBEVENT_LIBRARIES ${LIBEVENT_LIBRARY})
set(LIBEVENT_INCLUDE_DIRS ${LIBEVENT_INCLUDE_DIR})

16
cmake/FindLibrt.cmake Normal file
View File

@ -0,0 +1,16 @@
# - Try to find librt
# Once done, this will define
#
# LIBRT_FOUND - system has librt
# LIBRT_LIBRARIES - link these to use librt
include(FindPackageHandleStandardArgs)
find_library(LIBRT_LIBRARY rt
PATHS ${LIBRT_LIBRARYDIR})
find_package_handle_standard_args(librt DEFAULT_MSG LIBRT_LIBRARY)
mark_as_advanced(LIBRT_LIBRARY)
set(LIBRT_LIBRARIES ${LIBRT_LIBRARY})

61
cmake/QuicTest.cmake Normal file
View File

@ -0,0 +1,61 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
include(CTest)
if(BUILD_TESTS)
include(GoogleTest)
find_package(GMock MODULE REQUIRED)
endif()
function(quic_add_test)
if(NOT BUILD_TESTS)
return()
endif()
set(options)
set(one_value_args TARGET WORKING_DIRECTORY PREFIX)
set(multi_value_args SOURCES DEPENDS INCLUDES EXTRA_ARGS)
cmake_parse_arguments(PARSE_ARGV 0 QUIC_TEST "${options}" "${one_value_args}" "${multi_value_args}")
if(NOT QUIC_TEST_TARGET)
message(FATAL_ERROR "The TARGET parameter is mandatory.")
endif()
if(NOT QUIC_TEST_SOURCES)
set(QUIC_TEST_SOURCES "${QUIC_TEST_TARGET}.cpp")
endif()
add_executable(${QUIC_TEST_TARGET}
"${QUIC_TEST_SOURCES}"
# implementation of 'main()' that calls folly::init
"${QUIC_FBCODE_ROOT}/quic/common/test/TestMain.cpp"
)
target_compile_options(
${QUIC_TEST_TARGET} PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
target_link_libraries(${QUIC_TEST_TARGET} PRIVATE
"${QUIC_TEST_DEPENDS}"
)
target_include_directories(${QUIC_TEST_TARGET} PRIVATE
"${QUIC_TEST_INCLUDES}"
)
gtest_add_tests(TARGET ${QUIC_TEST_TARGET}
EXTRA_ARGS "${QUIC_TEST_EXTRA_ARGS}"
WORKING_DIRECTORY ${QUIC_TEST_WORKING_DIRECTORY}
TEST_PREFIX ${QUIC_TEST_PREFIX}
TEST_LIST QUIC_TEST_CASES)
target_link_libraries(${QUIC_TEST_TARGET} PRIVATE
${LIBGMOCK_LIBRARIES}
)
target_include_directories(${QUIC_TEST_TARGET} PRIVATE
${LIBGMOCK_INCLUDE_DIR}
${QUIC_EXTRA_INCLUDE_DIRECTORIES}
)
target_compile_definitions(${QUIC_TEST_TARGET} PRIVATE ${LIBGMOCK_DEFINES})
set_tests_properties(${QUIC_TEST_CASES} PROPERTIES TIMEOUT 120)
endfunction()

View File

@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# This module sets the following variables:
# mvfst_FOUND
# mvfst_INCLUDE_DIRS
#
# This module exports the following target:
# mvfst::mvfst
#
# which can be used with target_link_libraries() to pull in the proxygen
# library.
@PACKAGE_INIT@
include(CMakeFindDependencyMacro)
find_dependency(folly)
find_dependency(Fizz)
find_dependency(Threads)
find_dependency(Boost COMPONENTS iostreams system thread filesystem regex context)
if(NOT TARGET mvfst::mvfst)
include("${CMAKE_CURRENT_LIST_DIR}/mvfst-targets.cmake")
get_target_property(mvfst_INCLUDE_DIRS mvfst::mvfst INTERFACE_INCLUDE_DIRECTORIES)
endif()
if(NOT mvfst_FIND_QUIETLY)
message(STATUS "Found mvfst: ${PACKAGE_PREFIX_DIR}")
endif()

3
quic/.clang-tidy Normal file
View File

@ -0,0 +1,3 @@
---
Checks: 'boost-*,bugprone-*,clang-analyzer-*,modernize-*,performance-*'
...

0
quic/AUTODEPS Normal file
View File

31
quic/BUCK Normal file
View File

@ -0,0 +1,31 @@
fb_xplat_cxx_library(
name = "quic",
srcs = glob([
"api/*.cpp",
"client/*.cpp",
"codec/*.cpp",
"common/*.cpp",
"congestion_control/*.cpp",
"flowcontrol/*.cpp",
"happyeyeballs/*.cpp",
"logging/*.cpp",
"loss/*.cpp",
"state/*.cpp",
]),
header_namespace = "",
exported_headers = subdir_glob(
[
("", "api/QuicSocket.h"),
("", "client/QuicClientTransport.h"),
("", "client/handshake/QuicPskCache.h"),
],
prefix = "quic",
),
apple_sdks = (IOS, MACOSX, APPLETVOS),
visibility = ["PUBLIC"],
deps = [
"fbsource//xplat/folly:extended",
"fbsource//xplat/folly:molly",
"fbsource//xplat/third-party/boost:boost",
],
)

46
quic/BUILD_MODE.bzl Normal file
View File

@ -0,0 +1,46 @@
# Copyright 2017 Facebook
""" build mode definitions for quic """
load("@fbcode_macros//build_defs:create_build_mode.bzl", "create_build_mode")
_extra_cflags = [
]
_common_flags = [
"-Wformat",
"-Wformat-security",
"-Wunused-variable",
"-Wsign-compare",
]
_extra_clang_flags = _common_flags + [
# Default value for clang (3.4) is 256, change it to GCC's default value
# (https://fburl.com/23278774).
"-ftemplate-depth=900",
"-Wmismatched-tags",
# Only check shadowing with Clang: gcc complains about constructor
# argument shadowing
"-Wshadow",
]
_extra_gcc_flags = _common_flags + [
"-Wall",
]
_mode = create_build_mode(
c_flags = _extra_cflags,
clang_flags = _extra_clang_flags,
gcc_flags = _extra_gcc_flags,
)
_modes = {
"dbg": _mode,
"dbgo": _mode,
"dev": _mode,
"opt": _mode,
}
def get_modes():
""" Return modes for this file """
return _modes

80
quic/CMakeLists.txt Normal file
View File

@ -0,0 +1,80 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
add_library(
mvfst_constants
QuicConstants.cpp
)
target_include_directories(
mvfst_constants PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
PRIVATE
${Boost_INCLUDE_DIR}
)
target_compile_options(
mvfst_constants
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
target_link_libraries(
mvfst_constants PUBLIC
Folly::folly
${Boost_LIBRARIES}
)
install(FILES QuicConstants.h DESTINATION include/quic/)
install(
TARGETS mvfst_constants
EXPORT mvfst-exports
DESTINATION lib
)
add_library(
mvfst_exception
QuicException.cpp
)
target_include_directories(
mvfst_exception PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_exception
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
target_link_libraries(
mvfst_exception PUBLIC
Folly::folly
)
install(FILES QuicException.h DESTINATION include/quic/)
install(
TARGETS mvfst_exception
EXPORT mvfst-exports
DESTINATION lib
)
add_subdirectory(api)
add_subdirectory(client)
add_subdirectory(codec)
add_subdirectory(common)
add_subdirectory(congestion_control)
add_subdirectory(flowcontrol)
add_subdirectory(handshake)
add_subdirectory(happyeyeballs)
add_subdirectory(logging)
add_subdirectory(loss)
add_subdirectory(samples)
add_subdirectory(server)
add_subdirectory(state)

40
quic/QuicConstants.cpp Normal file
View File

@ -0,0 +1,40 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/QuicConstants.h>
namespace quic {
QuicBatchingMode getQuicBatchingMode(uint32_t val) {
switch (val) {
case static_cast<uint32_t>(QuicBatchingMode::BATCHING_MODE_NONE):
return QuicBatchingMode::BATCHING_MODE_NONE;
case static_cast<uint32_t>(QuicBatchingMode::BATCHING_MODE_GSO):
return QuicBatchingMode::BATCHING_MODE_GSO;
case static_cast<uint32_t>(QuicBatchingMode::BATCHING_MODE_SENDMMSG):
return QuicBatchingMode::BATCHING_MODE_SENDMMSG;
// no default
}
return QuicBatchingMode::BATCHING_MODE_NONE;
}
std::vector<QuicVersion> filterSupportedVersions(
const std::vector<QuicVersion>& versions) {
std::vector<QuicVersion> filteredVersions;
std::copy_if(
versions.begin(),
versions.end(),
std::back_inserter(filteredVersions),
[](auto version) {
return version == QuicVersion::MVFST ||
version == QuicVersion::QUIC_DRAFT ||
version == QuicVersion::MVFST_INVALID;
});
return filteredVersions;
}
} // namespace quic

426
quic/QuicConstants.h Normal file
View File

@ -0,0 +1,426 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <boost/variant.hpp>
#include <folly/Range.h>
#include <folly/String.h>
#include <chrono>
#include <cstdint>
namespace quic {
using Clock = std::chrono::steady_clock;
using TimePoint = std::chrono::time_point<Clock>;
// Default QUIC packet size for both read and write.
constexpr uint64_t kDefaultV4UDPSendPacketLen = 1252;
constexpr uint64_t kDefaultV6UDPSendPacketLen = 1232;
// With Android NDK r15c for some apps we use gnu-libstdc++ instead of
// llvm-libc++. And gnu-libstdc++ doesn't like to make std::min constexpr.
// That's why we can't have nice things in the world.
constexpr uint16_t kDefaultUDPSendPacketLen =
(kDefaultV4UDPSendPacketLen < kDefaultV6UDPSendPacketLen
? kDefaultV4UDPSendPacketLen
: kDefaultV6UDPSendPacketLen);
// This is the default if the transport parameter for max packet size is missing
// or zero.
constexpr uint16_t kDefaultMaxUDPPayload = 65527;
// This is the minimum the max_packet_size transport parameter is allowed to be,
// per the spec. Note this actually refers to the max UDP payload size, not the
// maximum QUIC packet size.
constexpr uint16_t kMinMaxUDPPayload = 1200;
// How many bytes to reduce from udpSendPacketLen when socket write leads to
// EMSGSIZE.
constexpr uint16_t kDefaultMsgSizeBackOffSize = 50;
// Size of read buffer we provide to AsyncUDPSocket. The packet size cannot be
// larger than this, unless configured otherwise.
constexpr uint16_t kDefaultUDPReadBufferSize = 4096;
constexpr uint16_t kMaxNumCoalescedPackets = 5;
// As per version 17 of the spec, transport parameters for private use must
// have ids greater than 0x3fff.
constexpr uint16_t kCustomTransportParameterThreshold = 0x3fff;
// If the amount of data in the buffer of a QuicSocket equals or exceeds this
// threshold, then the callback registered through
// notifyPendingWriteOnConnection() will not be called
constexpr uint64_t kDefaultBufferSpaceAvailable =
std::numeric_limits<uint64_t>::max();
// Frames types with values defines in Quic Draft 15+
enum class FrameType : uint8_t {
PADDING = 0x00,
PING = 0x01,
ACK = 0x02,
ACK_ECN = 0x03,
RST_STREAM = 0x04,
STOP_SENDING = 0x05,
CRYPTO_FRAME = 0x06, // librtmp has a #define CRYPTO
NEW_TOKEN = 0x07,
// STREAM frame can have values from 0x08 to 0x0f
STREAM = 0x08,
MAX_DATA = 0x10,
MAX_STREAM_DATA = 0x11,
MAX_STREAMS_BIDI = 0x12,
MAX_STREAMS_UNI = 0x13,
DATA_BLOCKED = 0x14,
STREAM_DATA_BLOCKED = 0x15,
STREAMS_BLOCKED_BIDI = 0x16,
STREAMS_BLOCKED_UNI = 0x17,
NEW_CONNECTION_ID = 0x18,
RETIRE_CONNECTION_ID = 0x19,
PATH_CHALLENGE = 0x1A,
PATH_RESPONSE = 0x1B,
CONNECTION_CLOSE = 0x1C,
APPLICATION_CLOSE = 0x1D,
MIN_STREAM_DATA = 0xFE, // subject to change (https://fburl.com/qpr)
EXPIRED_STREAM_DATA = 0xFF, // subject to change (https://fburl.com/qpr)
};
inline constexpr uint16_t toFrameError(FrameType frame) {
return 0x0100 | static_cast<uint8_t>(frame);
}
enum class TransportErrorCode : uint16_t {
NO_ERROR = 0x0000,
INTERNAL_ERROR = 0x0001,
FLOW_CONTROL_ERROR = 0x0003,
STREAM_LIMIT_ERROR = 0x0004,
STREAM_STATE_ERROR = 0x0005,
FINAL_OFFSET_ERROR = 0x0006,
FRAME_ENCODING_ERROR = 0x0007,
TRANSPORT_PARAMETER_ERROR = 0x0008,
VERSION_NEGOTIATION_ERROR = 0x0009,
PROTOCOL_VIOLATION = 0x000A,
INVALID_MIGRATION = 0x000C,
TLS_HANDSHAKE_FAILED = 0x201,
TLS_FATAL_ALERT_GENERATED = 0x202,
TLS_FATAL_ALERT_RECEIVED = 0x203,
};
enum class ApplicationErrorCode : uint16_t {
STOPPING = 0x00,
// HTTP2/QUIC error codes
HTTP_NO_ERROR = 0x01,
HTTP_PUSH_REFUSED = 0x02,
HTTP_INTERNAL_ERROR = 0x03,
HTTP_PUSH_ALREADY_IN_CACHE = 0x04,
HTTP_REQUEST_CANCELLED = 0x05,
HTTP_INCOMPLETE_REQUEST = 0x06,
HTTP_CONNECT_ERROR = 0x07,
HTTP_EXCESSIVE_LOAD = 0x08,
HTTP_VERSION_FALLBACK = 0x09,
HTTP_WRONG_STREAM = 0x0A,
HTTP_PUSH_LIMIT_EXCEEDED = 0x0B,
HTTP_DUPLICATE_PUSH = 0x0C,
HTTP_UNKNOWN_STREAM_TYPE = 0x0D,
HTTP_WRONG_STREAM_COUNT = 0x0E,
HTTP_CLOSED_CRITICAL_STREAM = 0x0F,
HTTP_WRONG_STREAM_DIRECTION = 0x10,
HTTP_EARLY_RESPONSE = 0x11,
HTTP_MISSING_SETTINGS = 0x12,
HTTP_UNEXPECTED_FRAME = 0x13,
HTTP_REQUEST_REJECTED = 0x14,
HTTP_QPACK_DECOMPRESSION_FAILED = 0xE0,
HTTP_QPACK_DECODER_STREAM_ERROR = 0xE1,
HTTP_QPACK_ENCODER_STREAM_ERROR = 0xE2,
HTTP_GENERAL_PROTOCOL_ERROR = 0xFF,
HTTP_MALFORMED_FRAME_DATA = 0x0100,
HTTP_MALFORMED_FRAME_HEADERS = 0x0101,
HTTP_MALFORMED_FRAME_PRIORITY = 0x0102,
HTTP_MALFORMED_FRAME_CANCEL_PUSH = 0x0103,
HTTP_MALFORMED_FRAME_SETTINGS = 0x0104,
HTTP_MALFORMED_FRAME_PUSH_PROMISE = 0x0105,
HTTP_MALFORMED_FRAME_GOAWAY = 0x0107,
HTTP_MALFORMED_FRAME_MAX_PUSH_ID = 0x010D,
HTTP_MALFORMED_FRAME = 0x01FF,
// Internal use only
INTERNAL_ERROR = 0xF1,
GIVEUP_ZERO_RTT = 0xF2
};
enum class LocalErrorCode : uint32_t {
// Local errors
NO_ERROR = 0x00000000,
CONNECT_FAILED = 0x40000000,
CODEC_ERROR = 0x40000001,
STREAM_CLOSED = 0x40000002,
STREAM_NOT_EXISTS = 0x40000003,
CREATING_EXISTING_STREAM = 0x40000004,
SHUTTING_DOWN = 0x40000005,
RESET_CRYPTO_STREAM = 0x40000006,
CWND_OVERFLOW = 0x40000007,
INFLIGHT_BYTES_OVERFLOW = 0x40000008,
LOST_BYTES_OVERFLOW = 0x40000009,
// This is a retryable error. When encountering this error,
// the user should retry the request.
NEW_VERSION_NEGOTIATED = 0x4000000A,
INVALID_WRITE_CALLBACK = 0x4000000B,
TLS_HANDSHAKE_FAILED = 0x4000000C,
APP_ERROR = 0x4000000D,
INTERNAL_ERROR = 0x4000000E,
TRANSPORT_ERROR = 0x4000000F,
INVALID_WRITE_DATA = 0x40000010,
INVALID_STATE_TRANSITION = 0x40000011,
CONNECTION_CLOSED = 0x40000012,
EARLY_DATA_REJECTED = 0x40000013,
CONNECTION_RESET = 0x40000014,
IDLE_TIMEOUT = 0x40000015,
PACKET_NUMBER_ENCODING = 0x40000016,
INVALID_OPERATION = 0x40000017,
STREAM_LIMIT_EXCEEDED = 0x40000018,
};
using QuicErrorCode =
boost::variant<ApplicationErrorCode, LocalErrorCode, TransportErrorCode>;
enum class QuicNodeType : bool {
Client,
Server,
};
enum class QuicVersion : uint32_t {
VERSION_NEGOTIATION = 0x00000000,
MVFST = 0xfaceb000,
QUIC_DRAFT = 0xFF000011, // Draft-17
MVFST_INVALID = 0xfaceb00f,
};
using QuicVersionType = std::underlying_type<QuicVersion>::type;
using TransportPartialReliabilitySetting = bool;
constexpr uint16_t kPartialReliabilityParameterId = 0xFF00; // subject to change
constexpr uint32_t kDrainFactor = 3;
// batching mode
enum class QuicBatchingMode : uint32_t {
BATCHING_MODE_NONE = 0,
BATCHING_MODE_GSO = 1,
BATCHING_MODE_SENDMMSG = 2,
};
QuicBatchingMode getQuicBatchingMode(uint32_t val);
// default QUIC batching mode - currently used only
// by BATCHING_MODE_GSO
constexpr uint32_t kDefaultQuicBatchingNum = 64;
// rfc6298:
constexpr int kRttAlpha = 8;
constexpr int kRttBeta = 4;
// Draft-17 recommends 100ms as initial RTT. We delibrately ignore that
// recommendation. This is not a bug.
constexpr std::chrono::microseconds kDefaultInitialRtt =
std::chrono::microseconds(50 * 1000);
constexpr std::chrono::microseconds kMinTLPTimeout =
std::chrono::microseconds(10 * 1000);
// HHWheelTimer tick interval
constexpr std::chrono::microseconds kGranularity =
std::chrono::microseconds(10 * 1000);
constexpr uint32_t kReorderingThreshold = 3;
constexpr double kTimeReorderingFraction = 0.125;
constexpr auto kPacketToSendForRTO = 2;
// Maximum number of packets to write per writeConnectionDataToSocket call.
constexpr uint64_t kDefaultWriteConnectionDataPacketLimit = 5;
// Maximum number of packets to write per burst in pacing
constexpr uint64_t kDefaultMaxBurstPackets = 10;
// Default timer tick interval for pacing timer
// the microsecond timers are accurate to about 5 usec
// but the notifications can get delayed if the event loop is busy
// this is subject to testing but I would suggest a value >= 200usec
constexpr std::chrono::microseconds kDefaultPacingTimerTickInterval{1000};
// Congestion control:
enum class CongestionControlType : uint8_t { Cubic, NewReno, Copa, BBR, None };
constexpr uint64_t kInitCwndInMss = 10;
constexpr uint64_t kMinCwndInMss = 2;
constexpr uint64_t kDefaultMaxCwndInMss = 2000;
// When server receives early data attempt without valid source address token,
// server will limit bytes in flight to avoid amplification attack until CFIN
// is received which proves sender owns the address.
constexpr uint64_t kLimitedCwndInMss = 3;
/* Hybrid slow start: */
// The first kAckSampling Acks within a RTT round will be used to sample delays
constexpr uint8_t kAckSampling = 8;
// Hystart won't exit slow start if Cwnd < kLowSsthresh
constexpr uint64_t kLowSsthreshInMss = 16;
// ACKs within kAckCountingGap are considered closely spaced, i.e., AckTrain
constexpr std::chrono::microseconds kAckCountingGap(2);
// Hystart's upper bound for DelayIncrease
constexpr std::chrono::microseconds kDelayIncreaseUpperBound(8);
// Hystart's lower bound for DelayIncrease
constexpr std::chrono::microseconds kDelayIncreaseLowerBound(2);
/* Cubic */
// Default cwnd reduction factor:
constexpr double kDefaultCubicReductionFactor = 0.8;
// Time elapsed scaling factor
constexpr double kTimeScalingFactor = 0.4;
// Default emulated connection numbers for each real connection
constexpr uint8_t kDefaultEmulatedConnection = 2;
// Default W_max reduction factor when loss happens before Cwnd gets back to
// previous W_max:
constexpr float kDefaultLastMaxReductionFactor = 0.85f;
/* Flow Control */
// Default flow control window for HTTP/2 + 1K for headers
constexpr uint64_t kDefaultStreamWindowSize = (64 + 1) * 1024;
constexpr uint64_t kDefaultConnectionWindowSize = 1024 * 1024;
/* Stream Limits */
constexpr uint64_t kDefaultMaxStreamsBidirectional = 2048;
constexpr uint64_t kDefaultMaxStreamsUnidirectional = 2048;
constexpr uint64_t kMaxStreamId = 1ull << 62;
constexpr uint64_t kMaxMaxStreams = 1ull << 60;
/* Idle timeout parameters */
// Default idle timeout to advertise.
constexpr std::chrono::seconds kDefaultIdleTimeout = std::chrono::seconds(60);
constexpr std::chrono::seconds kMaxIdleTimeout = std::chrono::seconds(600);
// Time format related:
constexpr uint8_t kQuicTimeExpoBits = 5;
constexpr uint8_t kQuicTimeMantissaBits = 16 - kQuicTimeExpoBits;
// This is the largest possible value with a exponent = 0:
constexpr uint16_t kLargestQuicTimeWithoutExpo = 0xFFF;
// Largest possible value with a positive exponent:
constexpr uint64_t kLargestQuicTime = 0x0FFFull << (0x1F - 1);
// Limit of non-retransmittable packets received before an Ack has to be
// emitted.
constexpr uint8_t kNonRxPacketsPendingBeforeAckThresh = 20;
// Limit of retransmittable packets received before an Ack has to be emitted.
constexpr uint8_t kRxPacketsPendingBeforeAckThresh = 10;
/* Ack timer */
// TODO: These numbers are shamlessly taken from Chromium code. We have no idea
// how good/bad this is.
// Ack timeout = SRTT * kAckTimerFactor
constexpr double kAckTimerFactor = 0.25;
// max ack timeout: 25ms
constexpr std::chrono::microseconds kMaxAckTimeout =
std::chrono::microseconds(25 * 1000);
// min ack timeout: 10ms
constexpr std::chrono::microseconds kMinAckTimeout =
std::chrono::microseconds(10 * 1000);
constexpr uint64_t kAckPurgingThresh = 10;
// Default number of packets to buffer if keys are not present.
constexpr uint32_t kDefaultMaxBufferedPackets = 20;
// Default exponent to use while computing ack delay.
constexpr uint64_t kDefaultAckDelayExponent = 3;
constexpr uint64_t kMaxAckDelayExponent = 20;
// Default connection id size of the connection id we will send.
constexpr size_t kDefaultConnectionIdSize = 8;
// Minimum size of the health check token. This is used to reduce the impact of
// amplification attacks.
constexpr size_t kMinHealthCheckTokenSize = 5;
// Maximum size of the reason phrase.
constexpr size_t kMaxReasonPhraseLength = 1024;
// Minimum size of an initial packet
constexpr size_t kMinInitialPacketSize = 1200;
// Default maximum RTOs that will happen before tearing down the connection
constexpr uint16_t kDefaultMaxNumRTO = 7;
// Maximum early data size that we need to negotiate in TLS
constexpr uint32_t kRequiredMaxEarlyDataSize = 0xffffffff;
// The minimum size of a stateless reset packet. This is the short header size,
// and 16 bytes of the token and 16 bytes of randomness
constexpr uint16_t kMinStatelessPacketSize = 13 + 16 + 16;
// TODO: remove this when we have a stateless reset generator.
constexpr folly::StringPiece kTestStatelessResetToken = "aaaabbbbccccdddd";
constexpr std::chrono::milliseconds kHappyEyeballsV4Delay =
std::chrono::milliseconds(150);
constexpr std::chrono::milliseconds kHappyEyeballsConnAttemptDelayWithCache =
std::chrono::seconds(15);
constexpr size_t kMaxNumTokenSourceAddresses = 3;
// Amount of time to retain initial keys until they are dropped after handshake
// completion.
constexpr std::chrono::seconds kTimeToRetainInitialKeys =
std::chrono::seconds(20);
// Amount of time to retain zero rtt keys until they are dropped after handshake
// completion.
constexpr std::chrono::seconds kTimeToRetainZeroRttKeys =
std::chrono::seconds(20);
constexpr std::chrono::seconds kTimeToRetainLastCongestionAndRttState =
std::chrono::seconds(60);
constexpr uint32_t kMaxNumMigrationsAllowed = 6;
constexpr auto kExpectedNumOfParamsInTheTicket = 7;
// default capability of QUIC partial reliability
constexpr TransportPartialReliabilitySetting kDefaultPartialReliability = false;
enum class ZeroRttSourceTokenMatchingPolicy : uint8_t {
REJECT_IF_NO_EXACT_MATCH,
LIMIT_IF_NO_EXACT_MATCH,
// T33014230 Subnet matching
// REJECT_IF_NO_SUBNECT_MATCH,
// LIMIT_IF_NO_EXACT_MATCH
};
inline folly::StringPiece nodeToString(QuicNodeType node) {
if (node == QuicNodeType::Client) {
return "Client";
} else {
return "Server";
}
}
template <class T>
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
for (auto it = v.cbegin(); it != v.cend(); ++it) {
os << *it;
if (std::next(it) != v.cend()) {
os << ",";
}
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const QuicVersion& v) {
os << static_cast<std::underlying_type<QuicVersion>::type>(v);
return os;
}
/**
* Filter the versions that are currently supported.
*/
std::vector<QuicVersion> filterSupportedVersions(
const std::vector<QuicVersion>&);
} // namespace quic

258
quic/QuicException.cpp Normal file
View File

@ -0,0 +1,258 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/QuicException.h>
#include <folly/Overload.h>
#include <glog/logging.h>
namespace quic {
QuicTransportException::QuicTransportException(
const std::string& msg,
TransportErrorCode errCode)
: std::runtime_error(msg), errCode_(errCode){};
QuicTransportException::QuicTransportException(
const char* msg,
TransportErrorCode errCode)
: std::runtime_error(msg), errCode_(errCode){};
QuicTransportException::QuicTransportException(
const std::string& msg,
TransportErrorCode errCode,
FrameType frameType)
: std::runtime_error(msg), errCode_(errCode), frameType_(frameType){};
QuicTransportException::QuicTransportException(
const char* msg,
TransportErrorCode errCode,
FrameType frameType)
: std::runtime_error(msg), errCode_(errCode), frameType_(frameType){};
QuicInternalException::QuicInternalException(
const std::string& msg,
LocalErrorCode errCode)
: std::runtime_error(msg), errorCode_(errCode){};
QuicInternalException::QuicInternalException(
const char* msg,
LocalErrorCode errCode)
: std::runtime_error(msg), errorCode_(errCode){};
QuicApplicationException::QuicApplicationException(
const std::string& msg,
ApplicationErrorCode errorCode)
: std::runtime_error(msg), errorCode_(errorCode){};
QuicApplicationException::QuicApplicationException(
const char* msg,
ApplicationErrorCode errorCode)
: std::runtime_error(msg), errorCode_(errorCode){};
std::string toString(LocalErrorCode code) {
switch (code) {
case LocalErrorCode::NO_ERROR:
return "No Error";
case LocalErrorCode::CONNECT_FAILED:
return "Connect failed";
case LocalErrorCode::CODEC_ERROR:
return "Codec Error";
case LocalErrorCode::STREAM_CLOSED:
return "Stream is closed";
case LocalErrorCode::STREAM_NOT_EXISTS:
return "Stream does not exist";
case LocalErrorCode::CREATING_EXISTING_STREAM:
return "Creating an existing stream";
case LocalErrorCode::SHUTTING_DOWN:
return "Shutting down";
case LocalErrorCode::RESET_CRYPTO_STREAM:
return "Reset the crypto stream";
case LocalErrorCode::CWND_OVERFLOW:
return "CWND overflow";
case LocalErrorCode::INFLIGHT_BYTES_OVERFLOW:
return "Inflight bytes overflow";
case LocalErrorCode::LOST_BYTES_OVERFLOW:
return "Lost bytes overflow";
case LocalErrorCode::NEW_VERSION_NEGOTIATED:
return "New version negotiatied";
case LocalErrorCode::INVALID_WRITE_CALLBACK:
return "Invalid write callback";
case LocalErrorCode::TLS_HANDSHAKE_FAILED:
return "TLS handshake failed";
case LocalErrorCode::APP_ERROR:
return "App error";
case LocalErrorCode::INTERNAL_ERROR:
return "Internal error";
case LocalErrorCode::TRANSPORT_ERROR:
return "Transport error";
case LocalErrorCode::INVALID_WRITE_DATA:
return "Invalid write data";
case LocalErrorCode::INVALID_STATE_TRANSITION:
return "Invalid state transition";
case LocalErrorCode::CONNECTION_CLOSED:
return "Connection closed";
case LocalErrorCode::EARLY_DATA_REJECTED:
return "Early data rejected";
case LocalErrorCode::CONNECTION_RESET:
return "Connection reset";
case LocalErrorCode::IDLE_TIMEOUT:
return "Idle timeout";
case LocalErrorCode::PACKET_NUMBER_ENCODING:
return "Packet number encoding";
case LocalErrorCode::INVALID_OPERATION:
return "Invalid operation";
case LocalErrorCode::STREAM_LIMIT_EXCEEDED:
return "Stream limit exceeded";
}
LOG(WARNING) << "toString has unhandled ErrorCode";
return "Unknown error";
}
std::string toString(TransportErrorCode code) {
switch (code) {
case TransportErrorCode::NO_ERROR:
return "No Error";
case TransportErrorCode::INTERNAL_ERROR:
return "Internal Error";
case TransportErrorCode::FLOW_CONTROL_ERROR:
return "Flow control error";
case TransportErrorCode::STREAM_LIMIT_ERROR:
return "Stream limit error";
case TransportErrorCode::STREAM_STATE_ERROR:
return "Stream State error";
case TransportErrorCode::FINAL_OFFSET_ERROR:
return "Final offset error";
case TransportErrorCode::FRAME_ENCODING_ERROR:
return "Frame format error";
case TransportErrorCode::TRANSPORT_PARAMETER_ERROR:
return "Transport parameter error";
case TransportErrorCode::VERSION_NEGOTIATION_ERROR:
return "Version negotiation error";
case TransportErrorCode::PROTOCOL_VIOLATION:
return "Protocol violation";
case TransportErrorCode::INVALID_MIGRATION:
return "Invalid migration";
case TransportErrorCode::TLS_HANDSHAKE_FAILED:
return "Handshake Failed";
case TransportErrorCode::TLS_FATAL_ALERT_GENERATED:
return "TLS Alert Sent";
case TransportErrorCode::TLS_FATAL_ALERT_RECEIVED:
return "TLS Alert Received";
}
LOG(WARNING) << "toString has unhandled ErrorCode";
return "Unknown error";
}
std::string toString(ApplicationErrorCode code) {
switch (code) {
case ApplicationErrorCode::STOPPING:
return "Stopping";
case ApplicationErrorCode::HTTP_NO_ERROR:
return "HTTP: No error";
case ApplicationErrorCode::HTTP_PUSH_REFUSED:
return "HTTP: Client refused pushed content";
case ApplicationErrorCode::HTTP_INTERNAL_ERROR:
return "HTTP: Internal error";
case ApplicationErrorCode::HTTP_PUSH_ALREADY_IN_CACHE:
return "HTTP: Pushed content already cached";
case ApplicationErrorCode::HTTP_REQUEST_CANCELLED:
return "HTTP: Data no longer needed";
case ApplicationErrorCode::HTTP_INCOMPLETE_REQUEST:
return "HTTP: Stream terminated early";
case ApplicationErrorCode::HTTP_CONNECT_ERROR:
return "HTTP: Reset or error on CONNECT request";
case ApplicationErrorCode::HTTP_EXCESSIVE_LOAD:
return "HTTP: Peer generating excessive load";
case ApplicationErrorCode::HTTP_VERSION_FALLBACK:
return "HTTP: Retry over HTTP/1.1";
case ApplicationErrorCode::HTTP_WRONG_STREAM:
return "HTTP: A frame was sent on the wrong stream";
case ApplicationErrorCode::HTTP_PUSH_LIMIT_EXCEEDED:
return "HTTP: Maximum Push ID exceeded";
case ApplicationErrorCode::HTTP_DUPLICATE_PUSH:
return "HTTP: Push ID was fulfilled multiple times";
case ApplicationErrorCode::HTTP_UNKNOWN_STREAM_TYPE:
return "HTTP: Unknown unidirectional stream type";
case ApplicationErrorCode::HTTP_WRONG_STREAM_COUNT:
return "HTTP: Too many unidirectional streams";
case ApplicationErrorCode::HTTP_CLOSED_CRITICAL_STREAM:
return "HTTP: Critical stream was closed";
case ApplicationErrorCode::HTTP_WRONG_STREAM_DIRECTION:
return "HTTP: Unidirectional stream in wrong direction";
case ApplicationErrorCode::HTTP_EARLY_RESPONSE:
return "HTTP: Remainder of request not needed";
case ApplicationErrorCode::HTTP_MISSING_SETTINGS:
return "HTTP: No SETTINGS frame received";
case ApplicationErrorCode::HTTP_UNEXPECTED_FRAME:
return "HTTP: Unexpected frame from client";
case ApplicationErrorCode::HTTP_REQUEST_REJECTED:
return "HTTP: Server did not process request";
case ApplicationErrorCode::HTTP_QPACK_DECOMPRESSION_FAILED:
return "HTTP: QPACK decompression failed";
case ApplicationErrorCode::HTTP_QPACK_DECODER_STREAM_ERROR:
return "HTTP: Error on QPACK decoder stream";
case ApplicationErrorCode::HTTP_QPACK_ENCODER_STREAM_ERROR:
return "HTTP: Error on QPACK encoder stream";
case ApplicationErrorCode::HTTP_GENERAL_PROTOCOL_ERROR:
return "HTTP: General protocol error";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_DATA:
return "HTTP: Malformed DATA frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_HEADERS:
return "HTTP: Malformed HEADERS frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_PRIORITY:
return "HTTP: Malformed PRIORITY frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_CANCEL_PUSH:
return "HTTP: Malformed CANCEL_PUSH frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_SETTINGS:
return "HTTP: Malformed SETTINGS frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_PUSH_PROMISE:
return "HTTP: Malformed PUSH_PROMISE frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_GOAWAY:
return "HTTP: Malformed GOAWAY frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME_MAX_PUSH_ID:
return "HTTP: Malformed MAX_PUSH_ID frame";
case ApplicationErrorCode::HTTP_MALFORMED_FRAME:
return "HTTP: Malformed frame";
case ApplicationErrorCode::INTERNAL_ERROR:
return "Internal error";
case ApplicationErrorCode::GIVEUP_ZERO_RTT:
return "Give up Zero RTT";
}
LOG(WARNING) << "toString has unhandled ErrorCode";
return "Unknown error";
}
std::string toString(QuicErrorCode code) {
return folly::variant_match(
code,
[](ApplicationErrorCode errorCode) { return toString(errorCode); },
[](LocalErrorCode errorCode) { return toString(errorCode); },
[](TransportErrorCode errorCode) { return toString(errorCode); });
}
std::string toString(
const std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>&
error) {
return folly::to<std::string>(
folly::variant_match(
error.first,
[](ApplicationErrorCode errorCode) {
return "ApplicationError: " + toString(errorCode) + ", ";
},
[](LocalErrorCode errorCode) {
return "LocalError: " + toString(errorCode) + ", ";
},
[](TransportErrorCode errorCode) {
return "TransportError: " + toString(errorCode) + ", ";
}),
error.second.value_or(folly::StringPiece()).toString());
}
} // namespace quic

104
quic/QuicException.h Normal file
View File

@ -0,0 +1,104 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <stdexcept>
#include <string>
#include <quic/QuicConstants.h>
namespace quic {
class QuicTransportException : public std::runtime_error {
public:
explicit QuicTransportException(
const std::string& msg,
TransportErrorCode errCode);
explicit QuicTransportException(const char* msg, TransportErrorCode errCode);
explicit QuicTransportException(
const std::string& msg,
TransportErrorCode errCode,
FrameType frameType);
explicit QuicTransportException(
const char* msg,
TransportErrorCode errCode,
FrameType frameType);
TransportErrorCode errorCode() const noexcept {
return errCode_;
}
folly::Optional<FrameType> frameType() const noexcept {
return frameType_;
}
private:
TransportErrorCode errCode_;
folly::Optional<FrameType> frameType_;
};
class QuicInternalException : public std::runtime_error {
public:
explicit QuicInternalException(
const std::string& msg,
LocalErrorCode errorCode);
explicit QuicInternalException(const char* msg, LocalErrorCode errCode);
LocalErrorCode errorCode() const noexcept {
return errorCode_;
}
private:
LocalErrorCode errorCode_;
};
class QuicApplicationException : public std::runtime_error {
public:
explicit QuicApplicationException(
const std::string& msg,
ApplicationErrorCode errorCode);
explicit QuicApplicationException(
const char* msg,
ApplicationErrorCode errorCode);
ApplicationErrorCode errorCode() const noexcept {
return errorCode_;
}
private:
ApplicationErrorCode errorCode_;
};
/**
* Convert the error code to a string.
*/
std::string toString(TransportErrorCode code);
std::string toString(LocalErrorCode code);
std::string toString(ApplicationErrorCode code);
std::string toString(QuicErrorCode code);
std::string toString(
const std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>& error);
inline std::ostream& operator<<(std::ostream& os, const QuicErrorCode& error) {
os << toString(error);
return os;
}
inline std::ostream& operator<<(
std::ostream& os,
const std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>&
error) {
os << toString(error);
return os;
}
} // namespace quic

37
quic/TARGETS Normal file
View File

@ -0,0 +1,37 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
cpp_library(
name = "constants",
srcs = [
"QuicConstants.cpp",
],
headers = [
"QuicConstants.h",
],
deps = [
"//folly:range",
"//folly:string",
],
external_deps = [
"boost",
],
)
cpp_library(
name = "exception",
srcs = [
"QuicException.cpp",
],
headers = [
"QuicException.h",
],
deps = [
":constants",
"//folly:overload",
],
external_deps = [
"glog",
],
)

92
quic/api/CMakeLists.txt Normal file
View File

@ -0,0 +1,92 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
add_library(
mvfst_transport STATIC
IoBufQuicBatch.cpp
QuicBatchWriter.cpp
QuicPacketScheduler.cpp
QuicTransportBase.cpp
QuicTransportFunctions.cpp
)
target_include_directories(
mvfst_transport PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_transport
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_transport
mvfst_cc_algo
mvfst_codec
mvfst_codec_pktbuilder
mvfst_codec_pktrebuilder
mvfst_codec_types
mvfst_constants
mvfst_exception
mvfst_flowcontrol
mvfst_happyeyeballs
mvfst_logging
mvfst_looper
mvfst_loss
mvfst_state_functions
mvfst_state_machine
mvfst_state_pacing_functions
mvfst_state_simple_frame_functions
mvfst_state_stream
mvfst_state_stream_functions
)
target_link_libraries(
mvfst_transport PUBLIC
Folly::folly
fizz::fizz
mvfst_cc_algo
mvfst_codec
mvfst_codec_pktbuilder
mvfst_codec_pktrebuilder
mvfst_codec_types
mvfst_constants
mvfst_exception
mvfst_flowcontrol
mvfst_happyeyeballs
mvfst_logging
mvfst_looper
mvfst_loss
mvfst_state_functions
mvfst_state_machine
mvfst_state_pacing_functions
mvfst_state_simple_frame_functions
mvfst_state_stream
mvfst_state_stream_functions
PRIVATE
${BOOST_LIBRARIES}
)
file(
GLOB_RECURSE QUIC_API_HEADERS_TOINSTALL
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.h
)
list(FILTER QUIC_API_HEADERS_TOINSTALL EXCLUDE REGEX test/)
foreach(header ${QUIC_API_HEADERS_TOINSTALL})
get_filename_component(header_dir ${header} DIRECTORY)
install(FILES ${header} DESTINATION include/quic/api/${header_dir})
endforeach()
install(
TARGETS mvfst_transport
EXPORT mvfst-exports
DESTINATION lib
)
add_subdirectory(test)

128
quic/api/IoBufQuicBatch.cpp Normal file
View File

@ -0,0 +1,128 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/api/IoBufQuicBatch.h>
#include <quic/happyeyeballs/QuicHappyEyeballsFunctions.h>
namespace quic {
IOBufQuicBatch::IOBufQuicBatch(
std::unique_ptr<BatchWriter>&& batchWriter,
folly::AsyncUDPSocket& sock,
folly::SocketAddress& peerAddress,
QuicConnectionStateBase::HappyEyeballsState& happyEyeballsState)
: batchWriter_(std::move(batchWriter)),
sock_(sock),
peerAddress_(peerAddress),
happyEyeballsState_(happyEyeballsState) {}
bool IOBufQuicBatch::write(
std::unique_ptr<folly::IOBuf>&& buf,
size_t encodedSize) {
pktSent_++;
// see if we need to flush the prev buffer(s)
if (batchWriter_->needsFlush(encodedSize)) {
// continue even if we get an error here
flush();
}
// try to append the new buffers
if (batchWriter_->append(std::move(buf), encodedSize)) {
// return if we get an error here
return flush();
}
return true;
}
bool IOBufQuicBatch::flush() {
bool ret = flushInternal();
reset();
return ret;
}
void IOBufQuicBatch::setContinueOnNetworkUnreachable(
bool continueOnNetworkUnreachable) {
continueOnNetworkUnreachable_ = continueOnNetworkUnreachable;
}
void IOBufQuicBatch::reset() {
batchWriter_->reset();
}
bool IOBufQuicBatch::isNetworkUnreachable(int err) {
return err == EHOSTUNREACH || err == ENETUNREACH;
}
bool IOBufQuicBatch::isRetriableError(int err) {
return err == EAGAIN || err == EWOULDBLOCK || err == ENOBUFS ||
err == EMSGSIZE ||
(continueOnNetworkUnreachable_ && isNetworkUnreachable(err));
}
bool IOBufQuicBatch::flushInternal() {
if (batchWriter_->empty()) {
return true;
}
auto consumed = batchWriter_->write(sock_, peerAddress_);
bool written = (consumed >= 0);
// If retriable error occured on first socket, kick off second socket
// immediately
// TODO I think any error on first socket should trigger this though.
if ((!written && isRetriableError(errno)) &&
happyEyeballsState_.connAttemptDelayTimeout &&
happyEyeballsState_.connAttemptDelayTimeout->isScheduled()) {
happyEyeballsState_.connAttemptDelayTimeout->cancelTimeout();
happyEyeballsStartSecondSocket(happyEyeballsState_);
}
// Write to second socket if there is no fatal error on first socket write
if ((written || isRetriableError(errno)) &&
happyEyeballsState_.shouldWriteToSecondSocket) {
// TODO: if the errno is EMSGSIZE, and we move on with the second socket,
// we actually miss the chance to fix our UDP packet size with the first
// socket.
consumed = batchWriter_->write(
*happyEyeballsState_.secondSocket,
happyEyeballsState_.secondPeerAddress);
// written is marked false if either socket write fails
// This causes write loop to exit early.
// I am not sure if this is necessary but at least it should be OK
written &= (consumed >= 0);
}
// TODO: handle ENOBUFS and backpressure the socket.
if (!written && !isRetriableError(errno)) {
int errnoCopy = errno;
std::string errorMsg = folly::to<std::string>(
folly::errnoStr(errnoCopy),
(errnoCopy == EMSGSIZE)
? folly::to<std::string>(", pktSize=", batchWriter_->size())
: "");
VLOG(4) << "Error writing to the socket " << errorMsg << " "
<< peerAddress_;
throw QuicTransportException(
folly::to<std::string>("Error on socket write ", errorMsg),
TransportErrorCode::INTERNAL_ERROR);
}
if (!written) {
// This can happen normally, so ignore for now. Now we treat EAGAIN same
// as a loss to avoid looping.
// TODO: Remove once we use write event from libevent.
return false; // done
}
return true; // success, not done yet
}
} // namespace quic

57
quic/api/IoBufQuicBatch.h Normal file
View File

@ -0,0 +1,57 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <quic/QuicException.h>
#include <quic/api/QuicBatchWriter.h>
#include <quic/state/StateData.h>
namespace quic {
class IOBufQuicBatch {
public:
IOBufQuicBatch(
std::unique_ptr<BatchWriter>&& batchWriter,
folly::AsyncUDPSocket& sock,
folly::SocketAddress& peerAddress,
QuicConnectionStateBase::HappyEyeballsState& happyEyeballsState);
~IOBufQuicBatch() = default;
// returns true if it succeeds and false if the loop should end
bool write(std::unique_ptr<folly::IOBuf>&& buf, size_t encodedSize);
bool flush();
FOLLY_ALWAYS_INLINE uint64_t getPktSent() const {
return pktSent_;
}
void setContinueOnNetworkUnreachable(bool continueOnNetworkUnreachable);
private:
void reset();
// flushes the internal buffers
bool flushInternal();
bool isNetworkUnreachable(int err);
/**
* Returns whether or not the errno can be retried later.
*/
bool isRetriableError(int err);
std::unique_ptr<BatchWriter> batchWriter_;
folly::AsyncUDPSocket& sock_;
folly::SocketAddress& peerAddress_;
QuicConnectionStateBase::HappyEyeballsState& happyEyeballsState_;
uint64_t pktSent_{0};
bool continueOnNetworkUnreachable_{false};
};
} // namespace quic

View File

@ -0,0 +1,176 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/api/QuicBatchWriter.h>
namespace quic {
// BatchWriter
bool BatchWriter::needsFlush(size_t /*unused*/) {
return false;
}
// SinglePacketBatchWriter
void SinglePacketBatchWriter::reset() {
buf_.reset();
}
bool SinglePacketBatchWriter::append(
std::unique_ptr<folly::IOBuf>&& buf,
size_t /*unused*/) {
buf_ = std::move(buf);
// needs to be flushed
return true;
}
ssize_t SinglePacketBatchWriter::write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) {
return sock.write(address, buf_);
}
// GSOPacketBatchWriter
GSOPacketBatchWriter::GSOPacketBatchWriter(size_t maxBufs)
: maxBufs_(maxBufs) {}
void GSOPacketBatchWriter::reset() {
buf_.reset(nullptr);
currBufs_ = 0;
prevSize_ = 0;
}
bool GSOPacketBatchWriter::needsFlush(size_t size) {
// if we get a buffer with a size that is greater
// than the prev one we need to flush
return (prevSize_ && (size > prevSize_));
}
bool GSOPacketBatchWriter::append(
std::unique_ptr<folly::IOBuf>&& buf,
size_t size) {
// first buffer
if (!buf_) {
DCHECK_EQ(currBufs_, 0);
buf_ = std::move(buf);
prevSize_ = size;
currBufs_ = 1;
return false; // continue
}
// now we've got an additional buffer
// append it to the chain
buf_->prependChain(std::move(buf));
currBufs_++;
// see if we've added a different size
if (size != prevSize_) {
CHECK_LT(size, prevSize_);
return true;
}
// reached max buffers
if (FOLLY_UNLIKELY(currBufs_ == maxBufs_)) {
return true;
}
// does not need to be flushed yet
return false;
}
ssize_t GSOPacketBatchWriter::write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) {
return (currBufs_ > 1)
? sock.writeGSO(address, buf_, static_cast<int>(prevSize_))
: sock.write(address, buf_);
}
// SendmmsgPacketBatchWriter
SendmmsgPacketBatchWriter::SendmmsgPacketBatchWriter(size_t maxBufs)
: maxBufs_(maxBufs) {
bufs_.reserve(maxBufs);
}
bool SendmmsgPacketBatchWriter::empty() const {
return !currSize_;
}
size_t SendmmsgPacketBatchWriter::size() const {
return currSize_;
}
void SendmmsgPacketBatchWriter::reset() {
bufs_.clear();
currSize_ = 0;
}
bool SendmmsgPacketBatchWriter::append(
std::unique_ptr<folly::IOBuf>&& buf,
size_t size) {
CHECK_LT(bufs_.size(), maxBufs_);
bufs_.emplace_back(std::move(buf));
currSize_ += size;
// reached max buffers
if (FOLLY_UNLIKELY(bufs_.size() == maxBufs_)) {
return true;
}
// does not need to be flushed yet
return false;
}
ssize_t SendmmsgPacketBatchWriter::write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) {
CHECK_GT(bufs_.size(), 0);
if (bufs_.size() == 1) {
return sock.write(address, bufs_[0]);
}
int ret = sock.writem(address, bufs_.data(), bufs_.size());
if (ret <= 0) {
return ret;
}
if (static_cast<size_t>(ret) == bufs_.size()) {
return currSize_;
}
// this is a partial write - we just need to
// return a different number than currSize_
return 0;
}
// BatchWriterFactory
std::unique_ptr<BatchWriter> BatchWriterFactory::makeBatchWriter(
folly::AsyncUDPSocket& sock,
const quic::QuicBatchingMode& batchingMode,
uint32_t batchingNum) {
switch (batchingMode) {
case quic::QuicBatchingMode::BATCHING_MODE_NONE:
return std::make_unique<SinglePacketBatchWriter>();
case quic::QuicBatchingMode::BATCHING_MODE_GSO: {
if (sock.getGSO() >= 0) {
return std::make_unique<GSOPacketBatchWriter>(batchingNum);
}
return std::make_unique<SinglePacketBatchWriter>();
}
case quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG:
return std::make_unique<SendmmsgPacketBatchWriter>(batchingNum);
// no default so we can catch missing case at compile time
}
// should be unreachable
return std::make_unique<SinglePacketBatchWriter>();
}
} // namespace quic

124
quic/api/QuicBatchWriter.h Normal file
View File

@ -0,0 +1,124 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/io/IOBuf.h>
#include <folly/io/async/AsyncUDPSocket.h>
#include <quic/QuicConstants.h>
namespace quic {
class BatchWriter {
public:
BatchWriter() = default;
virtual ~BatchWriter() = default;
// returns true if the batch does not contain any buffers
virtual bool empty() const = 0;
// returns the size in bytes of the batched buffers
virtual size_t size() const = 0;
// reset the internal state after a flush
virtual void reset() = 0;
// returns false if we need to flush before adding a new packet
virtual bool needsFlush(size_t /*unused*/);
/* append returns true if the
* writer need to be flushed
*/
virtual bool append(std::unique_ptr<folly::IOBuf>&& buf, size_t bufSize) = 0;
virtual ssize_t write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) = 0;
};
class IOBufBatchWriter : public BatchWriter {
public:
IOBufBatchWriter() = default;
~IOBufBatchWriter() override = default;
bool empty() const override {
return !buf_;
}
size_t size() const override {
return buf_ ? buf_->computeChainDataLength() : 0;
}
protected:
std::unique_ptr<folly::IOBuf> buf_;
};
class SinglePacketBatchWriter : public IOBufBatchWriter {
public:
SinglePacketBatchWriter() = default;
~SinglePacketBatchWriter() override = default;
void reset() override;
bool append(std::unique_ptr<folly::IOBuf>&& buf, size_t /*unused*/) override;
ssize_t write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) override;
};
class GSOPacketBatchWriter : public IOBufBatchWriter {
public:
explicit GSOPacketBatchWriter(size_t maxBufs);
~GSOPacketBatchWriter() override = default;
void reset() override;
bool needsFlush(size_t size) override;
bool append(std::unique_ptr<folly::IOBuf>&& buf, size_t size) override;
ssize_t write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) override;
private:
// max number of buffer chains we can accumulate before we need to flush
size_t maxBufs_{1};
// current number of buffer chains appended the buf_
size_t currBufs_{0};
// size of the previous buffer chain appended to the buf_
size_t prevSize_{0};
};
class SendmmsgPacketBatchWriter : public BatchWriter {
public:
explicit SendmmsgPacketBatchWriter(size_t maxBufs);
~SendmmsgPacketBatchWriter() override = default;
bool empty() const override;
size_t size() const override;
void reset() override;
bool append(std::unique_ptr<folly::IOBuf>&& buf, size_t size) override;
ssize_t write(
folly::AsyncUDPSocket& sock,
const folly::SocketAddress& address) override;
private:
// max number of buffer chains we can accumulate before we need to flush
size_t maxBufs_{1};
// size of data in all the buffers
size_t currSize_{0};
// array of IOBufs
std::vector<std::unique_ptr<folly::IOBuf>> bufs_;
};
class BatchWriterFactory {
public:
static std::unique_ptr<BatchWriter> makeBatchWriter(
folly::AsyncUDPSocket& sock,
const quic::QuicBatchingMode& batchingMode,
uint32_t batchingNum);
};
} // namespace quic

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/codec/QuicPacketRebuilder.h>
#include <quic/state/QuicStateFunctions.h>
namespace quic {
template <typename ClockType>
inline folly::Optional<PacketNum> AckScheduler::writeNextAcks(
PacketBuilderInterface& builder,
AckMode mode) {
switch (mode) {
case AckMode::Immediate: {
return writeAcksImpl<ClockType>(builder);
}
case AckMode::Pending: {
return writeAcksIfPending<ClockType>(builder);
}
}
__builtin_unreachable();
}
template <typename ClockType>
inline folly::Optional<PacketNum> AckScheduler::writeAcksIfPending(
PacketBuilderInterface& builder) {
if (ackState_.needsToSendAckImmediately) {
return writeAcksImpl<ClockType>(builder);
}
return folly::none;
}
template <typename ClockType>
folly::Optional<PacketNum> AckScheduler::writeAcksImpl(
PacketBuilderInterface& builder) {
// Use default ack delay for long headers. Usually long headers are sent
// before crypto negotiation, so the peer might not know about the ack delay
// exponent yet, so we use the default.
uint8_t ackDelayExponentToUse = folly::variant_match(
builder.getPacketHeader(),
[](const LongHeader&) { return kDefaultAckDelayExponent; },
[&](const auto&) { return conn_.transportSettings.ackDelayExponent; });
auto largestAckedPacketNum = *largestAckToSend(ackState_);
auto ackingTime = ClockType::now();
DCHECK(ackState_.largestRecvdPacketTime.hasValue())
<< "Missing received time for the largest acked packet";
// assuming that we're going to ack the largest recived with hightest pri
auto receivedTime = *ackState_.largestRecvdPacketTime;
std::chrono::microseconds ackDelay =
(ackingTime > receivedTime
? std::chrono::duration_cast<std::chrono::microseconds>(
ackingTime - receivedTime)
: std::chrono::microseconds::zero());
AckFrameMetaData meta(ackState_.acks, ackDelay, ackDelayExponentToUse);
auto ackWriteResult = writeAckFrame(meta, builder);
if (!ackWriteResult) {
return folly::none;
}
return largestAckedPacketNum;
}
} // namespace quic

View File

@ -0,0 +1,623 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/api/QuicPacketScheduler.h>
namespace {
quic::StreamFrameMetaData makeStreamFrameMetaDataFromStreamBuffer(
quic::StreamId id,
const quic::StreamBuffer& buffer);
quic::StreamFrameMetaData makeStreamFrameMetaDataFromStreamBuffer(
quic::StreamId id,
const quic::StreamBuffer& buffer) {
quic::StreamFrameMetaData streamMeta;
// It's very tricky to get the stream data without the length right,
// so don't support it for now.
streamMeta.hasMoreFrames = true;
streamMeta.id = id;
streamMeta.offset = buffer.offset;
streamMeta.fin = buffer.eof;
streamMeta.data = buffer.data.front() ? buffer.data.front()->clone()
: folly::IOBuf::create(0);
return streamMeta;
}
} // namespace
namespace quic {
bool hasAcksToSchedule(const AckState& ackState) {
if (ackState.acks.empty()) {
return false;
}
if (!ackState.largestAckScheduled) {
// Never scheduled an ack, we need to send
return true;
}
return *largestAckToSend(ackState) > *(ackState.largestAckScheduled);
}
bool neverWrittenAcksBefore(const QuicConnectionStateBase& conn) {
return (
!conn.ackStates.initialAckState.largestAckScheduled &&
!conn.ackStates.handshakeAckState.largestAckScheduled &&
!conn.ackStates.appDataAckState.largestAckScheduled);
}
bool hasAcksToSchedule(const QuicConnectionStateBase& conn) {
bool initialSpaceHasAcks = hasAcksToSchedule(conn.ackStates.initialAckState);
bool handshakeSpaceHasAcks =
hasAcksToSchedule(conn.ackStates.handshakeAckState);
bool appDataSpaceHasAcks = hasAcksToSchedule(conn.ackStates.appDataAckState);
bool cannotWriteInitialAcks =
!conn.initialWriteCipher || !initialSpaceHasAcks;
bool cannotWriteHandshakeAcks =
!conn.handshakeWriteCipher || !handshakeSpaceHasAcks;
bool cannotWriteAppDataAcks = !conn.oneRttWriteCipher || !appDataSpaceHasAcks;
if (cannotWriteInitialAcks && cannotWriteHandshakeAcks &&
cannotWriteAppDataAcks) {
return false;
}
if (neverWrittenAcksBefore(conn)) {
return true;
}
return initialSpaceHasAcks || handshakeSpaceHasAcks || appDataSpaceHasAcks;
}
folly::Optional<PacketNum> largestAckToSend(const AckState& ackState) {
if (ackState.acks.empty()) {
return folly::none;
}
return ackState.acks.back().end;
}
// Schedulers
FrameScheduler::Builder::Builder(
const QuicConnectionStateBase& conn,
fizz::EncryptionLevel encryptionLevel,
PacketNumberSpace packetNumberSpace,
const std::string& name)
: conn_(conn),
encryptionLevel_(encryptionLevel),
packetNumberSpace_(packetNumberSpace),
name_(name) {}
FrameScheduler::Builder& FrameScheduler::Builder::streamRetransmissions() {
retransmissionScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::streamFrames() {
streamFrameScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::ackFrames() {
ackScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::resetFrames() {
rstScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::windowUpdateFrames() {
windowUpdateScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::blockedFrames() {
blockedScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::cryptoFrames() {
cryptoStreamScheduler_ = true;
return *this;
}
FrameScheduler::Builder& FrameScheduler::Builder::simpleFrames() {
simpleFrameScheduler_ = true;
return *this;
}
FrameScheduler FrameScheduler::Builder::build() && {
auto scheduler = FrameScheduler(name_);
if (retransmissionScheduler_) {
scheduler.retransmissionScheduler_.emplace(RetransmissionScheduler(conn_));
}
if (streamFrameScheduler_) {
scheduler.streamFrameScheduler_.emplace(StreamFrameScheduler(conn_));
}
if (ackScheduler_) {
scheduler.ackScheduler_.emplace(
AckScheduler(conn_, getAckState(conn_, packetNumberSpace_)));
}
if (rstScheduler_) {
scheduler.rstScheduler_.emplace(RstStreamScheduler(conn_));
}
if (windowUpdateScheduler_) {
scheduler.windowUpdateScheduler_.emplace(WindowUpdateScheduler(conn_));
}
if (blockedScheduler_) {
scheduler.blockedScheduler_.emplace(BlockedScheduler(conn_));
}
if (cryptoStreamScheduler_) {
scheduler.cryptoStreamScheduler_.emplace(CryptoStreamScheduler(
conn_, *getCryptoStream(*conn_.cryptoState, encryptionLevel_)));
}
if (simpleFrameScheduler_) {
scheduler.simpleFrameScheduler_.emplace(SimpleFrameScheduler(conn_));
}
return scheduler;
}
FrameScheduler::FrameScheduler(const std::string& name) : name_(name) {}
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
FrameScheduler::scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes) {
// We need to keep track of writable bytes after writing header.
writableBytes = writableBytes > builder.getHeaderBytes()
? writableBytes - builder.getHeaderBytes()
: 0;
// We cannot return early if the writablyBytes dropps to 0 here, since pure
// acks can skip writableBytes entirely.
PacketBuilderWrapper wrapper(builder, writableBytes);
auto ackMode = hasImmediateData() ? AckMode::Immediate : AckMode::Pending;
bool cryptoDataWritten = false;
bool rstWritten = false;
if (cryptoStreamScheduler_ && cryptoStreamScheduler_->hasData()) {
cryptoDataWritten = cryptoStreamScheduler_->writeCryptoData(wrapper);
}
if (rstScheduler_ && rstScheduler_->hasPendingRsts()) {
rstWritten = rstScheduler_->writeRsts(wrapper);
}
if (ackScheduler_ && ackScheduler_->hasPendingAcks()) {
if (cryptoDataWritten || rstWritten) {
// If packet has non ack data, it is subject to congestion control. We
// need to use the wrapper/
ackScheduler_->writeNextAcks(wrapper, ackMode);
} else {
// If we start with writing acks, we will let the ack scheduler write
// up to the full packet space. If the ack bytes exceeds the writable
// bytes, this will be a pure ack packet and it will skip congestion
// controller. Otherwise, we will give other schedulers an opportunity to
// write up to writable bytes.
ackScheduler_->writeNextAcks(builder, ackMode);
}
}
if (windowUpdateScheduler_ &&
windowUpdateScheduler_->hasPendingWindowUpdates()) {
windowUpdateScheduler_->writeWindowUpdates(wrapper);
}
if (blockedScheduler_ && blockedScheduler_->hasPendingBlockedFrames()) {
blockedScheduler_->writeBlockedFrames(wrapper);
}
if (retransmissionScheduler_ && retransmissionScheduler_->hasPendingData()) {
retransmissionScheduler_->writeRetransmissionStreams(wrapper);
}
if (streamFrameScheduler_ && streamFrameScheduler_->hasPendingData()) {
streamFrameScheduler_->writeStreams(wrapper);
}
if (simpleFrameScheduler_ &&
simpleFrameScheduler_->hasPendingSimpleFrames()) {
simpleFrameScheduler_->writeSimpleFrames(wrapper);
}
return std::make_pair(folly::none, std::move(builder).buildPacket());
}
bool FrameScheduler::hasData() const {
return (ackScheduler_ && ackScheduler_->hasPendingAcks()) ||
hasImmediateData();
}
bool FrameScheduler::hasImmediateData() const {
return (cryptoStreamScheduler_ && cryptoStreamScheduler_->hasData()) ||
(retransmissionScheduler_ &&
retransmissionScheduler_->hasPendingData()) ||
(streamFrameScheduler_ && streamFrameScheduler_->hasPendingData()) ||
(rstScheduler_ && rstScheduler_->hasPendingRsts()) ||
(windowUpdateScheduler_ &&
windowUpdateScheduler_->hasPendingWindowUpdates()) ||
(blockedScheduler_ && blockedScheduler_->hasPendingBlockedFrames()) ||
(simpleFrameScheduler_ &&
simpleFrameScheduler_->hasPendingSimpleFrames());
}
std::string FrameScheduler::name() const {
return name_;
}
RetransmissionScheduler::RetransmissionScheduler(
const QuicConnectionStateBase& conn)
: conn_(conn) {}
void RetransmissionScheduler::writeRetransmissionStreams(
PacketBuilderInterface& builder) {
for (auto streamId : conn_.streamManager->lossStreams()) {
auto stream = conn_.streamManager->findStream(streamId);
CHECK(stream);
for (auto buffer = stream->lossBuffer.cbegin();
buffer != stream->lossBuffer.cend();
++buffer) {
auto streamMeta =
makeStreamFrameMetaDataFromStreamBuffer(stream->id, *buffer);
auto res = writeStreamFrame(streamMeta, builder);
if (!res) {
// Finish assembling a packet
break;
}
VLOG(4) << "Wrote retransmitted stream=" << streamMeta.id
<< " offset=" << streamMeta.offset
<< " bytes=" << res->bytesWritten << " fin=" << res->finWritten
<< " " << conn_;
}
}
}
bool RetransmissionScheduler::hasPendingData() const {
return !conn_.streamManager->lossStreams().empty();
}
StreamFrameScheduler::StreamFrameScheduler(const QuicConnectionStateBase& conn)
: conn_(conn) {}
void StreamFrameScheduler::writeStreams(PacketBuilderInterface& builder) {
uint64_t connWritableBytes = getSendConnFlowControlBytesWire(conn_);
MiddleStartingIterationWrapper wrapper(
conn_.streamManager->writableStreams(),
conn_.schedulingState.lastScheduledStream);
auto writableStreamItr = wrapper.cbegin();
while (writableStreamItr != wrapper.cend() && connWritableBytes > 0) {
auto res =
writeNextStreamFrame(builder, writableStreamItr, connWritableBytes);
if (!res) {
break;
}
}
}
bool StreamFrameScheduler::hasPendingData() const {
return conn_.streamManager->hasWritable() &&
getSendConnFlowControlBytesWire(conn_) > 0;
}
bool StreamFrameScheduler::writeNextStreamFrame(
PacketBuilderInterface& builder,
StreamFrameScheduler::WritableStreamItr& writableStreamItr,
uint64_t& connWritableBytes) {
auto stream = conn_.streamManager->findStream(*writableStreamItr);
CHECK(stream);
// hasWritableData is the condition which has to be satisfied for the
// stream to be in writableList
DCHECK(stream->hasWritableData());
auto streamMeta = makeStreamFrameMetaData(*stream, true, connWritableBytes);
auto res = writeStreamFrame(streamMeta, builder);
if (!res) {
// Finish assembling a packet
return false;
}
VLOG(4) << "Wrote stream frame stream=" << streamMeta.id
<< " offset=" << streamMeta.offset
<< " bytesWritten=" << res->bytesWritten
<< " finWritten=" << res->finWritten << " " << conn_;
connWritableBytes -= res->bytesWritten;
// bytesWritten < min(flowControlBytes, writeBuffer) means that we haven't
// written all writable bytes in this stream due to short of room in the
// packet.
if (res->bytesWritten ==
std::min<uint64_t>(
getSendStreamFlowControlBytesWire(*stream),
stream->writeBuffer.chainLength())) {
++writableStreamItr;
}
return true;
}
StreamFrameMetaData StreamFrameScheduler::makeStreamFrameMetaData(
const QuicStreamState& streamData,
bool /*hasMoreData*/,
uint64_t connWritableBytes) {
uint64_t writableBytes = std::min(
getSendStreamFlowControlBytesWire(streamData), connWritableBytes);
StreamFrameMetaData streamMeta;
streamMeta.hasMoreFrames = true;
streamMeta.id = streamData.id;
streamMeta.offset = streamData.currentWriteOffset;
if (streamData.writeBuffer.front()) {
folly::io::Cursor cursor(streamData.writeBuffer.front());
cursor.cloneAtMost(streamMeta.data, writableBytes);
}
streamMeta.fin = streamData.finalWriteOffset.hasValue() &&
streamData.writeBuffer.chainLength() <= writableBytes;
return streamMeta;
}
AckScheduler::AckScheduler(
const QuicConnectionStateBase& conn,
const AckState& ackState)
: conn_(conn), ackState_(ackState) {}
bool AckScheduler::hasPendingAcks() const {
return hasAcksToSchedule(ackState_);
}
RstStreamScheduler::RstStreamScheduler(const QuicConnectionStateBase& conn)
: conn_(conn) {}
bool RstStreamScheduler::hasPendingRsts() const {
return !conn_.pendingEvents.resets.empty();
}
bool RstStreamScheduler::writeRsts(PacketBuilderInterface& builder) {
bool rstWritten = false;
for (const auto& resetStream : conn_.pendingEvents.resets) {
// TODO: here, maybe coordinate scheduling of RST_STREAMS and streams.
auto bytesWritten = writeFrame(resetStream.second, builder);
if (!bytesWritten) {
break;
}
rstWritten = true;
}
return rstWritten;
}
SimpleFrameScheduler::SimpleFrameScheduler(const QuicConnectionStateBase& conn)
: conn_(conn) {}
bool SimpleFrameScheduler::hasPendingSimpleFrames() const {
return conn_.pendingEvents.pathChallenge ||
!conn_.pendingEvents.frames.empty();
}
bool SimpleFrameScheduler::writeSimpleFrames(PacketBuilderInterface& builder) {
auto& pathChallenge = conn_.pendingEvents.pathChallenge;
if (pathChallenge &&
!writeSimpleFrame(QuicSimpleFrame(*pathChallenge), builder)) {
return false;
}
bool framesWritten = false;
for (auto& frame : conn_.pendingEvents.frames) {
auto bytesWritten = writeSimpleFrame(QuicSimpleFrame(frame), builder);
if (!bytesWritten) {
break;
}
framesWritten = true;
}
return framesWritten;
}
WindowUpdateScheduler::WindowUpdateScheduler(
const QuicConnectionStateBase& conn)
: conn_(conn) {}
bool WindowUpdateScheduler::hasPendingWindowUpdates() const {
return conn_.streamManager->hasWindowUpdates() ||
conn_.pendingEvents.connWindowUpdate;
}
void WindowUpdateScheduler::writeWindowUpdates(
PacketBuilderInterface& builder) {
if (conn_.pendingEvents.connWindowUpdate) {
auto maxDataFrame = generateMaxDataFrame(conn_);
auto maximumData = maxDataFrame.maximumData;
auto bytes = writeFrame(std::move(maxDataFrame), builder);
if (bytes) {
VLOG(4) << "Wrote max_data=" << maximumData << " " << conn_;
}
}
for (const auto& windowUpdateStream : conn_.streamManager->windowUpdates()) {
auto stream = conn_.streamManager->findStream(windowUpdateStream);
if (!stream) {
continue;
}
auto maxStreamDataFrame = generateMaxStreamDataFrame(*stream);
auto maximumData = maxStreamDataFrame.maximumData;
auto bytes = writeFrame(std::move(maxStreamDataFrame), builder);
if (!bytes) {
break;
}
VLOG(4) << "Wrote max_stream_data stream=" << stream->id
<< " maximumData=" << maximumData << " " << conn_;
}
}
BlockedScheduler::BlockedScheduler(const QuicConnectionStateBase& conn)
: conn_(conn) {}
bool BlockedScheduler::hasPendingBlockedFrames() const {
return !conn_.streamManager->blockedStreams().empty();
}
void BlockedScheduler::writeBlockedFrames(PacketBuilderInterface& builder) {
for (const auto& blockedStream : conn_.streamManager->blockedStreams()) {
auto bytesWritten = writeFrame(blockedStream.second, builder);
if (!bytesWritten) {
break;
}
}
}
CryptoStreamScheduler::CryptoStreamScheduler(
const QuicConnectionStateBase& conn,
const QuicCryptoStream& cryptoStream)
: conn_(conn), cryptoStream_(cryptoStream) {}
bool CryptoStreamScheduler::writeCryptoData(PacketBuilderInterface& builder) {
bool cryptoDataWritten = false;
uint64_t writableData =
folly::to<uint64_t>(cryptoStream_.writeBuffer.chainLength());
// We use the crypto scheduler to reschedule the retransmissions of the
// crypto streams so that we know that retransmissions of the crypto data
// will always take precedence over the crypto data.
for (const auto& buffer : cryptoStream_.lossBuffer) {
auto res =
writeCryptoFrame(buffer.offset, buffer.data.front()->clone(), builder);
if (!res) {
return cryptoDataWritten;
}
VLOG(4) << "Wrote retransmitted crypto"
<< " offset=" << buffer.offset << " bytes=" << res->len << " "
<< conn_;
cryptoDataWritten = true;
}
if (writableData != 0) {
Buf data;
folly::io::Cursor cursor(cryptoStream_.writeBuffer.front());
cursor.cloneAtMost(data, writableData);
auto res = writeCryptoFrame(
cryptoStream_.currentWriteOffset, std::move(data), builder);
if (res) {
VLOG(4) << "Wrote crypto frame"
<< " offset=" << cryptoStream_.currentWriteOffset
<< " bytesWritten=" << res->len << " " << conn_;
cryptoDataWritten = true;
}
}
if (cryptoDataWritten && conn_.nodeType == QuicNodeType::Client) {
bool initialPacket = folly::variant_match(
builder.getPacketHeader(),
[](const LongHeader& header) {
return header.getHeaderType() == LongHeader::Types::Initial;
},
[](const auto&) { return false; });
if (initialPacket) {
// This is the initial packet, we need to fill er up.
while (builder.remainingSpaceInPkt() > 0) {
writeFrame(PaddingFrame(), builder);
}
}
}
return cryptoDataWritten;
}
bool CryptoStreamScheduler::hasData() const {
return !cryptoStream_.writeBuffer.empty() ||
!cryptoStream_.lossBuffer.empty();
}
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
CryptoStreamScheduler::scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes) {
// We need to keep track of writable bytes after writing header.
writableBytes = writableBytes > builder.getHeaderBytes()
? writableBytes - builder.getHeaderBytes()
: 0;
if (!writableBytes) {
return std::make_pair(folly::none, folly::none);
}
PacketBuilderWrapper wrapper(builder, writableBytes);
writeCryptoData(wrapper);
return std::make_pair(folly::none, std::move(builder).buildPacket());
}
CloningScheduler::CloningScheduler(
FrameScheduler& scheduler,
QuicConnectionStateBase& conn,
const std::string& name,
uint64_t cipherOverhead)
: frameScheduler_(scheduler),
conn_(conn),
name_(std::move(name)),
cipherOverhead_(cipherOverhead) {}
bool CloningScheduler::hasData() const {
return frameScheduler_.hasData() ||
(!conn_.outstandingPackets.empty() &&
(conn_.outstandingPackets.size() !=
conn_.outstandingHandshakePacketsCount +
conn_.outstandingPureAckPacketsCount));
}
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
CloningScheduler::scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes) {
// The writableBytes in this function shouldn't be limited by cwnd, since
// we only use CloningScheduler for the cases that we want to bypass cwnd for
// now.
if (frameScheduler_.hasData()) {
// Note that there is a possibility that we end up writing nothing here. But
// if frameScheduler_ hasData() to write, we shouldn't invoke the cloning
// path if the write fails.
return frameScheduler_.scheduleFramesForPacket(
std::move(builder), writableBytes);
}
// Look for an outstanding packet that's no larger than the writableBytes
for (auto iter = conn_.outstandingPackets.rbegin();
iter != conn_.outstandingPackets.rend();
++iter) {
auto opPnSpace = folly::variant_match(
iter->packet.header,
[](const auto& h) { return h.getPacketNumberSpace(); });
if (opPnSpace != PacketNumberSpace::AppData) {
continue;
}
// Reusing the RegularQuicPacketBuilder throughout loop bodies will lead to
// frames belong to different original packets being written into the same
// clone packet. So re-create a RegularQuicPacketBuilder every time.
// TODO: We can avoid the copy & rebuild of the header by creating an
// independent header builder.
auto builderPnSpace = folly::variant_match(
builder.getPacketHeader(),
[](const auto& h) { return h.getPacketNumberSpace(); });
CHECK_EQ(builderPnSpace, PacketNumberSpace::AppData);
RegularQuicPacketBuilder regularBuilder(
conn_.udpSendPacketLen,
builder.getPacketHeader(),
getAckState(conn_, builderPnSpace).largestAckedByPeer);
PacketRebuilder rebuilder(regularBuilder, conn_);
// We shouldn't clone Handshake packet. For PureAcks, cloning them bring
// perf down as shown by load test.
if (iter->isHandshake || iter->pureAck) {
continue;
}
// If the packet is already a clone that has been processed, we don't clone
// it again.
if (iter->associatedEvent &&
conn_.outstandingPacketEvents.count(*iter->associatedEvent) == 0) {
continue;
}
// The writableBytes here is an optimization. If the writableBytes is too
// small for this packet. rebuildFromPacket should fail anyway.
// TODO: This isn't the ideal way to solve the wrong writableBytes problem.
if (iter->encodedSize > writableBytes + cipherOverhead_) {
continue;
}
// Rebuilder will write the rest of frames
auto rebuildResult = rebuilder.rebuildFromPacket(*iter);
if (rebuildResult) {
return std::make_pair(
std::move(rebuildResult), std::move(regularBuilder).buildPacket());
}
}
return std::make_pair(folly::none, folly::none);
}
std::string CloningScheduler::name() const {
return name_;
}
} // namespace quic

View File

@ -0,0 +1,425 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <boost/iterator/iterator_facade.hpp>
#include <folly/Overload.h>
#include <quic/QuicConstants.h>
#include <quic/QuicException.h>
#include <quic/codec/QuicPacketBuilder.h>
#include <quic/codec/QuicWriteCodec.h>
#include <quic/codec/Types.h>
#include <quic/flowcontrol/QuicFlowController.h>
#include <quic/state/QuicStreamFunctions.h>
namespace quic {
/**
* Common interface for Quic packet schedulers
* used at the top level.
*/
class QuicPacketScheduler {
public:
virtual ~QuicPacketScheduler() = default;
/**
* Schedules frames and writes them to the builder and returns
* a pair of PacketEvent and the Packet that was built.
*
* Returns a optional PacketEvent which indicates if the built out packet is
* a clone and the associated PacketEvent for both origin and clone.
*/
virtual std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes) = 0;
/**
* Returns whether the scheduler has data to send.
*/
virtual bool hasData() const = 0;
/**
* Returns the name of the scheduler.
*/
virtual std::string name() const = 0;
};
// A tag to denote how we should schedule ack in this packet.
enum class AckMode { Pending, Immediate };
class RetransmissionScheduler {
public:
explicit RetransmissionScheduler(const QuicConnectionStateBase& conn);
void writeRetransmissionStreams(PacketBuilderInterface& builder);
bool hasPendingData() const;
private:
StreamFrameMetaData makeStreamFrameMetaDataFromStreamBufer(
StreamId id,
const StreamBuffer& buffer,
bool moreFrames) const;
const QuicConnectionStateBase& conn_;
};
class StreamFrameScheduler {
public:
explicit StreamFrameScheduler(const QuicConnectionStateBase& conn);
/**
* Return: the first boolean indicates if at least one Blocked frame
* is written into the packet by writeStreams function.
*/
void writeStreams(PacketBuilderInterface& builder);
bool hasPendingData() const;
private:
/**
* A helper iterator adaptor class that starts iteration of streams from a
* specific stream id.
*/
class MiddleStartingIterationWrapper {
public:
using MapType = std::set<StreamId>;
class MiddleStartingIterator
: public boost::iterator_facade<
MiddleStartingIterator,
const MiddleStartingIterationWrapper::MapType::value_type,
boost::forward_traversal_tag> {
friend class boost::iterator_core_access;
public:
using MapType = MiddleStartingIterationWrapper::MapType;
MiddleStartingIterator() = default;
MiddleStartingIterator(
const MapType* streams,
const MapType::key_type& start)
: streams_(streams) {
itr_ = streams_->lower_bound(start);
checkForWrapAround();
}
const MapType::value_type& dereference() const {
return *itr_;
}
bool equal(const MiddleStartingIterator& other) const {
return wrappedAround_ == other.wrappedAround_ && itr_ == other.itr_;
}
void increment() {
++itr_;
checkForWrapAround();
}
void checkForWrapAround() {
if (itr_ == streams_->cend()) {
wrappedAround_ = true;
itr_ = streams_->cbegin();
}
}
private:
friend class MiddleStartingIterationWrapper;
bool wrappedAround_{false};
const MapType* streams_{nullptr};
MapType::const_iterator itr_;
};
MiddleStartingIterationWrapper(
const MapType& streams,
const MapType::key_type& start)
: streams_(streams), start_(start) {}
MiddleStartingIterator cbegin() const {
return MiddleStartingIterator(&streams_, start_);
}
MiddleStartingIterator cend() const {
MiddleStartingIterator itr(&streams_, start_);
itr.wrappedAround_ = true;
return itr;
}
private:
const MapType& streams_;
const MapType::key_type& start_;
};
using WritableStreamItr =
MiddleStartingIterationWrapper::MiddleStartingIterator;
/**
* Helper function to write either stream data if stream is not flow
* controlled or a blocked frame otherwise.
*
* Return: boolean indicates if anything (either data, or Blocked frame) is
* written into the packet.
*
*/
bool writeNextStreamFrame(
PacketBuilderInterface& builder,
WritableStreamItr& writableStreamItr,
uint64_t& connWritableBytes);
StreamFrameMetaData makeStreamFrameMetaData(
const QuicStreamState& streamData,
bool hasMoreData,
uint64_t connWritableBytes);
const QuicConnectionStateBase& conn_;
};
class AckScheduler {
public:
AckScheduler(const QuicConnectionStateBase& conn, const AckState& ackState);
template <typename ClockType = Clock>
folly::Optional<PacketNum> writeNextAcks(
PacketBuilderInterface& builder,
AckMode mode);
bool hasPendingAcks() const;
private:
/* Write out pending acks if needsToSendAckImmeidately in the connection's
* pendingEvent is true.
*/
template <typename ClockType>
folly::Optional<PacketNum> writeAcksIfPending(
PacketBuilderInterface& builder);
// Write out pending acks
template <typename ClockType>
folly::Optional<PacketNum> writeAcksImpl(PacketBuilderInterface& builder);
const QuicConnectionStateBase& conn_;
const AckState& ackState_;
};
/**
* Returns whether or not the Ack scheduler has acks to schedule. This does not
* tell you when the ACKs can be written.
*/
bool hasAcksToSchedule(const QuicConnectionStateBase& conn);
bool hasAcksToSchedule(const AckState& ackState);
bool neverWrittenAcksBefore(const QuicConnectionStateBase& conn);
/**
* Returns the largest packet received which needs to be acked.
*/
folly::Optional<PacketNum> largestAckToSend(const AckState& ackState);
class RstStreamScheduler {
public:
explicit RstStreamScheduler(const QuicConnectionStateBase& conn);
bool hasPendingRsts() const;
bool writeRsts(PacketBuilderInterface& builder);
private:
const QuicConnectionStateBase& conn_;
};
/*
* Simple frames are those whose mechanics are "simple" wrt the send/receive
* mechanics. These frames are retransmitted regularly on loss.
*/
class SimpleFrameScheduler {
public:
explicit SimpleFrameScheduler(const QuicConnectionStateBase& conn);
bool hasPendingSimpleFrames() const;
bool writeSimpleFrames(PacketBuilderInterface& builder);
private:
const QuicConnectionStateBase& conn_;
};
class WindowUpdateScheduler {
public:
explicit WindowUpdateScheduler(const QuicConnectionStateBase& conn);
bool hasPendingWindowUpdates() const;
void writeWindowUpdates(PacketBuilderInterface& builder);
private:
const QuicConnectionStateBase& conn_;
};
class BlockedScheduler {
public:
explicit BlockedScheduler(const QuicConnectionStateBase& conn);
bool hasPendingBlockedFrames() const;
void writeBlockedFrames(PacketBuilderInterface& builder);
private:
const QuicConnectionStateBase& conn_;
};
class CryptoStreamScheduler {
public:
explicit CryptoStreamScheduler(
const QuicConnectionStateBase& conn,
const QuicCryptoStream& cryptoStream);
/**
* Returns whether or we could write data to the stream.
*/
bool writeCryptoData(PacketBuilderInterface& builder);
/**
* Returns a optional PacketEvent which indicates if the built out packet is a
* clone and the associated PacketEvent for both origin and clone. In the case
* of CryptoStreamScheduler, this will always return folly::none.
*/
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes);
bool hasData() const;
std::string name() const {
return "CryptoScheduler";
}
private:
const QuicConnectionStateBase& conn_;
const QuicCryptoStream& cryptoStream_;
std::string name_;
};
class FrameScheduler : public QuicPacketScheduler {
public:
~FrameScheduler() override = default;
struct Builder {
Builder(
const QuicConnectionStateBase& conn,
fizz::EncryptionLevel encryptionLevel,
PacketNumberSpace packetNumberSpace,
const std::string& name);
Builder& streamRetransmissions();
Builder& streamFrames();
Builder& ackFrames();
Builder& resetFrames();
Builder& windowUpdateFrames();
Builder& blockedFrames();
Builder& cryptoFrames();
Builder& simpleFrames();
FrameScheduler build() &&;
private:
const QuicConnectionStateBase& conn_;
fizz::EncryptionLevel encryptionLevel_;
PacketNumberSpace packetNumberSpace_;
std::string name_;
// schedulers
bool retransmissionScheduler_{false};
bool streamFrameScheduler_{false};
bool ackScheduler_{false};
bool rstScheduler_{false};
bool windowUpdateScheduler_{false};
bool blockedScheduler_{false};
bool cryptoStreamScheduler_{false};
bool simpleFrameScheduler_{false};
};
explicit FrameScheduler(const std::string& name);
virtual std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes) override;
// If any scheduler, including AckScheduler, has pending data to send
virtual bool hasData() const override;
// If any of the non-Ack scheduler has pending data to send
virtual bool hasImmediateData() const;
virtual std::string name() const override;
private:
folly::Optional<RetransmissionScheduler> retransmissionScheduler_;
folly::Optional<StreamFrameScheduler> streamFrameScheduler_;
folly::Optional<AckScheduler> ackScheduler_;
folly::Optional<RstStreamScheduler> rstScheduler_;
folly::Optional<WindowUpdateScheduler> windowUpdateScheduler_;
folly::Optional<BlockedScheduler> blockedScheduler_;
folly::Optional<CryptoStreamScheduler> cryptoStreamScheduler_;
folly::Optional<SimpleFrameScheduler> simpleFrameScheduler_;
std::string name_;
};
/**
* A packet scheduler wrapping a normal FrameScheduler with the ability to clone
* exiting packets that are still outstanding. A CloningScheduler first trie to
* write new farmes with new data into a packet. If that fails due to the lack
* of new data, it falls back to cloning one inflight packet from a connection's
* oustanding packets if there is at least one outstanding packet that's smaller
* than the writableBytes limit, and isn't a Handshake packet.
*/
class CloningScheduler : public QuicPacketScheduler {
public:
// Normally a scheduler takes in a const conn, and update conn later. But for
// this one I want to update conn right inside this class itself.
// TODO: Passing cipherOverhead into the CloningScheduler to recalculate the
// correct writableBytes isn't ideal. But unblock me or others from quickly
// testing it on load test. :(
CloningScheduler(
FrameScheduler& scheduler,
QuicConnectionStateBase& conn,
const std::string& name,
uint64_t cipherOverhead);
bool hasData() const override;
/**
* Returns a optional PacketEvent which indicates if the built out packet is a
* clone and the associated PacketEvent for both origin and clone.
*/
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
scheduleFramesForPacket(
RegularQuicPacketBuilder&& builder,
uint32_t writableBytes) override;
std::string name() const override;
private:
FrameScheduler& frameScheduler_;
QuicConnectionStateBase& conn_;
std::string name_;
uint64_t cipherOverhead_;
};
} // namespace quic
#include <quic/api/QuicPacketScheduler-inl.h>

851
quic/api/QuicSocket.h Normal file
View File

@ -0,0 +1,851 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Expected.h>
#include <folly/Optional.h>
#include <folly/io/IOBuf.h>
#include <quic/QuicConstants.h>
#include <quic/codec/Types.h>
#include <quic/state/StateData.h>
#include <chrono>
namespace folly {
class EventBase;
}
namespace quic {
class QuicSocket {
public:
/**
* Callback for connection level events. This callback must be set at all
* times.
*/
class ConnectionCallback {
public:
virtual ~ConnectionCallback() = default;
/**
* Invoked when stream id's flow control state changes. This is an edge
* triggred API and will be only invoked at the point that the flow control
* changes.
*/
virtual void onFlowControlUpdate(StreamId /*id*/) noexcept {}
/**
* Invoked when the peer creates a new bidirectional stream. The most
* common flow would be to set the ReadCallback from here
*/
virtual void onNewStream(StreamId id) noexcept = 0;
/**
* Invoked when the peer creates a new unidirectional stream. The most
* common flow would be to set the ReadCallback from here
*/
virtual void onNewUnidirectionalStream(StreamId id) noexcept = 0;
/**
* Invokved when a stream receives a StopSending frame from a peer.
*/
virtual void onStopSending(
StreamId id,
ApplicationErrorCode error) noexcept = 0;
/**
* Invoked when the transport initiates close. No callbacks will
* be delivered after this
*/
virtual void onConnectionEnd() noexcept = 0;
/**
* Invoked when the connection closed in error
*/
virtual void onConnectionError(
std::pair<QuicErrorCode, std::string> code) noexcept = 0;
/**
* Client only.
* Called when the transport becomes replay safe.
*/
virtual void onReplaySafe() noexcept {}
/**
* Called when the transport is ready to send/receive data.
*/
virtual void onTransportReady() noexcept {}
/**
* On server side:
* Called during handshake while negotiating early data.
* Returns wether application accepts parameters from besumption state.
*
* On client side:
* Called when transport is applying psk from cache.
* Returns whether application will attempt early data based on params.
*
* Default implementation is provided for applications that do not require
* additional parameters in order to use 0-RTT.
*/
virtual bool validateEarlyDataAppParams(
const folly::Optional<std::string>& /* alpn */,
const Buf& /* appParams */) noexcept {
return true;
}
/**
* On server side:
* Called when transport is writing NewSessionTicket.
* Returns application parameters that will be included in NewSessionTicket.
*
* On client side:
* Called when client receives NewSessionTicket and is going to write to
* cache.
* Returns application parameters that will be written to cache.
*
* Default implementation is provided for applications that do not require
* additional parameters in order to use 0-RTT.
*/
virtual Buf serializeEarlyDataAppParams() noexcept {
return nullptr;
}
};
/**
* Information about the transport, similar to what TCP has.
*/
struct TransportInfo {
std::chrono::microseconds srtt;
std::chrono::microseconds rttvar;
std::chrono::microseconds lrtt;
std::chrono::microseconds mrtt;
uint64_t writableBytes;
uint64_t congestionWindow;
uint32_t packetsRetransmitted;
uint32_t timeoutBasedLoss;
std::chrono::microseconds rto;
uint64_t bytesSent;
uint64_t bytesRecvd;
uint64_t totalBytesRetransmitted;
uint32_t rtoCount;
uint32_t totalRTOCount;
PacketNum largestPacketAckedByPeer;
PacketNum largestPacketSent;
};
/**
* Information about the stream level transport info. Specific to QUIC.
*/
struct StreamTransportInfo {
// Total time the stream has spent in head-of-line blocked state,
// in microseconds
std::chrono::microseconds totalHeadOfLineBlockedTime{
std::chrono::microseconds::zero()};
// How many times the stream has entered the "head-of-line blocked" state
uint32_t holbCount{0};
// Is the stream head-of-line blocked?
bool isHolb{false};
};
/**
* Sets connection callback, must be set BEFORE using the socket.
*/
virtual void setConnectionCallback(ConnectionCallback* callback) = 0;
virtual ~QuicSocket() = default;
/**
* ===== Generic Socket Methods =====
*/
/**
* Get the QUIC Client Connection ID
*/
virtual folly::Optional<ConnectionId> getClientConnectionId() const = 0;
/**
* Get the QUIC Server Connection ID
*/
virtual folly::Optional<ConnectionId> getServerConnectionId() const = 0;
/**
* Get the peer socket address
*/
virtual const folly::SocketAddress& getPeerAddress() const = 0;
/**
* Get the original peer socket address
*/
virtual const folly::SocketAddress& getOriginalPeerAddress() const = 0;
/**
* Get the local socket address
*/
virtual const folly::SocketAddress& getLocalAddress() const = 0;
/**
* Determine if transport is open and ready to read or write.
*
* return true iff the transport is open and ready, false otherwise.
*/
virtual bool good() const = 0;
virtual bool replaySafe() const = 0;
/**
* Determine if an error has occurred with this transport.
*/
virtual bool error() const = 0;
/**
* Close this socket with a drain period. If closing with an error, it may be
* specified.
*/
virtual void close(
folly::Optional<std::pair<QuicErrorCode, std::string>> errorCode) = 0;
/**
* Close this socket gracefully, by waiting for all the streams to be idle
* first.
*/
virtual void closeGracefully() = 0;
/**
* Close this socket without a drain period. If closing with an error, it may
* be specified.
*/
virtual void closeNow(
folly::Optional<std::pair<QuicErrorCode, std::string>> errorCode) = 0;
/**
* Returns the event base associated with this socket
*/
virtual folly::EventBase* getEventBase() const = 0;
/**
* Returns the current offset already read or written by the application on
* the given stream.
*/
virtual folly::Expected<size_t, LocalErrorCode> getStreamReadOffset(
StreamId id) const = 0;
virtual folly::Expected<size_t, LocalErrorCode> getStreamWriteOffset(
StreamId id) const = 0;
/**
* Returns the amount of data buffered by the transport waiting to be written
*/
virtual folly::Expected<size_t, LocalErrorCode> getStreamWriteBufferedBytes(
StreamId id) const = 0;
/**
* Get internal transport info similar to TCP information.
*/
virtual TransportInfo getTransportInfo() const = 0;
/**
* Get internal transport info similar to TCP information.
* Returns LocalErrorCode::STREAM_NOT_EXISTS if the stream is not found
*/
virtual folly::Expected<StreamTransportInfo, LocalErrorCode>
getStreamTransportInfo(StreamId id) const = 0;
/**
* Get the negotiated ALPN. If called before the transport is ready
* returns folly::none
*/
virtual folly::Optional<std::string> getAppProtocol() const = 0;
/**
* Sets the size of the given stream's receive window, or the connection
* receive window if stream id is 0. If the window size increases, a
* window update will be sent to the peer. If it decreases, the transport
* will delay future window updates until the sender's available window is
* <= recvWindowSize.
*/
virtual void setReceiveWindow(StreamId id, size_t recvWindowSize) = 0;
/**
* Set the size of the transport send buffer for the given stream.
* The maximum total amount of buffer space is the sum of maxUnacked and
* maxUnsent. Bytes passed to writeChain count against unsent until the
* transport flushes them to the wire, after which they count against unacked.
*/
virtual void
setSendBuffer(StreamId id, size_t maxUnacked, size_t maxUnsent) = 0;
/**
* Get the flow control settings for the given stream (or connection flow
* control by passing id=0). Settings include send and receive window
* capacity and available.
*/
struct FlowControlState {
// Number of bytes the peer has allowed me to send.
uint64_t sendWindowAvailable;
// The max offset provided by the peer.
uint64_t sendWindowMaxOffset;
// Number of bytes I have allowed the peer to send.
uint64_t receiveWindowAvailable;
// The max offset I have provided to the peer.
uint64_t receiveWindowMaxOffset;
FlowControlState(
uint64_t sendWindowAvailableIn,
uint64_t sendWindowMaxOffsetIn,
uint64_t receiveWindowAvailableIn,
uint64_t receiveWindowMaxOffsetIn)
: sendWindowAvailable(sendWindowAvailableIn),
sendWindowMaxOffset(sendWindowMaxOffsetIn),
receiveWindowAvailable(receiveWindowAvailableIn),
receiveWindowMaxOffset(receiveWindowMaxOffsetIn) {}
};
/**
* Returns the current flow control windows for the connection.
* Use getStreamFlowControl for stream flow control window.
*/
virtual folly::Expected<FlowControlState, LocalErrorCode>
getConnectionFlowControl() const = 0;
/**
* Returns the current flow control windows for the stream, id != 0.
* Use getConnectionFlowControl for connection flow control window.
*/
virtual folly::Expected<FlowControlState, LocalErrorCode>
getStreamFlowControl(StreamId id) const = 0;
/**
* Sets the flow control window for the connection.
* Use setStreamFlowControlWindow for per Stream flow control.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode>
setConnectionFlowControlWindow(uint64_t windowSize) = 0;
/**
* Sets the flow control window for the stream.
* Use setConnectionFlowControlWindow for connection flow control.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode>
setStreamFlowControlWindow(StreamId id, uint64_t windowSize) = 0;
/**
* Settings for the transport. This takes effect only before the transport
* is connected.
*/
virtual void setTransportSettings(TransportSettings transportSettings) = 0;
/**
* Is partial reliability supported.
*/
virtual bool isPartiallyReliableTransport() const = 0;
/**
* ===== Read API ====
*/
/**
* Callback class for receiving data on a stream
*/
class ReadCallback {
public:
virtual ~ReadCallback() = default;
/**
* Called from the transport layer when there is data, EOF or an error
* available to read on the given stream ID
*/
virtual void readAvailable(StreamId id) noexcept = 0;
/**
* Called from the transport layer when there is an error on the stream.
*/
virtual void readError(
StreamId id,
std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>
error) noexcept = 0;
};
/**
* Set the read callback for the given stream. Note that read callback is
* expected to be set all the time. Removing read callback indicates that
* stream is no longer intended to be read again.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> setReadCallback(
StreamId id,
ReadCallback* cb) = 0;
/**
* Convenience function that sets the read callbacks of all streams to be
* nullptr.
*/
virtual void unsetAllReadCallbacks() = 0;
/**
* Convenience function that sets the read callbacks of all streams to be
* nullptr.
*/
virtual void unsetAllPeekCallbacks() = 0;
/**
* Convenience function that sets the read callbacks of all streams to be
* nullptr.
*/
virtual void unsetAllDeliveryCallbacks() = 0;
/**
* Invoke onCanceled on all the delivery callbacks registered for streamId.
*/
virtual void cancelDeliveryCallbacksForStream(StreamId streamId) = 0;
/**
* Pause/Resume read callback being triggered when data is available.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> pauseRead(
StreamId id) = 0;
virtual folly::Expected<folly::Unit, LocalErrorCode> resumeRead(
StreamId id) = 0;
/**
* Initiates sending of a StopSending frame for a given stream to the peer.
* This is called a "solicited reset". On receipt of the StopSending frame
* the peer should, but may not, send a ResetStream frame for the requested
* stream. A caller can use this function when they are no longer processing
* received data on the stream.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> stopSending(
StreamId id,
ApplicationErrorCode error) = 0;
/**
* Read from the given stream, up to maxLen bytes. If maxLen is 0, transport
* will return all available bytes.
*
* The return value is Expected. If the value hasError(), then a read error
* occured and it can be obtained with error(). If the value hasValue(), then
* value() returns a pair of the data (if any) and the EOF marker.
*
* Calling read() when there is no data/eof to deliver will return an
* EAGAIN-like error code.
*/
virtual folly::Expected<std::pair<Buf, bool>, LocalErrorCode> read(
StreamId id,
size_t maxLen) = 0;
/**
* ===== Peek/Consume API =====
*/
/**
* Usage:
* class Application {
* void onNewStream(StreamId id) {
* socket_->setPeekCallback(id, this);
* }
*
* virtual void onDataAvailable(
* StreamId id,
* const folly::Range<PeekIterator>& peekData) noexcept override
* {
* auto amount = tryInterpret(peekData);
* if (amount) {
* socket_->consume(id, amount);
* }
* }
* };
*/
using PeekIterator = std::deque<StreamBuffer>::const_iterator;
class PeekCallback {
public:
virtual ~PeekCallback() = default;
/**
* Called from the transport layer when there is new data available to
* peek on a given stream.
* Callback can be called multiple times and it is up to application to
* de-dupe already peeked ranges.
*/
virtual void onDataAvailable(
StreamId id,
const folly::Range<PeekIterator>& peekData) noexcept = 0;
};
virtual folly::Expected<folly::Unit, LocalErrorCode> setPeekCallback(
StreamId id,
PeekCallback* cb) = 0;
/**
* Pause/Resume peek callback being triggered when data is available.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> pausePeek(
StreamId id) = 0;
virtual folly::Expected<folly::Unit, LocalErrorCode> resumePeek(
StreamId id) = 0;
/**
* Peek at the given stream.
*
* The return value is Expected. If the value hasError(), then a read error
* occured and it can be obtained with error(). If the value hasValue(),
* indicates that peekCallback has been called.
*
* The range that is passed to callback is only valid until callback returns,
* If caller need to preserve data that range points to - that data has to
* be copied.
*
* Calling peek() when there is no data/eof to deliver will return an
* EAGAIN-like error code.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> peek(
StreamId id,
const folly::Function<void(StreamId id, const folly::Range<PeekIterator>&)
const>& peekCallback) = 0;
/**
* Consumes data on the given stream, starting from currentReadOffset
*
* The return value is Expected. If the value hasError(), then a read error
* occured and it can be obtained with error().
*
* @offset - represents start of consumed range.
* Current implementation returns error and currentReadOffset if offset !=
* currentReadOffset
*
* Calling consume() when there is no data/eof to deliver
* will return an EAGAIN-like error code.
*
*/
virtual folly::Expected<
folly::Unit,
std::pair<LocalErrorCode, folly::Optional<uint64_t>>>
consume(StreamId id, uint64_t offset, size_t amount) = 0;
/**
* Equivalent of calling consume(id, stream->currentReadOffset, amount);
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> consume(
StreamId id,
size_t amount) = 0;
/**
* ===== Expire/reject API =====
*
* Sender can "expire" stream data by advancing receiver's minimum
* retransmittable offset, effectively discarding some of the previously
* sent (or planned to be sent) data.
* Discarded data may or may not have already arrived on the receiver end.
*
* Received can "reject" stream data by advancing sender's minimum
* retransmittable offset, effectively discarding some of the previously
* sent (or planned to be sent) data.
* Rejected data may or may not have already arrived on the receiver end.
*/
class DataExpiredCallback {
public:
virtual ~DataExpiredCallback() = default;
/**
* Called from the transport layer when sender informes us that data is
* expired on a given stream.
*/
virtual void onDataExpired(StreamId id, uint64_t newOffset) noexcept = 0;
};
virtual folly::Expected<folly::Unit, LocalErrorCode> setDataExpiredCallback(
StreamId id,
DataExpiredCallback* cb) = 0;
/**
* Expire data.
*
* The return value is Expected. If the value hasError(), then an error
* occured and it can be obtained with error().
*/
virtual folly::Expected<folly::Optional<uint64_t>, LocalErrorCode>
sendDataExpired(StreamId id, uint64_t offset) = 0;
class DataRejectedCallback {
public:
virtual ~DataRejectedCallback() = default;
/**
* Called from the transport layer when receiver informes us that data is
* not needed anymore on a given stream.
*/
virtual void onDataRejected(StreamId id, uint64_t newOffset) noexcept = 0;
};
virtual folly::Expected<folly::Unit, LocalErrorCode> setDataRejectedCallback(
StreamId id,
DataRejectedCallback* cb) = 0;
/**
* Reject data.
*
* The return value is Expected. If the value hasError(), then an error
* occured and it can be obtained with error().
*/
virtual folly::Expected<folly::Optional<uint64_t>, LocalErrorCode>
sendDataRejected(StreamId id, uint64_t offset) = 0;
/**
* ===== Write API =====
*/
/**
* Creates a bidirectional stream. This assigns a stream ID but does not
* send anything to the peer.
*
* If replaySafe is false, the transport will buffer (up to the send buffer
* limits) any writes on this stream until the transport is replay safe.
*/
virtual folly::Expected<StreamId, LocalErrorCode> createBidirectionalStream(
bool replaySafe = true) = 0;
/**
* Creates a unidirectional stream. This assigns a stream ID but does not
* send anything to the peer.
*
* If replaySafe is false, the transport will buffer (up to the send buffer
* limits) any writes on this stream until the transport is replay safe.
*/
virtual folly::Expected<StreamId, LocalErrorCode> createUnidirectionalStream(
bool replaySafe = true) = 0;
/**
* Returns the number of bidirectional streams that can be opened.
*/
virtual uint64_t getNumOpenableBidirectionalStreams() const = 0;
/**
* Returns the number of unidirectional streams that can be opened.
*/
virtual uint64_t getNumOpenableUnidirectionalStreams() const = 0;
/**
* Returns whether a stream ID represents a client-initiated stream.
*/
virtual bool isClientStream(StreamId stream) noexcept = 0;
/**
* Returns whether a stream ID represents a server-initiated stream.
*/
virtual bool isServerStream(StreamId stream) noexcept = 0;
/**
* Returns whether a stream ID represents a unidirectional stream.
*/
virtual bool isUnidirectionalStream(StreamId stream) noexcept = 0;
/**
* Returns whether a stream ID represents a bidirectional stream.
*/
virtual bool isBidirectionalStream(StreamId stream) noexcept = 0;
/**
* Callback class for receiving write readiness notifications
*/
class WriteCallback {
public:
virtual ~WriteCallback() = default;
/**
* Invoked when stream is ready to write after notifyPendingWriteOnStream
* has previously been called.
*
* maxToSend represents the amount of data that the transport layer expects
* to write to the network during this event loop, eg:
* min(remaining flow control, remaining send buffer space)
*/
virtual void onStreamWriteReady(
StreamId /* id */,
uint64_t /* maxToSend */) noexcept {}
/**
* Invoked when connection is ready to write after
* notifyPendingWriteOnConnection has previously been called.
*
* maxToSend represents the amount of data that the transport layer expects
* to write to the network during this event loop, eg:
* min(remaining flow control, remaining send buffer space)
*/
virtual void onConnectionWriteReady(uint64_t /* maxToSend */) noexcept {}
/**
* Invoked when a connection is being torn down after
* notifyPendingWriteOnStream has been called
*/
virtual void onStreamWriteError(
StreamId /* id */,
std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>
/* error */) noexcept {}
/**
* Invoked when a connection is being torn down after
* notifyPendingWriteOnConnection has been called
*/
virtual void onConnectionWriteError(
std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>
/* error */) noexcept {}
};
/**
* Inform the transport that there is data to write on this connection
* An app shouldn't mix connection and stream calls to this API
* Use this if the app wants to do prioritization.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode>
notifyPendingWriteOnConnection(WriteCallback* wcb) = 0;
/**
* Inform the transport that there is data to write on a given stream.
* An app shouldn't mix connection and stream calls to this API
* Use the Connection call if the app wants to do prioritization.
*/
virtual folly::Expected<folly::Unit, LocalErrorCode>
notifyPendingWriteOnStream(StreamId id, WriteCallback* wcb) = 0;
/**
* Callback class for receiving ack notifications
*/
class DeliveryCallback {
public:
virtual ~DeliveryCallback() = default;
/**
* Invoked when the peer has acknowledged the receipt of the specificed
* offset. rtt is the current RTT estimate for the connection.
*/
virtual void onDeliveryAck(
StreamId id,
uint64_t offset,
std::chrono::microseconds rtt) = 0;
/**
* Invoked on registered delivery callbacks when the bytes will never be
* delivered (due to a reset or other error).
*/
virtual void onCanceled(StreamId id, uint64_t offset) = 0;
};
/**
* Write data/eof to the given stream.
*
* cork indicates to the transport that the application expects to write
* more data soon. Passing a delivery callback registers a callback from the
* transport when the peer has acknowledged the receipt of all the data/eof
* passed to write.
*
* A returned IOBuf indicates that the passed data exceeded the transport
* flow control window or send buffer space. The application must call write
* with this data again later, and should be a signal to apply backpressure.
* If EOF was true or a delivery callback was set they also need to be
* passed again later. See notifyPendingWrite to register for a callback.
*
* An error code is present if there was an error with the write.
*/
using WriteResult = folly::Expected<Buf, LocalErrorCode>;
virtual WriteResult writeChain(
StreamId id,
Buf data,
bool eof,
bool cork,
DeliveryCallback* cb = nullptr) = 0;
/**
* Register a callback to be invoked when the peer has acknowledged the
* given offset on the given stream
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> registerDeliveryCallback(
StreamId id,
uint64_t offset,
DeliveryCallback* cb) = 0;
/**
* Close the stream for writing. Equivalent to writeChain(id, nullptr, true).
*/
virtual folly::Optional<LocalErrorCode> shutdownWrite(StreamId id) = 0;
/**
* Cancel the given stream
*/
virtual folly::Expected<folly::Unit, LocalErrorCode> resetStream(
StreamId id,
ApplicationErrorCode error) = 0;
/**
* Callback class for pings
*/
class PingCallback {
public:
virtual ~PingCallback() = default;
/**
* Invoked when the ping is acknowledged
*/
virtual void pingAcknowledged() noexcept = 0;
/**
* Invoked if the ping times out
*/
virtual void pingTimeout() noexcept = 0;
};
/**
* Send a ping to the peer. When the ping is acknowledged by the peer or
* times out, the transport will invoke the callback. The callback may be
* nullptr.
*/
virtual void sendPing(
PingCallback* callback,
std::chrono::milliseconds pingTimeout) = 0;
/**
* Get information on the state of the quic connection. Should only be used
* for logging.
*/
virtual const QuicConnectionStateBase* getState() const = 0;
/**
* Detaches the eventbase from the socket. This must be called from the
* eventbase of socket.
* Normally this is invoked by an app when the connection is idle, i.e.
* there are no "active" streams on the connection, however an app might
* think that all the streams are closed because it wrote the FIN
* to the QuicSocket, however the QuicSocket might not have delivered the FIN
* to the peer yet. Apps SHOULD use the delivery callback to make sure that
* all writes for the closed stream are finished before detaching the
* connection from the eventbase.
*/
virtual void detachEventBase() = 0;
/**
* Attaches an eventbase to the socket. This must be called from the
* eventbase that needs to be attached and the caller must make sure that
* there is no eventbase already attached to the socket.
*/
virtual void attachEventBase(folly::EventBase* evb) = 0;
/**
* Returns whether or not the eventbase can currently be detached from the
* socket.
*/
virtual bool isDetachable() = 0;
/**
* Signal the transport that a certain stream is a control stream.
* A control stream outlives all the other streams in a connection, therefore,
* if the transport knows about it, can enable some optimizations.
* Applications should declare all their control streams after either calling
* createStream() or receiving onNewStream()
*/
virtual folly::Optional<LocalErrorCode> setControlStream(StreamId id) = 0;
};
} // namespace quic

View File

@ -0,0 +1,436 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/io/Cursor.h>
#include <folly/io/async/AsyncTransport.h>
#include <quic/api/QuicSocket.h>
/**
* Adaptor for multiplexing over quic an existing use-case that
* expects and AsyncTransportWrapper
*/
namespace quic {
class QuicStreamAsyncTransport : public folly::AsyncTransportWrapper,
public QuicSocket::ReadCallback,
public QuicSocket::WriteCallback,
public folly::EventBase::LoopCallback {
public:
using UniquePtr = std::unique_ptr<
QuicStreamAsyncTransport,
folly::DelayedDestruction::Destructor>;
void setReadCB(AsyncTransportWrapper::ReadCallback* callback) override {
readCb_ = callback;
// It should be ok to do this immediately, rather than in the loop
handleRead();
}
AsyncTransportWrapper::ReadCallback* getReadCallback() const override {
return readCb_;
}
void write(
AsyncTransportWrapper::WriteCallback* callback,
const void* buf,
size_t bytes,
folly::WriteFlags /*flags*/ = folly::WriteFlags::NONE) override {
// TODO handle cork
auto streamWriteOffset = sock_->getStreamWriteOffset(id_);
if (streamWriteOffset.hasError()) {
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>(
"Quic write error: ", toString(streamWriteOffset.error())));
callback->writeErr(0, ex);
return;
}
writeBuf_.append(folly::IOBuf::wrapBuffer(buf, bytes));
writeCallbacks_.emplace_back(*streamWriteOffset + bytes, callback);
sock_->notifyPendingWriteOnStream(id_, this);
}
void writev(
AsyncTransportWrapper::WriteCallback* callback,
const iovec* vec,
size_t count,
folly::WriteFlags /*flags*/ = folly::WriteFlags::NONE) override {
size_t totalBytes = 0;
for (size_t i = 0; i < count; i++) {
writeBuf_.append(
folly::IOBuf::wrapBuffer(vec[i].iov_base, vec[i].iov_len));
totalBytes += vec[i].iov_len;
}
auto streamWriteOffset = sock_->getStreamWriteOffset(id_);
if (streamWriteOffset.hasError()) {
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>(
"Quic write error: ", toString(streamWriteOffset.error())));
callback->writeErr(0, ex);
return;
}
writeCallbacks_.emplace_back(*streamWriteOffset + totalBytes, callback);
sock_->notifyPendingWriteOnStream(id_, this);
}
void writeChain(
AsyncTransportWrapper::WriteCallback* callback,
std::unique_ptr<folly::IOBuf>&& buf,
folly::WriteFlags /*flags*/ = folly::WriteFlags::NONE) override {
auto streamWriteOffset = sock_->getStreamWriteOffset(id_);
if (streamWriteOffset.hasError()) {
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>(
"Quic write error: ", toString(streamWriteOffset.error())));
callback->writeErr(0, ex);
return;
}
writeCallbacks_.emplace_back(
*streamWriteOffset + buf->computeChainDataLength(), callback);
writeBuf_.append(std::move(buf));
sock_->notifyPendingWriteOnStream(id_, this);
}
void close() override {
shutdownWrite();
if (readCb_ && readEOF_ != EOFState::DELIVERED) {
// This is such a bizarre operation. I almost think if we haven't seen
// a fin then we should readErr instead of readEOF, this mirrors
// AsyncSocket though
readEOF_ = EOFState::QUEUED;
handleRead();
}
}
void closeNow() override {
if (writeBuf_.empty()) {
close();
} else {
sock_->resetStream(id_, quic::ApplicationErrorCode::STOPPING);
VLOG(4) << "Reset stream from closeNow";
}
}
void closeWithReset() override {
sock_->resetStream(id_, quic::ApplicationErrorCode::STOPPING);
VLOG(4) << "Reset stream from closeWithReset";
}
void shutdownWrite() override {
if (writeEOF_ == EOFState::NOT_SEEN) {
writeEOF_ = EOFState::QUEUED;
sock_->notifyPendingWriteOnStream(id_, this);
}
}
void shutdownWriteNow() override {
if (readEOF_ == EOFState::DELIVERED) {
// writes already shutdown
return;
}
if (writeBuf_.empty()) {
shutdownWrite();
} else {
sock_->resetStream(id_, quic::ApplicationErrorCode::STOPPING);
VLOG(4) << "Reset stream from shutdownWriteNow";
}
}
bool good() const override {
return (
!ex_ &&
(readEOF_ == EOFState::NOT_SEEN || writeEOF_ == EOFState::NOT_SEEN));
}
bool readable() const override {
return !ex_ && readEOF_ == EOFState::NOT_SEEN;
}
bool writable() const override {
return !ex_ && writeEOF_ == EOFState::NOT_SEEN;
}
bool isPending() const override {
return false;
}
bool connecting() const override {
return false;
}
virtual bool error() const override {
return bool(ex_);
}
folly::EventBase* getEventBase() const override {
return sock_->getEventBase();
}
void attachEventBase(folly::EventBase* /*eventBase*/) override {
LOG(FATAL) << "Does QUICSocket support this?";
}
void detachEventBase() override {
LOG(FATAL) << "Does QUICSocket support this?";
}
bool isDetachable() const override {
return false; // ?
}
void setSendTimeout(uint32_t /*milliseconds*/) override {
// QuicSocket needs this
}
uint32_t getSendTimeout() const override {
return 0;
}
void getLocalAddress(folly::SocketAddress* /*address*/) const override {
// QuicSocket needs this
}
void getPeerAddress(folly::SocketAddress* /*address*/) const override {
// QuicSocket needs this
}
bool isEorTrackingEnabled() const override {
return false;
}
void setEorTracking(bool /*track*/) override {}
size_t getAppBytesWritten() const override {
auto res = sock_->getStreamWriteOffset(id_);
return res.hasError() ? 0 : res.value();
}
size_t getRawBytesWritten() const override {
auto res = sock_->getStreamWriteOffset(id_);
return res.hasError() ? 0 : res.value();
}
size_t getAppBytesReceived() const override {
auto res = sock_->getStreamReadOffset(id_);
return res.hasError() ? 0 : res.value();
}
size_t getRawBytesReceived() const override {
auto res = sock_->getStreamReadOffset(id_);
return res.hasError() ? 0 : res.value();
}
std::string getApplicationProtocol() const noexcept override {
return "h1q";
}
std::string getSecurityProtocol() const override {
return "quic/tls1.3";
}
QuicStreamAsyncTransport(
std::shared_ptr<quic::QuicSocket> sock,
quic::StreamId id)
: sock_(std::move(sock)), id_(id) {}
~QuicStreamAsyncTransport() {
sock_->setReadCallback(id_, nullptr);
closeNow();
}
private:
void readAvailable(quic::StreamId /*streamId*/) noexcept override {
CHECK(readCb_);
// defer the actual read until the loop callback. This prevents possible
// tail recursion with readAvailable -> setReadCallback -> readAvailable
sock_->getEventBase()->runInLoop(this, true);
}
void readError(
quic::StreamId /*streamId*/,
std::pair<quic::QuicErrorCode, folly::Optional<folly::StringPiece>>
error) noexcept override {
ex_ = folly::AsyncSocketException(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>("Quic read error: ", toString(error)));
sock_->getEventBase()->runInLoop(this, true);
}
void runLoopCallback() noexcept override {
handleRead();
}
void handleRead() {
folly::DelayedDestruction::DestructorGuard dg(this);
bool emptyRead = false;
size_t numReads = 0;
while (readCb_ && !ex_ && readEOF_ == EOFState::NOT_SEEN && !emptyRead &&
++numReads < 16 /* max reads per event */) {
void* buf = nullptr;
size_t len = 0;
if (readCb_->isBufferMovable()) {
len = readCb_->maxBufferSize();
} else {
readCb_->getReadBuffer(&buf, &len);
if (buf == nullptr || len == 0) {
ex_ = folly::AsyncSocketException(
folly::AsyncSocketException::BAD_ARGS,
"ReadCallback::getReadBuffer() returned empty buffer");
break;
}
}
auto readData = sock_->read(id_, len);
if (readData.hasError()) {
ex_ = folly::AsyncSocketException(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>("Quic read error: ", (int)readData.error()));
} else {
size_t readLen = 0;
if (readData->first) {
readLen = readData->first->computeChainDataLength();
emptyRead = (readLen == 0);
} else {
emptyRead = true;
}
if (!emptyRead) {
if (readCb_->isBufferMovable()) {
readCb_->readBufferAvailable(std::move(readData->first));
} else {
folly::io::Cursor c(readData->first.get());
c.pull(buf, readLen);
readCb_->readDataAvailable(readLen);
}
}
if (readData->second && readEOF_ == EOFState::NOT_SEEN) {
readEOF_ = EOFState::QUEUED;
}
}
}
if (readCb_) {
if (ex_) {
auto cb = readCb_;
readCb_ = nullptr;
cb->readErr(*ex_);
} else if (readEOF_ == EOFState::QUEUED) {
auto cb = readCb_;
readCb_ = nullptr;
cb->readEOF();
readEOF_ = EOFState::DELIVERED;
}
}
if (readCb_ && readEOF_ == EOFState::NOT_SEEN && !ex_) {
sock_->setReadCallback(id_, this);
} else {
sock_->setReadCallback(id_, nullptr);
}
}
void send(uint64_t maxToSend) {
// overkill until there are delivery cbs
folly::DelayedDestruction::DestructorGuard dg(this);
uint64_t toSend =
std::min(maxToSend, folly::to<uint64_t>(writeBuf_.chainLength()));
auto streamWriteOffset = sock_->getStreamWriteOffset(id_);
if (streamWriteOffset.hasError()) {
// handle error
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>(
"Quic write error: ", toString(streamWriteOffset.error())));
failWrites(ex);
return;
}
uint64_t sentOffset = *streamWriteOffset + toSend;
bool writeEOF = (writeEOF_ == EOFState::QUEUED);
auto res = sock_->writeChain(
id_,
writeBuf_.split(toSend),
writeEOF,
false,
nullptr); // no delivery callbacks right now
if (res.hasValue()) {
if (res.value()) {
sentOffset -= res.value()->computeChainDataLength();
auto tail = writeBuf_.move();
writeBuf_.append(std::move(res.value()));
writeBuf_.append(std::move(tail));
} else if (writeEOF) {
writeEOF_ = EOFState::DELIVERED;
VLOG(4) << "Closed stream id_=" << id_;
}
// not actually sent. Mirrors AsyncSocket and invokes when data is in
// transport buffers
invokeWriteCallbacks(sentOffset);
} else {
// handle error
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>("Quic write error: ", toString(res.error())));
failWrites(ex);
}
}
void invokeWriteCallbacks(size_t sentOffset) {
while (!writeCallbacks_.empty() &&
writeCallbacks_.front().first <= sentOffset) {
auto wcb = writeCallbacks_.front().second;
writeCallbacks_.pop_front();
wcb->writeSuccess();
}
}
void failWrites(folly::AsyncSocketException& ex) {
while (!writeCallbacks_.empty()) {
auto& front = writeCallbacks_.front();
auto wcb = front.second;
writeCallbacks_.pop_front();
wcb->writeErr(0, ex);
}
}
void onStreamWriteReady(
quic::StreamId /*id*/,
uint64_t maxToSend) noexcept override {
if (writeEOF_ == EOFState::DELIVERED && writeBuf_.empty()) {
// nothing left to write
return;
}
send(maxToSend);
}
void onStreamWriteError(
StreamId /*id*/,
std::pair<quic::QuicErrorCode, folly::Optional<folly::StringPiece>>
error) noexcept override {
folly::AsyncSocketException ex(
folly::AsyncSocketException::UNKNOWN,
folly::to<std::string>("Quic write error: ", toString(error)));
failWrites(ex);
}
std::shared_ptr<quic::QuicSocket> sock_;
quic::StreamId id_;
enum class EOFState { NOT_SEEN, QUEUED, DELIVERED };
EOFState readEOF_{EOFState::NOT_SEEN};
EOFState writeEOF_{EOFState::NOT_SEEN};
AsyncTransportWrapper::ReadCallback* readCb_{nullptr};
folly::IOBufQueue writeBuf_{folly::IOBufQueue::cacheChainLength()};
std::deque<std::pair<size_t, AsyncTransportWrapper::WriteCallback*>>
writeCallbacks_;
folly::Optional<folly::AsyncSocketException> ex_;
};
} // namespace quic

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,574 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/ExceptionWrapper.h>
#include <folly/io/async/AsyncUDPSocket.h>
#include <folly/io/async/HHWheelTimer.h>
#include <quic/QuicException.h>
#include <quic/api/QuicSocket.h>
#include <quic/common/FunctionLooper.h>
#include <quic/common/Timers.h>
#include <quic/congestion_control/CongestionControllerFactory.h>
#include <quic/congestion_control/Copa.h>
#include <quic/congestion_control/NewReno.h>
#include <quic/congestion_control/QuicCubic.h>
#include <quic/state/StateData.h>
namespace quic {
enum class CloseState { OPEN, GRACEFUL_CLOSING, CLOSED };
/**
* Base class for the QUIC Transport. Implements common behavior for both
* clients and servers. QuicTransportBase assumes the following:
* 1. It is intended to be sub-classed and used via the subclass directly.
* 2. Assumes that the sub-class manages its ownership via a shared_ptr.
* This is needed in order for QUIC to be able to live beyond the lifetime
* of the object that holds it to send graceful close messages to the peer.
*/
class QuicTransportBase : public QuicSocket {
public:
QuicTransportBase(
folly::EventBase* evb,
std::unique_ptr<folly::AsyncUDPSocket> socket);
~QuicTransportBase() override;
void setPacingTimer(TimerHighRes::SharedPtr pacingTimer) noexcept;
folly::EventBase* getEventBase() const override;
folly::Optional<ConnectionId> getClientConnectionId() const override;
folly::Optional<ConnectionId> getServerConnectionId() const override;
const folly::SocketAddress& getPeerAddress() const override;
const folly::SocketAddress& getOriginalPeerAddress() const override;
const folly::SocketAddress& getLocalAddress() const override;
// QuicSocket interface
bool good() const override;
bool replaySafe() const override;
bool error() const override;
void close(
folly::Optional<std::pair<QuicErrorCode, std::string>> error) override;
void closeGracefully() override;
void closeNow(
folly::Optional<std::pair<QuicErrorCode, std::string>> error) override;
folly::Expected<size_t, LocalErrorCode> getStreamReadOffset(
StreamId id) const override;
folly::Expected<size_t, LocalErrorCode> getStreamWriteOffset(
StreamId id) const override;
folly::Expected<size_t, LocalErrorCode> getStreamWriteBufferedBytes(
StreamId id) const override;
TransportInfo getTransportInfo() const override;
folly::Expected<StreamTransportInfo, LocalErrorCode> getStreamTransportInfo(
StreamId id) const override;
folly::Optional<std::string> getAppProtocol() const override;
void setReceiveWindow(StreamId id, size_t recvWindowSize) override;
void setSendBuffer(StreamId id, size_t maxUnacked, size_t maxUnsent) override;
uint64_t bufferSpaceAvailable();
folly::Expected<QuicSocket::FlowControlState, LocalErrorCode>
getConnectionFlowControl() const override;
folly::Expected<QuicSocket::FlowControlState, LocalErrorCode>
getStreamFlowControl(StreamId id) const override;
folly::Expected<folly::Unit, LocalErrorCode> setConnectionFlowControlWindow(
uint64_t windowSize) override;
folly::Expected<folly::Unit, LocalErrorCode> setStreamFlowControlWindow(
StreamId id,
uint64_t windowSize) override;
folly::Expected<folly::Unit, LocalErrorCode> setReadCallback(
StreamId id,
ReadCallback* cb) override;
void unsetAllReadCallbacks() override;
void unsetAllPeekCallbacks() override;
void unsetAllDeliveryCallbacks() override;
folly::Expected<folly::Unit, LocalErrorCode> pauseRead(StreamId id) override;
folly::Expected<folly::Unit, LocalErrorCode> resumeRead(StreamId id) override;
folly::Expected<folly::Unit, LocalErrorCode> stopSending(
StreamId id,
ApplicationErrorCode error) override;
folly::Expected<std::pair<Buf, bool>, LocalErrorCode> read(
StreamId id,
size_t maxLen) override;
folly::Expected<folly::Unit, LocalErrorCode> setPeekCallback(
StreamId id,
PeekCallback* cb) override;
folly::Expected<folly::Unit, LocalErrorCode> pausePeek(StreamId id) override;
folly::Expected<folly::Unit, LocalErrorCode> resumePeek(StreamId id) override;
folly::Expected<folly::Unit, LocalErrorCode> peek(
StreamId id,
const folly::Function<void(StreamId id, const folly::Range<PeekIterator>&)
const>& peekCallback) override;
folly::Expected<folly::Unit, LocalErrorCode> consume(
StreamId id,
size_t amount) override;
folly::Expected<
folly::Unit,
std::pair<LocalErrorCode, folly::Optional<uint64_t>>>
consume(StreamId id, uint64_t offset, size_t amount) override;
folly::Expected<folly::Unit, LocalErrorCode> setDataExpiredCallback(
StreamId id,
DataExpiredCallback* cb) override;
folly::Expected<folly::Optional<uint64_t>, LocalErrorCode> sendDataExpired(
StreamId id,
uint64_t offset) override;
folly::Expected<folly::Unit, LocalErrorCode> setDataRejectedCallback(
StreamId id,
DataRejectedCallback* cb) override;
folly::Expected<folly::Optional<uint64_t>, LocalErrorCode> sendDataRejected(
StreamId id,
uint64_t offset) override;
folly::Expected<StreamId, LocalErrorCode> createBidirectionalStream(
bool replaySafe = true) override;
folly::Expected<StreamId, LocalErrorCode> createUnidirectionalStream(
bool replaySafe = true) override;
uint64_t getNumOpenableBidirectionalStreams() const override;
uint64_t getNumOpenableUnidirectionalStreams() const override;
bool isClientStream(StreamId stream) noexcept override;
bool isServerStream(StreamId stream) noexcept override;
bool isUnidirectionalStream(StreamId stream) noexcept override;
bool isBidirectionalStream(StreamId stream) noexcept override;
folly::Expected<folly::Unit, LocalErrorCode> notifyPendingWriteOnStream(
StreamId id,
WriteCallback* wcb) override;
folly::Expected<folly::Unit, LocalErrorCode> notifyPendingWriteOnConnection(
WriteCallback* wcb) override;
WriteResult writeChain(
StreamId id,
Buf data,
bool eof,
bool cork,
DeliveryCallback* cb = nullptr) override;
folly::Expected<folly::Unit, LocalErrorCode> registerDeliveryCallback(
StreamId id,
uint64_t offset,
DeliveryCallback* cb) override;
folly::Optional<LocalErrorCode> shutdownWrite(StreamId id) override;
folly::Expected<folly::Unit, LocalErrorCode> resetStream(
StreamId id,
ApplicationErrorCode errorCode) override;
void sendPing(PingCallback* callback, std::chrono::milliseconds pingTimeout)
override;
const QuicConnectionStateBase* getState() const override {
return conn_.get();
}
// Interface with the Transport layer when data is available.
// This is invoked when new data is received from the UDP socket.
virtual void onNetworkData(
const folly::SocketAddress& peer,
NetworkData&& data) noexcept;
// TODO: move only to the server api.
virtual void setSupportedVersions(const std::vector<QuicVersion>& versions);
virtual void setConnectionCallback(
ConnectionCallback* callback) override final;
bool isDetachable() override;
void detachEventBase() override;
void attachEventBase(folly::EventBase* evb) override;
folly::Optional<LocalErrorCode> setControlStream(StreamId id) override;
// Updates the congestion controller app limited state, after a change in the
// number of streams.
// App limited state is set to true if there was at least one non-control
// before the update and there are none after. It is set to false if instead
// there were no non-control streams before and there is at least one at the
// time of calling
void updateAppLimitedState(bool hadNonCtrlStreams);
/**
* Invoke onCanceled for all the delivery callbacks in the deliveryCallbacks
* passed in. This is supposed to be a copy of the real deque of the delivery
* callbacks for the stream, so there is no need to pop anything off of it.
*/
static void cancelDeliveryCallbacks(
StreamId id,
const std::deque<std::pair<uint64_t, QuicSocket::DeliveryCallback*>>&
deliveryCallbacks);
/**
* Invoke onCanceled for all the delivery callbacks in the deliveryCallbacks
* map. This is supposed to be a copy of the real map of the delivery
* callbacks of the transport, so there is no need to erase anything from it.
*/
static void cancelDeliveryCallbacks(
const std::unordered_map<
StreamId,
std::deque<std::pair<uint64_t, QuicSocket::DeliveryCallback*>>>&
deliveryCallbacks);
/**
* Set the initial flow control window for the connection.
*/
void setTransportSettings(TransportSettings transportSettings) override;
/**
* Set factory to create specific congestion controller instances
* for a given connection
*/
virtual void setCongestionControllerFactory(
std::shared_ptr<CongestionControllerFactory> factory);
// Subclass API.
/**
* Invoked when a new packet is read from the network.
* peer is the address of the peer that was in the packet.
* The sub-class may throw an exception if there was an error in processing
* the packet in which case the connection will be closed.
*/
virtual void onReadData(
const folly::SocketAddress& peer,
NetworkData&& networkData) = 0;
/**
* Invoked when we have to write some data to the wire.
* The subclass may use this to start writing data to the socket.
* It may also throw an exception in case of an error in which case the
* connection will be closed.
*/
virtual void writeData() = 0;
/**
* closeTransport is invoked on the sub-class when the transport is closed.
* The sub-class may clean up any state during this call. The transport
* may still be draining after this call.
*/
virtual void closeTransport() = 0;
/**
* Invoked after the drain timeout has exceeded and the connection state will
* be destroyed.
*/
virtual void unbindConnection() = 0;
/**
* Returns whether or not the connection has a write cipher. This will be used
* to decide to return the onTransportReady() callbacks.
*/
virtual bool hasWriteCipher() const = 0;
/**
* Returns a shared_ptr which can be used as a guard to keep this
* object alive.
*/
virtual std::shared_ptr<QuicTransportBase> sharedGuard() = 0;
bool isPartiallyReliableTransport() const override;
/**
* Invoke onCanceled on all the delivery callbacks registered for streamId.
*/
void cancelDeliveryCallbacksForStream(StreamId streamId) override;
// Timeout functions
class LossTimeout : public folly::HHWheelTimer::Callback {
public:
~LossTimeout() override = default;
explicit LossTimeout(QuicTransportBase* transport)
: transport_(transport) {}
void timeoutExpired() noexcept override {
transport_->lossTimeoutExpired();
}
virtual void callbackCanceled() noexcept override {
// ignore. this usually means that the eventbase is dying, so we will be
// canceled anyway
return;
}
private:
QuicTransportBase* transport_;
};
class AckTimeout : public folly::HHWheelTimer::Callback {
public:
~AckTimeout() override = default;
explicit AckTimeout(QuicTransportBase* transport) : transport_(transport) {}
void timeoutExpired() noexcept override {
transport_->ackTimeoutExpired();
}
virtual void callbackCanceled() noexcept override {
// ignore. this usually means that the eventbase is dying, so we will be
// canceled anyway
return;
}
private:
QuicTransportBase* transport_;
};
class PathValidationTimeout : public folly::HHWheelTimer::Callback {
public:
~PathValidationTimeout() override = default;
explicit PathValidationTimeout(QuicTransportBase* transport)
: transport_(transport) {}
void timeoutExpired() noexcept override {
transport_->pathValidationTimeoutExpired();
}
virtual void callbackCanceled() noexcept override {
// ignore. this usually means that the eventbase is dying, so we will be
// canceled anyway
return;
}
private:
QuicTransportBase* transport_;
};
class IdleTimeout : public folly::HHWheelTimer::Callback {
public:
~IdleTimeout() override = default;
explicit IdleTimeout(QuicTransportBase* transport)
: transport_(transport) {}
void timeoutExpired() noexcept override {
transport_->idleTimeoutExpired(true /* drain */);
}
void callbackCanceled() noexcept override {
// skip drain when canceling the timeout, to avoid scheduling a new
// drain timeout
transport_->idleTimeoutExpired(false /* drain */);
}
private:
QuicTransportBase* transport_;
};
// DrainTimeout is a bit different from other timeouts. It needs to hold a
// shared_ptr to the transport, since if a DrainTimeout is scheduled,
// transport cannot die.
class DrainTimeout : public folly::HHWheelTimer::Callback {
public:
~DrainTimeout() override = default;
explicit DrainTimeout(QuicTransportBase* transport)
: transport_(transport) {}
void timeoutExpired() noexcept override {
transport_->drainTimeoutExpired();
}
private:
QuicTransportBase* transport_;
};
void scheduleLossTimeout(std::chrono::milliseconds timeout);
void cancelLossTimeout();
bool isLossTimeoutScheduled() const;
// If you don't set it, the default is Cubic
void setCongestionControl(CongestionControlType type);
void describe(std::ostream& os) const;
void setLogger(std::shared_ptr<Logger> logger) {
conn_->logger = std::move(logger);
}
virtual void cancelAllAppCallbacks(
std::pair<QuicErrorCode, std::string> error) noexcept;
protected:
void processCallbacksAfterNetworkData();
void invokeReadDataAndCallbacks();
void invokePeekDataAndCallbacks();
void invokeDataExpiredCallbacks();
void invokeDataRejectedCallbacks();
void updateReadLooper();
void updatePeekLooper();
void updateWriteLooper(bool thisIteration);
void runOnEvbAsync(
folly::Function<void(std::shared_ptr<QuicTransportBase>)> func);
void closeImpl(
folly::Optional<std::pair<QuicErrorCode, std::string>> error,
bool drainConnection = true,
bool sendCloseImmediately = true);
folly::Expected<folly::Unit, LocalErrorCode> pauseOrResumeRead(
StreamId id,
bool resume);
folly::Expected<folly::Unit, LocalErrorCode> pauseOrResumePeek(
StreamId id,
bool resume);
void checkForClosedStream();
folly::Expected<folly::Unit, LocalErrorCode> setReadCallbackInternal(
StreamId id,
ReadCallback* cb) noexcept;
void setPeekCallbackInternal(StreamId id, PeekCallback* cb) noexcept;
folly::Expected<StreamId, LocalErrorCode> createStreamInternal(
bool bidirectional);
/**
* write data to socket
*
* At transport layer, this is the simplest form of write. It writes data
* out to the network, and schedule necessary timers (ack, idle, loss). It is
* both pacing oblivious and writeLooper oblivious. Caller needs to explicitly
* invoke updateWriteLooper afterwards if that's desired.
*/
void writeSocketData();
/**
* A wrapper around writeSocketData
*
* writeSocketDataAndCatch protects writeSocketData in a try-catch. It also
* dispatch the next write loop.
*/
void writeSocketDataAndCatch();
/**
* Paced write data to socket when connection is paced.
*
* Whether connection is based will be decided by TransportSettings and
* congection controller. When the connection is paced, this function writes
* out a burst size of packets and let the writeLooper schedule a callback to
* write another burst after a pacing interval if there are more data to
* write. When the connection isn't paced, this function do a normal write.
*/
void pacedWriteDataToSocket(bool fromTimer);
uint64_t maxWritableOnStream(const QuicStreamState&);
uint64_t maxWritableOnConn();
void lossTimeoutExpired() noexcept;
void ackTimeoutExpired() noexcept;
void pathValidationTimeoutExpired() noexcept;
void idleTimeoutExpired(bool drain) noexcept;
void drainTimeoutExpired() noexcept;
void setIdleTimer();
void scheduleAckTimeout();
void schedulePathValidationTimeout();
std::atomic<folly::EventBase*> evb_;
std::unique_ptr<folly::AsyncUDPSocket> socket_;
ConnectionCallback* connCallback_{nullptr};
std::unique_ptr<QuicConnectionStateBase> conn_;
struct ReadCallbackData {
ReadCallback* readCb;
bool resumed{true};
bool deliveredEOM{false};
ReadCallbackData(ReadCallback* readCallback) : readCb(readCallback) {}
};
struct PeekCallbackData {
PeekCallback* peekCb;
bool resumed{true};
PeekCallbackData(PeekCallback* peekCallback) : peekCb(peekCallback) {}
};
struct DataExpiredCallbackData {
DataExpiredCallback* dataExpiredCb;
bool resumed{true};
DataExpiredCallbackData(DataExpiredCallback* cb) : dataExpiredCb(cb) {}
};
struct DataRejectedCallbackData {
DataRejectedCallback* dataRejectedCb;
bool resumed{true};
DataRejectedCallbackData(DataRejectedCallback* cb) : dataRejectedCb(cb) {}
};
// Map of streamID to tupl
std::unordered_map<StreamId, ReadCallbackData> readCallbacks_;
std::unordered_map<StreamId, PeekCallbackData> peekCallbacks_;
std::unordered_map<
StreamId,
std::deque<std::pair<uint64_t, DeliveryCallback*>>>
deliveryCallbacks_;
std::unordered_map<StreamId, DataExpiredCallbackData> dataExpiredCallbacks_;
std::unordered_map<StreamId, DataRejectedCallbackData> dataRejectedCallbacks_;
WriteCallback* connWriteCallback_{nullptr};
std::map<StreamId, WriteCallback*> pendingWriteCallbacks_;
CloseState closeState_{CloseState::OPEN};
bool transportReadyNotified_{false};
LossTimeout lossTimeout_;
AckTimeout ackTimeout_;
PathValidationTimeout pathValidationTimeout_;
IdleTimeout idleTimeout_;
DrainTimeout drainTimeout_;
FunctionLooper::Ptr readLooper_;
FunctionLooper::Ptr peekLooper_;
FunctionLooper::Ptr writeLooper_;
// TODO: This is silly. We need a better solution.
// Uninitialied local address as a fallback answer when socket isn't bound.
folly::SocketAddress localFallbackAddress;
// CongestionController factory
std::shared_ptr<CongestionControllerFactory> ccFactory_{nullptr};
};
std::ostream& operator<<(std::ostream& os, const QuicTransportBase& qt);
} // namespace quic

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,234 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Expected.h>
#include <folly/io/async/AsyncUDPSocket.h>
#include <quic/QuicException.h>
#include <quic/api/QuicPacketScheduler.h>
#include <quic/api/QuicSocket.h>
#include <quic/state/StateData.h>
// Function to schedule writing data to socket. Return number of packets
// successfully scheduled
namespace quic {
using HeaderBuilder = std::function<PacketHeader(
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
PacketNum packetNum,
QuicVersion version)>;
using WritableBytesFunc =
std::function<uint64_t(const QuicConnectionStateBase& conn)>;
/**
* Attempts to write data from all frames in the QUIC connection into the UDP
* socket supplied with the aead and the headerCipher.
*/
uint64_t writeQuicDataToSocket(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher,
QuicVersion version,
uint64_t packetLimit);
/**
* Writes only the crypto and ack frames to the socket.
*
* return the number of packets written to socket.
*/
uint64_t writeCryptoAndAckDataToSocket(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
LongHeader::Types packetType,
fizz::Aead& cleartextCipher,
const PacketNumberCipher& headerCipher,
QuicVersion version,
uint64_t packetLimit);
/**
* Writes out all the data streams without writing out crypto streams.
* This is useful when the crypto stream still needs to be sent in separate
* packets and cannot use the encryption of the data key.
*/
uint64_t writeQuicDataExceptCryptoStreamToSocket(
folly::AsyncUDPSocket& socket,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher,
QuicVersion version,
uint64_t packetLimit);
/**
* Writes frame data including zero rtt data to the socket with the supplied
* zero rtt cipher.
*/
uint64_t writeZeroRttDataToSocket(
folly::AsyncUDPSocket& socket,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher,
QuicVersion version,
uint64_t packetLimit);
/**
* Whether we should and can write data.
*
* TODO: We should probably split "should" and "can" into two APIs.
*/
bool shouldWriteData(const QuicConnectionStateBase& conn);
bool hasAckDataToWrite(const QuicConnectionStateBase& conn);
bool hasNonAckDataToWrite(const QuicConnectionStateBase& conn);
/**
* Invoked when the written stream data was new stream data.
*/
void handleNewStreamDataWritten(
QuicConnectionStateBase& conn,
QuicStreamLike& stream,
uint64_t frameLen,
bool frameFin,
PacketNum packetNum,
PacketNumberSpace packetNumberSpace);
/**
* Invoked when the stream data written was retransmitted data.
*/
void handleRetransmissionWritten(
QuicConnectionStateBase& conn,
QuicStreamLike& stream,
uint64_t frameOffset,
uint64_t frameLen,
bool frameFin,
PacketNum packetNum);
/**
* Update the connection and stream state after stream data is written and deal
* with new data, as well as retranmissions. Returns true if the data sent is
* new data.
*/
bool handleStreamWritten(
QuicConnectionStateBase& conn,
QuicStreamLike& stream,
uint64_t frameOffset,
uint64_t frameLen,
bool frameFin,
PacketNum packetNum,
PacketNumberSpace packetNumberSpace);
/**
* Update the connection state after sending a new packet.
*/
void updateConnection(
QuicConnectionStateBase& conn,
folly::Optional<PacketEvent> packetEvent,
RegularQuicWritePacket packet,
TimePoint time,
uint32_t encodedSize);
uint64_t congestionControlWritableBytes(const QuicConnectionStateBase& conn);
uint64_t unlimitedWritableBytes(const QuicConnectionStateBase&);
void writeCloseCommon(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
PacketHeader&& header,
folly::Optional<std::pair<QuicErrorCode, std::string>> closeDetails,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher);
/**
* Writes a LongHeader packet with a close frame.
* The close frame type written depends on the type of error in closeDetails.
*/
void writeLongClose(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
LongHeader::Types headerType,
folly::Optional<std::pair<QuicErrorCode, std::string>> closeDetails,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher,
QuicVersion);
/**
* Write a short header packet with a close frame.
* The close frame type written depends on the type of error in closeDetails.
*/
void writeShortClose(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
const ConnectionId& connId,
folly::Optional<std::pair<QuicErrorCode, std::string>> closeDetails,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher);
/**
* Encrypts the packet header for the header type.
* This will overwrite the header with the encrypted header form. It will verify
* whether or not there are enough bytes to sample for the header encryption
* from the encryptedBody, and if there is not enough, it will throw an
* exception.
*/
void encryptPacketHeader(
HeaderForm headerType,
folly::IOBuf& header,
folly::IOBuf& encryptedBody,
const PacketNumberCipher& cipher);
/**
* Writes the connections data to the socket using the header
* builder as well as the scheduler. This will write the amount of
* data allowed by the writableBytesFunc and will only write a maximum
* number of packetLimit packets at each invocation.
*/
uint64_t writeConnectionDataToSocket(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
const HeaderBuilder& builder,
PacketNumberSpace pnSpace,
QuicPacketScheduler& scheduler,
const WritableBytesFunc& writableBytesFunc,
uint64_t packetLimit,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher,
QuicVersion version);
uint64_t writeProbingDataToSocket(
folly::AsyncUDPSocket& sock,
QuicConnectionStateBase& connection,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
const HeaderBuilder& builder,
PacketNumberSpace pnSpace,
FrameScheduler scheduler,
uint8_t probesToSend,
const fizz::Aead& aead,
const PacketNumberCipher& headerCipher,
QuicVersion version);
HeaderBuilder LongHeaderBuilder(LongHeader::Types packetType);
HeaderBuilder ShortHeaderBuilder();
} // namespace quic

71
quic/api/TARGETS Normal file
View File

@ -0,0 +1,71 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
cpp_library(
name = "transport",
srcs = [
"IoBufQuicBatch.cpp",
"QuicBatchWriter.cpp",
"QuicPacketScheduler.cpp",
"QuicTransportBase.cpp",
"QuicTransportFunctions.cpp",
],
headers = [
"IoBufQuicBatch.h",
"QuicBatchWriter.h",
"QuicPacketScheduler.h",
"QuicPacketScheduler-inl.h",
"QuicSocket.h",
"QuicTransportBase.h",
"QuicTransportFunctions.h",
],
deps = [
"//folly:exception_wrapper",
"//folly:expected",
"//folly:optional",
"//folly:overload",
"//folly:scope_guard",
"//folly/io:iobuf",
"//folly/io/async:async_base",
"//folly/io/async:async_udp_socket",
"//quic:constants",
"//quic:exception",
"//quic/codec:codec",
"//quic/codec:pktbuilder",
"//quic/codec:pktrebuilder",
"//quic/codec:types",
"//quic/common:looper",
"//quic/common:time_util",
"//quic/common:timers",
"//quic/congestion_control:congestion_controller_factory",
"//quic/congestion_control:copa",
"//quic/congestion_control:cubic",
"//quic/congestion_control:newreno",
"//quic/flowcontrol:flow_control",
"//quic/happyeyeballs:happyeyeballs",
"//quic/logging:logging",
"//quic/loss:loss",
"//quic/state:pacing_functions",
"//quic/state:simple_frame_functions",
"//quic/state:state_functions",
"//quic/state:state_machine",
"//quic/state:stream_functions",
"//quic/state/stream:stream",
],
external_deps = [
"boost",
],
)
cpp_library(
name = "stream_transport",
headers = [
"QuicStreamAsyncTransport.h",
],
deps = [
":transport",
"//folly/io:iobuf",
"//folly/io/async:async_transport",
],
)

View File

@ -0,0 +1,114 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
if(NOT BUILD_TESTS)
return()
endif()
add_library(mvfst_mock_socket STATIC
MockQuicSocket.h
MockQuicStats.h
Mocks.h
)
target_include_directories(
mvfst_mock_socket PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
)
add_dependencies(
mvfst_mock_socket
mvfst_exception
mvfst_transport
mvfst_codec_types
mvfst_server
mvfst_state_machine
mvfst_state_stats_callback
)
target_link_libraries(
mvfst_mock_socket PUBLIC
Folly::folly
mvfst_exception
mvfst_transport
mvfst_codec_types
mvfst_server
mvfst_state_machine
mvfst_state_stats_callback
${LIBGMOCK_LIBRARIES}
)
quic_add_test(TARGET QuicTransportTest
SOURCES
QuicTransportTest.cpp
DEPENDS
Folly::folly
mvfst_transport
mvfst_mock_socket
mvfst_mock_state
mvfst_server
mvfst_state_stream_functions
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicTransportBaseTest
SOURCES
QuicTransportBaseTest.cpp
DEPENDS
Folly::folly
mvfst_transport
mvfst_mock_socket
mvfst_mock_state
mvfst_codec_types
mvfst_test_utils
mvfst_state_stream_functions
mvfst_server
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicTransportFunctionsTest
SOURCES
QuicTransportFunctionsTest.cpp
DEPENDS
Folly::folly
mvfst_transport
mvfst_mock_socket
mvfst_mock_state
mvfst_test_utils
mvfst_server
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicPacketSchedulerTest
SOURCES
QuicPacketSchedulerTest.cpp
DEPENDS
Folly::folly
mvfst_client
mvfst_codec_pktbuilder
mvfst_transport
mvfst_mock_socket
mvfst_test_utils
mvfst_server
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET IoBufQuicBatchTest
SOURCES
IoBufQuicBatchTest.cpp
DEPENDS
Folly::folly
mvfst_transport
mvfst_state_machine
)
quic_add_test(TARGET QuicBatchWriterTest
SOURCES
QuicBatchWriterTest.cpp
DEPENDS
Folly::folly
mvfst_transport
)

View File

@ -0,0 +1,84 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/api/IoBufQuicBatch.h>
#include <gtest/gtest.h>
#include <quic/state/StateData.h>
constexpr const auto kNumLoops = 64;
constexpr const auto kMaxBufs = 10;
namespace quic {
namespace testing {
class TestPacketBatchWriter : public IOBufBatchWriter {
public:
explicit TestPacketBatchWriter(int maxBufs) : maxBufs_(maxBufs) {}
~TestPacketBatchWriter() override {
CHECK_EQ(bufNum_, 0);
CHECK_EQ(bufSize_, 0);
}
void reset() override {
bufNum_ = 0;
bufSize_ = 0;
}
bool append(std::unique_ptr<folly::IOBuf>&& /*unused*/, size_t size)
override {
bufNum_++;
bufSize_ += size;
return ((maxBufs_ < 0) || (bufNum_ >= maxBufs_));
}
ssize_t write(
folly::AsyncUDPSocket& /*unused*/,
const folly::SocketAddress& /*unused*/) override {
return bufSize_;
}
private:
int maxBufs_{0};
int bufNum_{0};
size_t bufSize_{0};
};
void RunTest(int numBatch) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
auto batchWriter = std::make_unique<TestPacketBatchWriter>(numBatch);
folly::SocketAddress peerAddress{"127.0.0.1", 1234};
QuicConnectionStateBase::HappyEyeballsState happyEyeballsState;
IOBufQuicBatch ioBufBatch(
std::move(batchWriter), sock, peerAddress, happyEyeballsState);
std::string strTest("Test");
for (size_t i = 0; i < kNumLoops; i++) {
auto buf = folly::IOBuf::copyBuffer(strTest.c_str(), strTest.length());
CHECK(ioBufBatch.write(std::move(buf), strTest.length()));
}
// check flush is successful
CHECK(ioBufBatch.flush());
// check we sent all the packets
CHECK_EQ(ioBufBatch.getPktSent(), kNumLoops);
}
TEST(QuicBatch, TestBatchingNone) {
RunTest(1);
}
TEST(QuicBatch, TestBatchingNoFlush) {
RunTest(-1);
}
TEST(QuicBatch, TestBatching) {
RunTest(kMaxBufs);
}
} // namespace testing
} // namespace quic

View File

@ -0,0 +1,219 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/portability/GMock.h>
#include <quic/api/QuicSocket.h>
namespace quic {
class MockQuicSocket : public QuicSocket {
public:
using SharedBuf = std::shared_ptr<folly::IOBuf>;
MockQuicSocket(folly::EventBase* /*eventBase*/, ConnectionCallback& cb)
: cb_(&cb) {}
MOCK_CONST_METHOD0(good, bool());
MOCK_CONST_METHOD0(replaySafe, bool());
MOCK_CONST_METHOD0(error, bool());
MOCK_METHOD1(
close,
void(folly::Optional<std::pair<QuicErrorCode, std::string>>));
MOCK_METHOD0(closeGracefully, void());
MOCK_METHOD1(
closeNow,
void(folly::Optional<std::pair<QuicErrorCode, std::string>>));
MOCK_CONST_METHOD0(
getClientConnectionId,
folly::Optional<quic::ConnectionId>());
MOCK_CONST_METHOD0(
getServerConnectionId,
folly::Optional<quic::ConnectionId>());
MOCK_CONST_METHOD0(getPeerAddress, const folly::SocketAddress&());
MOCK_CONST_METHOD0(getOriginalPeerAddress, const folly::SocketAddress&());
MOCK_CONST_METHOD0(getLocalAddress, const folly::SocketAddress&());
MOCK_CONST_METHOD0(getEventBase, folly::EventBase*());
MOCK_CONST_METHOD1(
getStreamReadOffset,
folly::Expected<size_t, LocalErrorCode>(StreamId));
MOCK_CONST_METHOD1(
getStreamWriteOffset,
folly::Expected<size_t, LocalErrorCode>(StreamId));
MOCK_CONST_METHOD1(
getStreamWriteBufferedBytes,
folly::Expected<size_t, LocalErrorCode>(StreamId));
MOCK_CONST_METHOD0(getTransportInfo, QuicSocket::TransportInfo());
MOCK_CONST_METHOD1(
getStreamTransportInfo,
folly::Expected<QuicSocket::StreamTransportInfo, LocalErrorCode>(
StreamId));
MOCK_CONST_METHOD0(getAppProtocol, folly::Optional<std::string>());
MOCK_METHOD2(setReceiveWindow, void(StreamId, size_t));
MOCK_METHOD3(setSendBuffer, void(StreamId, size_t, size_t));
MOCK_CONST_METHOD0(
getConnectionFlowControl,
folly::Expected<FlowControlState, LocalErrorCode>());
MOCK_CONST_METHOD1(
getStreamFlowControl,
folly::Expected<FlowControlState, LocalErrorCode>(StreamId));
MOCK_METHOD0(unsetAllReadCallbacks, void());
MOCK_METHOD0(unsetAllPeekCallbacks, void());
MOCK_METHOD0(unsetAllDeliveryCallbacks, void());
MOCK_METHOD1(cancelDeliveryCallbacksForStream, void(StreamId));
MOCK_METHOD1(
setConnectionFlowControlWindow,
folly::Expected<folly::Unit, LocalErrorCode>(uint64_t));
MOCK_METHOD2(
setStreamFlowControlWindow,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, uint64_t));
MOCK_METHOD1(setTransportSettings, void(TransportSettings));
MOCK_CONST_METHOD0(isPartiallyReliableTransport, bool());
MOCK_METHOD2(
setReadCallback,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, ReadCallback*));
MOCK_METHOD1(setConnectionCallback, void(ConnectionCallback*));
MOCK_METHOD1(
pauseRead,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId));
MOCK_METHOD1(
resumeRead,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId));
MOCK_METHOD2(
stopSending,
folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
ApplicationErrorCode));
folly::Expected<std::pair<Buf, bool>, LocalErrorCode> read(
StreamId id,
size_t maxRead) override {
auto res = readNaked(id, maxRead);
if (res.hasError()) {
return folly::makeUnexpected(res.error());
} else {
return std::pair<Buf, bool>(Buf(res.value().first), res.value().second);
}
}
using ReadResult =
folly::Expected<std::pair<folly::IOBuf*, bool>, LocalErrorCode>;
MOCK_METHOD2(readNaked, ReadResult(StreamId, size_t));
MOCK_METHOD1(
createBidirectionalStream,
folly::Expected<StreamId, LocalErrorCode>(bool));
MOCK_METHOD1(
createUnidirectionalStream,
folly::Expected<StreamId, LocalErrorCode>(bool));
MOCK_CONST_METHOD0(getNumOpenableBidirectionalStreams, uint64_t());
MOCK_CONST_METHOD0(getNumOpenableUnidirectionalStreams, uint64_t());
GMOCK_METHOD1_(, noexcept, , isClientStream, bool(StreamId));
GMOCK_METHOD1_(, noexcept, , isServerStream, bool(StreamId));
GMOCK_METHOD1_(, noexcept, , isBidirectionalStream, bool(StreamId));
GMOCK_METHOD1_(, noexcept, , isUnidirectionalStream, bool(StreamId));
MOCK_METHOD1(
notifyPendingWriteOnConnection,
folly::Expected<folly::Unit, LocalErrorCode>(WriteCallback*));
MOCK_METHOD2(
notifyPendingWriteOnStream,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, WriteCallback*));
folly::Expected<Buf, LocalErrorCode> writeChain(
StreamId id,
Buf data,
bool eof,
bool cork,
DeliveryCallback* cb) override {
SharedBuf sharedData(data.release());
auto res = writeChain(id, sharedData, eof, cork, cb);
if (res.hasError()) {
return folly::makeUnexpected(res.error());
} else {
return Buf(res.value());
}
}
using WriteResult = folly::Expected<folly::IOBuf*, LocalErrorCode>;
MOCK_METHOD5(
writeChain,
WriteResult(StreamId, SharedBuf, bool, bool, DeliveryCallback*));
MOCK_METHOD3(
registerDeliveryCallback,
folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
uint64_t,
DeliveryCallback*));
MOCK_METHOD1(shutdownWrite, folly::Optional<LocalErrorCode>(StreamId));
MOCK_METHOD2(
resetStream,
folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
ApplicationErrorCode));
MOCK_METHOD2(sendPing, void(PingCallback*, std::chrono::milliseconds));
MOCK_CONST_METHOD0(getState, const QuicConnectionStateBase*());
MOCK_METHOD0(isDetachable, bool());
MOCK_METHOD1(attachEventBase, void(folly::EventBase*));
MOCK_METHOD0(detachEventBase, void());
MOCK_METHOD1(setControlStream, folly::Optional<LocalErrorCode>(StreamId));
MOCK_METHOD2(
setPeekCallback,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, PeekCallback*));
MOCK_METHOD1(
pausePeek,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId));
MOCK_METHOD1(
resumePeek,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId));
MOCK_METHOD2(
peek,
folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
const folly::Function<
void(StreamId, const folly::Range<PeekIterator>&) const>&));
MOCK_METHOD3(
consume,
folly::Expected<
folly::Unit,
std::pair<LocalErrorCode, folly::Optional<uint64_t>>>(
StreamId,
uint64_t,
size_t));
MOCK_METHOD2(
consume,
folly::Expected<folly::Unit, LocalErrorCode>(StreamId, size_t));
MOCK_METHOD2(
setDataExpiredCallback,
folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
DataExpiredCallback*));
MOCK_METHOD2(
sendDataExpired,
folly::Expected<folly::Optional<uint64_t>, LocalErrorCode>(
StreamId,
uint64_t offset));
MOCK_METHOD2(
setDataRejectedCallback,
folly::Expected<folly::Unit, LocalErrorCode>(
StreamId,
DataRejectedCallback*));
MOCK_METHOD2(
sendDataRejected,
folly::Expected<folly::Optional<uint64_t>, LocalErrorCode>(
StreamId,
uint64_t offset));
ConnectionCallback* cb_;
};
} // namespace quic

View File

@ -0,0 +1,53 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/io/async/EventBase.h>
#include <folly/portability/GMock.h>
#include <quic/state/QuicTransportStatsCallback.h>
namespace quic {
class MockQuicStats : public QuicTransportStatsCallback {
public:
MOCK_METHOD0(onPacketReceived, void());
MOCK_METHOD0(onDuplicatedPacketReceived, void());
MOCK_METHOD0(onOutOfOrderPacketReceived, void());
MOCK_METHOD0(onPacketProcessed, void());
MOCK_METHOD0(onPacketSent, void());
MOCK_METHOD0(onPacketRetransmission, void());
MOCK_METHOD1(onPacketDropped, void(PacketDropReason));
MOCK_METHOD0(onPacketForwarded, void());
MOCK_METHOD0(onForwardedPacketReceived, void());
MOCK_METHOD0(onForwardedPacketProcessed, void());
MOCK_METHOD0(onNewConnection, void());
MOCK_METHOD1(onConnectionClose, void(folly::Optional<ConnectionCloseReason>));
MOCK_METHOD0(onNewQuicStream, void());
MOCK_METHOD0(onQuicStreamClosed, void());
MOCK_METHOD0(onQuicStreamReset, void());
MOCK_METHOD0(onConnFlowControlUpdate, void());
MOCK_METHOD0(onConnFlowControlBlocked, void());
MOCK_METHOD0(onStreamFlowControlUpdate, void());
MOCK_METHOD0(onStreamFlowControlBlocked, void());
MOCK_METHOD0(onCwndBlocked, void());
MOCK_METHOD0(onRTO, void());
MOCK_METHOD1(onRead, void(size_t));
MOCK_METHOD1(onWrite, void(size_t));
};
class MockQuicStatsFactory : public QuicTransportStatsCallbackFactory {
public:
~MockQuicStatsFactory() override = default;
MOCK_METHOD1(
make,
std::unique_ptr<QuicTransportStatsCallback>(folly::EventBase*));
};
} // namespace quic

276
quic/api/test/Mocks.h Normal file
View File

@ -0,0 +1,276 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <folly/portability/GMock.h>
#include <folly/io/async/EventBase.h>
#include <quic/QuicException.h>
#include <quic/api/QuicSocket.h>
#include <quic/codec/QuicConnectionId.h>
#include <quic/common/Timers.h>
#include <quic/server/QuicServerTransport.h>
#include <quic/state/StateData.h>
namespace quic {
class MockFrameScheduler : public FrameScheduler {
public:
~MockFrameScheduler() override = default;
MockFrameScheduler() : FrameScheduler("mock") {}
// override methods accepting rvalue ref since gmock doesn't support it
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>
scheduleFramesForPacket(
RegularQuicPacketBuilder&& builderIn,
uint32_t writableBytes) override {
auto builder =
std::make_unique<RegularQuicPacketBuilder>(std::move(builderIn));
return _scheduleFramesForPacket(builder, writableBytes);
}
GMOCK_METHOD0_(, const, , hasData, bool());
MOCK_METHOD2(
_scheduleFramesForPacket,
std::pair<
folly::Optional<PacketEvent>,
folly::Optional<RegularQuicPacketBuilder::Packet>>(
std::unique_ptr<RegularQuicPacketBuilder>&,
uint32_t));
};
class MockReadCallback : public QuicSocket::ReadCallback {
public:
~MockReadCallback() override = default;
GMOCK_METHOD1_(, noexcept, , readAvailable, void(StreamId));
GMOCK_METHOD2_(
,
noexcept,
,
readError,
void(
StreamId,
std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>));
};
class MockPeekCallback : public QuicSocket::PeekCallback {
public:
~MockPeekCallback() override = default;
GMOCK_METHOD2_(
,
noexcept,
,
onDataAvailable,
void(StreamId, const folly::Range<PeekIterator>&));
};
class MockWriteCallback : public QuicSocket::WriteCallback {
public:
~MockWriteCallback() override = default;
GMOCK_METHOD2_(, noexcept, , onStreamWriteReady, void(StreamId, uint64_t));
GMOCK_METHOD1_(, noexcept, , onConnectionWriteReady, void(uint64_t));
GMOCK_METHOD2_(
,
noexcept,
,
onStreamWriteError,
void(
StreamId,
std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>));
GMOCK_METHOD1_(
,
noexcept,
,
onConnectionWriteError,
void(std::pair<QuicErrorCode, folly::Optional<folly::StringPiece>>));
};
class MockConnectionCallback : public QuicSocket::ConnectionCallback {
public:
~MockConnectionCallback() override = default;
GMOCK_METHOD1_(, noexcept, , onFlowControlUpdate, void(StreamId));
GMOCK_METHOD1_(, noexcept, , onNewStream, void(StreamId));
GMOCK_METHOD1_(, noexcept, , onNewUnidirectionalStream, void(StreamId));
GMOCK_METHOD2_(
,
noexcept,
,
onStopSending,
void(StreamId, ApplicationErrorCode));
GMOCK_METHOD0_(, noexcept, , onConnectionEnd, void());
GMOCK_METHOD1_(
,
noexcept,
,
onConnectionError,
void(std::pair<QuicErrorCode, std::string>));
GMOCK_METHOD0_(, noexcept, , onReplaySafe, void());
GMOCK_METHOD0_(, noexcept, , onTransportReady, void());
GMOCK_METHOD2_(
,
noexcept,
,
validateEarlyDataAppParams,
bool(const folly::Optional<std::string>&, const Buf&));
GMOCK_METHOD0_(, noexcept, , serializeEarlyDataAppParams, Buf());
};
class MockDeliveryCallback : public QuicSocket::DeliveryCallback {
public:
~MockDeliveryCallback() override = default;
GMOCK_METHOD3_(
,
,
,
onDeliveryAck,
void(StreamId, uint64_t, std::chrono::microseconds));
GMOCK_METHOD2_(, , , onCanceled, void(StreamId, size_t));
};
class MockDataExpiredCallback : public QuicSocket::DataExpiredCallback {
public:
~MockDataExpiredCallback() override = default;
GMOCK_METHOD2_(, noexcept, , onDataExpired, void(StreamId, uint64_t));
};
class MockDataRejectedCallback : public QuicSocket::DataRejectedCallback {
public:
~MockDataRejectedCallback() override = default;
GMOCK_METHOD2_(, noexcept, , onDataRejected, void(StreamId, uint64_t));
};
class MockQuicTransport : public QuicServerTransport {
public:
using Ptr = std::shared_ptr<MockQuicTransport>;
class RoutingCallback : public QuicServerTransport::RoutingCallback {
public:
virtual ~RoutingCallback() = default;
// Called when a connection id is available
virtual void onConnectionIdAvailable(
QuicServerTransport::Ptr transport,
ConnectionId id) noexcept = 0;
// Called when a connecton id is bound and ip address should not
// be used any more for routing.
virtual void onConnectionIdBound(
QuicServerTransport::Ptr transport) noexcept = 0;
// Called when the connection is finished and needs to be Unbound.
virtual void onConnectionUnbound(
const QuicServerTransport::SourceIdentity& address,
folly::Optional<ConnectionId> connectionId) noexcept = 0;
};
MockQuicTransport(
folly::EventBase* evb,
std::unique_ptr<folly::AsyncUDPSocket> sock,
ConnectionCallback& cb,
std::shared_ptr<const fizz::server::FizzServerContext> ctx)
: QuicServerTransport(evb, std::move(sock), cb, ctx) {}
virtual ~MockQuicTransport() = default;
GMOCK_METHOD0_(, const, , getEventBase, folly::EventBase*());
MOCK_CONST_METHOD0(getPeerAddress, const folly::SocketAddress&());
MOCK_CONST_METHOD0(getOriginalPeerAddress, const folly::SocketAddress&());
GMOCK_METHOD1_(
,
,
,
setOriginalPeerAddress,
void(const folly::SocketAddress&));
GMOCK_METHOD0_(, , , accept, void());
GMOCK_METHOD1_(, , , setTransportSettings, void(TransportSettings));
GMOCK_METHOD1_(, noexcept, , setPacingTimer, void(TimerHighRes::SharedPtr));
void onNetworkData(
const folly::SocketAddress& peer,
NetworkData&& networkData) noexcept override {
onNetworkData(peer, networkData.data.get());
}
GMOCK_METHOD2_(
,
noexcept,
,
onNetworkData,
void(const folly::SocketAddress&, const folly::IOBuf*));
GMOCK_METHOD1_(
,
noexcept,
,
setRoutingCallback,
void(QuicServerTransport::RoutingCallback*));
GMOCK_METHOD1_(
,
noexcept,
,
setSupportedVersions,
void(const std::vector<QuicVersion>&));
GMOCK_METHOD1_(
,
noexcept,
,
setServerConnectionIdParams,
void(ServerConnectionIdParams));
GMOCK_METHOD1_(
,
noexcept,
,
close,
void(folly::Optional<std::pair<QuicErrorCode, std::string>>));
GMOCK_METHOD1_(
,
noexcept,
,
closeNow,
void(folly::Optional<std::pair<QuicErrorCode, std::string>>));
GMOCK_METHOD0_(, const, , hasShutdown, bool());
GMOCK_METHOD0_(
,
const,
,
getClientConnectionId,
folly::Optional<ConnectionId>());
GMOCK_METHOD1_(
,
noexcept,
,
setTransportInfoCallback,
void(QuicTransportStatsCallback*));
GMOCK_METHOD1_(, noexcept, , setConnectionIdAlgo, void(ConnectionIdAlgo*));
};
inline std::ostream& operator<<(std::ostream& os, const MockQuicTransport&) {
return os;
}
} // namespace quic

View File

@ -0,0 +1,196 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/api/QuicBatchWriter.h>
#include <gtest/gtest.h>
namespace quic {
namespace testing {
constexpr const auto kStrLen = 10;
constexpr const auto kStrLenGT = 20;
constexpr const auto kStrLenLT = 5;
constexpr const auto kBatchNum = 3;
constexpr const auto kNumLoops = 10;
TEST(QuicBatchWriter, TestBatchingNone) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
sock.setReuseAddr(false);
sock.bind(folly::SocketAddress("127.0.0.1", 0));
auto batchWriter = quic::BatchWriterFactory::makeBatchWriter(
sock, quic::QuicBatchingMode::BATCHING_MODE_NONE, kBatchNum);
CHECK(batchWriter);
std::string strTest('A', kStrLen);
// run multiple loops
for (size_t i = 0; i < kNumLoops; i++) {
CHECK(batchWriter->empty());
CHECK_EQ(batchWriter->size(), 0);
auto buf = folly::IOBuf::copyBuffer(strTest.c_str(), kStrLen);
CHECK(batchWriter->append(std::move(buf), kStrLen));
CHECK_EQ(batchWriter->size(), kStrLen);
batchWriter->reset();
}
}
TEST(QuicBatchWriter, TestBatchingGSOBase) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
sock.setReuseAddr(false);
sock.bind(folly::SocketAddress("127.0.0.1", 0));
auto batchWriter = quic::BatchWriterFactory::makeBatchWriter(
sock, quic::QuicBatchingMode::BATCHING_MODE_GSO, 1);
CHECK(batchWriter);
std::string strTest(kStrLen, 'A');
// if GSO is not available, just test we've got a regular
// batch writer
if (sock.getGSO() < 0) {
CHECK(batchWriter->empty());
CHECK_EQ(batchWriter->size(), 0);
auto buf = folly::IOBuf::copyBuffer(strTest);
CHECK(batchWriter->append(std::move(buf), strTest.size()));
EXPECT_FALSE(batchWriter->needsFlush(kStrLenLT));
}
}
TEST(QuicBatchWriter, TestBatchingGSOLastSmallPacket) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
sock.setReuseAddr(false);
sock.bind(folly::SocketAddress("127.0.0.1", 0));
auto batchWriter = quic::BatchWriterFactory::makeBatchWriter(
sock, quic::QuicBatchingMode::BATCHING_MODE_GSO, 1);
CHECK(batchWriter);
std::string strTest;
// only if GSO is available
if (sock.getGSO() >= 0) {
// run multiple loops
for (size_t i = 0; i < kNumLoops; i++) {
// batch kStrLen, kStrLenLT
CHECK(batchWriter->empty());
CHECK_EQ(batchWriter->size(), 0);
strTest = std::string(kStrLen, 'A');
auto buf = folly::IOBuf::copyBuffer(strTest);
EXPECT_FALSE(batchWriter->needsFlush(kStrLen));
EXPECT_FALSE(batchWriter->append(std::move(buf), kStrLen));
CHECK_EQ(batchWriter->size(), kStrLen);
strTest = std::string(kStrLenLT, 'A');
buf = folly::IOBuf::copyBuffer(strTest);
EXPECT_FALSE(batchWriter->needsFlush(kStrLenLT));
CHECK(batchWriter->append(std::move(buf), kStrLenLT));
CHECK_EQ(batchWriter->size(), kStrLen + kStrLenLT);
batchWriter->reset();
}
}
}
TEST(QuicBatchWriter, TestBatchingGSOLastBigPacket) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
sock.setReuseAddr(false);
sock.bind(folly::SocketAddress("127.0.0.1", 0));
auto batchWriter = quic::BatchWriterFactory::makeBatchWriter(
sock, quic::QuicBatchingMode::BATCHING_MODE_GSO, 1);
CHECK(batchWriter);
std::string strTest;
// only if GSO is available
if (sock.getGSO() >= 0) {
// run multiple loops
for (size_t i = 0; i < kNumLoops; i++) {
// try to batch kStrLen, kStrLenGT
CHECK(batchWriter->empty());
CHECK_EQ(batchWriter->size(), 0);
strTest = std::string(kStrLen, 'A');
auto buf = folly::IOBuf::copyBuffer(strTest);
EXPECT_FALSE(batchWriter->needsFlush(kStrLen));
EXPECT_FALSE(batchWriter->append(std::move(buf), kStrLen));
CHECK_EQ(batchWriter->size(), kStrLen);
CHECK(batchWriter->needsFlush(kStrLenGT));
batchWriter->reset();
}
}
}
TEST(QuicBatchWriter, TestBatchingGSOBatchNum) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
sock.setReuseAddr(false);
sock.bind(folly::SocketAddress("127.0.0.1", 0));
auto batchWriter = quic::BatchWriterFactory::makeBatchWriter(
sock, quic::QuicBatchingMode::BATCHING_MODE_GSO, kBatchNum);
CHECK(batchWriter);
std::string strTest(kStrLen, 'A');
// if GSO is not available, just test we've got a regular
// batch writer
if (sock.getGSO() >= 0) {
// run multiple loops
for (size_t i = 0; i < kNumLoops; i++) {
// try to batch up to kBatchNum
CHECK(batchWriter->empty());
CHECK_EQ(batchWriter->size(), 0);
size_t size = 0;
for (auto j = 0; j < kBatchNum - 1; j++) {
auto buf = folly::IOBuf::copyBuffer(strTest);
EXPECT_FALSE(batchWriter->append(std::move(buf), kStrLen));
size += kStrLen;
CHECK_EQ(batchWriter->size(), size);
}
// add the kBatchNum buf
auto buf = folly::IOBuf::copyBuffer(strTest.c_str(), kStrLen);
CHECK(batchWriter->append(std::move(buf), kStrLen));
size += kStrLen;
CHECK_EQ(batchWriter->size(), size);
batchWriter->reset();
}
}
}
TEST(QuicBatchWriter, TestBatchingSendmmsg) {
folly::EventBase evb;
folly::AsyncUDPSocket sock(&evb);
sock.setReuseAddr(false);
sock.bind(folly::SocketAddress("127.0.0.1", 0));
auto batchWriter = quic::BatchWriterFactory::makeBatchWriter(
sock, quic::QuicBatchingMode::BATCHING_MODE_SENDMMSG, kBatchNum);
CHECK(batchWriter);
std::string strTest(kStrLen, 'A');
// run multiple loops
for (size_t i = 0; i < kNumLoops; i++) {
// try to batch up to kBatchNum
CHECK(batchWriter->empty());
CHECK_EQ(batchWriter->size(), 0);
size_t size = 0;
for (auto j = 0; j < kBatchNum - 1; j++) {
auto buf = folly::IOBuf::copyBuffer(strTest);
EXPECT_FALSE(batchWriter->append(std::move(buf), kStrLen));
size += kStrLen;
CHECK_EQ(batchWriter->size(), size);
}
// add the kBatchNum buf
auto buf = folly::IOBuf::copyBuffer(strTest.c_str(), kStrLen);
CHECK(batchWriter->append(std::move(buf), kStrLen));
size += kStrLen;
CHECK_EQ(batchWriter->size(), size);
batchWriter->reset();
}
}
} // namespace testing
} // namespace quic

View File

@ -0,0 +1,701 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/api/QuicPacketScheduler.h>
#include <folly/portability/GTest.h>
#include <quic/api/test/Mocks.h>
#include <quic/client/state/ClientStateMachine.h>
#include <quic/codec/QuicPacketBuilder.h>
#include <quic/common/test/TestUtils.h>
#include <quic/server/state/ServerStateMachine.h>
using namespace quic;
using namespace testing;
namespace {
PacketNum addInitialOutstandingPacket(QuicConnectionStateBase& conn) {
PacketNum nextPacketNum =
getNextPacketNum(conn, PacketNumberSpace::Handshake);
std::vector<uint8_t> zeroConnIdData(quic::kDefaultConnectionIdSize, 0);
ConnectionId srcConnId(zeroConnIdData);
LongHeader header(
LongHeader::Types::Initial,
srcConnId,
conn.clientConnectionId.value_or(quic::test::getTestConnectionId()),
nextPacketNum,
QuicVersion::QUIC_DRAFT);
RegularQuicWritePacket packet(std::move(header));
conn.outstandingPackets.emplace_back(packet, Clock::now(), 0, true, false, 0);
conn.outstandingHandshakePacketsCount++;
increaseNextPacketNum(conn, PacketNumberSpace::Handshake);
return nextPacketNum;
}
PacketNum addHandshakeOutstandingPacket(QuicConnectionStateBase& conn) {
PacketNum nextPacketNum =
getNextPacketNum(conn, PacketNumberSpace::Handshake);
std::vector<uint8_t> zeroConnIdData(quic::kDefaultConnectionIdSize, 0);
ConnectionId srcConnId(zeroConnIdData);
LongHeader header(
LongHeader::Types::Handshake,
srcConnId,
conn.clientConnectionId.value_or(quic::test::getTestConnectionId()),
nextPacketNum,
QuicVersion::QUIC_DRAFT);
RegularQuicWritePacket packet(std::move(header));
conn.outstandingPackets.emplace_back(packet, Clock::now(), 0, true, false, 0);
conn.outstandingHandshakePacketsCount++;
increaseNextPacketNum(conn, PacketNumberSpace::Handshake);
return nextPacketNum;
}
PacketNum addPureAckOutstandingPacket(QuicConnectionStateBase& conn) {
PacketNum nextPacketNum = getNextPacketNum(conn, PacketNumberSpace::AppData);
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(quic::test::getTestConnectionId()),
nextPacketNum);
RegularQuicWritePacket packet(std::move(header));
conn.outstandingPackets.emplace_back(packet, Clock::now(), 0, false, true, 0);
increaseNextPacketNum(conn, PacketNumberSpace::AppData);
return nextPacketNum;
}
PacketNum addOutstandingPacket(QuicConnectionStateBase& conn) {
PacketNum nextPacketNum = getNextPacketNum(conn, PacketNumberSpace::AppData);
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(quic::test::getTestConnectionId()),
nextPacketNum);
RegularQuicWritePacket packet(std::move(header));
conn.outstandingPackets.emplace_back(
packet, Clock::now(), 0, false, false, 0);
increaseNextPacketNum(conn, PacketNumberSpace::AppData);
return nextPacketNum;
}
} // namespace
namespace quic {
namespace test {
class QuicPacketSchedulerTest : public Test {
public:
QuicVersion version{QuicVersion::MVFST};
};
TEST_F(QuicPacketSchedulerTest, NoopScheduler) {
QuicConnectionStateBase conn(QuicNodeType::Client);
FrameScheduler scheduler("frame");
EXPECT_FALSE(scheduler.hasData());
LongHeader header(
LongHeader::Types::Initial,
getTestConnectionId(1),
getTestConnectionId(),
0x1356,
version);
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.initialAckState.largestAckedByPeer);
auto builtPacket = std::move(builder).buildPacket();
EXPECT_TRUE(builtPacket.packet.frames.empty());
}
TEST_F(QuicPacketSchedulerTest, CryptoPaddingInitialPacket) {
QuicClientConnectionState conn;
auto connId = getTestConnectionId();
LongHeader longHeader1(
LongHeader::Types::Initial,
getTestConnectionId(1),
connId,
getNextPacketNum(conn, PacketNumberSpace::Initial),
QuicVersion::MVFST);
increaseNextPacketNum(conn, PacketNumberSpace::Initial);
RegularQuicPacketBuilder builder1(
conn.udpSendPacketLen,
std::move(longHeader1),
conn.ackStates.initialAckState.largestAckedByPeer);
CryptoStreamScheduler scheduler(
conn,
*getCryptoStream(*conn.cryptoState, fizz::EncryptionLevel::Plaintext));
writeDataToQuicStream(
conn.cryptoState->initialStream, folly::IOBuf::copyBuffer("chlo"));
scheduler.writeCryptoData(builder1);
EXPECT_EQ(builder1.remainingSpaceInPkt(), 0);
LongHeader longHeader2(
LongHeader::Types::Handshake,
connId,
connId,
getNextPacketNum(conn, PacketNumberSpace::Handshake),
QuicVersion::MVFST);
RegularQuicPacketBuilder builder2(
conn.udpSendPacketLen,
std::move(longHeader2),
conn.ackStates.handshakeAckState.largestAckedByPeer);
writeDataToQuicStream(
conn.cryptoState->initialStream, folly::IOBuf::copyBuffer("finished"));
scheduler.writeCryptoData(builder2);
EXPECT_GT(builder2.remainingSpaceInPkt(), 0);
}
TEST_F(QuicPacketSchedulerTest, CryptoServerInitialNotPadded) {
QuicServerConnectionState conn;
auto connId = getTestConnectionId();
PacketNum nextPacketNum = getNextPacketNum(conn, PacketNumberSpace::Initial);
LongHeader longHeader1(
LongHeader::Types::Initial,
getTestConnectionId(1),
connId,
nextPacketNum,
QuicVersion::MVFST);
RegularQuicPacketBuilder builder1(
conn.udpSendPacketLen,
std::move(longHeader1),
conn.ackStates.initialAckState.largestAckedByPeer);
CryptoStreamScheduler scheduler(
conn,
*getCryptoStream(*conn.cryptoState, fizz::EncryptionLevel::Plaintext));
writeDataToQuicStream(
conn.cryptoState->initialStream, folly::IOBuf::copyBuffer("shlo"));
scheduler.writeCryptoData(builder1);
EXPECT_GT(builder1.remainingSpaceInPkt(), 0);
}
TEST_F(QuicPacketSchedulerTest, CryptoPaddingRetransmissionClientInitial) {
QuicClientConnectionState conn;
auto connId = getTestConnectionId();
LongHeader longHeader(
LongHeader::Types::Initial,
getTestConnectionId(1),
connId,
getNextPacketNum(conn, PacketNumberSpace::Initial),
QuicVersion::MVFST);
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(longHeader),
conn.ackStates.initialAckState.largestAckedByPeer);
CryptoStreamScheduler scheduler(
conn,
*getCryptoStream(*conn.cryptoState, fizz::EncryptionLevel::Plaintext));
conn.cryptoState->initialStream.lossBuffer.push_back(
StreamBuffer{folly::IOBuf::copyBuffer("chlo"), 0, false});
scheduler.writeCryptoData(builder);
EXPECT_EQ(builder.remainingSpaceInPkt(), 0);
}
TEST_F(QuicPacketSchedulerTest, CryptoSchedulerOnlySingleLossFits) {
QuicServerConnectionState conn;
auto connId = getTestConnectionId();
LongHeader longHeader(
LongHeader::Types::Handshake,
connId,
connId,
getNextPacketNum(conn, PacketNumberSpace::Handshake),
QuicVersion::MVFST);
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(longHeader),
conn.ackStates.handshakeAckState.largestAckedByPeer);
PacketBuilderWrapper builderWrapper(builder, 13);
CryptoStreamScheduler scheduler(
conn,
*getCryptoStream(*conn.cryptoState, fizz::EncryptionLevel::Handshake));
conn.cryptoState->handshakeStream.lossBuffer.push_back(
StreamBuffer{folly::IOBuf::copyBuffer("shlo"), 0, false});
conn.cryptoState->handshakeStream.lossBuffer.push_back(StreamBuffer{
folly::IOBuf::copyBuffer(
"certificatethatisverylongseriouslythisisextremelylongandcannotfitintoapacket"),
7,
false});
EXPECT_TRUE(scheduler.writeCryptoData(builderWrapper));
}
TEST_F(QuicPacketSchedulerTest, CryptoWritePartialLossBuffer) {
QuicClientConnectionState conn;
auto connId = getTestConnectionId();
LongHeader longHeader(
LongHeader::Types::Initial,
ConnectionId(std::vector<uint8_t>()),
connId,
getNextPacketNum(conn, PacketNumberSpace::Initial),
QuicVersion::MVFST);
RegularQuicPacketBuilder builder(
25,
std::move(longHeader),
conn.ackStates.initialAckState.largestAckedByPeer);
CryptoStreamScheduler scheduler(
conn,
*getCryptoStream(*conn.cryptoState, fizz::EncryptionLevel::Plaintext));
conn.cryptoState->initialStream.lossBuffer.push_back(StreamBuffer{
folly::IOBuf::copyBuffer("return the special duration value max"),
0,
false});
EXPECT_TRUE(scheduler.writeCryptoData(builder));
EXPECT_EQ(builder.remainingSpaceInPkt(), 0);
EXPECT_FALSE(conn.cryptoState->initialStream.lossBuffer.empty());
}
TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerExists) {
QuicServerConnectionState conn;
conn.streamManager->setMaxLocalBidirectionalStreams(10);
auto connId = getTestConnectionId();
auto stream = conn.streamManager->createNextBidirectionalStream().value();
WindowUpdateScheduler scheduler(conn);
ShortHeader shortHeader(
ProtectionType::KeyPhaseZero,
connId,
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(shortHeader),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto originalSpace = builder.remainingSpaceInPkt();
conn.streamManager->queueWindowUpdate(stream->id);
scheduler.writeWindowUpdates(builder);
EXPECT_LT(builder.remainingSpaceInPkt(), originalSpace);
}
TEST_F(QuicPacketSchedulerTest, StreamFrameNoSpace) {
QuicServerConnectionState conn;
conn.streamManager->setMaxLocalBidirectionalStreams(10);
auto connId = getTestConnectionId();
auto stream = conn.streamManager->createNextBidirectionalStream().value();
WindowUpdateScheduler scheduler(conn);
ShortHeader shortHeader(
ProtectionType::KeyPhaseZero,
connId,
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(shortHeader),
conn.ackStates.appDataAckState.largestAckedByPeer);
PacketBuilderWrapper builderWrapper(builder, 2);
auto originalSpace = builder.remainingSpaceInPkt();
conn.streamManager->queueWindowUpdate(stream->id);
scheduler.writeWindowUpdates(builderWrapper);
EXPECT_EQ(builder.remainingSpaceInPkt(), originalSpace);
}
TEST_F(QuicPacketSchedulerTest, StreamFrameSchedulerStreamNotExists) {
QuicServerConnectionState conn;
auto connId = getTestConnectionId();
StreamId nonExistentStream = 11;
WindowUpdateScheduler scheduler(conn);
ShortHeader shortHeader(
ProtectionType::KeyPhaseZero,
connId,
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(shortHeader),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto originalSpace = builder.remainingSpaceInPkt();
conn.streamManager->queueWindowUpdate(nonExistentStream);
scheduler.writeWindowUpdates(builder);
EXPECT_EQ(builder.remainingSpaceInPkt(), originalSpace);
}
TEST_F(QuicPacketSchedulerTest, CloningSchedulerTest) {
QuicClientConnectionState conn;
FrameScheduler noopScheduler("frame");
ASSERT_FALSE(noopScheduler.hasData());
CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0);
EXPECT_FALSE(cloningScheduler.hasData());
auto packetNum = addOutstandingPacket(conn);
// There needs to have retransmittable frame for the rebuilder to work
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
EXPECT_TRUE(cloningScheduler.hasData());
ASSERT_FALSE(noopScheduler.hasData());
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto result = cloningScheduler.scheduleFramesForPacket(
std::move(builder), kDefaultUDPSendPacketLen);
EXPECT_TRUE(result.first.hasValue() && result.second.hasValue());
EXPECT_EQ(packetNum, *result.first);
}
TEST_F(QuicPacketSchedulerTest, WriteOnlyOutstandingPacketsTest) {
QuicClientConnectionState conn;
FrameScheduler noopScheduler("frame");
ASSERT_FALSE(noopScheduler.hasData());
CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0);
EXPECT_FALSE(cloningScheduler.hasData());
auto packetNum = addOutstandingPacket(conn);
// There needs to have retransmittable frame for the rebuilder to work
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
EXPECT_TRUE(cloningScheduler.hasData());
ASSERT_FALSE(noopScheduler.hasData());
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder regularBuilder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.appDataAckState.largestAckedByPeer);
// Create few frames
ConnectionCloseFrame connCloseFrame(
TransportErrorCode::FRAME_ENCODING_ERROR, "The sun is in the sky.");
MaxStreamsFrame maxStreamIdFrame(0x1024, true);
PingFrame pingFrame;
IntervalSet<PacketNum> ackBlocks;
ackBlocks.insert(10, 100);
ackBlocks.insert(200, 1000);
AckFrameMetaData ackMeta(
ackBlocks, std::chrono::microseconds(0), kDefaultAckDelayExponent);
// Write those framses with a regular builder
writeFrame(connCloseFrame, regularBuilder);
writeFrame(maxStreamIdFrame, regularBuilder);
writeFrame(pingFrame, regularBuilder);
writeAckFrame(ackMeta, regularBuilder);
auto result = cloningScheduler.scheduleFramesForPacket(
std::move(regularBuilder), kDefaultUDPSendPacketLen);
EXPECT_TRUE(result.first.hasValue() && result.second.hasValue());
EXPECT_EQ(packetNum, *result.first);
// written packet (result.second) should not have any frame in the builder
auto& writtenPacket = *result.second;
auto shortHeader = boost::get<ShortHeader>(&writtenPacket.packet.header);
CHECK(shortHeader);
EXPECT_EQ(ProtectionType::KeyPhaseOne, shortHeader->getProtectionType());
EXPECT_EQ(
conn.ackStates.appDataAckState.nextPacketNum,
shortHeader->getPacketSequenceNum());
// Test that the only frame that's written is maxdataframe
EXPECT_GE(writtenPacket.packet.frames.size(), 1);
auto& writtenFrame = writtenPacket.packet.frames.at(0);
auto maxDataFrame = boost::get<MaxDataFrame>(&writtenFrame);
CHECK(maxDataFrame);
for (auto& frame : writtenPacket.packet.frames) {
bool present = false;
/* the next four frames should not be written */
present |= boost::get<ConnectionCloseFrame>(&frame) ? true : false;
present |= boost::get<MaxStreamsFrame>(&frame) ? true : false;
present |= boost::get<PingFrame>(&frame) ? true : false;
present |= boost::get<WriteAckFrame>(&frame) ? true : false;
ASSERT_FALSE(present);
}
}
TEST_F(QuicPacketSchedulerTest, DoNotCloneProcessedClonedPacket) {
QuicClientConnectionState conn;
FrameScheduler noopScheduler("frame");
CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0);
// Add two outstanding packets, but then mark the second one processed by
// adding a PacketEvent that's missing from the outstandingPacketEvents set
PacketNum expected = addOutstandingPacket(conn);
// There needs to have retransmittable frame for the rebuilder to work
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
addOutstandingPacket(conn);
conn.outstandingPackets.back().associatedEvent = 1;
// There needs to have retransmittable frame for the rebuilder to work
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.initialAckState.largestAckedByPeer);
auto result = cloningScheduler.scheduleFramesForPacket(
std::move(builder), kDefaultUDPSendPacketLen);
EXPECT_TRUE(result.first.hasValue() && result.second.hasValue());
EXPECT_EQ(expected, *result.first);
}
TEST_F(QuicPacketSchedulerTest, DoNotClonePureAck) {
QuicClientConnectionState conn;
FrameScheduler noopScheduler("frame");
CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0);
// Add two outstanding packets, with second one being pureAck
auto expected = addOutstandingPacket(conn);
// There needs to have retransmittable frame for the rebuilder to work
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
addPureAckOutstandingPacket(conn);
conn.outstandingPackets.back().packet.frames.push_back(WriteAckFrame());
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto result = cloningScheduler.scheduleFramesForPacket(
std::move(builder), kDefaultUDPSendPacketLen);
EXPECT_TRUE(result.first.hasValue() && result.second.hasValue());
EXPECT_EQ(expected, *result.first);
}
TEST_F(QuicPacketSchedulerTest, CloneSchedulerHasDataIgnoresNonAppData) {
QuicClientConnectionState conn;
FrameScheduler noopScheduler("frame");
CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0);
EXPECT_FALSE(cloningScheduler.hasData());
addHandshakeOutstandingPacket(conn);
EXPECT_FALSE(cloningScheduler.hasData());
addInitialOutstandingPacket(conn);
EXPECT_FALSE(cloningScheduler.hasData());
addOutstandingPacket(conn);
EXPECT_TRUE(cloningScheduler.hasData());
}
TEST_F(QuicPacketSchedulerTest, DoNotCloneHandshake) {
QuicClientConnectionState conn;
FrameScheduler noopScheduler("frame");
CloningScheduler cloningScheduler(noopScheduler, conn, "CopyCat", 0);
// Add two outstanding packets, with second one being handshake
auto expected = addOutstandingPacket(conn);
// There needs to have retransmittable frame for the rebuilder to work
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
addHandshakeOutstandingPacket(conn);
conn.outstandingPackets.back().packet.frames.push_back(
MaxDataFrame(conn.flowControlState.advertisedMaxOffset));
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto result = cloningScheduler.scheduleFramesForPacket(
std::move(builder), kDefaultUDPSendPacketLen);
EXPECT_TRUE(result.first.hasValue() && result.second.hasValue());
EXPECT_EQ(expected, *result.first);
}
TEST_F(QuicPacketSchedulerTest, CloneSchedulerUseNormalSchedulerFirst) {
QuicClientConnectionState conn;
MockFrameScheduler mockScheduler;
CloningScheduler cloningScheduler(mockScheduler, conn, "Mocker", 0);
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
EXPECT_CALL(mockScheduler, hasData()).Times(1).WillOnce(Return(true));
EXPECT_CALL(mockScheduler, _scheduleFramesForPacket(_, _))
.Times(1)
.WillOnce(
Invoke([&, headerCopy = header](
std::unique_ptr<RegularQuicPacketBuilder>&, uint32_t) {
RegularQuicWritePacket packet(headerCopy);
packet.frames.push_back(MaxDataFrame(2832));
RegularQuicPacketBuilder::Packet builtPacket(
std::move(packet),
folly::IOBuf::copyBuffer("if you are the dealer"),
folly::IOBuf::copyBuffer("I'm out of the game"));
return std::make_pair(folly::none, std::move(builtPacket));
}));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto result = cloningScheduler.scheduleFramesForPacket(
std::move(builder), kDefaultUDPSendPacketLen);
EXPECT_EQ(folly::none, result.first);
folly::variant_match(
result.second->packet.header,
[&](const ShortHeader& shortHeader) {
EXPECT_EQ(ProtectionType::KeyPhaseOne, shortHeader.getProtectionType());
EXPECT_EQ(
conn.ackStates.appDataAckState.nextPacketNum,
shortHeader.getPacketSequenceNum());
},
[&](const LongHeader&) {
ASSERT_FALSE(true); // should not happen
});
EXPECT_EQ(1, result.second->packet.frames.size());
folly::variant_match(
result.second->packet.frames.front(),
[&](const MaxDataFrame& frame) { EXPECT_EQ(2832, frame.maximumData); },
[&](const auto&) {
ASSERT_FALSE(true); // should not happen
});
EXPECT_TRUE(folly::IOBufEqualTo{}(
*folly::IOBuf::copyBuffer("if you are the dealer"),
*result.second->header));
EXPECT_TRUE(folly::IOBufEqualTo{}(
*folly::IOBuf::copyBuffer("I'm out of the game"), *result.second->body));
}
TEST_F(QuicPacketSchedulerTest, CloneWillGenerateNewWindowUpdate) {
QuicClientConnectionState conn;
conn.streamManager->setMaxLocalBidirectionalStreams(10);
auto stream = conn.streamManager->createNextBidirectionalStream().value();
FrameScheduler noopScheduler("frame");
CloningScheduler cloningScheduler(noopScheduler, conn, "GiantsShoulder", 0);
auto expectedPacketEvent = addOutstandingPacket(conn);
ASSERT_EQ(1, conn.outstandingPackets.size());
conn.outstandingPackets.back().packet.frames.push_back(MaxDataFrame(1000));
conn.outstandingPackets.back().packet.frames.push_back(
MaxStreamDataFrame(stream->id, 1000));
conn.flowControlState.advertisedMaxOffset = 1000;
stream->flowControlState.advertisedMaxOffset = 1000;
conn.flowControlState.sumCurReadOffset = 300;
conn.flowControlState.windowSize = 3000;
stream->currentReadOffset = 200;
stream->flowControlState.windowSize = 1500;
ShortHeader header(
ProtectionType::KeyPhaseOne,
conn.clientConnectionId.value_or(getTestConnectionId()),
getNextPacketNum(conn, PacketNumberSpace::AppData));
RegularQuicPacketBuilder builder(
conn.udpSendPacketLen,
std::move(header),
conn.ackStates.appDataAckState.largestAckedByPeer);
auto packetResult = cloningScheduler.scheduleFramesForPacket(
std::move(builder), conn.udpSendPacketLen);
EXPECT_EQ(expectedPacketEvent, *packetResult.first);
int32_t verifyConnWindowUpdate = 1, verifyStreamWindowUpdate = 1;
for (const auto& frame : packetResult.second->packet.frames) {
folly::variant_match(
frame,
[&](const MaxStreamDataFrame& maxStreamDataFrame) {
EXPECT_EQ(stream->id, maxStreamDataFrame.streamId);
verifyStreamWindowUpdate--;
},
[&](const MaxDataFrame&) { verifyConnWindowUpdate--; },
[&](const PaddingFrame&) {},
[&](const auto&) {
// should never happen
EXPECT_TRUE(false);
});
}
EXPECT_EQ(0, verifyStreamWindowUpdate);
EXPECT_EQ(0, verifyConnWindowUpdate);
// Verify the built out packet has refreshed window update values
EXPECT_GE(packetResult.second->packet.frames.size(), 2);
uint32_t streamWindowUpdateCounter = 0;
uint32_t connWindowUpdateCounter = 0;
for (auto& streamFlowControl :
all_frames<MaxStreamDataFrame>(packetResult.second->packet.frames)) {
streamWindowUpdateCounter++;
EXPECT_EQ(1700, streamFlowControl.maximumData);
}
for (auto& connFlowControl :
all_frames<MaxDataFrame>(packetResult.second->packet.frames)) {
connWindowUpdateCounter++;
EXPECT_EQ(3300, connFlowControl.maximumData);
}
EXPECT_EQ(1, connWindowUpdateCounter);
EXPECT_EQ(1, streamWindowUpdateCounter);
}
class AckSchedulingTest : public TestWithParam<PacketNumberSpace> {};
TEST_F(QuicPacketSchedulerTest, AckStateHasAcksToSchedule) {
QuicClientConnectionState conn;
EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.initialAckState));
EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.handshakeAckState));
EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.appDataAckState));
conn.ackStates.initialAckState.acks.insert(0, 100);
EXPECT_TRUE(hasAcksToSchedule(conn.ackStates.initialAckState));
conn.ackStates.handshakeAckState.acks.insert(0, 100);
conn.ackStates.handshakeAckState.largestAckScheduled = 200;
EXPECT_FALSE(hasAcksToSchedule(conn.ackStates.handshakeAckState));
conn.ackStates.handshakeAckState.largestAckScheduled = folly::none;
EXPECT_TRUE(hasAcksToSchedule(conn.ackStates.handshakeAckState));
}
TEST_F(QuicPacketSchedulerTest, AckSchedulerHasAcksToSchedule) {
QuicClientConnectionState conn;
AckScheduler initialAckScheduler(
conn, getAckState(conn, PacketNumberSpace::Initial));
AckScheduler handshakeAckScheduler(
conn, getAckState(conn, PacketNumberSpace::Handshake));
AckScheduler appDataAckScheduler(
conn, getAckState(conn, PacketNumberSpace::AppData));
EXPECT_FALSE(initialAckScheduler.hasPendingAcks());
EXPECT_FALSE(handshakeAckScheduler.hasPendingAcks());
EXPECT_FALSE(appDataAckScheduler.hasPendingAcks());
conn.ackStates.initialAckState.acks.insert(0, 100);
EXPECT_TRUE(initialAckScheduler.hasPendingAcks());
conn.ackStates.handshakeAckState.acks.insert(0, 100);
conn.ackStates.handshakeAckState.largestAckScheduled = 200;
EXPECT_FALSE(handshakeAckScheduler.hasPendingAcks());
conn.ackStates.handshakeAckState.largestAckScheduled = folly::none;
EXPECT_TRUE(handshakeAckScheduler.hasPendingAcks());
}
TEST_F(QuicPacketSchedulerTest, ConnHasAcksToSchedule) {
QuicClientConnectionState conn;
EXPECT_FALSE(hasAcksToSchedule(conn));
conn.ackStates.initialAckState.acks.insert(0, 100);
EXPECT_FALSE(hasAcksToSchedule(conn));
conn.initialWriteCipher = createNoOpAead();
EXPECT_TRUE(hasAcksToSchedule(conn));
conn.ackStates.initialAckState.acks.clear();
EXPECT_FALSE(hasAcksToSchedule(conn));
conn.oneRttWriteCipher = createNoOpAead();
EXPECT_FALSE(hasAcksToSchedule(conn));
}
TEST_F(QuicPacketSchedulerTest, LargestAckToSend) {
QuicClientConnectionState conn;
EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.initialAckState));
EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.handshakeAckState));
EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.appDataAckState));
conn.ackStates.initialAckState.acks.insert(0, 50);
conn.ackStates.handshakeAckState.acks.insert(0, 50);
conn.ackStates.handshakeAckState.acks.insert(75, 150);
EXPECT_EQ(50, *largestAckToSend(conn.ackStates.initialAckState));
EXPECT_EQ(150, *largestAckToSend(conn.ackStates.handshakeAckState));
EXPECT_EQ(folly::none, largestAckToSend(conn.ackStates.appDataAckState));
}
} // namespace test
} // namespace quic

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

118
quic/api/test/TARGETS Normal file
View File

@ -0,0 +1,118 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
cpp_library(
name = "mocks",
headers = [
"MockQuicSocket.h",
"MockQuicStats.h",
"Mocks.h",
],
deps = [
"//folly/io/async:async_base",
"//folly/portability:gmock",
"//quic:exception",
"//quic/api:transport",
"//quic/codec:types",
"//quic/common:timers",
"//quic/server:server",
"//quic/state:state_machine",
"//quic/state:stats_callback",
],
)
cpp_unittest(
name = "QuicTransportTest",
srcs = [
"QuicTransportTest.cpp",
],
deps = [
":mocks",
"//fizz/crypto/aead/test:mocks",
"//folly:random",
"//folly/io:iobuf",
"//folly/io/async/test:mocks",
"//quic/api:transport",
"//quic/common:timers",
"//quic/common/test:test_utils",
"//quic/server/state:server",
"//quic/state:stream_functions",
"//quic/state/test:mocks",
],
external_deps = [
("googletest", None, "gmock"),
],
)
cpp_unittest(
name = "QuicTransportBaseTest",
srcs = [
"QuicTransportBaseTest.cpp",
],
deps = [
":mocks",
"//folly/io/async/test:mocks",
"//folly/portability:gmock",
"//folly/portability:gtest",
"//quic/api:transport",
"//quic/codec:types",
"//quic/common/test:test_utils",
"//quic/server/state:server",
"//quic/state:stream_functions",
"//quic/state/test:mocks",
],
)
cpp_unittest(
name = "QuicTransportFunctionsTest",
srcs = [
"QuicTransportFunctionsTest.cpp",
],
deps = [
":mocks",
"//folly/io/async/test:mocks",
"//quic/api:transport",
"//quic/common/test:test_utils",
"//quic/server/state:server",
"//quic/state/test:mocks",
],
)
cpp_unittest(
name = "QuicPacketSchedulerTest",
srcs = [
"QuicPacketSchedulerTest.cpp",
],
deps = [
":mocks",
"//folly/portability:gtest",
"//quic/api:transport",
"//quic/client/state:client",
"//quic/codec:pktbuilder",
"//quic/common/test:test_utils",
"//quic/server/state:server",
],
)
cpp_unittest(
name = "IoBufQuicBatchTest",
srcs = [
"IoBufQuicBatchTest.cpp",
],
deps = [
"//quic/api:transport",
"//quic/state:state_machine",
],
)
cpp_unittest(
name = "QuicBatchWriterTest",
srcs = [
"QuicBatchWriterTest.cpp",
],
deps = [
"//quic/api:transport",
],
)

View File

@ -0,0 +1,63 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
add_library(
mvfst_client STATIC
QuicClientTransport.cpp
handshake/ClientHandshake.cpp
state/ClientStateMachine.cpp
)
target_include_directories(
mvfst_client PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_client
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_client
mvfst_flowcontrol
mvfst_happyeyeballs
mvfst_loss
mvfst_state_ack_handler
mvfst_state_pacing_functions
mvfst_transport
)
target_link_libraries(
mvfst_client PUBLIC
Folly::folly
mvfst_flowcontrol
mvfst_happyeyeballs
mvfst_loss
mvfst_state_ack_handler
mvfst_state_pacing_functions
mvfst_transport
)
file(
GLOB_RECURSE QUIC_API_HEADERS_TOINSTALL
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.h
)
list(FILTER QUIC_API_HEADERS_TOINSTALL EXCLUDE REGEX test/)
foreach(header ${QUIC_API_HEADERS_TOINSTALL})
get_filename_component(header_dir ${header} DIRECTORY)
install(FILES ${header} DESTINATION include/quic/client/${header_dir})
endforeach()
install(
TARGETS mvfst_client
EXPORT mvfst-exports
DESTINATION lib
)
add_subdirectory(test)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,204 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Random.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncUDPSocket.h>
#include <quic/api/QuicTransportBase.h>
#include <quic/client/handshake/QuicPskCache.h>
#include <quic/client/state/ClientStateMachine.h>
#include <sys/socket.h>
namespace quic {
class QuicClientTransport
: public QuicTransportBase,
public folly::AsyncUDPSocket::ReadCallback,
public folly::AsyncUDPSocket::ErrMessageCallback,
public std::enable_shared_from_this<QuicClientTransport>,
private ClientHandshake::HandshakeCallback {
public:
QuicClientTransport(
folly::EventBase* evb,
std::unique_ptr<folly::AsyncUDPSocket> socket);
~QuicClientTransport() override;
/**
* Returns an un-connected QuicClientTransport which is self-owning.
* The transport is cleaned up when the app calls close() or closeNow() on the
* transport, or on receiving a terminal ConnectionCallback supplied on
* start().
* The transport is self owning in this case is to be able to
* deal with cases where the app wants to dispose of the transport, however
* the peer is still sending us packets. If we do not keep the transport alive
* for this period, the kernel will generate unwanted ICMP echo messages.
*/
template <class TransportType = QuicClientTransport>
static std::shared_ptr<TransportType> newClient(
folly::EventBase* evb,
std::unique_ptr<folly::AsyncUDPSocket> sock) {
auto client = std::make_shared<TransportType>(evb, std::move(sock));
client->setSelfOwning();
return client;
}
/**
* Supply the hostname to use to validate the server. Must be set before
* start().
*/
void setHostname(const std::string& hostname);
/**
* Set the client context for fizz. Must be set before start()
*/
void setFizzClientContext(
std::shared_ptr<const fizz::client::FizzClientContext> ctx);
/**
* Set a custom certificate verifier. Must be set before start().
*/
void setCertificateVerifier(
std::shared_ptr<const fizz::CertificateVerifier> verifier);
/**
* Supplies a new peer address to use for the connection. This must be called
* at least once before start().
*/
void addNewPeerAddress(folly::SocketAddress peerAddress);
void addNewSocket(std::unique_ptr<folly::AsyncUDPSocket> socket);
void setHappyEyeballsEnabled(bool happyEyeballsEnabled);
virtual void setHappyEyeballsCachedFamily(sa_family_t cachedFamily);
/**
* Set the cache that remembers psk and server transport parameters from
* last connection. This is useful for session resumption and 0-rtt.
*/
void setPskCache(std::shared_ptr<QuicPskCache> pskCache);
/**
* Starts the connection.
*/
virtual void start(ConnectionCallback* cb);
/**
* Returns whether or not TLS is resumed.
*/
bool isTLSResumed() const;
// From QuicTransportBase
void onReadData(const folly::SocketAddress& peer, NetworkData&& networkData)
override;
void writeData() override;
void closeTransport() override;
void unbindConnection() override;
bool hasWriteCipher() const override;
std::shared_ptr<QuicTransportBase> sharedGuard() override;
// folly::AsyncUDPSocket::ReadCallback
void onReadClosed() noexcept override {}
void onReadError(const folly::AsyncSocketException&) noexcept override {}
// folly::AsyncUDPSocket::ErrMessageCallback
void errMessage(const cmsghdr& cmsg) noexcept override;
void errMessageError(const folly::AsyncSocketException&) noexcept override {}
/**
* Make QuicClient transport self owning.
*/
void setSelfOwning();
/**
* Used to set private transport parameters that are not in the
* TransportParameterId enum.
* As per section 22.2 of the IETF QUIC draft version 17, private transport
* parameters must have IDs greater than 0x3fff.
*/
bool setCustomTransportParameter(
std::unique_ptr<CustomTransportParameter> customParam);
class HappyEyeballsConnAttemptDelayTimeout
: public folly::HHWheelTimer::Callback {
public:
explicit HappyEyeballsConnAttemptDelayTimeout(
QuicClientTransport* transport)
: transport_(transport) {}
void timeoutExpired() noexcept override {
transport_->happyEyeballsConnAttemptDelayTimeoutExpired();
}
void callbackCanceled() noexcept override {}
private:
QuicClientTransport* transport_;
};
protected:
void getReadBuffer(void** buf, size_t* len) noexcept override;
// From AsyncUDPSocket::ReadCallback
void onDataAvailable(
const folly::SocketAddress& server,
size_t len,
bool truncated) noexcept override;
void processUDPData(
const folly::SocketAddress& peer,
NetworkData&& networkData);
void processPacketData(
const folly::SocketAddress& peer,
TimePoint receiveTimePoint,
folly::IOBufQueue& packetQueue);
void startCryptoHandshake();
void happyEyeballsConnAttemptDelayTimeoutExpired() noexcept;
// From ClientHandshake::HandshakeCallback
void onNewCachedPsk(
fizz::client::NewCachedPsk& newCachedPsk) noexcept override;
Buf readBuffer_;
folly::Optional<std::string> hostname_;
std::shared_ptr<const fizz::client::FizzClientContext> ctx_;
std::shared_ptr<const fizz::CertificateVerifier> verifier_;
HappyEyeballsConnAttemptDelayTimeout happyEyeballsConnAttemptDelayTimeout_;
bool serverInitialParamsSet_{false};
uint64_t peerAdvertisedInitialMaxData_{0};
uint64_t peerAdvertisedInitialMaxStreamDataBidiLocal_{0};
uint64_t peerAdvertisedInitialMaxStreamDataBidiRemote_{0};
uint64_t peerAdvertisedInitialMaxStreamDataUni_{0};
private:
void cacheServerInitialParams(
uint64_t peerAdvertisedInitialMaxData,
uint64_t peerAdvertisedInitialMaxStreamDataBidiLocal,
uint64_t peerAdvertisedInitialMaxStreamDataBidiRemote,
uint64_t peerAdvertisedInitialMaxStreamDataUni);
folly::Optional<QuicCachedPsk> getPsk();
void removePsk();
void setPartialReliabilityTransportParameter();
private:
bool replaySafeNotified_{false};
// Set it QuicClientTransport is in a self owning mode. This will be cleaned
// up when the caller invokes a terminal call to the transport.
std::shared_ptr<QuicClientTransport> selfOwning_;
bool happyEyeballsEnabled_{false};
sa_family_t happyEyeballsCachedFamily_{AF_UNSPEC};
std::shared_ptr<QuicPskCache> pskCache_;
QuicClientConnectionState* clientConn_;
std::vector<TransportParameter> customTransportParameters_;
};
} // namespace quic

28
quic/client/TARGETS Normal file
View File

@ -0,0 +1,28 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
cpp_library(
name = "client",
srcs = [
"QuicClientTransport.cpp",
],
headers = [
"QuicClientTransport.h",
],
deps = [
"//folly:network_address",
"//folly:random",
"//folly/io/async:async_udp_socket",
"//folly/portability:sockets",
"//quic/api:transport",
"//quic/client/handshake:client_extension",
"//quic/client/handshake:psk_cache",
"//quic/client/state:client",
"//quic/flowcontrol:flow_control",
"//quic/happyeyeballs:happyeyeballs",
"//quic/loss:loss",
"//quic/state:ack_handler",
"//quic/state:pacing_functions",
],
)

View File

@ -0,0 +1,394 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/client/handshake/ClientHandshake.h>
#include <fizz/protocol/Protocol.h>
#include <quic/state/QuicStreamFunctions.h>
namespace quic {
ClientHandshake::ClientHandshake(QuicCryptoState& cryptoState)
: cryptoState_(cryptoState), visitor_(*this) {}
void ClientHandshake::connect(
std::shared_ptr<const fizz::client::FizzClientContext> context,
std::shared_ptr<const fizz::CertificateVerifier> verifier,
folly::Optional<std::string> hostname,
folly::Optional<fizz::client::CachedPsk> cachedPsk,
const std::shared_ptr<ClientTransportParametersExtension>& transportParams,
HandshakeCallback* callback) {
transportParams_ = transportParams;
callback_ = callback;
auto ctx = std::make_shared<fizz::client::FizzClientContext>(*context);
ctx->setFactory(std::make_shared<QuicFizzFactory>());
ctx->setCompatibilityMode(false);
// Since Draft-17, EOED should not be sent
ctx->setOmitEarlyRecordLayer(true);
processActions(machine_.processConnect(
state_,
std::move(ctx),
std::move(verifier),
std::move(hostname),
std::move(cachedPsk),
transportParams));
}
void ClientHandshake::doHandshake(
std::unique_ptr<folly::IOBuf> data,
fizz::EncryptionLevel encryptionLevel) {
if (!data) {
return;
}
// TODO: deal with clear text alert messages. It's possible that a MITM who
// mucks with the finished messages could cause the decryption to be invalid
// on the server, which would result in a cleartext close or a cleartext
// alert. We currently switch to 1-rtt ciphers immediately for reads and
// throw away the cleartext cipher for reads, this would result in us
// dropping the alert and timing out instead.
if (phase_ == Phase::Initial) {
// This could be an HRR or a cleartext alert.
phase_ = Phase::Handshake;
}
// First add it to the right read buffer.
switch (encryptionLevel) {
case fizz::EncryptionLevel::Plaintext:
initialReadBuf_.append(std::move(data));
break;
case fizz::EncryptionLevel::Handshake:
handshakeReadBuf_.append(std::move(data));
break;
case fizz::EncryptionLevel::EarlyData:
case fizz::EncryptionLevel::AppTraffic:
appDataReadBuf_.append(std::move(data));
break;
}
// Get the current buffer type the transport is accepting.
waitForData_ = false;
while (!waitForData_) {
switch (state_.readRecordLayer()->getEncryptionLevel()) {
case fizz::EncryptionLevel::Plaintext:
processActions(machine_.processSocketData(state_, initialReadBuf_));
break;
case fizz::EncryptionLevel::Handshake:
processActions(machine_.processSocketData(state_, handshakeReadBuf_));
break;
case fizz::EncryptionLevel::EarlyData:
case fizz::EncryptionLevel::AppTraffic:
processActions(machine_.processSocketData(state_, appDataReadBuf_));
break;
}
if (error_) {
error_.throw_exception();
}
}
}
std::unique_ptr<fizz::Aead> ClientHandshake::getOneRttWriteCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(oneRttWriteCipher_);
}
std::unique_ptr<fizz::Aead> ClientHandshake::getOneRttReadCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(oneRttReadCipher_);
}
std::unique_ptr<fizz::Aead> ClientHandshake::getZeroRttWriteCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(zeroRttWriteCipher_);
}
std::unique_ptr<fizz::Aead> ClientHandshake::getHandshakeReadCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(handshakeReadCipher_);
}
std::unique_ptr<fizz::Aead> ClientHandshake::getHandshakeWriteCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(handshakeWriteCipher_);
}
std::unique_ptr<PacketNumberCipher>
ClientHandshake::getOneRttReadHeaderCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(oneRttReadHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ClientHandshake::getOneRttWriteHeaderCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(oneRttWriteHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ClientHandshake::getHandshakeReadHeaderCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(handshakeReadHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ClientHandshake::getHandshakeWriteHeaderCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(handshakeWriteHeaderCipher_);
}
std::unique_ptr<PacketNumberCipher>
ClientHandshake::getZeroRttWriteHeaderCipher() {
if (error_) {
error_.throw_exception();
}
return std::move(zeroRttWriteHeaderCipher_);
}
/**
* Notify the crypto layer that we received one rtt protected data.
* This allows us to know that the peer has implicitly acked the 1-rtt keys.
*/
void ClientHandshake::onRecvOneRttProtectedData() {
if (phase_ != Phase::Established) {
phase_ = Phase::Established;
}
}
ClientHandshake::Phase ClientHandshake::getPhase() const {
return phase_;
}
folly::Optional<ServerTransportParameters>
ClientHandshake::getServerTransportParams() {
return transportParams_->getServerTransportParams();
}
bool ClientHandshake::isTLSResumed() const {
auto pskType = state_.pskType();
return pskType && *pskType == fizz::PskType::Resumption;
}
folly::Optional<bool> ClientHandshake::getZeroRttRejected() {
return std::move(zeroRttRejected_);
}
const fizz::client::State& ClientHandshake::getState() const {
return state_;
}
const folly::Optional<std::string>& ClientHandshake::getApplicationProtocol()
const {
auto& earlyDataParams = state_.earlyDataParams();
if (earlyDataParams) {
return earlyDataParams->alpn;
} else {
return state_.alpn();
}
}
void ClientHandshake::computeOneRttCipher(
const fizz::client::ReportHandshakeSuccess& handshakeSuccess) {
// The 1-rtt handshake should have succeeded if we know that the early
// write failed. We currently treat the data as lost.
// TODO: we need to deal with HRR based rejection as well, however we don't
// have an API right now.
if (earlyDataAttempted_ && !handshakeSuccess.earlyDataAccepted) {
if (fizz::client::earlyParametersMatch(state_)) {
zeroRttRejected_ = true;
} else {
// TODO: support app retry of zero rtt data.
error_ = folly::make_exception_wrapper<QuicInternalException>(
"Changing parameters when early data attempted not supported",
LocalErrorCode::EARLY_DATA_REJECTED);
return;
}
}
// After a successful handshake we should send packets with the type of
// ClientCleartext. We assume that by the time we get the data for the QUIC
// stream, the server would have also acked all the client initial packets.
phase_ = Phase::OneRttKeysDerived;
}
void ClientHandshake::computeZeroRttCipher() {
VLOG(10) << "Computing Client zero rtt keys";
auto& earlyDataParams = state_.earlyDataParams();
if (!earlyDataParams) {
error_ = folly::make_exception_wrapper<QuicTransportException>(
"Invalid early data params", TransportErrorCode::TLS_HANDSHAKE_FAILED);
return;
}
earlyDataAttempted_ = true;
CHECK(state_.earlyDataParams().hasValue());
}
void ClientHandshake::processActions(fizz::client::Actions actions) {
for (auto& action : actions) {
boost::apply_visitor(visitor_, action);
}
}
ClientHandshake::ActionMoveVisitor::ActionMoveVisitor(ClientHandshake& client)
: client_(client) {}
void ClientHandshake::ActionMoveVisitor::operator()(fizz::DeliverAppData&) {
client_.error_ = folly::make_exception_wrapper<QuicTransportException>(
"Invalid app data on crypto stream",
TransportErrorCode::PROTOCOL_VIOLATION);
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::WriteToSocket& write) {
for (auto& content : write.contents) {
auto& cryptoState = client_.cryptoState_;
if (content.encryptionLevel == fizz::EncryptionLevel::AppTraffic) {
// Don't write 1-rtt handshake data on the client.
continue;
}
auto cryptoStream = getCryptoStream(cryptoState, content.encryptionLevel);
writeDataToQuicStream(*cryptoStream, std::move(content.data));
}
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::client::ReportEarlyHandshakeSuccess&) {
client_.computeZeroRttCipher();
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::client::ReportHandshakeSuccess& handshakeSuccess) {
client_.computeOneRttCipher(handshakeSuccess);
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::client::ReportEarlyWriteFailed&) {
LOG(DFATAL) << "QUIC TLS app data write";
}
void ClientHandshake::ActionMoveVisitor::operator()(fizz::ReportError& err) {
auto errMsg = err.error.what();
if (errMsg.empty()) {
errMsg = "Error during handshake";
}
client_.error_ = folly::make_exception_wrapper<QuicTransportException>(
errMsg.toStdString(), TransportErrorCode::TLS_HANDSHAKE_FAILED);
}
void ClientHandshake::ActionMoveVisitor::operator()(fizz::WaitForData&) {
client_.waitForData_ = true;
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::client::MutateState& mutator) {
mutator(client_.state_);
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::client::NewCachedPsk& newCachedPsk) {
if (client_.callback_) {
client_.callback_->onNewCachedPsk(newCachedPsk);
}
}
void ClientHandshake::ActionMoveVisitor::operator()(fizz::EndOfData&) {
client_.error_ = folly::make_exception_wrapper<QuicTransportException>(
"unexpected close notify", TransportErrorCode::INTERNAL_ERROR);
}
void ClientHandshake::ActionMoveVisitor::operator()(
fizz::SecretAvailable& secretAvailable) {
QuicFizzFactory factory;
folly::variant_match(
secretAvailable.secret.type,
[&](fizz::EarlySecrets earlySecrets) {
switch (earlySecrets) {
case fizz::EarlySecrets::ClientEarlyTraffic: {
auto cipher = client_.state_.earlyDataParams()->cipher;
auto keyScheduler =
client_.state_.context()->getFactory()->makeKeyScheduler(
cipher);
client_.zeroRttWriteCipher_ =
fizz::Protocol::deriveRecordAeadWithLabel(
*client_.state_.context()->getFactory(),
*keyScheduler,
cipher,
folly::range(secretAvailable.secret.secret),
kQuicKeyLabel,
kQuicIVLabel);
client_.zeroRttWriteHeaderCipher_ = makePacketNumberCipher(
&factory, folly::range(secretAvailable.secret.secret), cipher);
break;
}
default:
break;
}
},
[&](fizz::HandshakeSecrets handshakeSecrets) {
auto aead = fizz::Protocol::deriveRecordAeadWithLabel(
*client_.state_.context()->getFactory(),
*client_.state_.keyScheduler(),
*client_.state_.cipher(),
folly::range(secretAvailable.secret.secret),
kQuicKeyLabel,
kQuicIVLabel);
auto headerCipher = makePacketNumberCipher(
&factory,
folly::range(secretAvailable.secret.secret),
*client_.state_.cipher());
switch (handshakeSecrets) {
case fizz::HandshakeSecrets::ClientHandshakeTraffic:
client_.handshakeWriteCipher_ = std::move(aead);
client_.handshakeWriteHeaderCipher_ = std::move(headerCipher);
break;
case fizz::HandshakeSecrets::ServerHandshakeTraffic:
client_.handshakeReadCipher_ = std::move(aead);
client_.handshakeReadHeaderCipher_ = std::move(headerCipher);
break;
}
},
[&](fizz::AppTrafficSecrets appSecrets) {
auto aead = fizz::Protocol::deriveRecordAeadWithLabel(
*client_.state_.context()->getFactory(),
*client_.state_.keyScheduler(),
*client_.state_.cipher(),
folly::range(secretAvailable.secret.secret),
kQuicKeyLabel,
kQuicIVLabel);
auto appHeaderCipher = makePacketNumberCipher(
&factory,
folly::range(secretAvailable.secret.secret),
*client_.state_.cipher());
switch (appSecrets) {
case fizz::AppTrafficSecrets::ClientAppTraffic:
client_.oneRttWriteCipher_ = std::move(aead);
client_.oneRttWriteHeaderCipher_ = std::move(appHeaderCipher);
break;
case fizz::AppTrafficSecrets::ServerAppTraffic:
client_.oneRttReadCipher_ = std::move(aead);
client_.oneRttReadHeaderCipher_ = std::move(appHeaderCipher);
break;
}
},
[&](auto) {});
}
} // namespace quic

View File

@ -0,0 +1,230 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <fizz/client/ClientProtocol.h>
#include <fizz/client/EarlyDataRejectionPolicy.h>
#include <fizz/client/FizzClientContext.h>
#include <fizz/client/PskCache.h>
#include <fizz/protocol/DefaultCertificateVerifier.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/DelayedDestruction.h>
#include <folly/ExceptionWrapper.h>
#include <quic/QuicConstants.h>
#include <quic/QuicException.h>
#include <quic/client/handshake/ClientTransportParametersExtension.h>
#include <quic/client/handshake/QuicPskCache.h>
#include <quic/handshake/HandshakeLayer.h>
#include <quic/state/StateData.h>
namespace quic {
class ClientHandshake : public Handshake {
public:
class HandshakeCallback {
public:
virtual ~HandshakeCallback() = default;
virtual void onNewCachedPsk(fizz::client::NewCachedPsk&) noexcept = 0;
};
enum class Phase { Initial, Handshake, OneRttKeysDerived, Established };
explicit ClientHandshake(QuicCryptoState& cryptoState);
/**
* Initiate the handshake with the supplied parameters.
*/
virtual void connect(
std::shared_ptr<const fizz::client::FizzClientContext> context,
std::shared_ptr<const fizz::CertificateVerifier> verifier,
folly::Optional<std::string> hostname,
folly::Optional<fizz::client::CachedPsk> cachedPsk,
const std::shared_ptr<ClientTransportParametersExtension>&
transportParams,
HandshakeCallback* callback);
/**
* Takes input bytes from the network and processes then in the handshake.
* This can change the state of the transport which may result in ciphers
* being initialized, bytes written out, or the write phase changing.
*/
virtual void doHandshake(
std::unique_ptr<folly::IOBuf> data,
fizz::EncryptionLevel encryptionLevel);
/**
* An edge triggered API to get the oneRttWriteCipher. Once you receive the
* write cipher subsequent calls will return null.
*/
std::unique_ptr<fizz::Aead> getOneRttWriteCipher();
/**
* An edge triggered API to get the oneRttReadCipher. Once you receive the
* read cipher subsequent calls will return null.
*/
std::unique_ptr<fizz::Aead> getOneRttReadCipher();
/**
* An edge triggered API to get the zeroRttWriteCipher. Once you receive the
* zero rtt write cipher subsequent calls will return null.
*/
std::unique_ptr<fizz::Aead> getZeroRttWriteCipher();
/**
* An edge triggered API to get the handshakeReadCipher. Once you
* receive the handshake read cipher subsequent calls will return null.
*/
std::unique_ptr<fizz::Aead> getHandshakeReadCipher();
/**
* An edge triggered API to get the handshakeWriteCipher. Once you
* receive the handshake write cipher subsequent calls will return null.
*/
std::unique_ptr<fizz::Aead> getHandshakeWriteCipher();
/**
* An edge triggered API to get the one rtt read header cpher. Once you
* receive the header cipher subsequent calls will return null.
*/
std::unique_ptr<PacketNumberCipher> getOneRttReadHeaderCipher();
/**
* An edge triggered API to get the one rtt write header cpher. Once you
* receive the header cipher subsequent calls will return null.
*/
std::unique_ptr<PacketNumberCipher> getOneRttWriteHeaderCipher();
/**
* An edge triggered API to get the handshake rtt read header cpher. Once you
* receive the header cipher subsequent calls will return null.
*/
std::unique_ptr<PacketNumberCipher> getHandshakeReadHeaderCipher();
/**
* An edge triggered API to get the handshake rtt write header cpher. Once you
* receive the header cipher subsequent calls will return null.
*/
std::unique_ptr<PacketNumberCipher> getHandshakeWriteHeaderCipher();
/**
* An edge triggered API to get the zero rtt write header cpher. Once you
* receive the header cipher subsequent calls will return null.
*/
std::unique_ptr<PacketNumberCipher> getZeroRttWriteHeaderCipher();
/**
* Notify the crypto layer that we received one rtt protected data.
* This allows us to know that the peer has implicitly acked the 1-rtt keys.
*/
void onRecvOneRttProtectedData();
Phase getPhase() const;
/**
* Was the TLS connection resumed or not.
*/
bool isTLSResumed() const;
/**
* Edge triggered api to obtain whether or not zero rtt data was rejected.
* If zero rtt was never attempted, then this will return folly::none. Once
* the result is obtained, the result is cleared out.
*/
folly::Optional<bool> getZeroRttRejected();
/**
* Returns the state of the TLS connection.
*/
const fizz::client::State& getState() const;
/**
* Returns the application protocol that was negotiated by the handshake.
*/
const folly::Optional<std::string>& getApplicationProtocol() const override;
/**
* Returns the negotiated transport parameters chosen by the server
*/
virtual folly::Optional<ServerTransportParameters> getServerTransportParams();
class ActionMoveVisitor : public boost::static_visitor<> {
public:
explicit ActionMoveVisitor(ClientHandshake& client);
void operator()(fizz::DeliverAppData&);
void operator()(fizz::WriteToSocket& write);
void operator()(fizz::client::ReportEarlyHandshakeSuccess&);
void operator()(fizz::client::ReportHandshakeSuccess& handshakeSuccess);
void operator()(fizz::client::ReportEarlyWriteFailed&);
void operator()(fizz::ReportError& err);
void operator()(fizz::WaitForData&);
void operator()(fizz::client::MutateState& mutator);
void operator()(fizz::client::NewCachedPsk& newCachedPsk);
void operator()(fizz::SecretAvailable& secret);
void operator()(fizz::EndOfData&);
private:
ClientHandshake& client_;
};
virtual ~ClientHandshake() = default;
protected:
// Represents the packet type that should be used to write the data currently
// in the stream.
Phase phase_{Phase::Initial};
std::unique_ptr<fizz::Aead> handshakeWriteCipher_;
std::unique_ptr<fizz::Aead> handshakeReadCipher_;
std::unique_ptr<fizz::Aead> oneRttReadCipher_;
std::unique_ptr<fizz::Aead> oneRttWriteCipher_;
std::unique_ptr<fizz::Aead> zeroRttWriteCipher_;
std::unique_ptr<PacketNumberCipher> oneRttReadHeaderCipher_;
std::unique_ptr<PacketNumberCipher> oneRttWriteHeaderCipher_;
std::unique_ptr<PacketNumberCipher> handshakeReadHeaderCipher_;
std::unique_ptr<PacketNumberCipher> handshakeWriteHeaderCipher_;
std::unique_ptr<PacketNumberCipher> zeroRttWriteHeaderCipher_;
folly::Optional<bool> zeroRttRejected_;
HandshakeCallback* callback_{nullptr};
QuicCryptoState& cryptoState_;
private:
void computeOneRttCipher(
const fizz::client::ReportHandshakeSuccess& handshakeSuccess);
void computeZeroRttCipher();
void processActions(fizz::client::Actions actions);
fizz::client::State state_;
fizz::client::ClientStateMachine machine_;
// Whether or not to wait for more data.
bool waitForData_{false};
folly::IOBufQueue initialReadBuf_{folly::IOBufQueue::cacheChainLength()};
folly::IOBufQueue handshakeReadBuf_{folly::IOBufQueue::cacheChainLength()};
folly::IOBufQueue appDataReadBuf_{folly::IOBufQueue::cacheChainLength()};
folly::exception_wrapper error_;
ActionMoveVisitor visitor_;
std::shared_ptr<const fizz::client::FizzClientContext> fizzClientContext_;
folly::Optional<std::string> pskIdentity_;
std::shared_ptr<ClientTransportParametersExtension> transportParams_;
bool earlyDataAttempted_{false};
};
} // namespace quic

View File

@ -0,0 +1,105 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <fizz/client/ClientExtensions.h>
#include <quic/handshake/TransportParameters.h>
namespace quic {
class ClientTransportParametersExtension : public fizz::ClientExtensions {
public:
ClientTransportParametersExtension(
QuicVersion initialVersion,
uint64_t initialMaxData,
uint64_t initialMaxStreamDataBidiLocal,
uint64_t initialMaxStreamDataBidiRemote,
uint64_t initialMaxStreamDataUni,
std::chrono::seconds idleTimeout,
uint64_t ackDelayExponent,
uint64_t maxRecvPacketSize,
std::vector<TransportParameter> customTransportParameters =
std::vector<TransportParameter>())
: initialVersion_(initialVersion),
initialMaxData_(initialMaxData),
initialMaxStreamDataBidiLocal_(initialMaxStreamDataBidiLocal),
initialMaxStreamDataBidiRemote_(initialMaxStreamDataBidiRemote),
initialMaxStreamDataUni_(initialMaxStreamDataUni),
idleTimeout_(idleTimeout),
ackDelayExponent_(ackDelayExponent),
maxRecvPacketSize_(maxRecvPacketSize),
customTransportParameters_(customTransportParameters) {}
~ClientTransportParametersExtension() override = default;
std::vector<fizz::Extension> getClientHelloExtensions() const override {
std::vector<fizz::Extension> exts;
ClientTransportParameters params;
params.initial_version = initialVersion_;
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_local,
initialMaxStreamDataBidiLocal_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_remote,
initialMaxStreamDataBidiRemote_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_stream_data_uni,
initialMaxStreamDataUni_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_data, initialMaxData_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_streams_bidi,
std::numeric_limits<uint32_t>::max()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::initial_max_streams_uni,
std::numeric_limits<uint32_t>::max()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::idle_timeout, idleTimeout_.count()));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::ack_delay_exponent, ackDelayExponent_));
params.parameters.push_back(encodeIntegerParameter(
TransportParameterId::max_packet_size, maxRecvPacketSize_));
for (const auto& customParameter : customTransportParameters_) {
params.parameters.push_back(customParameter);
}
exts.push_back(encodeExtension(params));
return exts;
}
void onEncryptedExtensions(
const std::vector<fizz::Extension>& exts) override {
auto serverParams = fizz::getExtension<ServerTransportParameters>(exts);
if (!serverParams) {
throw fizz::FizzException(
"missing server quic transport parameters extension",
fizz::AlertDescription::missing_extension);
}
serverTransportParameters_ = std::move(serverParams);
}
folly::Optional<ServerTransportParameters> getServerTransportParams() {
return std::move(serverTransportParameters_);
}
private:
QuicVersion initialVersion_;
uint64_t initialMaxData_;
uint64_t initialMaxStreamDataBidiLocal_;
uint64_t initialMaxStreamDataBidiRemote_;
uint64_t initialMaxStreamDataUni_;
std::chrono::seconds idleTimeout_;
uint64_t ackDelayExponent_;
uint64_t maxRecvPacketSize_;
folly::Optional<ServerTransportParameters> serverTransportParameters_;
std::vector<TransportParameter> customTransportParameters_;
};
} // namespace quic

View File

@ -0,0 +1,75 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <quic/QuicConstants.h>
#include <fizz/client/PskCache.h>
#include <folly/Optional.h>
#include <cstdint>
#include <string>
namespace quic {
struct CachedServerTransportParameters {
QuicVersion negotiatedVersion;
uint64_t initialMaxStreamDataBidiLocal;
uint64_t initialMaxStreamDataBidiRemote;
uint64_t initialMaxStreamDataUni;
uint64_t initialMaxData;
uint64_t idleTimeout;
uint64_t maxRecvPacketSize;
uint64_t ackDelayExponent;
};
struct QuicCachedPsk {
fizz::client::CachedPsk cachedPsk;
CachedServerTransportParameters transportParams;
std::string appParams;
};
class QuicPskCache {
public:
virtual ~QuicPskCache() = default;
virtual folly::Optional<QuicCachedPsk> getPsk(const std::string&) = 0;
virtual void putPsk(const std::string&, QuicCachedPsk) = 0;
virtual void removePsk(const std::string&) = 0;
};
/**
* Basic PSK cache that stores PSKs in a hash map. There is no bound on the size
* of this cache.
*/
class BasicQuicPskCache : public QuicPskCache {
public:
~BasicQuicPskCache() override = default;
folly::Optional<QuicCachedPsk> getPsk(const std::string& identity) override {
auto result = cache_.find(identity);
if (result != cache_.end()) {
return result->second;
}
return folly::none;
}
void putPsk(const std::string& identity, QuicCachedPsk psk) override {
cache_[identity] = std::move(psk);
}
void removePsk(const std::string& identity) override {
cache_.erase(identity);
}
private:
std::unordered_map<std::string, QuicCachedPsk> cache_;
};
} // namespace quic

View File

@ -0,0 +1,54 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
cpp_library(
name = "client_handshake",
srcs = [
"ClientHandshake.cpp",
],
headers = [
"ClientHandshake.h",
],
deps = [
":client_extension",
":psk_cache",
"//fizz/client:early_data_rejection",
"//fizz/client:fizz_client_context",
"//fizz/client:protocol",
"//fizz/client:psk_cache",
"//fizz/protocol:default_certificate_verifier",
"//fizz/protocol:protocol",
"//folly:exception_wrapper",
"//folly/io:iobuf",
"//folly/io/async:delayed_destruction",
"//quic:constants",
"//quic:exception",
"//quic/handshake:handshake",
"//quic/state:state_machine",
"//quic/state:stream_functions",
],
)
cpp_library(
name = "client_extension",
headers = [
"ClientTransportParametersExtension.h",
],
deps = [
"//fizz/client:client_extensions",
"//quic/handshake:transport_parameters",
],
)
cpp_library(
name = "psk_cache",
headers = [
"QuicPskCache.h",
],
deps = [
"//fizz/client:psk_cache",
"//folly:optional",
"//quic:constants",
],
)

View File

@ -0,0 +1,522 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <condition_variable>
#include <mutex>
#include <fizz/crypto/test/TestUtil.h>
#include <fizz/protocol/test/Mocks.h>
#include <fizz/server/Actions.h>
#include <fizz/server/test/Mocks.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/io/async/test/MockAsyncTransport.h>
#include <folly/ssl/Init.h>
#include <quic/client/handshake/ClientHandshake.h>
#include <quic/client/handshake/test/MockQuicPskCache.h>
#include <quic/common/test/TestUtils.h>
#include <quic/state/QuicStreamFunctions.h>
#include <quic/state/StateData.h>
using namespace std;
using namespace quic;
using namespace folly;
using namespace folly::test;
using namespace folly::ssl;
using namespace testing;
namespace quic {
namespace test {
class ClientHandshakeTest : public Test, public boost::static_visitor<> {
public:
~ClientHandshakeTest() override = default;
ClientHandshakeTest() {}
virtual void setupClientAndServerContext() {
clientCtx = std::make_shared<fizz::client::FizzClientContext>();
}
QuicVersion getVersion() {
return QuicVersion::QUIC_DRAFT;
}
virtual void connect() {
handshake->connect(
clientCtx,
verifier,
hostname,
folly::none,
std::make_shared<ClientTransportParametersExtension>(
getVersion(),
folly::to<uint32_t>(kDefaultConnectionWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen),
nullptr);
}
void SetUp() override {
folly::ssl::init();
dg.reset(new DelayedHolder());
serverCtx = ::quic::test::createServerCtx();
serverCtx->setOmitEarlyRecordLayer(true);
// Fizz is the name of the identity for our server certificate.
hostname = "Fizz";
setupClientAndServerContext();
verifier = std::make_shared<fizz::test::MockCertificateVerifier>();
handshake.reset(new ClientHandshake(cryptoState));
std::vector<QuicVersion> supportedVersions = {getVersion()};
auto serverTransportParameters =
std::make_shared<ServerTransportParametersExtension>(
getVersion(),
supportedVersions,
folly::to<uint32_t>(kDefaultConnectionWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
kDefaultPartialReliability);
fizzServer.reset(
new fizz::server::
FizzServer<ClientHandshakeTest, fizz::server::ServerStateMachine>(
serverState, serverReadBuf, *this, dg.get()));
connect();
processHandshake();
fizzServer->accept(&evb, serverCtx, serverTransportParameters);
}
void clientServerRound() {
auto writableBytes = getHandshakeWriteBytes();
serverReadBuf.append(std::move(writableBytes));
fizzServer->newTransportData();
evb.loop();
}
void serverClientRound() {
evb.loop();
for (auto& write : serverOutput) {
for (auto& content : write.contents) {
handshake->doHandshake(
std::move(content.data), content.encryptionLevel);
}
}
processHandshake();
}
void processHandshake() {
auto oneRttWriteCipherTmp = handshake->getOneRttWriteCipher();
auto oneRttReadCipherTmp = handshake->getOneRttReadCipher();
auto zeroRttWriteCipherTmp = handshake->getZeroRttWriteCipher();
auto handshakeWriteCipherTmp = handshake->getHandshakeWriteCipher();
auto handshakeReadCipherTmp = handshake->getHandshakeReadCipher();
if (oneRttWriteCipherTmp) {
oneRttWriteCipher = std::move(oneRttWriteCipherTmp);
}
if (oneRttReadCipherTmp) {
oneRttReadCipher = std::move(oneRttReadCipherTmp);
}
if (zeroRttWriteCipherTmp) {
zeroRttWriteCipher = std::move(zeroRttWriteCipherTmp);
}
if (handshakeWriteCipherTmp) {
handshakeWriteCipher = std::move(handshakeWriteCipherTmp);
}
if (handshakeReadCipherTmp) {
handshakeReadCipher = std::move(handshakeReadCipherTmp);
}
auto rejected = handshake->getZeroRttRejected();
if (rejected) {
zeroRttRejected = std::move(rejected);
}
}
void expectHandshakeCipher(bool expected) {
EXPECT_EQ(handshakeReadCipher != nullptr, expected);
EXPECT_EQ(handshakeWriteCipher != nullptr, expected);
}
void expectOneRttCipher(bool expected, bool oneRttOnly = false) {
if (expected) {
EXPECT_NE(oneRttReadCipher.get(), nullptr);
EXPECT_NE(oneRttWriteCipher.get(), nullptr);
} else {
EXPECT_EQ(oneRttReadCipher.get(), nullptr);
EXPECT_EQ(oneRttWriteCipher.get(), nullptr);
}
if (!oneRttOnly) {
EXPECT_EQ(zeroRttWriteCipher.get(), nullptr);
}
}
void expectZeroRttCipher(bool expected, bool expectOneRtt) {
if (expected) {
EXPECT_NE(zeroRttWriteCipher.get(), nullptr);
} else {
EXPECT_EQ(zeroRttWriteCipher.get(), nullptr);
}
expectOneRttCipher(expectOneRtt, true);
}
Buf getHandshakeWriteBytes() {
auto buf = folly::IOBuf::create(0);
if (!cryptoState.initialStream.writeBuffer.empty()) {
buf->prependChain(cryptoState.initialStream.writeBuffer.move());
}
if (!cryptoState.handshakeStream.writeBuffer.empty()) {
buf->prependChain(cryptoState.handshakeStream.writeBuffer.move());
}
if (!cryptoState.oneRttStream.writeBuffer.empty()) {
buf->prependChain(cryptoState.oneRttStream.writeBuffer.move());
}
return buf;
}
void operator()(fizz::DeliverAppData&) {
// do nothing here.
}
void operator()(fizz::WriteToSocket& write) {
serverOutput.push_back(std::move(write));
}
void operator()(fizz::server::ReportEarlyHandshakeSuccess&) {
earlyHandshakeSuccess = true;
}
void operator()(fizz::server::ReportHandshakeSuccess&) {
handshakeSuccess = true;
}
void operator()(fizz::ReportError& error) {
handshakeError = std::move(error);
}
void operator()(fizz::WaitForData&) {
fizzServer->waitForData();
}
void operator()(fizz::server::MutateState& mutator) {
mutator(serverState);
}
void operator()(fizz::server::AttemptVersionFallback&) {}
void operator()(fizz::SecretAvailable&) {}
void operator()(fizz::EndOfData&) {}
class DelayedHolder : public folly::DelayedDestruction {};
folly::EventBase evb;
std::unique_ptr<ClientHandshake> handshake;
QuicCryptoState cryptoState;
std::string hostname;
fizz::server::ServerStateMachine machine;
fizz::server::State serverState;
std::unique_ptr<fizz::server::FizzServer<
ClientHandshakeTest,
fizz::server::ServerStateMachine>>
fizzServer;
std::vector<fizz::WriteToSocket> serverOutput;
bool handshakeSuccess{false};
bool earlyHandshakeSuccess{false};
folly::Optional<fizz::ReportError> handshakeError;
folly::IOBufQueue serverReadBuf{folly::IOBufQueue::cacheChainLength()};
std::unique_ptr<DelayedHolder, folly::DelayedDestruction::Destructor> dg;
std::unique_ptr<fizz::Aead> handshakeWriteCipher;
std::unique_ptr<fizz::Aead> handshakeReadCipher;
std::unique_ptr<fizz::Aead> oneRttWriteCipher;
std::unique_ptr<fizz::Aead> oneRttReadCipher;
std::unique_ptr<fizz::Aead> zeroRttWriteCipher;
folly::Optional<bool> zeroRttRejected;
std::shared_ptr<fizz::test::MockCertificateVerifier> verifier;
std::shared_ptr<fizz::client::FizzClientContext> clientCtx;
std::shared_ptr<fizz::server::FizzServerContext> serverCtx;
};
TEST_F(ClientHandshakeTest, TestHandshakeSuccess) {
EXPECT_CALL(*verifier, verify(_));
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Initial);
expectHandshakeCipher(false);
serverClientRound();
expectHandshakeCipher(true);
EXPECT_FALSE(zeroRttRejected.hasValue());
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
clientServerRound();
expectOneRttCipher(true);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
handshake->onRecvOneRttProtectedData();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established);
EXPECT_FALSE(zeroRttRejected.hasValue());
EXPECT_TRUE(handshakeSuccess);
}
TEST_F(ClientHandshakeTest, TestNoErrorAfterAppClose) {
EXPECT_CALL(*verifier, verify(_));
clientServerRound();
serverClientRound();
clientServerRound();
fizzServer->appClose();
evb.loop();
// RTT 1/2 server -> client
EXPECT_NO_THROW(serverClientRound());
expectOneRttCipher(true);
EXPECT_FALSE(zeroRttRejected.hasValue());
EXPECT_TRUE(handshakeSuccess);
}
TEST_F(ClientHandshakeTest, TestAppBytesInterpretedAsHandshake) {
EXPECT_CALL(*verifier, verify(_));
clientServerRound();
serverClientRound();
clientServerRound();
fizz::AppWrite w;
w.data = IOBuf::copyBuffer("hey");
fizzServer->appWrite(std::move(w));
evb.loop();
// RTT 1/2 server -> client
serverClientRound();
expectOneRttCipher(true);
EXPECT_FALSE(zeroRttRejected.hasValue());
EXPECT_TRUE(handshakeSuccess);
}
class MockClientHandshakeCallback : public ClientHandshake::HandshakeCallback {
public:
GMOCK_METHOD1_(
,
noexcept,
,
onNewCachedPsk,
void(fizz::client::NewCachedPsk&));
};
class ClientHandshakeCallbackTest : public ClientHandshakeTest {
public:
void setupClientAndServerContext() override {
clientCtx = std::make_shared<fizz::client::FizzClientContext>();
clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
setupZeroRttOnServerCtx(*serverCtx, psk_);
conn_.version = getVersion();
}
void connect() override {
handshake->connect(
clientCtx,
verifier,
hostname,
folly::none,
std::make_shared<ClientTransportParametersExtension>(
getVersion(),
folly::to<uint32_t>(kDefaultConnectionWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen),
&mockClientHandshakeCallback_);
}
protected:
QuicCachedPsk psk_;
QuicConnectionStateBase conn_{QuicNodeType::Client};
MockClientHandshakeCallback mockClientHandshakeCallback_;
};
TEST_F(ClientHandshakeCallbackTest, TestHandshakeSuccess) {
clientServerRound();
serverClientRound();
clientServerRound();
EXPECT_CALL(mockClientHandshakeCallback_, onNewCachedPsk(_));
serverClientRound();
}
class ClientHandshakeHRRTest : public ClientHandshakeTest {
public:
~ClientHandshakeHRRTest() override = default;
void setupClientAndServerContext() override {
clientCtx = std::make_shared<fizz::client::FizzClientContext>();
clientCtx->setSupportedGroups(
{fizz::NamedGroup::secp256r1, fizz::NamedGroup::x25519});
clientCtx->setDefaultShares({fizz::NamedGroup::secp256r1});
serverCtx = std::make_shared<fizz::server::FizzServerContext>();
serverCtx->setFactory(std::make_shared<QuicFizzFactory>());
serverCtx->setSupportedGroups({fizz::NamedGroup::x25519});
setupCtxWithTestCert(*serverCtx);
}
};
TEST_F(ClientHandshakeHRRTest, TestFullHRR) {
EXPECT_CALL(*verifier, verify(_));
clientServerRound();
expectHandshakeCipher(false);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Initial);
serverClientRound();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Handshake);
clientServerRound();
expectOneRttCipher(false);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Handshake);
serverClientRound();
expectHandshakeCipher(true);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
clientServerRound();
expectOneRttCipher(true);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
EXPECT_FALSE(zeroRttRejected.hasValue());
EXPECT_TRUE(handshakeSuccess);
}
TEST_F(ClientHandshakeHRRTest, TestHRROnlyOneRound) {
EXPECT_CALL(*verifier, verify(_)).Times(0);
clientServerRound();
serverClientRound();
clientServerRound();
expectOneRttCipher(false);
EXPECT_FALSE(handshakeSuccess);
}
class ClientHandshakeZeroRttTest : public ClientHandshakeTest {
public:
~ClientHandshakeZeroRttTest() override = default;
void setupClientAndServerContext() override {
clientCtx = std::make_shared<fizz::client::FizzClientContext>();
clientCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
clientCtx->setSupportedAlpns({"h1q-fb", "hq"});
serverCtx->setSupportedVersions({fizz::ProtocolVersion::tls_1_3});
serverCtx->setSupportedAlpns({"h1q-fb"});
setupCtxWithTestCert(*serverCtx);
psk = setupZeroRttOnClientCtx(*clientCtx, hostname, QuicVersion::MVFST);
setupZeroRttServer();
}
void connect() override {
handshake->connect(
clientCtx,
verifier,
hostname,
psk.cachedPsk,
std::make_shared<ClientTransportParametersExtension>(
getVersion(),
folly::to<uint32_t>(kDefaultConnectionWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
folly::to<uint32_t>(kDefaultStreamWindowSize),
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen),
nullptr);
}
virtual void setupZeroRttServer() {
setupZeroRttOnServerCtx(*serverCtx, psk);
}
QuicCachedPsk psk;
};
TEST_F(ClientHandshakeZeroRttTest, TestZeroRttSuccess) {
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Initial);
expectZeroRttCipher(true, false);
expectHandshakeCipher(false);
serverClientRound();
expectHandshakeCipher(true);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
EXPECT_FALSE(zeroRttRejected.hasValue());
expectZeroRttCipher(true, true);
clientServerRound();
handshake->onRecvOneRttProtectedData();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established);
EXPECT_EQ(handshake->getApplicationProtocol(), "h1q-fb");
}
class ClientHandshakeZeroRttReject : public ClientHandshakeZeroRttTest {
public:
~ClientHandshakeZeroRttReject() override = default;
void setupZeroRttServer() override {}
};
TEST_F(ClientHandshakeZeroRttReject, TestZeroRttRejection) {
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Initial);
expectZeroRttCipher(true, false);
expectHandshakeCipher(false);
serverClientRound();
expectHandshakeCipher(true);
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::OneRttKeysDerived);
EXPECT_TRUE(zeroRttRejected.value_or(false));
// We will still keep the zero rtt key lying around.
expectZeroRttCipher(true, true);
clientServerRound();
handshake->onRecvOneRttProtectedData();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Established);
}
class ClientHandshakeZeroRttRejectFail : public ClientHandshakeZeroRttTest {
public:
~ClientHandshakeZeroRttRejectFail() override = default;
void setupClientAndServerContext() override {
// set it up so that the identity will not match.
hostname = "foobar";
ClientHandshakeZeroRttTest::setupClientAndServerContext();
}
void setupZeroRttServer() override {}
};
TEST_F(ClientHandshakeZeroRttRejectFail, TestZeroRttRejectionParamsDontMatch) {
clientServerRound();
EXPECT_EQ(handshake->getPhase(), ClientHandshake::Phase::Initial);
expectHandshakeCipher(false);
expectZeroRttCipher(true, false);
EXPECT_THROW(serverClientRound(), QuicInternalException);
}
} // namespace test
} // namespace quic

View File

@ -0,0 +1,155 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <gtest/gtest.h>
#include <quic/client/handshake/ClientTransportParametersExtension.h>
#include <quic/common/test/TestUtils.h>
#include <fizz/protocol/test/TestMessages.h>
using namespace fizz;
using namespace fizz::test;
namespace quic {
namespace test {
static EncryptedExtensions getEncryptedExtensions() {
auto ee = TestMessages::encryptedExt();
ServerTransportParameters serverParams;
serverParams.supported_versions = {QuicVersion::MVFST};
ee.extensions.push_back(encodeExtension(std::move(serverParams)));
return ee;
}
TEST(ClientTransportParametersTest, TestGetChloExtensions) {
ClientTransportParametersExtension ext(
MVFST1,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen);
auto extensions = ext.getClientHelloExtensions();
EXPECT_EQ(extensions.size(), 1);
auto serverParams = getExtension<ClientTransportParameters>(extensions);
EXPECT_TRUE(serverParams.hasValue());
}
TEST(ClientTransportParametersTest, TestOnEE) {
ClientTransportParametersExtension ext(
MVFST1,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen);
ext.getClientHelloExtensions();
ext.onEncryptedExtensions(getEncryptedExtensions().extensions);
}
TEST(ClientTransportParametersTest, TestOnEEMissingServerParams) {
ClientTransportParametersExtension ext(
MVFST1,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen);
ext.getClientHelloExtensions();
EXPECT_THROW(
ext.onEncryptedExtensions(TestMessages::encryptedExt().extensions),
FizzException);
}
TEST(ClientTransportParametersTest, TestGetChloExtensionsCustomParams) {
std::vector<TransportParameter> customTransportParameters;
std::string randomBytes = "\x01\x00\x55\x12\xff";
std::unique_ptr<CustomTransportParameter> element1 =
std::make_unique<CustomIntegralTransportParameter>(0x4000, 12);
std::unique_ptr<CustomTransportParameter> element2 =
std::make_unique<CustomStringTransportParameter>(0x4001, "abc");
std::unique_ptr<CustomTransportParameter> element3 =
std::make_unique<CustomBlobTransportParameter>(
0x4002, folly::IOBuf::copyBuffer(randomBytes));
customTransportParameters.push_back(element1->encode());
customTransportParameters.push_back(element2->encode());
customTransportParameters.push_back(element3->encode());
ClientTransportParametersExtension ext(
MVFST1,
kDefaultConnectionWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultStreamWindowSize,
kDefaultIdleTimeout,
kDefaultAckDelayExponent,
kDefaultUDPSendPacketLen,
customTransportParameters);
auto extensions = ext.getClientHelloExtensions();
EXPECT_EQ(extensions.size(), 1);
auto serverParams = getExtension<ClientTransportParameters>(extensions);
EXPECT_TRUE(serverParams.hasValue());
// check to see that the custom parameters are present
auto it1 = std::find_if(
serverParams->parameters.begin(),
serverParams->parameters.end(),
[](const TransportParameter& param) {
return static_cast<uint16_t>(param.parameter) == 0x4000;
});
EXPECT_NE(it1, serverParams->parameters.end());
auto it2 = std::find_if(
serverParams->parameters.begin(),
serverParams->parameters.end(),
[](const TransportParameter& param) {
return static_cast<uint16_t>(param.parameter) == 0x4001;
});
EXPECT_NE(it2, serverParams->parameters.end());
auto it3 = std::find_if(
serverParams->parameters.begin(),
serverParams->parameters.end(),
[](const TransportParameter& param) {
return static_cast<uint16_t>(param.parameter) == 0x4002;
});
EXPECT_NE(it3, serverParams->parameters.end());
// check that the values equal what we expect
folly::IOBufEqualTo eq;
folly::io::Cursor cursor1 = folly::io::Cursor(it1->value.get());
auto val = decodeQuicInteger(cursor1);
EXPECT_EQ(val->first, 12);
EXPECT_TRUE(eq(folly::IOBuf::copyBuffer("abc"), it2->value));
EXPECT_TRUE(eq(folly::IOBuf::copyBuffer(randomBytes), it3->value));
}
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <quic/client/handshake/QuicPskCache.h>
#include <folly/Optional.h>
#include <folly/portability/GMock.h>
#include <string>
namespace quic {
class MockQuicPskCache : public QuicPskCache {
public:
MOCK_METHOD1(getPsk, folly::Optional<QuicCachedPsk>(const std::string&));
MOCK_METHOD2(putPsk, void(const std::string&, QuicCachedPsk));
MOCK_METHOD1(removePsk, void(const std::string&));
};
} // namespace quic

View File

@ -0,0 +1,53 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
cpp_unittest(
name = "ClientHandshakeTest",
srcs = [
"ClientHandshakeTest.cpp",
],
deps = [
":mock_psk_cache",
"//fizz/crypto/test:TestUtil",
"//fizz/protocol/test:mocks",
"//fizz/server:protocol",
"//fizz/server/test:mocks",
"//folly/io/async:scoped_event_base_thread",
"//folly/io/async:ssl_context",
"//folly/io/async/test:mocks",
"//folly/ssl:init",
"//quic/client/handshake:client_handshake",
"//quic/common/test:test_utils",
"//quic/state:state_machine",
"//quic/state:stream_functions",
],
external_deps = [
("googletest", None, "gmock"),
],
)
cpp_unittest(
name = "ClientTransportParametersTest",
srcs = [
"ClientTransportParametersTest.cpp",
],
deps = [
"//fizz/protocol/test:test_messages",
"//quic/client/handshake:client_extension",
"//quic/common/test:test_utils",
],
)
cpp_library(
name = "mock_psk_cache",
headers = [
"MockQuicPskCache.h",
],
deps = [
"//folly:optional",
"//folly/portability:gmock",
"//quic/client/handshake:psk_cache",
],
)

View File

@ -0,0 +1,174 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/client/state/ClientStateMachine.h>
#include <folly/io/async/AsyncSocketException.h>
#include <quic/client/handshake/ClientHandshake.h>
#include <quic/congestion_control/QuicCubic.h>
#include <quic/flowcontrol/QuicFlowController.h>
#include <quic/handshake/TransportParameters.h>
#include <quic/state/QuicStateFunctions.h>
#include <quic/state/StateData.h>
#include <quic/state/StateMachine.h>
namespace quic {
std::unique_ptr<QuicClientConnectionState> undoAllClientStateForVersionMismatch(
std::unique_ptr<QuicClientConnectionState> conn,
QuicVersion /* negotiatedVersion */) {
// Create a new connection state and copy over properties that don't change
// across version negotiation.
auto newConn = std::make_unique<QuicClientConnectionState>();
newConn->clientConnectionId = conn->clientConnectionId;
newConn->initialDestinationConnectionId =
conn->initialDestinationConnectionId;
// TODO: don't carry server connection id over to the new connection.
newConn->serverConnectionId = conn->serverConnectionId;
newConn->ackStates.initialAckState.nextPacketNum =
conn->ackStates.initialAckState.nextPacketNum;
newConn->ackStates.handshakeAckState.nextPacketNum =
conn->ackStates.handshakeAckState.nextPacketNum;
newConn->ackStates.appDataAckState.nextPacketNum =
conn->ackStates.appDataAckState.nextPacketNum;
newConn->version = conn->version;
newConn->originalVersion = conn->originalVersion;
newConn->originalPeerAddress = conn->originalPeerAddress;
newConn->peerAddress = conn->peerAddress;
newConn->udpSendPacketLen = conn->udpSendPacketLen;
newConn->supportedVersions = conn->supportedVersions;
newConn->transportSettings = conn->transportSettings;
newConn->initialWriteCipher = std::move(conn->initialWriteCipher);
newConn->versionNegotiationNeeded = true;
newConn->readCodec = std::make_unique<QuicReadCodec>(QuicNodeType::Client);
newConn->readCodec->setClientConnectionId(*conn->clientConnectionId);
newConn->readCodec->setCodecParameters(
CodecParameters(conn->peerAckDelayExponent));
return newConn;
}
void processServerInitialParams(
QuicClientConnectionState& conn,
ServerTransportParameters serverParams,
PacketNum packetNum) {
auto maxData = getIntegerParameter(
TransportParameterId::initial_max_data, serverParams.parameters);
auto maxStreamDataBidiLocal = getIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_local,
serverParams.parameters);
auto maxStreamDataBidiRemote = getIntegerParameter(
TransportParameterId::initial_max_stream_data_bidi_remote,
serverParams.parameters);
auto maxStreamDataUni = getIntegerParameter(
TransportParameterId::initial_max_stream_data_uni,
serverParams.parameters);
auto idleTimeout = getIntegerParameter(
TransportParameterId::idle_timeout, serverParams.parameters);
auto maxStreamsBidi = getIntegerParameter(
TransportParameterId::initial_max_streams_bidi, serverParams.parameters);
auto maxStreamsUni = getIntegerParameter(
TransportParameterId::initial_max_streams_uni, serverParams.parameters);
auto ackDelayExponent = getIntegerParameter(
TransportParameterId::ack_delay_exponent, serverParams.parameters);
auto packetSize = getIntegerParameter(
TransportParameterId::max_packet_size, serverParams.parameters);
auto statelessResetToken =
getStatelessResetTokenParameter(serverParams.parameters);
auto partialReliability = getIntegerParameter(
static_cast<TransportParameterId>(kPartialReliabilityParameterId),
serverParams.parameters);
if (!packetSize || *packetSize == 0) {
packetSize = kDefaultMaxUDPPayload;
}
if (*packetSize < kMinMaxUDPPayload) {
throw QuicTransportException(
folly::to<std::string>(
"Max packet size too small. received max_packetSize = ",
*packetSize),
TransportErrorCode::TRANSPORT_PARAMETER_ERROR);
}
VLOG(10) << "Client advertised flow control ";
VLOG(10) << "conn=" << maxData.value_or(0);
VLOG(10) << " stream bidi local=" << maxStreamDataBidiLocal.value_or(0)
<< " ";
VLOG(10) << " stream bidi remote=" << maxStreamDataBidiRemote.value_or(0)
<< " ";
VLOG(10) << " stream uni=" << maxStreamDataUni.value_or(0) << " ";
VLOG(10) << conn;
conn.flowControlState.peerAdvertisedMaxOffset = maxData.value_or(0);
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal =
maxStreamDataBidiLocal.value_or(0);
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote =
maxStreamDataBidiRemote.value_or(0);
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetUni =
maxStreamDataUni.value_or(0);
// TODO Make idleTimeout disableable via transport parameter.
conn.streamManager->setMaxLocalBidirectionalStreams(
maxStreamsBidi.value_or(0));
conn.streamManager->setMaxLocalUnidirectionalStreams(
maxStreamsUni.value_or(0));
conn.peerIdleTimeout = std::chrono::seconds(idleTimeout.value_or(0));
if (ackDelayExponent && *ackDelayExponent > kMaxAckDelayExponent) {
throw QuicTransportException(
"ack_delay_exponent too large",
TransportErrorCode::TRANSPORT_PARAMETER_ERROR);
}
conn.peerAckDelayExponent =
ackDelayExponent.value_or(kDefaultAckDelayExponent);
// TODO: udpSendPacketLen should also be limited by PMTU
if (conn.transportSettings.canIgnorePathMTU) {
conn.udpSendPacketLen = *packetSize;
}
if (partialReliability && *partialReliability != 0 &&
conn.transportSettings.partialReliabilityEnabled) {
conn.partialReliabilityEnabled = true;
}
VLOG(10) << "conn.partialReliabilityEnabled="
<< conn.partialReliabilityEnabled;
conn.statelessResetToken = std::move(statelessResetToken);
// Update the existing streams, because we allow streams to be created before
// the connection is established.
conn.streamManager->streamStateForEach([&conn,
&packetNum](QuicStreamState& s) {
auto windowSize = isUnidirectionalStream(s.id)
? conn.transportSettings.advertisedInitialUniStreamWindowSize
: isLocalStream(conn.nodeType, s.id)
? conn.transportSettings.advertisedInitialBidiLocalStreamWindowSize
: conn.transportSettings
.advertisedInitialBidiRemoteStreamWindowSize;
handleStreamWindowUpdate(s, windowSize, packetNum);
});
}
void updateTransportParamsFromCachedEarlyParams(
QuicClientConnectionState& conn,
const CachedServerTransportParameters& transportParams) {
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiLocal =
transportParams.initialMaxStreamDataBidiLocal;
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetBidiRemote =
transportParams.initialMaxStreamDataBidiRemote;
conn.flowControlState.peerAdvertisedInitialMaxStreamOffsetUni =
transportParams.initialMaxStreamDataUni;
conn.flowControlState.peerAdvertisedMaxOffset =
transportParams.initialMaxData;
conn.peerIdleTimeout = std::chrono::seconds(transportParams.idleTimeout);
if (conn.transportSettings.canIgnorePathMTU) {
conn.udpSendPacketLen = transportParams.maxRecvPacketSize;
}
conn.peerAckDelayExponent = transportParams.ackDelayExponent;
}
void ClientInvalidStateHandler(QuicClientConnectionState& state) {
state.state = ClientStates::Error();
}
} // namespace quic

View File

@ -0,0 +1,102 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/io/async/AsyncSocketException.h>
#include <quic/client/handshake/ClientHandshake.h>
#include <quic/congestion_control/QuicCubic.h>
#include <quic/flowcontrol/QuicFlowController.h>
#include <quic/handshake/TransportParameters.h>
#include <quic/state/QuicStateFunctions.h>
#include <quic/state/StateData.h>
#include <quic/state/StateMachine.h>
namespace quic {
struct ClientStates {
struct Handshaking {};
struct Error {};
};
struct ClientEvents {
struct ReadData {
folly::SocketAddress peer;
Buf buf;
folly::Optional<folly::AsyncSocketException> error;
};
};
using ClientState =
boost::variant<ClientStates::Handshaking, ClientStates::Error>;
struct QuicClientConnectionState : public QuicConnectionStateBase {
~QuicClientConnectionState() override = default;
ClientState state;
// Whether version negotiation was done. We might need to error out
// all the callbacks as a result.
bool versionNegotiationNeeded{false};
// The stateless reset token sent by the server.
folly::Optional<StatelessResetToken> statelessResetToken;
// Initial destination connection id.
folly::Optional<ConnectionId> initialDestinationConnectionId;
ClientHandshake* clientHandshakeLayer;
// Packet number in which client initial was sent. Receipt of data on the
// crypto stream from the server can implicitly ack the client initial packet.
// TODO: use this to get rid of the data in the crypto stream.
// folly::Optional<PacketNum> clientInitialPacketNum;
QuicClientConnectionState() : QuicConnectionStateBase(QuicNodeType::Client) {
state = ClientStates::Handshaking();
cryptoState = std::make_unique<QuicCryptoState>();
congestionController = std::make_unique<Cubic>(*this);
// TODO: this is wrong, it should be the handshake finish time. But i need
// a relatively sane time now to make the timestamps all sane.
connectionTime = Clock::now();
supportedVersions = {QuicVersion::MVFST, QuicVersion::QUIC_DRAFT};
originalVersion = QuicVersion::MVFST;
clientHandshakeLayer = new ClientHandshake(*cryptoState);
handshakeLayer.reset(clientHandshakeLayer);
// We shouldn't normally need to set this until we're starting the
// transport, however writing unit tests is much easier if we set this here.
updateFlowControlStateWithSettings(flowControlState, transportSettings);
streamManager = std::make_unique<QuicStreamManager>(*this, this->nodeType);
}
};
void ClientInvalidStateHandler(QuicClientConnectionState& state);
struct QuicClientStateMachine {
using StateData = QuicClientConnectionState;
static constexpr auto InvalidEventHandler = &ClientInvalidStateHandler;
};
/**
* Undos the clients state to be the original state of the client. This is
* intended to be used in the case version negotiation is performed.
*/
std::unique_ptr<QuicClientConnectionState> undoAllClientStateForVersionMismatch(
std::unique_ptr<QuicClientConnectionState> conn,
QuicVersion /* negotiatedVersion */);
void processServerInitialParams(
QuicClientConnectionState& conn,
ServerTransportParameters serverParams,
PacketNum packetNum);
void updateTransportParamsFromCachedEarlyParams(
QuicClientConnectionState& conn,
const CachedServerTransportParameters& transportParams);
} // namespace quic

22
quic/client/state/TARGETS Normal file
View File

@ -0,0 +1,22 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
cpp_library(
name = "client",
srcs = [
"ClientStateMachine.cpp",
],
headers = [
"ClientStateMachine.h",
],
deps = [
"//folly/io/async:async_socket_exception",
"//quic/client/handshake:client_handshake",
"//quic/congestion_control:cubic",
"//quic/flowcontrol:flow_control",
"//quic/handshake:transport_parameters",
"//quic/state:state_functions",
"//quic/state:state_machine",
],
)

View File

@ -0,0 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
if(NOT BUILD_TESTS)
return()
endif()
quic_add_test(TARGET QuicClientTransportTest
SOURCES
QuicClientTransportTest.cpp
DEPENDS
Folly::folly
${LIBGMOCK_LIBRARIES}
mvfst_client
mvfst_codec_types
mvfst_echohandler
mvfst_handshake
mvfst_mock_socket
mvfst_mock_state
mvfst_server
mvfst_test_utils
mvfst_transport
)

File diff suppressed because it is too large Load Diff

29
quic/client/test/TARGETS Normal file
View File

@ -0,0 +1,29 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
cpp_unittest(
name = "QuicClientTransportTest",
srcs = [
"QuicClientTransportTest.cpp",
],
deps = [
"//fizz/crypto/aead/test:mocks",
"//folly/futures:core",
"//folly/io:iobuf",
"//folly/io/async:scoped_event_base_thread",
"//folly/io/async/test:mocks",
"//folly/portability:gmock",
"//folly/portability:gtest",
"//quic/api/test:mocks",
"//quic/client:client",
"//quic/client/handshake/test:mock_psk_cache",
"//quic/codec:types",
"//quic/common/test:test_utils",
"//quic/congestion_control:congestion_controller_factory",
"//quic/handshake:transport_parameters",
"//quic/samples/echo:EchoHandler",
"//quic/server:server",
"//quic/state/test:mocks",
],
)

258
quic/codec/CMakeLists.txt Normal file
View File

@ -0,0 +1,258 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
add_library(
mvfst_codec_types STATIC
DefaultConnectionIdAlgo.cpp
PacketNumber.cpp
QuicConnectionId.cpp
QuicInteger.cpp
Types.cpp
)
target_include_directories(
mvfst_codec_types PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_codec_types
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_codec_types
mvfst_constants
mvfst_exception
)
target_link_libraries(
mvfst_codec_types PUBLIC
Folly::folly
mvfst_constants
mvfst_exception
PRIVATE
${Boost_LIBRARIES}
)
add_library(
mvfst_codec_decode STATIC
Decode.cpp
)
target_include_directories(
mvfst_codec_decode PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_codec_decode
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_codec_decode
mvfst_codec_types
mvfst_exception
)
target_link_libraries(
mvfst_codec_decode PUBLIC
Folly::folly
mvfst_codec_types
mvfst_exception
)
add_library(
mvfst_codec_packet_number_cipher STATIC
PacketNumberCipher.cpp
)
target_include_directories(
mvfst_codec_packet_number_cipher PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_codec_packet_number_cipher
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_codec_packet_number_cipher
mvfst_codec_types
mvfst_codec_decode
)
target_link_libraries(
mvfst_codec_packet_number_cipher PUBLIC
Folly::folly
mvfst_codec_types
mvfst_codec_decode
)
add_library(
mvfst_codec_pktbuilder STATIC
QuicPacketBuilder.cpp
)
target_include_directories(
mvfst_codec_pktbuilder PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_codec_pktbuilder
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_codec_pktbuilder
mvfst_codec_types
mvfst_handshake
)
target_link_libraries(
mvfst_codec_pktbuilder PUBLIC
Folly::folly
mvfst_codec_types
mvfst_handshake
)
add_library(
mvfst_codec_pktrebuilder STATIC
QuicPacketRebuilder.cpp
)
target_include_directories(
mvfst_codec_pktrebuilder PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_codec_pktrebuilder
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_codec_pktrebuilder
mvfst_codec
mvfst_codec_pktbuilder
mvfst_flowcontrol
mvfst_state_machine
mvfst_state_simple_frame_functions
mvfst_state_stream_functions
)
target_link_libraries(
mvfst_codec_pktrebuilder PUBLIC
Folly::folly
mvfst_codec
mvfst_codec_pktbuilder
mvfst_flowcontrol
mvfst_state_machine
mvfst_state_simple_frame_functions
mvfst_state_stream_functions
)
add_library(
mvfst_codec STATIC
QuicHeaderCodec.cpp
QuicReadCodec.cpp
QuicWriteCodec.cpp
)
target_include_directories(
mvfst_codec PUBLIC
$<BUILD_INTERFACE:${QUIC_FBCODE_ROOT}>
$<INSTALL_INTERFACE:include/>
)
target_compile_options(
mvfst_codec
PRIVATE
${_QUIC_COMMON_COMPILE_OPTIONS}
)
add_dependencies(
mvfst_codec
mvfst_constants
mvfst_codec_decode
mvfst_codec_types
mvfst_codec_pktbuilder
mvfst_exception
mvfst_handshake
mvfst_state_ack_states
)
target_link_libraries(
mvfst_codec PUBLIC
Folly::folly
mvfst_constants
mvfst_codec_decode
mvfst_codec_types
mvfst_exception
mvfst_handshake
mvfst_state_ack_states
)
file(
GLOB_RECURSE QUIC_API_HEADERS_TOINSTALL
RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
*.h
)
list(FILTER QUIC_API_HEADERS_TOINSTALL EXCLUDE REGEX test/)
foreach(header ${QUIC_API_HEADERS_TOINSTALL})
get_filename_component(header_dir ${header} DIRECTORY)
install(FILES ${header} DESTINATION include/quic/codec/${header_dir})
endforeach()
install(
TARGETS mvfst_codec_types
EXPORT mvfst-exports
DESTINATION lib
)
install(
TARGETS mvfst_codec_decode
EXPORT mvfst-exports
DESTINATION lib
)
install(
TARGETS mvfst_codec_pktbuilder
EXPORT mvfst-exports
DESTINATION lib
)
install(
TARGETS mvfst_codec_pktrebuilder
EXPORT mvfst-exports
DESTINATION lib
)
install(
TARGETS mvfst_codec_packet_number_cipher
EXPORT mvfst-exports
DESTINATION lib
)
install(
TARGETS mvfst_codec
EXPORT mvfst-exports
DESTINATION lib
)
add_subdirectory(test)

View File

@ -0,0 +1,56 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Optional.h>
#include <quic/codec/QuicConnectionId.h>
namespace quic {
/**
* Interface to encode and decode algorithms for ConnectionId given routing
* info (embedded in ServerConnectionIdParams)
*
* NOTE: since several of these methods are called for every single packets,
* and every single connection, it is important to not do any
* blocking call in any of the implementation of these methods.
*/
class ConnectionIdAlgo {
public:
virtual ~ConnectionIdAlgo() = default;
/**
* Check if this implementation of algorithm can parse the given ConnectionId
*/
virtual bool canParse(const ConnectionId& id) const = 0;
/**
* Parses ServerConnectionIdParams from the given connection id.
*/
virtual ServerConnectionIdParams parseConnectionId(
const ConnectionId& id) = 0;
/**
* Encodes the given ServerConnectionIdParams into connection id
*/
virtual ConnectionId encodeConnectionId(
const ServerConnectionIdParams& params) = 0;
};
/**
* Factory interface to create ConnectionIdAlgo instance.
*/
class ConnectionIdAlgoFactory {
public:
virtual ~ConnectionIdAlgoFactory() = default;
virtual std::unique_ptr<ConnectionIdAlgo> make() = 0;
};
} // namespace quic

1010
quic/codec/Decode.cpp Normal file

File diff suppressed because it is too large Load Diff

200
quic/codec/Decode.h Normal file
View File

@ -0,0 +1,200 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <folly/io/Cursor.h>
#include <quic/codec/PacketNumber.h>
#include <quic/codec/Types.h>
namespace quic {
/**
* Connection level parameters needed by the codec to decode the packet
* successfully.
*/
struct CodecParameters {
// This must not be set to zero.
uint8_t peerAckDelayExponent{kDefaultAckDelayExponent};
CodecParameters() = default;
explicit CodecParameters(uint8_t peerAckDelayExponentIn)
: peerAckDelayExponent(peerAckDelayExponentIn) {}
};
struct ParsedLongHeaderInvariant {
uint8_t initialByte;
LongHeaderInvariant invariant;
size_t invariantLength;
ParsedLongHeaderInvariant(
uint8_t initialByteIn,
LongHeaderInvariant headerInvariant,
size_t length);
};
/**
* Decodes a version negotiation packet. Returns a folly::none, if it cannot
* decode the packet.
*/
folly::Optional<VersionNegotiationPacket> decodeVersionNegotiation(
const ParsedLongHeaderInvariant& longHeaderInvariant,
folly::io::Cursor& cursor);
/**
* Decodes a single regular QUIC packet from the cursor.
* The packet in the cursor must be at least 1 QUIC packet.
* Throws with a QuicException if the data in the cursor is not a complete QUIC
* packet or the packet could not be decoded correctly.
*/
RegularQuicPacket decodeRegularPacket(
PacketHeader&& header,
const CodecParameters& params,
folly::io::Cursor& cursor);
/**
* Parses a single frame from the cursor. Throws a QuicException if the frame
* could not be parsed.
*/
QuicFrame parseFrame(
folly::io::Cursor& cursor,
const PacketHeader& header,
const CodecParameters& params);
/**
* The following functions decode frames. They throw an QuicException when error
* occurs.
*/
PaddingFrame decodePaddingFrame(folly::io::Cursor&);
RstStreamFrame decodeRstStreamFrame(folly::io::Cursor& cursor);
ConnectionCloseFrame decodeConnectionCloseFrame(folly::io::Cursor& cursor);
ApplicationCloseFrame decodeApplicationCloseFrame(folly::io::Cursor& cursor);
MaxDataFrame decodeMaxDataFrame(folly::io::Cursor& cursor);
MaxStreamDataFrame decodeMaxStreamDataFrame(folly::io::Cursor& cursor);
ExpiredStreamDataFrame decodeExpiredStreamDataFrame(folly::io::Cursor& cursor);
MinStreamDataFrame decodeMinStreamDataFrame(folly::io::Cursor& cursor);
MaxStreamsFrame decodeBiDiMaxStreamsFrame(folly::io::Cursor& cursor);
MaxStreamsFrame decodeUniMaxStreamsFrame(folly::io::Cursor& cursor);
PingFrame decodePingFrame(folly::io::Cursor& cursor);
DataBlockedFrame decodeDataBlockedFrame(folly::io::Cursor& cursor);
StreamDataBlockedFrame decodeStreamDataBlockedFrame(folly::io::Cursor& cursor);
StreamsBlockedFrame decodeBiDiStreamsBlockedFrame(folly::io::Cursor& cursor);
StreamsBlockedFrame decodeUniStreamsBlockedFrame(folly::io::Cursor& cursor);
NewConnectionIdFrame decodeNewConnectionIdFrame(folly::io::Cursor& cursor);
NoopFrame decodeRetireConnectionIdFrame(folly::io::Cursor& cursor);
StopSendingFrame decodeStopSendingFrame(folly::io::Cursor& cursor);
PathChallengeFrame decodePathChallengeFrame(folly::io::Cursor& cursor);
PathResponseFrame decodePathResponseFrame(folly::io::Cursor& cursor);
ReadAckFrame decodeAckFrame(
folly::io::Cursor& cursor,
const PacketHeader& header,
const CodecParameters& params);
ReadAckFrame decodeAckFrameWithECN(
folly::io::Cursor& cursor,
const PacketHeader& header,
const CodecParameters& params);
ReadStreamFrame decodeStreamFrame(
folly::io::Cursor& cursor,
StreamTypeField frameTypeField);
ReadCryptoFrame decodeCryptoFrame(folly::io::Cursor& cursor);
ReadNewTokenFrame decodeNewTokenFrame(folly::io::Cursor& cursor);
/**
* Parse the Invariant fields in Long Header.
*
* cursor: points to the byte just past initialByte. After parsing, cursor will
* be moved to the byte right after Source Connection ID.
*/
folly::Expected<ParsedLongHeaderInvariant, TransportErrorCode>
parseLongHeaderInvariant(uint8_t initalByte, folly::io::Cursor& cursor);
struct PacketLength {
// The length of the packet payload (inlcuding packet number)
uint64_t packetLength;
// Length of the length field.
size_t lengthLength;
PacketLength(uint64_t packetLengthIn, size_t lengthLengthIn)
: packetLength(packetLengthIn), lengthLength(lengthLengthIn) {}
};
struct ParsedLongHeader {
LongHeader header;
PacketLength packetLength;
ParsedLongHeader(LongHeader headerIn, PacketLength packetLengthIn)
: header(std::move(headerIn)), packetLength(packetLengthIn) {}
};
struct ParsedLongHeaderResult {
bool isVersionNegotiation;
folly::Optional<ParsedLongHeader> parsedLongHeader;
ParsedLongHeaderResult(
bool isVersionNegotiationIn,
folly::Optional<ParsedLongHeader> parsedLongHeaderIn);
};
// Functions that operate on the initial byte
LongHeader::Types parseLongHeaderType(uint8_t initialByte);
size_t parsePacketNumberLength(uint8_t initialByte);
/**
* Returns the packet number and the length of the packet number.
* packetNumberRange should be kMaxPacketNumEncodingSize size.
*/
std::pair<PacketNum, size_t> parsePacketNumber(
uint8_t initialByte,
folly::ByteRange packetNumberRange,
PacketNum expectedNextPacketNum);
// cursor: has to be point to the byte just past initialByte
folly::Expected<ParsedLongHeaderResult, TransportErrorCode> parseLongHeader(
uint8_t initialByte,
folly::io::Cursor& cursor);
folly::Expected<ParsedLongHeader, TransportErrorCode> parseLongHeaderVariants(
LongHeader::Types type,
ParsedLongHeaderInvariant longHeaderInvariant,
folly::io::Cursor& cursor);
folly::Expected<ShortHeaderInvariant, TransportErrorCode>
parseShortHeaderInvariants(uint8_t initialByte, folly::io::Cursor& cursor);
folly::Expected<ShortHeader, TransportErrorCode> parseShortHeader(
uint8_t initialByte,
folly::io::Cursor& cursor);
} // namespace quic

View File

@ -0,0 +1,200 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/codec/DefaultConnectionIdAlgo.h>
#include <folly/Random.h>
#include <quic/QuicConstants.h>
#include <quic/QuicException.h>
namespace {
// mask to extract process id bit from the connectionId
constexpr uint8_t kProcessIdBitMask = 0x20;
// mask to set the first 6 bits from host-id
constexpr uint8_t kHostIdFirstByteMask = 0x3f;
// mask to set the next 8 bits from host-id
constexpr uint8_t kHostIdSecondByteMask = 0xff;
// mask to set the last 2 bits from host-id
constexpr uint8_t kHostIdThirdByteMask = 0xc0;
// mask to set the first 6 bits from worker-id
constexpr uint8_t kWorkerIdFirstByteMask = 0xfc;
// mask to set the last 2 bits from host-id
constexpr uint8_t kWorkerIdSecondByteMask = 0x03;
// first 2 bits in the connection id is reserved for versioning of the conn id
constexpr uint8_t kShortVersionBitsMask = 0xc0;
/**
* Sets the short version id bits (0 - 3) into the given ConnectionId
*/
void setVersionBitsInConnId(quic::ConnectionId& connId, uint8_t version) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for version",
quic::LocalErrorCode::INTERNAL_ERROR);
}
// clear 0-1 bits
connId.data()[0] &= (~kShortVersionBitsMask);
connId.data()[0] |= (kShortVersionBitsMask & (version << 6));
}
/**
* Extract the version id bits (0 - 1) from the given ConnectionId
*/
uint8_t getVersionBitsFromConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for version",
quic::LocalErrorCode::INTERNAL_ERROR);
}
uint8_t version = 0;
version = (kShortVersionBitsMask & connId.data()[0]) >> 6;
return version;
}
/**
* Sets the host id bits [2 - 17] bits into the given ConnectionId
*/
void setHostIdBitsInConnId(quic::ConnectionId& connId, uint16_t hostId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for hostid",
quic::LocalErrorCode::INTERNAL_ERROR);
}
// clear 2-7 bits
connId.data()[0] &= ~kHostIdFirstByteMask;
// clear 8 - 15 bits
connId.data()[1] &= ~kHostIdSecondByteMask;
// clear 16 - 17 bits
connId.data()[2] &= ~kHostIdThirdByteMask;
// set 2 - 7 bits in the connId with the first 6 bits of the worker id
connId.data()[0] |= (kHostIdFirstByteMask & (hostId >> 10));
// set 8 - 15 bits in the connId with the next 8 bits of the worker id
connId.data()[1] |= (kHostIdSecondByteMask & (hostId >> 2));
// set 16 - 17 bits in the connId with the last 2 bits of the worker id
connId.data()[2] |= (kHostIdThirdByteMask & (hostId << 6));
}
/**
* Extract the host id bits [4 - 15] bits from the given ConnectionId
*/
uint16_t getHostIdBitsInConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for hostid",
quic::LocalErrorCode::INTERNAL_ERROR);
}
uint16_t hostId = 0;
// get 2 - 7 bits from the connId and set first 6 bits of the host id
hostId = (kHostIdFirstByteMask & (connId.data()[0]));
// shift by 10 bits and make room for the last 10 bits
hostId = hostId << 10;
// get 8 - 15 bits from the connId
hostId |= (kHostIdSecondByteMask & connId.data()[1]) << 2;
// get 16 - 17 bits from the connId
hostId |= (kHostIdThirdByteMask & connId.data()[2]) >> 6;
return hostId;
}
/**
* Sets the given 8-bit workerId into the given connectionId's 16-23 bits
*/
void setWorkerIdBitsInConnId(quic::ConnectionId& connId, uint8_t workerId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for workerid",
quic::LocalErrorCode::INTERNAL_ERROR);
}
// clear 18-23 bits
connId.data()[2] &= 0xc0;
// clear 24-25 bits
connId.data()[3] &= 0x3f;
// set 18 - 23 bits in the connId with first 6 bits of the worker id
connId.data()[2] |= (kWorkerIdFirstByteMask & workerId) >> 2;
// set 24 - 25 bits in the connId with the last 2 bits of the worker id
connId.data()[3] |= (kWorkerIdSecondByteMask & workerId) << 6;
}
/**
* Extracts the 'workerId' bits from the given ConnectionId
*/
uint8_t getWorkerIdFromConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for workerid",
quic::LocalErrorCode::INTERNAL_ERROR);
}
// get 18 - 23 bits from the connId
uint8_t workerId = connId.data()[2] << 2;
// get 24 - 25 bits in the connId
workerId |= connId.data()[3] >> 6;
return workerId;
}
/**
* Sets the server id bit (at 24th bit) into the given ConnectionId
*/
void setProcessIdBitsInConnId(quic::ConnectionId& connId, uint8_t processId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for processid",
quic::LocalErrorCode::INTERNAL_ERROR);
}
// clear the 26th bit
connId.data()[3] &= (~kProcessIdBitMask);
connId.data()[3] |= (kProcessIdBitMask & (processId << 5));
}
/**
* Extract the server id bit (at 26th bit) from the given ConnectionId
*/
uint8_t getProcessIdBitsFromConnId(const quic::ConnectionId& connId) {
if (connId.size() < quic::kMinConnectionIdSize) {
throw quic::QuicInternalException(
"ConnectionId is too small for processid",
quic::LocalErrorCode::INTERNAL_ERROR);
}
uint8_t processId = 0;
processId = (kProcessIdBitMask & connId.data()[3]) >> 5;
return processId;
}
} // namespace
namespace quic {
bool DefaultConnectionIdAlgo::canParse(const ConnectionId& id) const {
if (id.size() < kMinConnectionIdSize) {
return false;
}
return getVersionBitsFromConnId(id) == kShortVersionId;
}
ServerConnectionIdParams DefaultConnectionIdAlgo::parseConnectionId(
const ConnectionId& id) {
ServerConnectionIdParams serverConnIdParams(
getVersionBitsFromConnId(id),
getHostIdBitsInConnId(id),
getProcessIdBitsFromConnId(id),
getWorkerIdFromConnId(id));
serverConnIdParams.clientConnId.assign(id);
return serverConnIdParams;
}
ConnectionId DefaultConnectionIdAlgo::encodeConnectionId(
const ServerConnectionIdParams& params) {
// In case there is no client cid, create a random connection id.
std::vector<uint8_t> connIdData(kDefaultConnectionIdSize);
folly::Random::secureRandom(connIdData.data(), connIdData.size());
ConnectionId connId = ConnectionId(std::move(connIdData));
setVersionBitsInConnId(connId, params.version);
setHostIdBitsInConnId(connId, params.hostId);
setProcessIdBitsInConnId(connId, params.processId);
setWorkerIdBitsInConnId(connId, params.workerId);
return connId;
}
} // namespace quic

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Optional.h>
#include <quic/codec/ConnectionIdAlgo.h>
#include <quic/codec/QuicConnectionId.h>
namespace quic {
/**
* Default implementation with algorithms to encode and decode for
* ConnectionId given routing info (embedded in ServerConnectionIdParams)
*
* The schema for connection id is defined as follows:
*
* First 2 (0 - 1) bits are reserved for short version id of the connection id
* If the load balancer (e.g. L4 lb) doesn't understand this version,
* it can fallback to default routing
* Next 16 (3 - 17) bits are reserved to be used for L4 LB
* Next eight bits (18 - 25) are reserved for worker id
* bit 26 is reserved for the Quic server id: server id is used to distinguish
* between the takeover instance and the taken over one
0 1 2 3 4 5 .. 17 18 19 20 .. 25 26 27 28 ... 63
|SHORT VERSION| For L4 LB | WORKER_ID | SERVER_ID | ..
*/
class DefaultConnectionIdAlgo : public ConnectionIdAlgo {
public:
~DefaultConnectionIdAlgo() override = default;
/**
* Check if this implementation of algorithm can parse the given ConnectionId
*/
bool canParse(const ConnectionId& id) const override;
/**
* Parses ServerConnectionIdParams from the given connection id.
*/
ServerConnectionIdParams parseConnectionId(const ConnectionId& id) override;
/**
* Encodes the given ServerConnectionIdParams into connection id
*/
ConnectionId encodeConnectionId(
const ServerConnectionIdParams& params) override;
};
/**
* Factory Interface to create ConnectionIdAlgo instance.
*/
class DefaultConnectionIdAlgoFactory : public ConnectionIdAlgoFactory {
public:
~DefaultConnectionIdAlgoFactory() override = default;
std::unique_ptr<ConnectionIdAlgo> make() override {
return std::make_unique<DefaultConnectionIdAlgo>();
}
};
} // namespace quic

View File

@ -0,0 +1,70 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/codec/PacketNumber.h>
namespace quic {
PacketNumEncodingResult::PacketNumEncodingResult(
PacketNum resultIn,
size_t lengthIn)
: result(resultIn), length(lengthIn) {}
PacketNumEncodingResult encodePacketNumber(
PacketNum packetNum,
PacketNum largestAckedPacketNum) {
DCHECK(
(!packetNum && !largestAckedPacketNum) ||
packetNum > largestAckedPacketNum);
PacketNum twiceDistance = (packetNum - largestAckedPacketNum) * 2;
// The number of bits we need to mask all set bits in twiceDistance.
// This is 1 + floor(log2(x)).
size_t lengthInBits = folly::findLastSet(twiceDistance);
// Round up to bytes
size_t lengthInBytes = lengthInBits == 0 ? 1 : (lengthInBits + 7) >> 3;
if (lengthInBytes > 4) {
throw QuicInternalException(
folly::to<std::string>(
"Impossible to encode PacketNum=",
packetNum,
", largestAcked=",
largestAckedPacketNum),
LocalErrorCode::PACKET_NUMBER_ENCODING);
}
// We need a mask that's all 1 for lengthInBytes bytes. Left shift a 1 by that
// many bits and then -1 will give us that. Or if lengthInBytes is 8, then ~0
// will just do it.
DCHECK_NE(lengthInBytes, 8);
int64_t mask = (1ULL << lengthInBytes * 8) - 1;
return PacketNumEncodingResult(packetNum & mask, lengthInBytes);
}
/**
* This simply follows Draft-17 Appendix-A.
*/
PacketNum decodePacketNumber(
uint64_t encodedPacketNum,
size_t packetNumBytes,
PacketNum expectedNextPacketNum) {
size_t packetNumBits = 8 * packetNumBytes;
PacketNum packetNumWin = 1ULL << packetNumBits;
PacketNum packetNumHalfWin = packetNumWin >> 1;
PacketNum mask = packetNumWin - 1;
PacketNum candidate = (expectedNextPacketNum & ~mask) | encodedPacketNum;
if (expectedNextPacketNum > packetNumHalfWin &&
candidate <= expectedNextPacketNum - packetNumHalfWin) {
return candidate + packetNumWin;
}
if (candidate > expectedNextPacketNum + packetNumHalfWin &&
candidate > packetNumWin) {
return candidate - packetNumWin;
}
return candidate;
}
} // namespace quic

38
quic/codec/PacketNumber.h Normal file
View File

@ -0,0 +1,38 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <quic/codec/Types.h>
namespace quic {
/**
* Returns a decoded packet number by using the expectedNextPacketNum to
* search for the most probable packet number that could satisfy that condition.
*/
PacketNum decodePacketNumber(
uint64_t encodedPacketNum,
size_t packetNumBytes,
PacketNum expectedNextPacketNum);
struct PacketNumEncodingResult {
PacketNum result;
// This is packet number length in bytes
size_t length;
PacketNumEncodingResult(PacketNum resultIn, size_t lengthIn);
};
PacketNumEncodingResult encodePacketNumber(
PacketNum packetNum,
PacketNum largestAckedPacketNum);
} // namespace quic

View File

@ -0,0 +1,131 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/codec/PacketNumberCipher.h>
#include <quic/codec/Decode.h>
#include <quic/codec/Types.h>
namespace quic {
constexpr size_t kAES128KeyLength = 16;
void PacketNumberCipher::decipherHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes,
uint8_t initialByteMask,
uint8_t /* packetNumLengthMask */) const {
CHECK_EQ(packetNumberBytes.size(), kMaxPacketNumEncodingSize);
HeaderProtectionMask headerMask = mask(sample);
// Mask size should be > packet number length + 1.
DCHECK_GE(headerMask.size(), 5);
initialByte.data()[0] ^= headerMask.data()[0] & initialByteMask;
size_t packetNumLength = parsePacketNumberLength(*initialByte.data());
for (size_t i = 0; i < packetNumLength; ++i) {
packetNumberBytes.data()[i] ^= headerMask.data()[i + 1];
}
}
void PacketNumberCipher::cipherHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes,
uint8_t initialByteMask,
uint8_t /* packetNumLengthMask */) const {
HeaderProtectionMask headerMask = mask(sample);
// Mask size should be > packet number length + 1.
DCHECK_GE(headerMask.size(), kMaxPacketNumEncodingSize + 1);
size_t packetNumLength = parsePacketNumberLength(*initialByte.data());
initialByte.data()[0] ^= headerMask.data()[0] & initialByteMask;
for (size_t i = 0; i < packetNumLength; ++i) {
packetNumberBytes.data()[i] ^= headerMask.data()[i + 1];
}
}
void PacketNumberCipher::decryptLongHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const {
decipherHeader(
sample,
initialByte,
packetNumberBytes,
LongHeader::kTypeBitsMask,
LongHeader::kPacketNumLenMask);
}
void PacketNumberCipher::decryptShortHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const {
decipherHeader(
sample,
initialByte,
packetNumberBytes,
ShortHeader::kTypeBitsMask,
ShortHeader::kPacketNumLenMask);
}
void PacketNumberCipher::encryptLongHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const {
cipherHeader(
sample,
initialByte,
packetNumberBytes,
LongHeader::kTypeBitsMask,
LongHeader::kPacketNumLenMask);
}
void PacketNumberCipher::encryptShortHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const {
cipherHeader(
sample,
initialByte,
packetNumberBytes,
ShortHeader::kTypeBitsMask,
ShortHeader::kPacketNumLenMask);
}
void Aes128PacketNumberCipher::setKey(folly::ByteRange key) {
encryptCtx_.reset(EVP_CIPHER_CTX_new());
if (encryptCtx_ == nullptr) {
throw std::runtime_error("Unable to allocate an EVP_CIPHER_CTX object");
}
if (EVP_EncryptInit_ex(
encryptCtx_.get(), EVP_aes_128_ecb(), nullptr, key.data(), nullptr) !=
1) {
throw std::runtime_error("Init error");
}
}
HeaderProtectionMask Aes128PacketNumberCipher::mask(
folly::ByteRange sample) const {
HeaderProtectionMask outMask;
CHECK_EQ(sample.size(), outMask.size());
int outLen = 0;
if (EVP_EncryptUpdate(
encryptCtx_.get(),
outMask.data(),
&outLen,
sample.data(),
sample.size()) != 1 ||
outLen != outMask.size()) {
throw std::runtime_error("Encryption error");
}
return outMask;
}
size_t Aes128PacketNumberCipher::keyLength() const {
return kAES128KeyLength;
}
} // namespace quic

View File

@ -0,0 +1,104 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Optional.h>
#include <folly/io/Cursor.h>
#include <folly/ssl/OpenSSLPtrTypes.h>
namespace quic {
using HeaderProtectionMask = std::array<uint8_t, 16>;
using Sample = std::array<uint8_t, 16>;
class PacketNumberCipher {
public:
virtual ~PacketNumberCipher() = default;
virtual void setKey(folly::ByteRange key) = 0;
virtual HeaderProtectionMask mask(folly::ByteRange sample) const = 0;
/**
* Decrypts a long header from a sample.
* sample should be 16 bytes long.
* initialByte is the initial byte.
* packetNumberBytes should be supplied with at least 4 bytes.
*/
virtual void decryptLongHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const;
/**
* Decrypts a short header from a sample.
* sample should be 16 bytes long.
* initialByte is the initial byte.
* packetNumberBytes should be supplied with at least 4 bytes.
*/
virtual void decryptShortHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const;
/**
* Encrypts a long header from a sample.
* sample should be 16 bytes long.
* initialByte is the initial byte.
*/
virtual void encryptLongHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const;
/**
* Encrypts a short header from a sample.
* sample should be 16 bytes long.
* initialByte is the initial byte.
*/
virtual void encryptShortHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes) const;
/**
* Returns the length of key needed for the pn cipher.
*/
virtual size_t keyLength() const = 0;
protected:
virtual void cipherHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes,
uint8_t initialByteMask,
uint8_t packetNumLengthMask) const;
virtual void decipherHeader(
folly::ByteRange sample,
folly::MutableByteRange initialByte,
folly::MutableByteRange packetNumberBytes,
uint8_t initialByteMask,
uint8_t packetNumLengthMask) const;
};
class Aes128PacketNumberCipher : public PacketNumberCipher {
public:
~Aes128PacketNumberCipher() override = default;
void setKey(folly::ByteRange key) override;
HeaderProtectionMask mask(folly::ByteRange sample) const override;
size_t keyLength() const override;
private:
folly::ssl::EvpCipherCtxUniquePtr encryptCtx_;
};
} // namespace quic

View File

@ -0,0 +1,110 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/codec/QuicConnectionId.h>
#include <glog/logging.h>
#include <quic/QuicConstants.h>
namespace quic {
uint8_t* ConnectionId::data() {
return connid.data();
}
const uint8_t* ConnectionId::data() const {
return connid.data();
}
uint8_t ConnectionId::size() const {
return static_cast<uint8_t>(connid.size());
}
std::string ConnectionId::hex() const {
return folly::hexlify(connid);
}
ConnectionId::ConnectionId(const std::vector<uint8_t>& connidIn)
: connid(connidIn) {
static_assert(
std::numeric_limits<uint8_t>::max() > kMaxConnectionIdSize,
"Max connection size is too big");
if (connid.size() != 0 &&
(connid.size() < kMinConnectionIdSize ||
connid.size() > kMaxConnectionIdSize)) {
// We can't throw a transport error here because of the dependency. This is
// sad because this will cause an internal error downstream.
throw std::runtime_error("ConnectionId invalid size");
}
}
ConnectionId::ConnectionId(folly::io::Cursor& cursor, size_t len) {
// Zero is special case for connids.
if (len == 0) {
return;
}
if (len < kMinConnectionIdSize || len > kMaxConnectionIdSize) {
// We can't throw a transport error here because of the dependency. This is
// sad because this will cause an internal error downstream.
throw std::runtime_error("ConnectionId invalid size");
}
connid.resize(len);
cursor.pull(connid.data(), len);
}
ConnectionId ConnectionId::createWithoutChecks(
const std::vector<uint8_t>& connidIn) {
ConnectionId connid;
connid.connid = connidIn;
return connid;
}
bool ConnectionId::operator==(const ConnectionId& other) const {
return connid == other.connid;
}
bool ConnectionId::operator!=(const ConnectionId& other) const {
return !operator==(other);
}
void ServerConnectionIdParams::setVersion(uint8_t versionIn) {
version = versionIn;
}
void ServerConnectionIdParams::setHostId(uint16_t hostIdIn) {
hostId = hostIdIn;
}
void ServerConnectionIdParams::setProcessId(uint8_t processIdIn) {
processId = processIdIn;
}
void ServerConnectionIdParams::setWorkerId(uint8_t workerIdIn) {
workerId = workerIdIn;
}
std::pair<uint8_t, uint8_t> decodeConnectionIdLengths(uint8_t connIdSize) {
uint8_t dcidLen = (connIdSize >> 4);
uint8_t scidLen = connIdSize & 0x0F;
dcidLen = dcidLen == 0 ? 0 : dcidLen + 3;
scidLen = scidLen == 0 ? 0 : scidLen + 3;
return std::make_pair(dcidLen, scidLen);
}
uint8_t encodeConnectionIdLengths(
uint8_t destinationConnectionIdSize,
uint8_t sourceConnectionIdSize) {
DCHECK_LE(destinationConnectionIdSize, kMaxConnectionIdSize);
DCHECK_LE(sourceConnectionIdSize, kMaxConnectionIdSize);
uint8_t dstByte =
destinationConnectionIdSize == 0 ? 0 : destinationConnectionIdSize - 3;
uint8_t srcByte =
sourceConnectionIdSize == 0 ? 0 : sourceConnectionIdSize - 3;
return ((dstByte << 4)) | (srcByte);
}
} // namespace quic

View File

@ -0,0 +1,138 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <folly/Optional.h>
#include <folly/String.h>
#include <folly/hash/Hash.h>
#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
#include <array>
namespace quic {
constexpr size_t kMinConnectionIdSize = 4;
constexpr size_t kMaxConnectionIdSize = 18;
// set conn id version at the first 4 bits
constexpr uint8_t kShortVersionId = 0x1;
struct ConnectionId {
uint8_t* data();
const uint8_t* data() const;
uint8_t size() const;
explicit ConnectionId(const std::vector<uint8_t>& connidIn);
explicit ConnectionId(folly::io::Cursor& cursor, size_t len);
bool operator==(const ConnectionId& other) const;
bool operator!=(const ConnectionId& other) const;
std::string hex() const;
/**
* Create an connection without any checks for tests.
*/
static ConnectionId createWithoutChecks(const std::vector<uint8_t>& connidIn);
private:
ConnectionId() = default;
std::vector<uint8_t> connid;
};
struct ConnectionIdHash {
size_t operator()(const ConnectionId& connId) const {
return folly::hash::fnv32_buf(connId.data(), connId.size());
}
};
inline std::ostream& operator<<(std::ostream& os, const ConnectionId& connId) {
os << connId.hex();
return os;
}
inline folly::IOBuf toData(const ConnectionId& connId) {
return folly::IOBuf::wrapBufferAsValue(connId.data(), connId.size());
}
/**
* Encapsulate parameters to generate server chosen connection id
*/
struct ServerConnectionIdParams {
explicit ServerConnectionIdParams(
uint16_t hostIdIn,
uint8_t processIdIn,
uint8_t workerIdIn)
: ServerConnectionIdParams(
kShortVersionId,
hostIdIn,
processIdIn,
workerIdIn) {}
explicit ServerConnectionIdParams(
uint8_t versionIn,
uint16_t hostIdIn,
uint8_t processIdIn,
uint8_t workerIdIn) {
setVersion(versionIn);
setHostId(hostIdIn);
setProcessId(processIdIn);
setWorkerId(workerIdIn);
}
/**
* Set Quic connection-id short version
*/
void setVersion(uint8_t versionIn);
/**
* Set Quic Host id
*/
void setHostId(uint16_t hostIdIn);
/**
* Set Quic process id
*/
void setProcessId(uint8_t processIdIn);
/**
* Set Quic server worker Id
*/
void setWorkerId(uint8_t workerIdIn);
folly::Optional<ConnectionId> clientConnId;
// Quic connection-id short version
uint8_t version{0};
// Quic Host id
uint16_t hostId{0};
// Quic process id
uint8_t processId{0};
// Quic server worker Id
uint8_t workerId{0};
};
/**
* Returns a pair of length of the connection ids decoded from the long header.
* Returns (Destination connid length, Source connid length)
*/
std::pair<uint8_t, uint8_t> decodeConnectionIdLengths(uint8_t connIdSize);
/**
* Given 2 connection ids, encodes their lengths in the wire format for the Quic
* long header.
*/
uint8_t encodeConnectionIdLengths(
uint8_t destinationConnectionIdSize,
uint8_t sourceConnectionIdSize);
} // namespace quic

View File

@ -0,0 +1,55 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/codec/QuicHeaderCodec.h>
#include <quic/QuicException.h>
#include <quic/codec/Decode.h>
namespace quic {
ParsedHeader::ParsedHeader(PacketHeader headerIn)
: header(std::move(headerIn)) {}
ParsedHeaderResult::ParsedHeaderResult(
bool isVersionNegotiationIn,
folly::Optional<ParsedHeader> parsedHeaderIn)
: isVersionNegotiation(isVersionNegotiationIn),
parsedHeader(std::move(parsedHeaderIn)) {
CHECK(isVersionNegotiation || parsedHeader);
}
folly::Expected<ParsedHeaderResult, TransportErrorCode> parseHeader(
const folly::IOBuf& data) {
folly::io::Cursor cursor(&data);
if (!cursor.canAdvance(sizeof(uint8_t))) {
return folly::makeUnexpected(TransportErrorCode::FRAME_ENCODING_ERROR);
}
uint8_t initialByte = cursor.readBE<uint8_t>();
if (getHeaderForm(initialByte) == HeaderForm::Long) {
return parseLongHeader(initialByte, cursor)
.then([](ParsedLongHeaderResult&& parsedLongHeaderResult) {
if (parsedLongHeaderResult.isVersionNegotiation) {
return ParsedHeaderResult(true, folly::none);
}
// We compensate for the type byte length by adding it back.
DCHECK(parsedLongHeaderResult.parsedLongHeader);
ParsedHeader parsedHeader(PacketHeader(
std::move(parsedLongHeaderResult.parsedLongHeader->header)));
return ParsedHeaderResult(false, parsedHeader);
});
} else {
return parseShortHeader(initialByte, cursor).then([](ShortHeader&& header) {
return ParsedHeaderResult(
false, ParsedHeader(PacketHeader(std::move(header))));
});
}
}
} // namespace quic

View File

@ -0,0 +1,33 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Optional.h>
#include <quic/codec/PacketNumber.h>
#include <quic/codec/Types.h>
namespace quic {
struct ParsedHeader {
PacketHeader header;
explicit ParsedHeader(PacketHeader headerIn);
};
struct ParsedHeaderResult {
bool isVersionNegotiation;
folly::Optional<ParsedHeader> parsedHeader;
ParsedHeaderResult(
bool isVersionNegotiationIn,
folly::Optional<ParsedHeader> parsedHeaderIn);
};
folly::Expected<ParsedHeaderResult, TransportErrorCode> parseHeader(
const folly::IOBuf& data);
} // namespace quic

131
quic/codec/QuicInteger.cpp Normal file
View File

@ -0,0 +1,131 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/codec/QuicInteger.h>
#include <folly/Conv.h>
namespace quic {
folly::Expected<size_t, TransportErrorCode> getQuicIntegerSize(uint64_t value) {
if (value <= kOneByteLimit) {
return 1;
} else if (value <= kTwoByteLimit) {
return 2;
} else if (value <= kFourByteLimit) {
return 4;
} else if (value <= kEightByteLimit) {
return 8;
}
return folly::makeUnexpected(TransportErrorCode::INTERNAL_ERROR);
}
uint8_t decodeQuicIntegerLength(uint8_t firstByte) {
return (1 << ((firstByte >> 6) & 0x03));
}
folly::Expected<size_t, TransportErrorCode> encodeQuicInteger(
uint64_t value,
folly::io::QueueAppender& appender) {
if (value <= kOneByteLimit) {
uint8_t modified = static_cast<uint8_t>(value);
appender.writeBE(modified);
return sizeof(modified);
} else if (value <= kTwoByteLimit) {
uint16_t reduced = static_cast<uint16_t>(value);
uint16_t modified = reduced | 0x4000;
appender.writeBE(modified);
return sizeof(modified);
} else if (value <= kFourByteLimit) {
uint32_t reduced = static_cast<uint32_t>(value);
uint32_t modified = reduced | 0x80000000;
appender.writeBE(modified);
return sizeof(modified);
} else if (value <= kEightByteLimit) {
uint64_t modified = value | 0xC000000000000000;
appender.writeBE(modified);
return sizeof(modified);
}
return folly::makeUnexpected(TransportErrorCode::INTERNAL_ERROR);
}
folly::Optional<std::pair<uint64_t, size_t>> decodeQuicInteger(
folly::io::Cursor& cursor,
uint64_t atMost) {
size_t numBytes = 0;
size_t advanceLen = 0;
uint64_t result = 0;
if (atMost < 1 || !cursor.canAdvance(1)) {
VLOG(10) << "Not enough bytes to decode integer, cursor len="
<< cursor.totalLength();
return folly::none;
}
const uint8_t firstByte = *cursor.peekBytes().data();
const uint8_t varintType = (firstByte >> 6) & 0x03;
const uint8_t unmaskedFirstByte = firstByte & 0x3F;
uint8_t* resultPtr = reinterpret_cast<uint8_t*>(&result);
switch (varintType) {
case 0:
// short circuit for 1 byte.
cursor.skip(1);
return std::make_pair((uint64_t)unmaskedFirstByte, (size_t)1);
case 1:
advanceLen = 6;
numBytes = 1;
break;
case 2:
advanceLen = 4;
numBytes = 3;
break;
case 3:
numBytes = 7;
break;
}
if (atMost < (numBytes + 1) || !cursor.canAdvance(numBytes + 1)) {
VLOG(10) << "Could not decode integer numBytes="
<< static_cast<int>(numBytes + 1) << " firstByte=" << std::hex
<< static_cast<int>(firstByte);
return folly::none;
}
cursor.skip(1);
memcpy(resultPtr + advanceLen, &unmaskedFirstByte, 1);
cursor.pull(resultPtr + advanceLen + 1, numBytes);
// make the data dependency on resultPtr explicit to avoid strict
// aliasing issues.
return std::make_pair(
folly::Endian::big(*reinterpret_cast<uint64_t*>(resultPtr)),
numBytes + 1);
}
QuicInteger::QuicInteger(uint64_t value) : value_(value) {}
size_t QuicInteger::getSize() const {
auto size = getQuicIntegerSize(value_);
if (size.hasError()) {
LOG(ERROR) << "Value too large value=" << value_;
throw QuicTransportException(
folly::to<std::string>("Value too large ", value_), size.error());
}
return size.value();
}
size_t QuicInteger::encode(folly::io::QueueAppender& appender) const {
auto size = encodeQuicInteger(value_, appender);
if (size.hasError()) {
LOG(ERROR) << "Value too large value=" << value_;
throw QuicTransportException(
folly::to<std::string>("Value too large ", value_), size.error());
}
return size.value();
}
uint64_t QuicInteger::getValue() const {
return value_;
}
} // namespace quic

80
quic/codec/QuicInteger.h Normal file
View File

@ -0,0 +1,80 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#pragma once
#include <folly/Optional.h>
#include <folly/io/Cursor.h>
#include <folly/lang/Bits.h>
#include <quic/QuicException.h>
namespace quic {
constexpr uint64_t kOneByteLimit = 0x3F;
constexpr uint64_t kTwoByteLimit = 0x3FFF;
constexpr uint64_t kFourByteLimit = 0x3FFFFFFF;
constexpr uint64_t kEightByteLimit = 0x3FFFFFFFFFFFFFFF;
/**
* Encodes the integer and writes it out to appender. Returns the number of
* bytes written, or an error if value is too large to be represented with the
* variable length encoding.
*/
folly::Expected<size_t, TransportErrorCode> encodeQuicInteger(
uint64_t value,
folly::io::QueueAppender& appender);
/**
* Reads an integer out of the cursor and returns a pair with the integer and
* the numbers of bytes read, or folly::none if there are not enough bytes to
* read the int. It only advances the cursor in case of success.
*/
folly::Optional<std::pair<uint64_t, size_t>> decodeQuicInteger(
folly::io::Cursor& cursor,
uint64_t atMost = std::numeric_limits<uint64_t>::max());
/**
* Returns the length of a quic integer given the first byte
*/
uint8_t decodeQuicIntegerLength(uint8_t firstByte);
/**
* Returns number of bytes needed to encode value as a QUIC integer, or an error
* if value is too large to be represented with the variable
* length encoding
*/
folly::Expected<size_t, TransportErrorCode> getQuicIntegerSize(uint64_t value);
/**
* A better API for dealing with QUIC integers for encoding.
*/
class QuicInteger {
public:
explicit QuicInteger(uint64_t value);
/**
* Encodes a QUIC integer to the appender.
*/
size_t encode(folly::io::QueueAppender& appender) const;
/**
* Returns the number of bytes needed to represent the QUIC integer in
* its encoded form.
**/
size_t getSize() const;
/**
* Returns the real value of the QUIC integer that it was instantiated with.
* This should normally never be used.
*/
uint64_t getValue() const;
private:
uint64_t value_;
};
} // namespace quic

View File

@ -0,0 +1,337 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/codec/QuicPacketBuilder.h>
#include <folly/Random.h>
#include <quic/codec/PacketNumber.h>
namespace {
// maximum length of packet length.
constexpr auto kMaxPacketLenSize = sizeof(uint16_t);
} // namespace
namespace quic {
PacketNumEncodingResult encodeLongHeaderHelper(
const LongHeader& longHeader,
folly::io::QueueAppender& appender,
uint32_t& spaceCounter,
PacketNum largestAckedPacketNum);
PacketNumEncodingResult encodeLongHeaderHelper(
const LongHeader& longHeader,
folly::io::QueueAppender& appender,
uint32_t& spaceCounter,
PacketNum largestAckedPacketNum) {
uint8_t initialByte = kHeaderFormMask | LongHeader::kFixedBitMask |
(static_cast<uint8_t>(longHeader.getHeaderType())
<< LongHeader::kTypeShift);
PacketNumEncodingResult encodedPacketNum = encodePacketNumber(
longHeader.getPacketSequenceNum(), largestAckedPacketNum);
initialByte &= ~LongHeader::kReservedBitsMask;
initialByte |= (encodedPacketNum.length - 1);
if (longHeader.getHeaderType() == LongHeader::Types::Retry) {
initialByte &= 0xF0;
auto odcidSize = longHeader.getOriginalDstConnId()->size();
initialByte |= (odcidSize == 0 ? 0 : odcidSize - 3);
}
appender.writeBE<uint8_t>(initialByte);
bool isInitial = longHeader.getHeaderType() == LongHeader::Types::Initial;
uint64_t tokenHeaderLength = 0;
auto token = longHeader.getToken();
if (isInitial) {
uint64_t tokenLength = token ? token->coalesce().size() : 0;
QuicInteger tokenLengthInt(tokenLength);
tokenHeaderLength = tokenLengthInt.getSize() + tokenLength;
}
auto longHeaderSize = sizeof(uint8_t) /* initialByte */ +
sizeof(QuicVersionType) + sizeof(uint8_t) /* DCIL | SCIL */ +
longHeader.getSourceConnId().size() +
longHeader.getDestinationConnId().size() + tokenHeaderLength +
kMaxPacketLenSize + encodedPacketNum.length;
if (spaceCounter < longHeaderSize) {
spaceCounter = 0;
} else {
spaceCounter -= longHeaderSize;
}
appender.writeBE<uint32_t>(folly::to<uint32_t>(longHeader.getVersion()));
auto connidSize = encodeConnectionIdLengths(
longHeader.getDestinationConnId().size(),
longHeader.getSourceConnId().size());
appender.writeBE<uint8_t>(connidSize);
appender.push(
longHeader.getDestinationConnId().data(),
longHeader.getDestinationConnId().size());
appender.push(
longHeader.getSourceConnId().data(), longHeader.getSourceConnId().size());
if (isInitial) {
uint64_t tokenLength = token ? token->coalesce().size() : 0;
QuicInteger tokenLengthInt(tokenLength);
tokenLengthInt.encode(appender);
if (tokenLength > 0) {
appender.push(token->coalesce());
}
}
if (longHeader.getHeaderType() == LongHeader::Types::Retry) {
auto& originalDstConnId = longHeader.getOriginalDstConnId();
appender.push(originalDstConnId->data(), originalDstConnId->size());
// Write the retry token
appender.insert(*token);
}
// defer write of the packet num and length till payload has been computed
return encodedPacketNum;
}
RegularQuicPacketBuilder::RegularQuicPacketBuilder(
uint32_t remainingBytes,
PacketHeader header,
PacketNum largestAckedPacketNum)
: remainingBytes_(remainingBytes),
packet_(std::move(header)),
headerAppender_(&header_, kLongHeaderHeaderSize),
bodyAppender_(&outputQueue_, kAppenderGrowthSize) {
writeHeaderBytes(largestAckedPacketNum);
}
uint32_t RegularQuicPacketBuilder::getHeaderBytes() const {
return folly::to<uint32_t>(header_.chainLength());
}
uint32_t RegularQuicPacketBuilder::remainingSpaceInPkt() const {
return remainingBytes_;
}
void RegularQuicPacketBuilder::writeBE(uint8_t data) {
bodyAppender_.writeBE<uint8_t>(data);
remainingBytes_ -= sizeof(data);
}
void RegularQuicPacketBuilder::writeBE(uint16_t data) {
bodyAppender_.writeBE<uint16_t>(data);
remainingBytes_ -= sizeof(data);
}
void RegularQuicPacketBuilder::writeBE(uint64_t data) {
bodyAppender_.writeBE<uint64_t>(data);
remainingBytes_ -= sizeof(data);
}
void RegularQuicPacketBuilder::write(const QuicInteger& quicInteger) {
remainingBytes_ -= quicInteger.encode(bodyAppender_);
}
void RegularQuicPacketBuilder::appendBytes(
PacketNum value,
uint8_t byteNumber) {
appendBytes(bodyAppender_, value, byteNumber);
}
void RegularQuicPacketBuilder::appendBytes(
folly::io::QueueAppender& appender,
PacketNum value,
uint8_t byteNumber) {
appender.ensure(byteNumber);
auto bigValue = folly::Endian::big(value);
appender.push(
(uint8_t*)&bigValue + sizeof(bigValue) - byteNumber, byteNumber);
remainingBytes_ -= byteNumber;
}
void RegularQuicPacketBuilder::insert(std::unique_ptr<folly::IOBuf> buf) {
remainingBytes_ -= buf->computeChainDataLength();
bodyAppender_.insert(std::move(buf));
}
void RegularQuicPacketBuilder::appendFrame(QuicWriteFrame frame) {
quicFrames_.push_back(std::move(frame));
}
RegularQuicPacketBuilder::Packet RegularQuicPacketBuilder::buildPacket() && {
// at this point everything should been set in the packet_
bool isLongHeader = folly::variant_match(
packet_.header,
[](const LongHeader&) { return true; },
[](const ShortHeader&) { return false; });
size_t minBodySize = kMaxPacketNumEncodingSize -
packetNumberEncoding_->length + sizeof(Sample);
while (outputQueue_.chainLength() + cipherOverhead_ < minBodySize &&
!quicFrames_.empty() && remainingBytes_ > kMaxPacketLenSize) {
quicFrames_.push_back(PaddingFrame());
QuicInteger paddingType(static_cast<uint8_t>(FrameType::PADDING));
write(paddingType);
}
packet_.frames = std::move(quicFrames_);
if (isLongHeader &&
boost::get<LongHeader>(packet_.header).getHeaderType() !=
LongHeader::Types::Retry) {
QuicInteger pktLen(
packetNumberEncoding_->length + outputQueue_.chainLength() +
cipherOverhead_);
pktLen.encode(headerAppender_);
appendBytes(
headerAppender_,
packetNumberEncoding_->result,
packetNumberEncoding_->length);
}
return Packet(std::move(packet_), header_.move(), outputQueue_.move());
}
void RegularQuicPacketBuilder::writeHeaderBytes(
PacketNum largestAckedPacketNum) {
if (packet_.header.type() == typeid(LongHeader)) {
LongHeader& longHeader = boost::get<LongHeader>(packet_.header);
encodeLongHeader(longHeader, largestAckedPacketNum);
} else {
ShortHeader& shortHeader = boost::get<ShortHeader>(packet_.header);
encodeShortHeader(shortHeader, largestAckedPacketNum);
}
}
void RegularQuicPacketBuilder::encodeLongHeader(
const LongHeader& longHeader,
PacketNum largestAckedPacketNum) {
packetNumberEncoding_ = encodeLongHeaderHelper(
longHeader, headerAppender_, remainingBytes_, largestAckedPacketNum);
}
void RegularQuicPacketBuilder::encodeShortHeader(
const ShortHeader& shortHeader,
PacketNum largestAckedPacketNum) {
packetNumberEncoding_ = encodePacketNumber(
shortHeader.getPacketSequenceNum(), largestAckedPacketNum);
if (remainingBytes_ < 1U + packetNumberEncoding_->length +
shortHeader.getConnectionId().size()) {
remainingBytes_ = 0;
return;
}
folly::io::QueueAppender appender(&header_, kAppenderGrowthSize);
uint8_t initialByte =
ShortHeader::kFixedBitMask | (packetNumberEncoding_->length - 1);
initialByte &= ~ShortHeader::kReservedBitsMask;
if (shortHeader.getProtectionType() == ProtectionType::KeyPhaseOne) {
initialByte |= ShortHeader::kKeyPhaseMask;
}
appender.writeBE<uint8_t>(initialByte);
--remainingBytes_;
appender.push(
shortHeader.getConnectionId().data(),
shortHeader.getConnectionId().size());
remainingBytes_ -= shortHeader.getConnectionId().size();
appendBytes(
appender, packetNumberEncoding_->result, packetNumberEncoding_->length);
}
void RegularQuicPacketBuilder::push(const uint8_t* data, size_t len) {
bodyAppender_.push(data, len);
remainingBytes_ -= len;
}
bool RegularQuicPacketBuilder::canBuildPacket() const noexcept {
return remainingBytes_ != 0;
}
const PacketHeader& RegularQuicPacketBuilder::getPacketHeader() const {
return packet_.header;
}
void RegularQuicPacketBuilder::setCipherOverhead(uint8_t overhead) noexcept {
cipherOverhead_ = overhead;
}
StatelessResetPacketBuilder::StatelessResetPacketBuilder(
uint16_t maxPacketSize,
const StatelessResetToken& resetToken) {
folly::io::QueueAppender appender(&outputQueue_, kAppenderGrowthSize);
// TODO: randomize the length
uint16_t randomOctetLength = maxPacketSize - resetToken.size() - 1;
uint8_t initialByte = ShortHeader::kFixedBitMask;
appender.writeBE<uint8_t>(initialByte);
auto randomOctets = folly::IOBuf::create(randomOctetLength);
folly::Random::secureRandom(randomOctets->writableData(), randomOctetLength);
appender.pushAtMost(randomOctets->data(), randomOctetLength);
appender.push(resetToken.data(), resetToken.size());
}
Buf StatelessResetPacketBuilder::buildPacket() && {
return outputQueue_.move();
}
VersionNegotiationPacketBuilder::VersionNegotiationPacketBuilder(
ConnectionId sourceConnectionId,
ConnectionId destinationConnectionId,
const std::vector<QuicVersion>& versions)
: remainingBytes_(kDefaultUDPSendPacketLen),
packet_(
generateRandomPacketType(),
sourceConnectionId,
destinationConnectionId),
appender_(&outputQueue_, kAppenderGrowthSize) {
writeVersionNegotiationPacket(versions);
}
uint32_t VersionNegotiationPacketBuilder::remainingSpaceInPkt() {
return remainingBytes_;
}
std::pair<VersionNegotiationPacket, Buf>
VersionNegotiationPacketBuilder::buildPacket() && {
return std::make_pair<VersionNegotiationPacket, Buf>(
std::move(packet_), outputQueue_.move());
}
void VersionNegotiationPacketBuilder::writeVersionNegotiationPacket(
const std::vector<QuicVersion>& versions) {
// Write header
appender_.writeBE<decltype(packet_.packetType)>(packet_.packetType);
remainingBytes_ -= sizeof(decltype(packet_.packetType));
appender_.writeBE(
static_cast<QuicVersionType>(QuicVersion::VERSION_NEGOTIATION));
remainingBytes_ -= sizeof(QuicVersionType);
auto connidSize = encodeConnectionIdLengths(
packet_.destinationConnectionId.size(),
packet_.sourceConnectionId.size());
appender_.writeBE<uint8_t>(connidSize);
remainingBytes_ -= sizeof(uint8_t);
appender_.push(
packet_.destinationConnectionId.data(),
packet_.destinationConnectionId.size());
remainingBytes_ -= packet_.destinationConnectionId.size();
appender_.push(
packet_.sourceConnectionId.data(), packet_.sourceConnectionId.size());
remainingBytes_ -= packet_.sourceConnectionId.size();
// Write versions
for (auto version : versions) {
if (remainingBytes_ < sizeof(QuicVersionType)) {
break;
}
appender_.writeBE<QuicVersionType>(static_cast<QuicVersionType>(version));
remainingBytes_ -= sizeof(QuicVersionType);
packet_.versions.push_back(version);
}
}
uint8_t VersionNegotiationPacketBuilder::generateRandomPacketType() const {
// TODO: change this back to generating random packet type after we rollout
// draft-13. For now the 0 packet type will make sure that the version
// negotiation packet is not interpreted as a long header.
// folly::Random::secureRandom<decltype(packet_.packetType)>();
return kHeaderFormMask;
}
bool VersionNegotiationPacketBuilder::canBuildPacket() const noexcept {
return remainingBytes_ != 0;
}
} // namespace quic

View File

@ -0,0 +1,253 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <quic/codec/PacketNumber.h>
#include <quic/codec/QuicInteger.h>
#include <quic/codec/Types.h>
#include <quic/handshake/HandshakeLayer.h>
namespace quic {
// We reserve 2 bytes for packet length in the long headers
constexpr auto kReservedPacketLenSize = sizeof(uint16_t);
// We reserve 4 bytes for packet number in the long headers.
constexpr auto kReservedPacketNumSize = kMaxPacketNumEncodingSize;
// Note a full PacketNum has 64 bits, but LongHeader only uses 32 bits of them
// This is based on Draft-15
constexpr auto kLongHeaderHeaderSize = sizeof(uint8_t) /* Type bytes */ +
sizeof(QuicVersionType) /* Version */ + sizeof(uint8_t) /* DCIL + SCIL */ +
kDefaultConnectionIdSize * 2 /* 2 connection IDs */ +
kReservedPacketLenSize /* minimal size of length */ +
kReservedPacketNumSize /* packet number */;
// A possible cipher overhead. The real overhead depends on the AEAD we will
// use. But we need a ball-park value when deciding if we should schedule a
// write.
constexpr auto kCipherOverheadHeuristic = 16;
// TODO: i'm sure this isn't the optimal value:
// IOBufQueue growth byte size for in PacketBuilder:
constexpr size_t kAppenderGrowthSize = 100;
class PacketBuilderInterface {
public:
virtual ~PacketBuilderInterface() = default;
virtual uint32_t remainingSpaceInPkt() const = 0;
// Functions to write bytes to the packet
virtual void writeBE(uint8_t data) = 0;
virtual void writeBE(uint16_t data) = 0;
virtual void writeBE(uint64_t data) = 0;
virtual void write(const QuicInteger& quicInteger) = 0;
virtual void appendBytes(PacketNum value, uint8_t byteNumber) = 0;
virtual void appendBytes(
folly::io::QueueAppender& appender,
PacketNum value,
uint8_t byteNumber) = 0;
virtual void insert(std::unique_ptr<folly::IOBuf> buf) = 0;
virtual void push(const uint8_t* data, size_t len) = 0;
// Append a frame to the packet.
virtual void appendFrame(QuicWriteFrame frame) = 0;
// Returns the packet header for the current packet.
virtual const PacketHeader& getPacketHeader() const = 0;
};
class RegularQuicPacketBuilder : public PacketBuilderInterface {
public:
~RegularQuicPacketBuilder() override = default;
RegularQuicPacketBuilder(RegularQuicPacketBuilder&&) = default;
struct Packet {
RegularQuicWritePacket packet;
Buf header;
Buf body;
Packet(RegularQuicWritePacket packetIn, Buf headerIn, Buf bodyIn)
: packet(std::move(packetIn)),
header(std::move(headerIn)),
body(std::move(bodyIn)) {}
};
RegularQuicPacketBuilder(
uint32_t remainingBytes,
PacketHeader header,
PacketNum largestAckedPacketNum);
uint32_t getHeaderBytes() const;
// PacketBuilderInterface
uint32_t remainingSpaceInPkt() const override;
void writeBE(uint8_t data) override;
void writeBE(uint16_t data) override;
void writeBE(uint64_t data) override;
void write(const QuicInteger& quicInteger) override;
void appendBytes(PacketNum value, uint8_t byteNumber) override;
void appendBytes(
folly::io::QueueAppender& appender,
PacketNum value,
uint8_t byteNumber) override;
void insert(std::unique_ptr<folly::IOBuf> buf) override;
void push(const uint8_t* data, size_t len) override;
void appendFrame(QuicWriteFrame frame) override;
const PacketHeader& getPacketHeader() const override;
Packet buildPacket() &&;
/**
* Whether the packet builder is able to build a packet. This should be
* checked right after the creation of a packet builder object.
*/
bool canBuildPacket() const noexcept;
void setCipherOverhead(uint8_t overhead) noexcept;
private:
void writeHeaderBytes(PacketNum largestAckedPacketNum);
void encodeLongHeader(
const LongHeader& longHeader,
PacketNum largestAckedPacketNum);
void encodeShortHeader(
const ShortHeader& shortHeader,
PacketNum largestAckedPacketNum);
private:
uint32_t remainingBytes_;
RegularQuicWritePacket packet_;
std::vector<QuicWriteFrame> quicFrames_;
folly::IOBufQueue header_{folly::IOBufQueue::cacheChainLength()};
folly::IOBufQueue outputQueue_{folly::IOBufQueue::cacheChainLength()};
folly::io::QueueAppender headerAppender_;
folly::io::QueueAppender bodyAppender_;
uint32_t cipherOverhead_{0};
folly::Optional<PacketNumEncodingResult> packetNumberEncoding_;
};
class VersionNegotiationPacketBuilder {
public:
explicit VersionNegotiationPacketBuilder(
ConnectionId sourceConnectionId,
ConnectionId destinationConnectionId,
const std::vector<QuicVersion>& versions);
virtual ~VersionNegotiationPacketBuilder() = default;
uint32_t remainingSpaceInPkt();
std::pair<VersionNegotiationPacket, Buf> buildPacket() &&;
/**
* Whether the packet builder is able to build a packet. This should be
* checked right after the creation of a packet builder object.
*/
bool canBuildPacket() const noexcept;
private:
void writeVersionNegotiationPacket(const std::vector<QuicVersion>& versions);
uint8_t generateRandomPacketType() const;
private:
uint32_t remainingBytes_;
VersionNegotiationPacket packet_;
folly::IOBufQueue outputQueue_{folly::IOBufQueue::cacheChainLength()};
folly::io::QueueAppender appender_;
};
class StatelessResetPacketBuilder {
public:
StatelessResetPacketBuilder(
uint16_t maxPacketSize,
const StatelessResetToken& resetToken);
Buf buildPacket() &&;
private:
folly::IOBufQueue outputQueue_{folly::IOBufQueue::cacheChainLength()};
};
/**
* A PacketBuilder that wraps in another PacketBuilder that may have a different
* writableBytes limit. The minimum between the limit will be used to limit the
* packet it can build.
*/
class PacketBuilderWrapper : public PacketBuilderInterface {
public:
~PacketBuilderWrapper() override = default;
PacketBuilderWrapper(
PacketBuilderInterface& builderIn,
uint32_t writableBytes)
: builder(builderIn),
diff(
writableBytes > builder.remainingSpaceInPkt()
? 0
: builder.remainingSpaceInPkt() - writableBytes) {}
uint32_t remainingSpaceInPkt() const override {
return builder.remainingSpaceInPkt() > diff
? builder.remainingSpaceInPkt() - diff
: 0;
}
void write(const QuicInteger& quicInteger) override {
builder.write(quicInteger);
}
void writeBE(uint8_t value) override {
builder.writeBE(value);
}
void writeBE(uint16_t value) override {
builder.writeBE(value);
}
void writeBE(uint64_t value) override {
builder.writeBE(value);
}
void appendBytes(PacketNum value, uint8_t byteNumber) override {
builder.appendBytes(value, byteNumber);
}
void appendBytes(
folly::io::QueueAppender& appender,
PacketNum value,
uint8_t byteNumber) override {
builder.appendBytes(appender, value, byteNumber);
}
void insert(std::unique_ptr<folly::IOBuf> buf) override {
builder.insert(std::move(buf));
}
void appendFrame(QuicWriteFrame frame) override {
builder.appendFrame(std::move(frame));
}
void push(const uint8_t* data, size_t len) override {
builder.push(data, len);
}
const PacketHeader& getPacketHeader() const override {
return builder.getPacketHeader();
}
private:
PacketBuilderInterface& builder;
uint32_t diff;
};
} // namespace quic

View File

@ -0,0 +1,240 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/codec/QuicPacketRebuilder.h>
#include <quic/codec/QuicWriteCodec.h>
#include <quic/flowcontrol/QuicFlowController.h>
#include <quic/state/QuicStreamFunctions.h>
#include <quic/state/SimpleFrameFunctions.h>
namespace quic {
PacketRebuilder::PacketRebuilder(
RegularQuicPacketBuilder& regularBuilder,
QuicConnectionStateBase& conn)
: builder_(regularBuilder), conn_(conn) {}
uint64_t PacketRebuilder::getHeaderBytes() const {
return builder_.getHeaderBytes();
}
PacketEvent PacketRebuilder::cloneOutstandingPacket(OutstandingPacket& packet) {
// Either the packet has never been cloned before, or it's associatedEvent is
// still in the outstandingPacketEvents set.
DCHECK(
!packet.associatedEvent ||
conn_.outstandingPacketEvents.count(*packet.associatedEvent));
if (!packet.associatedEvent) {
auto packetNum = folly::variant_match(
packet.packet.header, [](auto& h) { return h.getPacketSequenceNum(); });
DCHECK(!conn_.outstandingPacketEvents.count(packetNum));
packet.associatedEvent = packetNum;
conn_.outstandingPacketEvents.insert(packetNum);
++conn_.outstandingClonedPacketsCount;
}
return *packet.associatedEvent;
}
folly::Optional<PacketEvent> PacketRebuilder::rebuildFromPacket(
OutstandingPacket& packet) {
// TODO: if PMTU changes between the transmission of the original packet and
// now, then we cannot clone everything in the packet.
// TODO: make sure this cannot be called on handshake packets.
bool writeSuccess = false;
bool windowUpdateWritten = false;
bool shouldWriteWindowUpdate = false;
bool notPureAck = false;
for (auto iter = packet.packet.frames.cbegin();
iter != packet.packet.frames.cend();
iter++) {
const QuicWriteFrame& frame = *iter;
writeSuccess = folly::variant_match(
frame,
[&](const WriteAckFrame& ackFrame) {
uint64_t ackDelayExponent = folly::variant_match(
builder_.getPacketHeader(),
[](const LongHeader&) { return kDefaultAckDelayExponent; },
[&](const auto&) {
return conn_.transportSettings.ackDelayExponent;
});
AckFrameMetaData meta(
ackFrame.ackBlocks, ackFrame.ackDelay, ackDelayExponent);
auto ackWriteResult = writeAckFrame(meta, builder_);
return ackWriteResult.hasValue();
},
[&](const WriteStreamFrame& streamFrame) {
auto stream = conn_.streamManager->getStream(streamFrame.streamId);
if (stream && retransmittable(*stream)) {
StreamFrameMetaData meta(
streamFrame.streamId,
streamFrame.offset,
streamFrame.fin,
cloneRetransmissionBuffer(streamFrame, stream),
true);
auto streamWriteResult = writeStreamFrame(meta, builder_);
bool ret = streamWriteResult.hasValue() &&
streamWriteResult->bytesWritten == streamFrame.len &&
streamWriteResult->finWritten == streamFrame.fin;
notPureAck |= ret;
return ret;
}
// If a stream is already Closed, or HalfClosedLocal, we should not
// clone and resend this stream data. But should we abort the cloning
// of this packet and move on to the next packet? I'm gonna err on the
// aggressive side for now and call it success.
return true;
},
[&](const WriteCryptoFrame& cryptoFrame) {
// initialStream and handshakeStream can only be in handshake packet,
// so they are not clonable
CHECK(!packet.isHandshake);
folly::variant_match(packet.packet.header, [](const auto& header) {
// key update not supported
CHECK(header.getProtectionType() == ProtectionType::KeyPhaseZero);
});
auto& stream = conn_.cryptoState->oneRttStream;
auto buf = cloneCryptoRetransmissionBuffer(cryptoFrame, stream);
// No crypto data found to be cloned, just skip
if (!buf) {
return true;
}
auto cryptoWriteResult =
writeCryptoFrame(cryptoFrame.offset, std::move(buf), builder_);
bool ret = cryptoWriteResult.hasValue() &&
cryptoWriteResult->offset == cryptoFrame.offset &&
cryptoWriteResult->len == cryptoFrame.len;
notPureAck |= ret;
return ret;
},
[&](const MaxDataFrame&) {
shouldWriteWindowUpdate = true;
auto ret = 0 != writeFrame(generateMaxDataFrame(conn_), builder_);
windowUpdateWritten |= ret;
notPureAck |= ret;
return true;
},
[&](const MaxStreamDataFrame& maxStreamDataFrame) {
auto stream =
conn_.streamManager->getStream(maxStreamDataFrame.streamId);
if (!stream || !stream->shouldSendFlowControl()) {
return true;
}
shouldWriteWindowUpdate = true;
auto ret =
0 != writeFrame(generateMaxStreamDataFrame(*stream), builder_);
windowUpdateWritten |= ret;
notPureAck |= ret;
return true;
},
[&](const PaddingFrame& paddingFrame) {
return writeFrame(paddingFrame, builder_) != 0;
},
[&](const QuicSimpleFrame& simpleFrame) {
auto updatedSimpleFrame =
updateSimpleFrameOnPacketClone(conn_, simpleFrame);
if (!updatedSimpleFrame) {
return true;
}
bool ret =
writeSimpleFrame(std::move(*updatedSimpleFrame), builder_) != 0;
notPureAck |= ret;
return ret;
},
[&](const auto& otherFrame) {
bool ret = writeFrame(otherFrame, builder_) != 0;
notPureAck |= ret;
return ret;
});
if (!writeSuccess) {
return folly::none;
}
}
// We shouldn't clone if:
// (1) we only end up cloning acks and paddings.
// (2) we should write window update, but didn't, and wrote nothing else.
if (!notPureAck ||
(shouldWriteWindowUpdate && !windowUpdateWritten && !writeSuccess)) {
return folly::none;
}
return cloneOutstandingPacket(packet);
}
Buf PacketRebuilder::cloneCryptoRetransmissionBuffer(
const WriteCryptoFrame& frame,
const QuicCryptoStream& stream) {
/**
* Crypto's StreamBuffer is removed from retransmissionBuffer in 2 cases.
* 1: Packet containing the buffer gets acked.
* 2: Packet containing the buffer is marked loss.
* They have to be covered by making sure we do not clone an already acked or
* lost packet.
*/
DCHECK(frame.len) << "WriteCryptoFrame cloning: frame is empty. " << conn_;
auto iter = std::lower_bound(
stream.retransmissionBuffer.begin(),
stream.retransmissionBuffer.end(),
frame.offset,
[](const auto& buffer, const auto& targetOffset) {
return buffer.offset < targetOffset;
});
// If the crypto stream is canceled somehow, just skip cloning this frame
if (iter == stream.retransmissionBuffer.end()) {
return nullptr;
}
DCHECK(iter->offset == frame.offset)
<< "WriteCryptoFrame cloning: offset mismatch. " << conn_;
DCHECK(iter->data.chainLength() == frame.len)
<< "WriteCryptoFrame cloning: Len mismatch. " << conn_;
return iter->data.front()->clone();
}
Buf PacketRebuilder::cloneRetransmissionBuffer(
const WriteStreamFrame& frame,
const QuicStreamState* stream) {
/**
* StreamBuffer is removed from retransmissionBuffer in 3 cases.
* 1: After send or receive RST.
* 2: Packet containing the buffer gets acked.
* 3: Packet containing the buffer is marked loss.
* Checking retransmittable() should cover first case. The latter two cases
* have to be covered by making sure we do not clone an already acked or lost
* packet.
*/
DCHECK(frame.len || frame.fin)
<< "WriteStreamFrame cloning: frame is empty and doesn't have FIN set. "
<< conn_;
DCHECK(stream);
DCHECK(retransmittable(*stream));
auto iter = std::lower_bound(
stream->retransmissionBuffer.begin(),
stream->retransmissionBuffer.end(),
frame.offset,
[](const auto& buffer, const auto& targetOffset) {
return buffer.offset < targetOffset;
});
DCHECK(iter != stream->retransmissionBuffer.end())
<< "WriteStreamFrame cloning: cannot find it in the retx buffer. "
<< conn_;
DCHECK(iter->offset == frame.offset)
<< "WriteStreamFrame cloning: offset mismatch. " << conn_;
DCHECK(iter->data.chainLength() == frame.len)
<< "WriteStreamFrame cloning: Len mismatch. " << conn_;
DCHECK(iter->eof == frame.fin)
<< "WriteStreamFrame cloning: fin mismatch. " << conn_;
DCHECK(!frame.len || !iter->data.empty())
<< "WriteStreamFrame cloning: frame is not empty but StreamBuffer has "
<< "empty data. " << conn_;
return (frame.len ? iter->data.front()->clone() : nullptr);
}
} // namespace quic

View File

@ -0,0 +1,65 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <quic/codec/QuicPacketBuilder.h>
#include <quic/state/StateData.h>
namespace quic {
/**
* A PacketRebuilder is a packet builder that takes in a list of frames, and
* pass them onto a wrapped builder to rebuild the packet. Note that you still
* buildPacket() from the wrapped in builder.
* TODO: The cloning builder only clones stream data that has not already been
* resset or closed. It is possible that a packet may have contained data for
* only closed streams. In that case we would only write out the header.
* This is a waste, so we should do something about this in the future.
*/
class PacketRebuilder {
public:
PacketRebuilder(
RegularQuicPacketBuilder& regularBuilder,
QuicConnectionStateBase& conn);
folly::Optional<PacketEvent> rebuildFromPacket(OutstandingPacket& packet);
// TODO: Same as passing cipherOverhead into the CloningScheduler, this really
// is a sad way to solve the writableBytes problem.
uint64_t getHeaderBytes() const;
private:
/**
* A helper function that takes a OutstandingPacket that's not processed, and
* return its associatedEvent. If this packet has never been cloned, then
* create the associatedEvent and add it into outstandingPacketEvents first.
*/
PacketEvent cloneOutstandingPacket(OutstandingPacket& packet);
bool retransmittable(const QuicStreamState& stream) const {
return matchesStates<
StreamStateData,
StreamStates::Open,
StreamStates::HalfClosedRemote>(stream.state);
}
Buf cloneCryptoRetransmissionBuffer(
const WriteCryptoFrame& frame,
const QuicCryptoStream& stream);
Buf cloneRetransmissionBuffer(
const WriteStreamFrame& frame,
const QuicStreamState* stream);
private:
RegularQuicPacketBuilder& builder_;
QuicConnectionStateBase& conn_;
};
} // namespace quic

View File

@ -0,0 +1,454 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
// override-include-guard
#include <quic/codec/QuicReadCodec.h>
#include <folly/io/Cursor.h>
#include <quic/codec/Decode.h>
#include <quic/codec/PacketNumber.h>
namespace {
quic::ConnectionId zeroConnId() {
std::vector<uint8_t> zeroData(quic::kDefaultConnectionIdSize, 0);
return quic::ConnectionId(zeroData);
}
} // namespace
namespace quic {
QuicReadCodec::QuicReadCodec(QuicNodeType nodeType) : nodeType_(nodeType) {}
CodecResult QuicReadCodec::parseLongHeaderPacket(
folly::IOBufQueue& queue,
const AckStates& ackStates) {
folly::io::Cursor cursor(queue.front());
auto initialByte = cursor.readBE<uint8_t>();
auto longHeaderInvariant = parseLongHeaderInvariant(initialByte, cursor);
if (!longHeaderInvariant) {
VLOG(4) << "Dropping packet, failed to parse invariant " << connIdToHex();
// We've failed to parse the long header, so we have no idea where this
// packet ends. Clear the queue since no other data in this packet is
// parse-able.
queue.clear();
return folly::none;
}
if (longHeaderInvariant->invariant.version ==
QuicVersion::VERSION_NEGOTIATION) {
// Couldn't parse the packet as a regular packet, try parsing it as a
// version negotiation.
auto versionNegotiation =
decodeVersionNegotiation(*longHeaderInvariant, cursor);
if (versionNegotiation) {
return std::move(*versionNegotiation);
} else {
VLOG(4) << "Dropping version negotiation packet " << connIdToHex();
}
// Version negotiation is not allowed to be coalesced with any other packet.
queue.clear();
return CodecResult(folly::none);
}
auto type = parseLongHeaderType(initialByte);
auto parsedLongHeader =
parseLongHeaderVariants(type, std::move(*longHeaderInvariant), cursor);
if (!parsedLongHeader) {
VLOG(4) << "Dropping due to failed to parse header " << connIdToHex();
// We've failed to parse the long header, so we have no idea where this
// packet ends. Clear the queue since no other data in this packet is
// parse-able.
queue.clear();
return CodecResult(folly::none);
}
// As soon as we have parsed out the long header we can split off any
// coalesced packets. We do this early since the spec mandates that decryption
// failure must not stop the processing of subsequent coalesced packets.
auto longHeader = std::move(parsedLongHeader->header);
if (type == LongHeader::Types::Retry) {
return RegularQuicPacket(std::move(longHeader));
}
uint64_t packetNumberOffset = cursor.getCurrentPosition();
size_t currentPacketLen =
packetNumberOffset + parsedLongHeader->packetLength.packetLength;
if (queue.chainLength() < currentPacketLen) {
// Packet appears truncated, there's no parse-able data left.
queue.clear();
return CodecResult(folly::none);
}
auto currentPacketData = queue.split(currentPacketLen);
cursor.reset(currentPacketData.get());
cursor.skip(packetNumberOffset);
// Sample starts after the max packet number size. This ensures that we
// have enough bytes to skip before we can start reading the sample.
if (!cursor.canAdvance(kMaxPacketNumEncodingSize)) {
VLOG(4) << "Dropping packet, not enough for packet number "
<< connIdToHex();
// Packet appears truncated, there's no parse-able data left.
queue.clear();
return CodecResult(folly::none);
}
cursor.skip(kMaxPacketNumEncodingSize);
Sample sample;
if (!cursor.canAdvance(sample.size())) {
VLOG(4) << "Dropping packet, sample too small " << connIdToHex();
// Packet appears truncated, there's no parse-able data left.
queue.clear();
return CodecResult(folly::none);
}
cursor.pull(sample.data(), sample.size());
const PacketNumberCipher* headerCipher{nullptr};
const fizz::Aead* cipher{nullptr};
auto protectionType = longHeader.getProtectionType();
switch (protectionType) {
case ProtectionType::Initial:
if (handshakeDoneTime_) {
auto timeBetween = Clock::now() - *handshakeDoneTime_;
if (timeBetween > kTimeToRetainZeroRttKeys) {
VLOG(4) << nodeToString(nodeType_)
<< " dropping initial packet for exceeding key timeout"
<< connIdToHex();
return folly::none;
}
}
headerCipher = initialHeaderCipher_.get();
cipher = initialReadCipher_.get();
break;
case ProtectionType::Handshake:
headerCipher = handshakeHeaderCipher_.get();
cipher = handshakeReadCipher_.get();
break;
case ProtectionType::ZeroRtt:
if (handshakeDoneTime_) {
auto timeBetween = Clock::now() - *handshakeDoneTime_;
if (timeBetween > kTimeToRetainZeroRttKeys) {
VLOG(4) << nodeToString(nodeType_)
<< " dropping zero rtt packet for exceeding key timeout"
<< connIdToHex();
return folly::none;
}
}
headerCipher = zeroRttHeaderCipher_.get();
cipher = zeroRttReadCipher_.get();
break;
case ProtectionType::KeyPhaseZero:
case ProtectionType::KeyPhaseOne:
CHECK(false) << "one rtt protection type in long header";
}
if (!headerCipher || !cipher) {
// TODO: remove packet number here.
return CodecResult(
CipherUnavailable(std::move(currentPacketData), 0, protectionType));
}
// TODO: decrypt the long header.
PacketNum expectedNextPacketNum = 0;
folly::Optional<PacketNum> largestReceivedPacketNum;
switch (longHeaderTypeToProtectionType(type)) {
case ProtectionType::Initial:
largestReceivedPacketNum =
ackStates.initialAckState.largestReceivedPacketNum;
break;
case ProtectionType::Handshake:
largestReceivedPacketNum =
ackStates.handshakeAckState.largestReceivedPacketNum;
break;
case ProtectionType::ZeroRtt:
largestReceivedPacketNum =
ackStates.appDataAckState.largestReceivedPacketNum;
break;
default:
folly::assume_unreachable();
}
if (largestReceivedPacketNum) {
expectedNextPacketNum = 1 + *largestReceivedPacketNum;
}
folly::MutableByteRange initialByteRange(
currentPacketData->writableData(), 1);
folly::MutableByteRange packetNumberByteRange(
currentPacketData->writableData() + packetNumberOffset,
kMaxPacketNumEncodingSize);
headerCipher->decryptLongHeader(
folly::range(sample), initialByteRange, packetNumberByteRange);
std::pair<PacketNum, size_t> packetNum = parsePacketNumber(
initialByteRange.data()[0], packetNumberByteRange, expectedNextPacketNum);
longHeader.setPacketNumber(packetNum.first);
folly::IOBufQueue decryptQueue{folly::IOBufQueue::cacheChainLength()};
decryptQueue.append(std::move(currentPacketData));
size_t aadLen = packetNumberOffset + packetNum.second;
auto headerData = decryptQueue.split(aadLen);
// parsing verifies that packetLength >= packet number length.
auto encryptedData = decryptQueue.splitAtMost(
parsedLongHeader->packetLength.packetLength - packetNum.second);
if (!encryptedData) {
// There should normally be some integrity tag at least in the data,
// however allowing the aead to process the data even if the tag is not
// present helps with writing tests.
encryptedData = folly::IOBuf::create(0);
}
Buf decrypted;
auto decryptAttempt = cipher->tryDecrypt(
std::move(encryptedData), headerData.get(), packetNum.first);
if (!decryptAttempt) {
VLOG(4) << "Unable to decrypt packet=" << packetNum.first
<< " packetNumLen=" << parsePacketNumberLength(initialByte)
<< " protectionType=" << toString(protectionType) << " "
<< connIdToHex();
return CodecResult(folly::none);
}
decrypted = std::move(*decryptAttempt);
if (!decrypted) {
// TODO better way of handling this (tests break without this)
decrypted = folly::IOBuf::create(0);
}
folly::io::Cursor packetCursor(decrypted.get());
return decodeRegularPacket(std::move(longHeader), params_, packetCursor);
}
CodecResult QuicReadCodec::parsePacket(
folly::IOBufQueue& queue,
const AckStates& ackStates) {
if (queue.empty()) {
return CodecResult(folly::none);
}
DCHECK(!queue.front()->isChained());
folly::io::Cursor cursor(queue.front());
if (!cursor.canAdvance(sizeof(uint8_t))) {
return folly::none;
}
uint8_t initialByte = cursor.readBE<uint8_t>();
auto headerForm = getHeaderForm(initialByte);
if (headerForm == HeaderForm::Long) {
return parseLongHeaderPacket(queue, ackStates);
}
// Short header:
// TODO: support key phase one.
if (!oneRttReadCipher_ || !oneRttHeaderCipher_) {
VLOG(4) << nodeToString(nodeType_) << " cannot read key phase zero packet";
VLOG(20) << "cannot read data="
<< folly::hexlify(queue.front()->clone()->moveToFbString()) << " "
<< connIdToHex();
return CodecResult(
CipherUnavailable(queue.move(), 0, ProtectionType::KeyPhaseZero));
}
// TODO: allow other connid lengths from the state.
size_t packetNumberOffset = 1 + kDefaultConnectionIdSize;
PacketNum expectedNextPacketNum =
ackStates.appDataAckState.largestReceivedPacketNum
? (1 + *ackStates.appDataAckState.largestReceivedPacketNum)
: 0;
size_t sampleOffset = packetNumberOffset + kMaxPacketNumEncodingSize;
Sample sample;
if (queue.chainLength() < sampleOffset + sample.size()) {
VLOG(10) << "Dropping packet, too small for sample " << connIdToHex();
// There's not enough space for the short header packet, clear the queue
// to indicate there's no more parse-able data.
queue.clear();
return CodecResult(folly::none);
}
// Take it out of the queue so we can do some writing.
auto data = queue.move();
folly::MutableByteRange initialByteRange(data->writableData(), 1);
folly::MutableByteRange packetNumberByteRange(
data->writableData() + packetNumberOffset, kMaxPacketNumEncodingSize);
folly::ByteRange sampleByteRange(
data->writableData() + sampleOffset, sample.size());
oneRttHeaderCipher_->decryptShortHeader(
sampleByteRange, initialByteRange, packetNumberByteRange);
std::pair<PacketNum, size_t> packetNum = parsePacketNumber(
initialByteRange.data()[0], packetNumberByteRange, expectedNextPacketNum);
auto shortHeader = parseShortHeader(initialByteRange.data()[0], cursor);
if (!shortHeader) {
VLOG(10) << "Dropping packet, cannot parse " << connIdToHex();
return folly::none;
}
shortHeader->setPacketNumber(packetNum.first);
if (shortHeader->getProtectionType() == ProtectionType::KeyPhaseOne) {
VLOG(4) << nodeToString(nodeType_) << " cannot read key phase one packet "
<< connIdToHex();
return folly::none;
}
// Back in the queue so we can split.
// TODO: this will share the buffer. We should be able to supply an unshared
// buffer.
queue.append(std::move(data));
size_t aadLen = packetNumberOffset + packetNum.second;
auto headerData = queue.split(aadLen);
auto encryptedData = queue.move();
if (!encryptedData) {
// There should normally be some integrity tag at least in the data,
// however allowing the aead to process the data even if the tag is not
// present helps with writing tests.
encryptedData = folly::IOBuf::create(0);
}
Buf decrypted;
// TODO: small optimization we can do here: only read the token if
// decryption fails
folly::Optional<StatelessResetToken> token;
auto encryptedDataLength = encryptedData->computeChainDataLength();
if (statelessResetToken_ &&
encryptedDataLength > sizeof(StatelessResetToken)) {
token = StatelessResetToken();
// We want to avoid cloning the IOBuf which would prevent in-place
// decryption
folly::io::Cursor statelessTokenCursor(encryptedData.get());
// TODO: we could possibly use headroom or tailroom of the iobuf to avoid
// extra allocations
statelessTokenCursor.skip(
encryptedDataLength - sizeof(StatelessResetToken));
statelessTokenCursor.pull(token->data(), token->size());
}
auto decryptAttempt = oneRttReadCipher_->tryDecrypt(
std::move(encryptedData), headerData.get(), packetNum.first);
if (!decryptAttempt) {
// Can't return the data now, already consumed it to try decrypting it.
if (token) {
return StatelessReset(*token);
}
auto protectionType = shortHeader->getProtectionType();
VLOG(10) << "Unable to decrypt packet=" << packetNum.first
<< " protectionType=" << (int)protectionType << " "
<< connIdToHex();
return CodecResult(folly::none);
}
decrypted = std::move(*decryptAttempt);
if (!decrypted) {
// TODO better way of handling this (tests break without this)
decrypted = folly::IOBuf::create(0);
}
folly::io::Cursor packetCursor(decrypted.get());
return decodeRegularPacket(std::move(*shortHeader), params_, packetCursor);
}
const fizz::Aead* QuicReadCodec::getOneRttReadCipher() const {
return oneRttReadCipher_.get();
}
const fizz::Aead* QuicReadCodec::getZeroRttReadCipher() const {
return zeroRttReadCipher_.get();
}
const fizz::Aead* QuicReadCodec::getHandshakeReadCipher() const {
return handshakeReadCipher_.get();
}
const folly::Optional<StatelessResetToken>&
QuicReadCodec::getStatelessResetToken() const {
return statelessResetToken_;
}
void QuicReadCodec::setInitialReadCipher(
std::unique_ptr<fizz::Aead> initialReadCipher) {
initialReadCipher_ = std::move(initialReadCipher);
}
void QuicReadCodec::setOneRttReadCipher(
std::unique_ptr<fizz::Aead> oneRttReadCipher) {
oneRttReadCipher_ = std::move(oneRttReadCipher);
}
void QuicReadCodec::setZeroRttReadCipher(
std::unique_ptr<fizz::Aead> zeroRttReadCipher) {
if (nodeType_ == QuicNodeType::Client) {
throw QuicTransportException(
"Invalid cipher", TransportErrorCode::INTERNAL_ERROR);
}
zeroRttReadCipher_ = std::move(zeroRttReadCipher);
}
void QuicReadCodec::setHandshakeReadCipher(
std::unique_ptr<fizz::Aead> handshakeReadCipher) {
handshakeReadCipher_ = std::move(handshakeReadCipher);
}
void QuicReadCodec::setInitialHeaderCipher(
std::unique_ptr<PacketNumberCipher> initialHeaderCipher) {
initialHeaderCipher_ = std::move(initialHeaderCipher);
}
void QuicReadCodec::setOneRttHeaderCipher(
std::unique_ptr<PacketNumberCipher> oneRttHeaderCipher) {
oneRttHeaderCipher_ = std::move(oneRttHeaderCipher);
}
void QuicReadCodec::setZeroRttHeaderCipher(
std::unique_ptr<PacketNumberCipher> zeroRttHeaderCipher) {
zeroRttHeaderCipher_ = std::move(zeroRttHeaderCipher);
}
void QuicReadCodec::setHandshakeHeaderCipher(
std::unique_ptr<PacketNumberCipher> handshakeHeaderCipher) {
handshakeHeaderCipher_ = std::move(handshakeHeaderCipher);
}
void QuicReadCodec::setCodecParameters(CodecParameters params) {
params_ = std::move(params);
}
void QuicReadCodec::setClientConnectionId(ConnectionId connId) {
clientConnectionId_ = connId;
}
void QuicReadCodec::setServerConnectionId(ConnectionId connId) {
serverConnectionId_ = connId;
}
void QuicReadCodec::setStatelessResetToken(
StatelessResetToken statelessResetToken) {
statelessResetToken_ = std::move(statelessResetToken);
}
const fizz::Aead* QuicReadCodec::getInitialCipher() const {
return initialReadCipher_.get();
}
const PacketNumberCipher* QuicReadCodec::getInitialHeaderCipher() const {
return initialHeaderCipher_.get();
}
const PacketNumberCipher* QuicReadCodec::getOneRttHeaderCipher() const {
return oneRttHeaderCipher_.get();
}
const PacketNumberCipher* QuicReadCodec::getHandshakeHeaderCipher() const {
return handshakeHeaderCipher_.get();
}
const PacketNumberCipher* QuicReadCodec::getZeroRttHeaderCipher() const {
return zeroRttHeaderCipher_.get();
}
void QuicReadCodec::onHandshakeDone(TimePoint handshakeDoneTime) {
if (!handshakeDoneTime_) {
handshakeDoneTime_ = handshakeDoneTime;
}
}
folly::Optional<TimePoint> QuicReadCodec::getHandshakeDoneTime() {
return handshakeDoneTime_;
}
std::string QuicReadCodec::connIdToHex() {
static ConnectionId zeroConn = zeroConnId();
const auto& serverId = serverConnectionId_.value_or(zeroConn);
const auto& clientId = clientConnectionId_.value_or(zeroConn);
return folly::to<std::string>(
"server=", serverId.hex(), " ", "client=", clientId.hex());
}
} // namespace quic

129
quic/codec/QuicReadCodec.h Normal file
View File

@ -0,0 +1,129 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <folly/Optional.h>
#include <quic/codec/Decode.h>
#include <quic/codec/PacketNumber.h>
#include <quic/codec/Types.h>
#include <quic/handshake/HandshakeLayer.h>
#include <quic/state/AckStates.h>
namespace quic {
/**
* Structure which describes data which could not be processed by
* the read codec due to the required cipher being unavailable. The caller might
* use this to retry later once the cipher is available.
*/
struct CipherUnavailable {
Buf packet;
PacketNum packetNum;
ProtectionType protectionType;
CipherUnavailable(
Buf packetIn,
PacketNum packetNumIn,
ProtectionType protectionTypeIn)
: packet(std::move(packetIn)),
packetNum(packetNumIn),
protectionType(protectionTypeIn) {}
};
using CodecResult = boost::
variant<QuicPacket, folly::Optional<CipherUnavailable>, StatelessReset>;
class QuicReadCodec {
public:
virtual ~QuicReadCodec() = default;
explicit QuicReadCodec(QuicNodeType nodeType);
/**
* Tries to parse a packet from the buffer data.
* If it is able to parse the packet, then it returns
* a valid QUIC packet. If it is not able to parse a packet it might return a
* cipher unavailable structure. The caller can then retry when the cipher is
* available.
*/
virtual CodecResult parsePacket(
folly::IOBufQueue& queue,
const AckStates& ackStates);
const fizz::Aead* getOneRttReadCipher() const;
const fizz::Aead* getZeroRttReadCipher() const;
const fizz::Aead* getHandshakeReadCipher() const;
const fizz::Aead* getInitialCipher() const;
const PacketNumberCipher* getInitialHeaderCipher() const;
const PacketNumberCipher* getOneRttHeaderCipher() const;
const PacketNumberCipher* getHandshakeHeaderCipher() const;
const PacketNumberCipher* getZeroRttHeaderCipher() const;
const folly::Optional<StatelessResetToken>& getStatelessResetToken() const;
void setInitialReadCipher(std::unique_ptr<fizz::Aead> initialReadCipher);
void setOneRttReadCipher(std::unique_ptr<fizz::Aead> oneRttReadCipher);
void setZeroRttReadCipher(std::unique_ptr<fizz::Aead> zeroRttReadCipher);
void setHandshakeReadCipher(std::unique_ptr<fizz::Aead> handshakeReadCipher);
void setInitialHeaderCipher(
std::unique_ptr<PacketNumberCipher> initialHeaderCipher);
void setOneRttHeaderCipher(
std::unique_ptr<PacketNumberCipher> oneRttHeaderCipher);
void setZeroRttHeaderCipher(
std::unique_ptr<PacketNumberCipher> zeroRttHeaderCipher);
void setHandshakeHeaderCipher(
std::unique_ptr<PacketNumberCipher> handshakeHeaderCipher);
void setCodecParameters(CodecParameters params);
void setClientConnectionId(ConnectionId connId);
void setServerConnectionId(ConnectionId connId);
void setStatelessResetToken(StatelessResetToken statelessResetToken);
/**
* Should be invoked when the state machine believes that the handshake is
* complete.
*/
void onHandshakeDone(TimePoint handshakeDoneTime);
folly::Optional<TimePoint> getHandshakeDoneTime();
private:
CodecResult parseLongHeaderPacket(
folly::IOBufQueue& queue,
const AckStates& ackStates);
std::string connIdToHex();
QuicNodeType nodeType_;
CodecParameters params_;
folly::Optional<ConnectionId> clientConnectionId_;
folly::Optional<ConnectionId> serverConnectionId_;
// Cipher used to decrypt handshake packets.
std::unique_ptr<fizz::Aead> initialReadCipher_;
std::unique_ptr<fizz::Aead> oneRttReadCipher_;
std::unique_ptr<fizz::Aead> zeroRttReadCipher_;
std::unique_ptr<fizz::Aead> handshakeReadCipher_;
std::unique_ptr<PacketNumberCipher> initialHeaderCipher_;
std::unique_ptr<PacketNumberCipher> oneRttHeaderCipher_;
std::unique_ptr<PacketNumberCipher> zeroRttHeaderCipher_;
std::unique_ptr<PacketNumberCipher> handshakeHeaderCipher_;
folly::Optional<StatelessResetToken> statelessResetToken_;
folly::Optional<TimePoint> handshakeDoneTime_;
};
} // namespace quic

View File

@ -0,0 +1,598 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/QuicConstants.h>
#include <quic/QuicException.h>
#include <quic/codec/QuicInteger.h>
#include <quic/codec/QuicWriteCodec.h>
#include <algorithm>
#include <limits>
namespace {
bool packetSpaceCheck(uint64_t limit, size_t require);
/**
* A helper function to check if there are enough space to write in the packet.
* Return: true if there is enough space, false otherwise
*/
bool packetSpaceCheck(uint64_t limit, size_t require) {
return (folly::to<uint64_t>(require) <= limit);
}
} // namespace
namespace quic {
folly::Optional<StreamFrameWriteResult> writeStreamFrame(
const StreamFrameMetaData& streamFrameMetaData,
PacketBuilderInterface& builder) {
if (!builder.remainingSpaceInPkt()) {
return folly::none;
}
if ((!streamFrameMetaData.data ||
streamFrameMetaData.data->computeChainDataLength() == 0) &&
!streamFrameMetaData.fin) {
VLOG(2) << "No data or FIN supplied while writing stream "
<< streamFrameMetaData.id;
throw QuicInternalException(
"No data or FIN supplied", LocalErrorCode::INVALID_WRITE_DATA);
}
StreamTypeField::Builder initialByte;
QuicInteger streamId(streamFrameMetaData.id);
QuicInteger offset(streamFrameMetaData.offset);
size_t headerSize = sizeof(FrameType::STREAM) + streamId.getSize();
if (LIKELY(streamFrameMetaData.hasMoreFrames)) {
initialByte.setLength();
// We use the size remaining for simplicity here. 2 bytes should be enough
// for almost anything. We could save 1 byte by deciding whether or not we
// have < 100 bytes to write, however that seems a bit overkill.
auto size = getQuicIntegerSize(builder.remainingSpaceInPkt());
if (size.hasError()) {
throw QuicTransportException(
folly::to<std::string>(
"Stream Frame: Value too large ", builder.remainingSpaceInPkt()),
size.error());
}
headerSize += *size;
}
if (streamFrameMetaData.offset != 0) {
initialByte.setOffset();
headerSize += offset.getSize();
}
uint64_t spaceLeftInPkt = builder.remainingSpaceInPkt();
if (spaceLeftInPkt < headerSize) {
// We don't have enough space, this can happen often so exception is too
// expensive for this. Just return empty result.
VLOG(4) << "No space in packet for stream header. stream="
<< streamFrameMetaData.id
<< " remaining=" << builder.remainingSpaceInPkt();
return folly::none;
}
spaceLeftInPkt -= headerSize;
uint64_t dataInStream = 0;
if (streamFrameMetaData.data) {
dataInStream = streamFrameMetaData.data->computeChainDataLength();
}
auto dataCanWrite = std::min<uint64_t>(spaceLeftInPkt, dataInStream);
bool canWrite = (dataInStream > 0 && dataCanWrite > 0) ||
(dataInStream == 0 && streamFrameMetaData.fin);
if (!canWrite) {
VLOG(4) << "No space in packet for stream=" << streamFrameMetaData.id
<< " dataInStream=" << dataInStream
<< " fin=" << streamFrameMetaData.fin
<< " remaining=" << spaceLeftInPkt;
return folly::none;
}
QuicInteger actualLength(dataCanWrite);
bool writtenFin = false;
if (streamFrameMetaData.fin && dataCanWrite == dataInStream) {
// We can only write a FIN if we ended up writing all the bytes
// in the input data.
initialByte.setFin();
writtenFin = true;
}
builder.writeBE(initialByte.build().fieldValue());
builder.write(streamId);
if (streamFrameMetaData.offset != 0) {
builder.write(offset);
}
if (LIKELY(streamFrameMetaData.hasMoreFrames)) {
builder.write(actualLength);
}
Buf bufToWrite;
if (dataCanWrite > 0) {
folly::io::Cursor cursor(streamFrameMetaData.data.get());
cursor.clone(bufToWrite, dataCanWrite);
} else {
bufToWrite = folly::IOBuf::create(0);
}
VLOG(4) << "writing frame stream=" << streamFrameMetaData.id
<< " offset=" << streamFrameMetaData.offset
<< " data=" << dataCanWrite << " fin=" << writtenFin;
builder.insert(bufToWrite->clone());
builder.appendFrame(WriteStreamFrame(
streamFrameMetaData.id,
streamFrameMetaData.offset,
dataCanWrite,
writtenFin));
StreamFrameWriteResult result(
dataCanWrite, writtenFin, std::move(bufToWrite));
return folly::make_optional(std::move(result));
}
folly::Optional<WriteCryptoFrame>
writeCryptoFrame(uint64_t offsetIn, Buf data, PacketBuilderInterface& builder) {
uint64_t spaceLeftInPkt = builder.remainingSpaceInPkt();
QuicInteger packetType(static_cast<uint8_t>(FrameType::CRYPTO_FRAME));
QuicInteger offsetInteger(offsetIn);
size_t lengthBytes = 2;
size_t cryptoFrameHeaderSize =
packetType.getSize() + offsetInteger.getSize() + lengthBytes;
if (spaceLeftInPkt <= cryptoFrameHeaderSize) {
VLOG(3) << "No space left in packet to write cryptoFrame header of size: "
<< cryptoFrameHeaderSize << ", space left=" << spaceLeftInPkt;
return folly::none;
}
size_t spaceRemaining = spaceLeftInPkt - cryptoFrameHeaderSize;
size_t dataLength = data->computeChainDataLength();
size_t writeableData = std::min(dataLength, spaceRemaining);
QuicInteger lengthVarInt(writeableData);
if (lengthVarInt.getSize() > lengthBytes) {
throw QuicInternalException(
std::string("Length bytes representation"),
LocalErrorCode::CODEC_ERROR);
}
data->coalesce();
data->trimEnd(dataLength - writeableData);
builder.write(packetType);
builder.write(offsetInteger);
builder.write(lengthVarInt);
builder.insert(std::move(data));
builder.appendFrame(WriteCryptoFrame(offsetIn, lengthVarInt.getValue()));
return WriteCryptoFrame(offsetIn, lengthVarInt.getValue());
}
size_t fillFrameWithAckBlocks(
const IntervalSet<PacketNum>& ackBlocks,
WriteAckFrame& ackFrame,
uint64_t bytesLimit);
size_t fillFrameWithAckBlocks(
const IntervalSet<PacketNum>& ackBlocks,
WriteAckFrame& ackFrame,
uint64_t bytesLimit) {
PacketNum currentSeqNum = ackBlocks.crbegin()->start;
// starts off with 0 which is what we assumed the initial ack block to be for
// the largest acked.
size_t numAdditionalAckBlocks = 0;
QuicInteger previousNumAckBlockInt(numAdditionalAckBlocks);
for (auto blockItr = ackBlocks.crbegin() + 1; blockItr != ackBlocks.crend();
++blockItr) {
const auto& currBlock = *blockItr;
// These must be true because of the properties of the interval set.
CHECK_GE(currentSeqNum, currBlock.end + 2);
PacketNum gap = currentSeqNum - currBlock.end - 2;
PacketNum currBlockLen = currBlock.end - currBlock.start;
QuicInteger gapInt(gap);
QuicInteger currentBlockLenInt(currBlockLen);
QuicInteger numAckBlocksInt(numAdditionalAckBlocks + 1);
size_t additionalSize = gapInt.getSize() + currentBlockLenInt.getSize() +
(numAckBlocksInt.getSize() - previousNumAckBlockInt.getSize());
if (bytesLimit < additionalSize) {
break;
}
numAdditionalAckBlocks++;
bytesLimit -= additionalSize;
previousNumAckBlockInt = numAckBlocksInt;
currentSeqNum = currBlock.start;
ackFrame.ackBlocks.insert(currBlock.start, currBlock.end);
}
return numAdditionalAckBlocks;
}
folly::Optional<AckFrameWriteResult> writeAckFrame(
const quic::AckFrameMetaData& ackFrameMetaData,
PacketBuilderInterface& builder) {
if (ackFrameMetaData.ackBlocks.empty()) {
return folly::none;
}
// The last block must be the largest block.
auto largestAckedPacket = ackFrameMetaData.ackBlocks.back().end;
// ackBlocks are already an interval set so each value is naturally
// non-overlapping.
auto firstAckBlockLength =
largestAckedPacket - ackFrameMetaData.ackBlocks.back().start;
WriteAckFrame ackFrame;
uint64_t spaceLeft = builder.remainingSpaceInPkt();
uint64_t beginningSpace = spaceLeft;
// We could technically split the range if the size of the representation of
// the integer is too large, but that gets super tricky and is of dubious
// value.
QuicInteger largestAckedPacketInt(largestAckedPacket);
QuicInteger firstAckBlockLengthInt(firstAckBlockLength);
// Convert ackDelay to unsigned value explicitly before right shifting to
// avoid issues with right shifting signed values.
uint64_t encodedAckDelay = ackFrameMetaData.ackDelay.count();
encodedAckDelay = encodedAckDelay >> ackFrameMetaData.ackDelayExponent;
QuicInteger ackDelayInt(encodedAckDelay);
QuicInteger minAdditionalAckBlockCount(0);
// Required fields are Type, LargestAcked, AckDelay, AckBlockCount,
// firstAckBlockLength
QuicInteger encodedPacketType(static_cast<uint8_t>(FrameType::ACK));
auto headerSize = encodedPacketType.getSize() +
largestAckedPacketInt.getSize() + ackDelayInt.getSize() +
minAdditionalAckBlockCount.getSize() + firstAckBlockLengthInt.getSize();
if (spaceLeft < headerSize) {
return folly::none;
}
spaceLeft -= headerSize;
auto numAdditionalAckBlocks =
fillFrameWithAckBlocks(ackFrameMetaData.ackBlocks, ackFrame, spaceLeft);
QuicInteger numAdditionalAckBlocksInt(numAdditionalAckBlocks);
builder.write(encodedPacketType);
builder.write(largestAckedPacketInt);
builder.write(ackDelayInt);
builder.write(numAdditionalAckBlocksInt);
builder.write(firstAckBlockLengthInt);
PacketNum currentSeqNum = ackFrameMetaData.ackBlocks.back().start;
for (auto it = ackFrame.ackBlocks.crbegin(); it != ackFrame.ackBlocks.crend();
++it) {
CHECK_GE(currentSeqNum, it->end + 2);
PacketNum gap = currentSeqNum - it->end - 2;
PacketNum currBlockLen = it->end - it->start;
QuicInteger gapInt(gap);
QuicInteger currentBlockLenInt(currBlockLen);
builder.write(gapInt);
builder.write(currentBlockLenInt);
currentSeqNum = it->start;
}
// also the largest ack block since we already accounted for the space to
// write to it.
ackFrame.ackBlocks.insert(
ackFrameMetaData.ackBlocks.back().start,
ackFrameMetaData.ackBlocks.back().end);
ackFrame.ackDelay = ackFrameMetaData.ackDelay;
builder.appendFrame(std::move(ackFrame));
return AckFrameWriteResult(
beginningSpace - builder.remainingSpaceInPkt(),
1 + numAdditionalAckBlocks);
}
size_t writeSimpleFrame(
QuicSimpleFrame&& frame,
PacketBuilderInterface& builder) {
using FrameTypeType = std::underlying_type<FrameType>::type;
uint64_t spaceLeft = builder.remainingSpaceInPkt();
return folly::variant_match(
frame,
[&](StopSendingFrame& stopSendingFrame) {
QuicInteger packetType(static_cast<uint8_t>(FrameType::STOP_SENDING));
QuicInteger streamId(stopSendingFrame.streamId);
auto stopSendingFrameSize = packetType.getSize() + streamId.getSize() +
sizeof(ApplicationErrorCode);
if (packetSpaceCheck(spaceLeft, stopSendingFrameSize)) {
builder.write(packetType);
builder.write(streamId);
builder.writeBE(
static_cast<std::underlying_type<ApplicationErrorCode>::type>(
stopSendingFrame.errorCode));
builder.appendFrame(std::move(stopSendingFrame));
return stopSendingFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](MinStreamDataFrame& minStreamDataFrame) {
QuicInteger streamId(minStreamDataFrame.streamId);
QuicInteger maximumData(minStreamDataFrame.maximumData);
QuicInteger minimumStreamOffset(minStreamDataFrame.minimumStreamOffset);
QuicInteger frameType(
static_cast<FrameTypeType>(FrameType::MIN_STREAM_DATA));
auto minStreamDataFrameSize = frameType.getSize() + streamId.getSize() +
maximumData.getSize() + minimumStreamOffset.getSize();
if (packetSpaceCheck(spaceLeft, minStreamDataFrameSize)) {
builder.write(frameType);
builder.write(streamId);
builder.write(maximumData);
builder.write(minimumStreamOffset);
builder.appendFrame(std::move(minStreamDataFrame));
return minStreamDataFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](ExpiredStreamDataFrame& expiredStreamDataFrame) {
QuicInteger frameType(
static_cast<FrameTypeType>(FrameType::EXPIRED_STREAM_DATA));
QuicInteger streamId(expiredStreamDataFrame.streamId);
QuicInteger minimumStreamOffset(
expiredStreamDataFrame.minimumStreamOffset);
auto expiredStreamDataFrameSize = frameType.getSize() +
streamId.getSize() + minimumStreamOffset.getSize();
if (packetSpaceCheck(spaceLeft, expiredStreamDataFrameSize)) {
builder.write(frameType);
builder.write(streamId);
builder.write(minimumStreamOffset);
builder.appendFrame(std::move(expiredStreamDataFrame));
return expiredStreamDataFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](PathChallengeFrame& pathChallengeFrame) {
QuicInteger frameType(static_cast<uint8_t>(FrameType::PATH_CHALLENGE));
auto pathChallengeFrameSize =
frameType.getSize() + sizeof(pathChallengeFrame.pathData);
if (packetSpaceCheck(spaceLeft, pathChallengeFrameSize)) {
builder.write(frameType);
builder.writeBE(pathChallengeFrame.pathData);
builder.appendFrame(std::move(pathChallengeFrame));
return pathChallengeFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](PathResponseFrame& pathResponseFrame) {
QuicInteger frameType(static_cast<uint8_t>(FrameType::PATH_RESPONSE));
auto pathResponseFrameSize =
frameType.getSize() + sizeof(pathResponseFrame.pathData);
if (packetSpaceCheck(spaceLeft, pathResponseFrameSize)) {
builder.write(frameType);
builder.writeBE(pathResponseFrame.pathData);
builder.appendFrame(std::move(pathResponseFrame));
return pathResponseFrameSize;
}
// no space left in packet
return size_t(0);
});
}
size_t writeFrame(QuicWriteFrame&& frame, PacketBuilderInterface& builder) {
using FrameTypeType = std::underlying_type<FrameType>::type;
uint64_t spaceLeft = builder.remainingSpaceInPkt();
return folly::variant_match(
frame,
[&](PaddingFrame& paddingFrame) {
QuicInteger packetType(static_cast<uint8_t>(FrameType::PADDING));
if (packetSpaceCheck(spaceLeft, packetType.getSize())) {
builder.write(packetType);
builder.appendFrame(std::move(paddingFrame));
return packetType.getSize();
}
return size_t(0);
},
[&](PingFrame& pingFrame) {
QuicInteger packetType(static_cast<uint8_t>(FrameType::PING));
if (packetSpaceCheck(spaceLeft, packetType.getSize())) {
builder.write(packetType);
builder.appendFrame(std::move(pingFrame));
return packetType.getSize();
}
// no space left in packet
return size_t(0);
},
[&](RstStreamFrame& rstStreamFrame) {
QuicInteger packetType(static_cast<uint8_t>(FrameType::RST_STREAM));
QuicInteger streamId(rstStreamFrame.streamId);
QuicInteger offset(rstStreamFrame.offset);
auto rstStreamFrameSize = packetType.getSize() +
sizeof(ApplicationErrorCode) + streamId.getSize() +
offset.getSize();
if (packetSpaceCheck(spaceLeft, rstStreamFrameSize)) {
builder.write(packetType);
builder.write(streamId);
builder.writeBE(
static_cast<std::underlying_type<ApplicationErrorCode>::type>(
rstStreamFrame.errorCode));
builder.write(offset);
builder.appendFrame(std::move(rstStreamFrame));
return rstStreamFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](MaxDataFrame& maxDataFrame) {
QuicInteger packetType(static_cast<uint8_t>(FrameType::MAX_DATA));
QuicInteger maximumData(maxDataFrame.maximumData);
auto frameSize = packetType.getSize() + maximumData.getSize();
if (packetSpaceCheck(spaceLeft, frameSize)) {
builder.write(packetType);
builder.write(maximumData);
builder.appendFrame(std::move(maxDataFrame));
return frameSize;
}
// no space left in packet
return size_t(0);
},
[&](MaxStreamDataFrame& maxStreamDataFrame) {
QuicInteger packetType(
static_cast<uint8_t>(FrameType::MAX_STREAM_DATA));
QuicInteger streamId(maxStreamDataFrame.streamId);
QuicInteger maximumData(maxStreamDataFrame.maximumData);
auto maxStreamDataFrameSize =
packetType.getSize() + streamId.getSize() + maximumData.getSize();
if (packetSpaceCheck(spaceLeft, maxStreamDataFrameSize)) {
builder.write(packetType);
builder.write(streamId);
builder.write(maximumData);
builder.appendFrame(std::move(maxStreamDataFrame));
return maxStreamDataFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](MaxStreamsFrame& maxStreamsFrame) {
auto frameType = maxStreamsFrame.isForBidirectionalStream()
? FrameType::MAX_STREAMS_BIDI
: FrameType::MAX_STREAMS_UNI;
QuicInteger packetType(static_cast<FrameTypeType>(frameType));
QuicInteger streamId(maxStreamsFrame.maxStreams);
auto maxStreamsFrameSize = packetType.getSize() + streamId.getSize();
if (packetSpaceCheck(spaceLeft, maxStreamsFrameSize)) {
builder.write(packetType);
builder.write(streamId);
builder.appendFrame(std::move(maxStreamsFrame));
return maxStreamsFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](DataBlockedFrame& blockedFrame) {
QuicInteger packetType(static_cast<uint8_t>(FrameType::DATA_BLOCKED));
QuicInteger dataLimit(blockedFrame.dataLimit);
auto blockedFrameSize = packetType.getSize() + dataLimit.getSize();
if (packetSpaceCheck(spaceLeft, blockedFrameSize)) {
builder.write(packetType);
builder.write(dataLimit);
builder.appendFrame(std::move(blockedFrame));
return blockedFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](StreamDataBlockedFrame& streamBlockedFrame) {
QuicInteger packetType(
static_cast<uint8_t>(FrameType::STREAM_DATA_BLOCKED));
QuicInteger streamId(streamBlockedFrame.streamId);
QuicInteger dataLimit(streamBlockedFrame.dataLimit);
auto blockedFrameSize =
packetType.getSize() + streamId.getSize() + dataLimit.getSize();
if (packetSpaceCheck(spaceLeft, blockedFrameSize)) {
builder.write(packetType);
builder.write(streamId);
builder.write(dataLimit);
builder.appendFrame(std::move(streamBlockedFrame));
return blockedFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](StreamsBlockedFrame& streamsBlockedFrame) {
auto frameType = streamsBlockedFrame.isForBidirectionalStream()
? FrameType::STREAMS_BLOCKED_BIDI
: FrameType::STREAMS_BLOCKED_UNI;
QuicInteger packetType(static_cast<FrameTypeType>(frameType));
QuicInteger streamId(streamsBlockedFrame.streamLimit);
auto streamBlockedFrameSize = packetType.getSize() + streamId.getSize();
if (packetSpaceCheck(spaceLeft, streamBlockedFrameSize)) {
builder.write(packetType);
builder.write(streamId);
builder.appendFrame(std::move(streamsBlockedFrame));
return streamBlockedFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](NewConnectionIdFrame& newConnectionIdFrame) {
QuicInteger packetType(
static_cast<uint8_t>(FrameType::NEW_CONNECTION_ID));
QuicInteger sequence(newConnectionIdFrame.sequence);
// Include an 8-bit unsigned integer containing the length of the connId
auto newConnectionIdFrameSize = packetType.getSize() + sizeof(uint8_t) +
sequence.getSize() + newConnectionIdFrame.connectionId.size() +
newConnectionIdFrame.token.size();
if (packetSpaceCheck(spaceLeft, newConnectionIdFrameSize)) {
builder.write(packetType);
builder.write(sequence);
builder.writeBE(newConnectionIdFrame.connectionId.size());
builder.push(
newConnectionIdFrame.connectionId.data(),
newConnectionIdFrame.connectionId.size());
builder.push(
newConnectionIdFrame.token.data(),
newConnectionIdFrame.token.size());
builder.appendFrame(std::move(newConnectionIdFrame));
return newConnectionIdFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](ConnectionCloseFrame& connectionCloseFrame) {
QuicInteger packetType(
static_cast<uint8_t>(FrameType::CONNECTION_CLOSE));
QuicInteger reasonLength(connectionCloseFrame.reasonPhrase.size());
auto connCloseFrameSize = packetType.getSize() +
sizeof(TransportErrorCode) +
sizeof(connectionCloseFrame.closingFrameType) +
reasonLength.getSize() + connectionCloseFrame.reasonPhrase.size();
if (packetSpaceCheck(spaceLeft, connCloseFrameSize)) {
builder.write(packetType);
builder.writeBE(
static_cast<std::underlying_type<TransportErrorCode>::type>(
connectionCloseFrame.errorCode));
QuicInteger closingFrameType(static_cast<FrameTypeType>(
connectionCloseFrame.closingFrameType));
builder.write(closingFrameType);
builder.write(reasonLength);
builder.push(
(const uint8_t*)connectionCloseFrame.reasonPhrase.data(),
connectionCloseFrame.reasonPhrase.size());
builder.appendFrame(std::move(connectionCloseFrame));
return connCloseFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](ApplicationCloseFrame& applicationCloseFrame) {
QuicInteger packetType(
static_cast<uint8_t>(FrameType::APPLICATION_CLOSE));
QuicInteger reasonLength(applicationCloseFrame.reasonPhrase.size());
auto applicationCloseFrameSize = packetType.getSize() +
sizeof(ApplicationErrorCode) + reasonLength.getSize() +
applicationCloseFrame.reasonPhrase.size();
if (packetSpaceCheck(spaceLeft, applicationCloseFrameSize)) {
builder.write(packetType);
builder.writeBE(
static_cast<std::underlying_type<ApplicationErrorCode>::type>(
applicationCloseFrame.errorCode));
builder.write(reasonLength);
builder.push(
(const uint8_t*)applicationCloseFrame.reasonPhrase.data(),
applicationCloseFrame.reasonPhrase.size());
builder.appendFrame(std::move(applicationCloseFrame));
return applicationCloseFrameSize;
}
// no space left in packet
return size_t(0);
},
[&](QuicSimpleFrame& simpleFrame) {
return writeSimpleFrame(std::move(simpleFrame), builder);
},
[&](auto&) -> size_t {
// TODO add support for: RETIRE_CONNECTION_ID and NEW_TOKEN frames
auto errorStr = folly::to<std::string>(
"Unknown / unsupported frame type received at ", __func__);
VLOG(2) << errorStr;
throw QuicTransportException(
errorStr, TransportErrorCode::FRAME_ENCODING_ERROR);
});
}
} // namespace quic

145
quic/codec/QuicWriteCodec.h Normal file
View File

@ -0,0 +1,145 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <quic/codec/QuicPacketBuilder.h>
#include <quic/codec/Types.h>
#include <quic/common/IntervalSet.h>
#include <chrono>
namespace quic {
struct AckFrameMetaData {
// Ack blocks. There must be at least 1 ACK block to send.
const IntervalSet<PacketNum>& ackBlocks;
// Delay in sending ack from time that packet was received.
std::chrono::microseconds ackDelay;
// The ack delay exponent to use.
uint8_t ackDelayExponent;
AckFrameMetaData(
const IntervalSet<PacketNum>& acksIn,
std::chrono::microseconds ackDelayIn,
uint8_t ackDelayExponentIn)
: ackBlocks(acksIn),
ackDelay(ackDelayIn),
ackDelayExponent(ackDelayExponentIn) {}
};
struct StreamFrameMetaData {
StreamId id{0};
uint64_t offset{0};
bool fin{false};
Buf data;
bool hasMoreFrames{true};
StreamFrameMetaData() = default;
StreamFrameMetaData(
StreamId idIn,
uint64_t offsetIn,
bool finIn,
Buf bufIn,
bool hasMoreFramesIn)
: id(idIn),
offset(offsetIn),
fin(finIn),
data(std::move(bufIn)),
hasMoreFrames(hasMoreFramesIn) {}
};
struct StreamFrameWriteResult {
// The number of bytes written to the stream data section
uint64_t bytesWritten;
bool finWritten;
Buf writtenData;
explicit StreamFrameWriteResult(
uint64_t bytesWrittenIn,
bool finWrittenIn,
Buf writtenDataIn)
: bytesWritten(bytesWrittenIn),
finWritten(finWrittenIn),
writtenData(std::move(writtenDataIn)) {}
};
struct AckFrameWriteResult {
uint64_t bytesWritten;
// This includes the first ack block
size_t ackBlocksWritten;
AckFrameWriteResult(uint64_t bytesWrittenIn, size_t ackBlocksWrittenIn)
: bytesWritten(bytesWrittenIn), ackBlocksWritten(ackBlocksWrittenIn) {}
};
/**
* Write a simple QuicFrame into builder
*
* The input parameter is the frame to be written to the output appender.
*
*/
size_t writeSimpleFrame(
QuicSimpleFrame&& frame,
PacketBuilderInterface& builder);
/**
* Write a (non-ACK, non-Stream) QuicFrame into builder
*
* The input parameter is the frame to be written to the output appender.
*
*/
size_t writeFrame(QuicWriteFrame&& frame, PacketBuilderInterface& builder);
/**
* Write a StreamFrame into builder
*
* WriteCodec will take as much as it can to append it to the appender. Input
* data, Stream id and offset are passed in via streamFrameMetaData.
*
* Return: A StreamFrameWriteResult to indicate how many bytes of data (not
* including other stream frame fields) are written to the stream
*/
folly::Optional<StreamFrameWriteResult> writeStreamFrame(
const StreamFrameMetaData& streamFrameMetaData,
PacketBuilderInterface& builder);
/**
* Write a CryptoFrame into builder. The builder may not be able to accept all
* the bytes that are supplied to writeCryptoFrame.
*
* offset is the offset of the crypto frame to write into the builder
* data is the actual data that needs to be written.
*
* Return: A WriteCryptoFrame which represents the crypto frame that was
* written. The caller should check the structure to confirm how many bytes were
* written.
*/
folly::Optional<WriteCryptoFrame>
writeCryptoFrame(uint64_t offsetIn, Buf data, PacketBuilderInterface& builder);
/**
* Write a AckFrame into builder
*
* Similar to writeStreamFrame, the codec will give a best effort to write as
* many as AckBlock as it can. The WriteCodec may not be able to write
* all of them though. A vector of AckBlocks, the largest acked bytes and other
* ACK frame specific info are passed via ackFrameMetaData.
*
* The ackBlocks are supposed to be sorted in descending order
* of the packet sequence numbers. Exception will be thrown if they are not
* sorted.
*
* Return: A AckFrameWriteResult to indicate how many bytes and ack blocks are
* written to the appender. Returns an empty optional if an ack block could not
* be written.
*/
folly::Optional<AckFrameWriteResult> writeAckFrame(
const AckFrameMetaData& ackFrameMetaData,
PacketBuilderInterface& builder);
} // namespace quic

149
quic/codec/TARGETS Normal file
View File

@ -0,0 +1,149 @@
# @autodeps
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
cpp_library(
name = "types",
srcs = [
"DefaultConnectionIdAlgo.cpp",
"PacketNumber.cpp",
"QuicConnectionId.cpp",
"QuicInteger.cpp",
"Types.cpp",
],
headers = [
"ConnectionIdAlgo.h",
"DefaultConnectionIdAlgo.h",
"PacketNumber.h",
"QuicConnectionId.h",
"QuicInteger.h",
"Types.h",
],
deps = [
"//folly:conv",
"//folly:optional",
"//folly:overload",
"//folly:random",
"//folly:string",
"//folly/container:array",
"//folly/hash:hash",
"//folly/io:iobuf",
"//folly/lang:bits",
"//quic:constants",
"//quic:exception",
"//quic/common:interval_set",
],
external_deps = [
"boost",
"glog",
],
)
cpp_library(
name = "packet_number_cipher",
srcs = [
"PacketNumberCipher.cpp",
],
deps = [
":decode",
":types",
"//folly:optional",
"//folly/io:iobuf",
"//folly/ssl:openssl_ptr_types",
],
)
cpp_library(
name = "decode",
srcs = [
"Decode.cpp",
],
headers = [
"Decode.h",
],
compiler_flags = [
"-fstrict-aliasing",
],
deps = [
":types",
"//folly:string",
"//folly/io:iobuf",
"//quic:exception",
],
)
cpp_library(
name = "pktbuilder",
srcs = [
"QuicPacketBuilder.cpp",
],
headers = [
"QuicPacketBuilder.h",
],
compiler_flags = [
"-fstrict-aliasing",
],
deps = [
":types",
"//folly:random",
"//quic/handshake:handshake",
],
)
cpp_library(
name = "pktrebuilder",
srcs = [
"QuicPacketRebuilder.cpp",
],
headers = [
"QuicPacketRebuilder.h",
],
deps = [
":codec",
":pktbuilder",
"//quic/flowcontrol:flow_control",
"//quic/state:simple_frame_functions",
"//quic/state:state_machine",
"//quic/state:stream_functions",
],
)
cpp_library(
name = "header_codec",
srcs = [
"QuicHeaderCodec.cpp",
],
headers = [
"QuicHeaderCodec.h",
],
deps = [
":decode",
":types",
"//folly:optional",
"//quic:exception",
],
)
cpp_library(
name = "codec",
srcs = [
"QuicReadCodec.cpp",
"QuicWriteCodec.cpp",
],
headers = [
"QuicReadCodec.h",
"QuicWriteCodec.h",
],
deps = [
":decode",
":pktbuilder",
":types",
"//folly:optional",
"//folly/io:iobuf",
"//quic:constants",
"//quic:exception",
"//quic/common:interval_set",
"//quic/handshake:handshake",
"//quic/state:ack_states",
],
)

270
quic/codec/Types.cpp Normal file
View File

@ -0,0 +1,270 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/codec/Types.h>
#include <quic/QuicException.h>
namespace quic {
LongHeaderInvariant::LongHeaderInvariant(
QuicVersion ver,
ConnectionId scid,
ConnectionId dcid)
: version(ver), srcConnId(std::move(scid)), dstConnId(std::move(dcid)) {}
HeaderForm getHeaderForm(uint8_t headerValue) {
if (headerValue & kHeaderFormMask) {
return HeaderForm::Long;
}
return HeaderForm::Short;
}
LongHeader::LongHeader(
Types type,
LongHeaderInvariant invariant,
Buf token,
folly::Optional<ConnectionId> originalDstConnId)
: headerForm_(HeaderForm::Long),
longHeaderType_(type),
invariant_(std::move(invariant)),
token_(std::move(token)),
originalDstConnId_(originalDstConnId) {}
LongHeader::LongHeader(
Types type,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
PacketNum packetNum,
QuicVersion version,
Buf token,
folly::Optional<ConnectionId> originalDstConnId)
: headerForm_(HeaderForm::Long),
longHeaderType_(type),
invariant_(LongHeaderInvariant(version, srcConnId, dstConnId)),
packetSequenceNum_(packetNum),
token_(token ? std::move(token) : nullptr),
originalDstConnId_(originalDstConnId) {}
LongHeader::LongHeader(const LongHeader& other)
: headerForm_(other.headerForm_),
longHeaderType_(other.longHeaderType_),
invariant_(other.invariant_),
packetSequenceNum_(other.packetSequenceNum_),
originalDstConnId_(other.originalDstConnId_) {
if (other.token_) {
token_ = other.token_->clone();
}
}
void LongHeader::setPacketNumber(PacketNum packetNum) {
packetSequenceNum_ = packetNum;
}
LongHeader& LongHeader::operator=(const LongHeader& other) {
headerForm_ = other.headerForm_;
longHeaderType_ = other.longHeaderType_;
invariant_ = other.invariant_;
packetSequenceNum_ = other.packetSequenceNum_;
originalDstConnId_ = other.originalDstConnId_;
if (other.token_) {
token_ = other.token_->clone();
}
return *this;
}
LongHeader::Types LongHeader::getHeaderType() const noexcept {
return longHeaderType_;
}
const ConnectionId& LongHeader::getSourceConnId() const {
return invariant_.srcConnId;
}
const ConnectionId& LongHeader::getDestinationConnId() const {
return invariant_.dstConnId;
}
const folly::Optional<ConnectionId>& LongHeader::getOriginalDstConnId() const {
return originalDstConnId_;
}
PacketNum LongHeader::getPacketSequenceNum() const {
return *packetSequenceNum_;
}
QuicVersion LongHeader::getVersion() const {
return invariant_.version;
}
bool LongHeader::hasToken() const {
return token_ ? true : false;
}
folly::IOBuf* LongHeader::getToken() const {
return token_.get();
}
ProtectionType LongHeader::getProtectionType() const {
return longHeaderTypeToProtectionType(getHeaderType());
}
PacketNumberSpace LongHeader::getPacketNumberSpace() const noexcept {
return longHeaderTypeToPacketNumberSpace(getHeaderType());
}
ProtectionType longHeaderTypeToProtectionType(
LongHeader::Types longHeaderType) {
switch (longHeaderType) {
case LongHeader::Types::Initial:
case LongHeader::Types::Retry:
return ProtectionType::Initial;
case LongHeader::Types::Handshake:
return ProtectionType::Handshake;
case LongHeader::Types::ZeroRtt:
return ProtectionType::ZeroRtt;
}
folly::assume_unreachable();
}
ShortHeaderInvariant::ShortHeaderInvariant(ConnectionId dcid)
: destinationConnId(std::move(dcid)) {}
PacketNumberSpace longHeaderTypeToPacketNumberSpace(
LongHeader::Types longHeaderType) {
switch (longHeaderType) {
case LongHeader::Types::Initial:
case LongHeader::Types::Retry:
return PacketNumberSpace::Initial;
case LongHeader::Types::Handshake:
return PacketNumberSpace::Handshake;
case LongHeader::Types::ZeroRtt:
return PacketNumberSpace::AppData;
}
folly::assume_unreachable();
}
ShortHeader::ShortHeader(
ProtectionType protectionType,
ConnectionId connId,
PacketNum packetNum)
: headerForm_(HeaderForm::Short),
protectionType_(protectionType),
connectionId_(std::move(connId)),
packetSequenceNum_(packetNum) {
if (protectionType_ != ProtectionType::KeyPhaseZero &&
protectionType_ != ProtectionType::KeyPhaseOne) {
throw QuicInternalException(
"bad short header protection type", LocalErrorCode::CODEC_ERROR);
}
}
ShortHeader::ShortHeader(ProtectionType protectionType, ConnectionId connId)
: headerForm_(HeaderForm::Short),
protectionType_(protectionType),
connectionId_(std::move(connId)) {
if (protectionType_ != ProtectionType::KeyPhaseZero &&
protectionType_ != ProtectionType::KeyPhaseOne) {
throw QuicInternalException(
"bad short header protection type", LocalErrorCode::CODEC_ERROR);
}
}
ProtectionType ShortHeader::getProtectionType() const noexcept {
return protectionType_;
}
PacketNumberSpace ShortHeader::getPacketNumberSpace() const noexcept {
return PacketNumberSpace::AppData;
}
const ConnectionId& ShortHeader::getConnectionId() const {
return connectionId_;
}
PacketNum ShortHeader::getPacketSequenceNum() const {
return *packetSequenceNum_;
}
void ShortHeader::setPacketNumber(PacketNum packetNum) {
packetSequenceNum_ = packetNum;
}
StreamTypeField::StreamTypeField(uint8_t field) : field_(field) {}
folly::Optional<StreamTypeField> StreamTypeField::tryStream(uint8_t field) {
if ((field & kStreamFrameMask) == kStreamFrameMask) {
return StreamTypeField(field);
}
return folly::none;
}
bool StreamTypeField::hasDataLength() const {
return field_ & kDataLengthBit;
}
bool StreamTypeField::hasFin() const {
return field_ & kFinBit;
}
bool StreamTypeField::hasOffset() const {
return field_ & kOffsetBit;
}
uint8_t StreamTypeField::fieldValue() const {
return field_;
}
StreamTypeField::Builder& StreamTypeField::Builder::setFin() {
field_ |= StreamTypeField::kFinBit;
return *this;
}
StreamTypeField::Builder& StreamTypeField::Builder::setOffset() {
field_ |= StreamTypeField::kOffsetBit;
return *this;
}
StreamTypeField::Builder& StreamTypeField::Builder::setLength() {
field_ |= StreamTypeField::kDataLengthBit;
return *this;
}
StreamTypeField StreamTypeField::Builder::build() {
return StreamTypeField(field_);
}
std::string toString(PacketNumberSpace pnSpace) {
switch (pnSpace) {
case PacketNumberSpace::Initial:
return "InitialSpace";
case PacketNumberSpace::Handshake:
return "HandshakeSpace";
case PacketNumberSpace::AppData:
return "AppDataSpace";
}
CHECK(false) << "Unknown packet number space";
}
std::string toString(ProtectionType protectionType) {
switch (protectionType) {
case ProtectionType::Initial:
return "Initial";
case ProtectionType::Handshake:
return "Handshake";
case ProtectionType::ZeroRtt:
return "ZeroRtt";
case ProtectionType::KeyPhaseZero:
return "KeyPhaseZero";
case ProtectionType::KeyPhaseOne:
return "KeyPhaseOne";
}
CHECK(false) << "Unknown protection type";
}
} // namespace quic

860
quic/codec/Types.h Normal file
View File

@ -0,0 +1,860 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <boost/variant.hpp>
#include <folly/Conv.h>
#include <folly/Optional.h>
#include <folly/Overload.h>
#include <folly/container/Array.h>
#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h>
#include <quic/QuicConstants.h>
#include <quic/QuicException.h>
#include <quic/codec/QuicConnectionId.h>
#include <quic/codec/QuicInteger.h>
#include <quic/common/IntervalSet.h>
/**
* This details the types of objects that can be serialized or deserialized
* over the wire.
*/
namespace quic {
using Buf = std::unique_ptr<folly::IOBuf>;
using StreamId = uint64_t;
using PacketNum = uint64_t;
enum class PacketNumberSpace : uint8_t {
Initial,
Handshake,
AppData,
};
using StatelessResetToken = std::array<uint8_t, 16>;
constexpr uint8_t kHeaderFormMask = 0x80;
constexpr auto kMaxPacketNumEncodingSize = 4;
struct PaddingFrame {
bool operator==(const PaddingFrame& /*rhs*/) const {
return true;
}
};
struct PingFrame {
PingFrame() = default;
bool operator==(const PingFrame& /*rhs*/) const {
return true;
}
};
/**
* AckBlock represents a series of continuous packet sequences from
* [startPacket, endPacket]
*/
struct AckBlock {
PacketNum startPacket;
PacketNum endPacket;
AckBlock(PacketNum start, PacketNum end)
: startPacket(start), endPacket(end) {}
};
/**
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Largest Acknowledged (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| ACK Delay (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| ACK Block Count (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| ACK Blocks (*) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| First ACK Block (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Gap (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Additional ACK Block (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
struct ReadAckFrame {
PacketNum largestAcked;
std::chrono::microseconds ackDelay;
// Should have at least 1 block.
// These are ordered in descending order by start packet.
std::vector<AckBlock> ackBlocks;
bool operator==(const ReadAckFrame& /*rhs*/) const {
// Can't compare ackBlocks, function is just here to appease compiler.
return false;
}
};
struct WriteAckFrame {
IntervalSet<PacketNum> ackBlocks;
// Delay in sending ack from time that packet was received.
std::chrono::microseconds ackDelay;
bool operator==(const WriteAckFrame& /*rhs*/) const {
// Can't compare ackBlocks, function is just here to appease compiler.
return false;
}
};
struct RstStreamFrame {
StreamId streamId;
ApplicationErrorCode errorCode;
uint64_t offset;
RstStreamFrame(
StreamId streamIdIn,
ApplicationErrorCode errorCodeIn,
uint64_t offsetIn)
: streamId(streamIdIn), errorCode(errorCodeIn), offset(offsetIn) {}
bool operator==(const RstStreamFrame& rhs) const {
return streamId == rhs.streamId && errorCode == rhs.errorCode &&
offset == rhs.offset;
}
};
struct StopSendingFrame {
StreamId streamId;
ApplicationErrorCode errorCode;
StopSendingFrame(StreamId streamIdIn, ApplicationErrorCode errorCodeIn)
: streamId(streamIdIn), errorCode(errorCodeIn) {}
bool operator==(const StopSendingFrame& rhs) const {
return streamId == rhs.streamId && errorCode == rhs.errorCode;
}
};
struct ReadCryptoFrame {
uint64_t offset;
Buf data;
ReadCryptoFrame(uint64_t offsetIn, Buf dataIn)
: offset(offsetIn), data(std::move(dataIn)) {}
explicit ReadCryptoFrame(uint64_t offsetIn)
: offset(offsetIn), data(folly::IOBuf::create(0)) {}
// Stuff stored in a variant type needs to be copyable.
// TODO: can we make this copyable only by the variant, but not
// by anyone else.
ReadCryptoFrame(const ReadCryptoFrame& other) {
offset = other.offset;
if (other.data) {
data = other.data->clone();
}
}
ReadCryptoFrame& operator=(const ReadCryptoFrame& other) {
offset = other.offset;
if (other.data) {
data = other.data->clone();
}
return *this;
}
bool operator==(const ReadCryptoFrame& other) const {
folly::IOBufEqualTo eq;
return offset == other.offset && eq(data, other.data);
}
};
struct WriteCryptoFrame {
uint64_t offset;
uint64_t len;
WriteCryptoFrame(uint64_t offsetIn, uint64_t lenIn)
: offset(offsetIn), len(lenIn) {}
bool operator==(const WriteCryptoFrame& rhs) const {
return offset == rhs.offset && len == rhs.len;
}
};
struct ReadNewTokenFrame {
Buf token;
ReadNewTokenFrame(Buf tokenIn) : token(std::move(tokenIn)) {}
// Stuff stored in a variant type needs to be copyable.
// TODO: can we make this copyable only by the variant, but not
// by anyone else.
ReadNewTokenFrame(const ReadNewTokenFrame& other) {
if (other.token) {
token = other.token->clone();
}
}
ReadNewTokenFrame& operator=(const ReadNewTokenFrame& other) {
if (other.token) {
token = other.token->clone();
}
return *this;
}
bool operator==(const ReadNewTokenFrame& other) const {
folly::IOBufEqualTo eq;
return eq(token, other.token);
}
};
/**
The structure of the stream frame used for writes.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Stream ID (i) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| [Offset (i)] ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| [Length (i)] ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Stream Data (*) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
struct WriteStreamFrame {
StreamId streamId;
uint64_t offset;
uint64_t len;
bool fin;
WriteStreamFrame(
StreamId streamIdIn,
uint64_t offsetIn,
uint64_t lenIn,
bool finIn)
: streamId(streamIdIn), offset(offsetIn), len(lenIn), fin(finIn) {}
bool operator==(const WriteStreamFrame& rhs) const {
return streamId == rhs.streamId && offset == rhs.offset && len == rhs.len &&
fin == rhs.fin;
}
};
/**
* The structure of the stream frame used for reads.
*/
struct ReadStreamFrame {
StreamId streamId;
uint64_t offset;
Buf data;
bool fin;
ReadStreamFrame(
StreamId streamIdIn,
uint64_t offsetIn,
Buf dataIn,
bool finIn)
: streamId(streamIdIn),
offset(offsetIn),
data(std::move(dataIn)),
fin(finIn) {}
ReadStreamFrame(StreamId streamIdIn, uint64_t offsetIn, bool finIn)
: streamId(streamIdIn),
offset(offsetIn),
data(folly::IOBuf::create(0)),
fin(finIn) {}
// Stuff stored in a variant type needs to be copyable.
// TODO: can we make this copyable only by the variant, but not
// by anyone else.
ReadStreamFrame(const ReadStreamFrame& other) {
streamId = other.streamId;
offset = other.offset;
if (other.data) {
data = other.data->clone();
}
fin = other.fin;
}
ReadStreamFrame& operator=(const ReadStreamFrame& other) {
streamId = other.streamId;
offset = other.offset;
if (other.data) {
data = other.data->clone();
}
fin = other.fin;
return *this;
}
bool operator==(const ReadStreamFrame& other) const {
folly::IOBufEqualTo eq;
return streamId == other.streamId && offset == other.offset &&
fin == other.fin && eq(data, other.data);
}
};
struct MaxDataFrame {
uint64_t maximumData;
explicit MaxDataFrame(uint64_t maximumDataIn) : maximumData(maximumDataIn) {}
bool operator==(const MaxDataFrame& rhs) const {
return maximumData == rhs.maximumData;
}
};
struct MaxStreamDataFrame {
StreamId streamId;
uint64_t maximumData;
MaxStreamDataFrame(StreamId streamIdIn, uint64_t maximumDataIn)
: streamId(streamIdIn), maximumData(maximumDataIn) {}
bool operator==(const MaxStreamDataFrame& rhs) const {
return streamId == rhs.streamId && maximumData == rhs.maximumData;
}
};
// The MinStreamDataFrame is used by a receiver to inform
// a sender of the maximum amount of data that can be sent on a stream
// (like MAX_STREAM_DATA frame) and to request an update to the minimum
// retransmittable offset for this stream.
struct MinStreamDataFrame {
StreamId streamId;
uint64_t maximumData;
uint64_t minimumStreamOffset;
MinStreamDataFrame(
StreamId streamIdIn,
uint64_t maximumDataIn,
uint64_t minimumStreamOffsetIn)
: streamId(streamIdIn),
maximumData(maximumDataIn),
minimumStreamOffset(minimumStreamOffsetIn) {}
bool operator==(const MinStreamDataFrame& rhs) const {
return streamId == rhs.streamId && maximumData == rhs.maximumData &&
minimumStreamOffset == rhs.minimumStreamOffset;
}
};
// The ExpiredStreamDataFrame is used by a sender to
// inform a receiver of the minimum retransmittable offset for a stream.
struct ExpiredStreamDataFrame {
StreamId streamId;
uint64_t minimumStreamOffset;
ExpiredStreamDataFrame(StreamId streamIdIn, uint64_t minimumStreamOffsetIn)
: streamId(streamIdIn), minimumStreamOffset(minimumStreamOffsetIn) {}
bool operator==(const ExpiredStreamDataFrame& rhs) const {
return streamId == rhs.streamId &&
minimumStreamOffset == rhs.minimumStreamOffset;
}
};
struct MaxStreamsFrame {
// A count of the cumulative number of streams
uint64_t maxStreams;
bool isForBidirectional{false};
explicit MaxStreamsFrame(uint64_t maxStreamsIn, bool isBidirectionalIn)
: maxStreams(maxStreamsIn), isForBidirectional(isBidirectionalIn) {}
bool isForBidirectionalStream() const {
return isForBidirectional;
}
bool isForUnidirectionalStream() {
return !isForBidirectional;
}
bool operator==(const MaxStreamsFrame& rhs) const {
return maxStreams == rhs.maxStreams &&
isForBidirectional == rhs.isForBidirectional;
}
};
struct DataBlockedFrame {
// the connection-level limit at which blocking occurred
uint64_t dataLimit;
explicit DataBlockedFrame(uint64_t dataLimitIn) : dataLimit(dataLimitIn) {}
bool operator==(const DataBlockedFrame& rhs) const {
return dataLimit == rhs.dataLimit;
}
};
struct StreamDataBlockedFrame {
StreamId streamId;
uint64_t dataLimit;
StreamDataBlockedFrame(StreamId streamIdIn, uint64_t dataLimitIn)
: streamId(streamIdIn), dataLimit(dataLimitIn) {}
bool operator==(const StreamDataBlockedFrame& rhs) const {
return streamId == rhs.streamId && dataLimit == rhs.dataLimit;
}
};
struct StreamsBlockedFrame {
uint64_t streamLimit;
bool isForBidirectional{false};
explicit StreamsBlockedFrame(uint64_t streamLimitIn, bool isBidirectionalIn)
: streamLimit(streamLimitIn), isForBidirectional(isBidirectionalIn) {}
bool isForBidirectionalStream() const {
return isForBidirectional;
}
bool isForUnidirectionalStream() const {
return !isForBidirectional;
}
bool operator==(const StreamsBlockedFrame& rhs) const {
return streamLimit == rhs.streamLimit;
}
};
struct NewConnectionIdFrame {
uint16_t sequence;
ConnectionId connectionId;
StatelessResetToken token;
NewConnectionIdFrame(
uint16_t sequenceIn,
ConnectionId connectionIdIn,
StatelessResetToken tokenIn)
: sequence(sequenceIn),
connectionId(connectionIdIn),
token(std::move(tokenIn)) {}
bool operator==(const NewConnectionIdFrame& rhs) const {
return sequence == rhs.sequence && connectionId == rhs.connectionId &&
token == rhs.token;
}
};
struct RetireConnectionIdFrame {
uint64_t sequenceId;
explicit RetireConnectionIdFrame(uint64_t sequenceIn)
: sequenceId(sequenceIn) {}
};
struct PathChallengeFrame {
uint64_t pathData;
explicit PathChallengeFrame(uint64_t pathDataIn) : pathData(pathDataIn) {}
bool operator==(const PathChallengeFrame& rhs) const {
return pathData == rhs.pathData;
}
bool operator!=(const PathChallengeFrame& rhs) const {
return !(*this == rhs);
}
};
struct PathResponseFrame {
uint64_t pathData;
explicit PathResponseFrame(uint64_t pathDataIn) : pathData(pathDataIn) {}
bool operator==(const PathResponseFrame& rhs) const {
return pathData == rhs.pathData;
}
};
struct ConnectionCloseFrame {
// Members are not const to allow this to be movable.
TransportErrorCode errorCode;
std::string reasonPhrase;
// Per QUIC specification: type of frame that triggered the (close) error.
// A value of 0 (PADDING frame) implies the frame type is unknown
FrameType closingFrameType;
ConnectionCloseFrame(
TransportErrorCode errorCodeIn,
std::string reasonPhraseIn,
FrameType closingFrameTypeIn = FrameType::PADDING)
: errorCode(errorCodeIn),
reasonPhrase(std::move(reasonPhraseIn)),
closingFrameType(closingFrameTypeIn) {}
FrameType getClosingFrameType() const noexcept {
return closingFrameType;
}
bool operator==(const ConnectionCloseFrame& rhs) const {
return errorCode == rhs.errorCode && reasonPhrase == rhs.reasonPhrase;
}
};
struct ApplicationCloseFrame {
// Members are not const to allow this to be movable.
ApplicationErrorCode errorCode;
std::string reasonPhrase;
ApplicationCloseFrame(
ApplicationErrorCode errorCodeIn,
std::string reasonPhraseIn)
: errorCode(errorCodeIn), reasonPhrase(std::move(reasonPhraseIn)) {}
bool operator==(const ApplicationCloseFrame& rhs) const {
return errorCode == rhs.errorCode && reasonPhrase == rhs.reasonPhrase;
}
};
// Frame to represent ones we skip
struct NoopFrame {};
constexpr uint8_t kStatelessResetTokenLength = 16;
using StatelessResetToken = std::array<uint8_t, kStatelessResetTokenLength>;
struct StatelessReset {
StatelessResetToken token;
explicit StatelessReset(StatelessResetToken tokenIn)
: token(std::move(tokenIn)) {}
};
using QuicSimpleFrame = boost::variant<
StopSendingFrame,
MinStreamDataFrame,
ExpiredStreamDataFrame,
PathChallengeFrame,
PathResponseFrame>;
// Types of frames that can be read.
using QuicFrame = boost::variant<
PaddingFrame,
RstStreamFrame,
ConnectionCloseFrame,
ApplicationCloseFrame,
MaxDataFrame,
MaxStreamDataFrame,
MaxStreamsFrame,
PingFrame,
DataBlockedFrame,
StreamDataBlockedFrame,
StreamsBlockedFrame,
NewConnectionIdFrame,
ReadAckFrame,
ReadStreamFrame,
ReadCryptoFrame,
ReadNewTokenFrame,
QuicSimpleFrame,
NoopFrame>;
// Types of frames which are written.
using QuicWriteFrame = boost::variant<
PaddingFrame,
RstStreamFrame,
ConnectionCloseFrame,
ApplicationCloseFrame,
MaxDataFrame,
MaxStreamDataFrame,
MaxStreamsFrame,
StreamsBlockedFrame,
PingFrame,
DataBlockedFrame,
StreamDataBlockedFrame,
NewConnectionIdFrame,
WriteAckFrame,
WriteStreamFrame,
WriteCryptoFrame,
QuicSimpleFrame>;
enum class HeaderForm : bool {
Long = 1,
Short = 0,
};
enum class ProtectionType {
Initial,
Handshake,
ZeroRtt,
KeyPhaseZero,
KeyPhaseOne,
};
struct LongHeaderInvariant {
QuicVersion version;
ConnectionId srcConnId;
ConnectionId dstConnId;
LongHeaderInvariant(QuicVersion ver, ConnectionId scid, ConnectionId dcid);
};
// TODO: split this into read and write types.
struct LongHeader {
public:
static constexpr uint8_t kFixedBitMask = 0x40;
static constexpr uint8_t kPacketTypeMask = 0x30;
static constexpr uint8_t kReservedBitsMask = 0x0c;
static constexpr uint8_t kPacketNumLenMask = 0x03;
static constexpr uint8_t kTypeBitsMask = 0x0F;
static constexpr uint8_t kTypeShift = 4;
enum class Types : uint8_t {
Initial = 0x0,
ZeroRtt = 0x1,
Handshake = 0x2,
Retry = 0x3,
};
LongHeader(
Types type,
const ConnectionId& srcConnId,
const ConnectionId& dstConnId,
PacketNum packetNum,
QuicVersion version,
Buf token = nullptr,
folly::Optional<ConnectionId> originalDstConnId = folly::none);
LongHeader(
Types type,
LongHeaderInvariant invariant,
Buf token = nullptr,
folly::Optional<ConnectionId> originalDstConnId = folly::none);
void setPacketNumber(PacketNum packetNum);
// Stuff stored in a variant type needs to be copyable.
// TODO: can we make this copyable only by the variant, but not
// by anyone else.
LongHeader(const LongHeader& other);
LongHeader& operator=(const LongHeader& other);
Types getHeaderType() const noexcept;
const ConnectionId& getSourceConnId() const;
const ConnectionId& getDestinationConnId() const;
const folly::Optional<ConnectionId>& getOriginalDstConnId() const;
PacketNum getPacketSequenceNum() const;
QuicVersion getVersion() const;
ProtectionType getProtectionType() const;
PacketNumberSpace getPacketNumberSpace() const noexcept;
bool hasToken() const;
folly::IOBuf* getToken() const;
private:
HeaderForm headerForm_;
Types longHeaderType_;
LongHeaderInvariant invariant_;
folly::Optional<PacketNum> packetSequenceNum_; // at most 32 bits on wire
Buf token_;
folly::Optional<ConnectionId> originalDstConnId_;
};
struct ShortHeaderInvariant {
ConnectionId destinationConnId;
explicit ShortHeaderInvariant(ConnectionId dcid);
};
struct ShortHeader {
public:
static constexpr uint8_t kFixedBitMask = 0x40;
// There is also a spin bit which is 0x20, but as a decent implementation of
// course we don't implement that.
static constexpr uint8_t kReservedBitsMask = 0x18;
static constexpr uint8_t kKeyPhaseMask = 0x04;
static constexpr uint8_t kPacketNumLenMask = 0x03;
static constexpr uint8_t kTypeBitsMask = 0x1F;
/**
* The constructor for reading a packet.
*/
ShortHeader(ProtectionType protectionType, ConnectionId connId);
/**
* The constructor for writing a packet.
*/
ShortHeader(
ProtectionType protectionType,
ConnectionId connId,
PacketNum packetNum);
ProtectionType getProtectionType() const noexcept;
PacketNumberSpace getPacketNumberSpace() const noexcept;
const ConnectionId& getConnectionId() const;
PacketNum getPacketSequenceNum() const;
void setPacketNumber(PacketNum packetNum);
private:
ShortHeader() = default;
bool readInitialByte(uint8_t initalByte);
bool readConnectionId(folly::io::Cursor& cursor);
bool readPacketNum(
PacketNum largestReceivedPacketNum,
folly::io::Cursor& cursor);
private:
HeaderForm headerForm_;
ProtectionType protectionType_;
ConnectionId connectionId_;
folly::Optional<PacketNum> packetSequenceNum_; // var-size 8/16/24/32 bits
};
ProtectionType longHeaderTypeToProtectionType(LongHeader::Types type);
PacketNumberSpace longHeaderTypeToPacketNumberSpace(LongHeader::Types type);
using PacketHeader = boost::variant<LongHeader, ShortHeader>;
struct StreamTypeField {
public:
/**
* Returns a StreamTypeField if the field is a stream type.
*/
static folly::Optional<StreamTypeField> tryStream(uint8_t field);
bool hasFin() const;
bool hasDataLength() const;
bool hasOffset() const;
uint8_t fieldValue() const;
struct Builder {
public:
Builder& setFin();
Builder& setOffset();
Builder& setLength();
StreamTypeField build();
private:
uint8_t field_{kStreamFrameMask};
};
private:
static constexpr uint8_t kStreamFrameMask = 0x08;
// Stream Frame specific:
static constexpr uint8_t kFinBit = 0x01;
static constexpr uint8_t kDataLengthBit = 0x02;
static constexpr uint8_t kOffsetBit = 0x04;
explicit StreamTypeField(uint8_t field);
uint8_t field_;
};
struct VersionNegotiationPacket {
uint8_t packetType;
ConnectionId sourceConnectionId;
ConnectionId destinationConnectionId;
std::vector<QuicVersion> versions;
VersionNegotiationPacket(
uint8_t packetTypeIn,
ConnectionId sourceConnectionIdIn,
ConnectionId destinationConnectionIdIn)
: packetType(packetTypeIn),
sourceConnectionId(sourceConnectionIdIn),
destinationConnectionId(destinationConnectionIdIn) {}
};
/**
* Common struct for regular read and write packets.
*/
struct RegularPacket {
PacketHeader header;
explicit RegularPacket(PacketHeader&& headerIn)
: header(std::move(headerIn)) {}
};
/**
* A representation of a regular packet that is read from the network.
* This could be either Cleartext or Encrypted packets in long or short form.
* Cleartext packets include Client Initial, Client Cleartext, Non-Final Server
* Cleartext packet or Final Server Cleartext packet. Encrypted packets
* include 0-RTT, 1-RTT Phase 0 and 1-RTT Phase 1 packets.
*/
struct RegularQuicPacket : public RegularPacket {
std::vector<QuicFrame> frames;
explicit RegularQuicPacket(PacketHeader&& headerIn)
: RegularPacket(std::move(headerIn)) {}
};
/**
* A representation of a regular packet that is written to the network.
*/
struct RegularQuicWritePacket : public RegularPacket {
std::vector<QuicWriteFrame> frames;
explicit RegularQuicWritePacket(PacketHeader&& headerIn)
: RegularPacket(std::move(headerIn)) {}
};
using QuicPacket = boost::variant<RegularQuicPacket, VersionNegotiationPacket>;
using QuicWritePacket =
boost::variant<RegularQuicWritePacket, VersionNegotiationPacket>;
/**
* Returns whether the header is long or short from the initial byte of
* the QUIC packet.
*
* This function is version invariant.
*/
HeaderForm getHeaderForm(uint8_t headerValue);
inline std::ostream& operator<<(
std::ostream& os,
const LongHeader::Types& type) {
switch (type) {
case LongHeader::Types::Initial:
os << "Initial";
break;
case LongHeader::Types::Retry:
os << "Retry";
break;
case LongHeader::Types::Handshake:
os << "Handshake";
break;
case LongHeader::Types::ZeroRtt:
os << "ZeroRtt";
break;
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const PacketHeader& header) {
folly::variant_match(
header,
[&os](const LongHeader& h) {
os << "header=long"
<< " protectionType=" << (int)h.getProtectionType()
<< " type=" << std::hex << (int)h.getHeaderType();
},
[&os](const ShortHeader& h) {
os << "header=short"
<< " protectionType=" << (int)h.getProtectionType();
});
return os;
}
std::string toString(PacketNumberSpace pnSpace);
inline std::ostream& operator<<(std::ostream& os, PacketNumberSpace pnSpace) {
return os << toString(pnSpace);
}
std::string toString(ProtectionType protectionType);
} // namespace quic

View File

@ -0,0 +1,149 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
if(NOT BUILD_TESTS)
return()
endif()
add_library(mvfst_mock_codec STATIC
Mocks.h
)
target_include_directories(
mvfst_mock_codec PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
)
add_dependencies(
mvfst_mock_codec
mvfst_codec_pktbuilder
mvfst_test_utils
)
target_link_libraries(
mvfst_mock_codec PUBLIC
Folly::folly
mvfst_codec_pktbuilder
mvfst_test_utils
)
quic_add_test(TARGET QuicHeaderCodecTest
SOURCES
QuicHeaderCodecTest.cpp
DEPENDS
Folly::folly
mvfst_exception
mvfst_codec
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicReadCodecTest
SOURCES
QuicReadCodecTest.cpp
DEPENDS
Folly::folly
mvfst_codec
mvfst_exception
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicWriteCodecTest
SOURCES
QuicWriteCodecTest.cpp
DEPENDS
Folly::folly
mvfst_codec
mvfst_codec_decode
mvfst_codec_types
mvfst_exception
mvfst_mock_codec
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicTypesTest
SOURCES
TypesTest.cpp
DEPENDS
Folly::folly
mvfst_codec_decode
mvfst_codec_types
mvfst_exception
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET PacketNumberTest
SOURCES
PacketNumberTest.cpp
DEPENDS
Folly::folly
mvfst_codec_types
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET PacketNumberCipherTest
SOURCES
PacketNumberCipherTest.cpp
DEPENDS
Folly::folly
mvfst_codec_packet_number_cipher
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET DecodeTest
SOURCES
DecodeTest.cpp
DEPENDS
Folly::folly
mvfst_codec
mvfst_codec_decode
mvfst_codec_types
mvfst_exception
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicConnectionIdTest
SOURCES
QuicConnectionIdTest.cpp
DEPENDS
Folly::folly
mvfst_codec_types
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicPacketBuilderTest
SOURCES
QuicPacketBuilderTest.cpp
QuicPacketRebuilderTest.cpp
DEPENDS
Folly::folly
mvfst_codec
mvfst_codec_pktbuilder
mvfst_codec_pktrebuilder
mvfst_codec_types
mvfst_handshake
mvfst_mock_codec
mvfst_server
mvfst_state_functions
mvfst_state_machine
mvfst_state_stream_functions
mvfst_test_utils
${LIBGTEST_LIBRARY}
)
quic_add_test(TARGET QuicIntegerTest
SOURCES
QuicIntegerTest.cpp
DEPENDS
Folly::folly
mvfst_codec_types
mvfst_exception
${LIBGTEST_LIBRARY}
)

View File

@ -0,0 +1,675 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#include <quic/codec/Decode.h>
#include <folly/Random.h>
#include <folly/container/Array.h>
#include <folly/io/IOBuf.h>
#include <folly/portability/GTest.h>
#include <quic/codec/QuicReadCodec.h>
#include <quic/codec/Types.h>
#include <quic/common/test/TestUtils.h>
using namespace quic;
using namespace testing;
namespace quic {
namespace test {
using UnderlyingFrameType = std::underlying_type<FrameType>::type;
class DecodeTest : public Test {};
ShortHeader makeHeader() {
PacketNum packetNum = 100;
return ShortHeader(
ProtectionType::KeyPhaseZero, getTestConnectionId(), packetNum);
}
// NormalizedAckBlocks are in order needed.
struct NormalizedAckBlock {
QuicInteger gap; // Gap to previous AckBlock
QuicInteger blockLen;
NormalizedAckBlock(QuicInteger gapIn, QuicInteger blockLenIn)
: gap(gapIn), blockLen(blockLenIn) {}
};
template <class LargestAckedType = uint64_t>
std::unique_ptr<folly::IOBuf> createAckFrame(
folly::Optional<QuicInteger> largestAcked,
folly::Optional<QuicInteger> ackDelay = folly::none,
folly::Optional<QuicInteger> numAdditionalBlocks = folly::none,
folly::Optional<QuicInteger> firstAckBlockLength = folly::none,
std::vector<NormalizedAckBlock> ackBlocks = {},
bool useRealValuesForLargestAcked = false,
bool useRealValuesForAckDelay = false) {
folly::IOBufQueue ackFrame;
folly::io::QueueAppender wcursor(&ackFrame, 10);
if (largestAcked) {
if (useRealValuesForLargestAcked) {
wcursor.writeBE<LargestAckedType>(largestAcked->getValue());
} else {
largestAcked->encode(wcursor);
}
}
if (ackDelay) {
if (useRealValuesForAckDelay) {
wcursor.writeBE(ackDelay->getValue());
} else {
ackDelay->encode(wcursor);
}
}
if (numAdditionalBlocks) {
numAdditionalBlocks->encode(wcursor);
}
if (firstAckBlockLength) {
firstAckBlockLength->encode(wcursor);
}
for (size_t i = 0; i < ackBlocks.size(); ++i) {
ackBlocks[i].gap.encode(wcursor);
ackBlocks[i].blockLen.encode(wcursor);
}
return ackFrame.move();
}
template <class StreamIdType = StreamId>
std::unique_ptr<folly::IOBuf> createStreamFrame(
folly::Optional<QuicInteger> streamId,
folly::Optional<QuicInteger> offset = folly::none,
folly::Optional<QuicInteger> dataLength = folly::none,
Buf data = nullptr,
bool useRealValuesForStreamId = false) {
folly::IOBufQueue streamFrame;
folly::io::QueueAppender wcursor(&streamFrame, 10);
if (streamId) {
if (useRealValuesForStreamId) {
wcursor.writeBE<StreamIdType>(streamId->getValue());
} else {
streamId->encode(wcursor);
}
}
if (offset) {
offset->encode(wcursor);
}
if (dataLength) {
dataLength->encode(wcursor);
}
if (data) {
wcursor.insert(std::move(data));
}
return streamFrame.move();
}
std::unique_ptr<folly::IOBuf> createCryptoFrame(
folly::Optional<QuicInteger> offset = folly::none,
folly::Optional<QuicInteger> dataLength = folly::none,
Buf data = nullptr) {
folly::IOBufQueue cryptoFrame;
folly::io::QueueAppender wcursor(&cryptoFrame, 10);
if (offset) {
offset->encode(wcursor);
}
if (dataLength) {
dataLength->encode(wcursor);
}
if (data) {
wcursor.insert(std::move(data));
}
return cryptoFrame.move();
}
TEST_F(DecodeTest, VersionNegotiationPacketDecodeTest) {
ConnectionId srcCid = getTestConnectionId(0),
destCid = getTestConnectionId(1);
std::vector<QuicVersion> versions{{static_cast<QuicVersion>(1234),
static_cast<QuicVersion>(4321),
static_cast<QuicVersion>(2341),
static_cast<QuicVersion>(3412),
static_cast<QuicVersion>(4123)}};
auto packet =
VersionNegotiationPacketBuilder(srcCid, destCid, versions).buildPacket();
auto codec = std::make_unique<QuicReadCodec>(QuicNodeType::Server);
AckStates ackStates;
auto packetQueue = bufToQueue(std::move(packet.second));
auto quicPacket = boost::get<QuicPacket>(
codec->parsePacket(packetQueue, ackStates));
auto versionPacket = boost::get<VersionNegotiationPacket>(quicPacket);
EXPECT_EQ(versionPacket.destinationConnectionId, destCid);
EXPECT_EQ(versionPacket.sourceConnectionId, srcCid);
EXPECT_EQ(versionPacket.versions.size(), versions.size());
EXPECT_EQ(versionPacket.versions, versions);
}
TEST_F(DecodeTest, DifferentCIDLength) {
ConnectionId sourceConnectionId = getTestConnectionId();
ConnectionId destinationConnectionId({1, 2, 3, 4, 5, 6});
std::vector<QuicVersion> versions{{static_cast<QuicVersion>(1234),
static_cast<QuicVersion>(4321),
static_cast<QuicVersion>(2341),
static_cast<QuicVersion>(3412),
static_cast<QuicVersion>(4123)}};
auto packet = VersionNegotiationPacketBuilder(
sourceConnectionId, destinationConnectionId, versions)
.buildPacket();
auto codec = std::make_unique<QuicReadCodec>(QuicNodeType::Server);
AckStates ackStates;
auto packetQueue = bufToQueue(std::move(packet.second));
auto quicPacket = boost::get<QuicPacket>(
codec->parsePacket(packetQueue, ackStates));
auto versionPacket = boost::get<VersionNegotiationPacket>(quicPacket);
EXPECT_EQ(versionPacket.sourceConnectionId, sourceConnectionId);
EXPECT_EQ(versionPacket.destinationConnectionId, destinationConnectionId);
EXPECT_EQ(versionPacket.versions.size(), versions.size());
EXPECT_EQ(versionPacket.versions, versions);
}
TEST_F(DecodeTest, VersionNegotiationPacketBadPacketTest) {
ConnectionId connId = getTestConnectionId();
QuicVersionType version = static_cast<QuicVersionType>(QuicVersion::MVFST);
auto buf = folly::IOBuf::create(10);
folly::io::Appender appender(buf.get(), 10);
appender.writeBE<uint8_t>(kHeaderFormMask);
appender.push(connId.data(), connId.size());
appender.writeBE<QuicVersionType>(
static_cast<QuicVersionType>(QuicVersion::VERSION_NEGOTIATION));
appender.push((uint8_t*)&version, sizeof(QuicVersion) - 1);
auto codec = std::make_unique<QuicReadCodec>(QuicNodeType::Server);
AckStates ackStates;
auto packetQueue = bufToQueue(std::move(buf));
auto packet = codec->parsePacket(packetQueue, ackStates);
EXPECT_THROW(boost::get<QuicPacket>(packet), boost::bad_get);
buf = folly::IOBuf::create(0);
packetQueue = bufToQueue(std::move(buf));
packet = codec->parsePacket(packetQueue, ackStates);
// Packet with empty versions
EXPECT_THROW(boost::get<QuicPacket>(packet), boost::bad_get);
}
TEST_F(DecodeTest, ValidAckFrame) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(1);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor(result.get());
auto ackFrame = decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent));
EXPECT_EQ(ackFrame.ackBlocks.size(), 2);
EXPECT_EQ(ackFrame.largestAcked, 1000);
// Since 100 is the encoded value, we use the decoded value.
EXPECT_EQ(ackFrame.ackDelay.count(), 100 << kDefaultAckDelayExponent);
}
TEST_F(DecodeTest, AckFrameLargestAckExceedsRange) {
// An integer larger than the representable range of quic integer.
QuicInteger largestAcked(std::numeric_limits<uint64_t>::max());
QuicInteger ackDelay(10);
QuicInteger numAdditionalBlocks(0);
QuicInteger firstAckBlockLength(10);
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
{},
true);
folly::io::Cursor cursor(result.get());
auto ackFrame = decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent));
// it will interpret this as a 8 byte range with the max value.
EXPECT_EQ(ackFrame.largestAcked, 4611686018427387903);
}
TEST_F(DecodeTest, AckFrameLargestAckInvalid) {
// An integer larger than the representable range of quic integer.
QuicInteger largestAcked(std::numeric_limits<uint64_t>::max());
QuicInteger ackDelay(10);
QuicInteger numAdditionalBlocks(0);
QuicInteger firstAckBlockLength(10);
auto result = createAckFrame<uint8_t>(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
{},
true);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameDelayEncodingInvalid) {
QuicInteger largestAcked(1000);
// Maximal representable value by quic integer.
QuicInteger ackDelay(4611686018427387903);
QuicInteger numAdditionalBlocks(0);
QuicInteger firstAckBlockLength(10);
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
{},
false,
true);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameDelayExceedsRange) {
QuicInteger largestAcked(1000);
// Maximal representable value by quic integer.
QuicInteger ackDelay(4611686018427387903);
QuicInteger numAdditionalBlocks(0);
QuicInteger firstAckBlockLength(10);
auto result = createAckFrame(
largestAcked, ackDelay, numAdditionalBlocks, firstAckBlockLength);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameAdditionalBlocksUnderflow) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(2);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameAdditionalBlocksOverflow) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(2);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor(result.get());
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent));
}
TEST_F(DecodeTest, AckFrameMissingFields) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(2);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
auto header = makeHeader();
auto result1 = createAckFrame(
largestAcked,
folly::none,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor1(result1.get());
EXPECT_THROW(
decodeAckFrame(
cursor1, header, CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
auto result2 = createAckFrame(
largestAcked, ackDelay, folly::none, firstAckBlockLength, ackBlocks);
folly::io::Cursor cursor2(result2.get());
EXPECT_THROW(
decodeAckFrame(
cursor2, header, CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
auto result3 = createAckFrame(
largestAcked, ackDelay, folly::none, firstAckBlockLength, ackBlocks);
folly::io::Cursor cursor3(result3.get());
EXPECT_THROW(
decodeAckFrame(
cursor3, header, CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
auto result4 = createAckFrame(
largestAcked, ackDelay, numAdditionalBlocks, folly::none, ackBlocks);
folly::io::Cursor cursor4(result4.get());
EXPECT_THROW(
decodeAckFrame(
cursor4, header, CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
auto result5 = createAckFrame(
largestAcked, ackDelay, numAdditionalBlocks, firstAckBlockLength, {});
folly::io::Cursor cursor5(result5.get());
EXPECT_THROW(
decodeAckFrame(
cursor5, header, CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameFirstBlockLengthInvalid) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(0);
QuicInteger firstAckBlockLength(2000);
auto result = createAckFrame(
largestAcked, ackDelay, numAdditionalBlocks, firstAckBlockLength);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameBlockLengthInvalid) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(2);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(1000));
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameBlockGapInvalid) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(2);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
ackBlocks.emplace_back(QuicInteger(1000), QuicInteger(0));
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor(result.get());
EXPECT_THROW(
decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent)),
QuicTransportException);
}
TEST_F(DecodeTest, AckFrameBlockLengthZero) {
QuicInteger largestAcked(1000);
QuicInteger ackDelay(100);
QuicInteger numAdditionalBlocks(3);
QuicInteger firstAckBlockLength(10);
std::vector<NormalizedAckBlock> ackBlocks;
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(10));
ackBlocks.emplace_back(QuicInteger(10), QuicInteger(0));
ackBlocks.emplace_back(QuicInteger(0), QuicInteger(10));
auto result = createAckFrame(
largestAcked,
ackDelay,
numAdditionalBlocks,
firstAckBlockLength,
ackBlocks);
folly::io::Cursor cursor(result.get());
auto readAckFrame = decodeAckFrame(
cursor, makeHeader(), CodecParameters(kDefaultAckDelayExponent));
EXPECT_EQ(readAckFrame.ackBlocks[0].endPacket, 1000);
EXPECT_EQ(readAckFrame.ackBlocks[0].startPacket, 990);
EXPECT_EQ(readAckFrame.ackBlocks[1].endPacket, 978);
EXPECT_EQ(readAckFrame.ackBlocks[1].startPacket, 968);
EXPECT_EQ(readAckFrame.ackBlocks[2].endPacket, 956);
EXPECT_EQ(readAckFrame.ackBlocks[2].startPacket, 956);
EXPECT_EQ(readAckFrame.ackBlocks[3].endPacket, 954);
EXPECT_EQ(readAckFrame.ackBlocks[3].startPacket, 944);
}
TEST_F(DecodeTest, StreamDecodeSuccess) {
QuicInteger streamId(10);
QuicInteger offset(10);
QuicInteger length(1);
auto streamType =
StreamTypeField::Builder().setFin().setOffset().setLength().build();
auto streamFrame = createStreamFrame(
streamId, offset, length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(streamFrame.get());
auto decodedFrame = decodeStreamFrame(cursor, streamType);
EXPECT_EQ(decodedFrame.offset, 10);
EXPECT_EQ(decodedFrame.data->computeChainDataLength(), 1);
EXPECT_EQ(decodedFrame.streamId, 10);
EXPECT_TRUE(decodedFrame.fin);
}
TEST_F(DecodeTest, StreamLengthStreamIdInvalid) {
QuicInteger streamId(std::numeric_limits<uint64_t>::max());
auto streamType =
StreamTypeField::Builder().setFin().setOffset().setLength().build();
auto streamFrame = createStreamFrame<uint8_t>(
streamId, folly::none, folly::none, nullptr, true);
folly::io::Cursor cursor(streamFrame.get());
EXPECT_THROW(decodeStreamFrame(cursor, streamType), QuicTransportException);
}
TEST_F(DecodeTest, StreamOffsetNotPresent) {
QuicInteger streamId(10);
QuicInteger length(1);
auto streamType =
StreamTypeField::Builder().setFin().setOffset().setLength().build();
auto streamFrame = createStreamFrame(
streamId, folly::none, length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(streamFrame.get());
EXPECT_THROW(decodeStreamFrame(cursor, streamType), QuicTransportException);
}
TEST_F(DecodeTest, StreamIncorrectDataLength) {
QuicInteger streamId(10);
QuicInteger offset(10);
QuicInteger length(10);
auto streamType =
StreamTypeField::Builder().setFin().setOffset().setLength().build();
auto streamFrame = createStreamFrame(
streamId, offset, length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(streamFrame.get());
EXPECT_THROW(decodeStreamFrame(cursor, streamType), QuicTransportException);
}
TEST_F(DecodeTest, CryptoDecodeSuccess) {
QuicInteger offset(10);
QuicInteger length(1);
auto cryptoFrame =
createCryptoFrame(offset, length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(cryptoFrame.get());
auto decodedFrame = decodeCryptoFrame(cursor);
EXPECT_EQ(decodedFrame.offset, 10);
EXPECT_EQ(decodedFrame.data->computeChainDataLength(), 1);
}
TEST_F(DecodeTest, CryptoOffsetNotPresent) {
QuicInteger length(1);
auto cryptoFrame =
createCryptoFrame(folly::none, length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(cryptoFrame.get());
EXPECT_THROW(decodeCryptoFrame(cursor), QuicTransportException);
}
TEST_F(DecodeTest, CryptoLengthNotPresent) {
QuicInteger offset(0);
auto cryptoFrame = createCryptoFrame(offset, folly::none, nullptr);
folly::io::Cursor cursor(cryptoFrame.get());
EXPECT_THROW(decodeCryptoFrame(cursor), QuicTransportException);
}
TEST_F(DecodeTest, CryptoIncorrectDataLength) {
QuicInteger offset(10);
QuicInteger length(10);
auto cryptoFrame =
createCryptoFrame(offset, length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(cryptoFrame.get());
EXPECT_THROW(decodeCryptoFrame(cursor), QuicTransportException);
}
TEST_F(DecodeTest, PaddingFrameTest) {
auto buf = folly::IOBuf::create(sizeof(UnderlyingFrameType));
buf->append(1);
folly::io::RWPrivateCursor wcursor(buf.get());
folly::io::Cursor cursor(buf.get());
decodePaddingFrame(cursor);
}
std::unique_ptr<folly::IOBuf> createNewTokenFrame(
folly::Optional<QuicInteger> tokenLength = folly::none,
Buf token = nullptr) {
folly::IOBufQueue newTokenFrame;
folly::io::QueueAppender wcursor(&newTokenFrame, 10);
if (tokenLength) {
tokenLength->encode(wcursor);
}
if (token) {
wcursor.insert(std::move(token));
}
return newTokenFrame.move();
}
TEST_F(DecodeTest, NewTokenDecodeSuccess) {
QuicInteger length(1);
auto newTokenFrame =
createNewTokenFrame(length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(newTokenFrame.get());
auto decodedFrame = decodeNewTokenFrame(cursor);
EXPECT_EQ(decodedFrame.token->computeChainDataLength(), 1);
}
TEST_F(DecodeTest, NewTokenLengthNotPresent) {
auto newTokenFrame =
createNewTokenFrame(folly::none, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(newTokenFrame.get());
EXPECT_THROW(decodeNewTokenFrame(cursor), QuicTransportException);
}
TEST_F(DecodeTest, NewTokenIncorrectDataLength) {
QuicInteger length(10);
auto newTokenFrame =
createNewTokenFrame(length, folly::IOBuf::copyBuffer("a"));
folly::io::Cursor cursor(newTokenFrame.get());
EXPECT_THROW(decodeNewTokenFrame(cursor), QuicTransportException);
}
std::unique_ptr<folly::IOBuf> createMinOrExpiredStreamDataFrame(
QuicInteger streamId,
folly::Optional<QuicInteger> maximumData = folly::none,
folly::Optional<QuicInteger> minimumStreamOffset = folly::none) {
folly::IOBufQueue bufQueue;
folly::io::QueueAppender wcursor(&bufQueue, 10);
streamId.encode(wcursor);
if (maximumData) {
maximumData->encode(wcursor);
}
if (minimumStreamOffset) {
minimumStreamOffset->encode(wcursor);
}
return bufQueue.move();
}
TEST_F(DecodeTest, DecodeMinStreamDataFrame) {
QuicInteger streamId(10);
QuicInteger maximumData(1000);
QuicInteger minimumStreamOffset(100);
auto noOffset = createMinOrExpiredStreamDataFrame(streamId, maximumData);
folly::io::Cursor cursor0(noOffset.get());
EXPECT_THROW(decodeMinStreamDataFrame(cursor0), QuicTransportException);
auto minStreamDataFrame = createMinOrExpiredStreamDataFrame(
streamId, maximumData, minimumStreamOffset);
folly::io::Cursor cursor(minStreamDataFrame.get());
auto result = decodeMinStreamDataFrame(cursor);
EXPECT_EQ(result.streamId, 10);
EXPECT_EQ(result.maximumData, 1000);
EXPECT_EQ(result.minimumStreamOffset, 100);
}
TEST_F(DecodeTest, DecodeExpiredStreamDataFrame) {
QuicInteger streamId(10);
QuicInteger offset(100);
auto noOffset = createMinOrExpiredStreamDataFrame(streamId);
folly::io::Cursor cursor0(noOffset.get());
EXPECT_THROW(decodeExpiredStreamDataFrame(cursor0), QuicTransportException);
auto expiredStreamDataFrame =
createMinOrExpiredStreamDataFrame(streamId, folly::none, offset);
folly::io::Cursor cursor(expiredStreamDataFrame.get());
auto result = decodeExpiredStreamDataFrame(cursor);
EXPECT_EQ(result.streamId, 10);
EXPECT_EQ(result.minimumStreamOffset, 100);
}
} // namespace test
} // namespace quic

90
quic/codec/test/Mocks.h Normal file
View File

@ -0,0 +1,90 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
// Copyright 2004-present Facebook. All rights reserved.
#pragma once
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>
#include <quic/codec/QuicPacketBuilder.h>
#include <quic/common/test/TestUtils.h>
namespace quic {
// Otherwise you won't be able to mock QuicPacketBuidlerBase::appendFrame()
std::ostream& operator<<(std::ostream& out, const QuicWriteFrame& /*rhs*/) {
return out;
}
} // namespace quic
namespace quic {
namespace test {
class MockQuicPacketBuilder : public PacketBuilderInterface {
public:
// override method with unique_ptr since gmock doesn't support it
void insert(std::unique_ptr<folly::IOBuf> buf) override {
_insert(buf);
}
MOCK_METHOD1(appendFrame, void(QuicWriteFrame));
MOCK_METHOD1(_insert, void(std::unique_ptr<folly::IOBuf>&));
MOCK_METHOD2(push, void(const uint8_t*, size_t));
MOCK_METHOD1(write, void(const QuicInteger&));
GMOCK_METHOD0_(, const, , remainingSpaceInPkt, uint32_t());
GMOCK_METHOD0_(, const, , getPacketHeader, const PacketHeader&());
MOCK_METHOD1(writeBEUint8, void(uint8_t));
MOCK_METHOD1(writeBEUint16, void(uint16_t));
MOCK_METHOD1(writeBEUint64, void(uint16_t));
MOCK_METHOD2(appendBytes, void(PacketNum, uint8_t));
MOCK_METHOD3(
appendBytes,
void(folly::io::QueueAppender&, PacketNum, uint8_t));
void writeBE(uint8_t value) override {
writeBEUint8(value);
}
void writeBE(uint16_t value) override {
writeBEUint16(value);
}
void writeBE(uint64_t value) override {
writeBEUint64(value);
}
std::pair<RegularQuicWritePacket, Buf> buildPacket() && {
ShortHeader header(
ProtectionType::KeyPhaseZero, getTestConnectionId(), 0x01);
RegularQuicWritePacket regularPacket(std::move(header));
regularPacket.frames = std::move(frames_);
return std::make_pair(std::move(regularPacket), outputQueue_.move());
}
std::pair<RegularQuicWritePacket, Buf> buildLongHeaderPacket() && {
ConnectionId connId = getTestConnectionId();
PacketNum packetNum = 10;
LongHeader header(
LongHeader::Types::Handshake,
getTestConnectionId(1),
connId,
packetNum,
QuicVersion::MVFST);
RegularQuicWritePacket regularPacket(std::move(header));
regularPacket.frames = std::move(frames_);
return std::make_pair(std::move(regularPacket), outputQueue_.move());
}
std::vector<QuicWriteFrame> frames_;
uint32_t remaining_{kDefaultUDPSendPacketLen};
folly::IOBufQueue outputQueue_;
folly::io::QueueAppender appender_{&outputQueue_, 100};
};
} // namespace test
} // namespace quic

View File

@ -0,0 +1,132 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*
*/
#include <folly/portability/GTest.h>
#include <quic/codec/PacketNumberCipher.h>
#include <folly/String.h>
using namespace testing;
namespace quic {
namespace test {
struct HeaderParams {
folly::StringPiece key;
folly::StringPiece sample;
folly::StringPiece packetNumberBytes;
folly::StringPiece initialByte;
folly::StringPiece decryptedPacketNumberBytes;
folly::StringPiece decryptedInitialByte;
};
// TODO: add tests for short headers.
class LongPacketNumberCipherTest : public TestWithParam<HeaderParams> {
public:
void SetUp() override {}
protected:
Aes128PacketNumberCipher cipher_;
};
TEST_P(LongPacketNumberCipherTest, TestDecrypt) {
auto key = folly::unhexlify(GetParam().key);
cipher_.setKey(folly::range(key));
std::array<uint8_t, 1> initialByte;
std::array<uint8_t, 16> sample;
std::array<uint8_t, 4> packetNumberBytes;
auto initialByteString = folly::unhexlify(GetParam().initialByte);
auto sampleString = folly::unhexlify(GetParam().sample);
auto packetNumberBytesString = folly::unhexlify(GetParam().packetNumberBytes);
memcpy(initialByte.data(), initialByteString.data(), initialByte.size());
memcpy(sample.data(), sampleString.data(), sample.size());
memcpy(
packetNumberBytes.data(),
packetNumberBytesString.data(),
packetNumberBytes.size());
cipher_.decryptLongHeader(
sample, folly::range(initialByte), folly::range(packetNumberBytes));
EXPECT_EQ(
folly::hexlify(packetNumberBytes), GetParam().decryptedPacketNumberBytes);
EXPECT_EQ(folly::hexlify(initialByte), GetParam().decryptedInitialByte);
memcpy(initialByte.data(), initialByteString.data(), initialByte.size());
memcpy(sample.data(), sampleString.data(), sample.size());
memcpy(
packetNumberBytes.data(),
packetNumberBytesString.data(),
packetNumberBytes.size());
cipher_.decryptLongHeader(
sample, folly::range(initialByte), folly::range(packetNumberBytes));
EXPECT_EQ(
folly::hexlify(packetNumberBytes), GetParam().decryptedPacketNumberBytes);
EXPECT_EQ(folly::hexlify(initialByte), GetParam().decryptedInitialByte);
}
TEST_P(LongPacketNumberCipherTest, TestEncrypt) {
auto key = folly::unhexlify(GetParam().key);
cipher_.setKey(folly::range(key));
std::array<uint8_t, 1> initialByte;
std::array<uint8_t, 16> sample;
std::array<uint8_t, 4> packetNumberBytes;
auto initialByteString = folly::unhexlify(GetParam().decryptedInitialByte);
auto sampleString = folly::unhexlify(GetParam().sample);
auto packetNumberBytesString =
folly::unhexlify(GetParam().decryptedPacketNumberBytes);
memcpy(initialByte.data(), initialByteString.data(), initialByte.size());
memcpy(sample.data(), sampleString.data(), sample.size());
memcpy(
packetNumberBytes.data(),
packetNumberBytesString.data(),
packetNumberBytes.size());
cipher_.encryptLongHeader(
sample, folly::range(initialByte), folly::range(packetNumberBytes));
EXPECT_EQ(folly::hexlify(packetNumberBytes), GetParam().packetNumberBytes);
EXPECT_EQ(folly::hexlify(initialByte), GetParam().initialByte);
memcpy(initialByte.data(), initialByteString.data(), initialByte.size());
memcpy(sample.data(), sampleString.data(), sample.size());
memcpy(
packetNumberBytes.data(),
packetNumberBytesString.data(),
packetNumberBytes.size());
cipher_.encryptLongHeader(
sample, folly::range(initialByte), folly::range(packetNumberBytes));
EXPECT_EQ(folly::hexlify(packetNumberBytes), GetParam().packetNumberBytes);
EXPECT_EQ(folly::hexlify(initialByte), GetParam().initialByte);
}
INSTANTIATE_TEST_CASE_P(
LongPacketNumberCipherTests,
LongPacketNumberCipherTest,
::testing::Values(
HeaderParams{folly::StringPiece{"0edd982a6ac527f2eddcbb7348dea5d7"},
folly::StringPiece{"0000f3a694c75775b4e546172ce9e047"},
folly::StringPiece{"0dbc195a"},
folly::StringPiece{"c1"},
folly::StringPiece{"00000002"},
folly::StringPiece{"c3"}},
HeaderParams{folly::StringPiece{"94b9452d2b3c7c7f6da7fdd8593537fd"},
folly::StringPiece{"c4c2a2303d297e3c519bf6b22386e3d0"},
folly::StringPiece{"f7ed5f01"},
folly::StringPiece{"c4"},
folly::StringPiece{"00015f01"},
folly::StringPiece{"c1"}}));
} // namespace test
} // namespace quic

Some files were not shown because too many files have changed in this diff Show More