]> granicus.if.org Git - python/commitdiff
add support for ALPN (closes #20188)
authorBenjamin Peterson <benjamin@python.org>
Fri, 23 Jan 2015 21:35:37 +0000 (16:35 -0500)
committerBenjamin Peterson <benjamin@python.org>
Fri, 23 Jan 2015 21:35:37 +0000 (16:35 -0500)
Doc/library/ssl.rst
Lib/ssl.py
Lib/test/test_ssl.py
Misc/NEWS
Modules/_ssl.c

index d77c0284b5721064cfb2611762942f677259758d..11b8aa908508a1e67f0d30ac465f66acb3018aee 100644 (file)
@@ -673,6 +673,13 @@ Constants
 
    .. versionadded:: 3.3
 
+.. data:: HAS_ALPN
+
+   Whether the OpenSSL library has built-in support for the *Application-Layer
+   Protocol Negotiation* TLS extension as described in :rfc:`7301`.
+
+   .. versionadded:: 3.5
+
 .. data:: HAS_ECDH
 
    Whether the OpenSSL library has built-in support for Elliptic Curve-based
@@ -959,9 +966,18 @@ SSL sockets also have the following additional methods and attributes:
 
    .. versionadded:: 3.3
 
+.. method:: SSLSocket.selected_alpn_protocol()
+
+   Return the protocol that was selected during the TLS handshake.  If
+   :meth:`SSLContext.set_alpn_protocols` was not called, if the other party does
+   not support ALPN, or if the handshake has not happened yet, ``None`` is
+   returned.
+
+   .. versionadded:: 3.5
+
 .. method:: SSLSocket.selected_npn_protocol()
 
-   Returns the higher-level protocol that was selected during the TLS/SSL
+   Return the higher-level protocol that was selected during the TLS/SSL
    handshake. If :meth:`SSLContext.set_npn_protocols` was not called, or
    if the other party does not support NPN, or if the handshake has not yet
    happened, this will return ``None``.
@@ -1160,6 +1176,20 @@ to speed up repeated connections from the same clients.
       when connected, the :meth:`SSLSocket.cipher` method of SSL sockets will
       give the currently selected cipher.
 
+.. method:: SSLContext.set_alpn_protocols(protocols)
+
+   Specify which protocols the socket should advertise during the SSL/TLS
+   handshake. It should be a list of ASCII strings, like ``['http/1.1',
+   'spdy/2']``, ordered by preference. The selection of a protocol will happen
+   during the handshake, and will play out according to :rfc:`7301`. After a
+   successful handshake, the :meth:`SSLSocket.selected_alpn_protocol` method will
+   return the agreed-upon protocol.
+
+   This method will raise :exc:`NotImplementedError` if :data:`HAS_ALPN` is
+   False.
+
+   .. versionadded:: 3.5
+
 .. method:: SSLContext.set_npn_protocols(protocols)
 
    Specify which protocols the socket should advertise during the SSL/TLS
@@ -1200,7 +1230,7 @@ to speed up repeated connections from the same clients.
 
    Due to the early negotiation phase of the TLS connection, only limited
    methods and attributes are usable like
-   :meth:`SSLSocket.selected_npn_protocol` and :attr:`SSLSocket.context`.
+   :meth:`SSLSocket.selected_alpn_protocol` and :attr:`SSLSocket.context`.
    :meth:`SSLSocket.getpeercert`, :meth:`SSLSocket.getpeercert`,
    :meth:`SSLSocket.cipher` and :meth:`SSLSocket.compress` methods require that
    the TLS connection has progressed beyond the TLS Client Hello and therefore
index 39019f9b1344616bfce1fb65f0971af9483493fa..807e9f2896d19f346f78f5ede21cca1105255b3e 100644 (file)
@@ -122,7 +122,7 @@ _import_symbols('OP_')
 _import_symbols('ALERT_DESCRIPTION_')
 _import_symbols('SSL_ERROR_')
 
-from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
+from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
 
 from _ssl import _OPENSSL_API_VERSION
 
@@ -374,6 +374,17 @@ class SSLContext(_SSLContext):
 
         self._set_npn_protocols(protos)
 
+    def set_alpn_protocols(self, alpn_protocols):
+        protos = bytearray()
+        for protocol in alpn_protocols:
+            b = bytes(protocol, 'ascii')
+            if len(b) == 0 or len(b) > 255:
+                raise SSLError('ALPN protocols must be 1 to 255 in length')
+            protos.append(len(b))
+            protos.extend(b)
+
+        self._set_alpn_protocols(protos)
+
     def _load_windows_store_certs(self, storename, purpose):
         certs = bytearray()
         for cert, encoding, trust in enum_certificates(storename):
@@ -567,6 +578,13 @@ class SSLObject:
         if _ssl.HAS_NPN:
             return self._sslobj.selected_npn_protocol()
 
+    def selected_alpn_protocol(self):
+        """Return the currently selected ALPN protocol as a string, or ``None``
+        if a next protocol was not negotiated or if ALPN is not supported by one
+        of the peers."""
+        if _ssl.HAS_ALPN:
+            return self._sslobj.selected_alpn_protocol()
+
     def cipher(self):
         """Return the currently selected cipher as a 3-tuple ``(name,
         ssl_version, secret_bits)``."""
@@ -783,6 +801,13 @@ class SSLSocket(socket):
         else:
             return self._sslobj.selected_npn_protocol()
 
+    def selected_alpn_protocol(self):
+        self._checkClosed()
+        if not self._sslobj or not _ssl.HAS_ALPN:
+            return None
+        else:
+            return self._sslobj.selected_alpn_protocol()
+
     def cipher(self):
         self._checkClosed()
         if not self._sslobj:
index b7504c63f07d9075eb1207c4a56a0e7019670041..30af08d0b0d470d1cde0cc1ca825f71fdb120619 100644 (file)
@@ -1761,7 +1761,8 @@ else:
                 try:
                     self.sslconn = self.server.context.wrap_socket(
                         self.sock, server_side=True)
-                    self.server.selected_protocols.append(self.sslconn.selected_npn_protocol())
+                    self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
+                    self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
                 except (ssl.SSLError, ConnectionResetError) as e:
                     # We treat ConnectionResetError as though it were an
                     # SSLError - OpenSSL on Ubuntu abruptly closes the
@@ -1869,7 +1870,8 @@ else:
         def __init__(self, certificate=None, ssl_version=None,
                      certreqs=None, cacerts=None,
                      chatty=True, connectionchatty=False, starttls_server=False,
-                     npn_protocols=None, ciphers=None, context=None):
+                     npn_protocols=None, alpn_protocols=None,
+                     ciphers=None, context=None):
             if context:
                 self.context = context
             else:
@@ -1884,6 +1886,8 @@ else:
                     self.context.load_cert_chain(certificate)
                 if npn_protocols:
                     self.context.set_npn_protocols(npn_protocols)
+                if alpn_protocols:
+                    self.context.set_alpn_protocols(alpn_protocols)
                 if ciphers:
                     self.context.set_ciphers(ciphers)
             self.chatty = chatty
@@ -1893,7 +1897,8 @@ else:
             self.port = support.bind_port(self.sock)
             self.flag = None
             self.active = False
-            self.selected_protocols = []
+            self.selected_npn_protocols = []
+            self.selected_alpn_protocols = []
             self.shared_ciphers = []
             self.conn_errors = []
             threading.Thread.__init__(self)
@@ -2120,11 +2125,13 @@ else:
                     'compression': s.compression(),
                     'cipher': s.cipher(),
                     'peercert': s.getpeercert(),
+                    'client_alpn_protocol': s.selected_alpn_protocol(),
                     'client_npn_protocol': s.selected_npn_protocol(),
                     'version': s.version(),
                 })
                 s.close()
-            stats['server_npn_protocols'] = server.selected_protocols
+            stats['server_alpn_protocols'] = server.selected_alpn_protocols
+            stats['server_npn_protocols'] = server.selected_npn_protocols
             stats['server_shared_ciphers'] = server.shared_ciphers
         return stats
 
@@ -3022,6 +3029,55 @@ else:
             if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
                 self.fail("Non-DH cipher: " + cipher[0])
 
+        def test_selected_alpn_protocol(self):
+            # selected_alpn_protocol() is None unless ALPN is used.
+            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            context.load_cert_chain(CERTFILE)
+            stats = server_params_test(context, context,
+                                       chatty=True, connectionchatty=True)
+            self.assertIs(stats['client_alpn_protocol'], None)
+
+        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
+        def test_selected_alpn_protocol_if_server_uses_alpn(self):
+            # selected_alpn_protocol() is None unless ALPN is used by the client.
+            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            client_context.load_verify_locations(CERTFILE)
+            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            server_context.load_cert_chain(CERTFILE)
+            server_context.set_alpn_protocols(['foo', 'bar'])
+            stats = server_params_test(client_context, server_context,
+                                       chatty=True, connectionchatty=True)
+            self.assertIs(stats['client_alpn_protocol'], None)
+
+        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
+        def test_alpn_protocols(self):
+            server_protocols = ['foo', 'bar', 'milkshake']
+            protocol_tests = [
+                (['foo', 'bar'], 'foo'),
+                (['bar', 'foo'], 'bar'),
+                (['milkshake'], 'milkshake'),
+                (['http/3.0', 'http/4.0'], 'foo')
+            ]
+            for client_protocols, expected in protocol_tests:
+                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+                server_context.load_cert_chain(CERTFILE)
+                server_context.set_alpn_protocols(server_protocols)
+                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+                client_context.load_cert_chain(CERTFILE)
+                client_context.set_alpn_protocols(client_protocols)
+                stats = server_params_test(client_context, server_context,
+                                           chatty=True, connectionchatty=True)
+
+                msg = "failed trying %s (s) and %s (c).\n" \
+                      "was expecting %s, but got %%s from the %%s" \
+                          % (str(server_protocols), str(client_protocols),
+                             str(expected))
+                client_result = stats['client_alpn_protocol']
+                self.assertEqual(client_result, expected, msg % (client_result, "client"))
+                server_result = stats['server_alpn_protocols'][-1] \
+                    if len(stats['server_alpn_protocols']) else 'nothing'
+                self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
         def test_selected_npn_protocol(self):
             # selected_npn_protocol() is None unless NPN is used
             context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
index 5cd2c61e8217ca9070496aa66547261a0527c485..91643f390e410162ca91ef88c4f746b14f8c52a7 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -203,6 +203,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #20188: Support Application-Layer Protocol Negotiation (ALPN) in the ssl
+  module.
+
 - Issue #23133: Pickling of ipaddress objects now produces more compact and
   portable representation.
 
index 596966323e38c2b959348c60e7415f9606357278..2e19439366d665474c4c58ad69543129a465e16e 100644 (file)
@@ -109,6 +109,11 @@ struct py_ssl_library_code {
 # define HAVE_SNI 0
 #endif
 
+/* ALPN added in OpenSSL 1.0.2 */
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+# define HAVE_ALPN
+#endif
+
 enum py_ssl_error {
     /* these mirror ssl.h */
     PY_SSL_ERROR_NONE,
@@ -180,9 +185,13 @@ typedef struct {
     PyObject_HEAD
     SSL_CTX *ctx;
 #ifdef OPENSSL_NPN_NEGOTIATED
-    char *npn_protocols;
+    unsigned char *npn_protocols;
     int npn_protocols_len;
 #endif
+#ifdef HAVE_ALPN
+    unsigned char *alpn_protocols;
+    int alpn_protocols_len;
+#endif
 #ifndef OPENSSL_NO_TLSEXT
     PyObject *set_hostname;
 #endif
@@ -1460,7 +1469,20 @@ static PyObject *PySSL_selected_npn_protocol(PySSLSocket *self) {
 
     if (out == NULL)
         Py_RETURN_NONE;
-    return PyUnicode_FromStringAndSize((char *) out, outlen);
+    return PyUnicode_FromStringAndSize((char *)out, outlen);
+}
+#endif
+
+#ifdef HAVE_ALPN
+static PyObject *PySSL_selected_alpn_protocol(PySSLSocket *self) {
+    const unsigned char *out;
+    unsigned int outlen;
+
+    SSL_get0_alpn_selected(self->ssl, &out, &outlen);
+
+    if (out == NULL)
+        Py_RETURN_NONE;
+    return PyUnicode_FromStringAndSize((char *)out, outlen);
 }
 #endif
 
@@ -2053,6 +2075,9 @@ static PyMethodDef PySSLMethods[] = {
     {"version", (PyCFunction)PySSL_version, METH_NOARGS},
 #ifdef OPENSSL_NPN_NEGOTIATED
     {"selected_npn_protocol", (PyCFunction)PySSL_selected_npn_protocol, METH_NOARGS},
+#endif
+#ifdef HAVE_ALPN
+    {"selected_alpn_protocol", (PyCFunction)PySSL_selected_alpn_protocol, METH_NOARGS},
 #endif
     {"compression", (PyCFunction)PySSL_compression, METH_NOARGS},
     {"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS,
@@ -2159,6 +2184,9 @@ context_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 #ifdef OPENSSL_NPN_NEGOTIATED
     self->npn_protocols = NULL;
 #endif
+#ifdef HAVE_ALPN
+    self->alpn_protocols = NULL;
+#endif
 #ifndef OPENSSL_NO_TLSEXT
     self->set_hostname = NULL;
 #endif
@@ -2218,7 +2246,10 @@ context_dealloc(PySSLContext *self)
     context_clear(self);
     SSL_CTX_free(self->ctx);
 #ifdef OPENSSL_NPN_NEGOTIATED
-    PyMem_Free(self->npn_protocols);
+    PyMem_FREE(self->npn_protocols);
+#endif
+#ifdef HAVE_ALPN
+    PyMem_FREE(self->alpn_protocols);
 #endif
     Py_TYPE(self)->tp_free(self);
 }
@@ -2244,6 +2275,23 @@ set_ciphers(PySSLContext *self, PyObject *args)
     Py_RETURN_NONE;
 }
 
+static int
+do_protocol_selection(unsigned char **out, unsigned char *outlen,
+                      const unsigned char *remote_protocols, unsigned int remote_protocols_len,
+                      unsigned char *our_protocols, unsigned int our_protocols_len)
+{
+    if (our_protocols == NULL) {
+        our_protocols = (unsigned char*)"";
+        our_protocols_len = 0;
+    }
+
+    SSL_select_next_proto(out, outlen,
+                          remote_protocols, remote_protocols_len,
+                          our_protocols, our_protocols_len);
+
+    return SSL_TLSEXT_ERR_OK;
+}
+
 #ifdef OPENSSL_NPN_NEGOTIATED
 /* this callback gets passed to SSL_CTX_set_next_protos_advertise_cb */
 static int
@@ -2254,10 +2302,10 @@ _advertiseNPN_cb(SSL *s,
     PySSLContext *ssl_ctx = (PySSLContext *) args;
 
     if (ssl_ctx->npn_protocols == NULL) {
-        *data = (unsigned char *) "";
+        *data = (unsigned char *)"";
         *len = 0;
     } else {
-        *data = (unsigned char *) ssl_ctx->npn_protocols;
+        *data = ssl_ctx->npn_protocols;
         *len = ssl_ctx->npn_protocols_len;
     }
 
@@ -2270,23 +2318,9 @@ _selectNPN_cb(SSL *s,
               const unsigned char *server, unsigned int server_len,
               void *args)
 {
-    PySSLContext *ssl_ctx = (PySSLContext *) args;
-
-    unsigned char *client = (unsigned char *) ssl_ctx->npn_protocols;
-    int client_len;
-
-    if (client == NULL) {
-        client = (unsigned char *) "";
-        client_len = 0;
-    } else {
-        client_len = ssl_ctx->npn_protocols_len;
-    }
-
-    SSL_select_next_proto(out, outlen,
-                          server, server_len,
-                          client, client_len);
-
-    return SSL_TLSEXT_ERR_OK;
+    PySSLContext *ctx = (PySSLContext *)args;
+    return do_protocol_selection(out, outlen, server, server_len,
+                                 ctx->npn_protocols, ctx->npn_protocols_len);
 }
 #endif
 
@@ -2329,6 +2363,50 @@ _set_npn_protocols(PySSLContext *self, PyObject *args)
 #endif
 }
 
+#ifdef HAVE_ALPN
+static int
+_selectALPN_cb(SSL *s,
+              const unsigned char **out, unsigned char *outlen,
+              const unsigned char *client_protocols, unsigned int client_protocols_len,
+              void *args)
+{
+    PySSLContext *ctx = (PySSLContext *)args;
+    return do_protocol_selection((unsigned char **)out, outlen,
+                                 client_protocols, client_protocols_len,
+                                 ctx->alpn_protocols, ctx->alpn_protocols_len);
+}
+#endif
+
+static PyObject *
+_set_alpn_protocols(PySSLContext *self, PyObject *args)
+{
+#ifdef HAVE_ALPN
+    Py_buffer protos;
+
+    if (!PyArg_ParseTuple(args, "y*:set_npn_protocols", &protos))
+        return NULL;
+
+    PyMem_FREE(self->alpn_protocols);
+    self->alpn_protocols = PyMem_Malloc(protos.len);
+    if (!self->alpn_protocols)
+        return PyErr_NoMemory();
+    memcpy(self->alpn_protocols, protos.buf, protos.len);
+    self->alpn_protocols_len = protos.len;
+    PyBuffer_Release(&protos);
+
+    if (SSL_CTX_set_alpn_protos(self->ctx, self->alpn_protocols, self->alpn_protocols_len))
+        return PyErr_NoMemory();
+    SSL_CTX_set_alpn_select_cb(self->ctx, _selectALPN_cb, self);
+
+    PyBuffer_Release(&protos);
+    Py_RETURN_NONE;
+#else
+    PyErr_SetString(PyExc_NotImplementedError,
+                    "The ALPN extension requires OpenSSL 1.0.2 or later.");
+    return NULL;
+#endif
+}
+
 static PyObject *
 get_verify_mode(PySSLContext *self, void *c)
 {
@@ -3307,6 +3385,8 @@ static struct PyMethodDef context_methods[] = {
                   METH_VARARGS | METH_KEYWORDS, NULL},
     {"set_ciphers", (PyCFunction) set_ciphers,
                     METH_VARARGS, NULL},
+    {"_set_alpn_protocols", (PyCFunction) _set_alpn_protocols,
+                           METH_VARARGS, NULL},
     {"_set_npn_protocols", (PyCFunction) _set_npn_protocols,
                            METH_VARARGS, NULL},
     {"load_cert_chain", (PyCFunction) load_cert_chain,
@@ -4502,6 +4582,14 @@ PyInit__ssl(void)
     Py_INCREF(r);
     PyModule_AddObject(m, "HAS_NPN", r);
 
+#ifdef HAVE_ALPN
+    r = Py_True;
+#else
+    r = Py_False;
+#endif
+    Py_INCREF(r);
+    PyModule_AddObject(m, "HAS_ALPN", r);
+
     /* Mappings for error codes */
     err_codes_to_names = PyDict_New();
     err_names_to_codes = PyDict_New();