chiark / gitweb /
update TODO
[elogind.git] / libudev / libudev-ctrl.c
index cea1b7f55b0d359a4bfa62916d23e0a028abd6de..e0ec2fa3d7bcc69e1093b4595fad15844cba8144 100644 (file)
@@ -60,6 +60,7 @@ struct udev_ctrl {
        int sock;
        struct sockaddr_un saddr;
        socklen_t addrlen;
+       bool bound;
        bool connected;
 };
 
@@ -81,7 +82,7 @@ static struct udev_ctrl *udev_ctrl_new(struct udev *udev)
        return uctrl;
 }
 
-struct udev_ctrl *udev_ctrl_new_from_socket(struct udev *udev, const char *socket_path)
+struct udev_ctrl *udev_ctrl_new_from_socket_fd(struct udev *udev, const char *socket_path, int fd)
 {
        struct udev_ctrl *uctrl;
 
@@ -89,11 +90,16 @@ struct udev_ctrl *udev_ctrl_new_from_socket(struct udev *udev, const char *socke
        if (uctrl == NULL)
                return NULL;
 
-       uctrl->sock = socket(AF_LOCAL, SOCK_SEQPACKET, 0);
-       if (uctrl->sock < 0) {
-               err(udev, "error getting socket: %m\n");
-               udev_ctrl_unref(uctrl);
-               return NULL;
+       if (fd < 0) {
+               uctrl->sock = socket(AF_LOCAL, SOCK_SEQPACKET|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
+               if (uctrl->sock < 0) {
+                       err(udev, "error getting socket: %m\n");
+                       udev_ctrl_unref(uctrl);
+                       return NULL;
+               }
+       } else {
+               uctrl->bound = true;
+               uctrl->sock = fd;
        }
 
        uctrl->saddr.sun_family = AF_LOCAL;
@@ -105,35 +111,31 @@ struct udev_ctrl *udev_ctrl_new_from_socket(struct udev *udev, const char *socke
        return uctrl;
 }
 
-struct udev_ctrl *udev_ctrl_new_from_fd(struct udev *udev, int fd)
+struct udev_ctrl *udev_ctrl_new_from_socket(struct udev *udev, const char *socket_path)
 {
-       struct udev_ctrl *uctrl;
-
-       uctrl = udev_ctrl_new(udev);
-       if (uctrl == NULL)
-               return NULL;
-       uctrl->sock = fd;
-
-       return uctrl;
+       return udev_ctrl_new_from_socket_fd(udev, socket_path, -1);
 }
 
 int udev_ctrl_enable_receiving(struct udev_ctrl *uctrl)
 {
        int err;
 
-       if (uctrl->addrlen > 0) {
+       if (!uctrl->bound) {
                err = bind(uctrl->sock, (struct sockaddr *)&uctrl->saddr, uctrl->addrlen);
                if (err < 0) {
                        err = -errno;
                        err(uctrl->udev, "bind failed: %m\n");
                        return err;
                }
+
                err = listen(uctrl->sock, 0);
                if (err < 0) {
                        err = -errno;
                        err(uctrl->udev, "listen failed: %m\n");
                        return err;
                }
+
+               uctrl->bound = true;
        }
        return 0;
 }
@@ -174,6 +176,8 @@ int udev_ctrl_get_fd(struct udev_ctrl *uctrl)
 struct udev_ctrl_connection *udev_ctrl_get_connection(struct udev_ctrl *uctrl)
 {
        struct udev_ctrl_connection *conn;
+       struct ucred ucred;
+       socklen_t slen;
        const int on = 1;
 
        conn = calloc(1, sizeof(struct udev_ctrl_connection));
@@ -182,16 +186,33 @@ struct udev_ctrl_connection *udev_ctrl_get_connection(struct udev_ctrl *uctrl)
        conn->refcount = 1;
        conn->uctrl = uctrl;
 
-       conn->sock = accept4(uctrl->sock, NULL, NULL, SOCK_CLOEXEC);
+       conn->sock = accept4(uctrl->sock, NULL, NULL, SOCK_CLOEXEC|SOCK_NONBLOCK);
        if (conn->sock < 0) {
-               free(conn);
-               return NULL;
+               if (errno != EINTR)
+                       err(uctrl->udev, "unable to receive ctrl connection: %m\n");
+               goto err;
+       }
+
+       /* check peer credential of connection */
+       slen = sizeof(ucred);
+       if (getsockopt(conn->sock, SOL_SOCKET, SO_PEERCRED, &ucred, &slen) < 0) {
+               err(uctrl->udev, "unable to receive credentials of ctrl connection: %m\n");
+               goto err;
+       }
+       if (ucred.uid > 0) {
+               err(uctrl->udev, "sender uid=%i, message ignored\n", ucred.uid);
+               goto err;
        }
 
-       /* enable receiving of the sender credentials */
+       /* enable receiving of the sender credentials in the messages */
        setsockopt(conn->sock, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
        udev_ctrl_ref(uctrl);
        return conn;
+err:
+       if (conn->sock >= 0)
+               close(conn->sock);
+       free(conn);
+       return NULL;
 }
 
 struct udev_ctrl_connection *udev_ctrl_connection_ref(struct udev_ctrl_connection *conn)
@@ -327,19 +348,44 @@ struct udev_ctrl_msg *udev_ctrl_receive_msg(struct udev_ctrl_connection *conn)
                return NULL;
        uctrl_msg->refcount = 1;
        uctrl_msg->conn = conn;
+       udev_ctrl_connection_ref(conn);
+
+       /* wait for the incoming message */
+       for(;;) {
+               struct pollfd pfd[1];
+               int r;
+
+               pfd[0].fd = conn->sock;
+               pfd[0].events = POLLIN;
+
+               r = poll(pfd, 1, 10000);
+               if (r  < 0) {
+                       if (errno == EINTR)
+                               continue;
+                       goto err;
+               } else if (r == 0) {
+                       err(udev, "timeout waiting for ctrl message\n");
+                       goto err;
+               } else {
+                       if (!(pfd[0].revents & POLLIN)) {
+                               err(udev, "ctrl connection error: %m\n");
+                               goto err;
+                       }
+               }
+
+               break;
+       }
 
        iov.iov_base = &uctrl_msg->ctrl_msg_wire;
        iov.iov_len = sizeof(struct udev_ctrl_msg_wire);
-
        memset(&smsg, 0x00, sizeof(struct msghdr));
        smsg.msg_iov = &iov;
        smsg.msg_iovlen = 1;
        smsg.msg_control = cred_msg;
        smsg.msg_controllen = sizeof(cred_msg);
-
        size = recvmsg(conn->sock, &smsg, 0);
        if (size <  0) {
-               err(udev, "unable to receive user udevd message: %m\n");
+               err(udev, "unable to receive ctrl message: %m\n");
                goto err;
        }
        cmsg = CMSG_FIRSTHDR(&smsg);
@@ -361,7 +407,6 @@ struct udev_ctrl_msg *udev_ctrl_receive_msg(struct udev_ctrl_connection *conn)
        }
 
        dbg(udev, "created ctrl_msg %p (%i)\n", uctrl_msg, uctrl_msg->ctrl_msg_wire.type);
-       udev_ctrl_connection_ref(conn);
        return uctrl_msg;
 err:
        udev_ctrl_msg_unref(uctrl_msg);