chiark / gitweb /
site: Move main `struct msg' into site_incoming
[secnet.git] / site.c
diff --git a/site.c b/site.c
index 73d5b64aae560f9b244a333e78d564ed5cc2aa58..9c320a5b31013077cfddaca9842726e3d79dc39e 100644 (file)
--- a/site.c
+++ b/site.c
@@ -317,7 +317,6 @@ struct site {
     struct transform_if **transforms;
     int ntransforms;
     struct dh_if *dh;
-    struct hash_if *hash;
 
     uint32_t index; /* Index of this site */
     uint32_t early_capabilities;
@@ -614,8 +613,6 @@ static void append_string_xinfo_done(struct buffer_if *buf,
    out using a transform of config data supplied by netlink */
 static bool_t generate_msg(struct site *st, uint32_t type, cstring_t what)
 {
-    void *hst;
-    uint8_t *hash;
     string_t dhpub;
     unsigned minor;
 
@@ -655,18 +652,15 @@ static bool_t generate_msg(struct site *st, uint32_t type, cstring_t what)
     dhpub=st->dh->makepublic(st->dh->st,st->dhsecret,st->dh->len);
     buf_append_string(&st->buffer,dhpub);
     free(dhpub);
-    hash=safe_malloc(st->hash->len, "generate_msg");
-    hst=st->hash->init();
-    st->hash->update(hst,st->buffer.start,st->buffer.size);
-    st->hash->final(hst,hash);
-    bool_t ok=st->privkey->sign(st->privkey->st,hash,st->hash->len,
+
+    bool_t ok=st->privkey->sign(st->privkey->st,
+                               st->buffer.start,
+                               st->buffer.size,
                                &st->buffer);
     if (!ok) goto fail;
-    free(hash);
     return True;
 
  fail:
-    free(hash);
     return False;
 }
 
@@ -824,18 +818,18 @@ static bool_t generate_msg2(struct site *st)
 }
 
 static bool_t process_msg2(struct site *st, struct buffer_if *msg2,
-                          const struct comm_addr *src)
+                          const struct comm_addr *src,
+                          struct msg *m /* returned */)
 {
-    struct msg m;
     cstring_t err;
 
-    if (!unpick_msg(st,LABEL_MSG2,msg2,&m)) return False;
-    if (!check_msg(st,LABEL_MSG2,&m,&err)) {
+    if (!unpick_msg(st,LABEL_MSG2,msg2,m)) return False;
+    if (!check_msg(st,LABEL_MSG2,m,&err)) {
        slog(st,LOG_SEC,"msg2: %s",err);
        return False;
     }
-    st->setup_session_id=m.source;
-    st->remote_capabilities=m.remote_capabilities;
+    st->setup_session_id=m->source;
+    st->remote_capabilities=m->remote_capabilities;
 
     /* Select the transform to use */
 
@@ -866,7 +860,7 @@ kind##_found:                                                               \
 
 #undef CHOOSE_CRYPTO
 
-    memcpy(st->remoteN,m.nR,NONCELEN);
+    memcpy(st->remoteN,m->nR,NONCELEN);
     return True;
 }
 
@@ -884,22 +878,13 @@ static bool_t generate_msg3(struct site *st)
 
 static bool_t process_msg3_msg4(struct site *st, struct msg *m)
 {
-    uint8_t *hash;
-    void *hst;
-
     /* Check signature and store g^x mod m */
-    hash=safe_malloc(st->hash->len, "process_msg3_msg4");
-    hst=st->hash->init();
-    st->hash->update(hst,m->hashstart,m->hashlen);
-    st->hash->final(hst,hash);
     if (!st->pubkey->check(st->pubkey->st,
-                          hash,st->hash->len,
+                          m->hashstart,m->hashlen,
                           &m->sig)) {
        slog(st,LOG_SEC,"msg3/msg4 signature failed check!");
-       free(hash);
        return False;
     }
-    free(hash);
 
     st->remote_adv_mtu=m->remote_mtu;
 
@@ -907,9 +892,9 @@ static bool_t process_msg3_msg4(struct site *st, struct msg *m)
 }
 
 static bool_t process_msg3(struct site *st, struct buffer_if *msg3,
-                          const struct comm_addr *src, uint32_t msgtype)
+                          const struct comm_addr *src, uint32_t msgtype,
+                          struct msg *m /* returned */)
 {
-    struct msg m;
     cstring_t err;
 
     switch (msgtype) {
@@ -917,17 +902,17 @@ static bool_t process_msg3(struct site *st, struct buffer_if *msg3,
        default: assert(0);
     }
 
-    if (!unpick_msg(st,msgtype,msg3,&m)) return False;
-    if (!check_msg(st,msgtype,&m,&err)) {
+    if (!unpick_msg(st,msgtype,msg3,m)) return False;
+    if (!check_msg(st,msgtype,m,&err)) {
        slog(st,LOG_SEC,"msg3: %s",err);
        return False;
     }
-    uint32_t capab_adv_late = m.remote_capabilities
+    uint32_t capab_adv_late = m->remote_capabilities
        & ~st->remote_capabilities & st->early_capabilities;
     if (capab_adv_late) {
        slog(st,LOG_SEC,"msg3 impermissibly adds early capability flag(s)"
             " %#"PRIx32" (was %#"PRIx32", now %#"PRIx32")",
-            capab_adv_late, st->remote_capabilities, m.remote_capabilities);
+            capab_adv_late, st->remote_capabilities, m->remote_capabilities);
        return False;
     }
 
@@ -936,11 +921,11 @@ static bool_t process_msg3(struct site *st, struct buffer_if *msg3,
     int i;                                                             \
     for (i=0; i<st->n##kind##s; i++) {                                 \
        iface=st->kind##s[i];                                           \
-       if (iface->capab_bit == m.capab_##kind##num)                    \
+       if (iface->capab_bit == m->capab_##kind##num)                   \
            goto kind##_found;                                          \
     }                                                                  \
     slog(st,LOG_SEC,"peer chose unknown-to-us " what " %d!",           \
-        m.capab_##kind##num);                                                  \
+        m->capab_##kind##num);                                                 \
     return False;                                                      \
 kind##_found:                                                          \
     st->chosen_##kind=iface;                                           \
@@ -950,7 +935,7 @@ kind##_found:                                                               \
 
 #undef CHOSE_CRYPTO
 
-    if (!process_msg3_msg4(st,&m))
+    if (!process_msg3_msg4(st,m))
        return False;
 
     /* Update our idea of the remote site's capabilities, now that we've
@@ -961,15 +946,15 @@ kind##_found:                                                             \
      * doesn't change any of the bits we relied upon in the past, but it may
      * also have set additional capability bits.  We simply throw those away
      * now, and use the authentic capabilities from this MSG3. */
-    st->remote_capabilities=m.remote_capabilities;
+    st->remote_capabilities=m->remote_capabilities;
 
     /* Terminate their DH public key with a '0' */
-    m.pk[m.pklen]=0;
+    m->pk[m->pklen]=0;
     /* Invent our DH secret key */
     st->random->generate(st->random->st,st->dh->len,st->dhsecret);
 
     /* Generate the shared key and set up the transform */
-    if (!set_new_transform(st,m.pk)) return False;
+    if (!set_new_transform(st,m->pk)) return False;
 
     return True;
 }
@@ -982,25 +967,25 @@ static bool_t generate_msg4(struct site *st)
 }
 
 static bool_t process_msg4(struct site *st, struct buffer_if *msg4,
-                          const struct comm_addr *src)
+                          const struct comm_addr *src,
+                          struct msg *m /* returned */)
 {
-    struct msg m;
     cstring_t err;
 
-    if (!unpick_msg(st,LABEL_MSG4,msg4,&m)) return False;
-    if (!check_msg(st,LABEL_MSG4,&m,&err)) {
+    if (!unpick_msg(st,LABEL_MSG4,msg4,m)) return False;
+    if (!check_msg(st,LABEL_MSG4,m,&err)) {
        slog(st,LOG_SEC,"msg4: %s",err);
        return False;
     }
     
-    if (!process_msg3_msg4(st,&m))
+    if (!process_msg3_msg4(st,m))
        return False;
 
     /* Terminate their DH public key with a '0' */
-    m.pk[m.pklen]=0;
+    m->pk[m->pklen]=0;
 
     /* Generate the shared key and set up the transform */
-    if (!set_new_transform(st,m.pk)) return False;
+    if (!set_new_transform(st,m->pk)) return False;
 
     return True;
 }
@@ -1833,17 +1818,17 @@ static bool_t we_have_priority(struct site *st, const struct msg *m) {
 static bool_t setup_late_msg_ok(struct site *st, 
                                const struct buffer_if *buf_in,
                                uint32_t msgtype,
-                               const struct comm_addr *source) {
+                               const struct comm_addr *source,
+                               struct msg *m /* returned */) {
     /* For setup packets which seem from their type like they are
      * late.  Maybe they came via a different path.  All we do is make
      * a note of the sending address, iff they look like they are part
      * of the current key setup attempt. */
-    struct msg m;
-    if (!named_for_us(st,buf_in,msgtype,&m))
+    if (!named_for_us(st,buf_in,msgtype,m))
        /* named_for_us calls unpick_msg which gets the nonces */
        return False;
-    if (!consttime_memeq(m.nR,st->remoteN,NONCELEN) ||
-       !consttime_memeq(m.nL,st->localN, NONCELEN))
+    if (!consttime_memeq(m->nR,st->remoteN,NONCELEN) ||
+       !consttime_memeq(m->nL,st->localN, NONCELEN))
        /* spoof ?  from stale run ?  who knows */
        return False;
     transport_setup_msgok(st,source);
@@ -1864,10 +1849,10 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
 
     uint32_t dest=get_uint32(buf->start);
     uint32_t msgtype=get_uint32(buf->start+8);
-    struct msg named_msg;
+    struct msg msg;
 
     if (msgtype==LABEL_MSG1) {
-       if (!named_for_us(st,buf,msgtype,&named_msg))
+       if (!named_for_us(st,buf,msgtype,&msg))
            return False;
        /* It's a MSG1 addressed to us. Decide what to do about it. */
        dump_packet(st,buf,source,True,True);
@@ -1875,7 +1860,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            st->state==SITE_WAIT) {
            /* We should definitely process it */
            transport_compute_setupinit_peers(st,0,0,source);
-           if (process_msg1(st,buf,source,&named_msg)) {
+           if (process_msg1(st,buf,source,&msg)) {
                slog(st,LOG_SETUP_INIT,"key setup initiated by peer");
                bool_t entered=enter_new_state(st,SITE_SENTMSG2);
                if (entered && st->addresses && st->local_mobile)
@@ -1891,7 +1876,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            /* We've just sent a message 1! They may have crossed on
               the wire. If we have priority then we ignore the
               incoming one, otherwise we process it as usual. */
-           if (we_have_priority(st,&named_msg)) {
+           if (we_have_priority(st,&msg)) {
                BUF_FREE(buf);
                if (!st->msg1_crossed_logged++)
                    slog(st,LOG_SETUP_INIT,"crossed msg1s; we are higher "
@@ -1900,7 +1885,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            } else {
                slog(st,LOG_SETUP_INIT,"crossed msg1s; we are lower "
                     "priority => use incoming msg1");
-               if (process_msg1(st,buf,source,&named_msg)) {
+               if (process_msg1(st,buf,source,&msg)) {
                    BUF_FREE(&st->buffer); /* Free our old message 1 */
                    transport_setup_msgok(st,source);
                    enter_new_state(st,SITE_SENTMSG2);
@@ -1913,7 +1898,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            }
        } else if (st->state==SITE_SENTMSG2 ||
                   st->state==SITE_SENTMSG4) {
-           if (consttime_memeq(named_msg.nR,st->remoteN,NONCELEN)) {
+           if (consttime_memeq(msg.nR,st->remoteN,NONCELEN)) {
                /* We are ahead in the protocol, but that msg1 had the
                 * peer's nonce so presumably it is from this key
                 * exchange run, via a slower route */
@@ -1931,7 +1916,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
        return True;
     }
     if (msgtype==LABEL_PROD) {
-       if (!named_for_us(st,buf,msgtype,&named_msg))
+       if (!named_for_us(st,buf,msgtype,&msg))
            return False;
        dump_packet(st,buf,source,True,True);
        if (st->state!=SITE_RUN) {
@@ -1972,10 +1957,10 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            if (st->state!=SITE_SENTMSG1) {
                if ((st->state==SITE_SENTMSG3 ||
                     st->state==SITE_SENTMSG5) &&
-                   setup_late_msg_ok(st,buf,msgtype,source))
+                   setup_late_msg_ok(st,buf,msgtype,source,&msg))
                    break;
                slog(st,LOG_UNEXPECTED,"unexpected MSG2");
-           } else if (process_msg2(st,buf,source)) {
+           } else if (process_msg2(st,buf,source,&msg)) {
                transport_setup_msgok(st,source);
                enter_new_state(st,SITE_SENTMSG3);
            } else {
@@ -1986,10 +1971,10 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            /* Setup packet: expected only in state SENTMSG2 */
            if (st->state!=SITE_SENTMSG2) {
                if ((st->state==SITE_SENTMSG4) &&
-                   setup_late_msg_ok(st,buf,msgtype,source))
+                   setup_late_msg_ok(st,buf,msgtype,source,&msg))
                    break;
                slog(st,LOG_UNEXPECTED,"unexpected MSG3");
-           } else if (process_msg3(st,buf,source,msgtype)) {
+           } else if (process_msg3(st,buf,source,msgtype,&msg)) {
                transport_setup_msgok(st,source);
                enter_new_state(st,SITE_SENTMSG4);
            } else {
@@ -2000,10 +1985,10 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
            /* Setup packet: expected only in state SENTMSG3 */
            if (st->state!=SITE_SENTMSG3) {
                if ((st->state==SITE_SENTMSG5) &&
-                   setup_late_msg_ok(st,buf,msgtype,source))
+                   setup_late_msg_ok(st,buf,msgtype,source,&msg))
                    break;
                slog(st,LOG_UNEXPECTED,"unexpected MSG4");
-           } else if (process_msg4(st,buf,source)) {
+           } else if (process_msg4(st,buf,source,&msg)) {
                transport_setup_msgok(st,source);
                enter_new_state(st,SITE_SENTMSG5);
            } else {
@@ -2197,7 +2182,12 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     GET_CLOSURE_LIST("transform",transforms,ntransforms,CL_TRANSFORM);
 
     st->dh=find_cl_if(dict,"dh",CL_DH,True,"site",loc);
-    st->hash=find_cl_if(dict,"hash",CL_HASH,True,"site",loc);
+
+    if (st->privkey->sethash || st->pubkey->sethash) {
+       struct hash_if *hash=find_cl_if(dict,"hash",CL_HASH,True,"site",loc);
+       if (st->privkey->sethash) st->privkey->sethash(st->privkey->st,hash);
+       if (st->pubkey->sethash) st->pubkey->sethash(st->pubkey->st,hash);
+    }
 
 #define DEFAULT(D) (st->peer_mobile || st->local_mobile        \
                     ? DEFAULT_MOBILE_##D : DEFAULT_##D)