]> granicus.if.org Git - python/commitdiff
bpo-32622: Native sendfile on windows (GH-5565)
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Sun, 25 Feb 2018 17:10:58 +0000 (09:10 -0800)
committerGitHub <noreply@github.com>
Sun, 25 Feb 2018 17:10:58 +0000 (09:10 -0800)
* Support sendfile on Windows Proactor event loop naively.
(cherry picked from commit a19fb3c6aaa7632410d1d9dcb395d7101d124da4)

Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
Lib/asyncio/proactor_events.py
Lib/asyncio/windows_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_proactor_events.py
Lib/test/test_asyncio/test_unix_events.py
Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst [new file with mode: 0644]
Modules/overlapped.c

index 10ca6f8967fb0571dd190d5f6dd4c552f54d7ecb..b675c8200ce30d8c0e62811b4ef0f79baca695df 100644 (file)
@@ -6,11 +6,14 @@ proactor is only implemented on Windows with IOCP.
 
 __all__ = 'BaseProactorEventLoop',
 
+import io
+import os
 import socket
 import warnings
 
 from . import base_events
 from . import constants
+from . import events
 from . import futures
 from . import protocols
 from . import sslproto
@@ -107,6 +110,11 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
             self._force_close(exc)
 
     def _force_close(self, exc):
+        if self._empty_waiter is not None:
+            if exc is None:
+                self._empty_waiter.set_result(None)
+            else:
+                self._empty_waiter.set_exception(exc)
         if self._closing:
             return
         self._closing = True
@@ -327,6 +335,10 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
 
     _start_tls_compatible = True
 
+    def __init__(self, *args, **kw):
+        super().__init__(*args, **kw)
+        self._empty_waiter = None
+
     def write(self, data):
         if not isinstance(data, (bytes, bytearray, memoryview)):
             raise TypeError(
@@ -334,6 +346,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
                 f"not {type(data).__name__}")
         if self._eof_written:
             raise RuntimeError('write_eof() already called')
+        if self._empty_waiter is not None:
+            raise RuntimeError('unable to write; sendfile is in progress')
 
         if not data:
             return
@@ -393,6 +407,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
                     self._maybe_pause_protocol()
                 else:
                     self._write_fut.add_done_callback(self._loop_writing)
+            if self._empty_waiter is not None and self._write_fut is None:
+                self._empty_waiter.set_result(None)
         except ConnectionResetError as exc:
             self._force_close(exc)
         except OSError as exc:
@@ -407,6 +423,17 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
     def abort(self):
         self._force_close(None)
 
+    def _make_empty_waiter(self):
+        if self._empty_waiter is not None:
+            raise RuntimeError("Empty waiter is already set")
+        self._empty_waiter = self._loop.create_future()
+        if self._write_fut is None:
+            self._empty_waiter.set_result(None)
+        return self._empty_waiter
+
+    def _reset_empty_waiter(self):
+        self._empty_waiter = None
+
 
 class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport):
     def __init__(self, *args, **kw):
@@ -447,7 +474,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
                                transports.Transport):
     """Transport for connected sockets."""
 
-    _sendfile_compatible = constants._SendfileMode.FALLBACK
+    _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
 
     def _set_extra(self, sock):
         self._extra['socket'] = sock
@@ -556,6 +583,47 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
     async def sock_accept(self, sock):
         return await self._proactor.accept(sock)
 
+    async def _sock_sendfile_native(self, sock, file, offset, count):
+        try:
+            fileno = file.fileno()
+        except (AttributeError, io.UnsupportedOperation) as err:
+            raise events.SendfileNotAvailableError("not a regular file")
+        try:
+            fsize = os.fstat(fileno).st_size
+        except OSError as err:
+            raise events.SendfileNotAvailableError("not a regular file")
+        blocksize = count if count else fsize
+        if not blocksize:
+            return 0  # empty file
+
+        blocksize = min(blocksize, 0xffff_ffff)
+        end_pos = min(offset + count, fsize) if count else fsize
+        offset = min(offset, fsize)
+        total_sent = 0
+        try:
+            while True:
+                blocksize = min(end_pos - offset, blocksize)
+                if blocksize <= 0:
+                    return total_sent
+                await self._proactor.sendfile(sock, file, offset, blocksize)
+                offset += blocksize
+                total_sent += blocksize
+        finally:
+            if total_sent > 0:
+                file.seek(offset)
+
+    async def _sendfile_native(self, transp, file, offset, count):
+        resume_reading = transp.is_reading()
+        transp.pause_reading()
+        await transp._make_empty_waiter()
+        try:
+            return await self.sock_sendfile(transp._sock, file, offset, count,
+                                            fallback=False)
+        finally:
+            transp._reset_empty_waiter()
+            if resume_reading:
+                transp.resume_reading()
+
     def _close_self_pipe(self):
         if self._self_reading_future is not None:
             self._self_reading_future.cancel()
index f91fcddb2aad32596d106fa5fdc22f2776e27c66..d22edec51efc1fbde3c1a01a3569a7b4ec0d0789 100644 (file)
@@ -4,6 +4,7 @@ import _overlapped
 import _winapi
 import errno
 import math
+import msvcrt
 import socket
 import struct
 import weakref
@@ -527,6 +528,27 @@ class IocpProactor:
 
         return self._register(ov, conn, finish_connect)
 
+    def sendfile(self, sock, file, offset, count):
+        self._register_with_iocp(sock)
+        ov = _overlapped.Overlapped(NULL)
+        offset_low = offset & 0xffff_ffff
+        offset_high = (offset >> 32) & 0xffff_ffff
+        ov.TransmitFile(sock.fileno(),
+                        msvcrt.get_osfhandle(file.fileno()),
+                        offset_low, offset_high,
+                        count, 0, 0)
+
+        def finish_sendfile(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+        return self._register(ov, sock, finish_sendfile)
+
     def accept_pipe(self, pipe):
         self._register_with_iocp(pipe)
         ov = _overlapped.Overlapped(NULL)
index f5995974c683a6a9549576cdbbb9e86ff9f93a0d..6accbdae8b3e90a36c70df6a51c6e6794ca5f0f5 100644 (file)
@@ -15,6 +15,7 @@ except ImportError:
     ssl = None
 import subprocess
 import sys
+import tempfile
 import threading
 import time
 import errno
@@ -2092,22 +2093,7 @@ class SubprocessTestsMixin:
             self.loop.run_until_complete(connect(shell=False))
 
 
-class MySendfileProto(MyBaseProto):
-
-    def __init__(self, loop=None, close_after=0):
-        super().__init__(loop)
-        self.data = bytearray()
-        self.close_after = close_after
-
-    def data_received(self, data):
-        self.data.extend(data)
-        super().data_received(data)
-        if self.close_after and self.nbytes >= self.close_after:
-            self.transport.close()
-
-
-class SendfileMixin:
-    # Note: sendfile via SSL transport is equal to sendfile fallback
+class SendfileBase:
 
     DATA = b"12345abcde" * 160 * 1024  # 160 KiB
 
@@ -2130,9 +2116,134 @@ class SendfileMixin:
     def run_loop(self, coro):
         return self.loop.run_until_complete(coro)
 
-    def prepare(self, *, is_ssl=False, close_after=0):
+
+class SockSendfileMixin(SendfileBase):
+
+    class MyProto(asyncio.Protocol):
+
+        def __init__(self, loop):
+            self.started = False
+            self.closed = False
+            self.data = bytearray()
+            self.fut = loop.create_future()
+            self.transport = None
+
+        def connection_made(self, transport):
+            self.started = True
+            self.transport = transport
+
+        def data_received(self, data):
+            self.data.extend(data)
+
+        def connection_lost(self, exc):
+            self.closed = True
+            self.fut.set_result(None)
+
+        async def wait_closed(self):
+            await self.fut
+
+    def make_socket(self, cleanup=True):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setblocking(False)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+        if cleanup:
+            self.addCleanup(sock.close)
+        return sock
+
+    def prepare_socksendfile(self):
+        sock = self.make_socket()
+        proto = self.MyProto(self.loop)
+        port = support.find_unused_port()
+        srv_sock = self.make_socket(cleanup=False)
+        srv_sock.bind((support.HOST, port))
+        server = self.run_loop(self.loop.create_server(
+            lambda: proto, sock=srv_sock))
+        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
+
+        def cleanup():
+            if proto.transport is not None:
+                # can be None if the task was cancelled before
+                # connection_made callback
+                proto.transport.close()
+                self.run_loop(proto.wait_closed())
+
+            server.close()
+            self.run_loop(server.wait_closed())
+
+        self.addCleanup(cleanup)
+
+        return sock, proto
+
+    def test_sock_sendfile_success(self):
+        sock, proto = self.prepare_socksendfile()
+        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sock_sendfile_with_offset_and_count(self):
+        sock, proto = self.prepare_socksendfile()
+        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
+                                                    1000, 2000))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(proto.data, self.DATA[1000:3000])
+        self.assertEqual(self.file.tell(), 3000)
+        self.assertEqual(ret, 2000)
+
+    def test_sock_sendfile_zero_size(self):
+        sock, proto = self.prepare_socksendfile()
+        with tempfile.TemporaryFile() as f:
+            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
+                                                        0, None))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(ret, 0)
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sock_sendfile_mix_with_regular_send(self):
+        buf = b'1234567890' * 1024 * 1024  # 10 MB
+        sock, proto = self.prepare_socksendfile()
+        self.run_loop(self.loop.sock_sendall(sock, buf))
+        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+        self.run_loop(self.loop.sock_sendall(sock, buf))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(ret, len(self.DATA))
+        expected = buf + self.DATA + buf
+        self.assertEqual(proto.data, expected)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+
+class SendfileMixin(SendfileBase):
+
+    class MySendfileProto(MyBaseProto):
+
+        def __init__(self, loop=None, close_after=0):
+            super().__init__(loop)
+            self.data = bytearray()
+            self.close_after = close_after
+
+        def data_received(self, data):
+            self.data.extend(data)
+            super().data_received(data)
+            if self.close_after and self.nbytes >= self.close_after:
+                self.transport.close()
+
+
+    # Note: sendfile via SSL transport is equal to sendfile fallback
+
+    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
         port = support.find_unused_port()
-        srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
+        srv_proto = self.MySendfileProto(loop=self.loop,
+                                         close_after=close_after)
         if is_ssl:
             if not ssl:
                 self.skipTest("No ssl module")
@@ -2156,7 +2267,7 @@ class SendfileMixin:
         # reduce send socket buffer size to test on relative small data sets
         cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
         cli_sock.connect((support.HOST, port))
-        cli_proto = MySendfileProto(loop=self.loop)
+        cli_proto = self.MySendfileProto(loop=self.loop)
         tr, pr = self.run_loop(self.loop.create_connection(
             lambda: cli_proto, sock=cli_sock,
             ssl=cli_ctx, server_hostname=server_hostname))
@@ -2189,7 +2300,7 @@ class SendfileMixin:
             tr.close()
 
     def test_sendfile(self):
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
         ret = self.run_loop(
             self.loop.sendfile(cli_proto.transport, self.file))
         cli_proto.transport.close()
@@ -2200,7 +2311,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), len(self.DATA))
 
     def test_sendfile_force_fallback(self):
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
 
         def sendfile_native(transp, file, offset, count):
             # to raise SendfileNotAvailableError
@@ -2222,7 +2333,7 @@ class SendfileMixin:
         if sys.platform == 'win32':
             if isinstance(self.loop, asyncio.ProactorEventLoop):
                 self.skipTest("Fails on proactor event loop")
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
 
         def sendfile_native(transp, file, offset, count):
             # to raise SendfileNotAvailableError
@@ -2243,7 +2354,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), 0)
 
     def test_sendfile_ssl(self):
-        srv_proto, cli_proto = self.prepare(is_ssl=True)
+        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
         ret = self.run_loop(
             self.loop.sendfile(cli_proto.transport, self.file))
         cli_proto.transport.close()
@@ -2254,7 +2365,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), len(self.DATA))
 
     def test_sendfile_for_closing_transp(self):
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
         cli_proto.transport.close()
         with self.assertRaisesRegex(RuntimeError, "is closing"):
             self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
@@ -2263,7 +2374,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), 0)
 
     def test_sendfile_pre_and_post_data(self):
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
         PREFIX = b'zxcvbnm' * 1024
         SUFFIX = b'0987654321' * 1024
         cli_proto.transport.write(PREFIX)
@@ -2277,7 +2388,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), len(self.DATA))
 
     def test_sendfile_ssl_pre_and_post_data(self):
-        srv_proto, cli_proto = self.prepare(is_ssl=True)
+        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
         PREFIX = b'zxcvbnm' * 1024
         SUFFIX = b'0987654321' * 1024
         cli_proto.transport.write(PREFIX)
@@ -2291,7 +2402,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), len(self.DATA))
 
     def test_sendfile_partial(self):
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
         ret = self.run_loop(
             self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
         cli_proto.transport.close()
@@ -2302,7 +2413,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), 1100)
 
     def test_sendfile_ssl_partial(self):
-        srv_proto, cli_proto = self.prepare(is_ssl=True)
+        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
         ret = self.run_loop(
             self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
         cli_proto.transport.close()
@@ -2313,7 +2424,8 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), 1100)
 
     def test_sendfile_close_peer_after_receiving(self):
-        srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
+        srv_proto, cli_proto = self.prepare_sendfile(
+            close_after=len(self.DATA))
         ret = self.run_loop(
             self.loop.sendfile(cli_proto.transport, self.file))
         cli_proto.transport.close()
@@ -2324,8 +2436,8 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), len(self.DATA))
 
     def test_sendfile_ssl_close_peer_after_receiving(self):
-        srv_proto, cli_proto = self.prepare(is_ssl=True,
-                                            close_after=len(self.DATA))
+        srv_proto, cli_proto = self.prepare_sendfile(
+            is_ssl=True, close_after=len(self.DATA))
         ret = self.run_loop(
             self.loop.sendfile(cli_proto.transport, self.file))
         self.run_loop(srv_proto.done)
@@ -2335,7 +2447,7 @@ class SendfileMixin:
         self.assertEqual(self.file.tell(), len(self.DATA))
 
     def test_sendfile_close_peer_in_middle_of_receiving(self):
-        srv_proto, cli_proto = self.prepare(close_after=1024)
+        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
         with self.assertRaises(ConnectionError):
             self.run_loop(
                 self.loop.sendfile(cli_proto.transport, self.file))
@@ -2345,6 +2457,7 @@ class SendfileMixin:
                         srv_proto.nbytes)
         self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
                         self.file.tell())
+        self.assertTrue(cli_proto.transport.is_closing())
 
     def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
 
@@ -2355,7 +2468,7 @@ class SendfileMixin:
 
         self.loop._sendfile_native = sendfile_native
 
-        srv_proto, cli_proto = self.prepare(close_after=1024)
+        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
         with self.assertRaises(ConnectionError):
             self.run_loop(
                 self.loop.sendfile(cli_proto.transport, self.file))
@@ -2369,7 +2482,7 @@ class SendfileMixin:
     @unittest.skipIf(not hasattr(os, 'sendfile'),
                      "Don't have native sendfile support")
     def test_sendfile_prevents_bare_write(self):
-        srv_proto, cli_proto = self.prepare()
+        srv_proto, cli_proto = self.prepare_sendfile()
         fut = self.loop.create_future()
 
         async def coro():
@@ -2397,6 +2510,7 @@ if sys.platform == 'win32':
 
     class SelectEventLoopTests(EventLoopTestsMixin,
                                SendfileMixin,
+                               SockSendfileMixin,
                                test_utils.TestCase):
 
         def create_event_loop(self):
@@ -2404,6 +2518,7 @@ if sys.platform == 'win32':
 
     class ProactorEventLoopTests(EventLoopTestsMixin,
                                  SendfileMixin,
+                                 SockSendfileMixin,
                                  SubprocessTestsMixin,
                                  test_utils.TestCase):
 
@@ -2431,7 +2546,9 @@ if sys.platform == 'win32':
 else:
     import selectors
 
-    class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
+    class UnixEventLoopTestsMixin(EventLoopTestsMixin,
+                                  SendfileMixin,
+                                  SockSendfileMixin):
         def setUp(self):
             super().setUp()
             watcher = asyncio.SafeChildWatcher()
index f627dfce0e17d11728d65214d24d20773ad827c5..98e698983eab5c1f7cb65c4b379595b78a1e2136 100644 (file)
@@ -1,14 +1,18 @@
 """Tests for proactor_events.py"""
 
+import io
 import socket
 import unittest
+import sys
 from unittest import mock
 
 import asyncio
+from asyncio import events
 from asyncio.proactor_events import BaseProactorEventLoop
 from asyncio.proactor_events import _ProactorSocketTransport
 from asyncio.proactor_events import _ProactorWritePipeTransport
 from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from test import support
 from test.test_asyncio import utils as test_utils
 
 
@@ -775,5 +779,117 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
         self.assertFalse(future2.cancel.called)
 
 
+@unittest.skipIf(sys.platform != 'win32',
+                 'Proactor is supported on Windows only')
+class ProactorEventLoopUnixSockSendfileTests(test_utils.TestCase):
+    DATA = b"12345abcde" * 16 * 1024  # 160 KiB
+
+    class MyProto(asyncio.Protocol):
+
+        def __init__(self, loop):
+            self.started = False
+            self.closed = False
+            self.data = bytearray()
+            self.fut = loop.create_future()
+            self.transport = None
+
+        def connection_made(self, transport):
+            self.started = True
+            self.transport = transport
+
+        def data_received(self, data):
+            self.data.extend(data)
+
+        def connection_lost(self, exc):
+            self.closed = True
+            self.fut.set_result(None)
+
+        async def wait_closed(self):
+            await self.fut
+
+    @classmethod
+    def setUpClass(cls):
+        with open(support.TESTFN, 'wb') as fp:
+            fp.write(cls.DATA)
+        super().setUpClass()
+
+    @classmethod
+    def tearDownClass(cls):
+        support.unlink(support.TESTFN)
+        super().tearDownClass()
+
+    def setUp(self):
+        self.loop = asyncio.ProactorEventLoop()
+        self.set_event_loop(self.loop)
+        self.addCleanup(self.loop.close)
+        self.file = open(support.TESTFN, 'rb')
+        self.addCleanup(self.file.close)
+        super().setUp()
+
+    def make_socket(self, cleanup=True):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setblocking(False)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+        if cleanup:
+            self.addCleanup(sock.close)
+        return sock
+
+    def run_loop(self, coro):
+        return self.loop.run_until_complete(coro)
+
+    def prepare(self):
+        sock = self.make_socket()
+        proto = self.MyProto(self.loop)
+        port = support.find_unused_port()
+        srv_sock = self.make_socket(cleanup=False)
+        srv_sock.bind(('127.0.0.1', port))
+        server = self.run_loop(self.loop.create_server(
+            lambda: proto, sock=srv_sock))
+        self.run_loop(self.loop.sock_connect(sock, srv_sock.getsockname()))
+
+        def cleanup():
+            if proto.transport is not None:
+                # can be None if the task was cancelled before
+                # connection_made callback
+                proto.transport.close()
+                self.run_loop(proto.wait_closed())
+
+            server.close()
+            self.run_loop(server.wait_closed())
+
+        self.addCleanup(cleanup)
+
+        return sock, proto
+
+    def test_sock_sendfile_not_a_file(self):
+        sock, proto = self.prepare()
+        f = object()
+        with self.assertRaisesRegex(events.SendfileNotAvailableError,
+                                    "not a regular file"):
+            self.run_loop(self.loop._sock_sendfile_native(sock, f,
+                                                          0, None))
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sock_sendfile_iobuffer(self):
+        sock, proto = self.prepare()
+        f = io.BytesIO()
+        with self.assertRaisesRegex(events.SendfileNotAvailableError,
+                                    "not a regular file"):
+            self.run_loop(self.loop._sock_sendfile_native(sock, f,
+                                                          0, None))
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sock_sendfile_not_regular_file(self):
+        sock, proto = self.prepare()
+        f = mock.Mock()
+        f.fileno.return_value = -1
+        with self.assertRaisesRegex(events.SendfileNotAvailableError,
+                                    "not a regular file"):
+            self.run_loop(self.loop._sock_sendfile_native(sock, f,
+                                                          0, None))
+        self.assertEqual(self.file.tell(), 0)
+
+
 if __name__ == '__main__':
     unittest.main()
index 5bd76d30d2db03971d6c14916ad8847d6c1d2c90..104f995937972e60672029919547aa77a8a22b35 100644 (file)
@@ -466,10 +466,13 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
         self.addCleanup(self.file.close)
         super().setUp()
 
-    def make_socket(self, blocking=False):
+    def make_socket(self, cleanup=True):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        sock.setblocking(blocking)
-        self.addCleanup(sock.close)
+        sock.setblocking(False)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+        if cleanup:
+            self.addCleanup(sock.close)
         return sock
 
     def run_loop(self, coro):
@@ -479,8 +482,10 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
         sock = self.make_socket()
         proto = self.MyProto(self.loop)
         port = support.find_unused_port()
+        srv_sock = self.make_socket(cleanup=False)
+        srv_sock.bind((support.HOST, port))
         server = self.run_loop(self.loop.create_server(
-            lambda: proto, support.HOST, port))
+            lambda: proto, sock=srv_sock))
         self.run_loop(self.loop.sock_connect(sock, (support.HOST, port)))
 
         def cleanup():
@@ -497,27 +502,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
 
         return sock, proto
 
-    def test_sock_sendfile_success(self):
-        sock, proto = self.prepare()
-        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sock_sendfile_with_offset_and_count(self):
-        sock, proto = self.prepare()
-        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
-                                                    1000, 2000))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(proto.data, self.DATA[1000:3000])
-        self.assertEqual(self.file.tell(), 3000)
-        self.assertEqual(ret, 2000)
-
     def test_sock_sendfile_not_available(self):
         sock, proto = self.prepare()
         with mock.patch('asyncio.unix_events.os', spec=[]):
@@ -555,36 +539,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
                                                           0, None))
         self.assertEqual(self.file.tell(), 0)
 
-    def test_sock_sendfile_zero_size(self):
-        sock, proto = self.prepare()
-        fname = support.TESTFN + '.suffix'
-        with open(fname, 'wb') as f:
-            pass  # make zero sized file
-        f = open(fname, 'rb')
-        self.addCleanup(f.close)
-        self.addCleanup(support.unlink, fname)
-        ret = self.run_loop(self.loop._sock_sendfile_native(sock, f,
-                                                            0, None))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(ret, 0)
-        self.assertEqual(self.file.tell(), 0)
-
-    def test_sock_sendfile_mix_with_regular_send(self):
-        buf = b'1234567890' * 1024 * 1024  # 10 MB
-        sock, proto = self.prepare()
-        self.run_loop(self.loop.sock_sendall(sock, buf))
-        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
-        self.run_loop(self.loop.sock_sendall(sock, buf))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(ret, len(self.DATA))
-        expected = buf + self.DATA + buf
-        self.assertEqual(proto.data, expected)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
     def test_sock_sendfile_cancel1(self):
         sock, proto = self.prepare()
 
diff --git a/Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst b/Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst
new file mode 100644 (file)
index 0000000..456a6dc
--- /dev/null
@@ -0,0 +1 @@
+Implement native fast sendfile for Windows proactor event loop.
index 447a337fdd1fc3cf5d335f663c033333d63b6c97..ae7cddadd02df73648250fefe108a8bb3c3ec5c5 100644 (file)
@@ -39,7 +39,7 @@
 
 enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
       TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
-      TYPE_WAIT_NAMED_PIPE_AND_CONNECT};
+      TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE};
 
 typedef struct {
     PyObject_HEAD
@@ -89,6 +89,7 @@ SetFromWindowsErr(DWORD err)
 static LPFN_ACCEPTEX Py_AcceptEx = NULL;
 static LPFN_CONNECTEX Py_ConnectEx = NULL;
 static LPFN_DISCONNECTEX Py_DisconnectEx = NULL;
+static LPFN_TRANSMITFILE Py_TransmitFile = NULL;
 static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL;
 
 #define GET_WSA_POINTER(s, x)                                           \
@@ -102,6 +103,7 @@ initialize_function_pointers(void)
     GUID GuidAcceptEx = WSAID_ACCEPTEX;
     GUID GuidConnectEx = WSAID_CONNECTEX;
     GUID GuidDisconnectEx = WSAID_DISCONNECTEX;
+    GUID GuidTransmitFile = WSAID_TRANSMITFILE;
     HINSTANCE hKernel32;
     SOCKET s;
     DWORD dwBytes;
@@ -114,7 +116,8 @@ initialize_function_pointers(void)
 
     if (!GET_WSA_POINTER(s, AcceptEx) ||
         !GET_WSA_POINTER(s, ConnectEx) ||
-        !GET_WSA_POINTER(s, DisconnectEx))
+        !GET_WSA_POINTER(s, DisconnectEx) ||
+        !GET_WSA_POINTER(s, TransmitFile))
     {
         closesocket(s);
         SetFromWindowsErr(WSAGetLastError());
@@ -1194,6 +1197,61 @@ Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args)
     }
 }
 
+PyDoc_STRVAR(
+    Overlapped_TransmitFile_doc,
+    "TransmitFile(socket, file, offset, offset_high, "
+    "count_to_write, count_per_send, flags) "
+    "-> Overlapped[None]\n\n"
+    "Transmit file data over a connected socket.");
+
+static PyObject *
+Overlapped_TransmitFile(OverlappedObject *self, PyObject *args)
+{
+    SOCKET Socket;
+    HANDLE File;
+    DWORD offset;
+    DWORD offset_high;
+    DWORD count_to_write;
+    DWORD count_per_send;
+    DWORD flags;
+    BOOL ret;
+    DWORD err;
+
+    if (!PyArg_ParseTuple(args,
+                          F_HANDLE F_HANDLE F_DWORD F_DWORD
+                          F_DWORD F_DWORD F_DWORD,
+                          &Socket, &File, &offset, &offset_high,
+                          &count_to_write, &count_per_send,
+                          &flags))
+        return NULL;
+
+    if (self->type != TYPE_NONE) {
+        PyErr_SetString(PyExc_ValueError, "operation already attempted");
+        return NULL;
+    }
+
+    self->type = TYPE_TRANSMIT_FILE;
+    self->handle = (HANDLE)Socket;
+    self->overlapped.Offset = offset;
+    self->overlapped.OffsetHigh = offset_high;
+
+    Py_BEGIN_ALLOW_THREADS
+    ret = Py_TransmitFile(Socket, File, count_to_write, count_per_send,
+                          &self->overlapped,
+                          NULL, flags);
+    Py_END_ALLOW_THREADS
+
+    self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError();
+    switch (err) {
+        case ERROR_SUCCESS:
+        case ERROR_IO_PENDING:
+            Py_RETURN_NONE;
+        default:
+            self->type = TYPE_NOT_STARTED;
+            return SetFromWindowsErr(err);
+    }
+}
+
 PyDoc_STRVAR(
     Overlapped_ConnectNamedPipe_doc,
     "ConnectNamedPipe(handle) -> Overlapped[None]\n\n"
@@ -1303,6 +1361,8 @@ static PyMethodDef Overlapped_methods[] = {
      METH_VARARGS, Overlapped_ConnectEx_doc},
     {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx,
      METH_VARARGS, Overlapped_DisconnectEx_doc},
+    {"TransmitFile", (PyCFunction) Overlapped_TransmitFile,
+     METH_VARARGS, Overlapped_TransmitFile_doc},
     {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe,
      METH_VARARGS, Overlapped_ConnectNamedPipe_doc},
     {NULL}