diff --git a/ssl/tls1.c b/ssl/tls1.c index 569d8ad0c..19bf86d69 100644 --- a/ssl/tls1.c +++ b/ssl/tls1.c @@ -50,7 +50,7 @@ static const char * client_finished = "client finished"; static int do_handshake(SSL *ssl, uint8_t *buf, int read_len); static int set_key_block(SSL *ssl, int is_write); static int verify_digest(SSL *ssl, int mode, const uint8_t *buf, int read_len); -static void *crypt_new(SSL *ssl, uint8_t *key, uint8_t *iv, int is_decrypt); +static void *crypt_new(SSL *ssl, uint8_t *key, uint8_t *iv, int is_decrypt, void* cached); static int send_raw_packet(SSL *ssl, uint8_t protocol); /** @@ -591,6 +591,9 @@ SSL *ssl_new(SSL_CTX *ssl_ctx, int client_fd) ssl_ctx->tail = ssl; } + ssl->encrypt_ctx = malloc(sizeof(AES_CTX)); + ssl->decrypt_ctx = malloc(sizeof(AES_CTX)); + SSL_CTX_UNLOCK(ssl_ctx->mutex); return ssl; } @@ -917,14 +920,18 @@ void finished_digest(SSL *ssl, const char *label, uint8_t *digest) /** * Retrieve (and initialise) the context of a cipher. */ -static void *crypt_new(SSL *ssl, uint8_t *key, uint8_t *iv, int is_decrypt) +static void *crypt_new(SSL *ssl, uint8_t *key, uint8_t *iv, int is_decrypt, void* cached) { switch (ssl->cipher) { #ifndef CONFIG_SSL_SKELETON_MODE case SSL_AES128_SHA: { - AES_CTX *aes_ctx = (AES_CTX *)malloc(sizeof(AES_CTX)); + AES_CTX *aes_ctx; + if (cached) + aes_ctx = (AES_CTX*) cached; + else + aes_ctx = (AES_CTX*) malloc(sizeof(AES_CTX)); AES_set_key(aes_ctx, key, iv, AES_MODE_128); if (is_decrypt) @@ -937,7 +944,12 @@ static void *crypt_new(SSL *ssl, uint8_t *key, uint8_t *iv, int is_decrypt) case SSL_AES256_SHA: { - AES_CTX *aes_ctx = (AES_CTX *)malloc(sizeof(AES_CTX)); + AES_CTX *aes_ctx; + if (cached) + aes_ctx = (AES_CTX*) cached; + else + aes_ctx = (AES_CTX*) malloc(sizeof(AES_CTX)); + AES_set_key(aes_ctx, key, iv, AES_MODE_256); if (is_decrypt) @@ -952,7 +964,12 @@ static void *crypt_new(SSL *ssl, uint8_t *key, uint8_t *iv, int is_decrypt) #endif case SSL_RC4_128_SHA: { - RC4_CTX *rc4_ctx = (RC4_CTX *)malloc(sizeof(RC4_CTX)); + RC4_CTX* rc4_ctx; + if (cached) + rc4_ctx = (RC4_CTX*) cached; + else + rc4_ctx = (RC4_CTX*) malloc(sizeof(RC4_CTX)); + RC4_setup(rc4_ctx, key, 16); return (void *)rc4_ctx; } @@ -1184,7 +1201,7 @@ static int set_key_block(SSL *ssl, int is_write) } #endif - free(is_write ? ssl->encrypt_ctx : ssl->decrypt_ctx); + // free(is_write ? ssl->encrypt_ctx : ssl->decrypt_ctx); /* now initialise the ciphers */ if (is_client) @@ -1192,18 +1209,18 @@ static int set_key_block(SSL *ssl, int is_write) finished_digest(ssl, server_finished, ssl->dc->final_finish_mac); if (is_write) - ssl->encrypt_ctx = crypt_new(ssl, client_key, client_iv, 0); + ssl->encrypt_ctx = crypt_new(ssl, client_key, client_iv, 0, ssl->encrypt_ctx); else - ssl->decrypt_ctx = crypt_new(ssl, server_key, server_iv, 1); + ssl->decrypt_ctx = crypt_new(ssl, server_key, server_iv, 1, ssl->decrypt_ctx); } else { finished_digest(ssl, client_finished, ssl->dc->final_finish_mac); if (is_write) - ssl->encrypt_ctx = crypt_new(ssl, server_key, server_iv, 0); + ssl->encrypt_ctx = crypt_new(ssl, server_key, server_iv, 0, ssl->encrypt_ctx); else - ssl->decrypt_ctx = crypt_new(ssl, client_key, client_iv, 1); + ssl->decrypt_ctx = crypt_new(ssl, client_key, client_iv, 1, ssl->decrypt_ctx); } ssl->cipher_info = ciph_info;