chiark / gitweb /
comm: Introduce comm_addr_to_string
[secnet.git] / netlink.c
index b14d2ebe40420b324626b98ecc825fdfa4aac334..ea1221deaf9e67a10019f474001d99774c961872 100644 (file)
--- a/netlink.c
+++ b/netlink.c
@@ -180,7 +180,7 @@ static inline uint16_t ip_fast_csum(const uint8_t *iph, int32_t ihl) {
     return sum;
 }
 #else
-static inline uint16_t ip_fast_csum(uint8_t *iph, int32_t ihl)
+static inline uint16_t ip_fast_csum(const uint8_t *iph, int32_t ihl)
 {
     assert(ihl < INT_MAX/4);
     return ip_csum(iph,ihl*4);
@@ -237,6 +237,15 @@ struct icmphdr {
 
 static const union icmpinfofield icmp_noinfo;
     
+static void netlink_client_deliver(struct netlink *st,
+                                  struct netlink_client *client,
+                                  uint32_t source, uint32_t dest,
+                                  struct buffer_if *buf);
+static void netlink_host_deliver(struct netlink *st,
+                                struct netlink_client *sender,
+                                uint32_t source, uint32_t dest,
+                                struct buffer_if *buf);
+
 static const char *sender_name(struct netlink_client *sender /* or NULL */)
 {
     return sender?sender->name:"(local)";
@@ -254,7 +263,8 @@ static void netlink_packet_deliver(struct netlink *st,
    settable.
    */
 static struct icmphdr *netlink_icmp_tmpl(struct netlink *st,
-                                        uint32_t dest,uint16_t len)
+                                        uint32_t source, uint32_t dest,
+                                        uint16_t len)
 {
     struct icmphdr *h;
 
@@ -270,7 +280,7 @@ static struct icmphdr *netlink_icmp_tmpl(struct netlink *st,
     h->iph.frag=0;
     h->iph.ttl=255; /* XXX should be configurable */
     h->iph.protocol=1;
-    h->iph.saddr=htonl(st->secnet_address);
+    h->iph.saddr=htonl(source);
     h->iph.daddr=htonl(dest);
     h->iph.check=0;
     h->iph.check=ip_fast_csum((uint8_t *)&h->iph,h->iph.ihl);
@@ -392,12 +402,44 @@ static void netlink_icmp_simple(struct netlink *st,
 
     if (netlink_icmp_may_reply(buf)) {
        struct iphdr *iph=(struct iphdr *)buf->start;
+
+       uint32_t icmpdest = ntohl(iph->saddr);
+       uint32_t icmpsource;
+       const char *icmpsourcedebugprefix;
+       if (!st->ptp) {
+           icmpsource=st->secnet_address;
+           icmpsourcedebugprefix="";
+       } else if (origsender) {
+           /* was from peer, send reply as if from host */
+           icmpsource=st->local_address;
+           icmpsourcedebugprefix="L!";
+       } else {
+           /* was from host, send reply as if from peer */
+           icmpsource=st->secnet_address; /* actually, peer address */
+           icmpsourcedebugprefix="P!";
+       }
+       MDEBUG("%s: generating ICMP re %s[%s]->[%s]:"
+              " from %s%s type=%u code=%u\n",
+              st->name, sender_name(origsender),
+              ipaddr_to_string(ntohl(iph->saddr)),
+              ipaddr_to_string(ntohl(iph->daddr)),
+              icmpsourcedebugprefix,
+              ipaddr_to_string(icmpsource),
+              type, code);
+
        len=netlink_icmp_reply_len(buf);
-       h=netlink_icmp_tmpl(st,ntohl(iph->saddr),len);
+       h=netlink_icmp_tmpl(st,icmpsource,icmpdest,len);
        h->type=type; h->code=code; h->d=info;
        memcpy(buf_append(&st->icmp,len),buf->start,len);
        netlink_icmp_csum(h);
-       netlink_packet_deliver(st,NULL,&st->icmp);
+
+       if (!st->ptp) {
+           netlink_packet_deliver(st,NULL,&st->icmp);
+       } else if (origsender) {
+           netlink_client_deliver(st,origsender,icmpsource,icmpdest,&st->icmp);
+       } else {
+           netlink_host_deliver(st,NULL,icmpsource,icmpdest,&st->icmp);
+       }
        BUF_ASSERT_FREE(&st->icmp);
     }
 }
@@ -1198,6 +1240,8 @@ netlink_deliver_fn *netlink_init(struct netlink *st,
        st->remote_networks=ipset_complement(empty);
        ipset_free(empty);
     }
+    st->local_address=string_item_to_ipaddr(
+       dict_find_item(dict,"local-address", True, "netlink", loc),"netlink");
 
     sa=dict_find_item(dict,"secnet-address",False,"netlink",loc);
     ptpa=dict_find_item(dict,"ptp-address",False,"netlink",loc);