mirror of
https://github.com/Mbed-TLS/mbedtls.git
synced 2025-08-08 17:42:09 +03:00
Merged support for the ALPN extension
This commit is contained in:
@@ -383,6 +383,54 @@ static void ssl_write_session_ticket_ext( ssl_context *ssl,
|
||||
}
|
||||
#endif /* POLARSSL_SSL_SESSION_TICKETS */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
static void ssl_write_alpn_ext( ssl_context *ssl,
|
||||
unsigned char *buf, size_t *olen )
|
||||
{
|
||||
unsigned char *p = buf;
|
||||
const char **cur;
|
||||
|
||||
if( ssl->alpn_list == NULL )
|
||||
{
|
||||
*olen = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
SSL_DEBUG_MSG( 3, ( "client hello, adding alpn extension" ) );
|
||||
|
||||
*p++ = (unsigned char)( ( TLS_EXT_ALPN >> 8 ) & 0xFF );
|
||||
*p++ = (unsigned char)( ( TLS_EXT_ALPN ) & 0xFF );
|
||||
|
||||
/*
|
||||
* opaque ProtocolName<1..2^8-1>;
|
||||
*
|
||||
* struct {
|
||||
* ProtocolName protocol_name_list<2..2^16-1>
|
||||
* } ProtocolNameList;
|
||||
*/
|
||||
|
||||
/* Skip writing extension and list length for now */
|
||||
p += 4;
|
||||
|
||||
for( cur = ssl->alpn_list; *cur != NULL; cur++ )
|
||||
{
|
||||
*p = (unsigned char)( strlen( *cur ) & 0xFF );
|
||||
memcpy( p + 1, *cur, *p );
|
||||
p += 1 + *p;
|
||||
}
|
||||
|
||||
*olen = p - buf;
|
||||
|
||||
/* List length = olen - 2 (ext_type) - 2 (ext_len) - 2 (list_len) */
|
||||
buf[4] = (unsigned char)( ( ( *olen - 6 ) >> 8 ) & 0xFF );
|
||||
buf[5] = (unsigned char)( ( ( *olen - 6 ) ) & 0xFF );
|
||||
|
||||
/* Extension length = olen - 2 (ext_type) - 2 (ext_len) */
|
||||
buf[2] = (unsigned char)( ( ( *olen - 4 ) >> 8 ) & 0xFF );
|
||||
buf[3] = (unsigned char)( ( ( *olen - 4 ) ) & 0xFF );
|
||||
}
|
||||
#endif /* POLARSSL_SSL_ALPN */
|
||||
|
||||
static int ssl_write_client_hello( ssl_context *ssl )
|
||||
{
|
||||
int ret;
|
||||
@@ -595,6 +643,11 @@ static int ssl_write_client_hello( ssl_context *ssl )
|
||||
ext_len += olen;
|
||||
#endif
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
ssl_write_alpn_ext( ssl, p + 2 + ext_len, &olen );
|
||||
ext_len += olen;
|
||||
#endif
|
||||
|
||||
SSL_DEBUG_MSG( 3, ( "client hello, total extension length: %d",
|
||||
ext_len ) );
|
||||
|
||||
@@ -753,6 +806,54 @@ static int ssl_parse_supported_point_formats_ext( ssl_context *ssl,
|
||||
}
|
||||
#endif /* POLARSSL_ECDH_C || POLARSSL_ECDSA_C */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
static int ssl_parse_alpn_ext( ssl_context *ssl,
|
||||
const unsigned char *buf, size_t len )
|
||||
{
|
||||
size_t list_len, name_len;
|
||||
const char **p;
|
||||
|
||||
/* If we didn't send it, the server shouldn't send it */
|
||||
if( ssl->alpn_list == NULL )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
|
||||
|
||||
/*
|
||||
* opaque ProtocolName<1..2^8-1>;
|
||||
*
|
||||
* struct {
|
||||
* ProtocolName protocol_name_list<2..2^16-1>
|
||||
* } ProtocolNameList;
|
||||
*
|
||||
* the "ProtocolNameList" MUST contain exactly one "ProtocolName"
|
||||
*/
|
||||
|
||||
/* Min length is 2 (list_len) + 1 (name_len) + 1 (name) */
|
||||
if( len < 4 )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
|
||||
|
||||
list_len = ( buf[0] << 8 ) | buf[1];
|
||||
if( list_len != len - 2 )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
|
||||
|
||||
name_len = buf[2];
|
||||
if( name_len != list_len - 1 )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
|
||||
|
||||
/* Check that the server chosen protocol was in our list and save it */
|
||||
for( p = ssl->alpn_list; *p != NULL; p++ )
|
||||
{
|
||||
if( name_len == strlen( *p ) &&
|
||||
memcmp( buf + 3, *p, name_len ) == 0 )
|
||||
{
|
||||
ssl->alpn_chosen = *p;
|
||||
return( 0 );
|
||||
}
|
||||
}
|
||||
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
|
||||
}
|
||||
#endif /* POLARSSL_SSL_ALPN */
|
||||
|
||||
static int ssl_parse_server_hello( ssl_context *ssl )
|
||||
{
|
||||
int ret, i, comp;
|
||||
@@ -1023,6 +1124,16 @@ static int ssl_parse_server_hello( ssl_context *ssl )
|
||||
break;
|
||||
#endif /* POLARSSL_ECDH_C || POLARSSL_ECDSA_C */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
case TLS_EXT_ALPN:
|
||||
SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );
|
||||
|
||||
if( ( ret = ssl_parse_alpn_ext( ssl, ext + 4, ext_size ) ) != 0 )
|
||||
return( ret );
|
||||
|
||||
break;
|
||||
#endif /* POLARSSL_SSL_ALPN */
|
||||
|
||||
default:
|
||||
SSL_DEBUG_MSG( 3, ( "unknown extension found: %d (ignoring)",
|
||||
ext_id ) );
|
||||
|
@@ -683,6 +683,69 @@ static int ssl_parse_session_ticket_ext( ssl_context *ssl,
|
||||
}
|
||||
#endif /* POLARSSL_SSL_SESSION_TICKETS */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
static int ssl_parse_alpn_ext( ssl_context *ssl,
|
||||
unsigned char *buf, size_t len )
|
||||
{
|
||||
size_t list_len, cur_len;
|
||||
const unsigned char *theirs, *start, *end;
|
||||
const char **ours;
|
||||
|
||||
/* If ALPN not configured, just ignore the extension */
|
||||
if( ssl->alpn_list == NULL )
|
||||
return( 0 );
|
||||
|
||||
/*
|
||||
* opaque ProtocolName<1..2^8-1>;
|
||||
*
|
||||
* struct {
|
||||
* ProtocolName protocol_name_list<2..2^16-1>
|
||||
* } ProtocolNameList;
|
||||
*/
|
||||
|
||||
/* Min length is 2 (list_len) + 1 (name_len) + 1 (name) */
|
||||
if( len < 4 )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
|
||||
|
||||
list_len = ( buf[0] << 8 ) | buf[1];
|
||||
if( list_len != len - 2 )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
|
||||
|
||||
/*
|
||||
* Use our order of preference
|
||||
*/
|
||||
start = buf + 2;
|
||||
end = buf + len;
|
||||
for( ours = ssl->alpn_list; *ours != NULL; ours++ )
|
||||
{
|
||||
for( theirs = start; theirs != end; theirs += cur_len )
|
||||
{
|
||||
/* If the list is well formed, we should get equality first */
|
||||
if( theirs > end )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
|
||||
|
||||
cur_len = *theirs++;
|
||||
|
||||
/* Empty strings MUST NOT be included */
|
||||
if( cur_len == 0 )
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
|
||||
|
||||
if( cur_len == strlen( *ours ) &&
|
||||
memcmp( theirs, *ours, cur_len ) == 0 )
|
||||
{
|
||||
ssl->alpn_chosen = *ours;
|
||||
return( 0 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* If we get there, no match was found */
|
||||
ssl_send_alert_message( ssl, SSL_ALERT_LEVEL_FATAL,
|
||||
SSL_ALERT_MSG_NO_APPLICATION_PROTOCOL );
|
||||
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
|
||||
}
|
||||
#endif /* POLARSSL_SSL_ALPN */
|
||||
|
||||
/*
|
||||
* Auxiliary functions for ServerHello parsing and related actions
|
||||
*/
|
||||
@@ -1385,6 +1448,16 @@ static int ssl_parse_client_hello( ssl_context *ssl )
|
||||
break;
|
||||
#endif /* POLARSSL_SSL_SESSION_TICKETS */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
case TLS_EXT_ALPN:
|
||||
SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );
|
||||
|
||||
ret = ssl_parse_alpn_ext( ssl, ext + 4, ext_size );
|
||||
if( ret != 0 )
|
||||
return( ret );
|
||||
break;
|
||||
#endif /* POLARSSL_SSL_SESSION_TICKETS */
|
||||
|
||||
default:
|
||||
SSL_DEBUG_MSG( 3, ( "unknown extension found: %d (ignoring)",
|
||||
ext_id ) );
|
||||
@@ -1625,6 +1698,42 @@ static void ssl_write_supported_point_formats_ext( ssl_context *ssl,
|
||||
}
|
||||
#endif /* POLARSSL_ECDH_C || POLARSSL_ECDSA_C */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN )
|
||||
static void ssl_write_alpn_ext( ssl_context *ssl,
|
||||
unsigned char *buf, size_t *olen )
|
||||
{
|
||||
if( ssl->alpn_chosen == NULL )
|
||||
{
|
||||
*olen = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
SSL_DEBUG_MSG( 3, ( "server hello, adding alpn extension" ) );
|
||||
|
||||
/*
|
||||
* 0 . 1 ext identifier
|
||||
* 2 . 3 ext length
|
||||
* 4 . 5 protocol list length
|
||||
* 6 . 6 protocol name length
|
||||
* 7 . 7+n protocol name
|
||||
*/
|
||||
buf[0] = (unsigned char)( ( TLS_EXT_ALPN >> 8 ) & 0xFF );
|
||||
buf[1] = (unsigned char)( ( TLS_EXT_ALPN ) & 0xFF );
|
||||
|
||||
*olen = 7 + strlen( ssl->alpn_chosen );
|
||||
|
||||
buf[2] = (unsigned char)( ( ( *olen - 4 ) >> 8 ) & 0xFF );
|
||||
buf[3] = (unsigned char)( ( ( *olen - 4 ) ) & 0xFF );
|
||||
|
||||
buf[4] = (unsigned char)( ( ( *olen - 6 ) >> 8 ) & 0xFF );
|
||||
buf[5] = (unsigned char)( ( ( *olen - 6 ) ) & 0xFF );
|
||||
|
||||
buf[6] = (unsigned char)( ( ( *olen - 7 ) ) & 0xFF );
|
||||
|
||||
memcpy( buf + 7, ssl->alpn_chosen, *olen - 7 );
|
||||
}
|
||||
#endif /* POLARSSL_ECDH_C || POLARSSL_ECDSA_C */
|
||||
|
||||
static int ssl_write_server_hello( ssl_context *ssl )
|
||||
{
|
||||
#if defined(POLARSSL_HAVE_TIME)
|
||||
@@ -1791,6 +1900,11 @@ static int ssl_write_server_hello( ssl_context *ssl )
|
||||
ext_len += olen;
|
||||
#endif
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
ssl_write_alpn_ext( ssl, p + 2 + ext_len, &olen );
|
||||
ext_len += olen;
|
||||
#endif
|
||||
|
||||
SSL_DEBUG_MSG( 3, ( "server hello, total extension length: %d", ext_len ) );
|
||||
|
||||
*p++ = (unsigned char)( ( ext_len >> 8 ) & 0xFF );
|
||||
|
@@ -3521,6 +3521,10 @@ int ssl_session_reset( ssl_context *ssl )
|
||||
ssl->session = NULL;
|
||||
}
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
ssl->alpn_chosen = NULL;
|
||||
#endif
|
||||
|
||||
if( ( ret = ssl_handshake_init( ssl ) ) != 0 )
|
||||
return( ret );
|
||||
|
||||
@@ -3915,6 +3919,37 @@ void ssl_set_sni( ssl_context *ssl,
|
||||
}
|
||||
#endif /* POLARSSL_SSL_SERVER_NAME_INDICATION */
|
||||
|
||||
#if defined(POLARSSL_SSL_ALPN)
|
||||
int ssl_set_alpn_protocols( ssl_context *ssl, const char **protos )
|
||||
{
|
||||
size_t cur_len, tot_len;
|
||||
const char **p;
|
||||
|
||||
/*
|
||||
* "Empty strings MUST NOT be included and byte strings MUST NOT be
|
||||
* truncated". Check lengths now rather than later.
|
||||
*/
|
||||
tot_len = 0;
|
||||
for( p = protos; *p != NULL; p++ )
|
||||
{
|
||||
cur_len = strlen( *p );
|
||||
tot_len += cur_len;
|
||||
|
||||
if( cur_len == 0 || cur_len > 255 || tot_len > 65535 )
|
||||
return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
|
||||
}
|
||||
|
||||
ssl->alpn_list = protos;
|
||||
|
||||
return( 0 );
|
||||
}
|
||||
|
||||
const char *ssl_get_alpn_protocol( const ssl_context *ssl )
|
||||
{
|
||||
return ssl->alpn_chosen;
|
||||
}
|
||||
#endif /* POLARSSL_SSL_ALPN */
|
||||
|
||||
void ssl_set_max_version( ssl_context *ssl, int major, int minor )
|
||||
{
|
||||
if( major >= SSL_MIN_MAJOR_VERSION && major <= SSL_MAX_MAJOR_VERSION &&
|
||||
|
Reference in New Issue
Block a user