]> granicus.if.org Git - python/commitdiff
Update test_ssl.py to reflect the new approach for writing network-oriented tests...
authorTrent Nelson <trent.nelson@snakebite.org>
Thu, 10 Apr 2008 20:54:35 +0000 (20:54 +0000)
committerTrent Nelson <trent.nelson@snakebite.org>
Thu, 10 Apr 2008 20:54:35 +0000 (20:54 +0000)
Lib/test/test_ssl.py

index 520f440d2ab7e6830ed800b30e9e8cc43d3b8d93..59bf57dd4d69f5b348bebc3e81650e1e895f2b6c 100644 (file)
@@ -25,11 +25,10 @@ try:
 except ImportError:
     skip_expected = True
 
+HOST = test_support.HOST
 CERTFILE = None
 SVN_PYTHON_ORG_ROOT_CERT = None
 
-TESTPORT = 10025
-
 def handle_error(prefix):
     exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
     if test_support.verbose:
@@ -299,7 +298,7 @@ else:
                     except:
                         handle_error('')
 
-        def __init__(self, port, certificate, ssl_version=None,
+        def __init__(self, certificate, ssl_version=None,
                      certreqs=None, cacerts=None, expect_bad_connects=False,
                      chatty=True, connectionchatty=False, starttls_server=False):
             if ssl_version is None:
@@ -315,12 +314,8 @@ else:
             self.connectionchatty = connectionchatty
             self.starttls_server = starttls_server
             self.sock = socket.socket()
+            self.port = test_support.bind_port(self.sock)
             self.flag = None
-            if hasattr(socket, 'SO_REUSEADDR'):
-                self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-            if hasattr(socket, 'SO_REUSEPORT'):
-                self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-            self.sock.bind(('127.0.0.1', port))
             self.active = False
             threading.Thread.__init__(self)
             self.setDaemon(False)
@@ -471,12 +466,13 @@ else:
                                       format%args))
 
 
-        def __init__(self, port, certfile):
+        def __init__(self, certfile):
             self.flag = None
             self.active = False
             self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
+            self.port = test_support.find_unused_port()
             self.server = self.HTTPSServer(
-                ('', port), self.RootedHTTPRequestHandler, certfile)
+                (HOST, self.port), self.RootedHTTPRequestHandler, certfile)
             threading.Thread.__init__(self)
             self.setDaemon(True)
 
@@ -557,10 +553,11 @@ else:
             def handle_error(self):
                 raise
 
-        def __init__(self, port, certfile):
+        def __init__(self, certfile):
             self.flag = None
             self.active = False
-            self.server = self.EchoServer(port, certfile)
+            self.port = test_support.find_unused_port()
+            self.server = self.EchoServer(self.port, certfile)
             threading.Thread.__init__(self)
             self.setDaemon(True)
 
@@ -586,7 +583,7 @@ else:
             self.server.close()
 
     def badCertTest (certfile):
-        server = ThreadedEchoServer(TESTPORT, CERTFILE,
+        server = ThreadedEchoServer(CERTFILE,
                                     certreqs=ssl.CERT_REQUIRED,
                                     cacerts=CERTFILE, chatty=False,
                                     connectionchatty=False)
@@ -600,7 +597,7 @@ else:
                 s = ssl.wrap_socket(socket.socket(),
                                     certfile=certfile,
                                     ssl_version=ssl.PROTOCOL_TLSv1)
-                s.connect(('127.0.0.1', TESTPORT))
+                s.connect((HOST, server.port))
             except ssl.SSLError as x:
                 if test_support.verbose:
                     sys.stdout.write("\nSSLError is %s\n" % x)
@@ -616,7 +613,7 @@ else:
                           indata="FOO\n",
                           chatty=False, connectionchatty=False):
 
-        server = ThreadedEchoServer(TESTPORT, certfile,
+        server = ThreadedEchoServer(certfile,
                                     certreqs=certreqs,
                                     ssl_version=protocol,
                                     cacerts=cacertsfile,
@@ -636,7 +633,7 @@ else:
                                 ca_certs=cacertsfile,
                                 cert_reqs=certreqs,
                                 ssl_version=client_protocol)
-            s.connect(('127.0.0.1', TESTPORT))
+            s.connect((HOST, server.port))
         except ssl.SSLError as x:
             raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
         except Exception as x:
@@ -716,7 +713,7 @@ else:
             if test_support.verbose:
                 sys.stdout.write("\n")
             s2 = socket.socket()
-            server = ThreadedEchoServer(TESTPORT, CERTFILE,
+            server = ThreadedEchoServer(CERTFILE,
                                         certreqs=ssl.CERT_NONE,
                                         ssl_version=ssl.PROTOCOL_SSLv23,
                                         cacerts=CERTFILE,
@@ -733,7 +730,7 @@ else:
                                         ca_certs=CERTFILE,
                                         cert_reqs=ssl.CERT_REQUIRED,
                                         ssl_version=ssl.PROTOCOL_SSLv23)
-                    s.connect(('127.0.0.1', TESTPORT))
+                    s.connect((HOST, server.port))
                 except ssl.SSLError as x:
                     raise test_support.TestFailed(
                         "Unexpected SSL error:  " + str(x))
@@ -780,6 +777,7 @@ else:
 
             listener_ready = threading.Event()
             listener_gone = threading.Event()
+            port = test_support.find_unused_port()
 
             # `listener` runs in a thread.  It opens a socket listening on
             # PORT, and sits in an accept() until the main thread connects.
@@ -787,11 +785,7 @@ else:
             # to let the main thread know the socket is gone.
             def listener():
                 s = socket.socket()
-                if hasattr(socket, 'SO_REUSEADDR'):
-                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-                if hasattr(socket, 'SO_REUSEPORT'):
-                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-                s.bind(('127.0.0.1', TESTPORT))
+                s.bind((HOST, port))
                 s.listen(5)
                 listener_ready.set()
                 s.accept()
@@ -801,7 +795,7 @@ else:
             def connector():
                 listener_ready.wait()
                 s = socket.socket()
-                s.connect(('127.0.0.1', TESTPORT))
+                s.connect((HOST, port))
                 listener_gone.wait()
                 try:
                     ssl_sock = ssl.wrap_socket(s)
@@ -873,7 +867,7 @@ else:
 
             msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4")
 
-            server = ThreadedEchoServer(TESTPORT, CERTFILE,
+            server = ThreadedEchoServer(CERTFILE,
                                         ssl_version=ssl.PROTOCOL_TLSv1,
                                         starttls_server=True,
                                         chatty=True,
@@ -888,7 +882,7 @@ else:
                 try:
                     s = socket.socket()
                     s.setblocking(1)
-                    s.connect(('127.0.0.1', TESTPORT))
+                    s.connect((HOST, server.port))
                 except Exception as x:
                     raise test_support.TestFailed("Unexpected exception:  " + str(x))
                 else:
@@ -936,7 +930,7 @@ else:
 
         def testSocketServer(self):
 
-            server = OurHTTPSServer(TESTPORT, CERTFILE)
+            server = OurHTTPSServer(CERTFILE)
             flag = threading.Event()
             server.start(flag)
             # wait for it to start
@@ -948,8 +942,8 @@ else:
                 d1 = open(CERTFILE, 'rb').read()
                 d2 = ''
                 # now fetch the same data from the HTTPS server
-                url = 'https://127.0.0.1:%d/%s' % (
-                    TESTPORT, os.path.split(CERTFILE)[1])
+                url = 'https://%s:%d/%s' % (
+                    HOST, server.port, os.path.split(CERTFILE)[1])
                 f = urllib.urlopen(url)
                 dlen = f.info().getheader("content-length")
                 if dlen and (int(dlen) > 0):
@@ -984,7 +978,7 @@ else:
                 sys.stdout.write("\n")
 
             indata="FOO\n"
-            server = AsyncoreEchoServer(TESTPORT, CERTFILE)
+            server = AsyncoreEchoServer(CERTFILE)
             flag = threading.Event()
             server.start(flag)
             # wait for it to start
@@ -992,7 +986,7 @@ else:
             # try to connect
             try:
                 s = ssl.wrap_socket(socket.socket())
-                s.connect(('127.0.0.1', TESTPORT))
+                s.connect((HOST, server.port))
             except ssl.SSLError as x:
                 raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
             except Exception as x:
@@ -1019,30 +1013,11 @@ else:
                 server.stop()
                 server.join()
 
-
-def findtestsocket(start, end):
-    def testbind(i):
-        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        try:
-            s.bind(("127.0.0.1", i))
-        except:
-            return 0
-        else:
-            return 1
-        finally:
-            s.close()
-
-    for i in range(start, end):
-        if testbind(i) and testbind(i+1):
-            return i
-    return 0
-
-
 def test_main(verbose=False):
     if skip_expected:
         raise test_support.TestSkipped("No SSL support")
 
-    global CERTFILE, TESTPORT, SVN_PYTHON_ORG_ROOT_CERT
+    global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT
     CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
                             "keycert.pem")
     SVN_PYTHON_ORG_ROOT_CERT = os.path.join(
@@ -1053,10 +1028,6 @@ def test_main(verbose=False):
         not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
         raise test_support.TestFailed("Can't read certificate files!")
 
-    TESTPORT = findtestsocket(10025, 12000)
-    if not TESTPORT:
-        raise test_support.TestFailed("Can't find open port to test servers on!")
-
     tests = [BasicTests]
 
     if test_support.is_resource_enabled('network'):