chiark / gitweb /
site: Log about crossed MSG1 ignored only once
[secnet.git] / site.c
diff --git a/site.c b/site.c
index 1c57a8a66b6296ca305b7efab2d4cf9035838c56..11fc28be2443973ab84c30490a7d79dd09efabf0 100644 (file)
--- a/site.c
+++ b/site.c
@@ -1,5 +1,24 @@
 /* site.c - manage communication with a remote network site */
 
+/*
+ * This file is part of secnet.
+ * See README for full list of copyright holders.
+ *
+ * secnet is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 3 of the License, or
+ * (at your option) any later version.
+ * 
+ * secnet is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * General Public License for more details.
+ * 
+ * You should have received a copy of the GNU General Public License
+ * version 3 along with secnet; if not, see
+ * https://www.gnu.org/licenses/gpl.html.
+ */
+
 /* The 'site' code doesn't know anything about the structure of the
    packets it's transmitting.  In fact, under the new netlink
    configuration scheme it doesn't need to know anything at all about
@@ -277,6 +296,7 @@ struct site {
 /* configuration information */
     string_t localname;
     string_t remotename;
+    bool_t keepalive;
     bool_t local_mobile, peer_mobile; /* Mobile client support */
     int32_t transport_peers_max;
     string_t tunname; /* localname<->remotename by default, used in logs */
@@ -285,6 +305,7 @@ struct site {
     uint32_t mtu_target;
     struct netlink_if *netlink;
     struct comm_if **comms;
+    struct comm_clientinfo **commclientinfos;
     int ncomms;
     struct resolver_if *resolver;
     struct log_if *log;
@@ -315,6 +336,7 @@ struct site {
     uint32_t state;
     uint64_t now; /* Most recently seen time */
     bool_t allow_send_prod;
+    bool_t msg1_crossed_logged;
     int resolving_count;
     int resolving_n_results_all;
     int resolving_n_results_stored;
@@ -513,8 +535,10 @@ struct msg {
     char *sig;
 };
 
-static void set_new_transform(struct site *st, char *pk)
+static _Bool set_new_transform(struct site *st, char *pk)
 {
+    _Bool ok;
+
     /* Make room for the shared key */
     st->sharedsecretlen=st->chosen_transform->keylen?:st->dh->ceil_len;
     assert(st->sharedsecretlen);
@@ -532,15 +556,18 @@ static void set_new_transform(struct site *st, char *pk)
     /* Set up the transform */
     struct transform_if *generator=st->chosen_transform;
     struct transform_inst_if *generated=generator->create(generator->st);
-    generated->setkey(generated->st,st->sharedsecret,
-                     st->sharedsecretlen,st->setup_priority);
+    ok = generated->setkey(generated->st,st->sharedsecret,
+                          st->sharedsecretlen,st->setup_priority);
+
     dispose_transform(&st->new_transform);
+    if (!ok) return False;
     st->new_transform=generated;
 
     slog(st,LOG_SETUP_INIT,"key exchange negotiated transform"
         " %d (capabilities ours=%#"PRIx32" theirs=%#"PRIx32")",
         st->chosen_transform->capab_transformnum,
         st->local_capabilities, st->remote_capabilities);
+    return True;
 }
 
 struct xinfoadd {
@@ -692,6 +719,13 @@ static bool_t unpick_msg(struct site *st, uint32_t type,
     CHECK_AVAIL(msg,m->siglen);
     m->sig=buf_unprepend(msg,m->siglen);
     CHECK_EMPTY(msg);
+
+    /* In `process_msg3_msg4' below, we assume that we can write a nul
+     * terminator following the signature.  Make sure there's enough space.
+     */
+    if (msg->start >= msg->base + msg->alloclen)
+       return False;
+
     return True;
 }
 
@@ -825,7 +859,7 @@ static bool_t process_msg3_msg4(struct site *st, struct msg *m)
     hst=st->hash->init();
     st->hash->update(hst,m->hashstart,m->hashlen);
     st->hash->final(hst,hash);
-    /* Terminate signature with a '0' - cheating, but should be ok */
+    /* Terminate signature with a '0' - already checked that this will fit */
     m->sig[m->siglen]=0;
     if (!st->pubkey->check(st->pubkey->st,hash,st->hash->len,m->sig)) {
        slog(st,LOG_SEC,"msg3/msg4 signature failed check!");
@@ -884,7 +918,7 @@ static bool_t process_msg3(struct site *st, struct buffer_if *msg3,
     st->random->generate(st->random->st,st->dh->len,st->dhsecret);
 
     /* Generate the shared key and set up the transform */
-    set_new_transform(st,m.pk);
+    if (!set_new_transform(st,m.pk)) return False;
 
     return True;
 }
@@ -915,7 +949,7 @@ static bool_t process_msg4(struct site *st, struct buffer_if *msg4,
     m.pk[m.pklen]=0;
 
     /* Generate the shared key and set up the transform */
-    set_new_transform(st,m.pk);
+    if (!set_new_transform(st,m.pk)) return False;
 
     return True;
 }
@@ -1110,7 +1144,8 @@ static bool_t decrypt_msg0(struct site *st, struct buffer_if *msg0,
 
  skew:
     slog(st,LOG_DROP,"transform: %s (merely skew)",transform_err);
-    return False;
+    assert(problem);
+    return problem;
 }
 
 static bool_t process_msg0(struct site *st, struct buffer_if *msg0,
@@ -1127,6 +1162,10 @@ static bool_t process_msg0(struct site *st, struct buffer_if *msg0,
     case LABEL_MSG7:
        /* We must forget about the current session. */
        delete_keys(st,"request from peer",LOG_SEC);
+       /* probably, the peer is shutting down, and this is going to fail,
+        * but we need to be trying to bring the link up again */
+       if (st->keepalive)
+           initiate_key_setup(st,"peer requested key teardown",0);
        return True;
     case LABEL_MSG9:
        /* Deliver to netlink layer */
@@ -1145,16 +1184,34 @@ static bool_t process_msg0(struct site *st, struct buffer_if *msg0,
 }
 
 static void dump_packet(struct site *st, struct buffer_if *buf,
-                       const struct comm_addr *addr, bool_t incoming)
+                       const struct comm_addr *addr, bool_t incoming,
+                       bool_t ok)
 {
     uint32_t dest=get_uint32(buf->start);
     uint32_t source=get_uint32(buf->start+4);
     uint32_t msgtype=get_uint32(buf->start+8);
 
     if (st->log_events & LOG_DUMP)
-       slilog(st->log,M_DEBUG,"%s: %s: %08x<-%08x: %08x:",
+       slilog(st->log,M_DEBUG,"%s: %s: %08x<-%08x: %08x: %s%s",
               st->tunname,incoming?"incoming":"outgoing",
-              dest,source,msgtype);
+              dest,source,msgtype,comm_addr_to_string(addr),
+              ok?"":" - fail");
+}
+
+static bool_t comm_addr_sendmsg(struct site *st,
+                               const struct comm_addr *dest,
+                               struct buffer_if *buf)
+{
+    int i;
+    struct comm_clientinfo *commclientinfo = 0;
+
+    for (i=0; i < st->ncomms; i++) {
+       if (st->comms[i] == dest->comm) {
+           commclientinfo = st->commclientinfos[i];
+           break;
+       }
+    }
+    return dest->comm->sendmsg(dest->comm->st, buf, dest, commclientinfo);
 }
 
 static uint32_t site_status(void *st)
@@ -1240,7 +1297,7 @@ static void decrement_resolving_count(struct site *st, int by)
                 st->resolving_n_results_all, naddrs);
        }
        slog(st,LOG_STATE,"resolution completed, %d addrs, eg: %s",
-            naddrs, comm_addr_to_string(&addrs[0]));;
+            naddrs, iaddr_to_string(&addrs[0].ia));;
     }
 
     switch (st->state) {
@@ -1406,8 +1463,11 @@ static void enter_state_run(struct site *st)
     FILLZERO(st->remoteN);
     dispose_transform(&st->new_transform);
     memset(st->dhsecret,0,st->dh->len);
-    memset(st->sharedsecret,0,st->sharedsecretlen);
+    if (st->sharedsecret) memset(st->sharedsecret,0,st->sharedsecretlen);
     set_link_quality(st);
+
+    if (st->keepalive && !current_valid(st))
+       initiate_key_setup(st, "keepalive", 0);
 }
 
 static bool_t ensure_resolving(struct site *st)
@@ -1466,6 +1526,7 @@ static bool_t enter_new_state(struct site *st, uint32_t next)
     case SITE_SENTMSG1:
        state_assert(st,st->state==SITE_RUN || st->state==SITE_RESOLVE);
        gen=generate_msg1;
+       st->msg1_crossed_logged = False;
        break;
     case SITE_SENTMSG2:
        state_assert(st,st->state==SITE_RUN || st->state==SITE_RESOLVE ||
@@ -1581,8 +1642,8 @@ static void generate_send_prod(struct site *st,
     slog(st,LOG_SETUP_INIT,"prodding peer for key exchange");
     st->allow_send_prod=0;
     generate_prod(st,&st->scratch);
-    dump_packet(st,&st->scratch,source,False);
-    source->comm->sendmsg(source->comm->st, &st->scratch, source);
+    bool_t ok = comm_addr_sendmsg(st, source, &st->scratch);
+    dump_packet(st,&st->scratch,source,False,ok);
 }
 
 static inline void site_settimeout(uint64_t timeout, int *timeout_io)
@@ -1715,7 +1776,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
        if (!named_for_us(st,buf,msgtype,&named_msg))
            return False;
        /* It's a MSG1 addressed to us. Decide what to do about it. */
-       dump_packet(st,buf,source,True);
+       dump_packet(st,buf,source,True,True);
        if (st->state==SITE_RUN || st->state==SITE_RESOLVE ||
            st->state==SITE_WAIT) {
            /* We should definitely process it */
@@ -1738,8 +1799,9 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
               incoming one, otherwise we process it as usual. */
            if (st->setup_priority) {
                BUF_FREE(buf);
-               slog(st,LOG_DUMP,"crossed msg1s; we are higher "
-                    "priority => ignore incoming msg1");
+               if (!st->msg1_crossed_logged++)
+                   slog(st,LOG_DUMP,"crossed msg1s; we are higher "
+                        "priority => ignore incoming msg1");
                return True;
            } else {
                slog(st,LOG_DUMP,"crossed msg1s; we are lower "
@@ -1765,7 +1827,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
     if (msgtype==LABEL_PROD) {
        if (!named_for_us(st,buf,msgtype,&named_msg))
            return False;
-       dump_packet(st,buf,source,True);
+       dump_packet(st,buf,source,True,True);
        if (st->state!=SITE_RUN) {
            slog(st,LOG_DROP,"ignoring PROD when not in state RUN");
        } else if (current_valid(st)) {
@@ -1778,7 +1840,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
     }
     if (dest==st->index) {
        /* Explicitly addressed to us */
-       if (msgtype!=LABEL_MSG0) dump_packet(st,buf,source,True);
+       if (msgtype!=LABEL_MSG0) dump_packet(st,buf,source,True,True);
        switch (msgtype) {
        case LABEL_NAK:
            /* If the source is our current peer then initiate a key setup,
@@ -1927,7 +1989,7 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     dict_t *dict;
     int i;
 
-    st=safe_malloc(sizeof(*st),"site_apply");
+    NEW(st);
 
     st->cl.description="site";
     st->cl.type=CL_SITE;
@@ -1946,6 +2008,8 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     st->localname=dict_read_string(dict, "local-name", True, "site", loc);
     st->remotename=dict_read_string(dict, "name", True, "site", loc);
 
+    st->keepalive=dict_read_bool(dict,"keepalive",False,"site",loc,False);
+
     st->peer_mobile=dict_read_bool(dict,"mobile",False,"site",loc,False);
     st->local_mobile=
        dict_read_bool(dict,"local-mobile",False,"site",loc,False);
@@ -1980,7 +2044,7 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     if (!things##_cfg)                                                 \
        cfgfatal(loc,"site","closure list \"%s\" not found\n",dictkey); \
     st->nthings=list_length(things##_cfg);                             \
-    st->things=safe_malloc_ary(sizeof(*st->things),st->nthings,dictkey "s"); \
+    NEW_ARY(st->things,st->nthings);                                   \
     assert(st->nthings);                                               \
     for (i=0; i<st->nthings; i++) {                                    \
        item_t *item=list_elem(things##_cfg,i);                         \
@@ -1995,6 +2059,14 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
 
     GET_CLOSURE_LIST("comm",comms,ncomms,CL_COMM);
 
+    NEW_ARY(st->commclientinfos, st->ncomms);
+    dict_t *comminfo = dict_read_dict(dict,"comm-info",False,"site",loc);
+    for (i=0; i<st->ncomms; i++) {
+       st->commclientinfos[i] =
+           !comminfo ? 0 :
+           st->comms[i]->clientinfo(st->comms[i],comminfo,loc);
+    }
+
     st->resolver=find_cl_if(dict,"resolver",CL_RESOLVER,True,"site",loc);
     st->log=find_cl_if(dict,"log",CL_LOG,True,"site",loc);
     st->random=find_cl_if(dict,"random",CL_RANDOMSRC,True,"site",loc);
@@ -2193,11 +2265,18 @@ static void transport_record_peers(struct site *st, transport_peers *peers,
      *
      * Caller must first call transport_peers_expire. */
 
-    if (naddrs==1 && peers->npeers>=1 &&
-       comm_addr_equal(&addrs[0], &peers->peers[0].addr)) {
-       /* optimisation, also avoids debug for trivial updates */
-       peers->peers[0].last = *tv_now;
-       return;
+    if (naddrs==1) {
+       /* avoids debug for uninteresting updates */
+       int i;
+       for (i=0; i<peers->npeers; i++) {
+           if (comm_addr_equal(&addrs[0], &peers->peers[i].addr)) {
+               memmove(peers->peers+1, peers->peers,
+                       sizeof(peers->peers[0]) * i);
+               peers->peers[0].addr = addrs[0];
+               peers->peers[0].last = *tv_now;
+               return;
+           }
+       }
     }
 
     int old_npeers=peers->npeers;
@@ -2328,10 +2407,9 @@ void transport_xmit(struct site *st, transport_peers *peers,
     int nfailed=0;
     for (slot=0; slot<peers->npeers; slot++) {
        transport_peer *peer=&peers->peers[slot];
+       bool_t ok = comm_addr_sendmsg(st, &peer->addr, buf);
        if (candebug)
-           dump_packet(st, buf, &peer->addr, False);
-       bool_t ok =
-           peer->addr.comm->sendmsg(peer->addr.comm->st, buf, &peer->addr);
+           dump_packet(st, buf, &peer->addr, False, ok);
        if (!ok) {
            failed |= 1U << slot;
            nfailed++;
@@ -2353,12 +2431,14 @@ void transport_xmit(struct site *st, transport_peers *peers,
            transport_peers__copy_by_mask(peers->peers,&wslot,~failed,peers);
            assert(wslot+nfailed == peers->npeers);
            COPY_ARRAY(peers->peers+wslot, failedpeers, nfailed);
+           transport_peers_debug(st,peers,"mobile failure reorder",0,0,0);
        }
     } else {
        if (failed && peers->npeers > 1) {
            int wslot=0;
            transport_peers__copy_by_mask(peers->peers,&wslot,~failed,peers);
            peers->npeers=wslot;
+           transport_peers_debug(st,peers,"non-mobile failure cleanup",0,0,0);
        }
     }
 }