]> granicus.if.org Git - python/commitdiff
This is roughly socket2.diff from issue 1378, with a few changes applied
authorGuido van Rossum <guido@python.org>
Fri, 16 Nov 2007 01:24:05 +0000 (01:24 +0000)
committerGuido van Rossum <guido@python.org>
Fri, 16 Nov 2007 01:24:05 +0000 (01:24 +0000)
to ssl.py (no need to test whether we can dup any more).
Regular sockets no longer have a _base, but we still have explicit
reference counting of socket objects for the benefit of makefile();
using duplicate sockets won't work for SSLSocket.

Include/longobject.h
Lib/socket.py
Lib/ssl.py
Lib/test/test_socket.py
Modules/socketmodule.c

index 6bf3409611214d6e6db2402b5a9475e6106e8798..16abd0e90331cf19209ee8fee645efd87708a01d 100644 (file)
@@ -26,6 +26,15 @@ PyAPI_FUNC(size_t) PyLong_AsSize_t(PyObject *);
 PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLong(PyObject *);
 PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLongMask(PyObject *);
 
+/* Used by socketmodule.c */
+#if SIZEOF_SOCKET_T <= SIZEOF_LONG
+#define PyLong_FromSocket_t(fd) PyLong_FromLong((SOCKET_T)(fd))
+#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLong(fd)
+#else
+#define PyLong_FromSocket_t(fd) PyLong_FromLongLong(((SOCKET_T)(fd));
+#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLongLong(fd)
+#endif
+
 /* For use by intobject.c only */
 PyAPI_DATA(int) _PyLong_DigitValue[256];
 
index 6a9a381b7488e095f96e88bbb742db04c2cd704c..62eb82dcd1889a410d4ddefc6d1b04b2ce06fd4f 100644 (file)
@@ -79,28 +79,14 @@ if sys.platform.lower().startswith("win"):
     __all__.append("errorTab")
 
 
-# True if os.dup() can duplicate socket descriptors.
-# (On Windows at least, os.dup only works on files)
-_can_dup_socket = hasattr(_socket.socket, "dup")
-
-if _can_dup_socket:
-    def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
-        nfd = os.dup(fd)
-        return socket(family, type, proto, fileno=nfd)
-
 class socket(_socket.socket):
 
     """A subclass of _socket.socket adding the makefile() method."""
 
     __slots__ = ["__weakref__", "_io_refs", "_closed"]
-    if not _can_dup_socket:
-        __slots__.append("_base")
 
     def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
-        if fileno is None:
-            _socket.socket.__init__(self, family, type, proto)
-        else:
-            _socket.socket.__init__(self, family, type, proto, fileno)
+        _socket.socket.__init__(self, family, type, proto, fileno)
         self._io_refs = 0
         self._closed = False
 
@@ -114,23 +100,29 @@ class socket(_socket.socket):
                                 s[7:])
         return s
 
+    def dup(self):
+        """dup() -> socket object
+
+        Return a new socket object connected to the same system resource.
+        """
+        fd = dup(self.fileno())
+        sock = self.__class__(self.family, self.type, self.proto, fileno=fd)
+        sock.settimeout(self.gettimeout())
+        return sock
+
     def accept(self):
-        """Wrap accept() to give the connection the right type."""
-        conn, addr = _socket.socket.accept(self)
-        fd = conn.fileno()
-        nfd = fd
-        if _can_dup_socket:
-            nfd = os.dup(fd)
-        wrapper = socket(self.family, self.type, self.proto, fileno=nfd)
-        if fd == nfd:
-            wrapper._base = conn  # Keep the base alive
-        else:
-            conn.close()
-        return wrapper, addr
+        """accept() -> (socket object, address info)
+
+        Wait for an incoming connection.  Return a new socket
+        representing the connection, and the address of the client.
+        For IP sockets, the address info is a pair (hostaddr, port).
+        """
+        fd, addr = self._accept()
+        return socket(self.family, self.type, self.proto, fileno=fd), addr
 
     def makefile(self, mode="r", buffering=None, *,
                  encoding=None, newline=None):
-        """Return an I/O stream connected to the socket.
+        """makefile(...) -> an I/O stream connected to the socket
 
         The arguments are as for io.open() after the filename,
         except the only mode characters supported are 'r', 'w' and 'b'.
@@ -184,21 +176,18 @@ class socket(_socket.socket):
 
     def close(self):
         self._closed = True
-        if self._io_refs < 1:
-            self._real_close()
+        if self._io_refs <= 0:
+            _socket.socket.close(self)
 
-    # _real_close calls close on the _socket.socket base class.
 
-    if not _can_dup_socket:
-        def _real_close(self):
-            _socket.socket.close(self)
-            base = getattr(self, "_base", None)
-            if base is not None:
-                self._base = None
-                base.close()
-    else:
-        def _real_close(self):
-            _socket.socket.close(self)
+def fromfd(fd, family, type, proto=0):
+    """ fromfd(fd, family, type[, proto]) -> socket object
+
+    Create a socket object from a duplicate of the given file
+    descriptor.  The remaining arguments are the same as for socket().
+    """
+    nfd = dup(fd)
+    return socket(family, type, proto, nfd)
 
 
 class SocketIO(io.RawIOBase):
index c2cfa31c4441f7892f229792d14b10a61e8e8de8..9d63d12ce3babfbf2e40329d885fd089e804d300 100644 (file)
@@ -78,8 +78,8 @@ from _ssl import (
 from socket import socket, AF_INET, SOCK_STREAM, error
 from socket import getnameinfo as _getnameinfo
 from socket import error as socket_error
+from socket import dup as _dup
 import base64        # for DER-to-PEM translation
-_can_dup_socket = hasattr(socket, "dup")
 
 class SSLSocket(socket):
 
@@ -99,20 +99,11 @@ class SSLSocket(socket):
         if sock is not None:
             # copied this code from socket.accept()
             fd = sock.fileno()
-            nfd = fd
-            if _can_dup_socket:
-                nfd = os.dup(fd)
-            try:
-                socket.__init__(self, family=sock.family, type=sock.type,
-                                proto=sock.proto, fileno=nfd)
-            except:
-                if nfd != fd:
-                    os.close(nfd)
-            else:
-                if fd != nfd:
-                    sock.close()
-                    sock = None
-
+            nfd = _dup(fd)
+            socket.__init__(self, family=sock.family, type=sock.type,
+                            proto=sock.proto, fileno=nfd)
+            sock.close()
+            sock = None
         elif fileno is not None:
             socket.__init__(self, fileno=fileno)
         else:
index 82eb6e7e3c36192834c2e3e4aa04ace712bcb195..c01d998ae0b39d4991044f502be5c488b68e0b98 100644 (file)
@@ -575,6 +575,15 @@ class BasicTCPTest(SocketConnectedTest):
     def _testFromFd(self):
         self.serv_conn.send(MSG)
 
+    def testDup(self):
+        # Testing dup()
+        sock = self.cli_conn.dup()
+        msg = sock.recv(1024)
+        self.assertEqual(msg, MSG)
+
+    def _testDup(self):
+        self.serv_conn.send(MSG)
+
     def testShutdown(self):
         # Testing shutdown()
         msg = self.cli_conn.recv(1024)
index 30e8d2271211b671c1f40cb1e72ec934dd5f6dc6..f5ad29286522912a5d7117d0d5e6fae48517ee38 100644 (file)
@@ -89,12 +89,12 @@ A socket object represents one endpoint of a network connection.\n\
 \n\
 Methods of socket objects (keyword arguments not allowed):\n\
 \n\
-accept() -- accept a connection, returning new socket and client address\n\
+_accept() -- accept connection, returning new socket fd and client address\n\
 bind(addr) -- bind the socket to a local address\n\
 close() -- close the socket\n\
 connect(addr) -- connect the socket to a remote address\n\
 connect_ex(addr) -- connect, return an error code instead of an exception\n\
-dup() -- return a new socket object identical to the current one [*]\n\
+_dup() -- return a new socket fd duplicated from fileno()\n\
 fileno() -- return underlying file descriptor\n\
 getpeername() -- return remote address [*]\n\
 getsockname() -- return local address\n\
@@ -327,10 +327,26 @@ const char *inet_ntop(int af, const void *src, char *dst, socklen_t size);
 #include "getnameinfo.c"
 #endif
 
-#if defined(MS_WINDOWS)
-/* seem to be a few differences in the API */
+#ifdef MS_WINDOWS
+/* On Windows a socket is really a handle not an fd */
+static SOCKET
+dup_socket(SOCKET handle)
+{
+       HANDLE newhandle;
+
+       if (!DuplicateHandle(GetCurrentProcess(), (HANDLE)handle,
+                            GetCurrentProcess(), &newhandle,
+                            0, FALSE, DUPLICATE_SAME_ACCESS))
+       {
+               WSASetLastError(GetLastError());
+               return INVALID_SOCKET;
+       }
+       return (SOCKET)newhandle;
+}
 #define SOCKETCLOSE closesocket
-#define NO_DUP /* Actually it exists on NT 3.5, but what the heck... */
+#else
+/* On Unix we can use dup to duplicate the file descriptor of a socket*/
+#define dup_socket(fd) dup(fd)
 #endif
 
 #ifdef MS_WIN32
@@ -628,7 +644,7 @@ internal_select(PySocketSockObject *s, int writing)
                pollfd.events = writing ? POLLOUT : POLLIN;
 
                /* s->sock_timeout is in seconds, timeout in ms */
-               timeout = (int)(s->sock_timeout * 1000 + 0.5); 
+               timeout = (int)(s->sock_timeout * 1000 + 0.5);
                n = poll(&pollfd, 1, timeout);
        }
 #else
@@ -648,7 +664,7 @@ internal_select(PySocketSockObject *s, int writing)
                        n = select(s->sock_fd+1, &fds, NULL, NULL, &tv);
        }
 #endif
-       
+
        if (n < 0)
                return -1;
        if (n == 0)
@@ -1423,7 +1439,7 @@ getsockaddrlen(PySocketSockObject *s, socklen_t *len_ret)
 }
 
 
-/* s.accept() method */
+/* s._accept() -> (fd, address) */
 
 static PyObject *
 sock_accept(PySocketSockObject *s)
@@ -1457,17 +1473,12 @@ sock_accept(PySocketSockObject *s)
        if (newfd == INVALID_SOCKET)
                return s->errorhandler();
 
-       /* Create the new object with unspecified family,
-          to avoid calls to bind() etc. on it. */
-       sock = (PyObject *) new_sockobject(newfd,
-                                          s->sock_family,
-                                          s->sock_type,
-                                          s->sock_proto);
-
+       sock = PyLong_FromSocket_t(newfd);
        if (sock == NULL) {
                SOCKETCLOSE(newfd);
                goto finally;
        }
+
        addr = makesockaddr(s->sock_fd, SAS2SA(&addrbuf),
                            addrlen, s->sock_proto);
        if (addr == NULL)
@@ -1482,11 +1493,11 @@ finally:
 }
 
 PyDoc_STRVAR(accept_doc,
-"accept() -> (socket object, address info)\n\
+"_accept() -> (integer, address info)\n\
 \n\
-Wait for an incoming connection.  Return a new socket representing the\n\
-connection, and the address of the client.  For IP sockets, the address\n\
-info is a pair (hostaddr, port).");
+Wait for an incoming connection.  Return a new socket file descriptor\n\
+representing the connection, and the address of the client.\n\
+For IP sockets, the address info is a pair (hostaddr, port).");
 
 /* s.setblocking(flag) method.  Argument:
    False -- non-blocking mode; same as settimeout(0)
@@ -1882,11 +1893,7 @@ instead of raising an exception when an error occurs.");
 static PyObject *
 sock_fileno(PySocketSockObject *s)
 {
-#if SIZEOF_SOCKET_T <= SIZEOF_LONG
-       return PyInt_FromLong((long) s->sock_fd);
-#else
-       return PyLong_FromLongLong((PY_LONG_LONG)s->sock_fd);
-#endif
+       return PyLong_FromSocket_t(s->sock_fd);
 }
 
 PyDoc_STRVAR(fileno_doc,
@@ -1895,35 +1902,6 @@ PyDoc_STRVAR(fileno_doc,
 Return the integer file descriptor of the socket.");
 
 
-#ifndef NO_DUP
-/* s.dup() method */
-
-static PyObject *
-sock_dup(PySocketSockObject *s)
-{
-       SOCKET_T newfd;
-       PyObject *sock;
-
-       newfd = dup(s->sock_fd);
-       if (newfd < 0)
-               return s->errorhandler();
-       sock = (PyObject *) new_sockobject(newfd,
-                                          s->sock_family,
-                                          s->sock_type,
-                                          s->sock_proto);
-       if (sock == NULL)
-               SOCKETCLOSE(newfd);
-       return sock;
-}
-
-PyDoc_STRVAR(dup_doc,
-"dup() -> socket object\n\
-\n\
-Return a new socket object connected to the same system resource.");
-
-#endif
-
-
 /* s.getsockname() method */
 
 static PyObject *
@@ -2542,7 +2520,7 @@ of the socket (flag == SHUT_WR), or both ends (flag == SHUT_RDWR).");
 /* List of methods for socket objects */
 
 static PyMethodDef sock_methods[] = {
-       {"accept",        (PyCFunction)sock_accept, METH_NOARGS,
+       {"_accept",       (PyCFunction)sock_accept, METH_NOARGS,
                          accept_doc},
        {"bind",          (PyCFunction)sock_bind, METH_O,
                          bind_doc},
@@ -2552,10 +2530,6 @@ static PyMethodDef sock_methods[] = {
                          connect_doc},
        {"connect_ex",    (PyCFunction)sock_connect_ex, METH_O,
                          connect_ex_doc},
-#ifndef NO_DUP
-       {"dup",           (PyCFunction)sock_dup, METH_NOARGS,
-                         dup_doc},
-#endif
        {"fileno",        (PyCFunction)sock_fileno, METH_NOARGS,
                          fileno_doc},
 #ifdef HAVE_GETPEERNAME
@@ -2672,8 +2646,8 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds)
                                         &family, &type, &proto, &fdobj))
                return -1;
 
-       if (fdobj != NULL) {
-               fd = PyLong_AsLongLong(fdobj);
+       if (fdobj != NULL && fdobj != Py_None) {
+               fd = PyLong_AsSocket_t(fdobj);
                if (fd == (SOCKET_T)(-1) && PyErr_Occurred())
                        return -1;
                if (fd == INVALID_SOCKET) {
@@ -3172,6 +3146,38 @@ PyDoc_STRVAR(getprotobyname_doc,
 Return the protocol number for the named protocol.  (Rarely used.)");
 
 
+#ifndef NO_DUP
+/* dup() function for socket fds */
+
+static PyObject *
+socket_dup(PyObject *self, PyObject *fdobj)
+{
+       SOCKET_T fd, newfd;
+       PyObject *newfdobj;
+
+
+       fd = PyLong_AsSocket_t(fdobj);
+       if (fd == (SOCKET_T)(-1) && PyErr_Occurred())
+               return NULL;
+
+       newfd = dup_socket(fd);
+       if (newfd == INVALID_SOCKET)
+               return set_error();
+
+       newfdobj = PyLong_FromSocket_t(newfd);
+       if (newfdobj == NULL)
+               SOCKETCLOSE(newfd);
+       return newfdobj;
+}
+
+PyDoc_STRVAR(dup_doc,
+"dup(integer) -> integer\n\
+\n\
+Duplicate an integer socket file descriptor.  This is like os.dup(), but for\n\
+sockets; on some platforms os.dup() won't work for socket file descriptors.");
+#endif
+
+
 #ifdef HAVE_SOCKETPAIR
 /* Create a pair of sockets using the socketpair() function.
    Arguments as for socket() except the default family is AF_UNIX if
@@ -3811,6 +3817,10 @@ static PyMethodDef socket_methods[] = {
         METH_VARARGS, getservbyport_doc},
        {"getprotobyname",      socket_getprotobyname,
         METH_VARARGS, getprotobyname_doc},
+#ifndef NO_DUP
+       {"dup",                 socket_dup,
+         METH_O, dup_doc},
+#endif
 #ifdef HAVE_SOCKETPAIR
        {"socketpair",          socket_socketpair,
         METH_VARARGS, socketpair_doc},
@@ -4105,7 +4115,7 @@ init_socket(void)
        PyModule_AddIntConstant(m, "NETLINK_IP6_FW", NETLINK_IP6_FW);
 #ifdef NETLINK_DNRTMSG
        PyModule_AddIntConstant(m, "NETLINK_DNRTMSG", NETLINK_DNRTMSG);
-#endif 
+#endif
 #ifdef NETLINK_TAPBASE
        PyModule_AddIntConstant(m, "NETLINK_TAPBASE", NETLINK_TAPBASE);
 #endif