chiark / gitweb /
sig: Move marshalling responsibility into sign function
[secnet.git] / site.c
diff --git a/site.c b/site.c
index 478ce4d05a9a50c32819604365a1713722707469..6e5c5a67d56ed0d8b2bf8a06662298c9039a6c27 100644 (file)
--- a/site.c
+++ b/site.c
 #define SITE_SENTMSG5 7
 #define SITE_WAIT     8
 
+#define CASES_MSG3_KNOWN LABEL_MSG3: case LABEL_MSG3BIS
+
 int32_t site_max_start_pad = 4*4;
 
 static cstring_t state_name(uint32_t state)
@@ -310,14 +312,15 @@ struct site {
     struct resolver_if *resolver;
     struct log_if *log;
     struct random_if *random;
-    struct rsaprivkey_if *privkey;
-    struct rsapubkey_if *pubkey;
+    struct sigprivkey_if *privkey;
+    struct sigpubkey_if *pubkey;
     struct transform_if **transforms;
     int ntransforms;
     struct dh_if *dh;
     struct hash_if *hash;
 
     uint32_t index; /* Index of this site */
+    uint32_t early_capabilities;
     uint32_t local_capabilities;
     int32_t setup_retries; /* How many times to send setup packets */
     int32_t setup_retry_interval; /* Initial timeout for setup packets */
@@ -505,10 +508,10 @@ static void dispose_transform(struct transform_inst_if **transform_var)
 
 static _Bool type_is_msg34(uint32_t type)
 {
-    return
-       type == LABEL_MSG3 ||
-       type == LABEL_MSG3BIS ||
-       type == LABEL_MSG4;
+    switch (type) {
+       case CASES_MSG3_KNOWN: case LABEL_MSG4: return True;
+       default: return False;
+    }
 }
 
 struct parsedname {
@@ -614,7 +617,8 @@ static bool_t generate_msg(struct site *st, uint32_t type, cstring_t what)
 {
     void *hst;
     uint8_t *hash;
-    string_t dhpub, sig;
+    string_t dhpub;
+    unsigned minor;
 
     st->retries=st->setup_retries;
     BUF_ALLOC(&st->buffer,what);
@@ -626,7 +630,8 @@ static bool_t generate_msg(struct site *st, uint32_t type, cstring_t what)
 
     struct xinfoadd xia;
     append_string_xinfo_start(&st->buffer,&xia,st->localname);
-    if ((st->local_capabilities & CAPAB_EARLY) || (type != LABEL_MSG1)) {
+    if ((st->local_capabilities & st->early_capabilities) ||
+       (type != LABEL_MSG1)) {
        buf_append_uint32(&st->buffer,st->local_capabilities);
     }
     if (type_is_msg34(type)) {
@@ -642,8 +647,11 @@ static bool_t generate_msg(struct site *st, uint32_t type, cstring_t what)
 
     if (hacky_par_mid_failnow()) return False;
 
-    if (type==LABEL_MSG3BIS)
+    if (MSGMAJOR(type) == 3) do {
+       minor = MSGMINOR(type);
+       if (minor < 1) break;
        buf_append_uint8(&st->buffer,st->chosen_transform->capab_bit);
+    } while (0);
 
     dhpub=st->dh->makepublic(st->dh->st,st->dhsecret,st->dh->len);
     buf_append_string(&st->buffer,dhpub);
@@ -652,11 +660,15 @@ static bool_t generate_msg(struct site *st, uint32_t type, cstring_t what)
     hst=st->hash->init();
     st->hash->update(hst,st->buffer.start,st->buffer.size);
     st->hash->final(hst,hash);
-    sig=st->privkey->sign(st->privkey->st,hash,st->hash->len);
-    buf_append_string(&st->buffer,sig);
-    free(sig);
+    bool_t ok=st->privkey->sign(st->privkey->st,hash,st->hash->len,
+                               &st->buffer);
+    if (!ok) goto fail;
     free(hash);
     return True;
+
+ fail:
+    free(hash);
+    return False;
 }
 
 static bool_t unpick_name(struct buffer_if *msg, struct parsedname *nm)
@@ -678,6 +690,8 @@ static bool_t unpick_name(struct buffer_if *msg, struct parsedname *nm)
 static bool_t unpick_msg(struct site *st, uint32_t type,
                         struct buffer_if *msg, struct msg *m)
 {
+    unsigned minor;
+
     m->capab_transformnum=-1;
     m->hashstart=msg->start;
     CHECK_AVAIL(msg,4);
@@ -713,12 +727,19 @@ static bool_t unpick_msg(struct site *st, uint32_t type,
        CHECK_EMPTY(msg);
        return True;
     }
-    if (type==LABEL_MSG3BIS) {
-       CHECK_AVAIL(msg,1);
-       m->capab_transformnum = buf_unprepend_uint8(msg);
-    } else {
-       m->capab_transformnum = CAPAB_BIT_ANCIENTTRANSFORM;
-    }
+    if (MSGMAJOR(type) == 3) do {
+       minor = MSGMINOR(type);
+#define MAYBE_READ_CAP(minminor, kind, dflt) do {                      \
+    if (minor < (minminor))                                            \
+       m->capab_##kind##num = (dflt);                                  \
+    else {                                                             \
+       CHECK_AVAIL(msg, 1);                                            \
+       m->capab_##kind##num = buf_unprepend_uint8(msg);                \
+    }                                                                  \
+} while (0)
+       MAYBE_READ_CAP(1, transform, CAPAB_BIT_ANCIENTTRANSFORM);
+#undef MAYBE_READ_CAP
+    } while (0);
     CHECK_AVAIL(msg,2);
     m->pklen=buf_unprepend_uint16(msg);
     CHECK_AVAIL(msg,m->pklen);
@@ -773,7 +794,7 @@ static bool_t check_msg(struct site *st, uint32_t type, struct msg *m,
     }
     /* MSG3 has complicated rules about capabilities, which are
      * handled in process_msg3. */
-    if (type==LABEL_MSG3 || type==LABEL_MSG3BIS) return True;
+    if (MSGMAJOR(type) == 3) return True;
     if (m->remote_capabilities!=st->remote_capabilities) {
        *error="remote capabilities changed";
        return False;
@@ -824,25 +845,32 @@ static bool_t process_msg2(struct site *st, struct buffer_if *msg2,
 
     /* Select the transform to use */
 
-    uint32_t remote_transforms = st->remote_capabilities & CAPAB_TRANSFORM_MASK;
-    if (!remote_transforms)
+    uint32_t remote_crypto_caps = st->remote_capabilities & CAPAB_TRANSFORM_MASK;
+    if (!remote_crypto_caps)
        /* old secnets only had this one transform */
-       remote_transforms = 1UL << CAPAB_BIT_ANCIENTTRANSFORM;
+       remote_crypto_caps = 1UL << CAPAB_BIT_ANCIENTTRANSFORM;
+
+#define CHOOSE_CRYPTO(kind, whats) do {                                        \
+    struct kind##_if *iface;                                           \
+    uint32_t bit, ours = 0;                                            \
+    int i;                                                             \
+    for (i= 0; i < st->n##kind##s; i++) {                              \
+       iface=st->kind##s[i];                                           \
+       bit = 1UL << iface->capab_bit;                                  \
+       if (bit & remote_crypto_caps) goto kind##_found;                \
+       ours |= bit;                                                    \
+    }                                                                  \
+    slog(st,LOG_ERROR,"no " whats " in common"                         \
+        " (us %#"PRIx32"; them: %#"PRIx32")",                          \
+        st->local_capabilities & ours, remote_crypto_caps);            \
+    return False;                                                      \
+kind##_found:                                                          \
+    st->chosen_##kind = iface;                                         \
+} while (0)
 
-    struct transform_if *ti;
-    int i;
-    for (i=0; i<st->ntransforms; i++) {
-       ti=st->transforms[i];
-       if ((1UL << ti->capab_bit) & remote_transforms)
-           goto transform_found;
-    }
-    slog(st,LOG_ERROR,"no transforms in common"
-        " (us %#"PRIx32"; them: %#"PRIx32")",
-        st->local_capabilities & CAPAB_TRANSFORM_MASK,
-        remote_transforms);
-    return False;
- transform_found:
-    st->chosen_transform=ti;
+    CHOOSE_CRYPTO(transform, "transforms");
+
+#undef CHOOSE_CRYPTO
 
     memcpy(st->remoteN,m.nR,NONCELEN);
     return True;
@@ -854,8 +882,9 @@ static bool_t generate_msg3(struct site *st)
        and create message number 3. */
     st->random->generate(st->random->st,st->dh->len,st->dhsecret);
     return generate_msg(st,
-                       (st->remote_capabilities & CAPAB_TRANSFORM_MASK
-                        ? LABEL_MSG3BIS : LABEL_MSG3),
+                       (st->remote_capabilities & CAPAB_TRANSFORM_MASK)
+                       ? LABEL_MSG3BIS
+                       : LABEL_MSG3,
                        "site:MSG3");
 }
 
@@ -889,7 +918,10 @@ static bool_t process_msg3(struct site *st, struct buffer_if *msg3,
     struct msg m;
     cstring_t err;
 
-    assert(msgtype==LABEL_MSG3 || msgtype==LABEL_MSG3BIS);
+    switch (msgtype) {
+       case CASES_MSG3_KNOWN: break;
+       default: assert(0);
+    }
 
     if (!unpick_msg(st,msgtype,msg3,&m)) return False;
     if (!check_msg(st,msgtype,&m,&err)) {
@@ -897,31 +929,46 @@ static bool_t process_msg3(struct site *st, struct buffer_if *msg3,
        return False;
     }
     uint32_t capab_adv_late = m.remote_capabilities
-       & ~st->remote_capabilities & CAPAB_EARLY;
+       & ~st->remote_capabilities & st->early_capabilities;
     if (capab_adv_late) {
        slog(st,LOG_SEC,"msg3 impermissibly adds early capability flag(s)"
             " %#"PRIx32" (was %#"PRIx32", now %#"PRIx32")",
             capab_adv_late, st->remote_capabilities, m.remote_capabilities);
        return False;
     }
-    st->remote_capabilities|=m.remote_capabilities;
 
-    struct transform_if *ti;
-    int i;
-    for (i=0; i<st->ntransforms; i++) {
-       ti=st->transforms[i];
-       if (ti->capab_bit == m.capab_transformnum)
-           goto transform_found;
-    }
-    slog(st,LOG_SEC,"peer chose unknown-to-us transform %d!",
-        m.capab_transformnum);
-    return False;
- transform_found:
-    st->chosen_transform=ti;
+#define CHOSE_CRYPTO(kind, what) do {                                  \
+    struct kind##_if *iface;                                           \
+    int i;                                                             \
+    for (i=0; i<st->n##kind##s; i++) {                                 \
+       iface=st->kind##s[i];                                           \
+       if (iface->capab_bit == m.capab_##kind##num)                    \
+           goto kind##_found;                                          \
+    }                                                                  \
+    slog(st,LOG_SEC,"peer chose unknown-to-us " what " %d!",           \
+        m.capab_##kind##num);                                                  \
+    return False;                                                      \
+kind##_found:                                                          \
+    st->chosen_##kind=iface;                                           \
+} while (0)
+
+    CHOSE_CRYPTO(transform, "transform");
+
+#undef CHOSE_CRYPTO
 
     if (!process_msg3_msg4(st,&m))
        return False;
 
+    /* Update our idea of the remote site's capabilities, now that we've
+     * verified that its message was authentic.
+     *
+     * Our previous idea of the remote site's capabilities came from the
+     * unauthenticated MSG1.  We've already checked that this new message
+     * doesn't change any of the bits we relied upon in the past, but it may
+     * also have set additional capability bits.  We simply throw those away
+     * now, and use the authentic capabilities from this MSG3. */
+    st->remote_capabilities=m.remote_capabilities;
+
     /* Terminate their DH public key with a '0' */
     m.pk[m.pklen]=0;
     /* Invent our DH secret key */
@@ -1941,8 +1988,7 @@ static bool_t site_incoming(void *sst, struct buffer_if *buf,
                slog(st,LOG_SEC,"invalid MSG2");
            }
            break;
-       case LABEL_MSG3:
-       case LABEL_MSG3BIS:
+       case CASES_MSG3_KNOWN:
            /* Setup packet: expected only in state SENTMSG2 */
            if (st->state!=SITE_SENTMSG2) {
                if ((st->state==SITE_SENTMSG4) &&
@@ -2112,6 +2158,7 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     assert(index_sequence < 0xffffffffUL);
     st->index = ++index_sequence;
     st->local_capabilities = 0;
+    st->early_capabilities = CAPAB_PRIORITY_MOBILE;
     st->netlink=find_cl_if(dict,"link",CL_NETLINK,True,"site",loc);
 
 #define GET_CLOSURE_LIST(dictkey,things,nthings,CL_TYPE) do{           \
@@ -2146,12 +2193,12 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     st->log=find_cl_if(dict,"log",CL_LOG,True,"site",loc);
     st->random=find_cl_if(dict,"random",CL_RANDOMSRC,True,"site",loc);
 
-    st->privkey=find_cl_if(dict,"local-key",CL_RSAPRIVKEY,True,"site",loc);
+    st->privkey=find_cl_if(dict,"local-key",CL_SIGPRIVKEY,True,"site",loc);
     st->addresses=dict_read_string_array(dict,"address",False,"site",loc,0);
     if (st->addresses)
        st->remoteport=dict_read_number(dict,"port",True,"site",loc,0);
     else st->remoteport=0;
-    st->pubkey=find_cl_if(dict,"key",CL_RSAPUBKEY,True,"site",loc);
+    st->pubkey=find_cl_if(dict,"key",CL_SIGPUBKEY,True,"site",loc);
 
     GET_CLOSURE_LIST("transform",transforms,ntransforms,CL_TRANSFORM);
 
@@ -2228,14 +2275,18 @@ static list_t *site_apply(closure_t *self, struct cloc loc, dict_t *context,
     st->sharedsecretlen=st->sharedsecretallocd=0;
     st->sharedsecret=0;
 
-    for (i=0; i<st->ntransforms; i++) {
-       struct transform_if *ti=st->transforms[i];
-       uint32_t capbit = 1UL << ti->capab_bit;
-       if (st->local_capabilities & capbit)
-           slog(st,LOG_ERROR,"bit capability bit"
-                " %d (%#"PRIx32") reused", ti->capab_bit, capbit);
-       st->local_capabilities |= capbit;
-    }
+#define SET_CAPBIT(bit) do {                                           \
+    uint32_t capflag = 1UL << (bit);                                   \
+    if (st->local_capabilities & capflag)                              \
+       slog(st,LOG_ERROR,"capability bit"                              \
+            " %d (%#"PRIx32") reused", (bit), capflag);                \
+    st->local_capabilities |= capflag;                                 \
+} while (0)
+
+    for (i=0; i<st->ntransforms; i++)
+       SET_CAPBIT(st->transforms[i]->capab_bit);
+
+#undef SET_CAPBIT
 
     if (st->local_mobile || st->peer_mobile)
        st->local_capabilities |= CAPAB_PRIORITY_MOBILE;