#include "transport_utils.h"
static const char *TAG = "TRANS_SSL";
+
+typedef enum {
+ TRANS_SSL_INIT = 0,
+ TRANS_SSL_CONNECTING,
+} transport_ssl_conn_state_t;
+
/**
* mbedtls specific transport data
*/
typedef struct {
esp_tls_t *tls;
- void *cert_pem_data;
- int cert_pem_len;
+ esp_tls_cfg_t cfg;
bool ssl_initialized;
bool verify_server;
+ transport_ssl_conn_state_t conn_state;
} transport_ssl_t;
transport_handle_t transport_get_handle(transport_handle_t t);
static int ssl_close(transport_handle_t t);
+static int ssl_connect_async(transport_handle_t t, const char *host, int port, int timeout_ms)
+{
+ transport_ssl_t *ssl = transport_get_context_data(t);
+ if (ssl->conn_state == TRANS_SSL_INIT) {
+ if (ssl->cfg.cacert_pem_buf) {
+ ssl->verify_server = true;
+ }
+ ssl->cfg.timeout_ms = timeout_ms;
+ ssl->cfg.non_block = true;
+ ssl->ssl_initialized = true;
+ ssl->tls = calloc(1, sizeof(esp_tls_t));
+ if (!ssl->tls) {
+ return -1;
+ }
+ ssl->conn_state = TRANS_SSL_CONNECTING;
+ }
+ if (ssl->conn_state == TRANS_SSL_CONNECTING) {
+ return esp_tls_conn_new_async(host, strlen(host), port, &ssl->cfg, ssl->tls);
+ }
+ return 0;
+}
+
static int ssl_connect(transport_handle_t t, const char *host, int port, int timeout_ms)
{
transport_ssl_t *ssl = transport_get_context_data(t);
- esp_tls_cfg_t cfg = { 0 };
- if (ssl->cert_pem_data) {
+ if (ssl->cfg.cacert_pem_buf) {
ssl->verify_server = true;
- cfg.cacert_pem_buf = ssl->cert_pem_data;
- cfg.cacert_pem_bytes = ssl->cert_pem_len + 1;
}
- cfg.timeout_ms = timeout_ms;
+ ssl->cfg.timeout_ms = timeout_ms;
ssl->ssl_initialized = true;
- ssl->tls = esp_tls_conn_new(host, strlen(host), port, &cfg);
+ ssl->tls = esp_tls_conn_new(host, strlen(host), port, &ssl->cfg);
if (!ssl->tls) {
ESP_LOGE(TAG, "Failed to open a new connection");
return -1;
}
ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
if (ret <= 0) {
- ESP_LOGE(TAG, "mbedtls_ssl_write error, errno=%s", strerror(errno));
+ ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
}
return ret;
}
}
ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
if (ret <= 0) {
- ESP_LOGE(TAG, "mbedtls_ssl_read error, errno=%s", strerror(errno));
+ ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
}
return ret;
}
{
transport_ssl_t *ssl = transport_get_context_data(t);
if (t && ssl) {
- ssl->cert_pem_data = (void *)data;
- ssl->cert_pem_len = len;
+ ssl->cfg.cacert_pem_buf = (void *)data;
+ ssl->cfg.cacert_pem_bytes = len + 1;
}
}
TRANSPORT_MEM_CHECK(TAG, ssl, return NULL);
transport_set_context_data(t, ssl);
transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy, transport_get_handle);
+ transport_set_async_connect_func(t, ssl_connect_async);
return t;
}