chiark / gitweb /
socket-proxyd: Add --listener option for listener/destination pairs.
[elogind.git] / src / socket-proxy / socket-proxyd.c
index 1c64c0e2e5765732ee52ba034d5cb646cc694147..362e8aae9f3760dbe0245a266d655a532c5557bf 100644 (file)
@@ -53,6 +53,8 @@ typedef struct Context {
 } Context;
 
 typedef struct Connection {
+        Context *context;
+
         int server_fd, client_fd;
         int server_to_client_buffer[2]; /* a pipe */
         int client_to_server_buffer[2]; /* a pipe */
@@ -63,19 +65,15 @@ typedef struct Connection {
         sd_event_source *server_event_source, *client_event_source;
 } Connection;
 
-union sockaddr_any {
-        struct sockaddr sa;
-        struct sockaddr_un un;
-        struct sockaddr_in in;
-        struct sockaddr_in6 in6;
-        struct sockaddr_storage storage;
-};
-
 static const char *arg_remote_host = NULL;
+static int arg_listener = -1;
 
 static void connection_free(Connection *c) {
         assert(c);
 
+        if (c->context)
+                set_remove(c->context->connections, c);
+
         sd_event_source_unref(c->server_event_source);
         sd_event_source_unref(c->client_event_source);
 
@@ -99,14 +97,14 @@ static void context_free(Context *context) {
         while ((es = set_steal_first(context->listen)))
                 sd_event_source_unref(es);
 
-        while ((c = set_steal_first(context->connections)))
+        while ((c = set_first(context->connections)))
                 connection_free(c);
 
         set_free(context->listen);
         set_free(context->connections);
 }
 
-static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
+static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
         int r;
 
         assert(sa);
@@ -117,7 +115,7 @@ static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
 
-                *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path);
+                *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
 
         } else if (arg_remote_host[0] == '@') {
                 sa->un.sun_family = AF_UNIX;
@@ -125,7 +123,7 @@ static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
 
-                *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
+                *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
 
         } else {
                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
@@ -154,7 +152,7 @@ static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
                 }
 
                 assert(result);
-                if (result->ai_addrlen > sizeof(union sockaddr_any)) {
+                if (result->ai_addrlen > sizeof(union sockaddr_union)) {
                         log_error("Address too long.");
                         return -E2BIG;
                 }
@@ -290,7 +288,7 @@ static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userda
         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
                 goto quit;
 
-        r = connection_enable_event_sources(c, sd_event_get(s));
+        r = connection_enable_event_sources(c, sd_event_source_get_event(s));
         if (r < 0)
                 goto quit;
 
@@ -338,7 +336,7 @@ static int connection_enable_event_sources(Connection *c, sd_event *event) {
                 r = 0;
 
         if (r < 0) {
-                log_error("Failed to set up server event source: %s", strerror(-r));
+                log_error("Failed to set up client event source: %s", strerror(-r));
                 return r;
         }
 
@@ -376,7 +374,7 @@ static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userda
         if (r < 0)
                 goto fail;
 
-        r = connection_enable_event_sources(c, sd_event_get(s));
+        r = connection_enable_event_sources(c, sd_event_source_get_event(s));
         if (r < 0)
                 goto fail;
 
@@ -388,7 +386,7 @@ fail:
 }
 
 static int add_connection_socket(Context *context, sd_event *event, int fd) {
-        union sockaddr_any sa = {};
+        union sockaddr_union sa = {};
         socklen_t salen;
         Connection *c;
         int r;
@@ -411,11 +409,18 @@ static int add_connection_socket(Context *context, sd_event *event, int fd) {
         if (!c)
                 return log_oom();
 
+        c->context = context;
         c->server_fd = fd;
         c->client_fd = -1;
         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
 
+        r = set_put(context->connections, c);
+        if (r < 0) {
+                free(c);
+                return log_oom();
+        }
+
         r = get_remote_sockaddr(&sa, &salen);
         if (r < 0)
                 goto fail;
@@ -434,6 +439,12 @@ static int add_connection_socket(Context *context, sd_event *event, int fd) {
                                 log_error("Failed to add connection socket: %s", strerror(-r));
                                 goto fail;
                         }
+
+                        r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
+                        if (r < 0) {
+                                log_error("Failed to enable oneshot event source: %s", strerror(-r));
+                                goto fail;
+                        }
                 } else {
                         log_error("Failed to connect to remote host: %m");
                         goto fail;
@@ -467,7 +478,7 @@ static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdat
                 getpeername_pretty(nfd, &peer);
                 log_debug("New connection from %s", strna(peer));
 
-                r = add_connection_socket(context, sd_event_get(s), nfd);
+                r = add_connection_socket(context, sd_event_source_get_event(s), nfd);
                 if (r < 0) {
                         close_nointr_nofail(fd);
                         return r;
@@ -493,8 +504,6 @@ static int add_listen_socket(Context *context, sd_event *event, int fd) {
         assert(event);
         assert(fd >= 0);
 
-        log_info("Listening on %i", fd);
-
         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
         if (r < 0) {
                 log_oom();
@@ -546,8 +555,9 @@ static int help(void) {
         printf("%s [HOST:PORT]\n"
                "%s [SOCKET]\n\n"
                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
-               "  -h --help              Show this help\n"
-               "     --version           Show package version\n",
+               "  -l --listener=FD  Listen on a specific, single file descriptor.\n"
+               "  -h --help         Show this help\n"
+               "     --version      Show package version\n",
                program_invocation_short_name,
                program_invocation_short_name);
 
@@ -557,22 +567,22 @@ static int help(void) {
 static int parse_argv(int argc, char *argv[]) {
 
         enum {
-                ARG_VERSION = 0x100,
-                ARG_IGNORE_ENV
+                ARG_VERSION = 0x100
         };
 
         static const struct option options[] = {
-                { "help",       no_argument, NULL, 'h'           },
-                { "version",    no_argument, NULL, ARG_VERSION   },
+                { "help",     no_argument,       NULL, 'h'         },
+                { "version",  no_argument,       NULL, ARG_VERSION },
+                { "listener", required_argument, NULL, 'l'         },
                 {}
         };
 
-        int c;
+        int c, fd;
 
         assert(argc >= 0);
         assert(argv);
 
-        while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
+        while ((c = getopt_long(argc, argv, "hl:", options, NULL)) >= 0) {
 
                 switch (c) {
 
@@ -584,6 +594,18 @@ static int parse_argv(int argc, char *argv[]) {
                         puts(SYSTEMD_FEATURES);
                         return 0;
 
+                case 'l':
+                        if (safe_atoi(optarg, &fd) < 0) {
+                                log_error("Failed to parse listener file descriptor: %s", optarg);
+                                return -EINVAL;
+                        }
+                        if (fd < SD_LISTEN_FDS_START) {
+                                log_error("Listener file descriptor must be at least %d.", SD_LISTEN_FDS_START);
+                                return -EINVAL;
+                        }
+                        arg_listener = fd;
+                        break;
+
                 case '?':
                         return -EINVAL;
 
@@ -618,25 +640,32 @@ int main(int argc, char *argv[]) {
         if (r <= 0)
                 goto finish;
 
-        r = sd_event_new(&event);
+        r = sd_event_default(&event);
         if (r < 0) {
                 log_error("Failed to allocate event loop: %s", strerror(-r));
                 goto finish;
         }
 
-        n = sd_listen_fds(1);
-        if (n < 0) {
-                log_error("Failed to receive sockets from parent.");
-                r = n;
-                goto finish;
-        } else if (n == 0) {
-                log_error("Didn't get any sockets passed in.");
-                r = -EINVAL;
-                goto finish;
-        }
-
-        for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
-                r = add_listen_socket(&context, event, fd);
+        if (arg_listener == -1) {
+                n = sd_listen_fds(1);
+                if (n < 0) {
+                        log_error("Failed to receive sockets from parent.");
+                        r = n;
+                        goto finish;
+                } else if (n == 0) {
+                        log_error("Didn't get any sockets passed in.");
+                        r = -EINVAL;
+                        goto finish;
+                }
+                log_info("Listening on %d inherited socket(s), starting with fd=%d.", n, SD_LISTEN_FDS_START);
+                for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
+                        r = add_listen_socket(&context, event, fd);
+                        if (r < 0)
+                                goto finish;
+                }
+        } else {
+                log_info("Listening on single inherited socket fd=%d.", arg_listener);
+                r = add_listen_socket(&context, event, arg_listener);
                 if (r < 0)
                         goto finish;
         }