chiark / gitweb /
socket-proxy: actually properly keep track of connections
authorLennart Poettering <lennart@poettering.net>
Thu, 7 Nov 2013 15:53:14 +0000 (16:53 +0100)
committerLennart Poettering <lennart@poettering.net>
Thu, 7 Nov 2013 15:53:26 +0000 (16:53 +0100)
src/socket-proxy/socket-proxyd.c

index b6a7f1c1ba688ebba63fbd47cc0421f1376ad85f..56e660de57cb6b50cddd7b2d4c4a64ce22309c55 100644 (file)
@@ -53,6 +53,8 @@ typedef struct Context {
 } Context;
 
 typedef struct Connection {
 } Context;
 
 typedef struct Connection {
+        Context *context;
+
         int server_fd, client_fd;
         int server_to_client_buffer[2]; /* a pipe */
         int client_to_server_buffer[2]; /* a pipe */
         int server_fd, client_fd;
         int server_to_client_buffer[2]; /* a pipe */
         int client_to_server_buffer[2]; /* a pipe */
@@ -68,6 +70,9 @@ static const char *arg_remote_host = NULL;
 static void connection_free(Connection *c) {
         assert(c);
 
 static void connection_free(Connection *c) {
         assert(c);
 
+        if (c->context)
+                set_remove(c->context->connections, c);
+
         sd_event_source_unref(c->server_event_source);
         sd_event_source_unref(c->client_event_source);
 
         sd_event_source_unref(c->server_event_source);
         sd_event_source_unref(c->client_event_source);
 
@@ -91,7 +96,7 @@ static void context_free(Context *context) {
         while ((es = set_steal_first(context->listen)))
                 sd_event_source_unref(es);
 
         while ((es = set_steal_first(context->listen)))
                 sd_event_source_unref(es);
 
-        while ((c = set_steal_first(context->connections)))
+        while ((c = set_first(context->connections)))
                 connection_free(c);
 
         set_free(context->listen);
                 connection_free(c);
 
         set_free(context->listen);
@@ -403,11 +408,18 @@ static int add_connection_socket(Context *context, sd_event *event, int fd) {
         if (!c)
                 return log_oom();
 
         if (!c)
                 return log_oom();
 
+        c->context = context;
         c->server_fd = fd;
         c->client_fd = -1;
         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
 
         c->server_fd = fd;
         c->client_fd = -1;
         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
 
+        r = set_put(context->connections, c);
+        if (r < 0) {
+                free(c);
+                return log_oom();
+        }
+
         r = get_remote_sockaddr(&sa, &salen);
         if (r < 0)
                 goto fail;
         r = get_remote_sockaddr(&sa, &salen);
         if (r < 0)
                 goto fail;
@@ -491,8 +503,6 @@ static int add_listen_socket(Context *context, sd_event *event, int fd) {
         assert(event);
         assert(fd >= 0);
 
         assert(event);
         assert(fd >= 0);
 
-        log_info("Listening on %i", fd);
-
         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
         if (r < 0) {
                 log_oom();
         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
         if (r < 0) {
                 log_oom();