]> granicus.if.org Git - python/commitdiff
test_socket: use context managers to close directly the socket
authorVictor Stinner <victor.stinner@haypocalc.com>
Mon, 3 Jan 2011 14:30:46 +0000 (14:30 +0000)
committerVictor Stinner <victor.stinner@haypocalc.com>
Mon, 3 Jan 2011 14:30:46 +0000 (14:30 +0000)
Fix ResourceWarning(unclosed socket) warnings. Patch written by Nadeem Vawda.

Lib/test/test_socket.py

index 6bdb6c919e70438acc7e4d68f4eb0b1817e36d72..4c8c7d666ad980ce18abaef3541cb796491c8149 100644 (file)
@@ -1678,25 +1678,25 @@ class TestLinuxAbstractNamespace(unittest.TestCase):
 
     def testLinuxAbstractNamespace(self):
         address = b"\x00python-test-hello\x00\xff"
-        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        s1.bind(address)
-        s1.listen(1)
-        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        s2.connect(s1.getsockname())
-        s1.accept()
-        self.assertEqual(s1.getsockname(), address)
-        self.assertEqual(s2.getpeername(), address)
+        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1:
+            s1.bind(address)
+            s1.listen(1)
+            with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2:
+                s2.connect(s1.getsockname())
+                with s1.accept()[0] as s3:
+                    self.assertEqual(s1.getsockname(), address)
+                    self.assertEqual(s2.getpeername(), address)
 
     def testMaxName(self):
         address = b"\x00" + b"h" * (self.UNIX_PATH_MAX - 1)
-        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        s.bind(address)
-        self.assertEqual(s.getsockname(), address)
+        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
+            s.bind(address)
+            self.assertEqual(s.getsockname(), address)
 
     def testNameOverflow(self):
         address = "\x00" + "h" * self.UNIX_PATH_MAX
-        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
-        self.assertRaises(socket.error, s.bind, address)
+        with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
+            self.assertRaises(socket.error, s.bind, address)
 
 
 @unittest.skipUnless(thread, 'Threading required for this test.')
@@ -1898,10 +1898,10 @@ class CloexecConstantTest(unittest.TestCase):
         if v < (2, 6, 28):
             self.skipTest("Linux kernel 2.6.28 or higher required, not %s"
                           % ".".join(map(str, v)))
-        s = socket.socket(socket.AF_INET,
-                          socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
-        self.assertTrue(s.type & socket.SOCK_CLOEXEC)
-        self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
+        with socket.socket(socket.AF_INET,
+                           socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
+            self.assertTrue(s.type & socket.SOCK_CLOEXEC)
+            self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
 
 
 @unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"),
@@ -1922,29 +1922,33 @@ class NonblockConstantTest(unittest.TestCase):
                           % ".".join(map(str, v)))
         # a lot of it seems silly and redundant, but I wanted to test that
         # changing back and forth worked ok
-        s = socket.socket(socket.AF_INET,
-                          socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
-        self.checkNonblock(s)
-        s.setblocking(1)
-        self.checkNonblock(s, False)
-        s.setblocking(0)
-        self.checkNonblock(s)
-        s.settimeout(None)
-        self.checkNonblock(s, False)
-        s.settimeout(2.0)
-        self.checkNonblock(s, timeout=2.0)
-        s.setblocking(1)
-        self.checkNonblock(s, False)
+        with socket.socket(socket.AF_INET,
+                           socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s:
+            self.checkNonblock(s)
+            s.setblocking(1)
+            self.checkNonblock(s, False)
+            s.setblocking(0)
+            self.checkNonblock(s)
+            s.settimeout(None)
+            self.checkNonblock(s, False)
+            s.settimeout(2.0)
+            self.checkNonblock(s, timeout=2.0)
+            s.setblocking(1)
+            self.checkNonblock(s, False)
         # defaulttimeout
         t = socket.getdefaulttimeout()
         socket.setdefaulttimeout(0.0)
-        self.checkNonblock(socket.socket())
+        with socket.socket() as s:
+            self.checkNonblock(s)
         socket.setdefaulttimeout(None)
-        self.checkNonblock(socket.socket(), False)
+        with socket.socket() as s:
+            self.checkNonblock(s, False)
         socket.setdefaulttimeout(2.0)
-        self.checkNonblock(socket.socket(), timeout=2.0)
+        with socket.socket() as s:
+            self.checkNonblock(s, timeout=2.0)
         socket.setdefaulttimeout(None)
-        self.checkNonblock(socket.socket(), False)
+        with socket.socket() as s:
+            self.checkNonblock(s, False)
         socket.setdefaulttimeout(t)