]> granicus.if.org Git - python/commitdiff
Use context managers in test_ssl to simplify test writing.
authorAntoine Pitrou <solipsis@pitrou.net>
Wed, 21 Dec 2011 15:52:40 +0000 (16:52 +0100)
committerAntoine Pitrou <solipsis@pitrou.net>
Wed, 21 Dec 2011 15:52:40 +0000 (16:52 +0100)
Lib/test/test_ssl.py

index ba1d868ccced8de8ed8551cc6605cfcabf67e4fe..e5addf807778bc6b688db949bb88c7dfcdd016d6 100644 (file)
@@ -532,6 +532,14 @@ else:
             threading.Thread.__init__(self)
             self.daemon = True
 
+        def __enter__(self):
+            self.start(threading.Event())
+            self.flag.wait()
+
+        def __exit__(self, *args):
+            self.stop()
+            self.join()
+
         def start(self, flag=None):
             self.flag = flag
             threading.Thread.start(self)
@@ -638,6 +646,20 @@ else:
         def __str__(self):
             return "<%s %s>" % (self.__class__.__name__, self.server)
 
+        def __enter__(self):
+            self.start(threading.Event())
+            self.flag.wait()
+
+        def __exit__(self, *args):
+            if test_support.verbose:
+                sys.stdout.write(" cleanup: stopping server.\n")
+            self.stop()
+            if test_support.verbose:
+                sys.stdout.write(" cleanup: joining server thread.\n")
+            self.join()
+            if test_support.verbose:
+                sys.stdout.write(" cleanup: successfully joined.\n")
+
         def start(self, flag=None):
             self.flag = flag
             threading.Thread.start(self)
@@ -752,12 +774,7 @@ else:
         server = ThreadedEchoServer(CERTFILE,
                                     certreqs=ssl.CERT_REQUIRED,
                                     cacerts=CERTFILE, chatty=False)
-        flag = threading.Event()
-        server.start(flag)
-        # wait for it to start
-        flag.wait()
-        # try to connect
-        try:
+        with server:
             try:
                 s = ssl.wrap_socket(socket.socket(),
                                     certfile=certfile,
@@ -771,9 +788,6 @@ else:
                     sys.stdout.write("\nsocket.error is %s\n" % x[1])
             else:
                 raise AssertionError("Use of invalid cert should have failed!")
-        finally:
-            server.stop()
-            server.join()
 
     def server_params_test(certfile, protocol, certreqs, cacertsfile,
                            client_certfile, client_protocol=None, indata="FOO\n",
@@ -791,14 +805,10 @@ else:
                                     chatty=chatty,
                                     connectionchatty=connectionchatty,
                                     wrap_accepting_socket=wrap_accepting_socket)
-        flag = threading.Event()
-        server.start(flag)
-        # wait for it to start
-        flag.wait()
-        # try to connect
-        if client_protocol is None:
-            client_protocol = protocol
-        try:
+        with server:
+            # try to connect
+            if client_protocol is None:
+                client_protocol = protocol
             s = ssl.wrap_socket(socket.socket(),
                                 certfile=client_certfile,
                                 ca_certs=cacertsfile,
@@ -826,9 +836,6 @@ else:
                 if test_support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
             s.close()
-        finally:
-            server.stop()
-            server.join()
 
     def try_protocol_combo(server_protocol,
                            client_protocol,
@@ -930,12 +937,7 @@ else:
                                         ssl_version=ssl.PROTOCOL_SSLv23,
                                         cacerts=CERTFILE,
                                         chatty=False)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
-            try:
+            with server:
                 s = ssl.wrap_socket(socket.socket(),
                                     certfile=CERTFILE,
                                     ca_certs=CERTFILE,
@@ -957,9 +959,6 @@ else:
                         "Missing or invalid 'organizationName' field in certificate subject; "
                         "should be 'Python Software Foundation'.")
                 s.close()
-            finally:
-                server.stop()
-                server.join()
 
         def test_empty_cert(self):
             """Connecting with an empty cert file"""
@@ -1042,13 +1041,8 @@ else:
                                         starttls_server=True,
                                         chatty=True,
                                         connectionchatty=True)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
             wrapped = False
-            try:
+            with server:
                 s = socket.socket()
                 s.setblocking(1)
                 s.connect((HOST, server.port))
@@ -1093,9 +1087,6 @@ else:
                 else:
                     s.send("over\n")
                 s.close()
-            finally:
-                server.stop()
-                server.join()
 
         def test_socketserver(self):
             """Using a SocketServer to create and manage SSL connections."""
@@ -1145,12 +1136,7 @@ else:
             if test_support.verbose:
                 sys.stdout.write("\n")
             server = AsyncoreEchoServer(CERTFILE)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
-            try:
+            with server:
                 s = ssl.wrap_socket(socket.socket())
                 s.connect(('127.0.0.1', server.port))
                 if test_support.verbose:
@@ -1169,10 +1155,6 @@ else:
                 if test_support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
                 s.close()
-            finally:
-                server.stop()
-                # wait for server thread to end
-                server.join()
 
         def test_recv_send(self):
             """Test recv(), send() and friends."""
@@ -1185,19 +1167,14 @@ else:
                                         cacerts=CERTFILE,
                                         chatty=True,
                                         connectionchatty=False)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
-            s = ssl.wrap_socket(socket.socket(),
-                                server_side=False,
-                                certfile=CERTFILE,
-                                ca_certs=CERTFILE,
-                                cert_reqs=ssl.CERT_NONE,
-                                ssl_version=ssl.PROTOCOL_TLSv1)
-            s.connect((HOST, server.port))
-            try:
+            with server:
+                s = ssl.wrap_socket(socket.socket(),
+                                    server_side=False,
+                                    certfile=CERTFILE,
+                                    ca_certs=CERTFILE,
+                                    cert_reqs=ssl.CERT_NONE,
+                                    ssl_version=ssl.PROTOCOL_TLSv1)
+                s.connect((HOST, server.port))
                 # helper methods for standardising recv* method signatures
                 def _recv_into():
                     b = bytearray("\0"*100)
@@ -1285,9 +1262,6 @@ else:
 
                 s.write("over\n".encode("ASCII", "strict"))
                 s.close()
-            finally:
-                server.stop()
-                server.join()
 
         def test_handshake_timeout(self):
             # Issue #5103: SSL handshake must respect the socket timeout