]> granicus.if.org Git - python/commitdiff
OpenSSL support. This is based on patches for a version of SSLeay by
authorGuido van Rossum <guido@python.org>
Tue, 7 Dec 1999 21:37:17 +0000 (21:37 +0000)
committerGuido van Rossum <guido@python.org>
Tue, 7 Dec 1999 21:37:17 +0000 (21:37 +0000)
Brian E Gallew, which were improved and adapted to OpenSSL 0.9.4 by
Laszlo Kovacs of HP.  Both have kindly given permission to include
the patches in the Python distribution.  Final formatting by GvR.

Lib/httplib.py
Lib/urllib.py
Modules/socketmodule.c

index 05289b3da16b99767164b61bb71a7fa1c706b8e3..fade6948f55ffd2eeb6ffd8dffdb2a6d675c780d 100644 (file)
@@ -28,17 +28,48 @@ second request to the same server, you create a new HTTP object.
 connection for each request.)
 """
 
+import os
 import socket
 import string
 import mimetools
 
+try:
+    from cStringIO import StringIO
+except:
+    from StringIO import StringIO
+
 HTTP_VERSION = 'HTTP/1.0'
 HTTP_PORT = 80
+HTTPS_PORT = 443
+
+class FakeSocket:
+    def __init__(self, sock, ssl):
+       self.__sock = sock
+       self.__ssl = ssl
+       return
+
+    def makefile(self, mode):          # hopefully, never have to write
+       msgbuf = ""
+       while 1:
+           try:
+               msgbuf = msgbuf + self.__ssl.read()
+           except socket.sslerror, msg:
+               break
+       return StringIO(msgbuf)
+
+    def send(self, stuff, flags = 0):
+       return self.__ssl.write(stuff)
+
+    def recv(self, len = 1024, flags = 0):
+       return self.__ssl.read(len)
+
+    def __getattr__(self, attr):
+       return getattr(self.__sock, attr)
 
 class HTTP:
     """This class manages a connection to an HTTP server."""
-    
-    def __init__(self, host = '', port = 0):
+
+    def __init__(self, host = '', port = 0, **x509):
         """Initialize a new instance.
 
         If specified, `host' is the name of the remote host to which
@@ -46,10 +77,12 @@ class HTTP:
         to connect.  By default, httplib.HTTP_PORT is used.
 
         """
+        self.key_file = x509.get('key_file')
+        self.cert_file = x509.get('cert_file')
         self.debuglevel = 0
         self.file = None
         if host: self.connect(host, port)
-    
+
     def set_debuglevel(self, debuglevel):
         """Set the debug output level.
 
@@ -58,10 +91,10 @@ class HTTP:
 
         """
         self.debuglevel = debuglevel
-    
+
     def connect(self, host, port = 0):
         """Connect to a host on a given port.
-        
+
         Note:  This method is automatically invoked by __init__,
         if a host is specified during instantiation.
 
@@ -77,12 +110,12 @@ class HTTP:
         self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         if self.debuglevel > 0: print 'connect:', (host, port)
         self.sock.connect(host, port)
-    
+
     def send(self, str):
         """Send `str' to the server."""
         if self.debuglevel > 0: print 'send:', `str`
         self.sock.send(str)
-    
+
     def putrequest(self, request, selector):
         """Send a request to the server.
 
@@ -94,7 +127,7 @@ class HTTP:
         if not selector: selector = '/'
         str = '%s %s %s\r\n' % (request, selector, HTTP_VERSION)
         self.send(str)
-    
+
     def putheader(self, header, *args):
         """Send a request header line to the server.
 
@@ -103,14 +136,14 @@ class HTTP:
         """
         str = '%s: %s\r\n' % (header, string.joinfields(args,'\r\n\t'))
         self.send(str)
-    
+
     def endheaders(self):
         """Indicate that the last header line has been sent to the server."""
         self.send('\r\n')
-    
+
     def getreply(self):
         """Get a reply from the server.
-        
+
         Returns a tuple consisting of:
         - server response code (e.g. '200' if all goes well)
         - server response string corresponding to response code
@@ -136,7 +169,7 @@ class HTTP:
         errmsg = string.strip(msg)
         self.headers = mimetools.Message(self.file, 0)
         return errcode, errmsg, self.headers
-    
+
     def getfile(self):
         """Get a file object from which to receive data from the HTTP server.
 
@@ -145,7 +178,7 @@ class HTTP:
 
         """
         return self.file
-    
+
     def close(self):
         """Close the connection to the HTTP server."""
         if self.file:
@@ -155,6 +188,31 @@ class HTTP:
             self.sock.close()
         self.sock = None
 
+if hasattr(socket, "ssl"):
+    class HTTPS(HTTP):
+        """This class allows communication via SSL."""
+
+        def connect(self, host, port = 0):
+            """Connect to a host on a given port.
+
+            Note:  This method is automatically invoked by __init__,
+            if a host is specified during instantiation.
+
+            """
+            if not port:
+                i = string.find(host, ':')
+                if i >= 0:
+                    host, port = host[:i], host[i+1:]
+                    try: port = string.atoi(port)
+                    except string.atoi_error:
+                        raise socket.error, "nonnumeric port"
+            if not port: port = HTTPS_PORT
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            if self.debuglevel > 0: print 'connect:', (host, port)
+            sock.connect(host, port)
+            ssl = socket.ssl(sock, self.key_file, self.cert_file)
+            self.sock = FakeSocket(sock, ssl)
+
 
 def test():
     """Test this module.
@@ -170,6 +228,7 @@ def test():
     dl = 0
     for o, a in opts:
         if o == '-d': dl = dl + 1
+    print "testing HTTP..."
     host = 'www.python.org'
     selector = '/'
     if args[0:]: host = args[0]
@@ -187,6 +246,26 @@ def test():
         for header in headers.headers: print string.strip(header)
     print
     print h.getfile().read()
+    if hasattr(socket, "ssl"):
+        print "-"*40
+        print "testing HTTPS..."
+        host = 'synergy.as.cmu.edu'
+        selector = '/~geek/'
+        if args[0:]: host = args[0]
+        if args[1:]: selector = args[1]
+        h = HTTPS()
+        h.set_debuglevel(dl)
+        h.connect(host)
+        h.putrequest('GET', selector)
+        h.endheaders()
+        errcode, errmsg, headers = h.getreply()
+        print 'errcode =', errcode
+        print 'errmsg  =', errmsg
+        print
+        if headers:
+            for header in headers.headers: print string.strip(header)
+        print
+        print h.getfile().read()
 
 
 if __name__ == '__main__':
index 4bd329f2645dcbddb8f69aaa96fe12a8f446ddcc..ffbab2264d6f311d3170e91a843c272739822558 100644 (file)
@@ -27,7 +27,7 @@ import os
 import sys
 
 
-__version__ = '1.11'    # XXX This version is not always updated :-(
+__version__ = '1.12'    # XXX This version is not always updated :-(
 
 MAXFTPCACHE = 10        # Trim the ftp cache beyond this size
 
@@ -81,11 +81,13 @@ class URLopener:
     __tempfiles = None
 
     # Constructor
-    def __init__(self, proxies=None):
+    def __init__(self, proxies=None, **x509):
         if proxies is None:
             proxies = getproxies()
         assert hasattr(proxies, 'has_key'), "proxies must be a mapping"
         self.proxies = proxies
+        self.key_file = x509.get('key_file')
+        self.cert_file = x509.get('cert_file')
         server_version = "Python-urllib/%s" % __version__
         self.addheaders = [('User-agent', server_version)]
         self.__tempfiles = []
@@ -144,6 +146,7 @@ class URLopener:
             host, selector = splithost(proxy)
             url = (host, fullurl) # Signal special case to open_*()
         name = 'open_' + type
+        self.type = type
         if '-' in name:
             # replace - with _
             name = string.join(string.split(name, '-'), '_')
@@ -294,6 +297,42 @@ class URLopener:
         fp.close()
         raise IOError, ('http error', errcode, errmsg, headers)
 
+    # Use HTTPS protocol
+    if hasattr(socket, "ssl"):
+        def open_https(self, url):
+            import httplib
+            if type(url) is type(""):
+                host, selector = splithost(url)
+                user_passwd, host = splituser(host)
+            else:
+                host, selector = url
+                urltype, rest = splittype(selector)
+                if string.lower(urltype) == 'https':
+                    realhost, rest = splithost(rest)
+                    user_passwd, realhost = splituser(realhost)
+                    if user_passwd:
+                        selector = "%s://%s%s" % (urltype, realhost, rest)
+                print "proxy via https:", host, selector
+            if not host: raise IOError, ('https error', 'no host given')
+            if user_passwd:
+                import base64
+                auth = string.strip(base64.encodestring(user_passwd))
+            else:
+                auth = None
+            h = httplib.HTTPS(host, 0,
+                              key_file=self.key_file,
+                              cert_file=self.cert_file)
+            h.putrequest('GET', selector)
+            if auth: h.putheader('Authorization: Basic %s' % auth)
+            for args in self.addheaders: apply(h.putheader, args)
+            h.endheaders()
+            errcode, errmsg, headers = h.getreply()
+            fp = h.getfile()
+            if errcode == 200:
+                return addinfourl(fp, headers, url)
+            else:
+                return self.http_error(url, fp, errcode, errmsg, headers)
+  
     # Use Gopher protocol
     def open_gopher(self, url):
         import gopherlib
@@ -477,7 +516,8 @@ class FancyURLopener(URLopener):
             if match:
                 scheme, realm = match.groups()
                 if string.lower(scheme) == 'basic':
-                    return self.retry_http_basic_auth(url, realm, data)
+                   name = 'retry_' + self.type + '_basic_auth'
+                   return getattr(self,name)(url, realm)
 
     def retry_http_basic_auth(self, url, realm, data):
         host, selector = splithost(url)
@@ -488,6 +528,16 @@ class FancyURLopener(URLopener):
         host = user + ':' + passwd + '@' + host
         newurl = 'http://' + host + selector
         return self.open(newurl, data)
+   
+    def retry_https_basic_auth(self, url, realm):
+            host, selector = splithost(url)
+            i = string.find(host, '@') + 1
+            host = host[i:]
+            user, passwd = self.get_user_passwd(host, realm, i)
+            if not (user or passwd): return None
+            host = user + ':' + passwd + '@' + host
+            newurl = '//' + host + selector
+            return self.open_https(newurl)
 
     def get_user_passwd(self, host, realm, clear_cache = 0):
         key = realm + '@' + string.lower(host)
@@ -630,8 +680,8 @@ class addbase:
         self.fp = fp
         self.read = self.fp.read
         self.readline = self.fp.readline
-        self.readlines = self.fp.readlines
-        self.fileno = self.fp.fileno
+        if hasattr(self.fp, "readlines"): self.readlines = self.fp.readlines
+        if hasattr(self.fp, "fileno"): self.fileno = self.fp.fileno
     def __repr__(self):
         return '<%s at %s whose fp = %s>' % (self.__class__.__name__,
                                              `id(self)`, `self.fp`) 
@@ -1015,6 +1065,8 @@ def test(args=[]):
 ##          'gopher://gopher.micro.umn.edu/1/',
             'http://www.python.org/index.html',
             ]
+        if hasattr(URLopener, "open_https"):
+            args.append('https://synergy.as.cmu.edu/~geek/')
     try:
         for url in args:
             print '-'*10, url, '-'*10
index 4f63434bebf2238f4cb20b3f9e15844ac4e7f449..5e63e53910620e845b8ae16c4e3c013207ba7921 100644 (file)
@@ -31,6 +31,8 @@ PERFORMANCE OF THIS SOFTWARE.
 
 /* Socket module */
 
+/* SSL support based on patches by Brian E Gallew and Laszlo Kovacs */
+
 /*
 This module provides an interface to Berkeley socket IPC.
 
@@ -56,6 +58,7 @@ Module interface:
 - socket.AF_INET, socket.SOCK_STREAM, etc.: constants from <socket.h>
 - socket.inet_aton(IP address) -> 32-bit packed IP representation
 - socket.inet_ntoa(packed IP) -> IP address string
+- socket.ssl(socket, keyfile, certfile) -> new ssl object
 - an Internet socket address is a pair (hostname, port)
   where hostname can be anything recognized by gethostbyname()
   (including the dd.dd.dd.dd notation) and port is in host byte order
@@ -190,6 +193,14 @@ int shutdown( int, int );
 #include <GUSI.h>
 #endif
 
+#ifdef USE_SSL
+#include "rsa.h"
+#include "crypto.h"
+#include "x509.h"
+#include "pem.h"
+#include "ssl.h"
+#include "err.h"
+#endif /* USE_SSL */
 
 /* Here we have some hacks to choose between K&R or ANSI style function
    definitions.  For NT to build this as an extension module (ie, DLL)
@@ -253,6 +264,10 @@ fnname( arg1name, arg2name, arg3name, arg4name )   \
 
 static PyObject *PySocket_Error;
 
+#ifdef USE_SSL
+static PyObject *SSLErrorObject;
+#endif /* USE_SSL */
+
 
 /* Convenience function to raise an error according to errno
    and return a NULL pointer from a function. */
@@ -324,7 +339,30 @@ typedef struct {
        } sock_addr;
 } PySocketSockObject;
 
+#ifdef USE_SSL
+
+typedef struct {
+       PyObject_HEAD
+       PySocketSockObject *Socket;     /* Socket on which we're layered */
+       PyObject        *x_attr;        /* Attributes dictionary */
+       SSL_CTX*        ctx;
+       SSL*            ssl;
+       X509*           server_cert;
+       BIO*            sbio;
+       char            server[256];
+       char            issuer[256];
+
+} SSLObject;
 
+staticforward PyTypeObject SSL_Type;
+staticforward int SSL_setattr(SSLObject *self, char *name, PyObject *v);
+staticforward PyObject *SSL_SSLwrite(SSLObject *self, PyObject *args);
+staticforward PyObject *SSL_SSLread(SSLObject *self, PyObject *args);
+
+#define SSLObject_Check(v)     ((v)->ob_type == &SSL_Type)
+
+#endif /* USE_SSL */
 /* A forward reference to the Socktype type object.
    The Socktype variable contains pointers to various functions,
    some of which call newsockobject(), which uses Socktype, so
@@ -1874,6 +1912,242 @@ BUILD_FUNC_DEF_2(PySocket_inet_ntoa, PyObject *, self, PyObject *, args)
        return PyString_FromString(inet_ntoa(packed_addr));
 }
 
+
+#ifdef USE_SSL
+
+/* This is a C function to be called for new object initialization */
+static SSLObject *
+BUILD_FUNC_DEF_3(newSSLObject,
+                PySocketSockObject *,Sock, char*,key_file, char*,cert_file)
+{
+       SSLObject *self;
+       char *str;
+
+#if 0
+       meth=SSLv23_client_method();
+       meth=SSLv3_client_method();
+       meth=SSLv2_client_method();
+#endif
+
+       self = PyObject_NEW(SSLObject, &SSL_Type); /* Create new object */
+       if (self == NULL){
+               PyErr_SetObject(SSLErrorObject,
+                               PyString_FromString("newSSLObject error"));
+               return NULL;
+       }
+       memset(self->server, NULL, sizeof(char) * 256);
+       memset(self->issuer, NULL, sizeof(char) * 256);  
+  
+       self->x_attr = PyDict_New();
+       self->ctx = SSL_CTX_new(SSLv23_method()); /* Set up context */
+       if (self->ctx == NULL) {
+               PyErr_SetObject(SSLErrorObject,
+                               PyString_FromString("SSL_CTX_new error"));
+               PyMem_DEL(self);
+               return NULL;
+       }
+
+       if ( (key_file && !cert_file) || (!key_file && cert_file) )
+       {
+               PyErr_SetObject(SSLErrorObject,
+                     PyString_FromString(
+                       "Both the key & certificate files must be specified"));
+               PyMem_DEL(self);
+               return NULL;
+       }
+
+       if (key_file && cert_file)
+       {
+               if (SSL_CTX_use_PrivateKey_file(self->ctx, key_file,
+                                               SSL_FILETYPE_PEM) < 1)
+               {
+                       PyErr_SetObject(SSLErrorObject,
+                               PyString_FromString(
+                                 "SSL_CTX_use_PrivateKey_file error"));
+                       PyMem_DEL(self);
+                       return NULL;
+               }
+
+               if (SSL_CTX_use_certificate_chain_file(self->ctx,
+                                                      cert_file) < 1)
+               {
+                       PyErr_SetObject(SSLErrorObject,
+                               PyString_FromString(
+                                 "SSL_CTX_use_certificate_chain_file error"));
+                       PyMem_DEL(self);
+                       return NULL;
+               }
+       }
+
+       SSL_CTX_set_verify(self->ctx,
+                          SSL_VERIFY_NONE, NULL); /* set verify lvl */
+       self->ssl = SSL_new(self->ctx); /* New ssl struct */
+       SSL_set_fd(self->ssl, Sock->sock_fd);   /* Set the socket for SSL */
+       SSL_set_connect_state(self->ssl);
+
+       if ((SSL_connect(self->ssl)) == -1) {
+               /* Actually negotiate SSL connection */
+               PyErr_SetObject(SSLErrorObject,
+                               PyString_FromString("SSL_connect error"));
+               PyMem_DEL(self);
+               return NULL;
+       }
+       self->ssl->debug = 1;
+
+       if ((self->server_cert = SSL_get_peer_certificate(self->ssl))) {
+               X509_NAME_oneline(X509_get_subject_name(self->server_cert),
+                                 self->server, 256);
+               X509_NAME_oneline(X509_get_issuer_name(self->server_cert),
+                                 self->issuer, 256);
+       }
+       self->x_attr = NULL;
+       self->Socket = Sock;
+       Py_INCREF(self->Socket);
+       return self;
+}
+
+/* This is the Python function called for new object initialization */
+static PyObject *
+BUILD_FUNC_DEF_2(PySocket_ssl, PyObject *, self, PyObject *, args)
+{
+       SSLObject *rv;
+       PySocketSockObject *Sock;
+       char *key_file;
+       char *cert_file;
+  
+       if (!PyArg_ParseTuple(args, "O!zz",
+                             &PySocketSock_Type, (PyObject*)&Sock,
+                             &key_file, &cert_file) )
+               return NULL;
+  
+       rv = newSSLObject(Sock, key_file, cert_file);
+       if ( rv == NULL )
+               return NULL;
+       return (PyObject *)rv;
+}
+
+static char ssl_doc[] =
+"ssl(socket, keyfile, certfile) -> sslobject";
+
+static PyObject *
+BUILD_FUNC_DEF_2(SSL_server, SSLObject *, self, PyObject *, args)
+{
+       return PyString_FromString(self->server);
+}
+
+static PyObject *
+BUILD_FUNC_DEF_2(SSL_issuer, SSLObject *, self, PyObject *, args)
+{
+       return PyString_FromString(self->issuer);
+}
+
+
+/* SSL object methods */
+
+static PyMethodDef SSLMethods[] = {
+       { "write", (PyCFunction)SSL_SSLwrite, 1 },
+       { "read", (PyCFunction)SSL_SSLread, 1 },
+       { "server", (PyCFunction)SSL_server, 1 },
+       { "issuer", (PyCFunction)SSL_issuer, 1 },
+       { NULL, NULL}
+};
+
+static void SSL_dealloc(SSLObject *self)
+{
+       if (self->server_cert)  /* Possible not to have one? */
+               X509_free (self->server_cert);
+       SSL_CTX_free(self->ctx);
+       SSL_free(self->ssl);
+       Py_XDECREF(self->x_attr);
+       Py_XDECREF(self->Socket);
+       PyMem_DEL(self);
+}
+
+static PyObject *SSL_getattr(SSLObject *self, char *name)
+{
+       return Py_FindMethod(SSLMethods, (PyObject *)self, name);
+}
+
+staticforward PyTypeObject SSL_Type = {
+       PyObject_HEAD_INIT(&PyType_Type)
+       0,                              /*ob_size*/
+       "SSL",                  /*tp_name*/
+       sizeof(SSLObject),              /*tp_basicsize*/
+       0,                              /*tp_itemsize*/
+       /* methods */
+       (destructor)SSL_dealloc,        /*tp_dealloc*/
+       0,                              /*tp_print*/
+       (getattrfunc)SSL_getattr,       /*tp_getattr*/
+       0,                              /*tp_setattr*/
+       0,                              /*tp_compare*/
+       0,                              /*tp_repr*/
+       0,                              /*tp_as_number*/
+       0,                              /*tp_as_sequence*/
+       0,                              /*tp_as_mapping*/
+       0,                              /*tp_hash*/
+};
+
+
+
+static PyObject *SSL_SSLwrite(SSLObject *self, PyObject *args)
+{
+       char *data;
+       int len = 0;
+  
+       if (!PyArg_ParseTuple(args, "s|i", &data, &len))
+               return NULL;
+  
+       if (!len)
+               len = strlen(data);
+  
+       len = SSL_write(self->ssl, data, len);
+       return PyInt_FromLong((long)len);
+}
+
+static PyObject *SSL_SSLread(SSLObject *self, PyObject *args)
+{
+       PyObject *buf;
+       int count = 0;
+       int len = 1024;
+       int res;
+  
+       PyArg_ParseTuple(args, "|i", &len);
+  
+       if (!(buf = PyString_FromStringAndSize((char *) 0, len)))
+               return NULL;    /* Error object should already be set */
+  
+       count = SSL_read(self->ssl, PyString_AsString(buf), len);
+       res = SSL_get_error(self->ssl, count);
+
+       switch (res) {
+       case 0:                 /* Good return value! */
+               break;
+       case 6:
+               PyErr_SetString(SSLErrorObject, "EOF");
+               Py_DECREF(buf);
+               return NULL;
+               break;
+       case 5:
+       default:
+               return PyErr_SetFromErrno(SSLErrorObject);
+               break;
+       }
+  
+       fflush(stderr);
+         
+       if (count < 0) {
+               Py_DECREF(buf);
+               return PyErr_SetFromErrno(SSLErrorObject);
+       }
+  
+       if (count != len && _PyString_Resize(&buf, count) < 0)
+               return NULL;
+       return buf;
+}
+
+#endif /* USE_SSL */
+
+
 /* List of functions exported by this module. */
 
 static PyMethodDef PySocket_methods[] = {
@@ -1893,6 +2167,9 @@ static PyMethodDef PySocket_methods[] = {
        {"htonl",               PySocket_htonl, 0, htonl_doc},
        {"inet_aton",           PySocket_inet_aton, 0, inet_aton_doc},
        {"inet_ntoa",           PySocket_inet_ntoa, 0, inet_ntoa_doc},
+#ifdef USE_SSL
+       {"ssl",                 PySocket_ssl, 1, ssl_doc},
+#endif /* USE_SSL */
        {NULL,                  NULL}            /* Sentinel */
 };
 
@@ -1981,7 +2258,6 @@ OS2init()
 
 #endif /* PYOS_OS2 */
 
-
 /* Initialize this module.
  *   This is called when the first 'import socket' is done,
  *   via a table in config.c, if config.c is compiled with USE_SOCKET
@@ -2019,6 +2295,7 @@ ntohs(), ntohl() -- convert 16, 32 bit int from network to host byte order\n\
 htons(), htonl() -- convert 16, 32 bit int from host to network byte order\n\
 inet_aton() -- convert IP addr string (123.45.67.89) to 32-bit packed format\n\
 inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89)\n\
+ssl() -- secure socket layer support (only available if configured)\n\
 \n\
 (*) not available on all platforms!)\n\
 \n\
@@ -2092,6 +2369,18 @@ initsocket()
        PySocket_Error = PyErr_NewException("socket.error", NULL, NULL);
        if (PySocket_Error == NULL)
                return;
+#ifdef USE_SSL
+       SSL_load_error_strings();
+       SSLeay_add_ssl_algorithms();
+       SSLErrorObject = PyErr_NewException("socket.sslerror", NULL, NULL);
+       if (SSLErrorObject == NULL)
+               return;
+       PyDict_SetItemString(d, "sslerror", SSLErrorObject);
+       Py_INCREF(&SSL_Type);
+       if (PyDict_SetItemString(d, "SSLType",
+                                (PyObject *)&SSL_Type) != 0)
+               return;
+#endif /* USE_SSL */
        PyDict_SetItemString(d, "error", PySocket_Error);
        PySocketSock_Type.ob_type = &PyType_Type;
        PySocketSock_Type.tp_doc = sockettype_doc;