]> granicus.if.org Git - python/commitdiff
Use context managers in test_ssl to simplify test writing.
authorAntoine Pitrou <solipsis@pitrou.net>
Wed, 21 Dec 2011 15:52:40 +0000 (16:52 +0100)
committerAntoine Pitrou <solipsis@pitrou.net>
Wed, 21 Dec 2011 15:52:40 +0000 (16:52 +0100)
Lib/test/test_ssl.py

index f942d956587f1e9839e1bc9b8c740c87e413e158..e9fbc8afebb9e1fe5e53b13aa382d1a9291319d8 100644 (file)
@@ -881,6 +881,14 @@ else:
             threading.Thread.__init__(self)
             self.daemon = True
 
+        def __enter__(self):
+            self.start(threading.Event())
+            self.flag.wait()
+
+        def __exit__(self, *args):
+            self.stop()
+            self.join()
+
         def start(self, flag=None):
             self.flag = flag
             threading.Thread.start(self)
@@ -993,6 +1001,20 @@ else:
         def __str__(self):
             return "<%s %s>" % (self.__class__.__name__, self.server)
 
+        def __enter__(self):
+            self.start(threading.Event())
+            self.flag.wait()
+
+        def __exit__(self, *args):
+            if support.verbose:
+                sys.stdout.write(" cleanup: stopping server.\n")
+            self.stop()
+            if support.verbose:
+                sys.stdout.write(" cleanup: joining server thread.\n")
+            self.join()
+            if support.verbose:
+                sys.stdout.write(" cleanup: successfully joined.\n")
+
         def start (self, flag=None):
             self.flag = flag
             threading.Thread.start(self)
@@ -1020,12 +1042,7 @@ else:
                                     certreqs=ssl.CERT_REQUIRED,
                                     cacerts=CERTFILE, chatty=False,
                                     connectionchatty=False)
-        flag = threading.Event()
-        server.start(flag)
-        # wait for it to start
-        flag.wait()
-        # try to connect
-        try:
+        with server:
             try:
                 with socket.socket() as sock:
                     s = ssl.wrap_socket(sock,
@@ -1045,9 +1062,6 @@ else:
                     sys.stdout.write("\IOError is %s\n" % str(x))
             else:
                 raise AssertionError("Use of invalid cert should have failed!")
-        finally:
-            server.stop()
-            server.join()
 
     def server_params_test(client_context, server_context, indata=b"FOO\n",
                            chatty=True, connectionchatty=False):
@@ -1058,12 +1072,7 @@ else:
         server = ThreadedEchoServer(context=server_context,
                                     chatty=chatty,
                                     connectionchatty=False)
-        flag = threading.Event()
-        server.start(flag)
-        # wait for it to start
-        flag.wait()
-        # try to connect
-        try:
+        with server:
             s = client_context.wrap_socket(socket.socket())
             s.connect((HOST, server.port))
             for arg in [indata, bytearray(indata), memoryview(indata)]:
@@ -1086,9 +1095,6 @@ else:
                 if support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
             s.close()
-        finally:
-            server.stop()
-            server.join()
 
     def try_protocol_combo(server_protocol, client_protocol, expect_success,
                            certsreqs=None, server_options=0, client_options=0):
@@ -1157,12 +1163,7 @@ else:
             context.load_verify_locations(CERTFILE)
             context.load_cert_chain(CERTFILE)
             server = ThreadedEchoServer(context=context, chatty=False)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
-            try:
+            with server:
                 s = context.wrap_socket(socket.socket())
                 s.connect((HOST, server.port))
                 cert = s.getpeercert()
@@ -1185,9 +1186,6 @@ else:
                 after = ssl.cert_time_to_seconds(cert['notAfter'])
                 self.assertLess(before, after)
                 s.close()
-            finally:
-                server.stop()
-                server.join()
 
         def test_empty_cert(self):
             """Connecting with an empty cert file"""
@@ -1346,13 +1344,8 @@ else:
                                         starttls_server=True,
                                         chatty=True,
                                         connectionchatty=True)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
             wrapped = False
-            try:
+            with server:
                 s = socket.socket()
                 s.setblocking(1)
                 s.connect((HOST, server.port))
@@ -1399,9 +1392,6 @@ else:
                     conn.close()
                 else:
                     s.close()
-            finally:
-                server.stop()
-                server.join()
 
         def test_socketserver(self):
             """Using a SocketServer to create and manage SSL connections."""
@@ -1437,12 +1427,7 @@ else:
 
             indata = b"FOO\n"
             server = AsyncoreEchoServer(CERTFILE)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
-            try:
+            with server:
                 s = ssl.wrap_socket(socket.socket())
                 s.connect(('127.0.0.1', server.port))
                 if support.verbose:
@@ -1463,15 +1448,6 @@ else:
                 s.close()
                 if support.verbose:
                     sys.stdout.write(" client:  connection closed.\n")
-            finally:
-                if support.verbose:
-                    sys.stdout.write(" cleanup: stopping server.\n")
-                server.stop()
-                if support.verbose:
-                    sys.stdout.write(" cleanup: joining server thread.\n")
-                server.join()
-                if support.verbose:
-                    sys.stdout.write(" cleanup: successfully joined.\n")
 
         def test_recv_send(self):
             """Test recv(), send() and friends."""
@@ -1484,19 +1460,14 @@ else:
                                         cacerts=CERTFILE,
                                         chatty=True,
                                         connectionchatty=False)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
-            # try to connect
-            s = ssl.wrap_socket(socket.socket(),
-                                server_side=False,
-                                certfile=CERTFILE,
-                                ca_certs=CERTFILE,
-                                cert_reqs=ssl.CERT_NONE,
-                                ssl_version=ssl.PROTOCOL_TLSv1)
-            s.connect((HOST, server.port))
-            try:
+            with server:
+                s = ssl.wrap_socket(socket.socket(),
+                                    server_side=False,
+                                    certfile=CERTFILE,
+                                    ca_certs=CERTFILE,
+                                    cert_reqs=ssl.CERT_NONE,
+                                    ssl_version=ssl.PROTOCOL_TLSv1)
+                s.connect((HOST, server.port))
                 # helper methods for standardising recv* method signatures
                 def _recv_into():
                     b = bytearray(b"\0"*100)
@@ -1581,12 +1552,8 @@ else:
                             )
                         # consume data
                         s.read()
-
                 s.write(b"over\n")
                 s.close()
-            finally:
-                server.stop()
-                server.join()
 
         def test_handshake_timeout(self):
             # Issue #5103: SSL handshake must respect the socket timeout