chiark / gitweb /
change type for address family to "int"
[elogind.git] / src / libsystemd / sd-rtnl / rtnl-message.c
index d611207deab2c76bd157aa279b451c699fc8f89b..7f2e398b7494ab10cd45f77e74b9029cbae11bb4 100644 (file)
@@ -133,7 +133,7 @@ int sd_rtnl_message_route_set_scope(sd_rtnl_message *m, unsigned char scope) {
 }
 
 int sd_rtnl_message_new_route(sd_rtnl *rtnl, sd_rtnl_message **ret,
-                              uint16_t nlmsg_type, unsigned char rtm_family) {
+                              uint16_t nlmsg_type, int rtm_family) {
         struct rtmsg *rtm;
         int r;
 
@@ -275,7 +275,7 @@ int sd_rtnl_message_addr_set_scope(sd_rtnl_message *m, unsigned char scope) {
         return 0;
 }
 
-int sd_rtnl_message_addr_get_family(sd_rtnl_message *m, unsigned char *family) {
+int sd_rtnl_message_addr_get_family(sd_rtnl_message *m, int *family) {
         struct ifaddrmsg *ifa;
 
         assert_return(m, -EINVAL);
@@ -352,7 +352,7 @@ int sd_rtnl_message_addr_get_ifindex(sd_rtnl_message *m, int *ifindex) {
 
 int sd_rtnl_message_new_addr(sd_rtnl *rtnl, sd_rtnl_message **ret,
                              uint16_t nlmsg_type, int index,
-                             unsigned char family) {
+                             int family) {
         struct ifaddrmsg *ifa;
         int r;
 
@@ -383,7 +383,7 @@ int sd_rtnl_message_new_addr(sd_rtnl *rtnl, sd_rtnl_message **ret,
 }
 
 int sd_rtnl_message_new_addr_update(sd_rtnl *rtnl, sd_rtnl_message **ret,
-                             int index, unsigned char family) {
+                             int index, int family) {
         int r;
 
         r = sd_rtnl_message_new_addr(rtnl, ret, RTM_NEWADDR, index, family);
@@ -786,7 +786,7 @@ int rtnl_message_read_internal(sd_rtnl_message *m, unsigned short type, void **d
         return RTA_PAYLOAD(rta);
 }
 
-int sd_rtnl_message_read_string(sd_rtnl_message *m, unsigned short type, char **data) {
+int sd_rtnl_message_read_string(sd_rtnl_message *m, unsigned short type, const char **data) {
         int r;
         void *attr_data;
 
@@ -800,7 +800,7 @@ int sd_rtnl_message_read_string(sd_rtnl_message *m, unsigned short type, char **
         else if (strnlen(attr_data, r) >= (size_t) r)
                 return -EIO;
 
-        *data = (char *) attr_data;
+        *data = (const char *) attr_data;
 
         return 0;
 }
@@ -962,7 +962,7 @@ int sd_rtnl_message_enter_container(sd_rtnl_message *m, unsigned short type) {
                         return r;
         } else if (nl_type->type == NLA_UNION) {
                 const NLTypeSystemUnion *type_system_union;
-                char *key;
+                const char *key;
 
                 r = type_system_get_type_system_union(m->container_type_system[m->n_containers],
                                                       &type_system_union,
@@ -1098,6 +1098,60 @@ int socket_write_message(sd_rtnl *nl, sd_rtnl_message *m) {
         return k;
 }
 
+static int socket_recv_message(int fd, struct iovec *iov, uint32_t *_group, bool peek) {
+        uint8_t cred_buffer[CMSG_SPACE(sizeof(struct ucred)) +
+                            CMSG_SPACE(sizeof(struct nl_pktinfo))];
+        struct msghdr msg = {
+                .msg_iov = iov,
+                .msg_iovlen = 1,
+                .msg_control = cred_buffer,
+                .msg_controllen = sizeof(cred_buffer),
+        };
+        struct cmsghdr *cmsg;
+        uint32_t group = 0;
+        bool auth = false;
+        int r;
+
+        assert(fd >= 0);
+        assert(iov);
+
+        r = recvmsg(fd, &msg, MSG_TRUNC | (peek ? MSG_PEEK : 0));
+        if (r < 0)
+                /* no data */
+                return (errno == EAGAIN) ? 0 : -errno;
+        else if (r == 0)
+                /* connection was closed by the kernel */
+                return -ECONNRESET;
+
+        for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+                if (cmsg->cmsg_level == SOL_SOCKET &&
+                    cmsg->cmsg_type == SCM_CREDENTIALS &&
+                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
+                        struct ucred *ucred = (void *)CMSG_DATA(cmsg);
+
+                        /* from the kernel */
+                        if (ucred->uid == 0 && ucred->pid == 0)
+                                auth = true;
+                } else if (cmsg->cmsg_level == SOL_NETLINK &&
+                           cmsg->cmsg_type == NETLINK_PKTINFO &&
+                           cmsg->cmsg_len == CMSG_LEN(sizeof(struct nl_pktinfo))) {
+                        struct nl_pktinfo *pktinfo = (void *)CMSG_DATA(cmsg);
+
+                        /* multi-cast group */
+                        group = pktinfo->group;
+                }
+        }
+
+        if (!auth)
+                /* not from the kernel, ignore */
+                return 0;
+
+        if (group)
+                *_group = group;
+
+        return r;
+}
+
 /* On success, the number of bytes received is returned and *ret points to the received message
  * which has a valid header and the correct size.
  * If nothing useful was received 0 is returned.
@@ -1105,16 +1159,9 @@ int socket_write_message(sd_rtnl *nl, sd_rtnl_message *m) {
  */
 int socket_read_message(sd_rtnl *rtnl) {
         _cleanup_rtnl_message_unref_ sd_rtnl_message *first = NULL;
-        uint8_t cred_buffer[CMSG_SPACE(sizeof(struct ucred))];
         struct iovec iov = {};
-        struct msghdr msg = {
-                .msg_iov = &iov,
-                .msg_iovlen = 1,
-                .msg_control = cred_buffer,
-                .msg_controllen = sizeof(cred_buffer),
-        };
-        struct cmsghdr *cmsg;
-        bool auth = false, multi_part = false, done = false;
+        uint32_t group = 0;
+        bool multi_part = false, done = false;
         struct nlmsghdr *new_msg;
         size_t len;
         int r;
@@ -1125,13 +1172,9 @@ int socket_read_message(sd_rtnl *rtnl) {
         assert(rtnl->rbuffer_allocated >= sizeof(struct nlmsghdr));
 
         /* read nothing, just get the pending message size */
-        r = recvmsg(rtnl->fd, &msg, MSG_PEEK | MSG_TRUNC);
-        if (r < 0)
-                /* no data */
-                return (errno == EAGAIN) ? 0 : -errno;
-        else if (r == 0)
-                /* connection was closed by the kernel */
-                return -ECONNRESET;
+        r = socket_recv_message(rtnl->fd, &iov, &group, true);
+        if (r <= 0)
+                return r;
         else
                 len = (size_t)r;
 
@@ -1144,13 +1187,10 @@ int socket_read_message(sd_rtnl *rtnl) {
         iov.iov_base = rtnl->rbuffer;
         iov.iov_len = rtnl->rbuffer_allocated;
 
-        r = recvmsg(rtnl->fd, &msg, MSG_TRUNC);
-        if (r < 0)
-                /* no data */
-                return (errno == EAGAIN) ? 0 : -errno;
-        else if (r == 0)
-                /* connection was closed by the kernel */
-                return -ECONNRESET;
+        /* read the pending message */
+        r = socket_recv_message(rtnl->fd, &iov, &group, false);
+        if (r <= 0)
+                return r;
         else
                 len = (size_t)r;
 
@@ -1158,24 +1198,6 @@ int socket_read_message(sd_rtnl *rtnl) {
                 /* message did not fit in read buffer */
                 return -EIO;
 
-        for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
-                if (cmsg->cmsg_level == SOL_SOCKET &&
-                    cmsg->cmsg_type == SCM_CREDENTIALS &&
-                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred))) {
-                        struct ucred *ucred = (void *)CMSG_DATA(cmsg);
-
-                        /* from the kernel */
-                        if (ucred->uid == 0 && ucred->pid == 0) {
-                                auth = true;
-                                break;
-                        }
-                }
-        }
-
-        if (!auth)
-                /* not from the kernel, ignore */
-                return 0;
-
         if (NLMSG_OK(rtnl->rbuffer, len) && rtnl->rbuffer->nlmsg_flags & NLM_F_MULTI) {
                 multi_part = true;
 
@@ -1192,7 +1214,7 @@ int socket_read_message(sd_rtnl *rtnl) {
                 _cleanup_rtnl_message_unref_ sd_rtnl_message *m = NULL;
                 const NLType *nl_type;
 
-                if (new_msg->nlmsg_pid && new_msg->nlmsg_pid != rtnl->sockaddr.nl.nl_pid)
+                if (!group && new_msg->nlmsg_pid != rtnl->sockaddr.nl.nl_pid)
                         /* not broadcast and not for us */
                         continue;