]> granicus.if.org Git - python/commitdiff
Close #22063: socket operations (socket,recv, sock_sendall, sock_connect,
authorVictor Stinner <victor.stinner@gmail.com>
Tue, 29 Jul 2014 21:08:17 +0000 (23:08 +0200)
committerVictor Stinner <victor.stinner@gmail.com>
Tue, 29 Jul 2014 21:08:17 +0000 (23:08 +0200)
sock_accept) now raise an exception in debug mode if sockets are in blocking
mode.

Lib/asyncio/proactor_events.py
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_events.py

index ab566b32757e784ee75d62456fa69fd37d61c646..751155bfed336105d7646718c1c9fda00f684498 100644 (file)
@@ -385,12 +385,18 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
         self._selector = None
 
     def sock_recv(self, sock, n):
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         return self._proactor.recv(sock, n)
 
     def sock_sendall(self, sock, data):
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         return self._proactor.send(sock, data)
 
     def sock_connect(self, sock, address):
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         try:
             base_events._check_resolved_address(sock, address)
         except ValueError as err:
@@ -401,6 +407,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
             return self._proactor.connect(sock, address)
 
     def sock_accept(self, sock):
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         return self._proactor.accept(sock)
 
     def _socketpair(self):
index eca48b8e7c2bc587d1c40987a09899b85f8d2451..6b7bdf01c139ac486c9da15718a1368c0202c947 100644 (file)
@@ -256,6 +256,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
         This method is a coroutine.
         """
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         fut = futures.Future(loop=self)
         self._sock_recv(fut, False, sock, n)
         return fut
@@ -292,6 +294,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
         This method is a coroutine.
         """
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         fut = futures.Future(loop=self)
         if data:
             self._sock_sendall(fut, False, sock, data)
@@ -333,6 +337,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
         This method is a coroutine.
         """
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         fut = futures.Future(loop=self)
         try:
             base_events._check_resolved_address(sock, address)
@@ -374,6 +380,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
         This method is a coroutine.
         """
+        if self.get_debug() and sock.gettimeout() != 0:
+            raise ValueError("the socket must be non-blocking")
         fut = futures.Future(loop=self)
         self._sock_accept(fut, False, sock)
         return fut
index b0657495ad4e6793e276972f979f5c6d49736fdf..0cff00ae6756c775e73ea2ed14a5ceec1b2110e4 100644 (file)
@@ -383,6 +383,24 @@ class EventLoopTestsMixin:
         self.assertEqual(read, data)
 
     def _basetest_sock_client_ops(self, httpd, sock):
+        # in debug mode, socket operations must fail
+        # if the socket is not in blocking mode
+        self.loop.set_debug(True)
+        sock.setblocking(True)
+        with self.assertRaises(ValueError):
+            self.loop.run_until_complete(
+                self.loop.sock_connect(sock, httpd.address))
+        with self.assertRaises(ValueError):
+            self.loop.run_until_complete(
+                self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+        with self.assertRaises(ValueError):
+            self.loop.run_until_complete(
+                self.loop.sock_recv(sock, 1024))
+        with self.assertRaises(ValueError):
+            self.loop.run_until_complete(
+                self.loop.sock_accept(sock))
+
+        # test in non-blocking mode
         sock.setblocking(False)
         self.loop.run_until_complete(
             self.loop.sock_connect(sock, httpd.address))