chiark / gitweb /
socket-proxyd: rework to support multiple sockets and splice()-based zero-copy network IO
authorLennart Poettering <lennart@poettering.net>
Wed, 6 Nov 2013 21:40:54 +0000 (22:40 +0100)
committerLennart Poettering <lennart@poettering.net>
Wed, 6 Nov 2013 22:03:12 +0000 (23:03 +0100)
This also drops --ignore-env, which can't really work anymore if we
allow multiple fds. Also adds support for pretty printing of peer
identities for debug purposes, and abstract namespace UNIX sockets. Also
ensures that we never take more connections than a certain limit.

man/systemd-socket-proxyd.xml
src/shared/socket-util.c
src/shared/socket-util.h
src/socket-proxy/socket-proxyd.c

index d17c86e479612c9e0aaf5ad0d5e4a2de8f4c87ce..fcf4aafd60d2d5df9081a5bdffb753c3df581404 100644 (file)
         </refmeta>
         <refnamediv>
                 <refname>systemd-socket-proxyd</refname>
-                <refpurpose>Inherit a socket. Bidirectionally
-                proxy.</refpurpose>
+                <refpurpose>Bidirectionally proxy local sockets to another (possibly remote) socket.</refpurpose>
         </refnamediv>
         <refsynopsisdiv>
                 <cmdsynopsis>
                         <command>systemd-socket-proxyd</command>
-                        <arg choice="opt" rep="repeat">OPTIONS</arg>
-                        <arg choice="plain"><replaceable>HOSTNAME-OR-IPADDR</replaceable></arg>
-                        <arg choice="plain"><replaceable>PORT-OR-SERVICE</replaceable></arg>
+                        <arg choice="opt" rep="repeat"><replaceable>OPTIONS</replaceable></arg>
+                        <arg choice="plain"><replaceable>HOST</replaceable>:<replaceable>PORT</replaceable></arg>
                 </cmdsynopsis>
                 <cmdsynopsis>
                         <command>systemd-socket-proxyd</command>
-                        <arg choice="opt" rep="repeat">OPTIONS</arg>
+                        <arg choice="opt" rep="repeat"><replaceable>OPTIONS</replaceable></arg>
                         <arg choice="plain"><replaceable>UNIX-DOMAIN-SOCKET-PATH</replaceable>
                         </arg>
                 </cmdsynopsis>
         <refsect1>
                 <title>Description</title>
                 <para>
-                <command>systemd-socket-proxyd</command> provides a proxy
-                to socket-activate services that do not yet support
-                native socket activation. On behalf of the daemon,
-                the proxy inherits the socket from systemd, accepts
-                each client connection, opens a connection to the server
-                for each client, and then bidirectionally forwards
-                data between the two.</para>
+                <command>systemd-socket-proxyd</command> is a generic
+                socket-activated network socket forwarder proxy daemon
+                for IPV4, IPv6 and UNIX stream sockets. It may be used
+                to bi-directionally forward traffic from a local listening socket to a
+                local or remote destination socket.</para>
+
+                <para>One use of this tool is to provide
+                socket-activation support for services that do not
+                natively support socket activation. On behalf of the
+                service to activate, the proxy inherits the socket
+                from systemd, accepts each client connection, opens a
+                connection to a configured server for each client, and
+                then bidirectionally forwards data between the
+                two.</para>
                 <para>This utility's behavior is similar to
                 <citerefentry><refentrytitle>socat</refentrytitle><manvolnum>1</manvolnum></citerefentry>.
                 The main differences for <command>systemd-socket-proxyd</command>
                                         string and exits.</para>
                                 </listitem>
                         </varlistentry>
-                        <varlistentry>
-                                <term><option>--ignore-env</option></term>
-                                <listitem>
-                                        <para>Skips verification of
-                                        the expected PID and file
-                                        descriptor numbers. Use this if
-                                        invoked indirectly, for
-                                        example, with a shell script
-                                        rather than with
-                                        <option>ExecStart=/usr/lib/systemd/systemd-socket-proxyd</option>
-                                        </para>
-                                </listitem>
-                        </varlistentry>
                 </variablelist>
         </refsect1>
         <refsect1>
@@ -205,7 +197,7 @@ while [ ! -f /tmp/nginx.pid ]
   do
      /usr/bin/inotifywait /tmp/nginx.pid
   done
-/usr/bin/systemd-socket-proxyd --ignore-env localhost 8080]]>
+exec /usr/bin/systemd-socket-proxyd localhost 8080]]>
 </programlisting>
                         </example>
                         <example label="nginx configuration">
@@ -232,23 +224,11 @@ $ curl http://localhost:80/]]>
         <refsect1>
                 <title>See Also</title>
                 <para>
-                <citerefentry>
-                        <refentrytitle>
-                        systemd.service</refentrytitle>
-                        <manvolnum>5</manvolnum>
-                </citerefentry>,
-                <citerefentry>
-                        <refentrytitle>
-                        systemd.socket</refentrytitle>
-                        <manvolnum>5</manvolnum>
-                </citerefentry>,
-                <citerefentry>
-                        <refentrytitle>systemctl</refentrytitle>
-                        <manvolnum>1</manvolnum>
-                </citerefentry>,
-                <citerefentry>
-                        <refentrytitle>socat</refentrytitle>
-                        <manvolnum>1</manvolnum>
-                </citerefentry></para>
+                        <citerefentry><refentrytitle>systemd</refentrytitle><manvolnum>1</manvolnum></citerefentry>,
+                        <citerefentry><refentrytitle>systemd.socket</refentrytitle><manvolnum>5</manvolnum></citerefentry>,
+                        <citerefentry><refentrytitle>systemd.service</refentrytitle><manvolnum>5</manvolnum></citerefentry>,
+                        <citerefentry><refentrytitle>systemctl</refentrytitle><manvolnum>1</manvolnum></citerefentry>,
+                        <citerefentry><refentrytitle>socat</refentrytitle><manvolnum>1</manvolnum></citerefentry>
+                </para>
         </refsect1>
 </refentry>
index 1175795d7c978748a4c867dd0b99612e7d74a974..0097f011bb1332741b20490c6ef6be102f2df9f0 100644 (file)
@@ -568,6 +568,89 @@ bool socket_address_matches_fd(const SocketAddress *a, int fd) {
         return false;
 }
 
+int getpeername_pretty(int fd, char **ret) {
+
+        union {
+                struct sockaddr sa;
+                struct sockaddr_un un;
+                struct sockaddr_in in;
+                struct sockaddr_in6 in6;
+                struct sockaddr_storage storage;
+        } sa;
+
+        socklen_t salen;
+        char *p;
+
+        assert(fd >= 0);
+        assert(ret);
+
+        salen = sizeof(sa);
+        if (getpeername(fd, &sa.sa, &salen) < 0)
+                return -errno;
+
+        switch (sa.sa.sa_family) {
+
+        case AF_INET: {
+                uint32_t a;
+
+                a = ntohl(sa.in.sin_addr.s_addr);
+
+                if (asprintf(&p,
+                             "%u.%u.%u.%u:%u",
+                             a >> 24, (a >> 16) & 0xFF, (a >> 8) & 0xFF, a & 0xFF,
+                             ntohs(sa.in.sin_port)) < 0)
+                        return -ENOMEM;
+
+                break;
+        }
+
+        case AF_INET6: {
+                static const unsigned char ipv4_prefix[] = {
+                        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF
+                };
+
+                if (memcmp(&sa.in6.sin6_addr, ipv4_prefix, sizeof(ipv4_prefix)) == 0) {
+                        const uint8_t *a = sa.in6.sin6_addr.s6_addr+12;
+
+                        if (asprintf(&p,
+                                     "%u.%u.%u.%u:%u",
+                                     a[0], a[1], a[2], a[3],
+                                     ntohs(sa.in6.sin6_port)) < 0)
+                                return -ENOMEM;
+                } else {
+                        char a[INET6_ADDRSTRLEN];
+
+                        if (asprintf(&p,
+                                     "%s:%u",
+                                     inet_ntop(AF_INET6, &sa.in6.sin6_addr, a, sizeof(a)),
+                                     ntohs(sa.in6.sin6_port)) < 0)
+                                return -ENOMEM;
+                }
+
+                break;
+        }
+
+        case AF_UNIX: {
+                struct ucred ucred;
+
+                salen = sizeof(ucred);
+                if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &ucred, &salen) < 0)
+                        return -errno;
+
+                if (asprintf(&p, "PID %lu/UID %lu", (unsigned long) ucred.pid, (unsigned long) ucred.pid) < 0)
+                        return -ENOMEM;
+
+                break;
+        }
+
+        default:
+                return -ENOTSUP;
+        }
+
+        *ret = p;
+        return 0;
+}
+
 static const char* const netlink_family_table[] = {
         [NETLINK_ROUTE] = "route",
         [NETLINK_FIREWALL] = "firewall",
index 0b9bf2fefc50df0d8c498d9d36e9367d4d181aaa..13566f96917e614148abb3b7b138e17408ba318c 100644 (file)
@@ -99,3 +99,5 @@ int netlink_family_to_string_alloc(int b, char **s);
 int netlink_family_from_string(const char *s);
 
 bool socket_ipv6_is_supported(void);
+
+int getpeername_pretty(int fd, char **ret);
index a449b0eec42ba9d28af64c85b4f7642d4f264f5b..1c64c0e2e5765732ee52ba034d5cb646cc694147 100644 (file)
 #include "util.h"
 #include "event-util.h"
 #include "build.h"
+#include "set.h"
+#include "path-util.h"
+
+#define BUFFER_SIZE (256 * 1024)
+#define CONNECTIONS_MAX 256
 
-#define BUFFER_SIZE 16384
 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
+DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
 
-unsigned int total_clients = 0;
+typedef struct Context {
+        Set *listen;
+        Set *connections;
+} Context;
 
-DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
+typedef struct Connection {
+        int server_fd, client_fd;
+        int server_to_client_buffer[2]; /* a pipe */
+        int client_to_server_buffer[2]; /* a pipe */
 
-struct proxy {
-        int listen_fd;
-        bool ignore_env;
-        bool remote_is_inet;
-        const char *remote_host;
-        const char *remote_service;
-};
+        size_t server_to_client_buffer_full, client_to_server_buffer_full;
+        size_t server_to_client_buffer_size, client_to_server_buffer_size;
+
+        sd_event_source *server_event_source, *client_event_source;
+} Connection;
 
-struct connection {
-        int fd;
-        uint32_t events;
-        sd_event_source *w;
-        struct connection *c_destination;
-        size_t buffer_filled_len;
-        size_t buffer_sent_len;
-        char buffer[BUFFER_SIZE];
+union sockaddr_any {
+        struct sockaddr sa;
+        struct sockaddr_un un;
+        struct sockaddr_in in;
+        struct sockaddr_in6 in6;
+        struct sockaddr_storage storage;
 };
 
-static void free_connection(struct connection *c) {
-        if (c != NULL) {
-                log_debug("Freeing fd=%d (conn %p).", c->fd, c);
-                sd_event_source_unref(c->w);
-                if (c->fd > 0)
-                        close_nointr_nofail(c->fd);
-                free(c);
-        }
-}
+static const char *arg_remote_host = NULL;
 
-static int add_event_to_connection(struct connection *c, uint32_t events) {
-        int r;
+static void connection_free(Connection *c) {
+        assert(c);
 
-        log_debug("Have revents=%d. Adding revents=%d.", c->events, events);
+        sd_event_source_unref(c->server_event_source);
+        sd_event_source_unref(c->client_event_source);
 
-        c->events |= events;
+        if (c->server_fd >= 0)
+                close_nointr_nofail(c->server_fd);
+        if (c->client_fd >= 0)
+                close_nointr_nofail(c->client_fd);
 
-        r = sd_event_source_set_io_events(c->w, c->events);
-        if (r < 0) {
-                log_error("Error %d setting revents: %s", r, strerror(-r));
-                return r;
-        }
+        close_pipe(c->server_to_client_buffer);
+        close_pipe(c->client_to_server_buffer);
 
-        r = sd_event_source_set_enabled(c->w, SD_EVENT_ON);
-        if (r < 0) {
-                log_error("Error %d enabling source: %s", r, strerror(-r));
-                return r;
-        }
+        free(c);
+}
 
-        return 0;
+static void context_free(Context *context) {
+        sd_event_source *es;
+        Connection *c;
+
+        assert(context);
+
+        while ((es = set_steal_first(context->listen)))
+                sd_event_source_unref(es);
+
+        while ((c = set_steal_first(context->connections)))
+                connection_free(c);
+
+        set_free(context->listen);
+        set_free(context->connections);
 }
 
-static int remove_event_from_connection(struct connection *c, uint32_t events) {
+static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
         int r;
 
-        log_debug("Have revents=%d. Removing revents=%d.", c->events, events);
+        assert(sa);
+        assert(salen);
 
-        c->events &= ~events;
+        if (path_is_absolute(arg_remote_host)) {
+                sa->un.sun_family = AF_UNIX;
+                strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
+                sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
 
-        r = sd_event_source_set_io_events(c->w, c->events);
-        if (r < 0) {
-                log_error("Error %d setting revents: %s", r, strerror(-r));
-                return r;
-        }
+                *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path);
 
-        if (c->events == 0) {
-                r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF);
-                if (r < 0) {
-                        log_error("Error %d disabling source: %s", r, strerror(-r));
-                        return r;
+        } else if (arg_remote_host[0] == '@') {
+                sa->un.sun_family = AF_UNIX;
+                sa->un.sun_path[0] = 0;
+                strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
+                sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
+
+                *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
+
+        } else {
+                _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
+                const char *node, *service;
+
+                struct addrinfo hints = {
+                        .ai_family = AF_UNSPEC,
+                        .ai_socktype = SOCK_STREAM,
+                        .ai_flags = AI_ADDRCONFIG
+                };
+
+                service = strrchr(arg_remote_host, ':');
+                if (service) {
+                        node = strndupa(arg_remote_host, service - arg_remote_host);
+                        service ++;
+                } else {
+                        node = arg_remote_host;
+                        service = "80";
                 }
-        }
 
-        return 0;
-}
+                log_debug("Looking up address info for %s:%s", node, service);
+                r = getaddrinfo(node, service, &hints, &result);
+                if (r != 0) {
+                        log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
+                        return -EHOSTUNREACH;
+                }
 
-static int send_buffer(struct connection *sender) {
-        struct connection *receiver = sender->c_destination;
-        ssize_t len;
-        int r = 0;
-
-        /* We cannot assume that even a partial send() indicates that
-         * the next send() will return EAGAIN or EWOULDBLOCK. Loop until
-         * it does. */
-        while (sender->buffer_filled_len > sender->buffer_sent_len) {
-                len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
-                log_debug("send(%d, ...)=%zd", receiver->fd, len);
-                if (len < 0) {
-                        if (errno != EWOULDBLOCK && errno != EAGAIN) {
-                                log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
-                                return -errno;
-                        }
-                        else {
-                                /* send() is in a would-block state. */
-                                break;
-                        }
+                assert(result);
+                if (result->ai_addrlen > sizeof(union sockaddr_any)) {
+                        log_error("Address too long.");
+                        return -E2BIG;
                 }
 
-                /* len < 0 can't occur here. len == 0 is possible but
-                 * undefined behavior for nonblocking send(). */
-                assert(len > 0);
-                sender->buffer_sent_len += len;
+                memcpy(sa, result->ai_addr, result->ai_addrlen);
+                *salen = result->ai_addrlen;
         }
 
-        log_debug("send(%d, ...) completed with %zu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
-
-        /* Detect a would-block state or partial send. */
-        if (sender->buffer_filled_len > sender->buffer_sent_len) {
+        return 0;
+}
 
-                /* If the buffer is full, disable events coming for recv. */
-                if (sender->buffer_filled_len == BUFFER_SIZE) {
-                        r = remove_event_from_connection(sender, EPOLLIN);
-                        if (r < 0) {
-                                log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
-                                return r;
-                        }
-                }
+static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
+        int r;
 
-                /* Watch for when the recipient can be sent data again. */
-                r = add_event_to_connection(receiver, EPOLLOUT);
-                if (r < 0) {
-                        log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
-                        return r;
-                }
-                log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
-                return r;
-        }
+        assert(c);
+        assert(buffer);
+        assert(sz);
 
-        /* If we sent everything without any issues (would-block or
-         * partial send), the buffer is now empty. */
-        sender->buffer_filled_len = 0;
-        sender->buffer_sent_len = 0;
+        if (buffer[0] >= 0)
+                return 0;
 
-        /* Enable the sender's receive watcher, in case the buffer was
-         * full and we disabled it. */
-        r = add_event_to_connection(sender, EPOLLIN);
+        r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
         if (r < 0) {
-                log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
-                return r;
+                log_error("Failed to allocate pipe buffer: %m");
+                return -errno;
         }
 
-        /* Disable the other side's send watcher, as we have no data to send now. */
-        r = remove_event_from_connection(receiver, EPOLLOUT);
+        fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
+
+        r = fcntl(buffer[0], F_GETPIPE_SZ);
         if (r < 0) {
-                log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
-                return r;
+                log_error("Failed to get pipe buffer size: %m");
+                return -errno;
         }
 
+        assert(r > 0);
+        *sz = r;
+
         return 0;
 }
 
-static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
-        struct connection *c = (struct connection *) userdata;
-        int r = 0;
-        ssize_t len;
-
-        assert(revents & (EPOLLIN | EPOLLOUT));
-        assert(fd == c->fd);
-        assert(s == c->w);
-
-        log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
-
-        if (revents & EPOLLIN) {
-                log_debug("About to recv up to %zu bytes from fd=%d (%zu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
-
-                /* Receive until the buffer's full, there's no more data,
-                 * or the client/server disconnects. */
-                while (c->buffer_filled_len < BUFFER_SIZE) {
-                        len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
-                        log_debug("recv(%d, ...)=%zd", fd, len);
-                        if (len < 0) {
-                                if (errno != EWOULDBLOCK && errno != EAGAIN) {
-                                        log_error("Error %d in recv from fd=%d: %m", errno, fd);
-                                        return -errno;
-                                } else {
-                                        /* recv() is in a blocking state. */
-                                        break;
-                                }
-                        }
-                        else if (len == 0) {
-                                log_debug("Clean disconnection from fd=%d", fd);
-                                total_clients--;
-                                free_connection(c->c_destination);
-                                free_connection(c);
-                                return 0;
+static int connection_shovel(
+                Connection *c,
+                int *from, int buffer[2], int *to,
+                size_t *full, size_t *sz,
+                sd_event_source **from_source, sd_event_source **to_source) {
+
+        bool shoveled;
+
+        assert(c);
+        assert(from);
+        assert(buffer);
+        assert(buffer[0] >= 0);
+        assert(buffer[1] >= 0);
+        assert(to);
+        assert(full);
+        assert(sz);
+        assert(from_source);
+        assert(to_source);
+
+        do {
+                ssize_t z;
+
+                shoveled = false;
+
+                if (*full < *sz && *from >= 0 && *to >= 0) {
+                        z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
+                        if (z > 0) {
+                                *full += z;
+                                shoveled = true;
+                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
+                                *from_source = sd_event_source_unref(*from_source);
+                                close_nointr_nofail(*from);
+                                *from = -1;
+                        } else if (errno != EAGAIN && errno != EINTR) {
+                                log_error("Failed to splice: %m");
+                                return -errno;
                         }
-
-                        assert(len > 0);
-                        log_debug("Recording that the buffer got %zd more bytes full.", len);
-                        c->buffer_filled_len += len;
-                        log_debug("Buffer now has %zu bytes full.", c->buffer_filled_len);
                 }
 
-                /* Try sending the data immediately. */
-                return send_buffer(c);
-        }
-        else {
-                return send_buffer(c->c_destination);
-        }
+                if (*full > 0 && *to >= 0) {
+                        z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
+                        if (z > 0) {
+                                *full -= z;
+                                shoveled = true;
+                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
+                                *to_source = sd_event_source_unref(*to_source);
+                                close_nointr_nofail(*to);
+                                *to = -1;
+                        } else if (errno != EAGAIN && errno != EINTR) {
+                                log_error("Failed to splice: %m");
+                                return -errno;
+                        }
+                }
+        } while (shoveled);
 
-        return r;
+        return 0;
 }
 
-/* Once sending to the server is ready, set up the real watchers. */
-static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
-        struct sd_event *e = NULL;
-        struct connection *c_server_to_client = (struct connection *) userdata;
-        struct connection *c_client_to_server = c_server_to_client->c_destination;
+static int connection_enable_event_sources(Connection *c, sd_event *event);
+
+static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
+        Connection *c = userdata;
         int r;
 
-        assert(revents & EPOLLOUT);
+        assert(s);
+        assert(fd >= 0);
+        assert(c);
 
-        e = sd_event_get(s);
+        r = connection_shovel(c,
+                              &c->server_fd, c->server_to_client_buffer, &c->client_fd,
+                              &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
+                              &c->server_event_source, &c->client_event_source);
+        if (r < 0)
+                goto quit;
 
-        /* Cancel the initial write watcher for the server. */
-        sd_event_source_unref(s);
+        r = connection_shovel(c,
+                              &c->client_fd, c->client_to_server_buffer, &c->server_fd,
+                              &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
+                              &c->client_event_source, &c->server_event_source);
+        if (r < 0)
+                goto quit;
 
-        log_debug("Connected to server. Initializing watchers for receiving data.");
+        /* EOF on both sides? */
+        if (c->server_fd == -1 && c->client_fd == -1)
+                goto quit;
 
-        /* A recv watcher for the server. */
-        r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w);
-        if (r < 0) {
-                log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
-                goto fail;
-        }
-        c_server_to_client->events = EPOLLIN;
+        /* Server closed, and all data written to client? */
+        if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
+                goto quit;
 
-        /* A recv watcher for the client. */
-        r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w);
-        if (r < 0) {
-                log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
-                goto fail;
-        }
-        c_client_to_server->events = EPOLLIN;
+        /* Client closed, and all data written to server? */
+        if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
+                goto quit;
 
-        goto finish;
+        r = connection_enable_event_sources(c, sd_event_get(s));
+        if (r < 0)
+                goto quit;
 
-fail:
-        free_connection(c_client_to_server);
-        free_connection(c_server_to_client);
+        return 1;
 
-finish:
-        return r;
+quit:
+        connection_free(c);
+        return 0; /* ignore errors, continue serving */
 }
 
-static int get_server_connection_fd(const struct proxy *proxy) {
-        int server_fd;
-        int r = -EBADF;
-        int len;
+static int connection_enable_event_sources(Connection *c, sd_event *event) {
+        uint32_t a = 0, b = 0;
+        int r;
 
-        if (proxy->remote_is_inet) {
-                int s;
-                _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
-                struct addrinfo hints = {.ai_family = AF_UNSPEC,
-                                         .ai_socktype = SOCK_STREAM,
-                                         .ai_flags = AI_PASSIVE};
-
-                log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
-                s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
-                if (s != 0) {
-                        log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
-                        return r;
-                }
+        assert(c);
+        assert(event);
 
-                if (result == NULL) {
-                        log_error("getaddrinfo: no result");
-                        return r;
-                }
+        if (c->server_to_client_buffer_full > 0)
+                b |= EPOLLOUT;
+        if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
+                a |= EPOLLIN;
 
-                /* @TODO: Try connecting to all results instead of just the first. */
-                server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
-                if (server_fd < 0) {
-                        log_error("Error %d creating socket: %m", errno);
-                        return r;
-                }
+        if (c->client_to_server_buffer_full > 0)
+                a |= EPOLLOUT;
+        if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
+                b |= EPOLLIN;
 
-                r = connect(server_fd, result->ai_addr, result->ai_addrlen);
-                /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
-                if (r < 0 && errno != EINPROGRESS) {
-                        log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
-                        return r;
-                }
+        if (c->server_event_source)
+                r = sd_event_source_set_io_events(c->server_event_source, a);
+        else if (c->server_fd >= 0)
+                r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
+        else
+                r = 0;
+
+        if (r < 0) {
+                log_error("Failed to set up server event source: %s", strerror(-r));
+                return r;
         }
-        else {
-                struct sockaddr_un remote;
 
-                server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
-                if (server_fd < 0) {
-                        log_error("Error %d creating socket: %m", errno);
-                        return -EBADFD;
-                }
+        if (c->client_event_source)
+                r = sd_event_source_set_io_events(c->client_event_source, b);
+        else if (c->client_fd >= 0)
+                r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
+        else
+                r = 0;
 
-                remote.sun_family = AF_UNIX;
-                strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
-                len = strlen(remote.sun_path) + sizeof(remote.sun_family);
-                r = connect(server_fd, (struct sockaddr *) &remote, len);
-                if (r < 0 && errno != EINPROGRESS) {
-                        log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
-                        return -EBADFD;
-                }
+        if (r < 0) {
+                log_error("Failed to set up server event source: %s", strerror(-r));
+                return r;
         }
 
-        log_debug("Server connection is fd=%d", server_fd);
-        return server_fd;
+        return 0;
 }
 
-static int do_accept(sd_event *e, struct proxy *p, int fd) {
-        struct connection *c_server_to_client = NULL;
-        struct connection *c_client_to_server = NULL;
-        int r = 0;
-        union sockaddr_union sa;
-        socklen_t salen = sizeof(sa);
-        int client_fd, server_fd;
-
-        client_fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC);
-        if (client_fd < 0) {
-                if (errno == EAGAIN || errno == EWOULDBLOCK)
-                    return -errno;
-                log_error("Error %d accepting client connection: %m", errno);
-                r = -errno;
+static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
+        Connection *c = userdata;
+        socklen_t solen;
+        int error, r;
+
+        assert(s);
+        assert(fd >= 0);
+        assert(c);
+
+        solen = sizeof(error);
+        r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
+        if (r < 0) {
+                log_error("Failed to issue SO_ERROR: %m");
                 goto fail;
         }
 
-        server_fd = get_server_connection_fd(p);
-        if (server_fd < 0) {
-                log_error("Error initiating server connection.");
-                r = server_fd;
+        if (error != 0) {
+                log_error("Failed to connect to remote host: %s", strerror(error));
                 goto fail;
         }
 
-        c_client_to_server = new0(struct connection, 1);
-        if (c_client_to_server == NULL) {
-                log_oom();
+        c->client_event_source = sd_event_source_unref(c->client_event_source);
+
+        r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
+        if (r < 0)
                 goto fail;
-        }
 
-        c_server_to_client = new0(struct connection, 1);
-        if (c_server_to_client == NULL) {
-                log_oom();
+        r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
+        if (r < 0)
                 goto fail;
-        }
 
-        c_client_to_server->fd = client_fd;
-        c_server_to_client->fd = server_fd;
+        r = connection_enable_event_sources(c, sd_event_get(s));
+        if (r < 0)
+                goto fail;
 
-        if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
-                char sa_str[INET6_ADDRSTRLEN];
-                const char *success;
+        return 0;
 
-                success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
-                if (success == NULL)
-                        log_warning("Error %d calling inet_ntop: %m", errno);
-                else
-                        log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
-        }
-        else {
-                log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
+fail:
+        connection_free(c);
+        return 0; /* ignore errors, continue serving */
+}
+
+static int add_connection_socket(Context *context, sd_event *event, int fd) {
+        union sockaddr_any sa = {};
+        socklen_t salen;
+        Connection *c;
+        int r;
+
+        assert(context);
+        assert(event);
+        assert(fd >= 0);
+
+        if (set_size(context->connections) > CONNECTIONS_MAX) {
+                log_warning("Hit connection limit, refusing connection.");
+                close_nointr_nofail(fd);
+                return 0;
         }
 
-        total_clients++;
-        log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
-        log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
+        r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
+        if (r < 0)
+                return log_oom();
 
-        /* Initialize watcher for send to server; this shows connectivity. */
-        r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w);
-        if (r < 0) {
-                log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
+        c = new0(Connection, 1);
+        if (!c)
+                return log_oom();
+
+        c->server_fd = fd;
+        c->client_fd = -1;
+        c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
+        c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
+
+        r = get_remote_sockaddr(&sa, &salen);
+        if (r < 0)
+                goto fail;
+
+        c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
+        if (c->client_fd < 0) {
+                log_error("Failed to get remote socket: %m");
                 goto fail;
         }
 
-        /* Allow lookups of the opposite connection. */
-        c_server_to_client->c_destination = c_client_to_server;
-        c_client_to_server->c_destination = c_server_to_client;
+        r = connect(c->client_fd, &sa.sa, salen);
+        if (r < 0) {
+                if (errno == EINPROGRESS) {
+                        r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
+                        if (r < 0) {
+                                log_error("Failed to add connection socket: %s", strerror(-r));
+                                goto fail;
+                        }
+                } else {
+                        log_error("Failed to connect to remote host: %m");
+                        goto fail;
+                }
+        } else {
+                r = connection_enable_event_sources(c, event);
+                if (r < 0)
+                        goto fail;
+        }
 
-        goto finish;
+        return 0;
 
 fail:
-        log_warning("Accepting a client connection or connecting to the server failed.");
-        free_connection(c_client_to_server);
-        free_connection(c_server_to_client);
-
-finish:
-        return r;
+        connection_free(c);
+        return 0; /* ignore non-OOM errors, continue serving */
 }
 
 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
-        struct proxy *p = (struct proxy *) userdata;
-        sd_event *e = NULL;
-        int r = 0;
+        Context *context = userdata;
+        int nfd = -1, r;
 
+        assert(s);
+        assert(fd >= 0);
         assert(revents & EPOLLIN);
+        assert(context);
+
+        nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
+        if (nfd >= 0) {
+                _cleanup_free_ char *peer = NULL;
 
-        e = sd_event_get(s);
+                getpeername_pretty(nfd, &peer);
+                log_debug("New connection from %s", strna(peer));
 
-        for (;;) {
-                r = do_accept(e, p, fd);
-                if (r == -EAGAIN || r == -EWOULDBLOCK)
-                        break;
+                r = add_connection_socket(context, sd_event_get(s), nfd);
                 if (r < 0) {
-                        log_error("Error %d while trying to accept: %s", r, strerror(-r));
-                        break;
+                        close_nointr_nofail(fd);
+                        return r;
                 }
-        }
 
-        /* Re-enable the watcher. */
+        } else if (errno != -EAGAIN)
+                log_warning("Failed to accept() socket: %m");
+
         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
         if (r < 0) {
                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
                 return r;
         }
 
-        /* Preserve the main loop even if a single accept() fails. */
         return 1;
 }
 
-static int run_main_loop(struct proxy *proxy) {
-        _cleanup_event_source_unref_ sd_event_source *w_accept = NULL;
-        _cleanup_event_unref_ sd_event *e = NULL;
-        int r = EXIT_SUCCESS;
+static int add_listen_socket(Context *context, sd_event *event, int fd) {
+        sd_event_source *source;
+        int r;
+
+        assert(context);
+        assert(event);
+        assert(fd >= 0);
 
-        r = sd_event_new(&e);
+        log_info("Listening on %i", fd);
+
+        r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
         if (r < 0) {
-                log_error("Failed to allocate event loop: %s", strerror(-r));
+                log_oom();
+                return r;
+        }
+
+        r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
+        if (r < 0) {
+                log_error("Failed to determine socket type: %s", strerror(-r));
                 return r;
         }
+        if (r == 0) {
+                log_error("Passed in socket is not a stream socket.");
+                return -EINVAL;
+        }
 
-        r = fd_nonblock(proxy->listen_fd, true);
+        r = fd_nonblock(fd, true);
         if (r < 0) {
-                log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r));
+                log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
                 return r;
         }
 
-        log_debug("Initializing main listener fd=%d", proxy->listen_fd);
+        r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
+        if (r < 0) {
+                log_error("Failed to add event source: %s", strerror(-r));
+                return r;
+        }
 
-        r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
+        r = set_put(context->listen, source);
         if (r < 0) {
-                log_error("Error %d while adding event IO source: %s", r, strerror(-r));
+                log_error("Failed to add source to set: %s", strerror(-r));
+                sd_event_source_unref(source);
                 return r;
         }
 
         /* Set the watcher to oneshot in case other processes are also
          * watching to accept(). */
-        r = sd_event_source_set_enabled(w_accept, SD_EVENT_ONESHOT);
+        r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
         if (r < 0) {
-                log_error("Error %d while setting event IO source to ONESHOT: %s", r, strerror(-r));
+                log_error("Failed to enable oneshot mode: %s", strerror(-r));
                 return r;
         }
 
-        log_debug("Initialized main listener. Entering loop.");
-
-        return sd_event_loop(e);
+        return 0;
 }
 
 static int help(void) {
 
-        printf("%s hostname-or-ip port-or-service\n"
-               "%s unix-domain-socket-path\n\n"
-               "Inherit a socket. Bidirectionally proxy.\n\n"
-               "  -h --help       Show this help\n"
-               "  --version       Print version and exit\n"
-               "  --ignore-env    Ignore expected systemd environment\n",
+        printf("%s [HOST:PORT]\n"
+               "%s [SOCKET]\n\n"
+               "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
+               "  -h --help              Show this help\n"
+               "     --version           Show package version\n",
                program_invocation_short_name,
                program_invocation_short_name);
 
         return 0;
 }
 
-static int parse_argv(int argc, char *argv[], struct proxy *p) {
+static int parse_argv(int argc, char *argv[]) {
 
         enum {
                 ARG_VERSION = 0x100,
@@ -521,7 +564,6 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) {
         static const struct option options[] = {
                 { "help",       no_argument, NULL, 'h'           },
                 { "version",    no_argument, NULL, ARG_VERSION   },
-                { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
                 {}
         };
 
@@ -542,10 +584,6 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) {
                         puts(SYSTEMD_FEATURES);
                         return 0;
 
-                case ARG_IGNORE_ENV:
-                        p->ignore_env = true;
-                        continue;
-
                 case '?':
                         return -EINVAL;
 
@@ -554,75 +592,63 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) {
                 }
         }
 
-        if (optind + 1 != argc && optind + 2 != argc) {
-                log_error("Incorrect number of positional arguments.");
-                help();
+        if (optind >= argc) {
+                log_error("Not enough parameters.");
                 return -EINVAL;
         }
 
-        p->remote_host = argv[optind];
-        assert(p->remote_host);
-
-        p->remote_is_inet = p->remote_host[0] != '/';
-
-        if (optind == argc - 2) {
-                if (!p->remote_is_inet) {
-                        log_error("A port or service is not allowed for Unix socket destinations.");
-                        help();
-                        return -EINVAL;
-                }
-                p->remote_service = argv[optind + 1];
-                assert(p->remote_service);
-        } else if (p->remote_is_inet) {
-                log_error("A port or service is required for IP destinations.");
-                help();
+        if (argc != optind+1) {
+                log_error("Too many parameters.");
                 return -EINVAL;
         }
 
+        arg_remote_host = argv[optind];
         return 1;
 }
 
 int main(int argc, char *argv[]) {
-        struct proxy p = {};
-        int r;
+        _cleanup_event_unref_ sd_event *event = NULL;
+        Context context = {};
+        int r, n, fd;
 
         log_parse_environment();
         log_open();
 
-        r = parse_argv(argc, argv, &p);
+        r = parse_argv(argc, argv);
         if (r <= 0)
                 goto finish;
 
-        p.listen_fd = SD_LISTEN_FDS_START;
+        r = sd_event_new(&event);
+        if (r < 0) {
+                log_error("Failed to allocate event loop: %s", strerror(-r));
+                goto finish;
+        }
 
-        if (!p.ignore_env) {
-                int n;
-                n = sd_listen_fds(1);
-                if (n == 0) {
-                        log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
-                        r = EXIT_FAILURE;
-                        goto finish;
-                } else if (n < 0) {
-                        log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
-                        r = EXIT_FAILURE;
-                        goto finish;
-                } else if (n > 1) {
-                        log_error("Can't listen on more than one socket.");
-                        r = EXIT_FAILURE;
+        n = sd_listen_fds(1);
+        if (n < 0) {
+                log_error("Failed to receive sockets from parent.");
+                r = n;
+                goto finish;
+        } else if (n == 0) {
+                log_error("Didn't get any sockets passed in.");
+                r = -EINVAL;
+                goto finish;
+        }
+
+        for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
+                r = add_listen_socket(&context, event, fd);
+                if (r < 0)
                         goto finish;
-                }
         }
 
-        r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
+        r = sd_event_loop(event);
         if (r < 0) {
-                log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
+                log_error("Failed to run event loop: %s", strerror(-r));
                 goto finish;
         }
 
-        log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd);
-
-        r = run_main_loop(&p);
-
 finish:
+        context_free(&context);
+
         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
 }