X-Git-Url: https://www.chiark.greenend.org.uk/ucgi/~ianmdlvl/git?p=elogind.git;a=blobdiff_plain;f=src%2Fsocket-proxy%2Fsocket-proxyd.c;h=3041903757648e53eb2c5bcdf5d4cee7b54cea10;hp=d64b0d286785827c462a999e0b7599bf3aafdeb4;hb=25dbe4f50f93fb6398844ba67ea197f76adc237a;hpb=6298945d5c4b9a8116f2b1d1f9c7f6c0ff644a05 diff --git a/src/socket-proxy/socket-proxyd.c b/src/socket-proxy/socket-proxyd.c index d64b0d286..304190375 100644 --- a/src/socket-proxy/socket-proxyd.c +++ b/src/socket-proxy/socket-proxyd.c @@ -26,452 +26,579 @@ #include #include #include -#include +#include #include #include #include #include "sd-daemon.h" #include "sd-event.h" +#include "sd-resolve.h" #include "log.h" #include "socket-util.h" #include "util.h" #include "event-util.h" +#include "build.h" +#include "set.h" +#include "path-util.h" + +#define BUFFER_SIZE (256 * 1024) +#define CONNECTIONS_MAX 256 + +static const char *arg_remote_host = NULL; + +typedef struct Context { + sd_event *event; + sd_resolve *resolve; + + Set *listen; + Set *connections; +} Context; + +typedef struct Connection { + Context *context; + + int server_fd, client_fd; + int server_to_client_buffer[2]; /* a pipe */ + int client_to_server_buffer[2]; /* a pipe */ + + size_t server_to_client_buffer_full, client_to_server_buffer_full; + size_t server_to_client_buffer_size, client_to_server_buffer_size; + + sd_event_source *server_event_source, *client_event_source; + + sd_resolve_query *resolve_query; +} Connection; + +static void connection_free(Connection *c) { + assert(c); + + if (c->context) + set_remove(c->context->connections, c); + + sd_event_source_unref(c->server_event_source); + sd_event_source_unref(c->client_event_source); + + safe_close(c->server_fd); + safe_close(c->client_fd); + + safe_close_pair(c->server_to_client_buffer); + safe_close_pair(c->client_to_server_buffer); + + sd_resolve_query_unref(c->resolve_query); -#define BUFFER_SIZE 16384 -#define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop) - -unsigned int total_clients = 0; - -DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo); - -struct proxy { - int listen_fd; - bool ignore_env; - bool remote_is_inet; - const char *remote_host; - const char *remote_service; -}; - -struct connection { - int fd; - uint32_t events; - sd_event_source *w; - struct connection *c_destination; - size_t buffer_filled_len; - size_t buffer_sent_len; - char buffer[BUFFER_SIZE]; -}; - -static void free_connection(struct connection *c) { - log_debug("Freeing fd=%d (conn %p).", c->fd, c); - sd_event_source_unref(c->w); - close_nointr_nofail(c->fd); free(c); } -static int add_event_to_connection(struct connection *c, uint32_t events) { - int r; +static void context_free(Context *context) { + sd_event_source *es; + Connection *c; - log_debug("Have revents=%d. Adding revents=%d.", c->events, events); + assert(context); - c->events |= events; + while ((es = set_steal_first(context->listen))) + sd_event_source_unref(es); - r = sd_event_source_set_io_events(c->w, c->events); - if (r < 0) { - log_error("Error %d setting revents: %s", r, strerror(-r)); - return r; - } + while ((c = set_first(context->connections))) + connection_free(c); - r = sd_event_source_set_enabled(c->w, SD_EVENT_ON); - if (r < 0) { - log_error("Error %d enabling source: %s", r, strerror(-r)); - return r; - } + set_free(context->listen); + set_free(context->connections); - return 0; + sd_event_unref(context->event); + sd_resolve_unref(context->resolve); } -static int remove_event_from_connection(struct connection *c, uint32_t events) { +static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) { int r; - log_debug("Have revents=%d. Removing revents=%d.", c->events, events); + assert(c); + assert(buffer); + assert(sz); - c->events &= ~events; + if (buffer[0] >= 0) + return 0; - r = sd_event_source_set_io_events(c->w, c->events); + r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK); if (r < 0) { - log_error("Error %d setting revents: %s", r, strerror(-r)); - return r; + log_error("Failed to allocate pipe buffer: %m"); + return -errno; } - if (c->events == 0) { - r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF); - if (r < 0) { - log_error("Error %d disabling source: %s", r, strerror(-r)); - return r; - } + (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE); + + r = fcntl(buffer[0], F_GETPIPE_SZ); + if (r < 0) { + log_error("Failed to get pipe buffer size: %m"); + return -errno; } + assert(r > 0); + *sz = r; + return 0; } -static int send_buffer(struct connection *sender) { - struct connection *receiver = sender->c_destination; - ssize_t len; - int r = 0; - - /* We cannot assume that even a partial send() indicates that - * the next send() will return EAGAIN or EWOULDBLOCK. Loop until - * it does. */ - while (sender->buffer_filled_len > sender->buffer_sent_len) { - len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0); - log_debug("send(%d, ...)=%ld", receiver->fd, len); - if (len < 0) { - if (errno != EWOULDBLOCK && errno != EAGAIN) { - log_error("Error %d in send to fd=%d: %m", errno, receiver->fd); +static int connection_shovel( + Connection *c, + int *from, int buffer[2], int *to, + size_t *full, size_t *sz, + sd_event_source **from_source, sd_event_source **to_source) { + + bool shoveled; + + assert(c); + assert(from); + assert(buffer); + assert(buffer[0] >= 0); + assert(buffer[1] >= 0); + assert(to); + assert(full); + assert(sz); + assert(from_source); + assert(to_source); + + do { + ssize_t z; + + shoveled = false; + + if (*full < *sz && *from >= 0 && *to >= 0) { + z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK); + if (z > 0) { + *full += z; + shoveled = true; + } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { + *from_source = sd_event_source_unref(*from_source); + *from = safe_close(*from); + } else if (errno != EAGAIN && errno != EINTR) { + log_error("Failed to splice: %m"); return -errno; } - else { - /* send() is in a would-block state. */ - break; + } + + if (*full > 0 && *to >= 0) { + z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK); + if (z > 0) { + *full -= z; + shoveled = true; + } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { + *to_source = sd_event_source_unref(*to_source); + *to = safe_close(*to); + } else if (errno != EAGAIN && errno != EINTR) { + log_error("Failed to splice: %m"); + return -errno; } } + } while (shoveled); - /* len < 0 can't occur here. len == 0 is possible but - * undefined behavior for nonblocking send(). */ - assert(len > 0); - sender->buffer_sent_len += len; - } + return 0; +} - log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len); +static int connection_enable_event_sources(Connection *c); - /* Detect a would-block state or partial send. */ - if (sender->buffer_filled_len > sender->buffer_sent_len) { +static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { + Connection *c = userdata; + int r; - /* If the buffer is full, disable events coming for recv. */ - if (sender->buffer_filled_len == BUFFER_SIZE) { - r = remove_event_from_connection(sender, EPOLLIN); - if (r < 0) { - log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); - return r; - } - } + assert(s); + assert(fd >= 0); + assert(c); - /* Watch for when the recipient can be sent data again. */ - r = add_event_to_connection(receiver, EPOLLOUT); - if (r < 0) { - log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); - return r; - } - log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd); - return r; - } + r = connection_shovel(c, + &c->server_fd, c->server_to_client_buffer, &c->client_fd, + &c->server_to_client_buffer_full, &c->server_to_client_buffer_size, + &c->server_event_source, &c->client_event_source); + if (r < 0) + goto quit; + + r = connection_shovel(c, + &c->client_fd, c->client_to_server_buffer, &c->server_fd, + &c->client_to_server_buffer_full, &c->client_to_server_buffer_size, + &c->client_event_source, &c->server_event_source); + if (r < 0) + goto quit; + + /* EOF on both sides? */ + if (c->server_fd == -1 && c->client_fd == -1) + goto quit; + + /* Server closed, and all data written to client? */ + if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0) + goto quit; + + /* Client closed, and all data written to server? */ + if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0) + goto quit; + + r = connection_enable_event_sources(c); + if (r < 0) + goto quit; + + return 1; + +quit: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} - /* If we sent everything without any issues (would-block or - * partial send), the buffer is now empty. */ - sender->buffer_filled_len = 0; - sender->buffer_sent_len = 0; +static int connection_enable_event_sources(Connection *c) { + uint32_t a = 0, b = 0; + int r; + + assert(c); + + if (c->server_to_client_buffer_full > 0) + b |= EPOLLOUT; + if (c->server_to_client_buffer_full < c->server_to_client_buffer_size) + a |= EPOLLIN; + + if (c->client_to_server_buffer_full > 0) + a |= EPOLLOUT; + if (c->client_to_server_buffer_full < c->client_to_server_buffer_size) + b |= EPOLLIN; + + if (c->server_event_source) + r = sd_event_source_set_io_events(c->server_event_source, a); + else if (c->server_fd >= 0) + r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c); + else + r = 0; - /* Enable the sender's receive watcher, in case the buffer was - * full and we disabled it. */ - r = add_event_to_connection(sender, EPOLLIN); if (r < 0) { - log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); + log_error("Failed to set up server event source: %s", strerror(-r)); return r; } - /* Disable the other side's send watcher, as we have no data to send now. */ - r = remove_event_from_connection(receiver, EPOLLOUT); + if (c->client_event_source) + r = sd_event_source_set_io_events(c->client_event_source, b); + else if (c->client_fd >= 0) + r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c); + else + r = 0; + if (r < 0) { - log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); + log_error("Failed to set up client event source: %s", strerror(-r)); return r; } return 0; } -static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { - struct connection *c = (struct connection *) userdata; - int r = 0; - ssize_t len; - - assert(revents & (EPOLLIN | EPOLLOUT)); - assert(fd == c->fd); - assert(s == c->w); - - log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c); - - if (revents & EPOLLIN) { - log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len); - - /* Receive until the buffer's full, there's no more data, - * or the client/server disconnects. */ - while (c->buffer_filled_len < BUFFER_SIZE) { - len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0); - log_debug("recv(%d, ...)=%ld", fd, len); - if (len < 0) { - if (errno != EWOULDBLOCK && errno != EAGAIN) { - log_error("Error %d in recv from fd=%d: %m", errno, fd); - return -errno; - } - else { - /* recv() is in a blocking state. */ - break; - } - } - else if (len == 0) { - log_debug("Clean disconnection from fd=%d", fd); - total_clients--; - free_connection(c->c_destination); - free_connection(c); - return 0; - } +static int connection_complete(Connection *c) { + int r; - assert(len > 0); - log_debug("Recording that the buffer got %ld more bytes full.", len); - c->buffer_filled_len += len; - log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len); - } + assert(c); - /* Try sending the data immediately. */ - return send_buffer(c); - } - else { - return send_buffer(c->c_destination); - } + r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size); + if (r < 0) + goto fail; - return r; -} + r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size); + if (r < 0) + goto fail; -/* Once sending to the server is ready, set up the real watchers. */ -static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { - struct sd_event *e = NULL; - struct connection *c_server_to_client = (struct connection *) userdata; - struct connection *c_client_to_server = c_server_to_client->c_destination; - int r; + r = connection_enable_event_sources(c); + if (r < 0) + goto fail; - assert(revents & EPOLLOUT); + return 0; - e = sd_event_get(s); +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} - /* Cancel the initial write watcher for the server. */ - sd_event_source_unref(s); +static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { + Connection *c = userdata; + socklen_t solen; + int error, r; - log_debug("Connected to server. Initializing watchers for receiving data."); + assert(s); + assert(fd >= 0); + assert(c); - /* A recv watcher for the server. */ - r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w); + solen = sizeof(error); + r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen); if (r < 0) { - log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); + log_error("Failed to issue SO_ERROR: %m"); goto fail; } - c_server_to_client->events = EPOLLIN; - /* A recv watcher for the client. */ - r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w); - if (r < 0) { - log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r)); + if (error != 0) { + log_error("Failed to connect to remote host: %s", strerror(error)); goto fail; } - c_client_to_server->events = EPOLLIN; -goto finish; + c->client_event_source = sd_event_source_unref(c->client_event_source); -fail: - free_connection(c_client_to_server); - free_connection(c_server_to_client); + return connection_complete(c); -finish: - return r; +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ } -static int get_server_connection_fd(const struct proxy *proxy) { - int server_fd; - int r = -EBADF; - int len; - - if (proxy->remote_is_inet) { - int s; - _cleanup_freeaddrinfo_ struct addrinfo *result = NULL; - struct addrinfo hints = {.ai_family = AF_UNSPEC, - .ai_socktype = SOCK_STREAM, - .ai_flags = AI_PASSIVE}; - - log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service); - s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result); - if (s != 0) { - log_error("getaddrinfo error (%d): %s", s, gai_strerror(s)); - return r; - } - - if (result == NULL) { - log_error("getaddrinfo: no result"); - return r; - } +static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) { + int r; - /* @TODO: Try connecting to all results instead of just the first. */ - server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol); - if (server_fd < 0) { - log_error("Error %d creating socket: %m", errno); - return r; - } + assert(c); + assert(sa); + assert(salen); - r = connect(server_fd, result->ai_addr, result->ai_addrlen); - /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */ - if (r < 0 && errno != EINPROGRESS) { - log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service); - return r; - } + c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0); + if (c->client_fd < 0) { + log_error("Failed to get remote socket: %m"); + goto fail; } - else { - struct sockaddr_un remote; - server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0); - if (server_fd < 0) { - log_error("Error %d creating socket: %m", errno); - return -EBADFD; - } + r = connect(c->client_fd, sa, salen); + if (r < 0) { + if (errno == EINPROGRESS) { + r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c); + if (r < 0) { + log_error("Failed to add connection socket: %s", strerror(-r)); + goto fail; + } - remote.sun_family = AF_UNIX; - strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path)); - len = strlen(remote.sun_path) + sizeof(remote.sun_family); - r = connect(server_fd, (struct sockaddr *) &remote, len); - if (r < 0 && errno != EINPROGRESS) { - log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host); - return -EBADFD; + r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT); + if (r < 0) { + log_error("Failed to enable oneshot event source: %s", strerror(-r)); + goto fail; + } + } else { + log_error("Failed to connect to remote host: %m"); + goto fail; } + } else { + r = connection_complete(c); + if (r < 0) + goto fail; } - log_debug("Server connection is fd=%d", server_fd); - return server_fd; + return 0; + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ } -static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { - struct proxy *proxy = (struct proxy *) userdata; - struct connection *c_server_to_client; - struct connection *c_client_to_server = NULL; - int r = 0; - union sockaddr_union sa; - socklen_t salen = sizeof(sa); +static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) { + Connection *c = userdata; - assert(revents & EPOLLIN); + assert(q); + assert(c); - c_server_to_client = new0(struct connection, 1); - if (c_server_to_client == NULL) { - log_oom(); + if (ret != 0) { + log_error("Failed to resolve host: %s", gai_strerror(ret)); goto fail; } - c_client_to_server = new0(struct connection, 1); - if (c_client_to_server == NULL) { - log_oom(); - goto fail; + c->resolve_query = sd_resolve_query_unref(c->resolve_query); + + return connection_start(c, ai->ai_addr, ai->ai_addrlen); + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + +static int resolve_remote(Connection *c) { + + static const struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_flags = AI_ADDRCONFIG + }; + + union sockaddr_union sa = {}; + const char *node, *service; + socklen_t salen; + int r; + + if (path_is_absolute(arg_remote_host)) { + sa.un.sun_family = AF_UNIX; + strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1); + sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0; + + salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path); + + return connection_start(c, &sa.sa, salen); } - c_server_to_client->fd = get_server_connection_fd(proxy); - if (c_server_to_client->fd < 0) { - log_error("Error initiating server connection."); - goto fail; + if (arg_remote_host[0] == '@') { + sa.un.sun_family = AF_UNIX; + sa.un.sun_path[0] = 0; + strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2); + sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0; + + salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1); + + return connection_start(c, &sa.sa, salen); + } + + service = strrchr(arg_remote_host, ':'); + if (service) { + node = strndupa(arg_remote_host, service - arg_remote_host); + service ++; + } else { + node = arg_remote_host; + service = "80"; } - c_client_to_server->fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC); - if (c_client_to_server->fd < 0) { - log_error("Error accepting client connection."); + log_debug("Looking up address info for %s:%s", node, service); + r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c); + if (r < 0) { + log_error("Failed to resolve remote host: %s", strerror(-r)); goto fail; } + return 0; + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + +static int add_connection_socket(Context *context, int fd) { + Connection *c; + int r; + + assert(context); + assert(fd >= 0); - if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) { - char sa_str[INET6_ADDRSTRLEN]; - const char *success; + if (set_size(context->connections) > CONNECTIONS_MAX) { + log_warning("Hit connection limit, refusing connection."); + safe_close(fd); + return 0; + } - success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN); - if (success == NULL) - log_warning("Error %d calling inet_ntop: %m", errno); - else - log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd); + r = set_ensure_allocated(&context->connections, NULL); + if (r < 0) { + log_oom(); + return 0; } - else { - log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd); + + c = new0(Connection, 1); + if (!c) { + log_oom(); + return 0; } - total_clients++; - log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients); - log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client); + c->context = context; + c->server_fd = fd; + c->client_fd = -1; + c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1; + c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1; - /* Initialize watcher for send to server; this shows connectivity. */ - r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w); + r = set_put(context->connections, c); if (r < 0) { - log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); - goto fail; + free(c); + log_oom(); + return 0; } - /* Allow lookups of the opposite connection. */ - c_server_to_client->c_destination = c_client_to_server; - c_client_to_server->c_destination = c_server_to_client; + return resolve_remote(c); +} - goto finish; +static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { + _cleanup_free_ char *peer = NULL; + Context *context = userdata; + int nfd = -1, r; -fail: - log_warning("Accepting a client connection or connecting to the server failed."); - free_connection(c_client_to_server); - free_connection(c_server_to_client); + assert(s); + assert(fd >= 0); + assert(revents & EPOLLIN); + assert(context); + + nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC); + if (nfd < 0) { + if (errno != -EAGAIN) + log_warning("Failed to accept() socket: %m"); + } else { + getpeername_pretty(nfd, &peer); + log_debug("New connection from %s", strna(peer)); + + r = add_connection_socket(context, nfd); + if (r < 0) { + log_error("Failed to accept connection, ignoring: %s", strerror(-r)); + safe_close(fd); + } + } + + r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT); + if (r < 0) { + log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r)); + sd_event_exit(context->event, r); + return r; + } -finish: - /* Preserve the main loop even if a single proxy setup fails. */ return 1; } -static int run_main_loop(struct proxy *proxy) { - _cleanup_event_source_unref_ sd_event_source *w_accept = NULL; - _cleanup_event_unref_ sd_event *e = NULL; - int r = EXIT_SUCCESS; +static int add_listen_socket(Context *context, int fd) { + sd_event_source *source; + int r; + + assert(context); + assert(fd >= 0); - r = sd_event_new(&e); + r = set_ensure_allocated(&context->listen, NULL); if (r < 0) { - log_error("Failed to allocate event loop: %s", strerror(-r)); + log_oom(); return r; } - r = fd_nonblock(proxy->listen_fd, true); + r = sd_is_socket(fd, 0, SOCK_STREAM, 1); if (r < 0) { - log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r)); + log_error("Failed to determine socket type: %s", strerror(-r)); return r; } + if (r == 0) { + log_error("Passed in socket is not a stream socket."); + return -EINVAL; + } - log_debug("Initializing main listener fd=%d", proxy->listen_fd); - - r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept); + r = fd_nonblock(fd, true); if (r < 0) { - log_error("Failed to add event IO source: %s", strerror(-r)); + log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r)); return r; } - log_debug("Initialized main listener. Entering loop."); - - return sd_event_loop(e); -} + r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context); + if (r < 0) { + log_error("Failed to add event source: %s", strerror(-r)); + return r; + } -static int help(void) { + r = set_put(context->listen, source); + if (r < 0) { + log_error("Failed to add source to set: %s", strerror(-r)); + sd_event_source_unref(source); + return r; + } - printf("%s hostname-or-ip port-or-service\n" - "%s unix-domain-socket-path\n\n" - "Inherit a socket. Bidirectionally proxy.\n\n" - " -h --help Show this help\n" - " --version Print version and exit\n" - " --ignore-env Ignore expected systemd environment\n", - program_invocation_short_name, - program_invocation_short_name); + /* Set the watcher to oneshot in case other processes are also + * watching to accept(). */ + r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT); + if (r < 0) { + log_error("Failed to enable oneshot mode: %s", strerror(-r)); + return r; + } return 0; } -static void version(void) { - puts(PACKAGE_STRING " socket-proxyd"); +static void help(void) { + printf("%1$s [HOST:PORT]\n" + "%1$s [SOCKET]\n\n" + "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n" + " -h --help Show this help\n" + " --version Show package version\n", + program_invocation_short_name); } -static int parse_argv(int argc, char *argv[], struct proxy *p) { +static int parse_argv(int argc, char *argv[]) { enum { ARG_VERSION = 0x100, @@ -481,8 +608,7 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) { static const struct option options[] = { { "help", no_argument, NULL, 'h' }, { "version", no_argument, NULL, ARG_VERSION }, - { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV}, - { NULL, 0, NULL, 0 } + {} }; int c; @@ -490,7 +616,7 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) { assert(argc >= 0); assert(argv); - while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) { + while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) switch (c) { @@ -498,92 +624,88 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) { help(); return 0; - case '?': - return -EINVAL; - case ARG_VERSION: - version(); + puts(PACKAGE_STRING); + puts(SYSTEMD_FEATURES); return 0; - case ARG_IGNORE_ENV: - p->ignore_env = true; - continue; + case '?': + return -EINVAL; default: - log_error("Unknown option code %c", c); - return -EINVAL; + assert_not_reached("Unhandled option"); } - } - if (optind + 1 != argc && optind + 2 != argc) { - log_error("Incorrect number of positional arguments."); - help(); + if (optind >= argc) { + log_error("Not enough parameters."); return -EINVAL; } - p->remote_host = argv[optind]; - assert(p->remote_host); - - p->remote_is_inet = p->remote_host[0] != '/'; - - if (optind == argc - 2) { - if (!p->remote_is_inet) { - log_error("A port or service is not allowed for Unix socket destinations."); - help(); - return -EINVAL; - } - p->remote_service = argv[optind + 1]; - assert(p->remote_service); - } else if (p->remote_is_inet) { - log_error("A port or service is required for IP destinations."); - help(); + if (argc != optind+1) { + log_error("Too many parameters."); return -EINVAL; } + arg_remote_host = argv[optind]; return 1; } int main(int argc, char *argv[]) { - struct proxy p = {}; - int r; + Context context = {}; + int r, n, fd; log_parse_environment(); log_open(); - r = parse_argv(argc, argv, &p); + r = parse_argv(argc, argv); if (r <= 0) goto finish; - p.listen_fd = SD_LISTEN_FDS_START; + r = sd_event_default(&context.event); + if (r < 0) { + log_error("Failed to allocate event loop: %s", strerror(-r)); + goto finish; + } - if (!p.ignore_env) { - int n; - n = sd_listen_fds(1); - if (n == 0) { - log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?"); - r = EXIT_FAILURE; - goto finish; - } else if (n < 0) { - log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n)); - r = EXIT_FAILURE; - goto finish; - } else if (n > 1) { - log_error("Can't listen on more than one socket."); - r = EXIT_FAILURE; - goto finish; - } + r = sd_resolve_default(&context.resolve); + if (r < 0) { + log_error("Failed to allocate resolver: %s", strerror(-r)); + goto finish; } - r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1); + r = sd_resolve_attach_event(context.resolve, context.event, 0); if (r < 0) { - log_error("Error %d while checking inherited socket: %s", r, strerror(-r)); + log_error("Failed to attach resolver: %s", strerror(-r)); goto finish; } - log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd); + sd_event_set_watchdog(context.event, true); - r = run_main_loop(&p); + n = sd_listen_fds(1); + if (n < 0) { + log_error("Failed to receive sockets from parent."); + r = n; + goto finish; + } else if (n == 0) { + log_error("Didn't get any sockets passed in."); + r = -EINVAL; + goto finish; + } + + for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) { + r = add_listen_socket(&context, fd); + if (r < 0) + goto finish; + } + + r = sd_event_loop(context.event); + if (r < 0) { + log_error("Failed to run event loop: %s", strerror(-r)); + goto finish; + } finish: + context_free(&context); + return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS; }