1
0
mirror of https://github.com/Mbed-TLS/mbedtls.git synced 2025-10-21 14:53:42 +03:00

Add UDP support to the NET module

This commit is contained in:
Manuel Pégourié-Gonnard
2014-03-23 17:38:16 +01:00
committed by Paul Bakker
parent d6b721c7ee
commit f5a1312eaa
13 changed files with 97 additions and 38 deletions

View File

@@ -43,38 +43,48 @@
#define POLARSSL_NET_LISTEN_BACKLOG 10 /**< The backlog that listen() should use. */ #define POLARSSL_NET_LISTEN_BACKLOG 10 /**< The backlog that listen() should use. */
#define NET_PROTO_TCP 0 /**< The TCP transport protocol */
#define NET_PROTO_UDP 1 /**< The UDP transport protocol */
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
/** /**
* \brief Initiate a TCP connection with host:port * \brief Initiate a connection with host:port in the given protocol
* *
* \param fd Socket to use * \param fd Socket to use
* \param host Host to connect to * \param host Host to connect to
* \param port Port to connect to * \param port Port to connect to
* \param proto Protocol: NET_PROTO_TCP or NET_PROTO_UDP
* *
* \return 0 if successful, or one of: * \return 0 if successful, or one of:
* POLARSSL_ERR_NET_SOCKET_FAILED, * POLARSSL_ERR_NET_SOCKET_FAILED,
* POLARSSL_ERR_NET_UNKNOWN_HOST, * POLARSSL_ERR_NET_UNKNOWN_HOST,
* POLARSSL_ERR_NET_CONNECT_FAILED * POLARSSL_ERR_NET_CONNECT_FAILED
*
* \note Sets the socket in connected mode even with UDP.
*/ */
int net_connect( int *fd, const char *host, int port ); int net_connect( int *fd, const char *host, int port, int proto );
/** /**
* \brief Create a listening socket on bind_ip:port. * \brief Create a receiving socket on bind_ip:port in the chosen
* If bind_ip == NULL, all interfaces are binded. * protocol. If bind_ip == NULL, all interfaces are bound.
* *
* \param fd Socket to use * \param fd Socket to use
* \param bind_ip IP to bind to, can be NULL * \param bind_ip IP to bind to, can be NULL
* \param port Port number to use * \param port Port number to use
* \param proto Protocol: NET_PROTO_TCP or NET_PROTO_UDP
* *
* \return 0 if successful, or one of: * \return 0 if successful, or one of:
* POLARSSL_ERR_NET_SOCKET_FAILED, * POLARSSL_ERR_NET_SOCKET_FAILED,
* POLARSSL_ERR_NET_BIND_FAILED, * POLARSSL_ERR_NET_BIND_FAILED,
* POLARSSL_ERR_NET_LISTEN_FAILED * POLARSSL_ERR_NET_LISTEN_FAILED
*
* \note Regardless of the protocol, opens the sockets and binds it.
* In addition, make the socket listening if protocol is TCP.
*/ */
int net_bind( int *fd, const char *bind_ip, int port ); int net_bind( int *fd, const char *bind_ip, int port, int proto );
/** /**
* \brief Accept a connection from a remote client * \brief Accept a connection from a remote client
@@ -87,6 +97,10 @@ int net_bind( int *fd, const char *bind_ip, int port );
* \return 0 if successful, POLARSSL_ERR_NET_ACCEPT_FAILED, or * \return 0 if successful, POLARSSL_ERR_NET_ACCEPT_FAILED, or
* POLARSSL_ERR_NET_WANT_READ is bind_fd was set to * POLARSSL_ERR_NET_WANT_READ is bind_fd was set to
* non-blocking and accept() is blocking. * non-blocking and accept() is blocking.
*
* \note With UDP, connects the bind_fd to the client and just copy
* its descriptor to client_fd. New clients will not be able
* to connect until you close the socket and bind a new one.
*/ */
int net_accept( int bind_fd, int *client_fd, void *client_ip ); int net_accept( int bind_fd, int *client_fd, void *client_ip );

View File

@@ -160,9 +160,9 @@ static int net_prepare( void )
} }
/* /*
* Initiate a TCP connection with host:port * Initiate a TCP connection with host:port and the given protocol
*/ */
int net_connect( int *fd, const char *host, int port ) int net_connect( int *fd, const char *host, int port, int proto )
{ {
#if defined(POLARSSL_HAVE_IPV6) #if defined(POLARSSL_HAVE_IPV6)
int ret; int ret;
@@ -176,11 +176,11 @@ int net_connect( int *fd, const char *host, int port )
memset( port_str, 0, sizeof( port_str ) ); memset( port_str, 0, sizeof( port_str ) );
snprintf( port_str, sizeof( port_str ), "%d", port ); snprintf( port_str, sizeof( port_str ), "%d", port );
/* Do name resolution with both IPv6 and IPv4, but only TCP */ /* Do name resolution with both IPv6 and IPv4 */
memset( &hints, 0, sizeof( hints ) ); memset( &hints, 0, sizeof( hints ) );
hints.ai_family = AF_UNSPEC; hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM; hints.ai_socktype = proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP; hints.ai_protocol = proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP;
if( getaddrinfo( host, port_str, &hints, &addr_list ) != 0 ) if( getaddrinfo( host, port_str, &hints, &addr_list ) != 0 )
return( POLARSSL_ERR_NET_UNKNOWN_HOST ); return( POLARSSL_ERR_NET_UNKNOWN_HOST );
@@ -224,7 +224,9 @@ int net_connect( int *fd, const char *host, int port )
if( ( server_host = gethostbyname( host ) ) == NULL ) if( ( server_host = gethostbyname( host ) ) == NULL )
return( POLARSSL_ERR_NET_UNKNOWN_HOST ); return( POLARSSL_ERR_NET_UNKNOWN_HOST );
if( ( *fd = (int) socket( AF_INET, SOCK_STREAM, IPPROTO_IP ) ) < 0 ) if( ( *fd = (int) socket( AF_INET,
proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP ) ) < 0 )
return( POLARSSL_ERR_NET_SOCKET_FAILED ); return( POLARSSL_ERR_NET_SOCKET_FAILED );
memcpy( (void *) &server_addr.sin_addr, memcpy( (void *) &server_addr.sin_addr,
@@ -248,7 +250,7 @@ int net_connect( int *fd, const char *host, int port )
/* /*
* Create a listening socket on bind_ip:port * Create a listening socket on bind_ip:port
*/ */
int net_bind( int *fd, const char *bind_ip, int port ) int net_bind( int *fd, const char *bind_ip, int port, int proto )
{ {
#if defined(POLARSSL_HAVE_IPV6) #if defined(POLARSSL_HAVE_IPV6)
int n, ret; int n, ret;
@@ -265,8 +267,8 @@ int net_bind( int *fd, const char *bind_ip, int port )
/* Bind to IPv6 and/or IPv4, but only in TCP */ /* Bind to IPv6 and/or IPv4, but only in TCP */
memset( &hints, 0, sizeof( hints ) ); memset( &hints, 0, sizeof( hints ) );
hints.ai_family = AF_UNSPEC; hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM; hints.ai_socktype = proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP; hints.ai_protocol = proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP;
if( bind_ip == NULL ) if( bind_ip == NULL )
hints.ai_flags = AI_PASSIVE; hints.ai_flags = AI_PASSIVE;
@@ -301,12 +303,16 @@ int net_bind( int *fd, const char *bind_ip, int port )
continue; continue;
} }
/* Listen only makes sense for TCP */
if( proto == NET_PROTO_TCP )
{
if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 ) if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 )
{ {
close( *fd ); close( *fd );
ret = POLARSSL_ERR_NET_LISTEN_FAILED; ret = POLARSSL_ERR_NET_LISTEN_FAILED;
continue; continue;
} }
}
/* I we ever get there, it's a success */ /* I we ever get there, it's a success */
ret = 0; ret = 0;
@@ -326,7 +332,9 @@ int net_bind( int *fd, const char *bind_ip, int port )
if( ( ret = net_prepare() ) != 0 ) if( ( ret = net_prepare() ) != 0 )
return( ret ); return( ret );
if( ( *fd = (int) socket( AF_INET, SOCK_STREAM, IPPROTO_IP ) ) < 0 ) if( ( *fd = (int) socket( AF_INET,
proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM,
proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP ) ) < 0 )
return( POLARSSL_ERR_NET_SOCKET_FAILED ); return( POLARSSL_ERR_NET_SOCKET_FAILED );
n = 1; n = 1;
@@ -361,11 +369,15 @@ int net_bind( int *fd, const char *bind_ip, int port )
return( POLARSSL_ERR_NET_BIND_FAILED ); return( POLARSSL_ERR_NET_BIND_FAILED );
} }
/* Listen only makes sense for TCP */
if( proto == NET_PROTO_TCP )
{
if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 ) if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 )
{ {
close( *fd ); close( *fd );
return( POLARSSL_ERR_NET_LISTEN_FAILED ); return( POLARSSL_ERR_NET_LISTEN_FAILED );
} }
}
return( 0 ); return( 0 );
#endif /* POLARSSL_HAVE_IPV6 */ #endif /* POLARSSL_HAVE_IPV6 */
@@ -416,6 +428,9 @@ static int net_would_block( int fd )
*/ */
int net_accept( int bind_fd, int *client_fd, void *client_ip ) int net_accept( int bind_fd, int *client_fd, void *client_ip )
{ {
int ret;
int type;
#if defined(POLARSSL_HAVE_IPV6) #if defined(POLARSSL_HAVE_IPV6)
struct sockaddr_storage client_addr; struct sockaddr_storage client_addr;
#else #else
@@ -425,14 +440,35 @@ int net_accept( int bind_fd, int *client_fd, void *client_ip )
#if defined(__socklen_t_defined) || defined(_SOCKLEN_T) || \ #if defined(__socklen_t_defined) || defined(_SOCKLEN_T) || \
defined(_SOCKLEN_T_DECLARED) defined(_SOCKLEN_T_DECLARED)
socklen_t n = (socklen_t) sizeof( client_addr ); socklen_t n = (socklen_t) sizeof( client_addr );
socklen_t type_len = (socklen_t) sizeof( type );
#else #else
int n = (int) sizeof( client_addr ); int n = (int) sizeof( client_addr );
int type_len = (int) sizeof( type );
#endif #endif
*client_fd = (int) accept( bind_fd, (struct sockaddr *) /* Is this a TCP or UDP socket? */
&client_addr, &n ); if( getsockopt( bind_fd, SOL_SOCKET, SO_TYPE, &type, &type_len ) != 0 ||
( type != SOCK_STREAM && type != SOCK_DGRAM ) )
{
return( POLARSSL_ERR_NET_ACCEPT_FAILED );
}
if( *client_fd < 0 ) if( type == SOCK_STREAM )
{
/* TCP: actual accept() */
ret = *client_fd = (int) accept( bind_fd,
(struct sockaddr *) &client_addr, &n );
}
else
{
/* UDP: wait for a message, but keep it in the queue */
char buf[1] = { 0 };
ret = recvfrom( bind_fd, buf, 0, MSG_PEEK,
(struct sockaddr *) &client_addr, &n );
}
if( ret < 0 )
{ {
if( net_would_block( bind_fd ) != 0 ) if( net_would_block( bind_fd ) != 0 )
return( POLARSSL_ERR_NET_WANT_READ ); return( POLARSSL_ERR_NET_WANT_READ );
@@ -440,6 +476,15 @@ int net_accept( int bind_fd, int *client_fd, void *client_ip )
return( POLARSSL_ERR_NET_ACCEPT_FAILED ); return( POLARSSL_ERR_NET_ACCEPT_FAILED );
} }
/* UDP: hijack the listening socket for communicating with the client */
if( type != SOCK_STREAM )
{
if( connect( bind_fd, (struct sockaddr *) &client_addr, n ) != 0 )
return( POLARSSL_ERR_NET_ACCEPT_FAILED );
*client_fd = bind_fd;
}
if( client_ip != NULL ) if( client_ip != NULL )
{ {
#if defined(POLARSSL_HAVE_IPV6) #if defined(POLARSSL_HAVE_IPV6)

View File

@@ -135,7 +135,7 @@ int main( int argc, char *argv[] )
fflush( stdout ); fflush( stdout );
if( ( ret = net_connect( &server_fd, SERVER_NAME, if( ( ret = net_connect( &server_fd, SERVER_NAME,
SERVER_PORT ) ) != 0 ) SERVER_PORT, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_connect returned %d\n\n", ret ); printf( " failed\n ! net_connect returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -163,7 +163,7 @@ int main( int argc, char *argv[] )
printf( "\n . Waiting for a remote connection" ); printf( "\n . Waiting for a remote connection" );
fflush( stdout ); fflush( stdout );
if( ( ret = net_bind( &listen_fd, NULL, SERVER_PORT ) ) != 0 ) if( ( ret = net_bind( &listen_fd, NULL, SERVER_PORT, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_bind returned %d\n\n", ret ); printf( " failed\n ! net_bind returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -140,7 +140,7 @@ int main( int argc, char *argv[] )
fflush( stdout ); fflush( stdout );
if( ( ret = net_connect( &server_fd, SERVER_NAME, if( ( ret = net_connect( &server_fd, SERVER_NAME,
SERVER_PORT ) ) != 0 ) SERVER_PORT, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_connect returned %d\n\n", ret ); printf( " failed\n ! net_connect returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -844,7 +844,7 @@ int main( int argc, char *argv[] )
fflush( stdout ); fflush( stdout );
if( ( ret = net_connect( &server_fd, opt.server_addr, if( ( ret = net_connect( &server_fd, opt.server_addr,
opt.server_port ) ) != 0 ) opt.server_port, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_connect returned -0x%x\n\n", -ret ); printf( " failed\n ! net_connect returned -0x%x\n\n", -ret );
goto exit; goto exit;
@@ -1260,7 +1260,7 @@ reconnect:
} }
if( ( ret = net_connect( &server_fd, opt.server_name, if( ( ret = net_connect( &server_fd, opt.server_name,
opt.server_port ) ) != 0 ) opt.server_port , NET_PROTO_TCP) ) != 0 )
{ {
printf( " failed\n ! net_connect returned -0x%x\n\n", -ret ); printf( " failed\n ! net_connect returned -0x%x\n\n", -ret );
goto exit; goto exit;

View File

@@ -179,7 +179,7 @@ int main( int argc, char *argv[] )
printf( " . Bind on https://localhost:4433/ ..." ); printf( " . Bind on https://localhost:4433/ ..." );
fflush( stdout ); fflush( stdout );
if( ( ret = net_bind( &listen_fd, NULL, 4433 ) ) != 0 ) if( ( ret = net_bind( &listen_fd, NULL, 4433, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_bind returned %d\n\n", ret ); printf( " failed\n ! net_bind returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -574,7 +574,7 @@ int main( int argc, char *argv[] )
fflush( stdout ); fflush( stdout );
if( ( ret = net_connect( &server_fd, opt.server_name, if( ( ret = net_connect( &server_fd, opt.server_name,
opt.server_port ) ) != 0 ) opt.server_port, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_connect returned %d\n\n", ret ); printf( " failed\n ! net_connect returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -445,7 +445,7 @@ int main( int argc, char *argv[] )
printf( " . Bind on https://localhost:4433/ ..." ); printf( " . Bind on https://localhost:4433/ ..." );
fflush( stdout ); fflush( stdout );
if( ( ret = net_bind( &listen_fd, NULL, 4433 ) ) != 0 ) if( ( ret = net_bind( &listen_fd, NULL, 4433, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_bind returned %d\n\n", ret ); printf( " failed\n ! net_bind returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -159,7 +159,7 @@ int main( int argc, char *argv[] )
printf( " . Bind on https://localhost:4433/ ..." ); printf( " . Bind on https://localhost:4433/ ..." );
fflush( stdout ); fflush( stdout );
if( ( ret = net_bind( &listen_fd, NULL, 4433 ) ) != 0 ) if( ( ret = net_bind( &listen_fd, NULL, 4433, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_bind returned %d\n\n", ret ); printf( " failed\n ! net_bind returned %d\n\n", ret );
goto exit; goto exit;

View File

@@ -1246,7 +1246,7 @@ int main( int argc, char *argv[] )
fflush( stdout ); fflush( stdout );
if( ( ret = net_bind( &listen_fd, opt.server_addr, if( ( ret = net_bind( &listen_fd, opt.server_addr,
opt.server_port ) ) != 0 ) opt.server_port, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_bind returned -0x%x\n\n", -ret ); printf( " failed\n ! net_bind returned -0x%x\n\n", -ret );
goto exit; goto exit;

View File

@@ -193,7 +193,7 @@ static int ssl_test( struct options *opt )
if( opt->opmode == OPMODE_CLIENT ) if( opt->opmode == OPMODE_CLIENT )
{ {
if( ( ret = net_connect( &client_fd, opt->server_name, if( ( ret = net_connect( &client_fd, opt->server_name,
opt->server_port ) ) != 0 ) opt->server_port, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " ! net_connect returned %d\n\n", ret ); printf( " ! net_connect returned %d\n\n", ret );
return( ret ); return( ret );
@@ -242,7 +242,7 @@ static int ssl_test( struct options *opt )
if( server_fd < 0 ) if( server_fd < 0 )
{ {
if( ( ret = net_bind( &server_fd, NULL, if( ( ret = net_bind( &server_fd, NULL,
opt->server_port ) ) != 0 ) opt->server_port, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " ! net_bind returned %d\n\n", ret ); printf( " ! net_bind returned %d\n\n", ret );
return( ret ); return( ret );

View File

@@ -402,7 +402,7 @@ int main( int argc, char *argv[] )
fflush( stdout ); fflush( stdout );
if( ( ret = net_connect( &server_fd, opt.server_name, if( ( ret = net_connect( &server_fd, opt.server_name,
opt.server_port ) ) != 0 ) opt.server_port, NET_PROTO_TCP ) ) != 0 )
{ {
printf( " failed\n ! net_connect returned %d\n\n", ret ); printf( " failed\n ! net_connect returned %d\n\n", ret );
goto exit; goto exit;