X-Git-Url: http://www.chiark.greenend.org.uk/ucgi/~ianmdlvl/git?p=elogind.git;a=blobdiff_plain;f=src%2Fsocket-proxy%2Fsocket-proxyd.c;h=f6e6672cdfc5da87a5feff8deb3b49c5587d30b1;hp=1c64c0e2e5765732ee52ba034d5cb646cc694147;hb=e70bc43cdf75b36e7ad3d29e9a6f8ee1461e7d5e;hpb=8569a77629949b7818d00eba8eea1d05e2d1fc32 diff --git a/src/socket-proxy/socket-proxyd.c b/src/socket-proxy/socket-proxyd.c index 1c64c0e2e..f6e6672cd 100644 --- a/src/socket-proxy/socket-proxyd.c +++ b/src/socket-proxy/socket-proxyd.c @@ -33,6 +33,7 @@ #include "sd-daemon.h" #include "sd-event.h" +#include "sd-resolve.h" #include "log.h" #include "socket-util.h" #include "util.h" @@ -44,15 +45,19 @@ #define BUFFER_SIZE (256 * 1024) #define CONNECTIONS_MAX 256 -#define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop) -DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo); +static const char *arg_remote_host = NULL; typedef struct Context { + sd_event *event; + sd_resolve *resolve; + Set *listen; Set *connections; } Context; typedef struct Connection { + Context *context; + int server_fd, client_fd; int server_to_client_buffer[2]; /* a pipe */ int client_to_server_buffer[2]; /* a pipe */ @@ -61,31 +66,26 @@ typedef struct Connection { size_t server_to_client_buffer_size, client_to_server_buffer_size; sd_event_source *server_event_source, *client_event_source; -} Connection; - -union sockaddr_any { - struct sockaddr sa; - struct sockaddr_un un; - struct sockaddr_in in; - struct sockaddr_in6 in6; - struct sockaddr_storage storage; -}; -static const char *arg_remote_host = NULL; + sd_resolve_query *resolve_query; +} Connection; static void connection_free(Connection *c) { assert(c); + if (c->context) + set_remove(c->context->connections, c); + sd_event_source_unref(c->server_event_source); sd_event_source_unref(c->client_event_source); - if (c->server_fd >= 0) - close_nointr_nofail(c->server_fd); - if (c->client_fd >= 0) - close_nointr_nofail(c->client_fd); + safe_close(c->server_fd); + safe_close(c->client_fd); + + safe_close_pair(c->server_to_client_buffer); + safe_close_pair(c->client_to_server_buffer); - close_pipe(c->server_to_client_buffer); - close_pipe(c->client_to_server_buffer); + sd_resolve_query_unref(c->resolve_query); free(c); } @@ -99,71 +99,14 @@ static void context_free(Context *context) { while ((es = set_steal_first(context->listen))) sd_event_source_unref(es); - while ((c = set_steal_first(context->connections))) + while ((c = set_first(context->connections))) connection_free(c); set_free(context->listen); set_free(context->connections); -} - -static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) { - int r; - - assert(sa); - assert(salen); - - if (path_is_absolute(arg_remote_host)) { - sa->un.sun_family = AF_UNIX; - strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1); - sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0; - - *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path); - - } else if (arg_remote_host[0] == '@') { - sa->un.sun_family = AF_UNIX; - sa->un.sun_path[0] = 0; - strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2); - sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0; - - *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1); - - } else { - _cleanup_freeaddrinfo_ struct addrinfo *result = NULL; - const char *node, *service; - - struct addrinfo hints = { - .ai_family = AF_UNSPEC, - .ai_socktype = SOCK_STREAM, - .ai_flags = AI_ADDRCONFIG - }; - - service = strrchr(arg_remote_host, ':'); - if (service) { - node = strndupa(arg_remote_host, service - arg_remote_host); - service ++; - } else { - node = arg_remote_host; - service = "80"; - } - log_debug("Looking up address info for %s:%s", node, service); - r = getaddrinfo(node, service, &hints, &result); - if (r != 0) { - log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r)); - return -EHOSTUNREACH; - } - - assert(result); - if (result->ai_addrlen > sizeof(union sockaddr_any)) { - log_error("Address too long."); - return -E2BIG; - } - - memcpy(sa, result->ai_addr, result->ai_addrlen); - *salen = result->ai_addrlen; - } - - return 0; + sd_event_unref(context->event); + sd_resolve_unref(context->resolve); } static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) { @@ -227,8 +170,7 @@ static int connection_shovel( shoveled = true; } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { *from_source = sd_event_source_unref(*from_source); - close_nointr_nofail(*from); - *from = -1; + *from = safe_close(*from); } else if (errno != EAGAIN && errno != EINTR) { log_error("Failed to splice: %m"); return -errno; @@ -242,8 +184,7 @@ static int connection_shovel( shoveled = true; } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { *to_source = sd_event_source_unref(*to_source); - close_nointr_nofail(*to); - *to = -1; + *to = safe_close(*to); } else if (errno != EAGAIN && errno != EINTR) { log_error("Failed to splice: %m"); return -errno; @@ -254,7 +195,7 @@ static int connection_shovel( return 0; } -static int connection_enable_event_sources(Connection *c, sd_event *event); +static int connection_enable_event_sources(Connection *c); static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { Connection *c = userdata; @@ -290,7 +231,7 @@ static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userda if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0) goto quit; - r = connection_enable_event_sources(c, sd_event_get(s)); + r = connection_enable_event_sources(c); if (r < 0) goto quit; @@ -301,12 +242,11 @@ quit: return 0; /* ignore errors, continue serving */ } -static int connection_enable_event_sources(Connection *c, sd_event *event) { +static int connection_enable_event_sources(Connection *c) { uint32_t a = 0, b = 0; int r; assert(c); - assert(event); if (c->server_to_client_buffer_full > 0) b |= EPOLLOUT; @@ -321,7 +261,7 @@ static int connection_enable_event_sources(Connection *c, sd_event *event) { if (c->server_event_source) r = sd_event_source_set_io_events(c->server_event_source, a); else if (c->server_fd >= 0) - r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source); + r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c); else r = 0; @@ -333,18 +273,42 @@ static int connection_enable_event_sources(Connection *c, sd_event *event) { if (c->client_event_source) r = sd_event_source_set_io_events(c->client_event_source, b); else if (c->client_fd >= 0) - r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source); + r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c); else r = 0; if (r < 0) { - log_error("Failed to set up server event source: %s", strerror(-r)); + log_error("Failed to set up client event source: %s", strerror(-r)); return r; } return 0; } +static int connection_complete(Connection *c) { + int r; + + assert(c); + + r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size); + if (r < 0) + goto fail; + + r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size); + if (r < 0) + goto fail; + + r = connection_enable_event_sources(c); + if (r < 0) + goto fail; + + return 0; + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { Connection *c = userdata; socklen_t solen; @@ -368,17 +332,126 @@ static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userda c->client_event_source = sd_event_source_unref(c->client_event_source); - r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size); - if (r < 0) + return connection_complete(c); + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + +static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) { + int r; + + assert(c); + assert(sa); + assert(salen); + + c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0); + if (c->client_fd < 0) { + log_error("Failed to get remote socket: %m"); goto fail; + } - r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size); - if (r < 0) + r = connect(c->client_fd, sa, salen); + if (r < 0) { + if (errno == EINPROGRESS) { + r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c); + if (r < 0) { + log_error("Failed to add connection socket: %s", strerror(-r)); + goto fail; + } + + r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT); + if (r < 0) { + log_error("Failed to enable oneshot event source: %s", strerror(-r)); + goto fail; + } + } else { + log_error("Failed to connect to remote host: %m"); + goto fail; + } + } else { + r = connection_complete(c); + if (r < 0) + goto fail; + } + + return 0; + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + +static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) { + Connection *c = userdata; + + assert(q); + assert(c); + + if (ret != 0) { + log_error("Failed to resolve host: %s", gai_strerror(ret)); goto fail; + } - r = connection_enable_event_sources(c, sd_event_get(s)); - if (r < 0) + c->resolve_query = sd_resolve_query_unref(c->resolve_query); + + return connection_start(c, ai->ai_addr, ai->ai_addrlen); + +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + +static int resolve_remote(Connection *c) { + + static const struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_flags = AI_ADDRCONFIG + }; + + union sockaddr_union sa = {}; + const char *node, *service; + socklen_t salen; + int r; + + if (path_is_absolute(arg_remote_host)) { + sa.un.sun_family = AF_UNIX; + strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1); + sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0; + + salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path); + + return connection_start(c, &sa.sa, salen); + } + + if (arg_remote_host[0] == '@') { + sa.un.sun_family = AF_UNIX; + sa.un.sun_path[0] = 0; + strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2); + sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0; + + salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1); + + return connection_start(c, &sa.sa, salen); + } + + service = strrchr(arg_remote_host, ':'); + if (service) { + node = strndupa(arg_remote_host, service - arg_remote_host); + service ++; + } else { + node = arg_remote_host; + service = "80"; + } + + log_debug("Looking up address info for %s:%s", node, service); + r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c); + if (r < 0) { + log_error("Failed to resolve remote host: %s", strerror(-r)); goto fail; + } return 0; @@ -387,71 +460,49 @@ fail: return 0; /* ignore errors, continue serving */ } -static int add_connection_socket(Context *context, sd_event *event, int fd) { - union sockaddr_any sa = {}; - socklen_t salen; +static int add_connection_socket(Context *context, int fd) { Connection *c; int r; assert(context); - assert(event); assert(fd >= 0); if (set_size(context->connections) > CONNECTIONS_MAX) { log_warning("Hit connection limit, refusing connection."); - close_nointr_nofail(fd); + safe_close(fd); return 0; } r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func); - if (r < 0) - return log_oom(); + if (r < 0) { + log_oom(); + return 0; + } c = new0(Connection, 1); - if (!c) - return log_oom(); + if (!c) { + log_oom(); + return 0; + } + c->context = context; c->server_fd = fd; c->client_fd = -1; c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1; c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1; - r = get_remote_sockaddr(&sa, &salen); - if (r < 0) - goto fail; - - c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0); - if (c->client_fd < 0) { - log_error("Failed to get remote socket: %m"); - goto fail; - } - - r = connect(c->client_fd, &sa.sa, salen); + r = set_put(context->connections, c); if (r < 0) { - if (errno == EINPROGRESS) { - r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source); - if (r < 0) { - log_error("Failed to add connection socket: %s", strerror(-r)); - goto fail; - } - } else { - log_error("Failed to connect to remote host: %m"); - goto fail; - } - } else { - r = connection_enable_event_sources(c, event); - if (r < 0) - goto fail; + free(c); + log_oom(); + return 0; } - return 0; - -fail: - connection_free(c); - return 0; /* ignore non-OOM errors, continue serving */ + return resolve_remote(c); } static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { + _cleanup_free_ char *peer = NULL; Context *context = userdata; int nfd = -1, r; @@ -461,40 +512,37 @@ static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdat assert(context); nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC); - if (nfd >= 0) { - _cleanup_free_ char *peer = NULL; - + if (nfd < 0) { + if (errno != -EAGAIN) + log_warning("Failed to accept() socket: %m"); + } else { getpeername_pretty(nfd, &peer); log_debug("New connection from %s", strna(peer)); - r = add_connection_socket(context, sd_event_get(s), nfd); + r = add_connection_socket(context, nfd); if (r < 0) { - close_nointr_nofail(fd); - return r; + log_error("Failed to accept connection, ignoring: %s", strerror(-r)); + safe_close(fd); } - - } else if (errno != -EAGAIN) - log_warning("Failed to accept() socket: %m"); + } r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT); if (r < 0) { - log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r)); + log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r)); + sd_event_exit(context->event, r); return r; } return 1; } -static int add_listen_socket(Context *context, sd_event *event, int fd) { +static int add_listen_socket(Context *context, int fd) { sd_event_source *source; int r; assert(context); - assert(event); assert(fd >= 0); - log_info("Listening on %i", fd); - r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func); if (r < 0) { log_oom(); @@ -517,7 +565,7 @@ static int add_listen_socket(Context *context, sd_event *event, int fd) { return r; } - r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source); + r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context); if (r < 0) { log_error("Failed to add event source: %s", strerror(-r)); return r; @@ -541,17 +589,13 @@ static int add_listen_socket(Context *context, sd_event *event, int fd) { return 0; } -static int help(void) { - - printf("%s [HOST:PORT]\n" - "%s [SOCKET]\n\n" +static void help(void) { + printf("%1$s [HOST:PORT]\n" + "%1$s [SOCKET]\n\n" "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n" " -h --help Show this help\n" " --version Show package version\n", - program_invocation_short_name, program_invocation_short_name); - - return 0; } static int parse_argv(int argc, char *argv[]) { @@ -572,12 +616,13 @@ static int parse_argv(int argc, char *argv[]) { assert(argc >= 0); assert(argv); - while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) { + while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) switch (c) { case 'h': - return help(); + help(); + return 0; case ARG_VERSION: puts(PACKAGE_STRING); @@ -590,7 +635,6 @@ static int parse_argv(int argc, char *argv[]) { default: assert_not_reached("Unhandled option"); } - } if (optind >= argc) { log_error("Not enough parameters."); @@ -607,7 +651,6 @@ static int parse_argv(int argc, char *argv[]) { } int main(int argc, char *argv[]) { - _cleanup_event_unref_ sd_event *event = NULL; Context context = {}; int r, n, fd; @@ -618,12 +661,26 @@ int main(int argc, char *argv[]) { if (r <= 0) goto finish; - r = sd_event_new(&event); + r = sd_event_default(&context.event); if (r < 0) { log_error("Failed to allocate event loop: %s", strerror(-r)); goto finish; } + r = sd_resolve_default(&context.resolve); + if (r < 0) { + log_error("Failed to allocate resolver: %s", strerror(-r)); + goto finish; + } + + r = sd_resolve_attach_event(context.resolve, context.event, 0); + if (r < 0) { + log_error("Failed to attach resolver: %s", strerror(-r)); + goto finish; + } + + sd_event_set_watchdog(context.event, true); + n = sd_listen_fds(1); if (n < 0) { log_error("Failed to receive sockets from parent."); @@ -636,12 +693,12 @@ int main(int argc, char *argv[]) { } for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) { - r = add_listen_socket(&context, event, fd); + r = add_listen_socket(&context, fd); if (r < 0) goto finish; } - r = sd_event_loop(event); + r = sd_event_loop(context.event); if (r < 0) { log_error("Failed to run event loop: %s", strerror(-r)); goto finish;