]> granicus.if.org Git - python/commitdiff
#14758: add IPv6 support to smtpd.
authorR David Murray <rdmurray@bitdance.com>
Wed, 11 Jun 2014 17:48:58 +0000 (13:48 -0400)
committerR David Murray <rdmurray@bitdance.com>
Wed, 11 Jun 2014 17:48:58 +0000 (13:48 -0400)
Patch by Milan Oberkirch.

Doc/library/smtpd.rst
Doc/whatsnew/3.5.rst
Lib/smtpd.py
Lib/test/mock_socket.py
Lib/test/test_smtpd.py

index e6625df74fd50cde9b4e573e970d2babf2301d2a..0f4a0bf512273b0f379b30693f9a8fd3ccb09b8f 100644 (file)
@@ -68,8 +68,8 @@ SMTPServer Objects
    .. versionchanged:: 3.4
       The *map* argument was added.
 
-   .. versionchanged:: 3.5
-      the *decode_data* argument was added.
+   .. versionchanged:: 3.5 the *decode_data* argument was added, and *localaddr*
+      and *remoteaddr* may now contain IPv6 addresses.
 
 
 DebuggingServer Objects
index 68106c160b85ad0fd7be5de9a04d6d16355c2fca..846a41661af327903ce440e74285af5c2e182b20 100644 (file)
@@ -194,6 +194,10 @@ smtpd
   is ``True`` for backward compatibility reasons, but will change to ``False``
   in Python 3.6.  (Contributed by Maciej Szulik in :issue:`19662`.)
 
+* It is now possible to provide, directly or via name resolution, IPv6
+  addresses in the :class:`~smtpd.SMTPServer` constructor, and have it
+  successfully connect.  (Contributed by Milan Oberkirch in :issue:`14758`.)
+
 socket
 ------
 
index 569b42e2228eb01f5ae546ba99c5a014017da9bd..d828c5f12fee85daaf91b5a26a8930e97a7c4254 100755 (executable)
@@ -610,7 +610,8 @@ class SMTPServer(asyncore.dispatcher):
         self._decode_data = decode_data
         asyncore.dispatcher.__init__(self, map=map)
         try:
-            self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+            gai_results = socket.getaddrinfo(*localaddr)
+            self.create_socket(gai_results[0][0], gai_results[0][1])
             # try to re-use a server port if possible
             self.set_reuse_addr()
             self.bind(localaddr)
index e36724f54bce0d52ec2af99d554d7ffc751855c2..a4fbca6369c4f2a50dd4257fb47dffc8ab8fe9e0 100644 (file)
@@ -35,8 +35,9 @@ class MockFile:
 class MockSocket:
     """Mock socket object used by smtpd and smtplib tests.
     """
-    def __init__(self):
+    def __init__(self, family=None):
         global _reply_data
+        self.family = family
         self.output = []
         self.lines = []
         if _reply_data:
@@ -108,8 +109,7 @@ class MockSocket:
 
 
 def socket(family=None, type=None, proto=None):
-    return MockSocket()
-
+    return MockSocket(family)
 
 def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT,
                       source_address=None):
@@ -144,13 +144,16 @@ def gethostname():
 def gethostbyname(name):
     return ""
 
+def getaddrinfo(host, port):
+    return socket_module.getaddrinfo(host, port)
 
 gaierror = socket_module.gaierror
 error = socket_module.error
 
 
 # Constants
-AF_INET = None
+AF_INET = socket_module.AF_INET
+AF_INET6 = socket_module.AF_INET6
 SOCK_STREAM = None
 SOL_SOCKET = None
 SO_REUSEADDR = None
index db1f52b7b4aeaceb5a7f317857481696984519bb..caeb79763907e110038393a8463257dec28ab911 100644 (file)
@@ -36,7 +36,8 @@ class SMTPDServerTest(unittest.TestCase):
         smtpd.socket = asyncore.socket = mock_socket
 
     def test_process_message_unimplemented(self):
-        server = smtpd.SMTPServer('a', 'b', decode_data=True)
+        server = smtpd.SMTPServer((support.HOST, 0), ('b', 0),
+                                  decode_data=True)
         conn, addr = server.accept()
         channel = smtpd.SMTPChannel(server, conn, addr, decode_data=True)
 
@@ -52,19 +53,39 @@ class SMTPDServerTest(unittest.TestCase):
 
     def test_decode_data_default_warns(self):
         with self.assertWarns(DeprecationWarning):
-            smtpd.SMTPServer('a', 'b')
+            smtpd.SMTPServer((support.HOST, 0), ('b', 0))
 
     def tearDown(self):
         asyncore.close_all()
         asyncore.socket = smtpd.socket = socket
 
 
+class TestFamilyDetection(unittest.TestCase):
+    def setUp(self):
+        smtpd.socket = asyncore.socket = mock_socket
+
+    def tearDown(self):
+        asyncore.close_all()
+        asyncore.socket = smtpd.socket = socket
+
+    @unittest.skipUnless(support.IPV6_ENABLED, "IPv6 not enabled")
+    def test_socket_uses_IPv6(self):
+        server = smtpd.SMTPServer((support.HOSTv6, 0), (support.HOST, 0),
+                                  decode_data=False)
+        self.assertEqual(server.socket.family, socket.AF_INET6)
+
+    def test_socket_uses_IPv4(self):
+        server = smtpd.SMTPServer((support.HOST, 0), (support.HOSTv6, 0),
+                                  decode_data=False)
+        self.assertEqual(server.socket.family, socket.AF_INET)
+
+
 class SMTPDChannelTest(unittest.TestCase):
     def setUp(self):
         smtpd.socket = asyncore.socket = mock_socket
         self.old_debugstream = smtpd.DEBUGSTREAM
         self.debug = smtpd.DEBUGSTREAM = io.StringIO()
-        self.server = DummyServer('a', 'b')
+        self.server = DummyServer((support.HOST, 0), ('b', 0))
         conn, addr = self.server.accept()
         self.channel = smtpd.SMTPChannel(self.server, conn, addr,
                                          decode_data=True)
@@ -79,7 +100,9 @@ class SMTPDChannelTest(unittest.TestCase):
         self.channel.handle_read()
 
     def test_broken_connect(self):
-        self.assertRaises(DummyDispatcherBroken, BrokenDummyServer, 'a', 'b')
+        self.assertRaises(
+            DummyDispatcherBroken, BrokenDummyServer,
+            (support.HOST, 0), ('b', 0))
 
     def test_server_accept(self):
         self.server.handle_accept()
@@ -513,11 +536,21 @@ class SMTPDChannelTest(unittest.TestCase):
             self.channel._SMTPChannel__addr = 'spam'
 
     def test_decode_data_default_warning(self):
-        server = DummyServer('a', 'b')
+        server = DummyServer((support.HOST, 0), ('b', 0))
         conn, addr = self.server.accept()
         with self.assertWarns(DeprecationWarning):
             smtpd.SMTPChannel(server, conn, addr)
 
+@unittest.skipUnless(support.IPV6_ENABLED, "IPv6 not enabled")
+class SMTPDChannelIPv6Test(SMTPDChannelTest):
+    def setUp(self):
+        smtpd.socket = asyncore.socket = mock_socket
+        self.old_debugstream = smtpd.DEBUGSTREAM
+        self.debug = smtpd.DEBUGSTREAM = io.StringIO()
+        self.server = DummyServer((support.HOSTv6, 0), ('b', 0))
+        conn, addr = self.server.accept()
+        self.channel = smtpd.SMTPChannel(self.server, conn, addr,
+                                         decode_data=True)
 
 class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase):
 
@@ -525,7 +558,7 @@ class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase):
         smtpd.socket = asyncore.socket = mock_socket
         self.old_debugstream = smtpd.DEBUGSTREAM
         self.debug = smtpd.DEBUGSTREAM = io.StringIO()
-        self.server = DummyServer('a', 'b')
+        self.server = DummyServer((support.HOST, 0), ('b', 0))
         conn, addr = self.server.accept()
         # Set DATA size limit to 32 bytes for easy testing
         self.channel = smtpd.SMTPChannel(self.server, conn, addr, 32,
@@ -576,7 +609,8 @@ class SMTPDChannelWithDecodeDataFalse(unittest.TestCase):
         smtpd.socket = asyncore.socket = mock_socket
         self.old_debugstream = smtpd.DEBUGSTREAM
         self.debug = smtpd.DEBUGSTREAM = io.StringIO()
-        self.server = DummyServer('a', 'b', decode_data=False)
+        self.server = DummyServer((support.HOST, 0), ('b', 0),
+                                  decode_data=False)
         conn, addr = self.server.accept()
         # Set decode_data to False
         self.channel = smtpd.SMTPChannel(self.server, conn, addr,
@@ -620,7 +654,8 @@ class SMTPDChannelWithDecodeDataTrue(unittest.TestCase):
         smtpd.socket = asyncore.socket = mock_socket
         self.old_debugstream = smtpd.DEBUGSTREAM
         self.debug = smtpd.DEBUGSTREAM = io.StringIO()
-        self.server = DummyServer('a', 'b')
+        self.server = DummyServer((support.HOST, 0), ('b', 0),
+                                  decode_data=True)
         conn, addr = self.server.accept()
         # Set decode_data to True
         self.channel = smtpd.SMTPChannel(self.server, conn, addr,