]> granicus.if.org Git - python/commitdiff
Fix issue 9794: adds context manager protocol to socket.socket so that socket.create_...
authorGiampaolo Rodolà <g.rodola@gmail.com>
Wed, 8 Sep 2010 22:44:12 +0000 (22:44 +0000)
committerGiampaolo Rodolà <g.rodola@gmail.com>
Wed, 8 Sep 2010 22:44:12 +0000 (22:44 +0000)
Doc/library/socket.rst
Doc/whatsnew/3.2.rst
Lib/socket.py
Lib/test/test_socket.py

index 581756fe671e970daa0585a50dff496dc5bad0a4..a7656c1ec31527432b5ea06e93f1269899841e52 100644 (file)
@@ -213,6 +213,9 @@ The module :mod:`socket` exports the following constants and functions:
    .. versionchanged:: 3.2
       *source_address* was added.
 
+   .. versionchanged:: 3.2
+      support for the :keyword:`with` statement was added.
+
 
 .. function:: getaddrinfo(host, port, family=0, type=0, proto=0, flags=0)
 
index 49696233f98265bd84ed7fd74e9c403c84097aee..7d8970be8437df07f81e678043b5cca862b72f7f 100644 (file)
@@ -389,6 +389,12 @@ New, Improved, and Deprecated Modules
 
   (Contributed by Giampaolo Rodolà; :issue:`8807`.)
 
+* :func:`socket.create_connection` now supports the context manager protocol
+  to unconditionally consume :exc:`socket.error` exceptions and to close the
+  socket when done.
+
+  (Contributed by Giampaolo Rodolà; :issue:`9794`.)
+
 
 Multi-threading
 ===============
index 004d6a9445fe10765fe92813ab8ecfb6383007ec..bfc9a726554a5ce82b74469fec15b7c8387557f6 100644 (file)
@@ -93,6 +93,13 @@ class socket(_socket.socket):
         self._io_refs = 0
         self._closed = False
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args):
+        if not self._closed:
+            self.close()
+
     def __repr__(self):
         """Wrap __repr__() to reveal the real class name."""
         s = _socket.socket.__repr__(self)
index 19c494b127b4325731eab71c657230c519d9b7d3..81f9cdf7f213afbb9fc358a40d719584ab679935 100644 (file)
@@ -1595,6 +1595,49 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
         self.cli.close()
 
 
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class ContextManagersTest(ThreadedTCPSocketTest):
+
+    def _testSocketClass(self):
+        # base test
+        with socket.socket() as sock:
+            self.assertFalse(sock._closed)
+        self.assertTrue(sock._closed)
+        # close inside with block
+        with socket.socket() as sock:
+            sock.close()
+        self.assertTrue(sock._closed)
+        # exception inside with block
+        with socket.socket() as sock:
+            self.assertRaises(socket.error, sock.sendall, b'foo')
+        self.assertTrue(sock._closed)
+
+    def testCreateConnectionBase(self):
+        conn, addr = self.serv.accept()
+        data = conn.recv(1024)
+        conn.sendall(data)
+
+    def _testCreateConnectionBase(self):
+        address = self.serv.getsockname()
+        with socket.create_connection(address) as sock:
+            self.assertFalse(sock._closed)
+            sock.sendall(b'foo')
+            self.assertEqual(sock.recv(1024), b'foo')
+        self.assertTrue(sock._closed)
+
+    def testCreateConnectionClose(self):
+        conn, addr = self.serv.accept()
+        data = conn.recv(1024)
+        conn.sendall(data)
+
+    def _testCreateConnectionClose(self):
+        address = self.serv.getsockname()
+        with socket.create_connection(address) as sock:
+            sock.close()
+        self.assertTrue(sock._closed)
+        self.assertRaises(socket.error, sock.sendall, b'foo')
+
+
 def test_main():
     tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
              TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@@ -1609,6 +1652,7 @@ def test_main():
         NetworkConnectionNoServer,
         NetworkConnectionAttributesTest,
         NetworkConnectionBehaviourTest,
+        ContextManagersTest,
     ])
     if hasattr(socket, "socketpair"):
         tests.append(BasicSocketPairTest)