chiark / gitweb /
util: replace close_pipe() with new safe_close_pair()
[elogind.git] / src / socket-proxy / socket-proxyd.c
index 362e8aae9f3760dbe0245a266d655a532c5557bf..ac47c851507b2516e01835bdac6dd32d675189c1 100644 (file)
@@ -66,7 +66,6 @@ typedef struct Connection {
 } Connection;
 
 static const char *arg_remote_host = NULL;
-static int arg_listener = -1;
 
 static void connection_free(Connection *c) {
         assert(c);
@@ -77,13 +76,11 @@ static void connection_free(Connection *c) {
         sd_event_source_unref(c->server_event_source);
         sd_event_source_unref(c->client_event_source);
 
-        if (c->server_fd >= 0)
-                close_nointr_nofail(c->server_fd);
-        if (c->client_fd >= 0)
-                close_nointr_nofail(c->client_fd);
+        safe_close(c->server_fd);
+        safe_close(c->client_fd);
 
-        close_pipe(c->server_to_client_buffer);
-        close_pipe(c->client_to_server_buffer);
+        safe_close_pair(c->server_to_client_buffer);
+        safe_close_pair(c->client_to_server_buffer);
 
         free(c);
 }
@@ -225,8 +222,7 @@ static int connection_shovel(
                                 shoveled = true;
                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
                                 *from_source = sd_event_source_unref(*from_source);
-                                close_nointr_nofail(*from);
-                                *from = -1;
+                                *from = safe_close(*from);
                         } else if (errno != EAGAIN && errno != EINTR) {
                                 log_error("Failed to splice: %m");
                                 return -errno;
@@ -240,8 +236,7 @@ static int connection_shovel(
                                 shoveled = true;
                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
                                 *to_source = sd_event_source_unref(*to_source);
-                                close_nointr_nofail(*to);
-                                *to = -1;
+                                *to = safe_close(*to);
                         } else if (errno != EAGAIN && errno != EINTR) {
                                 log_error("Failed to splice: %m");
                                 return -errno;
@@ -319,7 +314,7 @@ static int connection_enable_event_sources(Connection *c, sd_event *event) {
         if (c->server_event_source)
                 r = sd_event_source_set_io_events(c->server_event_source, a);
         else if (c->server_fd >= 0)
-                r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
+                r = sd_event_add_io(event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
         else
                 r = 0;
 
@@ -331,7 +326,7 @@ static int connection_enable_event_sources(Connection *c, sd_event *event) {
         if (c->client_event_source)
                 r = sd_event_source_set_io_events(c->client_event_source, b);
         else if (c->client_fd >= 0)
-                r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
+                r = sd_event_add_io(event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
         else
                 r = 0;
 
@@ -397,7 +392,7 @@ static int add_connection_socket(Context *context, sd_event *event, int fd) {
 
         if (set_size(context->connections) > CONNECTIONS_MAX) {
                 log_warning("Hit connection limit, refusing connection.");
-                close_nointr_nofail(fd);
+                safe_close(fd);
                 return 0;
         }
 
@@ -434,7 +429,7 @@ static int add_connection_socket(Context *context, sd_event *event, int fd) {
         r = connect(c->client_fd, &sa.sa, salen);
         if (r < 0) {
                 if (errno == EINPROGRESS) {
-                        r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
+                        r = sd_event_add_io(event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
                         if (r < 0) {
                                 log_error("Failed to add connection socket: %s", strerror(-r));
                                 goto fail;
@@ -463,6 +458,7 @@ fail:
 }
 
 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
+        _cleanup_free_ char *peer = NULL;
         Context *context = userdata;
         int nfd = -1, r;
 
@@ -472,24 +468,24 @@ static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdat
         assert(context);
 
         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
-        if (nfd >= 0) {
-                _cleanup_free_ char *peer = NULL;
-
+        if (nfd < 0) {
+                if (errno != -EAGAIN)
+                        log_warning("Failed to accept() socket: %m");
+        } else {
                 getpeername_pretty(nfd, &peer);
                 log_debug("New connection from %s", strna(peer));
 
                 r = add_connection_socket(context, sd_event_source_get_event(s), nfd);
                 if (r < 0) {
-                        close_nointr_nofail(fd);
-                        return r;
+                        log_error("Failed to accept connection, ignoring: %s", strerror(-r));
+                        safe_close(fd);
                 }
-
-        } else if (errno != -EAGAIN)
-                log_warning("Failed to accept() socket: %m");
+        }
 
         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
         if (r < 0) {
-                log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
+                log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r));
+                sd_event_exit(sd_event_source_get_event(s), r);
                 return r;
         }
 
@@ -526,7 +522,7 @@ static int add_listen_socket(Context *context, sd_event *event, int fd) {
                 return r;
         }
 
-        r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
+        r = sd_event_add_io(event, &source, fd, EPOLLIN, accept_cb, context);
         if (r < 0) {
                 log_error("Failed to add event source: %s", strerror(-r));
                 return r;
@@ -555,9 +551,8 @@ static int help(void) {
         printf("%s [HOST:PORT]\n"
                "%s [SOCKET]\n\n"
                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
-               "  -l --listener=FD  Listen on a specific, single file descriptor.\n"
-               "  -h --help         Show this help\n"
-               "     --version      Show package version\n",
+               "  -h --help              Show this help\n"
+               "     --version           Show package version\n",
                program_invocation_short_name,
                program_invocation_short_name);
 
@@ -567,22 +562,22 @@ static int help(void) {
 static int parse_argv(int argc, char *argv[]) {
 
         enum {
-                ARG_VERSION = 0x100
+                ARG_VERSION = 0x100,
+                ARG_IGNORE_ENV
         };
 
         static const struct option options[] = {
-                { "help",     no_argument,       NULL, 'h'         },
-                { "version",  no_argument,       NULL, ARG_VERSION },
-                { "listener", required_argument, NULL, 'l'         },
+                { "help",       no_argument, NULL, 'h'           },
+                { "version",    no_argument, NULL, ARG_VERSION   },
                 {}
         };
 
-        int c, fd;
+        int c;
 
         assert(argc >= 0);
         assert(argv);
 
-        while ((c = getopt_long(argc, argv, "hl:", options, NULL)) >= 0) {
+        while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
 
                 switch (c) {
 
@@ -594,18 +589,6 @@ static int parse_argv(int argc, char *argv[]) {
                         puts(SYSTEMD_FEATURES);
                         return 0;
 
-                case 'l':
-                        if (safe_atoi(optarg, &fd) < 0) {
-                                log_error("Failed to parse listener file descriptor: %s", optarg);
-                                return -EINVAL;
-                        }
-                        if (fd < SD_LISTEN_FDS_START) {
-                                log_error("Listener file descriptor must be at least %d.", SD_LISTEN_FDS_START);
-                                return -EINVAL;
-                        }
-                        arg_listener = fd;
-                        break;
-
                 case '?':
                         return -EINVAL;
 
@@ -646,26 +629,21 @@ int main(int argc, char *argv[]) {
                 goto finish;
         }
 
-        if (arg_listener == -1) {
-                n = sd_listen_fds(1);
-                if (n < 0) {
-                        log_error("Failed to receive sockets from parent.");
-                        r = n;
-                        goto finish;
-                } else if (n == 0) {
-                        log_error("Didn't get any sockets passed in.");
-                        r = -EINVAL;
-                        goto finish;
-                }
-                log_info("Listening on %d inherited socket(s), starting with fd=%d.", n, SD_LISTEN_FDS_START);
-                for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
-                        r = add_listen_socket(&context, event, fd);
-                        if (r < 0)
-                                goto finish;
-                }
-        } else {
-                log_info("Listening on single inherited socket fd=%d.", arg_listener);
-                r = add_listen_socket(&context, event, arg_listener);
+        sd_event_set_watchdog(event, true);
+
+        n = sd_listen_fds(1);
+        if (n < 0) {
+                log_error("Failed to receive sockets from parent.");
+                r = n;
+                goto finish;
+        } else if (n == 0) {
+                log_error("Didn't get any sockets passed in.");
+                r = -EINVAL;
+                goto finish;
+        }
+
+        for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
+                r = add_listen_socket(&context, event, fd);
                 if (r < 0)
                         goto finish;
         }