]> granicus.if.org Git - esp-idf/blob - components/tcp_transport/transport_ssl.c
Merge branch 'bugfix/fix_mesh_proxy_adv_with_wrong_dev_name' into 'master'
[esp-idf] / components / tcp_transport / transport_ssl.c
1 // Copyright 2015-2018 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <string.h>
16 #include <stdlib.h>
17
18 #include "freertos/FreeRTOS.h"
19 #include "freertos/task.h"
20 #include "esp_tls.h"
21 #include "esp_log.h"
22 #include "esp_system.h"
23
24 #include "esp_transport.h"
25 #include "esp_transport_ssl.h"
26 #include "esp_transport_utils.h"
27 #include "esp_transport_ssl_internal.h"
28
29 static const char *TAG = "TRANS_SSL";
30
31 typedef enum {
32     TRANS_SSL_INIT = 0,
33     TRANS_SSL_CONNECTING,
34 } transport_ssl_conn_state_t;
35
36 /**
37  *  mbedtls specific transport data
38  */
39 typedef struct {
40     esp_tls_t                *tls;
41     esp_tls_cfg_t            cfg;
42     bool                     ssl_initialized;
43     transport_ssl_conn_state_t conn_state;
44 } transport_ssl_t;
45
46 static int ssl_close(esp_transport_handle_t t);
47
48 static int ssl_connect_async(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
49 {
50     transport_ssl_t *ssl = esp_transport_get_context_data(t);
51     if (ssl->conn_state == TRANS_SSL_INIT) {
52         ssl->cfg.timeout_ms = timeout_ms;
53         ssl->cfg.non_block = true;
54         ssl->ssl_initialized = true;
55         ssl->tls = esp_tls_init();
56         if (!ssl->tls) {
57             return -1;
58         }
59         ssl->conn_state = TRANS_SSL_CONNECTING;
60     }
61     if (ssl->conn_state == TRANS_SSL_CONNECTING) {
62         return esp_tls_conn_new_async(host, strlen(host), port, &ssl->cfg, ssl->tls);
63     }
64     return 0;
65 }
66
67 static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
68 {
69     transport_ssl_t *ssl = esp_transport_get_context_data(t);
70
71     ssl->cfg.timeout_ms = timeout_ms;
72     ssl->ssl_initialized = true;
73     ssl->tls = esp_tls_init();
74     if (esp_tls_conn_new_sync(host, strlen(host), port, &ssl->cfg, ssl->tls) < 0) {
75         ESP_LOGE(TAG, "Failed to open a new connection");
76         esp_transport_set_errors(t, ssl->tls->error_handle);
77         esp_tls_conn_delete(ssl->tls);
78         ssl->tls = NULL;
79         return -1;
80     }
81
82     return 0;
83 }
84
85 static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms)
86 {
87     transport_ssl_t *ssl = esp_transport_get_context_data(t);
88     int ret = -1;
89     fd_set readset;
90     fd_set errset;
91     FD_ZERO(&readset);
92     FD_ZERO(&errset);
93     FD_SET(ssl->tls->sockfd, &readset);
94     FD_SET(ssl->tls->sockfd, &errset);
95     struct timeval timeout;
96     esp_transport_utils_ms_to_timeval(timeout_ms, &timeout);
97
98     ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, &timeout);
99     if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) {
100         int sock_errno = 0;
101         uint32_t optlen = sizeof(sock_errno);
102         getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
103         ESP_LOGE(TAG, "ssl_poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
104         ret = -1;
105     }
106     return ret;
107 }
108
109 static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms)
110 {
111     transport_ssl_t *ssl = esp_transport_get_context_data(t);
112     int ret = -1;
113     fd_set writeset;
114     fd_set errset;
115     FD_ZERO(&writeset);
116     FD_ZERO(&errset);
117     FD_SET(ssl->tls->sockfd, &writeset);
118     FD_SET(ssl->tls->sockfd, &errset);
119     struct timeval timeout;
120     esp_transport_utils_ms_to_timeval(timeout_ms, &timeout);
121     ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, &timeout);
122     if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) {
123         int sock_errno = 0;
124         uint32_t optlen = sizeof(sock_errno);
125         getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
126         ESP_LOGE(TAG, "ssl_poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd);
127         ret = -1;
128     }
129     return ret;
130 }
131
132 static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
133 {
134     int poll, ret;
135     transport_ssl_t *ssl = esp_transport_get_context_data(t);
136
137     if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
138         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
139         return poll;
140     }
141     ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
142     if (ret < 0) {
143         ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno));
144         esp_transport_set_errors(t, ssl->tls->error_handle);
145     }
146     return ret;
147 }
148
149 static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
150 {
151     int poll, ret;
152     transport_ssl_t *ssl = esp_transport_get_context_data(t);
153
154     if (esp_tls_get_bytes_avail(ssl->tls) <= 0) {
155         if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) {
156             return poll;
157         }
158     }
159     ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
160     if (ret < 0) {
161         ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno));
162         esp_transport_set_errors(t, ssl->tls->error_handle);
163     }
164     if (ret == 0) {
165         ret = -1;
166     }
167     return ret;
168 }
169
170 static int ssl_close(esp_transport_handle_t t)
171 {
172     int ret = -1;
173     transport_ssl_t *ssl = esp_transport_get_context_data(t);
174     if (ssl->ssl_initialized) {
175         esp_tls_conn_delete(ssl->tls);
176         ssl->ssl_initialized = false;
177     }
178     return ret;
179 }
180
181 static int ssl_destroy(esp_transport_handle_t t)
182 {
183     transport_ssl_t *ssl = esp_transport_get_context_data(t);
184     esp_transport_close(t);
185     free(ssl);
186     return 0;
187 }
188
189 void esp_transport_ssl_enable_global_ca_store(esp_transport_handle_t t)
190 {
191     transport_ssl_t *ssl = esp_transport_get_context_data(t);
192     if (t && ssl) {
193         ssl->cfg.use_global_ca_store = true;
194     }
195 }
196
197 void esp_transport_ssl_set_psk_key_hint(esp_transport_handle_t t, const psk_hint_key_t* psk_hint_key)
198 {
199     transport_ssl_t *ssl = esp_transport_get_context_data(t);
200     if (t && ssl) {
201         ssl->cfg.psk_hint_key =  psk_hint_key;
202     }
203 }
204
205 void esp_transport_ssl_set_cert_data(esp_transport_handle_t t, const char *data, int len)
206 {
207     transport_ssl_t *ssl = esp_transport_get_context_data(t);
208     if (t && ssl) {
209         ssl->cfg.cacert_pem_buf = (void *)data;
210         ssl->cfg.cacert_pem_bytes = len + 1;
211     }
212 }
213
214 void esp_transport_ssl_set_cert_data_der(esp_transport_handle_t t, const char *data, int len)
215 {
216     transport_ssl_t *ssl = esp_transport_get_context_data(t);
217     if (t && ssl) {
218         ssl->cfg.cacert_buf = (void *)data;
219         ssl->cfg.cacert_bytes = len;
220     }
221 }
222
223 void esp_transport_ssl_set_client_cert_data(esp_transport_handle_t t, const char *data, int len)
224 {
225     transport_ssl_t *ssl = esp_transport_get_context_data(t);
226     if (t && ssl) {
227         ssl->cfg.clientcert_pem_buf = (void *)data;
228         ssl->cfg.clientcert_pem_bytes = len + 1;
229     }
230 }
231
232 void esp_transport_ssl_set_client_cert_data_der(esp_transport_handle_t t, const char *data, int len)
233 {
234     transport_ssl_t *ssl = esp_transport_get_context_data(t);
235     if (t && ssl) {
236         ssl->cfg.clientcert_buf = (void *)data;
237         ssl->cfg.clientcert_bytes = len;
238     }
239 }
240
241 void esp_transport_ssl_set_client_key_data(esp_transport_handle_t t, const char *data, int len)
242 {
243     transport_ssl_t *ssl = esp_transport_get_context_data(t);
244     if (t && ssl) {
245         ssl->cfg.clientkey_pem_buf = (void *)data;
246         ssl->cfg.clientkey_pem_bytes = len + 1;
247     }
248 }
249
250 void esp_transport_ssl_set_client_key_data_der(esp_transport_handle_t t, const char *data, int len)
251 {
252     transport_ssl_t *ssl = esp_transport_get_context_data(t);
253     if (t && ssl) {
254         ssl->cfg.clientkey_buf = (void *)data;
255         ssl->cfg.clientkey_bytes = len;
256     }
257 }
258
259 void esp_transport_ssl_skip_common_name_check(esp_transport_handle_t t)
260 {
261     transport_ssl_t *ssl = esp_transport_get_context_data(t);
262     if (t && ssl) {
263         ssl->cfg.skip_common_name = true;
264     }
265 }
266
267 esp_transport_handle_t esp_transport_ssl_init(void)
268 {
269     esp_transport_handle_t t = esp_transport_init();
270     transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t));
271     ESP_TRANSPORT_MEM_CHECK(TAG, ssl, return NULL);
272     esp_transport_set_context_data(t, ssl);
273     esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
274     esp_transport_set_async_connect_func(t, ssl_connect_async);
275     return t;
276 }
277