]> granicus.if.org Git - libnl/commitdiff
nl_recv(): Memory allocation errors are handled properly now
authorКоренберг Марк (дома) <socketpair@gmail.com>
Fri, 19 Oct 2012 16:58:58 +0000 (22:58 +0600)
committerКоренберг Марк (дома) <socketpair@gmail.com>
Fri, 19 Oct 2012 17:48:46 +0000 (23:48 +0600)
1. all cleanup actions (like free()) now located at the end of function
2. in case of error or EOF, *buf and *creds (if given) set to NULL
   This protect from invalid code at user's side, like:
   char *buf;
   x = nl_recv(..., &buf, ...);
   if (x<=0)
      goto cleanup;
   cleanup:
      free(buf);
3. all intermediate buffers are stored into local variables, and user's
   variables only touches at the end.

lib/nl.c

index 8426c5883899b86c08a609c817becd5a68fe239c..d08f7e15ad9d32d750ee4d6a21d177a7a9d6e718 100644 (file)
--- a/lib/nl.c
+++ b/lib/nl.c
@@ -437,9 +437,9 @@ int nl_recv(struct nl_sock *sk, struct sockaddr_nl *nla,
                .msg_controllen = 0,
                .msg_flags = 0,
        };
-       struct cmsghdr *cmsg;
-
        memset(nla, 0, sizeof(*nla));
+       struct ucred* tmpcreds = NULL;
+       int retval = 0;
 
        if (sk->s_flags & NL_MSG_PEEK)
                flags |= MSG_PEEK | MSG_TRUNC;
@@ -448,20 +448,29 @@ int nl_recv(struct nl_sock *sk, struct sockaddr_nl *nla,
                page_size = getpagesize();
 
        iov.iov_len = sk->s_bufsize ? : page_size;
-       iov.iov_base = *buf = malloc(iov.iov_len);
+       iov.iov_base = malloc(iov.iov_len);
+
+       if (!iov.iov_base) {
+           retval = -NLE_NOMEM;
+           goto abort;
+       }
 
        if (sk->s_flags & NL_SOCK_PASSCRED) {
                msg.msg_controllen = CMSG_SPACE(sizeof(struct ucred));
                msg.msg_control = calloc(1, msg.msg_controllen);
+               if (!msg.msg_control) {
+                       retval = -NLE_NOMEM;
+                       goto abort;
+               }
        }
 retry:
 
        n = recvmsg(sk->s_fd, &msg, flags);
-       if (!n)
+       if (!n) {
+               retval = 0;
                goto abort;
-
+       }
        if (n < 0) {
-
                if (errno == EINTR) {
                        NL_DBG(3, "recvmsg() returned EINTR, retrying\n");
                        goto retry;
@@ -469,26 +478,37 @@ retry:
 
                 if (errno == EAGAIN) {
                        NL_DBG(3, "recvmsg() returned EAGAIN, aborting\n");
+                       retval = 0;
                        goto abort;
                }
-
-               free(msg.msg_control);
-               free(*buf);
-               return -nl_syserr2nlerr(errno);
+               retval = -nl_syserr2nlerr(errno);
+               goto abort;
        }
 
        if (msg.msg_flags & MSG_CTRUNC) {
+               void *tmp;
                msg.msg_controllen *= 2;
-               msg.msg_control = realloc(msg.msg_control, msg.msg_controllen);
+               tmp = realloc(msg.msg_control, msg.msg_controllen);
+               if (!tmp) {
+                   retval = -NLE_NOMEM;
+                   goto abort;
+               }
+               msg.msg_control = tmp;
                goto retry;
        }
 
         if (iov.iov_len < n || msg.msg_flags & MSG_TRUNC) {
+               void *tmp;
                /* Provided buffer is not long enough, enlarge it
                 * to size of n (which should be total length of the message)
                 * and try again. */
                iov.iov_len = n;
-               iov.iov_base = *buf = realloc(*buf, iov.iov_len);
+               tmp = realloc(iov.iov_base, iov.iov_len);
+               if (!tmp) {
+                   retval = -NLE_NOMEM;
+                   goto abort;
+               }
+               iov.iov_base = tmp;
                flags = 0;
                goto retry;
        }
@@ -500,29 +520,40 @@ retry:
        }
 
        if (msg.msg_namelen != sizeof(struct sockaddr_nl)) {
-               free(msg.msg_control);
-               free(*buf);
-               return -NLE_NOADDR;
+               retval =  -NLE_NOADDR;
+               goto abort;
        }
 
        for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
                if (cmsg->cmsg_level == SOL_SOCKET &&
                    cmsg->cmsg_type == SCM_CREDENTIALS) {
                        if (creds) {
-                               *creds = calloc(1, sizeof(struct ucred));
-                               memcpy(*creds, CMSG_DATA(cmsg), sizeof(struct ucred));
+                               tmpcreds = malloc(sizeof(*tmpcreds));
+                               if (!tmpcreds) {
+                                      retval = -NLE_NOMEM;
+                                      goto abort;
+                               }
+                               memcpy(tmpcreds, CMSG_DATA(cmsg), sizeof(*tmpcreds));
                        }
                        break;
                }
        }
 
-       free(msg.msg_control);
-       return n;
-
+       retval = n;
 abort:
        free(msg.msg_control);
-       free(*buf);
-       return 0;
+
+       if (retval <= 0) {
+           free(iov.iov_base); iov.iov_base = NULL;
+           free(tmpcreds); tmpcreds = NULL;
+       }
+
+       *buf = iov.iov_base;
+
+       if (creds)
+           *creds = tmpcreds;
+
+       return retval;
 }
 
 /** @cond SKIP */