]> granicus.if.org Git - esp-idf/blob - components/tcp_transport/transport_ssl.c
esp_http_client: Add support for non-blocking feature in esp_http_client_perform...
[esp-idf] / components / tcp_transport / transport_ssl.c
1 // Copyright 2015-2018 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <string.h>
16 #include <stdlib.h>
17
18 #include "freertos/FreeRTOS.h"
19 #include "freertos/task.h"
20 #include "esp_tls.h"
21 #include "esp_log.h"
22 #include "esp_system.h"
23
24 #include "transport.h"
25 #include "transport_ssl.h"
26 #include "transport_utils.h"
27
28 static const char *TAG = "TRANS_SSL";
29
30 typedef enum {
31     TRANS_SSL_INIT = 0,
32     TRANS_SSL_CONNECTING,
33 } transport_ssl_conn_state_t;
34
35 /**
36  *  mbedtls specific transport data
37  */
38 typedef struct {
39     esp_tls_t                *tls;
40     esp_tls_cfg_t            cfg;
41     bool                     ssl_initialized;
42     bool                     verify_server;
43     transport_ssl_conn_state_t conn_state;
44 } transport_ssl_t;
45
46 transport_handle_t transport_get_handle(transport_handle_t t);
47
48 static int ssl_close(transport_handle_t t);
49
50 static int ssl_connect_async(transport_handle_t t, const char *host, int port, int timeout_ms)
51 {
52     transport_ssl_t *ssl = transport_get_context_data(t);
53     if (ssl->conn_state == TRANS_SSL_INIT) {
54         if (ssl->cfg.cacert_pem_buf) {
55             ssl->verify_server = true;
56         }
57         ssl->cfg.timeout_ms = timeout_ms;
58         ssl->cfg.non_block = true;
59         ssl->ssl_initialized = true;
60         ssl->tls = calloc(1, sizeof(esp_tls_t));
61         if (!ssl->tls) {
62             return -1;
63         }
64         ssl->conn_state = TRANS_SSL_CONNECTING;
65     }
66     if (ssl->conn_state == TRANS_SSL_CONNECTING) {
67         return esp_tls_conn_new_async(host, strlen(host), port, &ssl->cfg, ssl->tls);
68     }
69     return 0;
70 }
71
72 static int ssl_connect(transport_handle_t t, const char *host, int port, int timeout_ms)
73 {
74     transport_ssl_t *ssl = transport_get_context_data(t);
75     if (ssl->cfg.cacert_pem_buf) {
76         ssl->verify_server = true;
77     }
78     ssl->cfg.timeout_ms = timeout_ms;
79     ssl->ssl_initialized = true;
80     ssl->tls = esp_tls_conn_new(host, strlen(host), port, &ssl->cfg);
81     if (!ssl->tls) {
82         ESP_LOGE(TAG, "Failed to open a new connection");
83         return -1;
84     }
85     return 0;
86 }
87
88 static int ssl_poll_read(transport_handle_t t, int timeout_ms)
89 {
90     transport_ssl_t *ssl = transport_get_context_data(t);
91     fd_set readset;
92     FD_ZERO(&readset);
93     FD_SET(ssl->tls->sockfd, &readset);
94     struct timeval timeout;
95     transport_utils_ms_to_timeval(timeout_ms, &timeout);
96
97     return select(ssl->tls->sockfd + 1, &readset, NULL, NULL, &timeout);
98 }
99
100 static int ssl_poll_write(transport_handle_t t, int timeout_ms)
101 {
102     transport_ssl_t *ssl = transport_get_context_data(t);
103     fd_set writeset;
104     FD_ZERO(&writeset);
105     FD_SET(ssl->tls->sockfd, &writeset);
106     struct timeval timeout;
107     transport_utils_ms_to_timeval(timeout_ms, &timeout);
108     return select(ssl->tls->sockfd + 1, NULL, &writeset, NULL, &timeout);
109 }
110
111 static int ssl_write(transport_handle_t t, const char *buffer, int len, int timeout_ms)
112 {
113     int poll, ret;
114     transport_ssl_t *ssl = transport_get_context_data(t);
115
116     if ((poll = transport_poll_write(t, timeout_ms)) <= 0) {
117         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
118         return poll;
119     }
120     ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
121     if (ret <= 0) {
122         ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
123     }
124     return ret;
125 }
126
127 static int ssl_read(transport_handle_t t, char *buffer, int len, int timeout_ms)
128 {
129     int poll, ret;
130     transport_ssl_t *ssl = transport_get_context_data(t);
131
132     if (esp_tls_get_bytes_avail(ssl->tls) <= 0) {
133         if ((poll = transport_poll_read(t, timeout_ms)) <= 0) {
134             return poll;
135         }
136     }
137     ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
138     if (ret <= 0) {
139         ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
140     }
141     return ret;
142 }
143
144 static int ssl_close(transport_handle_t t)
145 {
146     int ret = -1;
147     transport_ssl_t *ssl = transport_get_context_data(t);
148     if (ssl->ssl_initialized) {
149         esp_tls_conn_delete(ssl->tls);
150         ssl->ssl_initialized = false;
151         ssl->verify_server = false;
152     }
153     return ret;
154 }
155
156 static int ssl_destroy(transport_handle_t t)
157 {
158     transport_ssl_t *ssl = transport_get_context_data(t);
159     transport_close(t);
160     free(ssl);
161     return 0;
162 }
163
164 void transport_ssl_set_cert_data(transport_handle_t t, const char *data, int len)
165 {
166     transport_ssl_t *ssl = transport_get_context_data(t);
167     if (t && ssl) {
168         ssl->cfg.cacert_pem_buf = (void *)data;
169         ssl->cfg.cacert_pem_bytes = len + 1;
170     }
171 }
172
173 transport_handle_t transport_ssl_init()
174 {
175     transport_handle_t t = transport_init();
176     transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t));
177     TRANSPORT_MEM_CHECK(TAG, ssl, return NULL);
178     transport_set_context_data(t, ssl);
179     transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy, transport_get_handle);
180     transport_set_async_connect_func(t, ssl_connect_async);
181     return t;
182 }
183