diff --git a/ssl/ssl.h b/ssl/ssl.h index 97e87d495..b287d5aa8 100644 --- a/ssl/ssl.h +++ b/ssl/ssl.h @@ -352,6 +352,16 @@ EXP_FUNC int STDCALL ssl_handshake_status(const SSL *ssl); */ EXP_FUNC int STDCALL ssl_get_config(int offset); +/** + * @brief Sets the hostname to be used for SNI + * @see https://en.wikipedia.org/wiki/Server_Name_Indication + * @param char* hostname + * @return success from the operation + * - 1 on success + * - 0 on failure + */ +EXP_FUNC int STDCALL ssl_set_hostname(SSL *ssl, const char* host_name); + /** * @brief Display why the handshake failed. * diff --git a/ssl/tls1.c b/ssl/tls1.c index c4a676b10..195cb0196 100644 --- a/ssl/tls1.c +++ b/ssl/tls1.c @@ -568,6 +568,8 @@ SSL *ssl_new(SSL_CTX *ssl_ctx, int client_fd) ssl->encrypt_ctx = malloc(sizeof(AES_CTX)); ssl->decrypt_ctx = malloc(sizeof(AES_CTX)); + ssl->host_name = NULL; + SSL_CTX_UNLOCK(ssl_ctx->mutex); return ssl; } @@ -1849,6 +1851,29 @@ EXP_FUNC int STDCALL ssl_get_config(int offset) } } +/** + * Sets the SNI hostname + */ +EXP_FUNC int STDCALL ssl_set_hostname(SSL *ssl, const char* host_name) { + if(host_name == NULL || strlen(host_name) == 0 || strlen(host_name) > 255 ) { + return 0; + } + + if(ssl->host_name != NULL) { + free(ssl->host_name); + } + + ssl->host_name = (char *)malloc(strlen(host_name)+1); + if(ssl->host_name == NULL) { + // most probably there was no memory available + return 0; + } + + strcpy(ssl->host_name, host_name); + + return 1; +} + #ifdef CONFIG_SSL_CERT_VERIFICATION /** * Authenticate a received certificate. diff --git a/ssl/tls1.h b/ssl/tls1.h index b7cd7f36e..c53ce6da0 100644 --- a/ssl/tls1.h +++ b/ssl/tls1.h @@ -198,6 +198,7 @@ struct _SSL uint8_t read_sequence[8]; /* 64 bit sequence number */ uint8_t write_sequence[8]; /* 64 bit sequence number */ uint8_t hmac_header[SSL_RECORD_SIZE]; /* rx hmac */ + char *host_name; /* Needed for the SNI support */ }; typedef struct _SSL SSL; diff --git a/ssl/tls1_clnt.c b/ssl/tls1_clnt.c index 5f2598922..b84877da7 100644 --- a/ssl/tls1_clnt.c +++ b/ssl/tls1_clnt.c @@ -220,6 +220,26 @@ static int send_client_hello(SSL *ssl) buf[offset++] = 1; /* no compression */ buf[offset++] = 0; + + if (ssl->host_name != NULL) { + unsigned int host_len = strlen(ssl->host_name); + + buf[offset++] = 0; + buf[offset++] = host_len+9; /* extensions length */ + + buf[offset++] = 0; + buf[offset++] = 0; /* server_name(0) (65535) */ + buf[offset++] = 0; + buf[offset++] = host_len+5; /* server_name length */ + buf[offset++] = 0; + buf[offset++] = host_len+3; /* server_list length */ + buf[offset++] = 0; /* host_name(0) (255) */ + buf[offset++] = 0; + buf[offset++] = host_len; /* host_name length */ + strncpy((char*) &buf[offset], ssl->host_name, host_len); + offset += host_len; + } + buf[3] = offset - 4; /* handshake size */ return send_packet(ssl, PT_HANDSHAKE_PROTOCOL, NULL, offset);