chiark / gitweb /
7db3559515f411eca917588d7fb69298765fa48a
[secnet.git] / netlink.c
1 /* User-kernel network link */
2
3 /* We support a variety of methods: userv-ipif, ipif on its own (when
4    we run as root), SLIP to a pty, an external netlink daemon. There
5    is a performance/security tradeoff. */
6
7 /* When dealing with SLIP (to a pty, or ipif) we have separate rx, tx
8    and client buffers. When receiving we may read() any amount, not
9    just whole packets. When transmitting we need to bytestuff anyway,
10    and may be part-way through receiving. */
11
12 /* Each netlink device is actually a router, with its own IP
13    address. We should eventually do things like decreasing the TTL and
14    recalculating the header checksum, generating ICMP, responding to
15    pings, etc. but for now we can get away without them. We should
16    implement this stuff no matter how we get the packets to/from the
17    kernel. */
18
19 /* This is where we have the anti-spoofing paranoia - before sending a
20    packet to the kernel we check that the tunnel it came over could
21    reasonably have produced it. */
22
23 #include <stdio.h>
24 #include <string.h>
25 #include <unistd.h>
26 #include <fcntl.h>
27
28 #include "secnet.h"
29 #include "util.h"
30
31 #define DEFAULT_BUFSIZE 2048
32
33 #define SLIP_END    192
34 #define SLIP_ESC    219
35 #define SLIP_ESCEND 220
36 #define SLIP_ESCESC 221
37
38 struct netlink_client {
39     struct subnet_list *networks;
40     netlink_deliver_fn *deliver;
41     void *dst;
42     struct netlink_client *next;
43 };
44
45 struct userv {
46     closure_t cl;
47     struct netlink_if ops;
48     uint32_t max_start_pad;
49     uint32_t max_end_pad;
50     int txfd; /* We transmit to userv */
51     int rxfd; /* We receive from userv */
52     struct netlink_client *clients;
53     string_t name;
54     string_t userv_path;
55     string_t service_user;
56     string_t service_name;
57     struct subnet_list networks;
58     uint32_t local_address;
59     uint32_t secnet_address;
60     uint32_t mtu;
61     uint32_t txbuflen;
62     struct buffer_if *buff; /* We unstuff received packets into here
63                                and send them to the site code. */
64     bool_t pending_esc;
65 };
66
67 static int userv_beforepoll(void *sst, struct pollfd *fds, int *nfds_io,
68                             int *timeout_io, const struct timeval *tv_now,
69                             uint64_t *now)
70 {
71     struct userv *st=sst;
72     *nfds_io=2;
73     fds[0].fd=st->txfd;
74     fds[0].events=POLLERR; /* Might want to pick up POLLOUT sometime */
75     fds[1].fd=st->rxfd;
76     fds[1].events=POLLIN|POLLERR|POLLHUP;
77     return 0;
78 }
79
80 static void process_local_packet(struct userv *st)
81 {
82     uint32_t source,dest;
83     struct netlink_client *c;
84
85     source=ntohl(*(uint32_t *)(st->buff->start+12));
86     dest=ntohl(*(uint32_t *)(st->buff->start+16));
87
88 /*    printf("process_local_packet source=%s dest=%s len=%d\n",
89       ipaddr_to_string(source),ipaddr_to_string(dest),
90       st->buff->size); */
91     if (!subnet_match(&st->networks,source)) {
92         string_t s,d;
93         s=ipaddr_to_string(source);
94         d=ipaddr_to_string(dest);
95         Message(M_WARNING,"%s: outgoing packet with bad source address "
96                 "(s=%s,d=%s)\n",st->name,s,d);
97         free(s); free(d);
98         return;
99     }
100     for (c=st->clients; c; c=c->next) {
101         if (subnet_match(c->networks,dest)) {
102             c->deliver(c->dst,c,st->buff);
103             BUF_ALLOC(st->buff,"netlink:process_local_packet");
104             return;
105         }
106     }
107     if (dest==st->secnet_address) {
108         printf("%s: secnet received packet of len %d from %s\n",st->name,
109                st->buff->size,ipaddr_to_string(source));
110         return;
111     }
112     {
113         string_t s,d;
114         s=ipaddr_to_string(source);
115         d=ipaddr_to_string(dest);
116         Message(M_WARNING,"%s: outgoing packet with bad destination address "
117                           "(s=%s,d=%s)\n",st->name,s,d);
118         free(s); free(d);
119         return;
120     }
121 }
122
123 static void userv_afterpoll(void *sst, struct pollfd *fds, int nfds,
124                             const struct timeval *tv_now, uint64_t *now)
125 {
126     struct userv *st=sst;
127     uint8_t rxbuf[DEFAULT_BUFSIZE];
128     int l,i;
129
130     if (fds[1].revents&POLLERR) {
131         printf("userv_afterpoll: hup!\n");
132     }
133     if (fds[1].revents&POLLIN) {
134         l=read(st->rxfd,rxbuf,DEFAULT_BUFSIZE);
135         if (l<0) {
136             fatal_perror("userv_afterpoll: read(rxfd)");
137         }
138         if (l==0) {
139             fatal("userv_afterpoll: read(rxfd)=0; userv gone away?\n");
140         }
141         /* XXX really crude unstuff code */
142         /* XXX check for buffer overflow */
143         for (i=0; i<l; i++) {
144             if (st->pending_esc) {
145                 st->pending_esc=False;
146                 switch(rxbuf[i]) {
147                 case SLIP_ESCEND:
148                     *(uint8_t *)buf_append(st->buff,1)=SLIP_END;
149                     break;
150                 case SLIP_ESCESC:
151                     *(uint8_t *)buf_append(st->buff,1)=SLIP_ESC;
152                     break;
153                 default:
154                     fatal("userv_afterpoll: bad SLIP escape character\n");
155                 }
156             } else {
157                 switch (rxbuf[i]) {
158                 case SLIP_END:
159                     if (st->buff->size>0) process_local_packet(st);
160                     BUF_ASSERT_USED(st->buff);
161                     buffer_init(st->buff,st->max_start_pad);
162                     break;
163                 case SLIP_ESC:
164                     st->pending_esc=True;
165                     break;
166                 default:
167                     *(uint8_t *)buf_append(st->buff,1)=rxbuf[i];
168                     break;
169                 }
170             }
171         }
172     }
173     return;
174 }
175
176 static void userv_phase_hook(void *sst, uint32_t newphase)
177 {
178     struct userv *st=sst;
179     pid_t child;
180     int c_stdin[2];
181     int c_stdout[2];
182     string_t addrs;
183     string_t nets;
184     string_t s;
185     struct netlink_client *c;
186     int i;
187
188     /* This is where we actually invoke userv - all the networks we'll
189        be using should already have been registered. */
190
191     addrs=safe_malloc(512,"userv_phase_hook:addrs");
192     snprintf(addrs,512,"%s,%s,%d,slip",ipaddr_to_string(st->local_address),
193              ipaddr_to_string(st->secnet_address),st->mtu);
194
195     nets=safe_malloc(1024,"userv_phase_hook:nets");
196     *nets=0;
197     for (c=st->clients; c; c=c->next) {
198         for (i=0; i<c->networks->entries; i++) {
199             s=subnet_to_string(&c->networks->list[i]);
200             strcat(nets,s);
201             strcat(nets,",");
202             free(s);
203         }
204     }
205     nets[strlen(nets)-1]=0;
206
207     Message(M_INFO,"\nuserv_phase_hook: %s %s %s %s %s\n",st->userv_path,
208            st->service_user,st->service_name,addrs,nets);
209
210     /* Allocate buffer, plus space for padding. Make sure we end up
211        with the start of the packet well-aligned. */
212     /* ALIGN(st->max_start_pad,16); */
213     /* ALIGN(st->max_end_pad,16); */
214
215     st->pending_esc=False;
216
217     /* Invoke userv */
218     if (pipe(c_stdin)!=0) {
219         fatal_perror("userv_phase_hook: pipe(c_stdin)");
220     }
221     if (pipe(c_stdout)!=0) {
222         fatal_perror("userv_phase_hook: pipe(c_stdout)");
223     }
224     st->txfd=c_stdin[1];
225     st->rxfd=c_stdout[0];
226
227     child=fork();
228     if (child==-1) {
229         fatal_perror("userv_phase_hook: fork()");
230     }
231     if (child==0) {
232         char **argv;
233
234         /* We are the child. Modify our stdin and stdout, then exec userv */
235         dup2(c_stdin[0],0);
236         dup2(c_stdout[1],1);
237         close(c_stdin[1]);
238         close(c_stdout[0]);
239
240         /* The arguments are:
241            userv
242            service-user
243            service-name
244            local-addr,secnet-addr,mtu,protocol
245            route1,route2,... */
246         argv=malloc(sizeof(*argv)*6);
247         argv[0]=st->userv_path;
248         argv[1]=st->service_user;
249         argv[2]=st->service_name;
250         argv[3]=addrs;
251         argv[4]=nets;
252         argv[5]=NULL;
253         execvp(st->userv_path,argv);
254         perror("netlink-userv-ipif: execvp");
255
256         exit(1);
257     }
258     /* We are the parent... */
259            
260     /* Register for poll() */
261     register_for_poll(st, userv_beforepoll, userv_afterpoll, 2, "netlink");
262 }
263
264 static void *userv_regnets(void *sst, struct subnet_list *nets,
265                            netlink_deliver_fn *deliver, void *dst,
266                            uint32_t max_start_pad, uint32_t max_end_pad)
267 {
268     struct userv *st=sst;
269     struct netlink_client *c;
270
271     Message(M_DEBUG_CONFIG,"userv_regnets: request for %d networks, "
272             "max_start_pad=%d, max_end_pad=%d\n",
273             nets->entries,max_start_pad,max_end_pad);
274
275     c=safe_malloc(sizeof(*c),"userv_regnets");
276     c->networks=nets;
277     c->deliver=deliver;
278     c->dst=dst;
279     c->next=st->clients;
280     st->clients=c;
281     if (max_start_pad > st->max_start_pad) st->max_start_pad=max_start_pad;
282     if (max_end_pad > st->max_end_pad) st->max_end_pad=max_end_pad;
283
284     return c;
285 }
286
287 static void userv_deliver(void *sst, void *cid, struct buffer_if *buf)
288 {
289     struct userv *st=sst;
290     struct netlink_client *client=cid;
291     uint8_t txbuf[DEFAULT_BUFSIZE];
292
293     uint32_t source,dest;
294     uint8_t *i;
295     uint32_t j;
296
297     source=ntohl(*(uint32_t *)(buf->start+12));
298     dest=ntohl(*(uint32_t *)(buf->start+16));
299
300     /* Check that the packet source is in 'nets' and its destination is
301        in client->networks */
302     if (!subnet_match(client->networks,source)) {
303         string_t s,d;
304         s=ipaddr_to_string(source);
305         d=ipaddr_to_string(dest);
306         Message(M_WARNING,"%s: incoming packet with bad source address "
307                 "(s=%s,d=%s)\n",st->name,s,d);
308         free(s); free(d);
309         return;
310     }
311     if (!subnet_match(&st->networks,dest)) {
312         string_t s,d;
313         s=ipaddr_to_string(source);
314         d=ipaddr_to_string(dest);
315         Message(M_WARNING,"%s: incoming packet with bad destination address "
316                 "(s=%s,d=%s)\n",st->name,s,d);
317         free(s); free(d);
318         return;
319     }
320
321     /* Really we should decrease TTL, check it's above zero, and
322        recalculate header checksum here. If it gets down to zero,
323        generate an ICMP time-exceeded and send the new packet back to
324        the originating tunnel. XXX check buffer usage! */
325
326     /* (Basically do full IP packet forwarding stuff. Except that we
327        know any packet passed in here is destined for the local
328        machine; only exception is if it's destined for us.) */
329
330     if (dest==st->secnet_address) {
331         printf("%s: incoming tunneled packet for secnet!\n",st->name);
332         return;
333     }
334
335     /* Now spit the packet at userv-ipif: SLIP start marker, then
336        bytestuff the packet, then SLIP end marker */
337     /* XXX crunchy bytestuff code */
338     j=0;
339     txbuf[j++]=SLIP_END;
340     for (i=buf->start; i<(buf->start+buf->size); i++) {
341         switch (*i) {
342         case SLIP_END:
343             txbuf[j++]=SLIP_ESC;
344             txbuf[j++]=SLIP_ESCEND;
345             break;
346         case SLIP_ESC:
347             txbuf[j++]=SLIP_ESC;
348             txbuf[j++]=SLIP_ESCESC;
349             break;
350         default:
351             txbuf[j++]=*i;
352             break;
353         }
354     }
355     txbuf[j++]=SLIP_END;
356     if (write(st->txfd,txbuf,j)<0) {
357         fatal_perror("userv_deliver: write()");
358     }
359
360     return;
361 }
362
363 static list_t *userv_apply(closure_t *self, struct cloc loc, dict_t *context,
364                            list_t *args)
365 {
366     struct userv *st;
367     item_t *item;
368     dict_t *dict;
369
370     st=safe_malloc(sizeof(*st),"userv_apply (netlink)");
371     st->cl.description="userv-netlink";
372     st->cl.type=CL_NETLINK;
373     st->cl.apply=NULL;
374     st->cl.interface=&st->ops;
375     st->ops.st=st;
376     st->ops.regnets=userv_regnets;
377     st->ops.deliver=userv_deliver;
378     st->max_start_pad=0;
379     st->max_end_pad=0;
380     st->rxfd=-1; st->txfd=-1;
381     st->clients=NULL;
382
383     /* First parameter must be a dict */
384     item=list_elem(args,0);
385     if (!item || item->type!=t_dict)
386         cfgfatal(loc,"userv-ipif","parameter must be a dictionary\n");
387     
388     dict=item->data.dict;
389     st->name=dict_read_string(dict,"name",False,"userv-netlink",loc);
390     st->userv_path=dict_read_string(dict,"userv-path",False,"userv-netlink",
391                                     loc);
392     st->service_user=dict_read_string(dict,"service-user",False,
393                                       "userv-netlink",loc);
394     st->service_name=dict_read_string(dict,"service-name",False,
395                                       "userv-netlink",loc);
396     if (!st->name) st->name="netlink-userv-ipif";
397     if (!st->userv_path) st->userv_path="userv";
398     if (!st->service_user) st->service_user="root";
399     if (!st->service_name) st->service_name="ipif";
400     dict_read_subnet_list(dict, "networks", True, "userv-netlink", loc,
401                           &st->networks);
402     st->local_address=string_to_ipaddr(
403         dict_find_item(dict,"local-address", True, "userv-netlink", loc),
404         "userv-netlink");
405     st->secnet_address=string_to_ipaddr(
406         dict_find_item(dict,"secnet-address", True, "userv-netlink", loc),
407         "userv-netlink");
408     if (!subnet_match(&st->networks,st->local_address)) {
409         cfgfatal(loc,"netlink-userv-ipif","local-address must be in "
410               "local networks\n");
411     }
412     st->mtu=dict_read_number(dict, "mtu", False, "userv-netlink", loc, 1000);
413     st->buff=find_cl_if(dict,"buffer",CL_BUFFER,True,"userv-netlink",loc);
414     BUF_ALLOC(st->buff,"netlink:userv_apply");
415
416     add_hook(PHASE_DROPPRIV,userv_phase_hook,st);
417
418     return new_closure(&st->cl);
419 }
420
421 struct null {
422     closure_t cl;
423     struct netlink_if ops;
424 };
425
426 static void *null_regnets(void *sst, struct subnet_list *nets,
427                           netlink_deliver_fn *deliver, void *dst,
428                           uint32_t max_start_pad, uint32_t max_end_pad)
429 {
430     Message(M_DEBUG_CONFIG,"null_regnets: request for %d networks, "
431             "max_start_pad=%d, max_end_pad=%d\n",
432             nets->entries,max_start_pad,max_end_pad);
433     return NULL;
434 }
435
436 static void null_deliver(void *sst, void *cid, struct buffer_if *buf)
437 {
438     return;
439 }
440
441 static list_t *null_apply(closure_t *self, struct cloc loc, dict_t *context,
442                           list_t *args)
443 {
444     struct null *st;
445
446     st=safe_malloc(sizeof(*st),"null_apply (netlink)");
447     st->cl.description="null-netlink";
448     st->cl.type=CL_NETLINK;
449     st->cl.apply=NULL;
450     st->cl.interface=&st->ops;
451     st->ops.st=st;
452     st->ops.regnets=null_regnets;
453     st->ops.deliver=null_deliver;
454
455     return new_closure(&st->cl);
456 }
457
458 init_module netlink_module;
459 void netlink_module(dict_t *dict)
460 {
461     add_closure(dict,"userv-ipif",userv_apply);
462 #if 0
463     add_closure(dict,"pty-slip",ptyslip_apply);
464     add_closure(dict,"slipd",slipd_apply);
465 #endif /* 0 */
466     add_closure(dict,"null-netlink",null_apply);
467 }