]> granicus.if.org Git - python/commitdiff
Fixes Issue #14635: telnetlib will use poll() rather than select() when possible
authorGregory P. Smith <greg@krypto.org>
Mon, 16 Jul 2012 06:42:26 +0000 (23:42 -0700)
committerGregory P. Smith <greg@krypto.org>
Mon, 16 Jul 2012 06:42:26 +0000 (23:42 -0700)
to avoid failing due to the select() file descriptor limit.

Lib/telnetlib.py
Lib/test/test_telnetlib.py
Misc/ACKS
Misc/NEWS

index 82b5e8fc1b11b2bb52f015e5154dfc5e88d1dd02..a59693e746315de8babd832f8c42c91d124fcf62 100644 (file)
@@ -34,6 +34,7 @@ To do:
 
 
 # Imported modules
+import errno
 import sys
 import socket
 import select
@@ -205,6 +206,7 @@ class Telnet:
         self.sb = 0 # flag for SB and SE sequence.
         self.sbdataq = b''
         self.option_callback = None
+        self._has_poll = hasattr(select, 'poll')
         if host is not None:
             self.open(host, port, timeout)
 
@@ -286,6 +288,61 @@ class Telnet:
         possibly the empty string.  Raise EOFError if the connection
         is closed and no cooked data is available.
 
+        """
+        if self._has_poll:
+            return self._read_until_with_poll(match, timeout)
+        else:
+            return self._read_until_with_select(match, timeout)
+
+    def _read_until_with_poll(self, match, timeout):
+        """Read until a given string is encountered or until timeout.
+
+        This method uses select.poll() to implement the timeout.
+        """
+        n = len(match)
+        call_timeout = timeout
+        if timeout is not None:
+            from time import time
+            time_start = time()
+        self.process_rawq()
+        i = self.cookedq.find(match)
+        if i < 0:
+            poller = select.poll()
+            poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
+            poller.register(self, poll_in_or_priority_flags)
+            while i < 0 and not self.eof:
+                try:
+                    ready = poller.poll(call_timeout)
+                except select.error as e:
+                    if e.errno == errno.EINTR:
+                        if timeout is not None:
+                            elapsed = time() - time_start
+                            call_timeout = timeout-elapsed
+                        continue
+                    raise
+                for fd, mode in ready:
+                    if mode & poll_in_or_priority_flags:
+                        i = max(0, len(self.cookedq)-n)
+                        self.fill_rawq()
+                        self.process_rawq()
+                        i = self.cookedq.find(match, i)
+                if timeout is not None:
+                    elapsed = time() - time_start
+                    if elapsed >= timeout:
+                        break
+                    call_timeout = timeout-elapsed
+            poller.unregister(self)
+        if i >= 0:
+            i = i + n
+            buf = self.cookedq[:i]
+            self.cookedq = self.cookedq[i:]
+            return buf
+        return self.read_very_lazy()
+
+    def _read_until_with_select(self, match, timeout=None):
+        """Read until a given string is encountered or until timeout.
+
+        The timeout is implemented using select.select().
         """
         n = len(match)
         self.process_rawq()
@@ -588,6 +645,79 @@ class Telnet:
         or if more than one expression can match the same input, the
         results are undeterministic, and may depend on the I/O timing.
 
+        """
+        if self._has_poll:
+            return self._expect_with_poll(list, timeout)
+        else:
+            return self._expect_with_select(list, timeout)
+
+    def _expect_with_poll(self, expect_list, timeout=None):
+        """Read until one from a list of a regular expressions matches.
+
+        This method uses select.poll() to implement the timeout.
+        """
+        re = None
+        expect_list = expect_list[:]
+        indices = range(len(expect_list))
+        for i in indices:
+            if not hasattr(expect_list[i], "search"):
+                if not re: import re
+                expect_list[i] = re.compile(expect_list[i])
+        call_timeout = timeout
+        if timeout is not None:
+            from time import time
+            time_start = time()
+        self.process_rawq()
+        m = None
+        for i in indices:
+            m = expect_list[i].search(self.cookedq)
+            if m:
+                e = m.end()
+                text = self.cookedq[:e]
+                self.cookedq = self.cookedq[e:]
+                break
+        if not m:
+            poller = select.poll()
+            poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
+            poller.register(self, poll_in_or_priority_flags)
+            while not m and not self.eof:
+                try:
+                    ready = poller.poll(call_timeout)
+                except select.error as e:
+                    if e.errno == errno.EINTR:
+                        if timeout is not None:
+                            elapsed = time() - time_start
+                            call_timeout = timeout-elapsed
+                        continue
+                    raise
+                for fd, mode in ready:
+                    if mode & poll_in_or_priority_flags:
+                        self.fill_rawq()
+                        self.process_rawq()
+                        for i in indices:
+                            m = expect_list[i].search(self.cookedq)
+                            if m:
+                                e = m.end()
+                                text = self.cookedq[:e]
+                                self.cookedq = self.cookedq[e:]
+                                break
+                if timeout is not None:
+                    elapsed = time() - time_start
+                    if elapsed >= timeout:
+                        break
+                    call_timeout = timeout-elapsed
+            poller.unregister(self)
+        if m:
+            return (i, m, text)
+        text = self.read_very_lazy()
+        if not text and self.eof:
+            raise EOFError
+        return (-1, None, text)
+
+    def _expect_with_select(self, list, timeout=None):
+        """Read until one from a list of a regular expressions matches.
+
+        The timeout is implemented using select.select().
         """
         re = None
         list = list[:]
index 87418f514e037c4b1971297759aebb637e8710c5..38da08c72832843c46b85472672d40f4f894f7a9 100644 (file)
@@ -75,8 +75,8 @@ class GeneralTests(TestCase):
 
 class SocketStub(object):
     ''' a socket proxy that re-defines sendall() '''
-    def __init__(self, reads=[]):
-        self.reads = reads
+    def __init__(self, reads=()):
+        self.reads = list(reads)  # Intentionally make a copy.
         self.writes = []
         self.block = False
     def sendall(self, data):
@@ -102,7 +102,7 @@ class TelnetAlike(telnetlib.Telnet):
         self._messages += out.getvalue()
         return
 
-def new_select(*s_args):
+def mock_select(*s_args):
     block = False
     for l in s_args:
         for fob in l:
@@ -113,6 +113,30 @@ def new_select(*s_args):
     else:
         return s_args
 
+class MockPoller(object):
+    test_case = None  # Set during TestCase setUp.
+
+    def __init__(self):
+        self._file_objs = []
+
+    def register(self, fd, eventmask):
+        self.test_case.assertTrue(hasattr(fd, 'fileno'), fd)
+        self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI)
+        self._file_objs.append(fd)
+
+    def poll(self, timeout=None):
+        block = False
+        for fob in self._file_objs:
+            if isinstance(fob, TelnetAlike):
+                block = fob.sock.block
+        if block:
+            return []
+        else:
+            return zip(self._file_objs, [select.POLLIN]*len(self._file_objs))
+
+    def unregister(self, fd):
+        self._file_objs.remove(fd)
+
 @contextlib.contextmanager
 def test_socket(reads):
     def new_conn(*ignored):
@@ -125,7 +149,7 @@ def test_socket(reads):
         socket.create_connection = old_conn
     return
 
-def test_telnet(reads=[], cls=TelnetAlike):
+def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
     ''' return a telnetlib.Telnet object that uses a SocketStub with
         reads queued up to be read '''
     for x in reads:
@@ -133,15 +157,28 @@ def test_telnet(reads=[], cls=TelnetAlike):
     with test_socket(reads):
         telnet = cls('dummy', 0)
         telnet._messages = '' # debuglevel output
+        if use_poll is not None:
+            if use_poll and not telnet._has_poll:
+                raise unittest.SkipTest('select.poll() required.')
+            telnet._has_poll = use_poll
     return telnet
 
-class ReadTests(TestCase):
+
+class ExpectAndReadTestCase(TestCase):
     def setUp(self):
         self.old_select = select.select
-        select.select = new_select
+        self.old_poll = select.poll
+        select.select = mock_select
+        select.poll = MockPoller
+        MockPoller.test_case = self
+
     def tearDown(self):
+        MockPoller.test_case = None
+        select.poll = self.old_poll
         select.select = self.old_select
 
+
+class ReadTests(ExpectAndReadTestCase):
     def test_read_until(self):
         """
         read_until(expected, timeout=None)
@@ -158,6 +195,21 @@ class ReadTests(TestCase):
         data = telnet.read_until(b'match')
         self.assertEqual(data, expect)
 
+    def test_read_until_with_poll(self):
+        """Use select.poll() to implement telnet.read_until()."""
+        want = [b'x' * 10, b'match', b'y' * 10]
+        telnet = test_telnet(want, use_poll=True)
+        select.select = lambda *_: self.fail('unexpected select() call.')
+        data = telnet.read_until(b'match')
+        self.assertEqual(data, b''.join(want[:-1]))
+
+    def test_read_until_with_select(self):
+        """Use select.select() to implement telnet.read_until()."""
+        want = [b'x' * 10, b'match', b'y' * 10]
+        telnet = test_telnet(want, use_poll=False)
+        select.poll = lambda *_: self.fail('unexpected poll() call.')
+        data = telnet.read_until(b'match')
+        self.assertEqual(data, b''.join(want[:-1]))
 
     def test_read_all(self):
         """
@@ -349,8 +401,38 @@ class OptionTests(TestCase):
         self.assertRegex(telnet._messages, r'0.*test')
 
 
+class ExpectTests(ExpectAndReadTestCase):
+    def test_expect(self):
+        """
+        expect(expected, [timeout])
+          Read until the expected string has been seen, or a timeout is
+          hit (default is no timeout); may block.
+        """
+        want = [b'x' * 10, b'match', b'y' * 10]
+        telnet = test_telnet(want)
+        (_,_,data) = telnet.expect([b'match'])
+        self.assertEqual(data, b''.join(want[:-1]))
+
+    def test_expect_with_poll(self):
+        """Use select.poll() to implement telnet.expect()."""
+        want = [b'x' * 10, b'match', b'y' * 10]
+        telnet = test_telnet(want, use_poll=True)
+        select.select = lambda *_: self.fail('unexpected select() call.')
+        (_,_,data) = telnet.expect([b'match'])
+        self.assertEqual(data, b''.join(want[:-1]))
+
+    def test_expect_with_select(self):
+        """Use select.select() to implement telnet.expect()."""
+        want = [b'x' * 10, b'match', b'y' * 10]
+        telnet = test_telnet(want, use_poll=False)
+        select.poll = lambda *_: self.fail('unexpected poll() call.')
+        (_,_,data) = telnet.expect([b'match'])
+        self.assertEqual(data, b''.join(want[:-1]))
+
+
 def test_main(verbose=None):
-    support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests)
+    support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,
+                         ExpectTests)
 
 if __name__ == '__main__':
     test_main()
index 26fe1d06a727280b55e9424b68de92bd998aaeb4..3bf81a266bccc997391d03f7dff23735c6676542 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -410,6 +410,7 @@ Chris Hoffman
 Albert Hofkamp
 Tomas Hoger
 Jonathan Hogg
+Akintayo Holder
 Gerrit Holl
 Shane Holloway
 Rune Holm
index 74e40387a2f93a417f102578bdb5a6ae042a7980..1d5353a20ea02b21588c78df87e63b37ead03774 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -87,6 +87,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #14635: telnetlib will use poll() rather than select() when possible
+  to avoid failing due to the select() file descriptor limit.
+
 - Issue #15180: Clarify posixpath.join() error message when mixing str & bytes
 
 - Issue #15230: runpy.run_path now correctly sets __package__ as described