diff --git a/arduino/libraries/WiFi/src/WiFiClient.cpp b/arduino/libraries/WiFi/src/WiFiClient.cpp index e5f05cbe..59f1d121 100644 --- a/arduino/libraries/WiFi/src/WiFiClient.cpp +++ b/arduino/libraries/WiFi/src/WiFiClient.cpp @@ -26,6 +26,9 @@ #include "WiFiClient.h" +extern "C" { + #include "esp_log.h" +} WiFiClient::WiFiClient() : WiFiClient(-1) @@ -64,15 +67,59 @@ int WiFiClient::connect(IPAddress ip, uint16_t port) addr.sin_addr.s_addr = (uint32_t)ip; addr.sin_port = htons(port); + if (_connTimeout == 0) { if (lwip_connect_r(_socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) { lwip_close_r(_socket); _socket = -1; return 0; } + } int nonBlocking = 1; lwip_ioctl_r(_socket, FIONBIO, &nonBlocking); + if (_connTimeout > 0) { + int res = lwip_connect_r(_socket, (struct sockaddr*)&addr, sizeof(addr)); + if (res < 0 && errno != EINPROGRESS) { + ESP_LOGW("WiFiClient", "connect on socket %d, errno: %d, \"%s\"", _socket, errno, strerror(errno)); + lwip_close_r(_socket); + _socket = -1; + return 0; + } + + struct timeval tv; + tv.tv_sec = _connTimeout / 1000; + tv.tv_usec = (_connTimeout % 1000) * 1000; + + fd_set fdset; + FD_ZERO(&fdset); + FD_SET(_socket, &fdset); + + res = select(_socket + 1, nullptr, &fdset, nullptr, &tv); + if (res < 0) { + ESP_LOGW("WiFiClient", "select on socket %d, errno: %d, \"%s\"", _socket, errno, strerror(errno)); + lwip_close_r(_socket); + return 0; + } + if (res == 0) { + ESP_LOGW("WiFiClient", "select returned due to timeout %d ms for socket %d", _connTimeout, _socket); + lwip_close_r(_socket); + return 0; + } + int sockerr; + socklen_t len = (socklen_t) sizeof(int); + res = lwip_getsockopt(_socket, SOL_SOCKET, SO_ERROR, &sockerr, &len); + if (res < 0) { + ESP_LOGW("WiFiClient", "getsockopt on socket %d, errno: %d, \"%s\"", _socket, errno, strerror(errno)); + lwip_close_r(_socket); + return 0; + } + if (sockerr != 0) { + ESP_LOGW("WiFiClient", "socket error on socket %d, errno: %d, \"%s\"", _socket, sockerr, strerror(sockerr)); + lwip_close_r(_socket); + return 0; + } + } return 1; } diff --git a/arduino/libraries/WiFi/src/WiFiClient.h b/arduino/libraries/WiFi/src/WiFiClient.h index 3840ed89..5ea6d535 100644 --- a/arduino/libraries/WiFi/src/WiFiClient.h +++ b/arduino/libraries/WiFi/src/WiFiClient.h @@ -50,6 +50,8 @@ class WiFiClient : public Client { virtual /*IPAddress*/uint32_t remoteIP(); virtual uint16_t remotePort(); + void setConnectionTimeout(uint16_t timeout) {_connTimeout = timeout;} + // using Print::write; protected: @@ -59,6 +61,7 @@ class WiFiClient : public Client { private: int _socket; + uint16_t _connTimeout = 0; }; #endif // WIFICLIENT_H diff --git a/arduino/libraries/WiFi/src/WiFiSSLClient.cpp b/arduino/libraries/WiFi/src/WiFiSSLClient.cpp index 3bc7b40e..f4ef17b2 100644 --- a/arduino/libraries/WiFi/src/WiFiSSLClient.cpp +++ b/arduino/libraries/WiFi/src/WiFiSSLClient.cpp @@ -23,6 +23,10 @@ #include "WiFiSSLClient.h" +extern "C" { + #include "esp_log.h" +} + class __Guard { public: __Guard(SemaphoreHandle_t handle) { @@ -50,6 +54,8 @@ WiFiSSLClient::WiFiSSLClient() : _mbedMutex = xSemaphoreCreateRecursiveMutex(); } +static int net_connect( mbedtls_net_context *ctx, const char *host, const char *port, int proto, uint16_t timeout); + int WiFiSSLClient::connect(const char* host, uint16_t port, bool sni) { synchronized { @@ -113,7 +119,8 @@ int WiFiSSLClient::connect(const char* host, uint16_t port, bool sni) char portStr[6]; itoa(port, portStr, 10); - if (mbedtls_net_connect(&_netContext, host, portStr, MBEDTLS_NET_PROTO_TCP) != 0) { + if (_connTimeout ? net_connect(&_netContext, host, portStr, MBEDTLS_NET_PROTO_TCP, _connTimeout) + : mbedtls_net_connect(&_netContext, host, portStr, MBEDTLS_NET_PROTO_TCP)) { stop(); return 0; } @@ -293,3 +300,79 @@ uint16_t WiFiSSLClient::remotePort() return ntohs(((struct sockaddr_in *)&addr)->sin_port); } + + +/* + * based on mbedtls_net_connect, but with timeout support + */ +int net_connect(mbedtls_net_context *ctx, const char *host, const char *port, int proto, uint16_t timeout) { + int ret; + struct addrinfo hints, *addr_list, *cur; + + /* Do name resolution with both IPv6 and IPv4 */ + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = proto == MBEDTLS_NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM; + hints.ai_protocol = + proto == MBEDTLS_NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP; + + if ( getaddrinfo( host, port, &hints, &addr_list ) != 0) { + return ( MBEDTLS_ERR_NET_UNKNOWN_HOST); + } + + /* Try the sockaddrs until a connection succeeds */ + ret = MBEDTLS_ERR_NET_UNKNOWN_HOST; + for (cur = addr_list; cur != NULL; cur = cur->ai_next) { + int fd = socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); + + if (fd < 0) { + ret = MBEDTLS_ERR_NET_SOCKET_FAILED; + continue; + } + + mbedtls_net_context tmpCtx; + tmpCtx.fd = fd; + mbedtls_net_set_nonblock(&tmpCtx); + + int res = connect(fd, cur->ai_addr, cur->ai_addrlen); + if (res < 0 && errno != EINPROGRESS) { + ESP_LOGW("WiFiSSLClient", "connect on fd %d, errno: %d, \"%s\"", fd, errno, strerror(errno)); + } else { + struct timeval tv; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; + + fd_set fdset; + FD_ZERO(&fdset); + FD_SET(fd, &fdset); + + res = select(fd + 1, nullptr, &fdset, nullptr, &tv); + if (res < 0) { + ESP_LOGW("WiFiSSLClient", "select on fd %d, errno: %d, \"%s\"", fd, errno, strerror(errno)); + } else if (res == 0) { + ESP_LOGW("WiFiSSLClient", "select returned due to timeout %d ms for fd %d", timeout, fd); + } else { + int sockerr; + socklen_t len = (socklen_t) sizeof(int); + res = getsockopt(fd, SOL_SOCKET, SO_ERROR, &sockerr, &len); + if (res < 0) { + ESP_LOGW("WiFiSSLClient", "getsockopt on fd %d, errno: %d, \"%s\"", fd, errno, strerror(errno)); + } else if (sockerr != 0) { + ESP_LOGW("WiFiSSLClient", "socket error on fd %d, errno: %d, \"%s\"", fd, sockerr, strerror(sockerr)); + } else { + ctx->fd = fd; // connected! + ret = 0; + mbedtls_net_set_block(ctx); // back to blocking for SSL handshake + break; + } + } + } + close(fd); + ret = MBEDTLS_ERR_NET_CONNECT_FAILED; + } + + freeaddrinfo(addr_list); + + return (ret); +} + diff --git a/arduino/libraries/WiFi/src/WiFiSSLClient.h b/arduino/libraries/WiFi/src/WiFiSSLClient.h index 523e0427..41990a1a 100644 --- a/arduino/libraries/WiFi/src/WiFiSSLClient.h +++ b/arduino/libraries/WiFi/src/WiFiSSLClient.h @@ -55,6 +55,8 @@ class WiFiSSLClient /*: public Client*/ { virtual /*IPAddress*/uint32_t remoteIP(); virtual uint16_t remotePort(); + void setConnectionTimeout(uint16_t timeout) {_connTimeout = timeout;} + private: int connect(const char* host, uint16_t port, bool sni); @@ -69,6 +71,7 @@ class WiFiSSLClient /*: public Client*/ { mbedtls_x509_crt _caCrt; bool _connected; int _peek; + uint16_t _connTimeout = 0; SemaphoreHandle_t _mbedMutex; }; diff --git a/main/CommandHandler.cpp b/main/CommandHandler.cpp index e3435d1b..1d821624 100644 --- a/main/CommandHandler.cpp +++ b/main/CommandHandler.cpp @@ -595,6 +595,7 @@ int startClientTcp(const uint8_t command[], uint8_t response[]) uint16_t port; uint8_t socket; uint8_t type; + uint16_t timeout = 0; memset(host, 0x00, sizeof(host)); @@ -611,11 +612,16 @@ int startClientTcp(const uint8_t command[], uint8_t response[]) port = ntohs(port); socket = command[13 + command[3]]; type = command[15 + command[3]]; + if (command[2] == 6) { // optional sixth parameter + timeout = (uint16_t) command[17 + command[3]] << 8 | command[18 + command[3]]; + } } if (type == 0x00) { int result; + tcpClients[socket].setConnectionTimeout(timeout); + if (host[0] != '\0') { result = tcpClients[socket].connect(host, port); } else { @@ -660,6 +666,8 @@ int startClientTcp(const uint8_t command[], uint8_t response[]) } else if (type == 0x02) { int result; + tlsClients[socket].setConnectionTimeout(timeout); + if (host[0] != '\0') { result = tlsClients[socket].connect(host, port); } else { @@ -684,6 +692,8 @@ int startClientTcp(const uint8_t command[], uint8_t response[]) configureECCx08(); + static_cast(bearsslClient.getClient())->setConnectionTimeout(timeout); + if (host[0] != '\0') { result = bearsslClient.connect(host, port); } else {