chiark / gitweb /
server/admin.h: Consolidate address construction during resolution.
[tripe] / pkstream / pkstream.c
1 /* -*-c-*-
2  *
3  * Forwarding UDP packets over a stream
4  *
5  * (c) 2003 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Trivial IP Encryption (TrIPE).
11  *
12  * TrIPE is free software: you can redistribute it and/or modify it under
13  * the terms of the GNU General Public License as published by the Free
14  * Software Foundation; either version 3 of the License, or (at your
15  * option) any later version.
16  *
17  * TrIPE is distributed in the hope that it will be useful, but WITHOUT
18  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
19  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
20  * for more details.
21  *
22  * You should have received a copy of the GNU General Public License
23  * along with TrIPE.  If not, see <https://www.gnu.org/licenses/>.
24  */
25
26 /*----- Header files ------------------------------------------------------*/
27
28 #include "config.h"
29
30 #include <ctype.h>
31 #include <errno.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35
36 #include <sys/time.h>
37 #include <sys/types.h>
38 #include <unistd.h>
39 #include <fcntl.h>
40 #include <sys/uio.h>
41 #include <sys/socket.h>
42 #include <netinet/in.h>
43 #include <arpa/inet.h>
44 #include <netdb.h>
45
46 #include <mLib/alloc.h>
47 #include <mLib/bits.h>
48 #include <mLib/darray.h>
49 #include <mLib/dstr.h>
50 #include <mLib/fdflags.h>
51 #include <mLib/mdwopt.h>
52 #include <mLib/quis.h>
53 #include <mLib/report.h>
54 #include <mLib/sel.h>
55 #include <mLib/selpk.h>
56
57 #include "util.h"
58
59 /*----- Data structures ---------------------------------------------------*/
60
61 typedef union addr {
62   struct sockaddr sa;
63   struct sockaddr_in sin;
64   struct sockaddr_in6 sin6;
65 } addr;
66
67 DA_DECL(addr_v, addr);
68 DA_DECL(str_v, const char *);
69
70 typedef struct pk {
71   struct pk *next;                      /* Next packet in the chain */
72   octet *p, *o;                         /* Buffer start and current posn */
73   size_t n;                             /* Size of packet remaining */
74 } pk;
75
76 typedef struct pkstream {
77   unsigned f;                           /* Flags... */
78 #define PKF_FULL 1u                     /*   Buffer is full: stop reading */
79   sel_file r, w;                        /* Read and write selectors */
80   pk *pks, **pk_tail;                   /* Packet queue */
81   size_t npk, szpk;                     /* Number and size of data */
82   selpk p;                              /* Packet parser */
83 } pkstream;
84
85 typedef struct connwait {
86   unsigned f;                           /* Various flags */
87 #define cwf_port 1u                     /*   Port is defined => listen */
88   sel_file *sfv;                        /* Selectors */
89   addr_v me, peer;                     /* Who I'm meant to be; who peer is */
90 } connwait;
91
92 /*----- Static variables --------------------------------------------------*/
93
94 static sel_state sel;
95 static connwait cw;
96 static int fd_udp;
97 static size_t pk_nmax = 128, pk_szmax = 1024*1024;
98
99 /*----- Main code ---------------------------------------------------------*/
100
101 static int nonblockify(int fd)
102   { return (fdflags(fd, O_NONBLOCK, O_NONBLOCK, 0, 0)); }
103
104 static int cloexec(int fd)
105   { return (fdflags(fd, 0, 0, FD_CLOEXEC, FD_CLOEXEC)); }
106
107 static socklen_t addrsz(const addr *a)
108 {
109   switch (a->sa.sa_family) {
110     case AF_INET: return sizeof(a->sin);
111     case AF_INET6: return sizeof(a->sin6);
112     default: abort();
113   }
114 }
115
116 static int knownafp(int af)
117 {
118   switch (af) {
119     case AF_INET: case AF_INET6: return (1);
120     default: return (0);
121   }
122 }
123
124 static int initsock(int fd, int af)
125 {
126   int yes = 1;
127
128   switch (af) {
129     case AF_INET: break;
130     case AF_INET6:
131       if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &yes, sizeof(yes)))
132         return (-1);
133       break;
134     default: abort();
135   }
136   return (0);
137 }
138
139 static const char *addrstr(const addr *a)
140 {
141   static char buf[128];
142   socklen_t n = sizeof(buf);
143
144   if (getnameinfo(&a->sa, addrsz(a), buf, n, 0, 0, NI_NUMERICHOST))
145     return ("<addrstr failed>");
146   return (buf);
147 }
148
149 static int addreq(const addr *a, const addr *b)
150 {
151   if (a->sa.sa_family != b->sa.sa_family) return (0);
152   switch (a->sa.sa_family) {
153     case AF_INET:
154       return (a->sin.sin_addr.s_addr == b->sin.sin_addr.s_addr);
155     case AF_INET6:
156       return (!memcmp(a->sin6.sin6_addr.s6_addr,
157                       b->sin6.sin6_addr.s6_addr,
158                       16) &&
159               a->sin6.sin6_scope_id == b->sin6.sin6_scope_id);
160     default:
161       abort();
162   }
163 }
164
165 static void initaddr(addr *a, int af)
166 {
167   a->sa.sa_family = af;
168   switch (af) {
169     case AF_INET:
170       a->sin.sin_addr.s_addr = INADDR_ANY;
171       a->sin.sin_port = 0;
172       break;
173     case AF_INET6:
174       memset(a->sin6.sin6_addr.s6_addr, 0, 16);
175       a->sin6.sin6_port = 0;
176       a->sin6.sin6_flowinfo = 0;
177       a->sin6.sin6_scope_id = 0;
178       break;
179     default:
180       abort();
181   }
182 }
183
184 #define caf_addr 1u
185 #define caf_port 2u
186 static void copyaddr(addr *a, const struct sockaddr *sa, unsigned f)
187 {
188   const struct sockaddr_in *sin;
189   const struct sockaddr_in6 *sin6;
190
191   a->sa.sa_family = sa->sa_family;
192   switch (sa->sa_family) {
193     case AF_INET:
194       sin = (const struct sockaddr_in *)sa;
195       if (f&caf_addr) a->sin.sin_addr = sin->sin_addr;
196       if (f&caf_port) a->sin.sin_port = sin->sin_port;
197       break;
198     case AF_INET6:
199       sin6 = (const struct sockaddr_in6 *)sa;
200       if (f&caf_addr) {
201         a->sin6.sin6_addr = sin6->sin6_addr;
202         a->sin6.sin6_scope_id = sin6->sin6_scope_id;
203       }
204       if (f&caf_port) a->sin6.sin6_port = sin6->sin6_port;
205       /* ??? flowinfo? */
206       break;
207     default:
208       abort();
209   }
210 }
211
212 static void dolisten(void);
213
214 static void doclose(pkstream *p)
215 {
216   pk *pk, *ppk;
217   close(p->w.fd);
218   close(p->p.reader.fd);
219   selpk_destroy(&p->p);
220   if (!(p->f&PKF_FULL)) sel_rmfile(&p->r);
221   if (p->npk) sel_rmfile(&p->w);
222   for (pk = p->pks; pk; pk = ppk) {
223     ppk = pk->next;
224     xfree(pk->p);
225     xfree(pk);
226   }
227   xfree(p);
228   if (cw.f&cwf_port) dolisten();
229   else exit(0);
230 }
231
232 static void rdtcp(octet *b, size_t sz, pkbuf *pk, size_t *k, void *vp)
233 {
234   pkstream *p = vp;
235   size_t pksz;
236
237   if (!sz) { doclose(p); return; }
238   pksz = LOAD16(b);
239   if (pksz + 2 == sz) {
240     DISCARD(write(fd_udp, b + 2, pksz));
241     selpk_want(&p->p, 2);
242   } else {
243     selpk_want(&p->p, pksz + 2);
244     *k = sz;
245   }
246 }
247
248 static void wrtcp(int fd, unsigned mode, void *vp)
249 {
250 #define NPK 16
251   struct iovec iov[NPK];
252   pkstream *p = vp;
253   size_t i;
254   ssize_t n;
255   pk *pk, *ppk;
256
257   for (i = 0, pk = p->pks; i < NPK && pk; i++, pk = pk->next) {
258     iov[i].iov_base = pk->o;
259     iov[i].iov_len = pk->n;
260   }
261
262   if ((n = writev(fd, iov, i)) < 0) {
263     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) return;
264     moan("couldn't write to TCP socket: %s", strerror(errno));
265     doclose(p);
266     return;
267   }
268
269   p->szpk -= n;
270   for (pk = p->pks; n && pk; pk = ppk) {
271     ppk = pk->next;
272     if (pk->n <= n) {
273       p->npk--;
274       n -= pk->n;
275       xfree(pk->p);
276       xfree(pk);
277     } else {
278       pk->n -= n;
279       pk->o += n;
280       break;
281     }
282   }
283   p->pks = pk;
284   if (!pk) { p->pk_tail = &p->pks; sel_rmfile(&p->w); }
285   if ((p->f&PKF_FULL) && p->npk < pk_nmax && p->szpk < pk_szmax)
286     { p->f &= ~PKF_FULL; sel_addfile(&p->r); }
287 }
288
289 static void rdudp(int fd, unsigned mode, void *vp)
290 {
291   octet buf[65536];
292   ssize_t n;
293   pkstream *p = vp;
294   pk *pk;
295
296   if ((n = read(fd, buf, sizeof(buf))) < 0) {
297     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
298       return;
299     moan("couldn't read from UDP socket: %s", strerror(errno));
300     return;
301   }
302   pk = xmalloc(sizeof(*pk));
303   pk->next = 0;
304   pk->p = xmalloc(n + 2);
305   STORE16(pk->p, n);
306   memcpy(pk->p + 2, buf, n);
307   pk->o = pk->p;
308   pk->n = n + 2;
309   *p->pk_tail = pk;
310   p->pk_tail = &pk->next;
311   if (!p->npk) sel_addfile(&p->w);
312   sel_force(&p->w);
313   p->npk++;
314   p->szpk += n + 2;
315   if (p->npk >= pk_nmax || p->szpk >= pk_szmax)
316     { sel_rmfile(&p->r); p->f |= PKF_FULL; }
317 }
318
319 static void dofwd(int fd_in, int fd_out)
320 {
321   pkstream *p = xmalloc(sizeof(*p));
322   sel_initfile(&sel, &p->r, fd_udp, SEL_READ, rdudp, p);
323   sel_initfile(&sel, &p->w, fd_out, SEL_WRITE, wrtcp, p);
324   selpk_init(&p->p, &sel, fd_in, rdtcp, p);
325   selpk_want(&p->p, 2);
326   p->pks = 0;
327   p->pk_tail = &p->pks;
328   p->npk = p->szpk = 0;
329   p->f = 0;
330   sel_addfile(&p->r);
331 }
332
333 static void doaccept(int fd_s, unsigned mode, void *p)
334 {
335   int fd;
336   addr a;
337   socklen_t sz = sizeof(a);
338   size_t i, n;
339
340   if ((fd = accept(fd_s, &a.sa, &sz)) < 0) {
341     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) return;
342     moan("couldn't accept incoming connection: %s", strerror(errno));
343     return;
344   }
345   n = DA_LEN(&cw.peer);
346   if (!n) goto match;
347   for (i = 0; i < n; i++) if (addreq(&a, &DA(&cw.peer)[i])) goto match;
348   moan("rejecting connection from %s", addrstr(&a));
349   close(fd); return;
350 match:
351   if (nonblockify(fd) || cloexec(fd)) {
352     moan("couldn't accept incoming connection: %s", strerror(errno));
353     close(fd); return;
354   }
355   dofwd(fd, fd);
356   n = DA_LEN(&cw.me);
357   for (i = 0; i < n; i++) { close(cw.sfv[i].fd); sel_rmfile(&cw.sfv[i]); }
358 }
359
360 static void dolisten1(const addr *a, sel_file *sf)
361 {
362   int fd;
363   int opt = 1;
364
365   if ((fd = socket(a->sa.sa_family, SOCK_STREAM, IPPROTO_TCP)) < 0 ||
366       setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) ||
367       initsock(fd, a->sa.sa_family) ||
368       bind(fd, &a->sa, addrsz(a)) ||
369       listen(fd, 1) || nonblockify(fd) || cloexec(fd))
370     die(1, "couldn't set up listening socket: %s", strerror(errno));
371   sel_initfile(&sel, sf, fd, SEL_READ, doaccept, 0);
372   sel_addfile(sf);
373 }
374
375 static void dolisten(void)
376 {
377   size_t i, n;
378
379   n = DA_LEN(&cw.me);
380   for (i = 0; i < n; i++)
381     dolisten1(&DA(&cw.me)[i], &cw.sfv[i]);
382 }
383
384 static void pushaddrs(addr_v *av, const struct addrinfo *ailist)
385 {
386   const struct addrinfo *ai;
387   size_t i, n;
388
389   for (ai = ailist, n = 0; ai; ai = ai->ai_next)
390     if (knownafp(ai->ai_family)) n++;
391   DA_ENSURE(av, n);
392   for (i = DA_LEN(av), ai = ailist; ai; ai = ai->ai_next) {
393     if (!knownafp(ai->ai_family)) continue;
394     initaddr(&DA(av)[i], ai->ai_family);
395     copyaddr(&DA(av)[i++], ai->ai_addr, caf_addr | caf_port);
396   }
397   DA_EXTEND(av, n);
398 }
399
400 #define paf_parse 1u
401 static void parseaddr(const struct addrinfo *aihint,
402                       const char *host, const char *svc, unsigned f,
403                       struct addrinfo **ai_out)
404 {
405   char *alloc = 0, *sep;
406   int err;
407
408   if (f&paf_parse) {
409     alloc = xstrdup(host);
410     if (alloc[0] != '[') {
411       if ((sep = strchr(alloc, ':')) == 0)
412         die(1, "missing port number in address `%s'", host);
413       host = alloc; *sep = 0; svc = sep + 1;
414     } else {
415       if ((sep = strchr(alloc, ']')) == 0 || sep[1] != ':')
416         die(1, "bad syntax in address `%s:'", host);
417       host = alloc + 1; *sep = 0; svc = sep + 2;
418     }
419   }
420
421   err = getaddrinfo(host, svc, aihint, ai_out);
422   if (err) {
423     if (host && svc) {
424       die(1, "failed to resolve hostname `%s', service `%s': %s",
425           host, svc, gai_strerror(err));
426     } else if (host)
427       die(1, "failed to resolve hostname `%s': %s", host, gai_strerror(err));
428     else
429       die(1, "failed to resolve service `%s': %s", svc, gai_strerror(err));
430   }
431
432   xfree(alloc);
433 }
434
435 static void usage(FILE *fp)
436 {
437   pquis(fp,
438         "Usage: $ [-46] [-l PORT] [-b ADDR] [-p ADDR] [-c ADDR:PORT]\n\
439         ADDR:PORT ADDR:PORT\n");
440 }
441
442 static void version(FILE *fp)
443   { pquis(fp, "$, tripe version " VERSION "\n"); }
444
445 static void help(FILE *fp)
446 {
447   version(fp);
448   fputc('\n', fp);
449   usage(fp);
450   fputs("\n\
451 Options:\n\
452 \n\
453 -h, --help              Display this help text.\n\
454 -v, --version           Display version number.\n\
455 -u, --usage             Display pointless usage message.\n\
456 \n\
457 -4, --ipv4              Restrict to IPv4 only.\n\
458 -6, --ipv6              Restrict to IPv6 only.\n\
459 -l, --listen=PORT       Listen for connections to TCP PORT.\n\
460 -p, --peer=ADDR         Only accept connections from IP ADDR.\n\
461 -b, --bind=ADDR         Bind to ADDR before connecting.\n\
462 -c, --connect=ADDR:PORT Connect to IP ADDR, TCP PORT.\n\
463 \n\
464 Forwards UDP packets over a reliable stream.  By default, uses stdin and\n\
465 stdout; though it can use TCP sockets instead.\n\
466 ", fp);
467 }
468
469 int main(int argc, char *argv[])
470 {
471   unsigned f = 0;
472   str_v bindhosts = DA_INIT, peerhosts = DA_INIT;
473   const char *bindsvc = 0;
474   addr bindaddr;
475   const char *connhost = 0;
476   struct addrinfo aihint = { 0 }, *ai, *ailist;
477   int af = AF_UNSPEC;
478   int fd = -1;
479   int len = 65536;
480   size_t i, n;
481
482 #define f_bogus 1u
483
484   cw.f = 0;
485
486   ego(argv[0]);
487   sel_init(&sel);
488   for (;;) {
489     static struct option opt[] = {
490       { "help",                 0,              0,      'h' },
491       { "version",              0,              0,      'v' },
492       { "usage",                0,              0,      'u' },
493       { "ipv4",                 0,              0,      '4' },
494       { "ipv6",                 0,              0,      '6' },
495       { "listen",               OPTF_ARGREQ,    0,      'l' },
496       { "peer",                 OPTF_ARGREQ,    0,      'p' },
497       { "bind",                 OPTF_ARGREQ,    0,      'b' },
498       { "connect",              OPTF_ARGREQ,    0,      'c' },
499       { 0,                      0,              0,      0 }
500     };
501     int i;
502
503     i = mdwopt(argc, argv, "hvu46l:p:b:c:", opt, 0, 0, 0);
504     if (i < 0)
505       break;
506     switch (i) {
507       case 'h': help(stdout); exit(0);
508       case 'v': version(stdout); exit(0);
509       case 'u': usage(stdout); exit(0);
510       case '4': af = AF_INET; break;
511       case '6': af = AF_INET6; break;
512       case 'l': bindsvc = optarg; break;
513       case 'p': DA_PUSH(&peerhosts, optarg); break;
514       case 'b': DA_PUSH(&bindhosts, optarg); break;
515       case 'c': connhost = optarg; break;
516       default: f |= f_bogus; break;
517     }
518   }
519   if (optind + 2 != argc || (f&f_bogus)) { usage(stderr); exit(1); }
520
521   if (DA_LEN(&bindhosts) && !bindsvc && !connhost)
522     die(1, "bind addr only makes sense when listening or connecting");
523   if (DA_LEN(&peerhosts) && !bindsvc)
524     die(1, "peer addr only makes sense when listening");
525   if (bindsvc && connhost)
526     die(1, "can't listen and connect");
527
528   aihint.ai_family = af;
529   DA_CREATE(&cw.me); DA_CREATE(&cw.peer);
530
531   n = DA_LEN(&bindhosts);
532   if (n || bindsvc) {
533     aihint.ai_socktype = SOCK_STREAM;
534     aihint.ai_protocol = IPPROTO_TCP;
535     aihint.ai_flags = AI_ADDRCONFIG | AI_PASSIVE;
536     if (!n) {
537       parseaddr(&aihint, 0, bindsvc, 0, &ailist);
538       pushaddrs(&cw.me, ailist);
539       freeaddrinfo(ailist);
540     } else if (!bindsvc) {
541       if (n != 1) die(1, "can only bind to one address as client");
542       parseaddr(&aihint, DA(&bindhosts)[0], 0, 0, &ailist);
543       for (ai = ailist; ai && !knownafp(ai->ai_family); ai = ai->ai_next);
544       if (!ai)
545         die(1, "no usable addresses returned for `%s'", DA(&bindhosts)[0]);
546       initaddr(&bindaddr, ai->ai_family);
547       copyaddr(&bindaddr, ai->ai_addr, caf_addr);
548       aihint.ai_family = ai->ai_family;
549       freeaddrinfo(ailist);
550     } else for (i = 0; i < n; i++) {
551       parseaddr(&aihint, DA(&bindhosts)[i], bindsvc, 0, &ailist);
552       pushaddrs(&cw.me, ailist);
553       freeaddrinfo(ailist);
554     }
555     if (bindsvc) {
556       cw.f |= cwf_port;
557       n = DA_LEN(&cw.me);
558       cw.sfv = xmalloc(n*sizeof(*cw.sfv));
559     }
560   }
561
562   n = DA_LEN(&peerhosts);
563   if (n) {
564     aihint.ai_socktype = SOCK_STREAM;
565     aihint.ai_protocol = IPPROTO_TCP;
566     aihint.ai_flags = AI_ADDRCONFIG;
567     for (i = 0; i < n; i++) {
568       parseaddr(&aihint, DA(&peerhosts)[i], 0, 0, &ailist);
569       pushaddrs(&cw.peer, ailist);
570       freeaddrinfo(ailist);
571     }
572     if (!DA_LEN(&cw.peer)) die(1, "no usable peer addresses");
573   }
574
575   if (connhost) {
576     aihint.ai_socktype = SOCK_STREAM;
577     aihint.ai_protocol = IPPROTO_TCP;
578     aihint.ai_flags = AI_ADDRCONFIG;
579     parseaddr(&aihint, connhost, 0, paf_parse, &ailist);
580
581     for (ai = ailist; ai; ai = ai->ai_next) {
582       if ((fd = socket(ai->ai_family, SOCK_STREAM, IPPROTO_TCP)) >= 0 &&
583           !initsock(fd, ai->ai_family) &&
584           (!DA_LEN(&bindhosts) ||
585            !bind(fd, &bindaddr.sa, addrsz(&bindaddr))) &&
586           !connect(fd, ai->ai_addr, ai->ai_addrlen))
587         goto conn_tcp;
588       if (fd >= 0) close(fd);
589     }
590     die(1, "couldn't connect to TCP server: %s", strerror(errno));
591   conn_tcp:
592     if (nonblockify(fd) || cloexec(fd))
593       die(1, "couldn't connect to TCP server: %s", strerror(errno));
594   }
595
596   aihint.ai_family = af;
597   aihint.ai_socktype = SOCK_DGRAM;
598   aihint.ai_protocol = IPPROTO_UDP;
599   aihint.ai_flags = AI_ADDRCONFIG | AI_PASSIVE;
600   parseaddr(&aihint, argv[optind], 0, paf_parse, &ailist);
601   for (ai = ailist; ai && !knownafp(ai->ai_family); ai = ai->ai_next);
602   if (!ai) die(1, "no usable addresses returned for `%s'", argv[optind]);
603   if ((fd_udp = socket(ai->ai_family, SOCK_DGRAM, IPPROTO_UDP)) < 0 ||
604       initsock(fd_udp, ai->ai_family) ||
605       nonblockify(fd_udp) || cloexec(fd_udp) ||
606       setsockopt(fd_udp, SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)) ||
607       setsockopt(fd_udp, SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)) ||
608       bind(fd_udp, ai->ai_addr, ai->ai_addrlen))
609     die(1, "couldn't set up UDP socket: %s", strerror(errno));
610   freeaddrinfo(ailist);
611   aihint.ai_family = ai->ai_family;
612   aihint.ai_flags = AI_ADDRCONFIG;
613   parseaddr(&aihint, argv[optind + 1], 0, paf_parse, &ailist);
614   for (ai = ailist; ai; ai = ai->ai_next)
615     if (!connect(fd_udp, ai->ai_addr, ai->ai_addrlen)) goto conn_udp;
616   die(1, "couldn't set up UDP socket: %s", strerror(errno));
617 conn_udp:
618
619   if (bindsvc) dolisten();
620   else if (connhost) dofwd(fd, fd);
621   else dofwd(STDIN_FILENO, STDOUT_FILENO);
622
623   for (;;) {
624     if (sel_select(&sel) && errno != EINTR)
625       die(1, "select failed: %s", strerror(errno));
626   }
627   return (0);
628 }
629
630 /*----- That's all, folks -------------------------------------------------*/