]> granicus.if.org Git - python/commitdiff
Patch #1003700: Add socketpair function to socket module.
authorDave Cole <djc@object-craft.com.au>
Mon, 9 Aug 2004 04:51:41 +0000 (04:51 +0000)
committerDave Cole <djc@object-craft.com.au>
Mon, 9 Aug 2004 04:51:41 +0000 (04:51 +0000)
Doc/lib/libsocket.tex
Lib/socket.py
Lib/test/test_socket.py
Misc/NEWS
Modules/socketmodule.c
configure.in
pyconfig.h.in

index f78ad2fc2101959c2e81b55511431b64c8a5d119..06d264550669c3d0ec5bf38b55483086359ebab5 100644 (file)
@@ -303,6 +303,14 @@ success, a new \class{SSLObject} is returned.
 \warning{This does not do any certificate verification!}
 \end{funcdesc}
 
+\begin{funcdesc}{socketpair}{\optional{family\optional{, type\optional{, proto}}}}
+Build a pair of connected socket objects using the given address
+family, socket type and protocol number.  Address family, socket type
+and protocol number are as for the \function{socket()} function above.
+Availability: \UNIX.
+\versionadded{2.3}
+\end{funcdesc}
+
 \begin{funcdesc}{fromfd}{fd, family, type\optional{, proto}}
 Build a socket object from an existing file descriptor (an integer as
 returned by a file object's \method{fileno()} method).  Address family,
index e97ce5979f5df85849cf6d0b1f80a93f80be92f1..f96a14683715eabfec8936cea14e6f9a63a381cb 100644 (file)
@@ -10,6 +10,7 @@ socket are available as methods of the socket object.
 Functions:
 
 socket() -- create a new socket object
+socketpair() -- create a pair of new socket objects [*]
 fromfd() -- create a socket object from an open file descriptor [*]
 gethostname() -- return the current hostname
 gethostbyname() -- map a hostname to its IP number
index adeca569e46fb21b75a7000c8173e44d39cb27d8..2dc34aa7b571330a17bde623b99cdfb695d9b83c 100644 (file)
@@ -187,6 +187,28 @@ class SocketConnectedTest(ThreadedTCPSocketTest):
         self.serv_conn = None
         ThreadedTCPSocketTest.clientTearDown(self)
 
+class SocketPairTest(unittest.TestCase, ThreadableTest):
+
+    def __init__(self, methodName='runTest'):
+        unittest.TestCase.__init__(self, methodName=methodName)
+        ThreadableTest.__init__(self)
+
+    def setUp(self):
+        self.serv, self.cli = socket.socketpair()
+
+    def tearDown(self):
+        self.serv.close()
+        self.serv = None
+
+    def clientSetUp(self):
+        pass
+
+    def clientTearDown(self):
+        self.cli.close()
+        self.cli = None
+        ThreadableTest.clientTearDown(self)
+
+    
 #######################################################################
 ## Begin Tests
 
@@ -541,6 +563,25 @@ class BasicUDPTest(ThreadedUDPSocketTest):
     def _testRecvFrom(self):
         self.cli.sendto(MSG, 0, (HOST, PORT))
 
+class BasicSocketPairTest(SocketPairTest):
+
+    def __init__(self, methodName='runTest'):
+        SocketPairTest.__init__(self, methodName=methodName)
+
+    def testRecv(self):
+        msg = self.serv.recv(1024)
+        self.assertEqual(msg, MSG)
+
+    def _testRecv(self):
+        self.cli.send(MSG)
+
+    def testSend(self):
+        self.serv.send(MSG)
+
+    def _testSend(self):
+        msg = self.cli.recv(1024)
+        self.assertEqual(msg, MSG)
+
 class NonBlockingTCPTests(ThreadedTCPSocketTest):
 
     def __init__(self, methodName='runTest'):
@@ -786,6 +827,8 @@ def test_main():
         LineBufferedFileObjectClassTestCase,
         SmallBufferedFileObjectClassTestCase
     ])
+    if hasattr(socket, "socketpair"):
+        tests.append(BasicSocketPairTest)
     test_support.run_unittest(*tests)
 
 if __name__ == "__main__":
index f3e51c65b5dbfd21ecb595f201a4ed80137dffcb..f3ed6385a0bb276a781f3a7869f649497ee20c42 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -24,6 +24,8 @@ Core and builtins
 Extension modules
 -----------------
 
+- Added socket.socketpair().
+
 Library
 -------
 
index 30159d2c048efbf4e7c0e4851d807e0be92710ba..f06e253e7668a3abeb8439da0660cab4bee73d8c 100644 (file)
@@ -28,6 +28,7 @@ Module interface:
 - socket.getservbyname(servicename[, protocolname]) --> port number
 - socket.getservbyport(portnumber[, protocolname]) --> service name
 - socket.socket([family[, type [, proto]]]) --> new socket object
+- socket.socketpair([family[, type [, proto]]]) --> (socket, socket)
 - socket.ntohs(16 bit value) --> new int object
 - socket.ntohl(32 bit value) --> new int object
 - socket.htons(16 bit value) --> new int object
@@ -3009,6 +3010,63 @@ PyDoc_STRVAR(getprotobyname_doc,
 Return the protocol number for the named protocol.  (Rarely used.)");
 
 
+#ifdef HAVE_SOCKETPAIR
+/* Create a pair of sockets using the socketpair() function.
+   Arguments as for socket(). */
+
+/*ARGSUSED*/
+static PyObject *
+socket_socketpair(PyObject *self, PyObject *args)
+{
+       PySocketSockObject *s0 = NULL, *s1 = NULL;
+       SOCKET_T sv[2];
+       int family, type = SOCK_STREAM, proto = 0;
+       PyObject *res = NULL;
+
+#if defined(AF_UNIX)
+       family = AF_UNIX;
+#else
+       family = AF_INET;
+#endif
+       if (!PyArg_ParseTuple(args, "|iii:socketpair",
+                             &family, &type, &proto))
+               return NULL;
+       /* Create a pair of socket fds */
+       if (socketpair(family, type, proto, sv) < 0)
+               return set_error();
+#ifdef SIGPIPE
+       (void) signal(SIGPIPE, SIG_IGN);
+#endif
+       s0 = new_sockobject(sv[0], family, type, proto);
+       if (s0 == NULL)
+               goto finally;
+       s1 = new_sockobject(sv[1], family, type, proto);
+       if (s1 == NULL)
+               goto finally;
+       res = PyTuple_Pack(2, s0, s1);
+
+finally:
+       if (res == NULL) {
+               if (s0 == NULL)
+                       SOCKETCLOSE(sv[0]);
+               if (s1 == NULL)
+                       SOCKETCLOSE(sv[1]);
+       }
+       Py_XDECREF(s0);
+       Py_XDECREF(s1);
+       return res;
+}
+
+PyDoc_STRVAR(socketpair_doc,
+"socketpair([family[, type[, proto]]]) -> (socket object, socket object)\n\
+\n\
+Create a pair of socket objects from the sockets returned by the platform\n\
+socketpair() function.\n\
+The arguments are the same as for socket().");
+
+#endif /* HAVE_SOCKETPAIR */
+
+
 #ifndef NO_DUP
 /* Create a socket object from a numeric file description.
    Useful e.g. if stdin is a socket.
@@ -3607,6 +3665,10 @@ static PyMethodDef socket_methods[] = {
 #ifndef NO_DUP
        {"fromfd",              socket_fromfd,
         METH_VARARGS, fromfd_doc},
+#endif
+#ifdef HAVE_SOCKETPAIR
+       {"socketpair",          socket_socketpair,
+        METH_VARARGS, socketpair_doc},
 #endif
        {"ntohs",               socket_ntohs,
         METH_VARARGS, ntohs_doc},
index f73d1ceac1094c705eeb046ce896cb884e792eed..fbd436bb5b9b338823341c44f2116dc75b03af80 100644 (file)
@@ -2468,6 +2468,17 @@ int foo(int x, ...) {
 ])
 AC_MSG_RESULT($works)
 
+# check for socketpair
+AC_MSG_CHECKING(for socketpair)
+AC_TRY_COMPILE([
+#include <sys/types.h>
+#include <sys/socket.h>
+], void *x=socketpair,
+  AC_DEFINE(HAVE_SOCKETPAIR, 1, Define if you have the 'socketpair' function.)
+  AC_MSG_RESULT(yes),
+  AC_MSG_RESULT(no)
+)
+
 # check if sockaddr has sa_len member
 AC_MSG_CHECKING(if sockaddr has sa_len member)
 AC_TRY_COMPILE([#include <sys/types.h>
index 5e1a43c4b839e7e02222eb70d03797b8938ff039..19227e0e0db1ace80394b04b5e3f4b841dd5b1d0 100644 (file)
 /* Define to 1 if you have the `snprintf' function. */
 #undef HAVE_SNPRINTF
 
+/* Define to 1 if you have the `socketpair' function. */
+#undef HAVE_SOCKETPAIR
+
 /* Define if sockaddr has sa_len member */
 #undef HAVE_SOCKADDR_SA_LEN