]> granicus.if.org Git - python/commitdiff
Revert r62242: trunk's test_ssl.py isn't as up-to-date as py3k's, and should't have...
authorTrent Nelson <trent.nelson@snakebite.org>
Thu, 10 Apr 2008 20:12:06 +0000 (20:12 +0000)
committerTrent Nelson <trent.nelson@snakebite.org>
Thu, 10 Apr 2008 20:12:06 +0000 (20:12 +0000)
Lib/test/test_ssl.py

index 99ed00fc001363d594946fc7296ce16012583ff0..520f440d2ab7e6830ed800b30e9e8cc43d3b8d93 100644 (file)
@@ -25,10 +25,11 @@ try:
 except ImportError:
     skip_expected = True
 
-HOST = test_support.HOST
 CERTFILE = None
 SVN_PYTHON_ORG_ROOT_CERT = None
 
+TESTPORT = 10025
+
 def handle_error(prefix):
     exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
     if test_support.verbose:
@@ -298,7 +299,7 @@ else:
                     except:
                         handle_error('')
 
-        def __init__(self, certificate, ssl_version=None,
+        def __init__(self, port, certificate, ssl_version=None,
                      certreqs=None, cacerts=None, expect_bad_connects=False,
                      chatty=True, connectionchatty=False, starttls_server=False):
             if ssl_version is None:
@@ -314,8 +315,12 @@ else:
             self.connectionchatty = connectionchatty
             self.starttls_server = starttls_server
             self.sock = socket.socket()
-            self.port = test_support.bind_port(self.sock)
             self.flag = None
+            if hasattr(socket, 'SO_REUSEADDR'):
+                self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            if hasattr(socket, 'SO_REUSEPORT'):
+                self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+            self.sock.bind(('127.0.0.1', port))
             self.active = False
             threading.Thread.__init__(self)
             self.setDaemon(False)
@@ -466,13 +471,12 @@ else:
                                       format%args))
 
 
-        def __init__(self, certfile):
+        def __init__(self, port, certfile):
             self.flag = None
             self.active = False
             self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
-            self.port = test_support.find_unused_port()
             self.server = self.HTTPSServer(
-                (HOST, self.port), self.RootedHTTPRequestHandler, certfile)
+                ('', port), self.RootedHTTPRequestHandler, certfile)
             threading.Thread.__init__(self)
             self.setDaemon(True)
 
@@ -582,7 +586,7 @@ else:
             self.server.close()
 
     def badCertTest (certfile):
-        server = ThreadedEchoServer(CERTFILE,
+        server = ThreadedEchoServer(TESTPORT, CERTFILE,
                                     certreqs=ssl.CERT_REQUIRED,
                                     cacerts=CERTFILE, chatty=False,
                                     connectionchatty=False)
@@ -596,7 +600,7 @@ else:
                 s = ssl.wrap_socket(socket.socket(),
                                     certfile=certfile,
                                     ssl_version=ssl.PROTOCOL_TLSv1)
-                s.connect((HOST, server.port))
+                s.connect(('127.0.0.1', TESTPORT))
             except ssl.SSLError as x:
                 if test_support.verbose:
                     sys.stdout.write("\nSSLError is %s\n" % x)
@@ -612,7 +616,7 @@ else:
                           indata="FOO\n",
                           chatty=False, connectionchatty=False):
 
-        server = ThreadedEchoServer(certfile,
+        server = ThreadedEchoServer(TESTPORT, certfile,
                                     certreqs=certreqs,
                                     ssl_version=protocol,
                                     cacerts=cacertsfile,
@@ -627,11 +631,12 @@ else:
             client_protocol = protocol
         try:
             s = ssl.wrap_socket(socket.socket(),
+                                server_side=False,
                                 certfile=client_certfile,
                                 ca_certs=cacertsfile,
                                 cert_reqs=certreqs,
                                 ssl_version=client_protocol)
-            s.connect((HOST, server.port))
+            s.connect(('127.0.0.1', TESTPORT))
         except ssl.SSLError as x:
             raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
         except Exception as x:
@@ -641,17 +646,18 @@ else:
                 if test_support.verbose:
                     sys.stdout.write(
                         " client:  sending %s...\n" % (repr(indata)))
-            s.write(indata)
+            s.write(indata.encode('ASCII', 'strict'))
             outdata = s.read()
             if connectionchatty:
                 if test_support.verbose:
                     sys.stdout.write(" client:  read %s\n" % repr(outdata))
+            outdata = str(outdata, 'ASCII', 'strict')
             if outdata != indata.lower():
                 raise test_support.TestFailed(
                     "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
-                    % (outdata[:min(len(outdata),20)], len(outdata),
-                       indata[:min(len(indata),20)].lower(), len(indata)))
-            s.write("over\n")
+                    % (repr(outdata[:min(len(outdata),20)]), len(outdata),
+                       repr(indata[:min(len(indata),20)].lower()), len(indata)))
+            s.write("over\n".encode("ASCII", "strict"))
             if connectionchatty:
                 if test_support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
@@ -697,44 +703,7 @@ else:
 
     class ThreadedTests(unittest.TestCase):
 
-        def testRudeShutdown(self):
-
-            listener_ready = threading.Event()
-            listener_gone = threading.Event()
-            port = test_support.find_unused_port()
-
-            # `listener` runs in a thread.  It opens a socket listening on
-            # PORT, and sits in an accept() until the main thread connects.
-            # Then it rudely closes the socket, and sets Event `listener_gone`
-            # to let the main thread know the socket is gone.
-            def listener():
-                s = socket.socket()
-                s.bind((HOST, port))
-                s.listen(5)
-                listener_ready.set()
-                s.accept()
-                s = None # reclaim the socket object, which also closes it
-                listener_gone.set()
-
-            def connector():
-                listener_ready.wait()
-                s = socket.socket()
-                s.connect((HOST, port))
-                listener_gone.wait()
-                try:
-                    ssl_sock = ssl.wrap_socket(s)
-                except IOError:
-                    pass
-                else:
-                    raise test_support.TestFailed(
-                          'connecting to closed SSL socket should have failed')
-
-            t = threading.Thread(target=listener)
-            t.start()
-            connector()
-            t.join()
-
-        def testEcho(self):
+        def testEcho (self):
 
             if test_support.verbose:
                 sys.stdout.write("\n")
@@ -747,7 +716,7 @@ else:
             if test_support.verbose:
                 sys.stdout.write("\n")
             s2 = socket.socket()
-            server = ThreadedEchoServer(CERTFILE,
+            server = ThreadedEchoServer(TESTPORT, CERTFILE,
                                         certreqs=ssl.CERT_NONE,
                                         ssl_version=ssl.PROTOCOL_SSLv23,
                                         cacerts=CERTFILE,
@@ -764,7 +733,7 @@ else:
                                         ca_certs=CERTFILE,
                                         cert_reqs=ssl.CERT_REQUIRED,
                                         ssl_version=ssl.PROTOCOL_SSLv23)
-                    s.connect((HOST, server.port))
+                    s.connect(('127.0.0.1', TESTPORT))
                 except ssl.SSLError as x:
                     raise test_support.TestFailed(
                         "Unexpected SSL error:  " + str(x))
@@ -807,6 +776,46 @@ else:
             badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
                                      "badkey.pem"))
 
+        def testRudeShutdown(self):
+
+            listener_ready = threading.Event()
+            listener_gone = threading.Event()
+
+            # `listener` runs in a thread.  It opens a socket listening on
+            # PORT, and sits in an accept() until the main thread connects.
+            # Then it rudely closes the socket, and sets Event `listener_gone`
+            # to let the main thread know the socket is gone.
+            def listener():
+                s = socket.socket()
+                if hasattr(socket, 'SO_REUSEADDR'):
+                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+                if hasattr(socket, 'SO_REUSEPORT'):
+                    s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+                s.bind(('127.0.0.1', TESTPORT))
+                s.listen(5)
+                listener_ready.set()
+                s.accept()
+                s = None # reclaim the socket object, which also closes it
+                listener_gone.set()
+
+            def connector():
+                listener_ready.wait()
+                s = socket.socket()
+                s.connect(('127.0.0.1', TESTPORT))
+                listener_gone.wait()
+                try:
+                    ssl_sock = ssl.wrap_socket(s)
+                except IOError:
+                    pass
+                else:
+                    raise test_support.TestFailed(
+                          'connecting to closed SSL socket should have failed')
+
+            t = threading.Thread(target=listener)
+            t.start()
+            connector()
+            t.join()
+
         def testProtocolSSL2(self):
             if test_support.verbose:
                 sys.stdout.write("\n")
@@ -864,7 +873,7 @@ else:
 
             msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4")
 
-            server = ThreadedEchoServer(CERTFILE,
+            server = ThreadedEchoServer(TESTPORT, CERTFILE,
                                         ssl_version=ssl.PROTOCOL_TLSv1,
                                         starttls_server=True,
                                         chatty=True,
@@ -879,7 +888,7 @@ else:
                 try:
                     s = socket.socket()
                     s.setblocking(1)
-                    s.connect((HOST, server.port))
+                    s.connect(('127.0.0.1', TESTPORT))
                 except Exception as x:
                     raise test_support.TestFailed("Unexpected exception:  " + str(x))
                 else:
@@ -927,8 +936,7 @@ else:
 
         def testSocketServer(self):
 
-
-            server = AsyncoreHTTPSServer(CERTFILE)
+            server = OurHTTPSServer(TESTPORT, CERTFILE)
             flag = threading.Event()
             server.start(flag)
             # wait for it to start
@@ -940,8 +948,8 @@ else:
                 d1 = open(CERTFILE, 'rb').read()
                 d2 = ''
                 # now fetch the same data from the HTTPS server
-                url = 'https://%s:%d/%s' % (
-                    HOST, server.port, os.path.split(CERTFILE)[1])
+                url = 'https://127.0.0.1:%d/%s' % (
+                    TESTPORT, os.path.split(CERTFILE)[1])
                 f = urllib.urlopen(url)
                 dlen = f.info().getheader("content-length")
                 if dlen and (int(dlen) > 0):
@@ -970,11 +978,71 @@ else:
                     sys.stdout.write('joining thread\n')
                 server.join()
 
+        def testAsyncoreServer(self):
+
+            if test_support.verbose:
+                sys.stdout.write("\n")
+
+            indata="FOO\n"
+            server = AsyncoreEchoServer(TESTPORT, CERTFILE)
+            flag = threading.Event()
+            server.start(flag)
+            # wait for it to start
+            flag.wait()
+            # try to connect
+            try:
+                s = ssl.wrap_socket(socket.socket())
+                s.connect(('127.0.0.1', TESTPORT))
+            except ssl.SSLError as x:
+                raise test_support.TestFailed("Unexpected SSL error:  " + str(x))
+            except Exception as x:
+                raise test_support.TestFailed("Unexpected exception:  " + str(x))
+            else:
+                if test_support.verbose:
+                    sys.stdout.write(
+                        " client:  sending %s...\n" % (repr(indata)))
+                s.sendall(indata.encode('ASCII', 'strict'))
+                outdata = s.recv()
+                if test_support.verbose:
+                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
+                outdata = str(outdata, 'ASCII', 'strict')
+                if outdata != indata.lower():
+                    raise test_support.TestFailed(
+                        "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
+                        % (repr(outdata[:min(len(outdata),20)]), len(outdata),
+                           repr(indata[:min(len(indata),20)].lower()), len(indata)))
+                s.write("over\n".encode("ASCII", "strict"))
+                if test_support.verbose:
+                    sys.stdout.write(" client:  closing connection.\n")
+                s.close()
+            finally:
+                server.stop()
+                server.join()
+
+
+def findtestsocket(start, end):
+    def testbind(i):
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        try:
+            s.bind(("127.0.0.1", i))
+        except:
+            return 0
+        else:
+            return 1
+        finally:
+            s.close()
+
+    for i in range(start, end):
+        if testbind(i) and testbind(i+1):
+            return i
+    return 0
+
+
 def test_main(verbose=False):
     if skip_expected:
         raise test_support.TestSkipped("No SSL support")
 
-    global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT
+    global CERTFILE, TESTPORT, SVN_PYTHON_ORG_ROOT_CERT
     CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
                             "keycert.pem")
     SVN_PYTHON_ORG_ROOT_CERT = os.path.join(
@@ -985,6 +1053,10 @@ def test_main(verbose=False):
         not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
         raise test_support.TestFailed("Can't read certificate files!")
 
+    TESTPORT = findtestsocket(10025, 12000)
+    if not TESTPORT:
+        raise test_support.TestFailed("Can't find open port to test servers on!")
+
     tests = [BasicTests]
 
     if test_support.is_resource_enabled('network'):