diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 24e60f7934..5b4e36f2ff 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -117,7 +117,8 @@ #define SSL_MINOR_VERSION_2 2 /*!< TLS v1.1 */ #define SSL_MINOR_VERSION_3 3 /*!< TLS v1.2 */ -/* RFC 6066 section 4 */ +/* RFC 6066 section 4, see also mfl_code_to_length in ssl_tls.c + * NONE must be zero so that memset()ing session to zero works */ #define SSL_MAX_FRAG_LEN_NONE 0 /*!< don't use this extension */ #define SSL_MAX_FRAG_LEN_512 1 /*!< MaxFragmentLength 2^9 */ #define SSL_MAX_FRAG_LEN_1024 2 /*!< MaxFragmentLength 2^10 */ @@ -509,7 +510,6 @@ struct _ssl_context /* Maximum fragment length extension (RFC 6066 section 4) */ unsigned char mfl_code; /*!< numerical code for MaxFragmentLength */ - uint16_t max_frag_len; /*!< value of MaxFragmentLength */ /* * PKI layer diff --git a/library/ssl_tls.c b/library/ssl_tls.c index d6be987334..0374ee8186 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -59,6 +59,23 @@ #define strcasecmp _stricmp #endif +/* + * Convert max_fragment_length codes to length. + * RFC 6066 says: + * enum{ + * 2^9(1), 2^10(2), 2^11(3), 2^12(4), (255) + * } MaxFragmentLength; + * and we add 0 -> extension unused + */ +static unsigned int mfl_code_to_length[] = +{ + SSL_MAX_CONTENT_LEN, /* SSL_MAX_FRAG_LEN_NONE */ + 512, /* SSL_MAX_FRAG_LEN_512 */ + 1024, /* SSL_MAX_FRAG_LEN_1024 */ + 2048, /* SSL_MAX_FRAG_LEN_2048 */ + 4096, /* SSL_MAX_FRAG_LEN_4096 */ +}; + #if defined(POLARSSL_SSL_HW_RECORD_ACCEL) int (*ssl_hw_record_init)(ssl_context *ssl, const unsigned char *key_enc, const unsigned char *key_dec, @@ -2827,7 +2844,6 @@ int ssl_init( ssl_context *ssl ) memset( ssl->out_ctr, 0, SSL_BUFFER_LEN ); ssl->mfl_code = SSL_MAX_FRAG_LEN_NONE; - ssl->max_frag_len = SSL_MAX_CONTENT_LEN; ssl->hostname = NULL; ssl->hostname_len = 0; @@ -2871,7 +2887,6 @@ int ssl_session_reset( ssl_context *ssl ) ssl->out_left = 0; ssl->mfl_code = SSL_MAX_FRAG_LEN_NONE; - ssl->max_frag_len = SSL_MAX_CONTENT_LEN; ssl->transform_in = NULL; ssl->transform_out = NULL; @@ -3119,35 +3134,13 @@ void ssl_set_min_version( ssl_context *ssl, int major, int minor ) int ssl_set_max_frag_len( ssl_context *ssl, unsigned char mfl_code ) { - uint16_t max_frag_len; - - switch( mfl_code ) + if( mfl_code >= sizeof( mfl_code_to_length ) || + mfl_code_to_length[mfl_code] > SSL_MAX_CONTENT_LEN ) { - case SSL_MAX_FRAG_LEN_512: - max_frag_len = 512; - break; - - case SSL_MAX_FRAG_LEN_1024: - max_frag_len = 1024; - break; - - case SSL_MAX_FRAG_LEN_2048: - max_frag_len = 2048; - break; - - case SSL_MAX_FRAG_LEN_4096: - max_frag_len = 4096; - break; - - default: - return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); + return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); } - if( max_frag_len > SSL_MAX_CONTENT_LEN ) - return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - ssl->mfl_code = mfl_code; - ssl->max_frag_len = max_frag_len; return( 0 ); } @@ -3413,6 +3406,7 @@ int ssl_write( ssl_context *ssl, const unsigned char *buf, size_t len ) { int ret; size_t n; + unsigned int max_len; SSL_DEBUG_MSG( 2, ( "=> write" ) ); @@ -3425,8 +3419,12 @@ int ssl_write( ssl_context *ssl, const unsigned char *buf, size_t len ) } } - n = ( len < ssl->max_frag_len ) - ? len : ssl->max_frag_len; + /* + * Assume mfl_code is correct since it was checked when set + */ + max_len = mfl_code_to_length[ssl->mfl_code]; + + n = ( len < max_len) ? len : max_len; if( ssl->out_left != 0 ) {