Skip to content

Require explicit socket port reuse #8940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions ports/espressif/common-hal/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_o
}
}

bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
size_t common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
const char *host, size_t hostlen, uint32_t port) {
struct sockaddr_in bind_addr;
const char *broadcast = "<broadcast>";
Expand All @@ -351,13 +351,11 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self,
bind_addr.sin_family = AF_INET;
bind_addr.sin_port = htons(port);

int opt = 1;
int err = lwip_setsockopt(self->num, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
if (err != 0) {
mp_raise_RuntimeError(MP_ERROR_TEXT("Cannot set socket options"));
}
int result = lwip_bind(self->num, (struct sockaddr *)&bind_addr, sizeof(bind_addr));
return result == 0;
if (result == 0) {
return 0;
}
return errno;
}

void socketpool_socket_close(socketpool_socket_obj_t *self) {
Expand Down
7 changes: 3 additions & 4 deletions ports/raspberrypi/common-hal/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_o
return MP_OBJ_FROM_PTR(accepted);
}

bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
size_t common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
const char *host, size_t hostlen, uint32_t port) {

// get address
Expand All @@ -876,7 +876,6 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
} else {
bind_addr_ptr = IP_ANY_TYPE;
}
ip_set_option(socket->pcb.ip, SOF_REUSEADDR);

err_t err = ERR_ARG;
switch (socket->type) {
Expand All @@ -891,10 +890,10 @@ bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *socket,
}

if (err != ERR_OK) {
mp_raise_OSError(error_lookup_table[-err]);
return error_lookup_table[-err];
}

return mp_const_none;
return 0;
}

STATIC err_t _lwip_tcp_close_poll(void *arg, struct tcp_pcb *pcb) {
Expand Down
6 changes: 3 additions & 3 deletions shared-bindings/socketpool/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ STATIC mp_obj_t socketpool_socket_bind(mp_obj_t self_in, mp_obj_t addr_in) {
mp_raise_ValueError(MP_ERROR_TEXT("port must be >= 0"));
}

bool ok = common_hal_socketpool_socket_bind(self, host, hostlen, (uint32_t)port);
if (!ok) {
mp_raise_ValueError(MP_ERROR_TEXT("Error: Failure to bind"));
size_t error = common_hal_socketpool_socket_bind(self, host, hostlen, (uint32_t)port);
if (error != 0) {
mp_raise_OSError(error);
}

return mp_const_none;
Expand Down
2 changes: 1 addition & 1 deletion shared-bindings/socketpool/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
extern const mp_obj_type_t socketpool_socket_type;

socketpool_socket_obj_t *common_hal_socketpool_socket_accept(socketpool_socket_obj_t *self, uint8_t *ip, uint32_t *port);
bool common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
size_t common_hal_socketpool_socket_bind(socketpool_socket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
void common_hal_socketpool_socket_close(socketpool_socket_obj_t *self);
void common_hal_socketpool_socket_connect(socketpool_socket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
bool common_hal_socketpool_socket_get_closed(socketpool_socket_obj_t *self);
Expand Down
8 changes: 8 additions & 0 deletions shared-bindings/socketpool/SocketPool.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ MP_DEFINE_EXCEPTION(gaierror, OSError)
//| SOCK_RAW: int
//| EAI_NONAME: int
//|
//| SOL_SOCKET: int
//|
//| SO_REUSEADDR: int
//|
//| TCP_NODELAY: int
//|
//| IPPROTO_IP: int
Expand Down Expand Up @@ -196,6 +200,10 @@ STATIC const mp_rom_map_elem_t socketpool_socketpool_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_SOCK_DGRAM), MP_ROM_INT(SOCKETPOOL_SOCK_DGRAM) },
{ MP_ROM_QSTR(MP_QSTR_SOCK_RAW), MP_ROM_INT(SOCKETPOOL_SOCK_RAW) },

{ MP_ROM_QSTR(MP_QSTR_SOL_SOCKET), MP_ROM_INT(SOCKETPOOL_SOL_SOCKET) },

{ MP_ROM_QSTR(MP_QSTR_SO_REUSEADDR), MP_ROM_INT(SOCKETPOOL_SO_REUSEADDR) },

{ MP_ROM_QSTR(MP_QSTR_TCP_NODELAY), MP_ROM_INT(SOCKETPOOL_TCP_NODELAY) },

{ MP_ROM_QSTR(MP_QSTR_IPPROTO_IP), MP_ROM_INT(SOCKETPOOL_IPPROTO_IP) },
Expand Down
8 changes: 8 additions & 0 deletions shared-bindings/socketpool/SocketPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ typedef enum {
SOCKETPOOL_TCP_NODELAY = 1,
} socketpool_socketpool_tcpopt_t;

typedef enum {
SOCKETPOOL_SOL_SOCKET = 0xfff,
} socketpool_socketpool_optlevel_t;

typedef enum {
SOCKETPOOL_SO_REUSEADDR = 0x0004,
} socketpool_socketpool_socketopt_t;

typedef enum {
SOCKETPOOL_IP_MULTICAST_TTL = 5,
} socketpool_socketpool_ipopt_t;
Expand Down
24 changes: 12 additions & 12 deletions shared-module/ssl/SSLSocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons
(void)level;
mp_printf(&mp_plat_print, "DBG:%s:%04d: %s\n", file, line, str);
}
#define DEBUG(fmt, ...) mp_printf(&mp_plat_print, "DBG:%s:%04d: " fmt "\n", __FILE__, __LINE__,##__VA_ARGS__)
#define DEBUG_PRINT(fmt, ...) mp_printf(&mp_plat_print, "DBG:%s:%04d: " fmt "\n", __FILE__, __LINE__,##__VA_ARGS__)
#else
#define DEBUG(...) do {} while (0)
#define DEBUG_PRINT(...) do {} while (0)
#endif

STATIC NORETURN void mbedtls_raise_error(int err) {
Expand Down Expand Up @@ -107,10 +107,10 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {

// mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
mp_int_t out_sz = socketpool_socket_send(sock, buf, len);
DEBUG("socket_send() -> %d", out_sz);
DEBUG_PRINT("socket_send() -> %d", out_sz);
if (out_sz < 0) {
int err = -out_sz;
DEBUG("sock_stream->write() -> %d nonblocking? %d", out_sz, mp_is_nonblocking_error(err));
DEBUG_PRINT("sock_stream->write() -> %d nonblocking? %d", out_sz, mp_is_nonblocking_error(err));
if (mp_is_nonblocking_error(err)) {
return MBEDTLS_ERR_SSL_WANT_WRITE;
}
Expand All @@ -125,7 +125,7 @@ STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

mp_int_t out_sz = socketpool_socket_recv_into(sock, buf, len);
DEBUG("socket_recv() -> %d", out_sz);
DEBUG_PRINT("socket_recv() -> %d", out_sz);
if (out_sz < 0) {
int err = -out_sz;
if (mp_is_nonblocking_error(err)) {
Expand Down Expand Up @@ -261,14 +261,14 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t

mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) {
int ret = mbedtls_ssl_read(&self->ssl, buf, len);
DEBUG("recv_into mbedtls_ssl_read() -> %d\n", ret);
DEBUG_PRINT("recv_into mbedtls_ssl_read() -> %d\n", ret);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
DEBUG("returning %d\n", 0);
DEBUG_PRINT("returning %d\n", 0);
// end of stream
return 0;
}
if (ret >= 0) {
DEBUG("returning %d\n", ret);
DEBUG_PRINT("returning %d\n", ret);
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
Expand All @@ -279,15 +279,15 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
// renegotiation.
ret = MP_EWOULDBLOCK;
}
DEBUG("raising errno [error case] %d\n", ret);
DEBUG_PRINT("raising errno [error case] %d\n", ret);
mp_raise_OSError(ret);
}

mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) {
int ret = mbedtls_ssl_write(&self->ssl, buf, len);
DEBUG("send mbedtls_ssl_write() -> %d\n", ret);
DEBUG_PRINT("send mbedtls_ssl_write() -> %d\n", ret);
if (ret >= 0) {
DEBUG("returning %d\n", ret);
DEBUG_PRINT("returning %d\n", ret);
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
Expand All @@ -298,7 +298,7 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
// renegotiation.
ret = MP_EWOULDBLOCK;
}
DEBUG("raising errno [error case] %d\n", ret);
DEBUG_PRINT("raising errno [error case] %d\n", ret);
mp_raise_OSError(ret);
}

Expand Down