]> granicus.if.org Git - python/commitdiff
incorporate fixes from issue 3162; SSL doc patch
authorBill Janssen <janssen@parc.com>
Mon, 8 Sep 2008 16:37:24 +0000 (16:37 +0000)
committerBill Janssen <janssen@parc.com>
Mon, 8 Sep 2008 16:37:24 +0000 (16:37 +0000)
Doc/library/ssl.rst
Lib/ssl.py
Lib/test/test_ssl.py

index df161fd3c032856c6528da527bd7b04e1b2f0c9a..4fcc2f40df6fd5452fce9b0d5f66cf939fed530e 100644 (file)
@@ -327,9 +327,10 @@ SSLSocket Objects
    Performs the SSL shutdown handshake, which removes the TLS layer
    from the underlying socket, and returns the underlying socket
    object.  This can be used to go from encrypted operation over a
-   connection to unencrypted.  The returned socket should always be
+   connection to unencrypted.  The socket instance returned should always be
    used for further communication with the other side of the
-   connection, rather than the original socket
+   connection, rather than the original socket instance (which may
+   not function properly after the unwrap).
 
 .. index:: single: certificates
 
index c9ee71a08505ba700546a7393ef53b517771aebd..8a799bcc56743ba6d80209189e374812f55ff15b 100644 (file)
@@ -91,10 +91,12 @@ class SSLSocket (socket):
                  suppress_ragged_eofs=True):
         socket.__init__(self, _sock=sock._sock)
         # the initializer for socket trashes the methods (tsk, tsk), so...
-        self.send = lambda x, flags=0: SSLSocket.send(self, x, flags)
-        self.recv = lambda x, flags=0: SSLSocket.recv(self, x, flags)
+        self.send = lambda data, flags=0: SSLSocket.send(self, data, flags)
         self.sendto = lambda data, addr, flags=0: SSLSocket.sendto(self, data, addr, flags)
-        self.recvfrom = lambda addr, buflen, flags: SSLSocket.recvfrom(self, addr, buflen, flags)
+        self.recv = lambda buflen=1024, flags=0: SSLSocket.recv(self, buflen, flags)
+        self.recvfrom = lambda addr, buflen=1024, flags=0: SSLSocket.recvfrom(self, addr, buflen, flags)
+        self.recv_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recv_into(self, buffer, nbytes, flags)
+        self.recvfrom_into = lambda buffer, nbytes=None, flags=0: SSLSocket.recvfrom_into(self, buffer, nbytes, flags)
 
         if certfile and not keyfile:
             keyfile = certfile
@@ -221,6 +223,30 @@ class SSLSocket (socket):
         else:
             return socket.recv(self, buflen, flags)
 
+    def recv_into (self, buffer, nbytes=None, flags=0):
+        if buffer and (nbytes is None):
+            nbytes = len(buffer)
+        elif nbytes is None:
+            nbytes = 1024
+        if self._sslobj:
+            if flags != 0:
+                raise ValueError(
+                  "non-zero flags not allowed in calls to recv_into() on %s" %
+                  self.__class__)
+            while True:
+                try:
+                    tmp_buffer = self.read(nbytes)
+                    v = len(tmp_buffer)
+                    buffer[:v] = tmp_buffer
+                    return v
+                except SSLError as x:
+                    if x.args[0] == SSL_ERROR_WANT_READ:
+                        continue
+                    else:
+                        raise x
+        else:
+            return socket.recv_into(self, buffer, nbytes, flags)
+
     def recvfrom (self, addr, buflen=1024, flags=0):
         if self._sslobj:
             raise ValueError("recvfrom not allowed on instances of %s" %
@@ -228,6 +254,13 @@ class SSLSocket (socket):
         else:
             return socket.recvfrom(self, addr, buflen, flags)
 
+    def recvfrom_into (self, buffer, nbytes=None, flags=0):
+        if self._sslobj:
+            raise ValueError("recvfrom_into not allowed on instances of %s" %
+                             self.__class__)
+        else:
+            return socket.recvfrom_into(self, buffer, nbytes, flags)
+
     def pending (self):
         if self._sslobj:
             return self._sslobj.pending()
@@ -295,8 +328,9 @@ class SSLSocket (socket):
 
     def makefile(self, mode='r', bufsize=-1):
 
-        """Ouch.  Need to make and return a file-like object that
-        works with the SSL connection."""
+        """Make and return a file-like object that
+        works with the SSL connection.  Just use the code
+        from the socket module."""
 
         self._makefile_refs += 1
         return _fileobject(self, mode, bufsize)
index 05a9b5787e67f427b8ad1346703c2f47508d269c..98681f4fe13b7169aba755f47f8391da1933d2e3 100644 (file)
@@ -1030,6 +1030,127 @@ else:
                 server.join()
 
 
+        def testAllRecvAndSendMethods(self):
+
+            if test_support.verbose:
+                sys.stdout.write("\n")
+
+            server = ThreadedEchoServer(CERTFILE,
+                                        certreqs=ssl.CERT_NONE,
+                                        ssl_version=ssl.PROTOCOL_TLSv1,
+                                        cacerts=CERTFILE,
+                                        chatty=True,
+                                        connectionchatty=False)
+            flag = threading.Event()
+            server.start(flag)
+            # wait for it to start
+            flag.wait()
+            # try to connect
+            try:
+                s = ssl.wrap_socket(socket.socket(),
+                                    server_side=False,
+                                    certfile=CERTFILE,
+                                    ca_certs=CERTFILE,
+                                    cert_reqs=ssl.CERT_NONE,
+                                    ssl_version=ssl.PROTOCOL_TLSv1)
+                s.connect((HOST, server.port))
+            except ssl.SSLError as x:
+                raise support.TestFailed("Unexpected SSL error:  " + str(x))
+            except Exception as x:
+                raise support.TestFailed("Unexpected exception:  " + str(x))
+            else:
+                # helper methods for standardising recv* method signatures
+                def _recv_into():
+                    b = bytearray("\0"*100)
+                    count = s.recv_into(b)
+                    return b[:count]
+
+                def _recvfrom_into():
+                    b = bytearray("\0"*100)
+                    count, addr = s.recvfrom_into(b)
+                    return b[:count]
+
+                # (name, method, whether to expect success, *args)
+                send_methods = [
+                    ('send', s.send, True, []),
+                    ('sendto', s.sendto, False, ["some.address"]),
+                    ('sendall', s.sendall, True, []),
+                ]
+                recv_methods = [
+                    ('recv', s.recv, True, []),
+                    ('recvfrom', s.recvfrom, False, ["some.address"]),
+                    ('recv_into', _recv_into, True, []),
+                    ('recvfrom_into', _recvfrom_into, False, []),
+                ]
+                data_prefix = u"PREFIX_"
+
+                for meth_name, send_meth, expect_success, args in send_methods:
+                    indata = data_prefix + meth_name
+                    try:
+                        send_meth(indata.encode('ASCII', 'strict'), *args)
+                        outdata = s.read()
+                        outdata = outdata.decode('ASCII', 'strict')
+                        if outdata != indata.lower():
+                            raise support.TestFailed(
+                                "While sending with <<%s>> bad data "
+                                "<<%r>> (%d) received; "
+                                "expected <<%r>> (%d)\n" % (
+                                    meth_name, outdata[:20], len(outdata),
+                                    indata[:20], len(indata)
+                                )
+                            )
+                    except ValueError as e:
+                        if expect_success:
+                            raise support.TestFailed(
+                                "Failed to send with method <<%s>>; "
+                                "expected to succeed.\n" % (meth_name,)
+                            )
+                        if not str(e).startswith(meth_name):
+                            raise support.TestFailed(
+                                "Method <<%s>> failed with unexpected "
+                                "exception message: %s\n" % (
+                                    meth_name, e
+                                )
+                            )
+
+                for meth_name, recv_meth, expect_success, args in recv_methods:
+                    indata = data_prefix + meth_name
+                    try:
+                        s.send(indata.encode('ASCII', 'strict'))
+                        outdata = recv_meth(*args)
+                        outdata = outdata.decode('ASCII', 'strict')
+                        if outdata != indata.lower():
+                            raise support.TestFailed(
+                                "While receiving with <<%s>> bad data "
+                                "<<%r>> (%d) received; "
+                                "expected <<%r>> (%d)\n" % (
+                                    meth_name, outdata[:20], len(outdata),
+                                    indata[:20], len(indata)
+                                )
+                            )
+                    except ValueError as e:
+                        if expect_success:
+                            raise support.TestFailed(
+                                "Failed to receive with method <<%s>>; "
+                                "expected to succeed.\n" % (meth_name,)
+                            )
+                        if not str(e).startswith(meth_name):
+                            raise support.TestFailed(
+                                "Method <<%s>> failed with unexpected "
+                                "exception message: %s\n" % (
+                                    meth_name, e
+                                )
+                            )
+                        # consume data
+                        s.read()
+
+                s.write("over\n".encode("ASCII", "strict"))
+                s.close()
+            finally:
+                server.stop()
+                server.join()
+
+
 def test_main(verbose=False):
     if skip_expected:
         raise test_support.TestSkipped("No SSL support")