diff --git a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp index 325b4c1b5d..8a7d71e99f 100644 --- a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp +++ b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp @@ -74,37 +74,47 @@ typedef std::list BufferList; class SSLContext { public: - SSLContext() + SSLContext(bool isServer = false) { - if (_ssl_ctx_refcnt == 0) { - _ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); + _isServer = isServer; + if (!_isServer) { + if (_ssl_client_ctx_refcnt == 0) { + _ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); + } + ++_ssl_client_ctx_refcnt; + } else { + if (_ssl_svr_ctx_refcnt == 0) { + _ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); + } + ++_ssl_svr_ctx_refcnt; } - ++_ssl_ctx_refcnt; } ~SSLContext() { - if (_ssl) { - ssl_free(_ssl); - _ssl = nullptr; + if (io_ctx) { + io_ctx->unref(); + io_ctx = nullptr; } - - --_ssl_ctx_refcnt; - if (_ssl_ctx_refcnt == 0) { - ssl_ctx_free(_ssl_ctx); + _ssl = nullptr; + if (!_isServer) { + --_ssl_client_ctx_refcnt; + if (_ssl_client_ctx_refcnt == 0) { + ssl_ctx_free(_ssl_client_ctx); + _ssl_client_ctx = nullptr; + } + } else { + --_ssl_svr_ctx_refcnt; + if (_ssl_svr_ctx_refcnt == 0) { + ssl_ctx_free(_ssl_svr_ctx); + _ssl_svr_ctx = nullptr; + } } } - void ref() - { - ++_refcnt; - } - - void unref() + static void _delete_shared_SSL(SSL *_to_del) { - if (--_refcnt == 0) { - delete this; - } + ssl_free(_to_del); } void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms) @@ -116,17 +126,23 @@ class SSLContext ssl_free will want to send a close notify alert, but the old TCP connection is already gone at this point, so reset io_ctx. */ io_ctx = nullptr; - ssl_free(_ssl); + _ssl = nullptr; _available = 0; _read_ptr = nullptr; } io_ctx = ctx; - _ssl = ssl_client_new(_ssl_ctx, reinterpret_cast(this), nullptr, 0, ext); + ctx->ref(); + + // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free + SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast(this), nullptr, 0, ext); + std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSL); + _ssl = _new_ssl_shared; + uint32_t t = millis(); - while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) { + while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { uint8_t* data; - int rc = ssl_read(_ssl, &data); + int rc = ssl_read(_ssl.get(), &data); if (rc < SSL_OK) { ssl_display_error(rc); break; @@ -134,18 +150,23 @@ class SSLContext } } - void connectServer(ClientContext *ctx) { + void connectServer(ClientContext *ctx, uint32_t timeout_ms) + { io_ctx = ctx; - _ssl = ssl_server_new(_ssl_ctx, reinterpret_cast(this)); - _isServer = true; + ctx->ref(); + + // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free + SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast(this)); + std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSL); + _ssl = _new_ssl_shared; - uint32_t timeout_ms = 5000; uint32_t t = millis(); - while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) { + while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { uint8_t* data; - int rc = ssl_read(_ssl, &data); + int rc = ssl_read(_ssl.get(), &data); if (rc < SSL_OK) { + ssl_display_error(rc); break; } } @@ -153,13 +174,19 @@ class SSLContext void stop() { + if (io_ctx) { + io_ctx->unref(); + } io_ctx = nullptr; } bool connected() { - if (_isServer) return _ssl != nullptr; - else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK; + if (_isServer) { + return _ssl != nullptr; + } else { + return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK; + } } int read(uint8_t* dst, size_t size) @@ -289,10 +316,9 @@ class SSLContext return loadObject(type, buf.get(), size); } - bool loadObject(int type, const uint8_t* data, size_t size) { - int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast(size), nullptr); + int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast(size), nullptr); if (rc != SSL_OK) { DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc); return false; @@ -302,7 +328,7 @@ class SSLContext bool verifyCert() { - int rc = ssl_verify_cert(_ssl); + int rc = ssl_verify_cert(_ssl.get()); if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) { DEBUGV("Allowing self-signed certificate\n"); return true; @@ -321,12 +347,16 @@ class SSLContext operator SSL*() { - return _ssl; + return _ssl.get(); } static ClientContext* getIOContext(int fd) { - return reinterpret_cast(fd)->io_ctx; + if (fd) { + SSLContext *thisSSL = reinterpret_cast(fd); + return thisSSL->io_ctx; + } + return nullptr; } protected: @@ -339,10 +369,9 @@ class SSLContext optimistic_yield(100); uint8_t* data; - int rc = ssl_read(_ssl, &data); + int rc = ssl_read(_ssl.get(), &data); if (rc <= 0) { if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) { - ssl_free(_ssl); _ssl = nullptr; } return 0; @@ -359,7 +388,7 @@ class SSLContext return 0; } - int rc = ssl_write(_ssl, src, size); + int rc = ssl_write(_ssl.get(), src, size); if (rc >= 0) { return rc; } @@ -404,10 +433,11 @@ class SSLContext } bool _isServer = false; - static SSL_CTX* _ssl_ctx; - static int _ssl_ctx_refcnt; - SSL* _ssl = nullptr; - int _refcnt = 0; + static SSL_CTX* _ssl_client_ctx; + static int _ssl_client_ctx_refcnt; + static SSL_CTX* _ssl_svr_ctx; + static int _ssl_svr_ctx_refcnt; + std::shared_ptr _ssl = nullptr; const uint8_t* _read_ptr = nullptr; size_t _available = 0; BufferList _writeBuffers; @@ -415,8 +445,10 @@ class SSLContext ClientContext* io_ctx = nullptr; }; -SSL_CTX* SSLContext::_ssl_ctx = nullptr; -int SSLContext::_ssl_ctx_refcnt = 0; +SSL_CTX* SSLContext::_ssl_client_ctx = nullptr; +int SSLContext::_ssl_client_ctx_refcnt = 0; +SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr; +int SSLContext::_ssl_svr_ctx_refcnt = 0; WiFiClientSecure::WiFiClientSecure() { @@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure() WiFiClientSecure::~WiFiClientSecure() { - if (_ssl) { - _ssl->unref(); - } -} - -WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other) - : WiFiClient(static_cast(other)) -{ - _ssl = other._ssl; - if (_ssl) { - _ssl->ref(); - } -} - -WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs) -{ - (WiFiClient&) *this = rhs; - _ssl = rhs._ssl; - if (_ssl) { - _ssl->ref(); - } - return *this; + _ssl = nullptr; } // Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning -WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen) +WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, + const uint8_t *rsakey, int rsakeyLen, + const uint8_t *cert, int certLen) { + // TLS handshake may take more than the 5 second default timeout + _timeout = 15000; + + // We've been given the client context from the available() call _client = client; - if (_ssl) { - _ssl->unref(); - _ssl = nullptr; - } + _client->ref(); - _ssl = new SSLContext; - _ssl->ref(); + // Make the "_ssl" SSLContext, in the constructor there should be none yet + SSLContext *_new_ssl = new SSLContext(true); + std::shared_ptr _new_ssl_shared(_new_ssl); + _ssl = _new_ssl_shared; if (usePMEM) { if (rsakey && rsakeyLen) { @@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui _ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen); } } - _client->ref(); - _ssl->connectServer(client); + _ssl->connectServer(client, _timeout); } int WiFiClientSecure::connect(IPAddress ip, uint16_t port) @@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port) int WiFiClientSecure::_connectSSL(const char* hostName) { if (!_ssl) { - _ssl = new SSLContext; - _ssl->ref(); + _ssl = std::make_shared(); } _ssl->connect(_client, hostName, _timeout); auto status = ssl_handshake_status(*_ssl); if (status != SSL_OK) { - _ssl->unref(); _ssl = nullptr; return 0; } @@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) } if (rc != SSL_CLOSE_NOTIFY) { - _ssl->unref(); _ssl = nullptr; } @@ -640,8 +652,6 @@ void WiFiClientSecure::stop() { if (_ssl) { _ssl->stop(); - _ssl->unref(); - _ssl = nullptr; } WiFiClient::stop(); } @@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name) String domain_name_str(domain_name); domain_name_str.toLowerCase(); - const char* san = NULL; + const char* san = nullptr; int i = 0; - while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) { + while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) { String san_str(san); san_str.toLowerCase(); if (matchName(san_str, domain_name_str)) { @@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name) void WiFiClientSecure::_initSSLContext() { if (!_ssl) { - _ssl = new SSLContext; - _ssl->ref(); + _ssl = std::make_shared(); } } diff --git a/libraries/ESP8266WiFi/src/WiFiClientSecure.h b/libraries/ESP8266WiFi/src/WiFiClientSecure.h index 9b7cf8df10..73ec587f1d 100644 --- a/libraries/ESP8266WiFi/src/WiFiClientSecure.h +++ b/libraries/ESP8266WiFi/src/WiFiClientSecure.h @@ -32,8 +32,6 @@ class WiFiClientSecure : public WiFiClient { public: WiFiClientSecure(); ~WiFiClientSecure() override; - WiFiClientSecure(const WiFiClientSecure&); - WiFiClientSecure& operator=(const WiFiClientSecure&); int connect(IPAddress ip, uint16_t port) override; int connect(const String host, uint16_t port) override; @@ -91,7 +89,7 @@ friend class WiFiServerSecure; // Needs access to custom constructor below int _connectSSL(const char* hostName); bool _verifyDN(const char* name); - SSLContext* _ssl = nullptr; + std::shared_ptr _ssl = nullptr; }; #endif //wificlientsecure_h