chiark / gitweb /
netlink: Break out netlink_client_deliver
[secnet.git] / netlink.c
index ecad16cb73e6d8cb18b724148912648a2482a219..54ad76fb908fa6d7383bb96422cb2fcf84d797a2 100644 (file)
--- a/netlink.c
+++ b/netlink.c
@@ -98,15 +98,14 @@ their use.
 */
 
 #include <string.h>
+#include <assert.h>
+#include <limits.h>
 #include "secnet.h"
 #include "util.h"
 #include "ipaddr.h"
 #include "netlink.h"
 #include "process.h"
 
-#define OPT_SOFTROUTE   1
-#define OPT_ALLOWROUTE  2
-
 #define ICMP_TYPE_ECHO_REPLY             0
 
 #define ICMP_TYPE_UNREACHABLE            3
@@ -121,7 +120,7 @@ their use.
 #define ICMP_CODE_TTL_EXCEEDED           0
 
 /* Generic IP checksum routine */
-static inline uint16_t ip_csum(uint8_t *iph,uint32_t count)
+static inline uint16_t ip_csum(uint8_t *iph,int32_t count)
 {
     register uint32_t sum=0;
 
@@ -145,7 +144,7 @@ static inline uint16_t ip_csum(uint8_t *iph,uint32_t count)
  *      By Jorge Cwik <jorge@laser.satlink.net>, adapted for linux by
  *      Arnt Gulbrandsen.
  */
-static inline uint16_t ip_fast_csum(uint8_t *iph, uint32_t ihl) {
+static inline uint16_t ip_fast_csum(uint8_t *iph, int32_t ihl) {
     uint32_t sum;
 
     __asm__ __volatile__(
@@ -175,8 +174,9 @@ static inline uint16_t ip_fast_csum(uint8_t *iph, uint32_t ihl) {
     return sum;
 }
 #else
-static inline uint16_t ip_fast_csum(uint8_t *iph, uint32_t ihl)
+static inline uint16_t ip_fast_csum(uint8_t *iph, int32_t ihl)
 {
+    assert(ihl < INT_MAX/4);
     return ip_csum(iph,ihl*4);
 }
 #endif
@@ -238,7 +238,7 @@ static struct icmphdr *netlink_icmp_tmpl(struct netlink *st,
     struct icmphdr *h;
 
     BUF_ALLOC(&st->icmp,"netlink_icmp_tmpl");
-    buffer_init(&st->icmp,st->max_start_pad);
+    buffer_init(&st->icmp,calculate_max_start_pad());
     h=buf_append(&st->icmp,sizeof(*h));
 
     h->iph.version=4;
@@ -262,7 +262,7 @@ static struct icmphdr *netlink_icmp_tmpl(struct netlink *st,
 /* Fill in the ICMP checksum field correctly */
 static void netlink_icmp_csum(struct icmphdr *h)
 {
-    uint32_t len;
+    int32_t len;
 
     len=ntohs(h->iph.tot_len)-(4*h->iph.ihl);
     h->check=0;
@@ -293,6 +293,7 @@ static bool_t netlink_icmp_may_reply(struct buffer_if *buf)
     struct icmphdr *icmph;
     uint32_t source;
 
+    if (buf->size < (int)sizeof(struct icmphdr)) return False;
     iph=(struct iphdr *)buf->start;
     icmph=(struct icmphdr *)buf->start;
     if (iph->protocol==1) {
@@ -338,6 +339,7 @@ static bool_t netlink_icmp_may_reply(struct buffer_if *buf)
    */
 static uint16_t netlink_icmp_reply_len(struct buffer_if *buf)
 {
+    if (buf->size < (int)sizeof(struct iphdr)) return 0;
     struct iphdr *iph=(struct iphdr *)buf->start;
     uint16_t hlen,plen;
 
@@ -354,11 +356,11 @@ static void netlink_icmp_simple(struct netlink *st, struct buffer_if *buf,
                                struct netlink_client *client,
                                uint8_t type, uint8_t code)
 {
-    struct iphdr *iph=(struct iphdr *)buf->start;
     struct icmphdr *h;
     uint16_t len;
 
     if (netlink_icmp_may_reply(buf)) {
+       struct iphdr *iph=(struct iphdr *)buf->start;
        len=netlink_icmp_reply_len(buf);
        h=netlink_icmp_tmpl(st,ntohl(iph->saddr),len);
        h->type=type; h->code=code;
@@ -381,19 +383,41 @@ static void netlink_icmp_simple(struct netlink *st, struct buffer_if *buf,
  * 3. Checksums correctly.
  * 4. Doesn't have a bogus length
  */
-static bool_t netlink_check(struct netlink *st, struct buffer_if *buf)
+static bool_t netlink_check(struct netlink *st, struct buffer_if *buf,
+                           char *errmsgbuf, int errmsgbuflen)
 {
+#define BAD(...) do{                                   \
+       snprintf(errmsgbuf,errmsgbuflen,__VA_ARGS__);   \
+       return False;                                   \
+    }while(0)
+
+    if (buf->size < (int)sizeof(struct iphdr)) BAD("len %"PRIu32"",buf->size);
     struct iphdr *iph=(struct iphdr *)buf->start;
-    uint32_t len;
+    int32_t len;
 
-    if (iph->ihl < 5 || iph->version != 4) return False;
-    if (buf->size < iph->ihl*4) return False;
-    if (ip_fast_csum((uint8_t *)iph, iph->ihl)!=0) return False;
+    if (iph->ihl < 5) BAD("ihl %u",iph->ihl);
+    if (iph->version != 4) BAD("version %u",iph->version);
+    if (buf->size < iph->ihl*4) BAD("size %"PRId32"<%u*4",buf->size,iph->ihl);
+    if (ip_fast_csum((uint8_t *)iph, iph->ihl)!=0) BAD("csum");
     len=ntohs(iph->tot_len);
     /* There should be no padding */
-    if (buf->size!=len || len<(iph->ihl<<2)) return False;
+    if (buf->size!=len) BAD("len %"PRId32"!=%"PRId32,buf->size,len);
+    if (len<(iph->ihl<<2)) BAD("len %"PRId32"<(%u<<2)",len,iph->ihl);
     /* XXX check that there's no source route specified */
     return True;
+
+#undef BAD
+}
+
+/* Deliver a packet _to_ client; used after we have decided
+ * what to do with it. */
+static void netlink_client_deliver(struct netlink *st,
+                                  struct netlink_client *client,
+                                  uint32_t source, uint32_t dest,
+                                  struct buffer_if *buf)
+{
+    client->deliver(client->dst, buf);
+    client->outcount++;
 }
 
 /* Deliver a packet. "client" is the _origin_ of the packet, not its
@@ -403,6 +427,13 @@ static void netlink_packet_deliver(struct netlink *st,
                                   struct netlink_client *client,
                                   struct buffer_if *buf)
 {
+    if (buf->size < (int)sizeof(struct iphdr)) {
+       Message(M_ERR,"%s: trying to deliver a too-short packet"
+               " from %s!\n",st->name, client?client->name:"(local)");
+       BUF_FREE(buf);
+       return;
+    }
+
     struct iphdr *iph=(struct iphdr *)buf->start;
     uint32_t dest=ntohl(iph->daddr);
     uint32_t source=ntohl(iph->saddr);
@@ -498,18 +529,18 @@ static void netlink_packet_deliver(struct netlink *st,
            netlink_icmp_simple(st,buf,client,ICMP_TYPE_UNREACHABLE,
                                ICMP_CODE_NET_PROHIBITED);
            BUF_FREE(buf);
-       }
-       if (best_quality>0) {
-           /* XXX Fragment if required */
-           st->routes[best_match]->deliver(
-               st->routes[best_match]->dst, buf);
-           st->routes[best_match]->outcount++;
-           BUF_ASSERT_FREE(buf);
        } else {
-           /* Generate ICMP destination unreachable */
-           netlink_icmp_simple(st,buf,client,ICMP_TYPE_UNREACHABLE,
-                               ICMP_CODE_NET_UNREACHABLE); /* client==NULL */
-           BUF_FREE(buf);
+           if (best_quality>0) {
+               /* XXX Fragment if required */
+               netlink_client_deliver(st,st->routes[best_match],
+                                      source,dest,buf);
+               BUF_ASSERT_FREE(buf);
+           } else {
+               /* Generate ICMP destination unreachable */
+               netlink_icmp_simple(st,buf,client,ICMP_TYPE_UNREACHABLE,
+                                   ICMP_CODE_NET_UNREACHABLE); /* client==NULL */
+               BUF_FREE(buf);
+           }
        }
     }
     BUF_ASSERT_FREE(buf);
@@ -519,6 +550,7 @@ static void netlink_packet_forward(struct netlink *st,
                                   struct netlink_client *client,
                                   struct buffer_if *buf)
 {
+    if (buf->size < (int)sizeof(struct iphdr)) return;
     struct iphdr *iph=(struct iphdr *)buf->start;
     
     BUF_ASSERT_USED(buf);
@@ -548,6 +580,12 @@ static void netlink_packet_local(struct netlink *st,
 
     st->localcount++;
 
+    if (buf->size < (int)sizeof(struct icmphdr)) {
+       Message(M_WARNING,"%s: short packet addressed to secnet; "
+               "ignoring it\n",st->name);
+       BUF_FREE(buf);
+       return;
+    }
     h=(struct icmphdr *)buf->start;
 
     if ((ntohs(h->iph.frag_off)&0xbfff)!=0) {
@@ -591,14 +629,19 @@ static void netlink_incoming(struct netlink *st, struct netlink_client *client,
 {
     uint32_t source,dest;
     struct iphdr *iph;
+    char errmsgbuf[50];
+    const char *sourcedesc=client?client->name:"host";
 
     BUF_ASSERT_USED(buf);
-    if (!netlink_check(st,buf)) {
-       Message(M_WARNING,"%s: bad IP packet from %s\n",
-               st->name,client?client->name:"host");
+
+    if (!netlink_check(st,buf,errmsgbuf,sizeof(errmsgbuf))) {
+       Message(M_WARNING,"%s: bad IP packet from %s: %s\n",
+               st->name,sourcedesc,
+               errmsgbuf);
        BUF_FREE(buf);
        return;
     }
+    assert(buf->size >= (int)sizeof(struct icmphdr));
     iph=(struct iphdr *)buf->start;
 
     source=ntohl(iph->saddr);
@@ -645,7 +688,7 @@ static void netlink_incoming(struct netlink *st, struct netlink_client *client,
        if (client) {
            st->deliver_to_host(st->dst,buf);
        } else {
-           st->clients->deliver(st->clients->dst,buf);
+           netlink_client_deliver(st,st->clients,source,dest,buf);
        }
        BUF_ASSERT_FREE(buf);
        return;
@@ -692,7 +735,7 @@ static void netlink_set_quality(void *sst, uint32_t quality)
 static void netlink_output_subnets(struct netlink *st, uint32_t loglevel,
                                   struct subnet_list *snets)
 {
-    uint32_t i;
+    int32_t i;
     string_t net;
 
     for (i=0; i<snets->entries; i++) {
@@ -721,14 +764,15 @@ static void netlink_dump_routes(struct netlink *st, bool_t requested)
        for (i=0; i<st->n_clients; i++) {
            netlink_output_subnets(st,c,st->routes[i]->subnets);
            Message(c,"-> tunnel %s (%s,mtu %d,%s routes,%s,"
-                   "quality %d,use %d)\n",
+                   "quality %d,use %d,pri %lu)\n",
                    st->routes[i]->name,
                    st->routes[i]->up?"up":"down",
                    st->routes[i]->mtu,
                    st->routes[i]->options&OPT_SOFTROUTE?"soft":"hard",
                    st->routes[i]->options&OPT_ALLOWROUTE?"free":"restricted",
                    st->routes[i]->link_quality,
-                   st->routes[i]->outcount);
+                   st->routes[i]->outcount,
+                   (unsigned long)st->routes[i]->priority);
        }
        net=ipaddr_to_string(st->secnet_address);
        Message(c,"%s/32 -> netlink \"%s\" (use %d)\n",
@@ -759,17 +803,19 @@ static void netlink_phase_hook(void *sst, uint32_t new_phase)
 {
     struct netlink *st=sst;
     struct netlink_client *c;
-    uint32_t i;
+    int32_t i;
 
     /* All the networks serviced by the various tunnels should now
      * have been registered.  We build a routing table by sorting the
      * clients by priority.  */
-    st->routes=safe_malloc(st->n_clients*sizeof(*st->routes),
-                          "netlink_phase_hook");
+    st->routes=safe_malloc_ary(sizeof(*st->routes),st->n_clients,
+                              "netlink_phase_hook");
     /* Fill the table */
     i=0;
-    for (c=st->clients; c; c=c->next)
+    for (c=st->clients; c; c=c->next) {
+       assert(i<INT_MAX);
        st->routes[i++]=c;
+    }
     /* Sort the table in descending order of priority */
     qsort(st->routes,st->n_clients,sizeof(*st->routes),
          netlink_compare_client_priority);
@@ -784,28 +830,7 @@ static void netlink_signal_handler(void *sst, int signum)
     netlink_dump_routes(st,True);
 }
 
-static void netlink_inst_output_config(void *sst, struct buffer_if *buf)
-{
-/*    struct netlink_client *c=sst; */
-/*    struct netlink *st=c->nst; */
-
-    /* For now we don't output anything */
-    BUF_ASSERT_USED(buf);
-}
-
-static bool_t netlink_inst_check_config(void *sst, struct buffer_if *buf)
-{
-/*    struct netlink_client *c=sst; */
-/*    struct netlink *st=c->nst; */
-
-    BUF_ASSERT_USED(buf);
-    /* We need to eat all of the configuration information from the buffer
-       for backward compatibility. */
-    buf->size=0;
-    return True;
-}
-
-static void netlink_inst_set_mtu(void *sst, uint32_t new_mtu)
+static void netlink_inst_set_mtu(void *sst, int32_t new_mtu)
 {
     struct netlink_client *c=sst;
 
@@ -813,14 +838,10 @@ static void netlink_inst_set_mtu(void *sst, uint32_t new_mtu)
 }
 
 static void netlink_inst_reg(void *sst, netlink_deliver_fn *deliver, 
-                            void *dst, uint32_t max_start_pad,
-                            uint32_t max_end_pad)
+                            void *dst)
 {
     struct netlink_client *c=sst;
-    struct netlink *st=c->nst;
 
-    if (max_start_pad > st->max_start_pad) st->max_start_pad=max_start_pad;
-    if (max_end_pad > st->max_end_pad) st->max_end_pad=max_end_pad;
     c->deliver=deliver;
     c->dst=dst;
 }
@@ -841,7 +862,8 @@ static closure_t *netlink_inst_create(struct netlink *st,
     struct netlink_client *c;
     string_t name;
     struct ipset *networks;
-    uint32_t options,priority,mtu;
+    uint32_t options,priority;
+    int32_t mtu;
     list_t *l;
 
     name=dict_read_string(dict, "name", True, st->name, loc);
@@ -890,8 +912,6 @@ static closure_t *netlink_inst_create(struct netlink *st,
     c->ops.reg=netlink_inst_reg;
     c->ops.deliver=netlink_inst_incoming;
     c->ops.set_quality=netlink_set_quality;
-    c->ops.output_config=netlink_inst_output_config;
-    c->ops.check_config=netlink_inst_check_config;
     c->ops.set_mtu=netlink_inst_set_mtu;
     c->nst=st;
 
@@ -901,7 +921,7 @@ static closure_t *netlink_inst_create(struct netlink *st,
     c->deliver=NULL;
     c->dst=NULL;
     c->name=name;
-    c->link_quality=LINK_QUALITY_DOWN;
+    c->link_quality=LINK_QUALITY_UNUSED;
     c->mtu=mtu?mtu:st->mtu;
     c->options=options;
     c->outcount=0;
@@ -909,6 +929,7 @@ static closure_t *netlink_inst_create(struct netlink *st,
     c->kup=False;
     c->next=st->clients;
     st->clients=c;
+    assert(st->n_clients < INT_MAX);
     st->n_clients++;
 
     return &c->cl;
@@ -948,8 +969,6 @@ netlink_deliver_fn *netlink_init(struct netlink *st,
     st->cl.type=CL_PURE;
     st->cl.apply=netlink_inst_apply;
     st->cl.interface=st;
-    st->max_start_pad=0;
-    st->max_end_pad=0;
     st->clients=NULL;
     st->routes=NULL;
     st->n_clients=0;
@@ -1065,7 +1084,6 @@ static list_t *null_apply(closure_t *self, struct cloc loc, dict_t *context,
     return new_closure(&st->nl.cl);
 }
 
-init_module netlink_module;
 void netlink_module(dict_t *dict)
 {
     add_closure(dict,"null-netlink",null_apply);