chiark / gitweb /
sd-resolve: use different element of union to make code more readable
[elogind.git] / src / libsystemd / sd-resolve / sd-resolve.c
index cb8e34e3688e85ba53df331d9b7d329facd7f34b..8aec75aa47eaab996ec0c7ab6a828fbc716f4892 100644 (file)
   along with systemd; If not, see <http://www.gnu.org/licenses/>.
 ***/
 
-#include <assert.h>
-#include <fcntl.h>
 #include <signal.h>
 #include <unistd.h>
-#include <sys/select.h>
 #include <stdio.h>
 #include <string.h>
 #include <stdlib.h>
 #include <errno.h>
-#include <sys/wait.h>
-#include <sys/types.h>
-#include <pwd.h>
-#include <netinet/in.h>
-#include <arpa/nameser.h>
 #include <resolv.h>
-#include <dirent.h>
-#include <sys/time.h>
-#include <sys/resource.h>
 #include <stdint.h>
 #include <pthread.h>
 #include <sys/prctl.h>
-#include <sys/poll.h>
+#include <poll.h>
 
 #include "util.h"
 #include "list.h"
@@ -85,9 +74,9 @@ struct sd_resolve {
         pthread_t workers[WORKERS_MAX];
         unsigned n_valid_workers;
 
-        unsigned current_id, current_index;
-        sd_resolve_query* queries[QUERIES_MAX];
-        unsigned n_queries, n_done;
+        unsigned current_id;
+        sd_resolve_query* query_array[QUERIES_MAX];
+        unsigned n_queries, n_done, n_outstanding;
 
         sd_event_source *event_source;
         sd_event *event;
@@ -96,6 +85,8 @@ struct sd_resolve {
 
         sd_resolve **default_resolve_ptr;
         pid_t tid;
+
+        LIST_HEAD(sd_resolve_query, queries);
 };
 
 struct sd_resolve_query {
@@ -105,6 +96,7 @@ struct sd_resolve_query {
 
         QueryType type:4;
         bool done:1;
+        bool floating:1;
         unsigned id;
 
         int ret;
@@ -121,6 +113,8 @@ struct sd_resolve_query {
         };
 
         void *userdata;
+
+        LIST_FIELDS(sd_resolve_query, queries);
 };
 
 typedef struct RHeader {
@@ -200,6 +194,8 @@ static int getaddrinfo_done(sd_resolve_query* q);
 static int getnameinfo_done(sd_resolve_query *q);
 static int res_query_done(sd_resolve_query* q);
 
+static void resolve_query_disconnect(sd_resolve_query *q);
+
 #define RESOLVE_DONT_DESTROY(resolve) \
         _cleanup_resolve_unref_ _unused_ sd_resolve *_dont_destroy_##resolve = sd_resolve_ref(resolve)
 
@@ -454,7 +450,7 @@ static int handle_request(int out_fd, const Packet *packet, size_t length) {
                  assert(length >= sizeof(ResRequest));
                  assert(length == sizeof(ResRequest) + res_req->dname_len);
 
-                 dname = (const char *) req + sizeof(ResRequest);
+                 dname = (const char *) res_req + sizeof(ResRequest);
 
                  if (req->type == REQUEST_RES_QUERY)
                          ret = res_query(dname, res_req->class, res_req->type, (unsigned char *) &answer, BUFSIZE);
@@ -519,7 +515,7 @@ static int start_threads(sd_resolve *resolve, unsigned extra) {
         unsigned n;
         int r;
 
-        n = resolve->n_queries + extra - resolve->n_done;
+        n = resolve->n_outstanding + extra;
         n = CLAMP(n, WORKERS_MIN, WORKERS_MAX);
 
         while (resolve->n_valid_workers < n) {
@@ -630,10 +626,17 @@ _public_ int sd_resolve_get_tid(sd_resolve *resolve, pid_t *tid) {
 
 static void resolve_free(sd_resolve *resolve) {
         PROTECT_ERRNO;
+        sd_resolve_query *q;
         unsigned i;
 
         assert(resolve);
 
+        while ((q = resolve->queries)) {
+                assert(q->floating);
+                resolve_query_disconnect(q);
+                sd_resolve_query_unref(q);
+        }
+
         if (resolve->default_resolve_ptr)
                 *(resolve->default_resolve_ptr) = NULL;
 
@@ -719,7 +722,7 @@ static sd_resolve_query *lookup_query(sd_resolve *resolve, unsigned id) {
 
         assert(resolve);
 
-        q = resolve->queries[id % QUERIES_MAX];
+        q = resolve->query_array[id % QUERIES_MAX];
         if (q)
                 if (q->id == id)
                         return q;
@@ -737,7 +740,7 @@ static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
         q->done = true;
         resolve->n_done ++;
 
-        resolve->current = q;
+        resolve->current = sd_resolve_query_ref(q);
 
         switch (q->type) {
 
@@ -760,6 +763,13 @@ static int complete_query(sd_resolve *resolve, sd_resolve_query *q) {
 
         resolve->current = NULL;
 
+        if (q->floating) {
+                resolve_query_disconnect(q);
+                sd_resolve_query_unref(q);
+        }
+
+        sd_resolve_query_unref(q);
+
         return r;
 }
 
@@ -833,6 +843,9 @@ static int handle_response(sd_resolve *resolve, const Packet *packet, size_t len
                 return 0;
         }
 
+        assert(resolve->n_outstanding > 0);
+        resolve->n_outstanding--;
+
         q = lookup_query(resolve, resp->id);
         if (!q)
                 return 0;
@@ -989,7 +1002,7 @@ _public_ int sd_resolve_wait(sd_resolve *resolve, uint64_t timeout_usec) {
         return sd_resolve_process(resolve);
 }
 
-static int alloc_query(sd_resolve *resolve, sd_resolve_query **_q) {
+static int alloc_query(sd_resolve *resolve, bool floating, sd_resolve_query **_q) {
         sd_resolve_query *q;
         int r;
 
@@ -1003,21 +1016,22 @@ static int alloc_query(sd_resolve *resolve, sd_resolve_query **_q) {
         if (r < 0)
                 return r;
 
-        while (resolve->queries[resolve->current_index]) {
-                resolve->current_index++;
+        while (resolve->query_array[resolve->current_id % QUERIES_MAX])
                 resolve->current_id++;
 
-                resolve->current_index %= QUERIES_MAX;
-        }
-
-        q = resolve->queries[resolve->current_index] = new0(sd_resolve_query, 1);
+        q = resolve->query_array[resolve->current_id % QUERIES_MAX] = new0(sd_resolve_query, 1);
         if (!q)
                 return -ENOMEM;
 
         q->n_ref = 1;
-        q->resolve = sd_resolve_ref(resolve);
-        q->id = resolve->current_id;
+        q->resolve = resolve;
+        q->floating = floating;
+        q->id = resolve->current_id++;
+
+        if (!floating)
+                sd_resolve_ref(resolve);
 
+        LIST_PREPEND(queries, resolve->queries, q);
         resolve->n_queries++;
 
         *_q = q;
@@ -1038,12 +1052,11 @@ _public_ int sd_resolve_getaddrinfo(
         int r;
 
         assert_return(resolve, -EINVAL);
-        assert_return(_q, -EINVAL);
         assert_return(node || service, -EINVAL);
         assert_return(callback, -EINVAL);
         assert_return(!resolve_pid_changed(resolve), -ECHILD);
 
-        r = alloc_query(resolve, &q);
+        r = alloc_query(resolve, !_q, &q);
         if (r < 0)
                 return r;
 
@@ -1078,7 +1091,11 @@ _public_ int sd_resolve_getaddrinfo(
                 return -errno;
         }
 
-        *_q = q;
+        resolve->n_outstanding++;
+
+        if (_q)
+                *_q = q;
+
         return 0;
 }
 
@@ -1109,7 +1126,6 @@ _public_ int sd_resolve_getnameinfo(
         int r;
 
         assert_return(resolve, -EINVAL);
-        assert_return(_q, -EINVAL);
         assert_return(sa, -EINVAL);
         assert_return(salen >= sizeof(struct sockaddr), -EINVAL);
         assert_return(salen <= sizeof(union sockaddr_union), -EINVAL);
@@ -1117,7 +1133,7 @@ _public_ int sd_resolve_getnameinfo(
         assert_return(callback, -EINVAL);
         assert_return(!resolve_pid_changed(resolve), -ECHILD);
 
-        r = alloc_query(resolve, &q);
+        r = alloc_query(resolve, !_q, &q);
         if (r < 0)
                 return r;
 
@@ -1145,7 +1161,11 @@ _public_ int sd_resolve_getnameinfo(
                 return -errno;
         }
 
-        *_q = q;
+        resolve->n_outstanding++;
+
+        if (_q)
+                *_q = q;
+
         return 0;
 }
 
@@ -1176,12 +1196,11 @@ static int resolve_res(
         int r;
 
         assert_return(resolve, -EINVAL);
-        assert_return(_q, -EINVAL);
         assert_return(dname, -EINVAL);
         assert_return(callback, -EINVAL);
         assert_return(!resolve_pid_changed(resolve), -ECHILD);
 
-        r = alloc_query(resolve, &q);
+        r = alloc_query(resolve, !_q, &q);
         if (r < 0)
                 return r;
 
@@ -1208,7 +1227,11 @@ static int resolve_res(
                 return -errno;
         }
 
-        *_q = q;
+        resolve->n_outstanding++;
+
+        if (_q)
+                *_q = q;
+
         return 0;
 }
 
@@ -1251,23 +1274,38 @@ static void resolve_freeaddrinfo(struct addrinfo *ai) {
         }
 }
 
-static void resolve_query_free(sd_resolve_query *q) {
+static void resolve_query_disconnect(sd_resolve_query *q) {
+        sd_resolve *resolve;
         unsigned i;
 
         assert(q);
-        assert(q->resolve);
-        assert(q->resolve->n_queries > 0);
+
+        if (!q->resolve)
+                return;
+
+        resolve = q->resolve;
+        assert(resolve->n_queries > 0);
 
         if (q->done) {
-                assert(q->resolve->n_done > 0);
-                q->resolve->n_done--;
+                assert(resolve->n_done > 0);
+                resolve->n_done--;
         }
 
         i = q->id % QUERIES_MAX;
-        assert(q->resolve->queries[i] == q);
-        q->resolve->queries[i] = NULL;
-        q->resolve->n_queries--;
-        sd_resolve_unref(q->resolve);
+        assert(resolve->query_array[i] == q);
+        resolve->query_array[i] = NULL;
+        LIST_REMOVE(queries, resolve->queries, q);
+        resolve->n_queries--;
+
+        q->resolve = NULL;
+        if (!q->floating)
+                sd_resolve_unref(resolve);
+}
+
+static void resolve_query_free(sd_resolve_query *q) {
+        assert(q);
+
+        resolve_query_disconnect(q);
 
         resolve_freeaddrinfo(q->addrinfo);
         free(q->host);