]> granicus.if.org Git - esp-idf/blobdiff - components/tcp_transport/transport_ssl.c
esp_http_client: Add support for non-blocking feature in esp_http_client_perform...
[esp-idf] / components / tcp_transport / transport_ssl.c
index 9ccaf4028ed2f34ff3e86c8ea9ee062782db8959..8e995764e6ee7b35b1c4a8211f8469d531633567 100644 (file)
 #include "transport_utils.h"
 
 static const char *TAG = "TRANS_SSL";
+
+typedef enum {
+    TRANS_SSL_INIT = 0,
+    TRANS_SSL_CONNECTING,
+} transport_ssl_conn_state_t;
+
 /**
  *  mbedtls specific transport data
  */
 typedef struct {
     esp_tls_t                *tls;
-    void                     *cert_pem_data;
-    int                      cert_pem_len;
+    esp_tls_cfg_t            cfg;
     bool                     ssl_initialized;
     bool                     verify_server;
+    transport_ssl_conn_state_t conn_state;
 } transport_ssl_t;
 
 transport_handle_t transport_get_handle(transport_handle_t t);
 
 static int ssl_close(transport_handle_t t);
 
+static int ssl_connect_async(transport_handle_t t, const char *host, int port, int timeout_ms)
+{
+    transport_ssl_t *ssl = transport_get_context_data(t);
+    if (ssl->conn_state == TRANS_SSL_INIT) {
+        if (ssl->cfg.cacert_pem_buf) {
+            ssl->verify_server = true;
+        }
+        ssl->cfg.timeout_ms = timeout_ms;
+        ssl->cfg.non_block = true;
+        ssl->ssl_initialized = true;
+        ssl->tls = calloc(1, sizeof(esp_tls_t));
+        if (!ssl->tls) {
+            return -1;
+        }
+        ssl->conn_state = TRANS_SSL_CONNECTING;
+    }
+    if (ssl->conn_state == TRANS_SSL_CONNECTING) {
+        return esp_tls_conn_new_async(host, strlen(host), port, &ssl->cfg, ssl->tls);
+    }
+    return 0;
+}
+
 static int ssl_connect(transport_handle_t t, const char *host, int port, int timeout_ms)
 {
     transport_ssl_t *ssl = transport_get_context_data(t);
-    esp_tls_cfg_t cfg = { 0 };
-    if (ssl->cert_pem_data) {
+    if (ssl->cfg.cacert_pem_buf) {
         ssl->verify_server = true;
-        cfg.cacert_pem_buf = ssl->cert_pem_data;
-        cfg.cacert_pem_bytes = ssl->cert_pem_len + 1;
     }
-    cfg.timeout_ms = timeout_ms;
+    ssl->cfg.timeout_ms = timeout_ms;
     ssl->ssl_initialized = true;
-    ssl->tls = esp_tls_conn_new(host, strlen(host), port, &cfg);
+    ssl->tls = esp_tls_conn_new(host, strlen(host), port, &ssl->cfg);
     if (!ssl->tls) {
         ESP_LOGE(TAG, "Failed to open a new connection");
         return -1;
@@ -94,7 +119,7 @@ static int ssl_write(transport_handle_t t, const char *buffer, int len, int time
     }
     ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
     if (ret <= 0) {
-        ESP_LOGE(TAG, "mbedtls_ssl_write error, errno=%s", strerror(errno));
+        ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
     }
     return ret;
 }
@@ -111,7 +136,7 @@ static int ssl_read(transport_handle_t t, char *buffer, int len, int timeout_ms)
     }
     ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
     if (ret <= 0) {
-        ESP_LOGE(TAG, "mbedtls_ssl_read error, errno=%s", strerror(errno));
+        ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
     }
     return ret;
 }
@@ -140,8 +165,8 @@ void transport_ssl_set_cert_data(transport_handle_t t, const char *data, int len
 {
     transport_ssl_t *ssl = transport_get_context_data(t);
     if (t && ssl) {
-        ssl->cert_pem_data = (void *)data;
-        ssl->cert_pem_len = len;
+        ssl->cfg.cacert_pem_buf = (void *)data;
+        ssl->cfg.cacert_pem_bytes = len + 1;
     }
 }
 
@@ -152,6 +177,7 @@ transport_handle_t transport_ssl_init()
     TRANSPORT_MEM_CHECK(TAG, ssl, return NULL);
     transport_set_context_data(t, ssl);
     transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy, transport_get_handle);
+    transport_set_async_connect_func(t, ssl_connect_async);
     return t;
 }