__all__ = 'BaseProactorEventLoop',
+import io
+import os
import socket
import warnings
from . import base_events
from . import constants
+from . import events
from . import futures
from . import protocols
from . import sslproto
self._force_close(exc)
def _force_close(self, exc):
+ if self._empty_waiter is not None:
+ if exc is None:
+ self._empty_waiter.set_result(None)
+ else:
+ self._empty_waiter.set_exception(exc)
if self._closing:
return
self._closing = True
_start_tls_compatible = True
+ def __init__(self, *args, **kw):
+ super().__init__(*args, **kw)
+ self._empty_waiter = None
+
def write(self, data):
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError(
f"not {type(data).__name__}")
if self._eof_written:
raise RuntimeError('write_eof() already called')
+ if self._empty_waiter is not None:
+ raise RuntimeError('unable to write; sendfile is in progress')
if not data:
return
self._maybe_pause_protocol()
else:
self._write_fut.add_done_callback(self._loop_writing)
+ if self._empty_waiter is not None and self._write_fut is None:
+ self._empty_waiter.set_result(None)
except ConnectionResetError as exc:
self._force_close(exc)
except OSError as exc:
def abort(self):
self._force_close(None)
+ def _make_empty_waiter(self):
+ if self._empty_waiter is not None:
+ raise RuntimeError("Empty waiter is already set")
+ self._empty_waiter = self._loop.create_future()
+ if self._write_fut is None:
+ self._empty_waiter.set_result(None)
+ return self._empty_waiter
+
+ def _reset_empty_waiter(self):
+ self._empty_waiter = None
+
class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport):
def __init__(self, *args, **kw):
transports.Transport):
"""Transport for connected sockets."""
- _sendfile_compatible = constants._SendfileMode.FALLBACK
+ _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
def _set_extra(self, sock):
self._extra['socket'] = sock
async def sock_accept(self, sock):
return await self._proactor.accept(sock)
+ async def _sock_sendfile_native(self, sock, file, offset, count):
+ try:
+ fileno = file.fileno()
+ except (AttributeError, io.UnsupportedOperation) as err:
+ raise events.SendfileNotAvailableError("not a regular file")
+ try:
+ fsize = os.fstat(fileno).st_size
+ except OSError as err:
+ raise events.SendfileNotAvailableError("not a regular file")
+ blocksize = count if count else fsize
+ if not blocksize:
+ return 0 # empty file
+
+ blocksize = min(blocksize, 0xffff_ffff)
+ end_pos = min(offset + count, fsize) if count else fsize
+ offset = min(offset, fsize)
+ total_sent = 0
+ try:
+ while True:
+ blocksize = min(end_pos - offset, blocksize)
+ if blocksize <= 0:
+ return total_sent
+ await self._proactor.sendfile(sock, file, offset, blocksize)
+ offset += blocksize
+ total_sent += blocksize
+ finally:
+ if total_sent > 0:
+ file.seek(offset)
+
+ async def _sendfile_native(self, transp, file, offset, count):
+ resume_reading = transp.is_reading()
+ transp.pause_reading()
+ await transp._make_empty_waiter()
+ try:
+ return await self.sock_sendfile(transp._sock, file, offset, count,
+ fallback=False)
+ finally:
+ transp._reset_empty_waiter()
+ if resume_reading:
+ transp.resume_reading()
+
def _close_self_pipe(self):
if self._self_reading_future is not None:
self._self_reading_future.cancel()
import _winapi
import errno
import math
+import msvcrt
import socket
import struct
import weakref
return self._register(ov, conn, finish_connect)
+ def sendfile(self, sock, file, offset, count):
+ self._register_with_iocp(sock)
+ ov = _overlapped.Overlapped(NULL)
+ offset_low = offset & 0xffff_ffff
+ offset_high = (offset >> 32) & 0xffff_ffff
+ ov.TransmitFile(sock.fileno(),
+ msvcrt.get_osfhandle(file.fileno()),
+ offset_low, offset_high,
+ count, 0, 0)
+
+ def finish_sendfile(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+ _overlapped.ERROR_OPERATION_ABORTED):
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+ return self._register(ov, sock, finish_sendfile)
+
def accept_pipe(self, pipe):
self._register_with_iocp(pipe)
ov = _overlapped.Overlapped(NULL)
ssl = None
import subprocess
import sys
+import tempfile
import threading
import time
import errno
self.loop.run_until_complete(connect(shell=False))
-class MySendfileProto(MyBaseProto):
-
- def __init__(self, loop=None, close_after=0):
- super().__init__(loop)
- self.data = bytearray()
- self.close_after = close_after
-
- def data_received(self, data):
- self.data.extend(data)
- super().data_received(data)
- if self.close_after and self.nbytes >= self.close_after:
- self.transport.close()
-
-
-class SendfileMixin:
- # Note: sendfile via SSL transport is equal to sendfile fallback
+class SendfileBase:
DATA = b"12345abcde" * 160 * 1024 # 160 KiB
def run_loop(self, coro):
return self.loop.run_until_complete(coro)
- def prepare(self, *, is_ssl=False, close_after=0):
+
+class SockSendfileMixin(SendfileBase):
+
+ class MyProto(asyncio.Protocol):
+
+ def __init__(self, loop):
+ self.started = False
+ self.closed = False
+ self.data = bytearray()
+ self.fut = loop.create_future()
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.started = True
+ self.transport = transport
+
+ def data_received(self, data):
+ self.data.extend(data)
+
+ def connection_lost(self, exc):
+ self.closed = True
+ self.fut.set_result(None)
+
+ async def wait_closed(self):
+ await self.fut
+
+ def make_socket(self, cleanup=True):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setblocking(False)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+ if cleanup:
+ self.addCleanup(sock.close)
+ return sock
+
+ def prepare_socksendfile(self):
+ sock = self.make_socket()
+ proto = self.MyProto(self.loop)
+ port = support.find_unused_port()
+ srv_sock = self.make_socket(cleanup=False)
+ srv_sock.bind((support.HOST, port))
+ server = self.run_loop(self.loop.create_server(
+ lambda: proto, sock=srv_sock))
+ self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
+
+ def cleanup():
+ if proto.transport is not None:
+ # can be None if the task was cancelled before
+ # connection_made callback
+ proto.transport.close()
+ self.run_loop(proto.wait_closed())
+
+ server.close()
+ self.run_loop(server.wait_closed())
+
+ self.addCleanup(cleanup)
+
+ return sock, proto
+
+ def test_sock_sendfile_success(self):
+ sock, proto = self.prepare_socksendfile()
+ ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(ret, len(self.DATA))
+ self.assertEqual(proto.data, self.DATA)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+ def test_sock_sendfile_with_offset_and_count(self):
+ sock, proto = self.prepare_socksendfile()
+ ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
+ 1000, 2000))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(proto.data, self.DATA[1000:3000])
+ self.assertEqual(self.file.tell(), 3000)
+ self.assertEqual(ret, 2000)
+
+ def test_sock_sendfile_zero_size(self):
+ sock, proto = self.prepare_socksendfile()
+ with tempfile.TemporaryFile() as f:
+ ret = self.run_loop(self.loop.sock_sendfile(sock, f,
+ 0, None))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(ret, 0)
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sock_sendfile_mix_with_regular_send(self):
+ buf = b'1234567890' * 1024 * 1024 # 10 MB
+ sock, proto = self.prepare_socksendfile()
+ self.run_loop(self.loop.sock_sendall(sock, buf))
+ ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+ self.run_loop(self.loop.sock_sendall(sock, buf))
+ sock.close()
+ self.run_loop(proto.wait_closed())
+
+ self.assertEqual(ret, len(self.DATA))
+ expected = buf + self.DATA + buf
+ self.assertEqual(proto.data, expected)
+ self.assertEqual(self.file.tell(), len(self.DATA))
+
+
+class SendfileMixin(SendfileBase):
+
+ class MySendfileProto(MyBaseProto):
+
+ def __init__(self, loop=None, close_after=0):
+ super().__init__(loop)
+ self.data = bytearray()
+ self.close_after = close_after
+
+ def data_received(self, data):
+ self.data.extend(data)
+ super().data_received(data)
+ if self.close_after and self.nbytes >= self.close_after:
+ self.transport.close()
+
+
+ # Note: sendfile via SSL transport is equal to sendfile fallback
+
+ def prepare_sendfile(self, *, is_ssl=False, close_after=0):
port = support.find_unused_port()
- srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
+ srv_proto = self.MySendfileProto(loop=self.loop,
+ close_after=close_after)
if is_ssl:
if not ssl:
self.skipTest("No ssl module")
# reduce send socket buffer size to test on relative small data sets
cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
cli_sock.connect((support.HOST, port))
- cli_proto = MySendfileProto(loop=self.loop)
+ cli_proto = self.MySendfileProto(loop=self.loop)
tr, pr = self.run_loop(self.loop.create_connection(
lambda: cli_proto, sock=cli_sock,
ssl=cli_ctx, server_hostname=server_hostname))
tr.close()
def test_sendfile(self):
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_force_fallback(self):
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
self.skipTest("Fails on proactor event loop")
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
self.assertEqual(self.file.tell(), 0)
def test_sendfile_ssl(self):
- srv_proto, cli_proto = self.prepare(is_ssl=True)
+ srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_for_closing_transp(self):
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
cli_proto.transport.close()
with self.assertRaisesRegex(RuntimeError, "is closing"):
self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
self.assertEqual(self.file.tell(), 0)
def test_sendfile_pre_and_post_data(self):
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_pre_and_post_data(self):
- srv_proto, cli_proto = self.prepare(is_ssl=True)
+ srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_partial(self):
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_ssl_partial(self):
- srv_proto, cli_proto = self.prepare(is_ssl=True)
+ srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_close_peer_after_receiving(self):
- srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
+ srv_proto, cli_proto = self.prepare_sendfile(
+ close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_close_peer_after_receiving(self):
- srv_proto, cli_proto = self.prepare(is_ssl=True,
- close_after=len(self.DATA))
+ srv_proto, cli_proto = self.prepare_sendfile(
+ is_ssl=True, close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_close_peer_in_middle_of_receiving(self):
- srv_proto, cli_proto = self.prepare(close_after=1024)
+ srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
srv_proto.nbytes)
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
self.file.tell())
+ self.assertTrue(cli_proto.transport.is_closing())
def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
self.loop._sendfile_native = sendfile_native
- srv_proto, cli_proto = self.prepare(close_after=1024)
+ srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
@unittest.skipIf(not hasattr(os, 'sendfile'),
"Don't have native sendfile support")
def test_sendfile_prevents_bare_write(self):
- srv_proto, cli_proto = self.prepare()
+ srv_proto, cli_proto = self.prepare_sendfile()
fut = self.loop.create_future()
async def coro():
class SelectEventLoopTests(EventLoopTestsMixin,
SendfileMixin,
+ SockSendfileMixin,
test_utils.TestCase):
def create_event_loop(self):
class ProactorEventLoopTests(EventLoopTestsMixin,
SendfileMixin,
+ SockSendfileMixin,
SubprocessTestsMixin,
test_utils.TestCase):
else:
import selectors
- class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
+ class UnixEventLoopTestsMixin(EventLoopTestsMixin,
+ SendfileMixin,
+ SockSendfileMixin):
def setUp(self):
super().setUp()
watcher = asyncio.SafeChildWatcher()
"""Tests for proactor_events.py"""
+import io
import socket
import unittest
+import sys
from unittest import mock
import asyncio
+from asyncio import events
from asyncio.proactor_events import BaseProactorEventLoop
from asyncio.proactor_events import _ProactorSocketTransport
from asyncio.proactor_events import _ProactorWritePipeTransport
from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from test import support
from test.test_asyncio import utils as test_utils
self.assertFalse(future2.cancel.called)
+@unittest.skipIf(sys.platform != 'win32',
+ 'Proactor is supported on Windows only')
+class ProactorEventLoopUnixSockSendfileTests(test_utils.TestCase):
+ DATA = b"12345abcde" * 16 * 1024 # 160 KiB
+
+ class MyProto(asyncio.Protocol):
+
+ def __init__(self, loop):
+ self.started = False
+ self.closed = False
+ self.data = bytearray()
+ self.fut = loop.create_future()
+ self.transport = None
+
+ def connection_made(self, transport):
+ self.started = True
+ self.transport = transport
+
+ def data_received(self, data):
+ self.data.extend(data)
+
+ def connection_lost(self, exc):
+ self.closed = True
+ self.fut.set_result(None)
+
+ async def wait_closed(self):
+ await self.fut
+
+ @classmethod
+ def setUpClass(cls):
+ with open(support.TESTFN, 'wb') as fp:
+ fp.write(cls.DATA)
+ super().setUpClass()
+
+ @classmethod
+ def tearDownClass(cls):
+ support.unlink(support.TESTFN)
+ super().tearDownClass()
+
+ def setUp(self):
+ self.loop = asyncio.ProactorEventLoop()
+ self.set_event_loop(self.loop)
+ self.addCleanup(self.loop.close)
+ self.file = open(support.TESTFN, 'rb')
+ self.addCleanup(self.file.close)
+ super().setUp()
+
+ def make_socket(self, cleanup=True):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setblocking(False)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+ if cleanup:
+ self.addCleanup(sock.close)
+ return sock
+
+ def run_loop(self, coro):
+ return self.loop.run_until_complete(coro)
+
+ def prepare(self):
+ sock = self.make_socket()
+ proto = self.MyProto(self.loop)
+ port = support.find_unused_port()
+ srv_sock = self.make_socket(cleanup=False)
+ srv_sock.bind(('127.0.0.1', port))
+ server = self.run_loop(self.loop.create_server(
+ lambda: proto, sock=srv_sock))
+ self.run_loop(self.loop.sock_connect(sock, srv_sock.getsockname()))
+
+ def cleanup():
+ if proto.transport is not None:
+ # can be None if the task was cancelled before
+ # connection_made callback
+ proto.transport.close()
+ self.run_loop(proto.wait_closed())
+
+ server.close()
+ self.run_loop(server.wait_closed())
+
+ self.addCleanup(cleanup)
+
+ return sock, proto
+
+ def test_sock_sendfile_not_a_file(self):
+ sock, proto = self.prepare()
+ f = object()
+ with self.assertRaisesRegex(events.SendfileNotAvailableError,
+ "not a regular file"):
+ self.run_loop(self.loop._sock_sendfile_native(sock, f,
+ 0, None))
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sock_sendfile_iobuffer(self):
+ sock, proto = self.prepare()
+ f = io.BytesIO()
+ with self.assertRaisesRegex(events.SendfileNotAvailableError,
+ "not a regular file"):
+ self.run_loop(self.loop._sock_sendfile_native(sock, f,
+ 0, None))
+ self.assertEqual(self.file.tell(), 0)
+
+ def test_sock_sendfile_not_regular_file(self):
+ sock, proto = self.prepare()
+ f = mock.Mock()
+ f.fileno.return_value = -1
+ with self.assertRaisesRegex(events.SendfileNotAvailableError,
+ "not a regular file"):
+ self.run_loop(self.loop._sock_sendfile_native(sock, f,
+ 0, None))
+ self.assertEqual(self.file.tell(), 0)
+
+
if __name__ == '__main__':
unittest.main()
self.addCleanup(self.file.close)
super().setUp()
- def make_socket(self, blocking=False):
+ def make_socket(self, cleanup=True):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setblocking(blocking)
- self.addCleanup(sock.close)
+ sock.setblocking(False)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+ if cleanup:
+ self.addCleanup(sock.close)
return sock
def run_loop(self, coro):
sock = self.make_socket()
proto = self.MyProto(self.loop)
port = support.find_unused_port()
+ srv_sock = self.make_socket(cleanup=False)
+ srv_sock.bind((support.HOST, port))
server = self.run_loop(self.loop.create_server(
- lambda: proto, support.HOST, port))
+ lambda: proto, sock=srv_sock))
self.run_loop(self.loop.sock_connect(sock, (support.HOST, port)))
def cleanup():
return sock, proto
- def test_sock_sendfile_success(self):
- sock, proto = self.prepare()
- ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(ret, len(self.DATA))
- self.assertEqual(proto.data, self.DATA)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
- def test_sock_sendfile_with_offset_and_count(self):
- sock, proto = self.prepare()
- ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
- 1000, 2000))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(proto.data, self.DATA[1000:3000])
- self.assertEqual(self.file.tell(), 3000)
- self.assertEqual(ret, 2000)
-
def test_sock_sendfile_not_available(self):
sock, proto = self.prepare()
with mock.patch('asyncio.unix_events.os', spec=[]):
0, None))
self.assertEqual(self.file.tell(), 0)
- def test_sock_sendfile_zero_size(self):
- sock, proto = self.prepare()
- fname = support.TESTFN + '.suffix'
- with open(fname, 'wb') as f:
- pass # make zero sized file
- f = open(fname, 'rb')
- self.addCleanup(f.close)
- self.addCleanup(support.unlink, fname)
- ret = self.run_loop(self.loop._sock_sendfile_native(sock, f,
- 0, None))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(ret, 0)
- self.assertEqual(self.file.tell(), 0)
-
- def test_sock_sendfile_mix_with_regular_send(self):
- buf = b'1234567890' * 1024 * 1024 # 10 MB
- sock, proto = self.prepare()
- self.run_loop(self.loop.sock_sendall(sock, buf))
- ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
- self.run_loop(self.loop.sock_sendall(sock, buf))
- sock.close()
- self.run_loop(proto.wait_closed())
-
- self.assertEqual(ret, len(self.DATA))
- expected = buf + self.DATA + buf
- self.assertEqual(proto.data, expected)
- self.assertEqual(self.file.tell(), len(self.DATA))
-
def test_sock_sendfile_cancel1(self):
sock, proto = self.prepare()
--- /dev/null
+Implement native fast sendfile for Windows proactor event loop.
enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
- TYPE_WAIT_NAMED_PIPE_AND_CONNECT};
+ TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE};
typedef struct {
PyObject_HEAD
static LPFN_ACCEPTEX Py_AcceptEx = NULL;
static LPFN_CONNECTEX Py_ConnectEx = NULL;
static LPFN_DISCONNECTEX Py_DisconnectEx = NULL;
+static LPFN_TRANSMITFILE Py_TransmitFile = NULL;
static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL;
#define GET_WSA_POINTER(s, x) \
GUID GuidAcceptEx = WSAID_ACCEPTEX;
GUID GuidConnectEx = WSAID_CONNECTEX;
GUID GuidDisconnectEx = WSAID_DISCONNECTEX;
+ GUID GuidTransmitFile = WSAID_TRANSMITFILE;
HINSTANCE hKernel32;
SOCKET s;
DWORD dwBytes;
if (!GET_WSA_POINTER(s, AcceptEx) ||
!GET_WSA_POINTER(s, ConnectEx) ||
- !GET_WSA_POINTER(s, DisconnectEx))
+ !GET_WSA_POINTER(s, DisconnectEx) ||
+ !GET_WSA_POINTER(s, TransmitFile))
{
closesocket(s);
SetFromWindowsErr(WSAGetLastError());
}
}
+PyDoc_STRVAR(
+ Overlapped_TransmitFile_doc,
+ "TransmitFile(socket, file, offset, offset_high, "
+ "count_to_write, count_per_send, flags) "
+ "-> Overlapped[None]\n\n"
+ "Transmit file data over a connected socket.");
+
+static PyObject *
+Overlapped_TransmitFile(OverlappedObject *self, PyObject *args)
+{
+ SOCKET Socket;
+ HANDLE File;
+ DWORD offset;
+ DWORD offset_high;
+ DWORD count_to_write;
+ DWORD count_per_send;
+ DWORD flags;
+ BOOL ret;
+ DWORD err;
+
+ if (!PyArg_ParseTuple(args,
+ F_HANDLE F_HANDLE F_DWORD F_DWORD
+ F_DWORD F_DWORD F_DWORD,
+ &Socket, &File, &offset, &offset_high,
+ &count_to_write, &count_per_send,
+ &flags))
+ return NULL;
+
+ if (self->type != TYPE_NONE) {
+ PyErr_SetString(PyExc_ValueError, "operation already attempted");
+ return NULL;
+ }
+
+ self->type = TYPE_TRANSMIT_FILE;
+ self->handle = (HANDLE)Socket;
+ self->overlapped.Offset = offset;
+ self->overlapped.OffsetHigh = offset_high;
+
+ Py_BEGIN_ALLOW_THREADS
+ ret = Py_TransmitFile(Socket, File, count_to_write, count_per_send,
+ &self->overlapped,
+ NULL, flags);
+ Py_END_ALLOW_THREADS
+
+ self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError();
+ switch (err) {
+ case ERROR_SUCCESS:
+ case ERROR_IO_PENDING:
+ Py_RETURN_NONE;
+ default:
+ self->type = TYPE_NOT_STARTED;
+ return SetFromWindowsErr(err);
+ }
+}
+
PyDoc_STRVAR(
Overlapped_ConnectNamedPipe_doc,
"ConnectNamedPipe(handle) -> Overlapped[None]\n\n"
METH_VARARGS, Overlapped_ConnectEx_doc},
{"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx,
METH_VARARGS, Overlapped_DisconnectEx_doc},
+ {"TransmitFile", (PyCFunction) Overlapped_TransmitFile,
+ METH_VARARGS, Overlapped_TransmitFile_doc},
{"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe,
METH_VARARGS, Overlapped_ConnectNamedPipe_doc},
{NULL}