]> granicus.if.org Git - esp-idf/blob - components/tcp_transport/transport_ssl.c
Merge branch 'feature/py3_espcoredump' into 'master'
[esp-idf] / components / tcp_transport / transport_ssl.c
1 // Copyright 2015-2018 Espressif Systems (Shanghai) PTE LTD
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <string.h>
16 #include <stdlib.h>
17
18 #include "freertos/FreeRTOS.h"
19 #include "freertos/task.h"
20 #include "esp_tls.h"
21 #include "esp_log.h"
22 #include "esp_system.h"
23
24 #include "transport.h"
25 #include "transport_ssl.h"
26 #include "transport_utils.h"
27
28 static const char *TAG = "TRANS_SSL";
29 /**
30  *  mbedtls specific transport data
31  */
32 typedef struct {
33     esp_tls_t                *tls;
34     void                     *cert_pem_data;
35     int                      cert_pem_len;
36     bool                     ssl_initialized;
37     bool                     verify_server;
38 } transport_ssl_t;
39
40 transport_handle_t transport_get_handle(transport_handle_t t);
41
42 static int ssl_close(transport_handle_t t);
43
44 static int ssl_connect(transport_handle_t t, const char *host, int port, int timeout_ms)
45 {
46     transport_ssl_t *ssl = transport_get_context_data(t);
47     esp_tls_cfg_t cfg = { 0 };
48     if (ssl->cert_pem_data) {
49         ssl->verify_server = true;
50         cfg.cacert_pem_buf = ssl->cert_pem_data;
51         cfg.cacert_pem_bytes = ssl->cert_pem_len + 1;
52     }
53     cfg.timeout_ms = timeout_ms;
54     ssl->ssl_initialized = true;
55     ssl->tls = esp_tls_conn_new(host, strlen(host), port, &cfg);
56     if (!ssl->tls) {
57         ESP_LOGE(TAG, "Failed to open a new connection");
58         return -1;
59     }
60     return 0;
61 }
62
63 static int ssl_poll_read(transport_handle_t t, int timeout_ms)
64 {
65     transport_ssl_t *ssl = transport_get_context_data(t);
66     fd_set readset;
67     FD_ZERO(&readset);
68     FD_SET(ssl->tls->sockfd, &readset);
69     struct timeval timeout;
70     transport_utils_ms_to_timeval(timeout_ms, &timeout);
71
72     return select(ssl->tls->sockfd + 1, &readset, NULL, NULL, &timeout);
73 }
74
75 static int ssl_poll_write(transport_handle_t t, int timeout_ms)
76 {
77     transport_ssl_t *ssl = transport_get_context_data(t);
78     fd_set writeset;
79     FD_ZERO(&writeset);
80     FD_SET(ssl->tls->sockfd, &writeset);
81     struct timeval timeout;
82     transport_utils_ms_to_timeval(timeout_ms, &timeout);
83     return select(ssl->tls->sockfd + 1, NULL, &writeset, NULL, &timeout);
84 }
85
86 static int ssl_write(transport_handle_t t, const char *buffer, int len, int timeout_ms)
87 {
88     int poll, ret;
89     transport_ssl_t *ssl = transport_get_context_data(t);
90
91     if ((poll = transport_poll_write(t, timeout_ms)) <= 0) {
92         ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms);
93         return poll;
94     }
95     ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len);
96     if (ret <= 0) {
97         ESP_LOGE(TAG, "mbedtls_ssl_write error, errno=%s", strerror(errno));
98     }
99     return ret;
100 }
101
102 static int ssl_read(transport_handle_t t, char *buffer, int len, int timeout_ms)
103 {
104     int poll, ret;
105     transport_ssl_t *ssl = transport_get_context_data(t);
106
107     if (esp_tls_get_bytes_avail(ssl->tls) <= 0) {
108         if ((poll = transport_poll_read(t, timeout_ms)) <= 0) {
109             return poll;
110         }
111     }
112     ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len);
113     if (ret <= 0) {
114         ESP_LOGE(TAG, "mbedtls_ssl_read error, errno=%s", strerror(errno));
115     }
116     return ret;
117 }
118
119 static int ssl_close(transport_handle_t t)
120 {
121     int ret = -1;
122     transport_ssl_t *ssl = transport_get_context_data(t);
123     if (ssl->ssl_initialized) {
124         esp_tls_conn_delete(ssl->tls);
125         ssl->ssl_initialized = false;
126         ssl->verify_server = false;
127     }
128     return ret;
129 }
130
131 static int ssl_destroy(transport_handle_t t)
132 {
133     transport_ssl_t *ssl = transport_get_context_data(t);
134     transport_close(t);
135     free(ssl);
136     return 0;
137 }
138
139 void transport_ssl_set_cert_data(transport_handle_t t, const char *data, int len)
140 {
141     transport_ssl_t *ssl = transport_get_context_data(t);
142     if (t && ssl) {
143         ssl->cert_pem_data = (void *)data;
144         ssl->cert_pem_len = len;
145     }
146 }
147
148 transport_handle_t transport_ssl_init()
149 {
150     transport_handle_t t = transport_init();
151     transport_ssl_t *ssl = calloc(1, sizeof(transport_ssl_t));
152     TRANSPORT_MEM_CHECK(TAG, ssl, return NULL);
153     transport_set_context_data(t, ssl);
154     transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy, transport_get_handle);
155     return t;
156 }
157