chiark / gitweb /
rtnl: complain if used after fork
[elogind.git] / src / libsystemd-rtnl / sd-rtnl.c
index abf45d0eed109b971e0c9c94700c6e693f666a56..9c1f40e48a7b9be95846eec78062a3e486ebfd88 100644 (file)
@@ -43,12 +43,23 @@ static int sd_rtnl_new(sd_rtnl **ret) {
 
         rtnl->sockaddr.nl.nl_family = AF_NETLINK;
 
+        rtnl->original_pid = getpid();
+
         *ret = rtnl;
         return 0;
 }
 
-int sd_rtnl_open(__u32 groups, sd_rtnl **ret) {
-        sd_rtnl *rtnl;
+static bool rtnl_pid_changed(sd_rtnl *rtnl) {
+        assert(rtnl);
+
+        /* We don't support people creating an rtnl connection and
+         * keeping it around over a fork(). Let's complain. */
+
+        return rtnl->original_pid != getpid();
+}
+
+int sd_rtnl_open(uint32_t groups, sd_rtnl **ret) {
+        _cleanup_sd_rtnl_unref_ sd_rtnl *rtnl = NULL;
         int r;
 
         r = sd_rtnl_new(&rtnl);
@@ -56,22 +67,17 @@ int sd_rtnl_open(__u32 groups, sd_rtnl **ret) {
                 return r;
 
         rtnl->fd = socket(PF_NETLINK, SOCK_RAW|SOCK_CLOEXEC|SOCK_NONBLOCK, NETLINK_ROUTE);
-        if (rtnl->fd < 0) {
-                r = -errno;
-                sd_rtnl_unref(rtnl);
-                return r;
-        }
+        if (rtnl->fd < 0)
+                return -errno;
 
         rtnl->sockaddr.nl.nl_groups = groups;
 
         r = bind(rtnl->fd, &rtnl->sockaddr.sa, sizeof(rtnl->sockaddr));
-        if (r < 0) {
-                r = -errno;
-                sd_rtnl_unref(rtnl);
-                return r;
-        }
+        if (r < 0)
+                return -errno;
 
         *ret = rtnl;
+        rtnl = NULL;
 
         return 0;
 }
@@ -94,16 +100,16 @@ sd_rtnl *sd_rtnl_unref(sd_rtnl *rtnl) {
 }
 
 int sd_rtnl_send_with_reply_and_block(sd_rtnl *nl,
-                                        sd_rtnl_message *message,
-                                        uint64_t usec,
-                                        sd_rtnl_message **ret) {
+                sd_rtnl_message *message,
+                uint64_t usec,
+                sd_rtnl_message **ret) {
         struct pollfd p[1] = {};
-        sd_rtnl_message *reply;
         struct timespec left;
         usec_t timeout;
         int r, serial;
 
         assert_return(nl, -EINVAL);
+        assert_return(!rtnl_pid_changed(nl), -ECHILD);
         assert_return(message, -EINVAL);
 
         r = message_seal(nl, message);
@@ -115,10 +121,15 @@ int sd_rtnl_send_with_reply_and_block(sd_rtnl *nl,
         p[0].fd = nl->fd;
         p[0].events = POLLOUT;
 
-        timeout = now(CLOCK_MONOTONIC) + usec;
+        if (usec == (uint64_t) -1)
+                timeout = 0;
+        else if (usec == 0)
+                timeout = now(CLOCK_MONOTONIC) + RTNL_DEFAULT_TIMEOUT;
+        else
+                timeout = now(CLOCK_MONOTONIC) + usec;
 
         for (;;) {
-                if (usec != (uint64_t) -1) {
+                if (timeout) {
                         usec_t n;
 
                         n = now(CLOCK_MONOTONIC);
@@ -128,7 +139,7 @@ int sd_rtnl_send_with_reply_and_block(sd_rtnl *nl,
                         timespec_store(&left, timeout - n);
                 }
 
-                r = ppoll(p, 1, usec == (uint64_t) -1 ? NULL : &left, NULL);
+                r = ppoll(p, 1, timeout ? &left : NULL, NULL);
                 if (r < 0)
                         return 0;
 
@@ -144,7 +155,9 @@ int sd_rtnl_send_with_reply_and_block(sd_rtnl *nl,
         p[0].events = POLLIN;
 
         for (;;) {
-                if (usec != (uint64_t) -1) {
+                _cleanup_sd_rtnl_message_unref_ sd_rtnl_message *reply = NULL;
+
+                if (timeout) {
                         usec_t n;
 
                         n = now(CLOCK_MONOTONIC);
@@ -154,7 +167,7 @@ int sd_rtnl_send_with_reply_and_block(sd_rtnl *nl,
                         timespec_store(&left, timeout - n);
                 }
 
-                r = ppoll(p, 1, usec == (uint64_t) -1 ? NULL : &left, NULL);
+                r = ppoll(p, 1, timeout ? &left : NULL, NULL);
                 if (r < 0)
                         return r;
 
@@ -167,20 +180,16 @@ int sd_rtnl_send_with_reply_and_block(sd_rtnl *nl,
 
                         if (received_serial == serial) {
                                 r = message_get_errno(reply);
-                                if (r < 0) {
-                                        sd_rtnl_message_unref(reply);
+                                if (r < 0)
                                         return r;
-                                }
 
-                                if (ret)
+                                if (ret) {
                                         *ret = reply;
-                                else
-                                        reply = sd_rtnl_message_unref(reply);
+                                        reply = NULL;
+                                }
 
                                 break;;
                         }
-
-                        reply = sd_rtnl_message_unref(reply);
                 }
         }