chiark / gitweb /
netlink: Safely discard short packets
[secnet.git] / netlink.c
index f6d4e72920ab7b9a28422ae200353e4a231dcb24..af6434feb9002c12a938a2d2923d2f92ea8b9c53 100644 (file)
--- a/netlink.c
+++ b/netlink.c
@@ -106,9 +106,6 @@ their use.
 #include "netlink.h"
 #include "process.h"
 
-#define OPT_SOFTROUTE   1
-#define OPT_ALLOWROUTE  2
-
 #define ICMP_TYPE_ECHO_REPLY             0
 
 #define ICMP_TYPE_UNREACHABLE            3
@@ -123,7 +120,7 @@ their use.
 #define ICMP_CODE_TTL_EXCEEDED           0
 
 /* Generic IP checksum routine */
-static inline uint16_t ip_csum(uint8_t *iph,uint32_t count)
+static inline uint16_t ip_csum(uint8_t *iph,int32_t count)
 {
     register uint32_t sum=0;
 
@@ -147,7 +144,7 @@ static inline uint16_t ip_csum(uint8_t *iph,uint32_t count)
  *      By Jorge Cwik <jorge@laser.satlink.net>, adapted for linux by
  *      Arnt Gulbrandsen.
  */
-static inline uint16_t ip_fast_csum(uint8_t *iph, uint32_t ihl) {
+static inline uint16_t ip_fast_csum(uint8_t *iph, int32_t ihl) {
     uint32_t sum;
 
     __asm__ __volatile__(
@@ -177,8 +174,9 @@ static inline uint16_t ip_fast_csum(uint8_t *iph, uint32_t ihl) {
     return sum;
 }
 #else
-static inline uint16_t ip_fast_csum(uint8_t *iph, uint32_t ihl)
+static inline uint16_t ip_fast_csum(uint8_t *iph, int32_t ihl)
 {
+    assert(ihl < INT_MAX/4);
     return ip_csum(iph,ihl*4);
 }
 #endif
@@ -240,7 +238,7 @@ static struct icmphdr *netlink_icmp_tmpl(struct netlink *st,
     struct icmphdr *h;
 
     BUF_ALLOC(&st->icmp,"netlink_icmp_tmpl");
-    buffer_init(&st->icmp,st->max_start_pad);
+    buffer_init(&st->icmp,calculate_max_start_pad());
     h=buf_append(&st->icmp,sizeof(*h));
 
     h->iph.version=4;
@@ -264,7 +262,7 @@ static struct icmphdr *netlink_icmp_tmpl(struct netlink *st,
 /* Fill in the ICMP checksum field correctly */
 static void netlink_icmp_csum(struct icmphdr *h)
 {
-    uint32_t len;
+    int32_t len;
 
     len=ntohs(h->iph.tot_len)-(4*h->iph.ihl);
     h->check=0;
@@ -295,6 +293,7 @@ static bool_t netlink_icmp_may_reply(struct buffer_if *buf)
     struct icmphdr *icmph;
     uint32_t source;
 
+    if (buf->size < (int)sizeof(struct icmphdr)) return False;
     iph=(struct iphdr *)buf->start;
     icmph=(struct icmphdr *)buf->start;
     if (iph->protocol==1) {
@@ -340,6 +339,7 @@ static bool_t netlink_icmp_may_reply(struct buffer_if *buf)
    */
 static uint16_t netlink_icmp_reply_len(struct buffer_if *buf)
 {
+    if (buf->size < (int)sizeof(struct iphdr)) return 0;
     struct iphdr *iph=(struct iphdr *)buf->start;
     uint16_t hlen,plen;
 
@@ -356,11 +356,11 @@ static void netlink_icmp_simple(struct netlink *st, struct buffer_if *buf,
                                struct netlink_client *client,
                                uint8_t type, uint8_t code)
 {
-    struct iphdr *iph=(struct iphdr *)buf->start;
     struct icmphdr *h;
     uint16_t len;
 
     if (netlink_icmp_may_reply(buf)) {
+       struct iphdr *iph=(struct iphdr *)buf->start;
        len=netlink_icmp_reply_len(buf);
        h=netlink_icmp_tmpl(st,ntohl(iph->saddr),len);
        h->type=type; h->code=code;
@@ -383,19 +383,30 @@ static void netlink_icmp_simple(struct netlink *st, struct buffer_if *buf,
  * 3. Checksums correctly.
  * 4. Doesn't have a bogus length
  */
-static bool_t netlink_check(struct netlink *st, struct buffer_if *buf)
+static bool_t netlink_check(struct netlink *st, struct buffer_if *buf,
+                           char *errmsgbuf, int errmsgbuflen)
 {
+#define BAD(...) do{                                   \
+       snprintf(errmsgbuf,errmsgbuflen,__VA_ARGS__);   \
+       return False;                                   \
+    }while(0)
+
+    if (buf->size < (int)sizeof(struct iphdr)) BAD("len %"PRIu32"",buf->size);
     struct iphdr *iph=(struct iphdr *)buf->start;
-    uint32_t len;
+    int32_t len;
 
-    if (iph->ihl < 5 || iph->version != 4) return False;
-    if (buf->size < iph->ihl*4) return False;
-    if (ip_fast_csum((uint8_t *)iph, iph->ihl)!=0) return False;
+    if (iph->ihl < 5) BAD("ihl %u",iph->ihl);
+    if (iph->version != 4) BAD("version %u",iph->version);
+    if (buf->size < iph->ihl*4) BAD("size %"PRId32"<%u*4",buf->size,iph->ihl);
+    if (ip_fast_csum((uint8_t *)iph, iph->ihl)!=0) BAD("csum");
     len=ntohs(iph->tot_len);
     /* There should be no padding */
-    if (buf->size!=len || len<(iph->ihl<<2)) return False;
+    if (buf->size!=len) BAD("len %"PRId32"!=%"PRId32,buf->size,len);
+    if (len<(iph->ihl<<2)) BAD("len %"PRId32"<(%u<<2)",len,iph->ihl);
     /* XXX check that there's no source route specified */
     return True;
+
+#undef BAD
 }
 
 /* Deliver a packet. "client" is the _origin_ of the packet, not its
@@ -405,6 +416,13 @@ static void netlink_packet_deliver(struct netlink *st,
                                   struct netlink_client *client,
                                   struct buffer_if *buf)
 {
+    if (buf->size < (int)sizeof(struct iphdr)) {
+       Message(M_ERR,"%s: trying to deliver a too-short packet"
+               " from %s!\n",st->name, client?client->name:"(local)");
+       BUF_FREE(buf);
+       return;
+    }
+
     struct iphdr *iph=(struct iphdr *)buf->start;
     uint32_t dest=ntohl(iph->daddr);
     uint32_t source=ntohl(iph->saddr);
@@ -522,6 +540,7 @@ static void netlink_packet_forward(struct netlink *st,
                                   struct netlink_client *client,
                                   struct buffer_if *buf)
 {
+    if (buf->size < (int)sizeof(struct iphdr)) return;
     struct iphdr *iph=(struct iphdr *)buf->start;
     
     BUF_ASSERT_USED(buf);
@@ -551,6 +570,12 @@ static void netlink_packet_local(struct netlink *st,
 
     st->localcount++;
 
+    if (buf->size < (int)sizeof(struct icmphdr)) {
+       Message(M_WARNING,"%s: short packet addressed to secnet; "
+               "ignoring it\n",st->name);
+       BUF_FREE(buf);
+       return;
+    }
     h=(struct icmphdr *)buf->start;
 
     if ((ntohs(h->iph.frag_off)&0xbfff)!=0) {
@@ -594,14 +619,19 @@ static void netlink_incoming(struct netlink *st, struct netlink_client *client,
 {
     uint32_t source,dest;
     struct iphdr *iph;
+    char errmsgbuf[50];
+    const char *sourcedesc=client?client->name:"host";
 
     BUF_ASSERT_USED(buf);
-    if (!netlink_check(st,buf)) {
-       Message(M_WARNING,"%s: bad IP packet from %s\n",
-               st->name,client?client->name:"host");
+
+    if (!netlink_check(st,buf,errmsgbuf,sizeof(errmsgbuf))) {
+       Message(M_WARNING,"%s: bad IP packet from %s: %s\n",
+               st->name,sourcedesc,
+               errmsgbuf);
        BUF_FREE(buf);
        return;
     }
+    assert(buf->size >= (int)sizeof(struct icmphdr));
     iph=(struct iphdr *)buf->start;
 
     source=ntohl(iph->saddr);
@@ -695,7 +725,7 @@ static void netlink_set_quality(void *sst, uint32_t quality)
 static void netlink_output_subnets(struct netlink *st, uint32_t loglevel,
                                   struct subnet_list *snets)
 {
-    uint32_t i;
+    int32_t i;
     string_t net;
 
     for (i=0; i<snets->entries; i++) {
@@ -763,13 +793,13 @@ static void netlink_phase_hook(void *sst, uint32_t new_phase)
 {
     struct netlink *st=sst;
     struct netlink_client *c;
-    uint32_t i;
+    int32_t i;
 
     /* All the networks serviced by the various tunnels should now
      * have been registered.  We build a routing table by sorting the
      * clients by priority.  */
-    st->routes=safe_malloc(st->n_clients*sizeof(*st->routes),
-                          "netlink_phase_hook");
+    st->routes=safe_malloc_ary(sizeof(*st->routes),st->n_clients,
+                              "netlink_phase_hook");
     /* Fill the table */
     i=0;
     for (c=st->clients; c; c=c->next) {
@@ -790,28 +820,7 @@ static void netlink_signal_handler(void *sst, int signum)
     netlink_dump_routes(st,True);
 }
 
-static void netlink_inst_output_config(void *sst, struct buffer_if *buf)
-{
-/*    struct netlink_client *c=sst; */
-/*    struct netlink *st=c->nst; */
-
-    /* For now we don't output anything */
-    BUF_ASSERT_USED(buf);
-}
-
-static bool_t netlink_inst_check_config(void *sst, struct buffer_if *buf)
-{
-/*    struct netlink_client *c=sst; */
-/*    struct netlink *st=c->nst; */
-
-    BUF_ASSERT_USED(buf);
-    /* We need to eat all of the configuration information from the buffer
-       for backward compatibility. */
-    buf->size=0;
-    return True;
-}
-
-static void netlink_inst_set_mtu(void *sst, uint32_t new_mtu)
+static void netlink_inst_set_mtu(void *sst, int32_t new_mtu)
 {
     struct netlink_client *c=sst;
 
@@ -819,14 +828,10 @@ static void netlink_inst_set_mtu(void *sst, uint32_t new_mtu)
 }
 
 static void netlink_inst_reg(void *sst, netlink_deliver_fn *deliver, 
-                            void *dst, uint32_t max_start_pad,
-                            uint32_t max_end_pad)
+                            void *dst)
 {
     struct netlink_client *c=sst;
-    struct netlink *st=c->nst;
 
-    if (max_start_pad > st->max_start_pad) st->max_start_pad=max_start_pad;
-    if (max_end_pad > st->max_end_pad) st->max_end_pad=max_end_pad;
     c->deliver=deliver;
     c->dst=dst;
 }
@@ -847,7 +852,8 @@ static closure_t *netlink_inst_create(struct netlink *st,
     struct netlink_client *c;
     string_t name;
     struct ipset *networks;
-    uint32_t options,priority,mtu;
+    uint32_t options,priority;
+    int32_t mtu;
     list_t *l;
 
     name=dict_read_string(dict, "name", True, st->name, loc);
@@ -896,8 +902,6 @@ static closure_t *netlink_inst_create(struct netlink *st,
     c->ops.reg=netlink_inst_reg;
     c->ops.deliver=netlink_inst_incoming;
     c->ops.set_quality=netlink_set_quality;
-    c->ops.output_config=netlink_inst_output_config;
-    c->ops.check_config=netlink_inst_check_config;
     c->ops.set_mtu=netlink_inst_set_mtu;
     c->nst=st;
 
@@ -907,7 +911,7 @@ static closure_t *netlink_inst_create(struct netlink *st,
     c->deliver=NULL;
     c->dst=NULL;
     c->name=name;
-    c->link_quality=LINK_QUALITY_DOWN;
+    c->link_quality=LINK_QUALITY_UNUSED;
     c->mtu=mtu?mtu:st->mtu;
     c->options=options;
     c->outcount=0;
@@ -955,8 +959,6 @@ netlink_deliver_fn *netlink_init(struct netlink *st,
     st->cl.type=CL_PURE;
     st->cl.apply=netlink_inst_apply;
     st->cl.interface=st;
-    st->max_start_pad=0;
-    st->max_end_pad=0;
     st->clients=NULL;
     st->routes=NULL;
     st->n_clients=0;