chiark / gitweb /
server/admin.c (a_vformat): Fix uses of `va_arg' to dereference `ap'.
[tripe] / pkstream / pkstream.c
1 /* -*-c-*-
2  *
3  * Forwarding UDP packets over a stream
4  *
5  * (c) 2003 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Trivial IP Encryption (TrIPE).
11  *
12  * TrIPE is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU General Public License as published by
14  * the Free Software Foundation; either version 2 of the License, or
15  * (at your option) any later version.
16  *
17  * TrIPE is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU General Public License for more details.
21  *
22  * You should have received a copy of the GNU General Public License
23  * along with TrIPE; if not, write to the Free Software Foundation,
24  * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25  */
26
27 /*----- Header files ------------------------------------------------------*/
28
29 #include "config.h"
30
31 #include <ctype.h>
32 #include <errno.h>
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <string.h>
36
37 #include <sys/time.h>
38 #include <sys/types.h>
39 #include <unistd.h>
40 #include <fcntl.h>
41 #include <sys/uio.h>
42 #include <sys/socket.h>
43 #include <netinet/in.h>
44 #include <arpa/inet.h>
45 #include <netdb.h>
46
47 #include <mLib/alloc.h>
48 #include <mLib/bits.h>
49 #include <mLib/dstr.h>
50 #include <mLib/fdflags.h>
51 #include <mLib/mdwopt.h>
52 #include <mLib/quis.h>
53 #include <mLib/report.h>
54 #include <mLib/sel.h>
55 #include <mLib/selpk.h>
56
57 #include "util.h"
58
59 /*----- Data structures ---------------------------------------------------*/
60
61 typedef struct pk {
62   struct pk *next;                      /* Next packet in the chain */
63   octet *p, *o;                         /* Buffer start and current posn */
64   size_t n;                             /* Size of packet remaining */
65 } pk;
66
67 typedef struct pkstream {
68   unsigned f;                           /* Flags... */
69 #define PKF_FULL 1u                     /*   Buffer is full: stop reading */
70   sel_file r, w;                        /* Read and write selectors */
71   pk *pks, **pk_tail;                   /* Packet queue */
72   size_t npk, szpk;                     /* Number and size of data */
73   selpk p;                              /* Packet parser */
74 } pkstream;
75
76 typedef struct connwait {
77   sel_file a;                           /* Selector */
78   struct sockaddr_in me;                /* Who I'm meant to be */
79   struct in_addr peer;                  /* Who my peer is */
80 } connwait;
81
82 /*----- Static variables --------------------------------------------------*/
83
84 static sel_state sel;
85 static connwait cw;
86 static int fd_udp;
87 static size_t pk_nmax = 128, pk_szmax = 1024 * 1024;
88
89 /*----- Main code ---------------------------------------------------------*/
90
91 static int nonblockify(int fd)
92   { return (fdflags(fd, O_NONBLOCK, O_NONBLOCK, 0, 0)); }
93
94 static int cloexec(int fd)
95   { return (fdflags(fd, 0, 0, FD_CLOEXEC, FD_CLOEXEC)); }
96
97 static void dolisten(void);
98
99 static void doclose(pkstream *p)
100 {
101   pk *pk, *ppk;
102   close(p->w.fd);
103   close(p->p.reader.fd);
104   selpk_destroy(&p->p);
105   if (!(p->f & PKF_FULL))
106     sel_rmfile(&p->r);
107   if (p->npk)
108     sel_rmfile(&p->w);
109   for (pk = p->pks; pk; pk = ppk) {
110     ppk = pk->next;
111     xfree(pk->p);
112     xfree(pk);
113   }
114   xfree(p);
115   if (cw.me.sin_port != 0)
116     dolisten();
117   else
118     exit(0);
119 }
120
121 static void rdtcp(octet *b, size_t sz, pkbuf *pk, size_t *k, void *vp)
122 {
123   pkstream *p = vp;
124   size_t pksz;
125
126   if (!sz) {
127     doclose(p);
128     return;
129   }
130   pksz = LOAD16(b);
131   if (pksz + 2 == sz) {
132     DISCARD(write(fd_udp, b + 2, pksz));
133     selpk_want(&p->p, 2);
134   } else {
135     selpk_want(&p->p, pksz + 2);
136     *k = sz;
137   }
138 }
139
140 static void wrtcp(int fd, unsigned mode, void *vp)
141 {
142 #define NPK 16
143   struct iovec iov[NPK];
144   pkstream *p = vp;
145   size_t i;
146   ssize_t n;
147   pk *pk, *ppk;
148
149   for (i = 0, pk = p->pks; i < NPK && pk; i++, pk = pk->next) {
150     iov[i].iov_base = pk->o;
151     iov[i].iov_len = pk->n;
152   }
153
154   if ((n = writev(fd, iov, i)) < 0) {
155     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
156       return;
157     moan("couldn't write to TCP socket: %s", strerror(errno));
158     doclose(p);
159     return;
160   }
161
162   p->szpk -= n;
163   for (pk = p->pks; n && pk; pk = ppk) {
164     ppk = pk->next;
165     if (pk->n <= n) {
166       p->npk--;
167       n -= pk->n;
168       xfree(pk->p);
169       xfree(pk);
170     } else {
171       pk->n -= n;
172       pk->o += n;
173       break;
174     }
175   }
176   p->pks = pk;
177   if (!pk) {
178     p->pk_tail = &p->pks;
179     sel_rmfile(&p->w);
180   }
181   if ((p->f & PKF_FULL) && p->npk < pk_nmax && p->szpk < pk_szmax) {
182     p->f &= ~PKF_FULL;
183     sel_addfile(&p->r);
184   }
185 }
186
187 static void rdudp(int fd, unsigned mode, void *vp)
188 {
189   octet buf[65536];
190   ssize_t n;
191   pkstream *p = vp;
192   pk *pk;
193
194   if ((n = read(fd, buf, sizeof(buf))) < 0) {
195     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
196       return;
197     moan("couldn't read from UDP socket: %s", strerror(errno));
198     return;
199   }
200   pk = xmalloc(sizeof(*pk));
201   pk->next = 0;
202   pk->p = xmalloc(n + 2);
203   STORE16(pk->p, n);
204   memcpy(pk->p + 2, buf, n);
205   pk->o = pk->p;
206   pk->n = n + 2;
207   *p->pk_tail = pk;
208   p->pk_tail = &pk->next;
209   if (!p->npk)
210     sel_addfile(&p->w);
211   sel_force(&p->w);
212   p->npk++;
213   p->szpk += n + 2;
214   if (p->npk >= pk_nmax || p->szpk >= pk_szmax) {
215     sel_rmfile(&p->r);
216     p->f |= PKF_FULL;
217   }
218 }
219
220 static void dofwd(int fd_in, int fd_out)
221 {
222   pkstream *p = xmalloc(sizeof(*p));
223   sel_initfile(&sel, &p->r, fd_udp, SEL_READ, rdudp, p);
224   sel_initfile(&sel, &p->w, fd_out, SEL_WRITE, wrtcp, p);
225   selpk_init(&p->p, &sel, fd_in, rdtcp, p);
226   selpk_want(&p->p, 2);
227   p->pks = 0;
228   p->pk_tail = &p->pks;
229   p->npk = p->szpk = 0;
230   p->f = 0;
231   sel_addfile(&p->r);
232 }
233
234 static void doaccept(int fd_s, unsigned mode, void *p)
235 {
236   int fd;
237   struct sockaddr_in sin;
238   socklen_t sz = sizeof(sin);
239
240   if ((fd = accept(fd_s, (struct sockaddr *)&sin, &sz)) < 0) {
241     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
242       return;
243     moan("couldn't accept incoming connection: %s", strerror(errno));
244     return;
245   }
246   if (cw.peer.s_addr != INADDR_ANY &&
247       cw.peer.s_addr != sin.sin_addr.s_addr) {
248     close(fd);
249     moan("rejecting connection from %s", inet_ntoa(sin.sin_addr));
250     return;
251   }
252   if (nonblockify(fd) || cloexec(fd)) {
253     close(fd);
254     moan("couldn't accept incoming connection: %s", strerror(errno));
255     return;
256   }
257   dofwd(fd, fd);
258   close(fd_s);
259   sel_rmfile(&cw.a);
260 }
261
262 static void dolisten(void)
263 {
264   int fd;
265   int opt = 1;
266
267   if ((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0 ||
268       setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) ||
269       bind(fd, (struct sockaddr *)&cw.me, sizeof(cw.me)) ||
270       listen(fd, 1) || nonblockify(fd) || cloexec(fd))
271     die(1, "couldn't set up listening socket: %s", strerror(errno));
272   sel_initfile(&sel, &cw.a, fd, SEL_READ, doaccept, 0);
273   sel_addfile(&cw.a);
274 }
275
276 static void parseaddr(const char *pp, struct in_addr *a, unsigned short *pt)
277 {
278   char *p = xstrdup(pp);
279   char *q = 0;
280   if (a && pt) {
281     strtok(p, ":");
282     q = strtok(0, "");
283     if (!q)
284       die(1, "missing port number in address `%s'", p);
285   } else if (pt) {
286     q = p;
287   }
288
289   if (a) {
290     struct hostent *h;
291     if ((h = gethostbyname(p)) == 0)
292       die(1, "unknown host `%s'", p);
293     memcpy(a, h->h_addr, sizeof(*a));
294   }
295
296   if (pt) {
297     struct servent *s;
298     char *qq;
299     unsigned long n;
300     if ((s = getservbyname(q, "tcp")) != 0)
301       *pt = s->s_port;
302     else if ((n = strtoul(q, &qq, 0)) == 0 || *qq || n > 0xffff)
303       die(1, "bad port number `%s'", q);
304     else
305       *pt = htons(n);
306   }
307 }
308
309 static void usage(FILE *fp)
310 {
311   pquis(fp,
312         "Usage: $ [-l PORT] [-b ADDR] [-p ADDR] [-c ADDR:PORT]\n\
313         ADDR:PORT ADDR:PORT\n");
314 }
315
316 static void version(FILE *fp)
317   { pquis(fp, "$, tripe version " VERSION "\n"); }
318
319 static void help(FILE *fp)
320 {
321   version(fp);
322   fputc('\n', fp);
323   usage(fp);
324   fputs("\n\
325 Options:\n\
326 \n\
327 -h, --help              Display this help text.\n\
328 -v, --version           Display version number.\n\
329 -u, --usage             Display pointless usage message.\n\
330 \n\
331 -l, --listen=PORT       Listen for connections to TCP PORT.\n\
332 -p, --peer=ADDR         Only accept connections from IP ADDR.\n\
333 -b, --bind=ADDR         Bind to ADDR before connecting.\n\
334 -c, --connect=ADDR:PORT Connect to IP ADDR, TCP PORT.\n\
335 \n\
336 Forwards UDP packets over a reliable stream.  By default, uses stdin and\n\
337 stdout; though it can use TCP sockets instead.\n\
338 ", fp);
339 }
340
341 int main(int argc, char *argv[])
342 {
343   unsigned f = 0;
344   unsigned short pt;
345   struct sockaddr_in connaddr, bindaddr;
346   struct sockaddr_in udp_me, udp_peer;
347   int len = 65536;
348
349 #define f_bogus 1u
350
351   ego(argv[0]);
352   bindaddr.sin_family = AF_INET;
353   bindaddr.sin_addr.s_addr = INADDR_ANY;
354   bindaddr.sin_port = 0;
355   connaddr.sin_family = AF_INET;
356   connaddr.sin_addr.s_addr = INADDR_ANY;
357   cw.me.sin_family = AF_INET;
358   cw.me.sin_addr.s_addr = INADDR_ANY;
359   cw.me.sin_port = 0;
360   cw.peer.s_addr = INADDR_ANY;
361   sel_init(&sel);
362   for (;;) {
363     static struct option opt[] = {
364       { "help",                 0,              0,      'h' },
365       { "version",              0,              0,      'v' },
366       { "usage",                0,              0,      'u' },
367       { "listen",               OPTF_ARGREQ,    0,      'l' },
368       { "peer",                 OPTF_ARGREQ,    0,      'p' },
369       { "bind",                 OPTF_ARGREQ,    0,      'b' },
370       { "connect",              OPTF_ARGREQ,    0,      'c' },
371       { 0,                      0,              0,      0 }
372     };
373     int i;
374
375     i = mdwopt(argc, argv, "hvul:p:b:c:", opt, 0, 0, 0);
376     if (i < 0)
377       break;
378     switch (i) {
379       case 'h':
380         help(stdout);
381         exit(0);
382       case 'v':
383         version(stdout);
384         exit(0);
385       case 'u':
386         usage(stdout);
387         exit(0);
388       case 'l':
389         parseaddr(optarg, 0, &pt);
390         cw.me.sin_port = pt;
391         break;
392       case 'p':
393         parseaddr(optarg, &cw.peer, 0);
394         break;
395       case 'b':
396         parseaddr(optarg, &bindaddr.sin_addr, 0);
397         cw.me.sin_addr = bindaddr.sin_addr;
398         break;
399       case 'c':
400         parseaddr(optarg, &connaddr.sin_addr, &pt);
401         connaddr.sin_port = pt;
402         break;
403       default:
404         f |= f_bogus;
405         break;
406     }
407   }
408   if (optind + 2 != argc || (f & f_bogus)) {
409     usage(stderr);
410     exit(1);
411   }
412
413   udp_me.sin_family = udp_peer.sin_family = AF_INET;
414   parseaddr(argv[optind], &udp_me.sin_addr, &pt);
415   udp_me.sin_port = pt;
416   parseaddr(argv[optind + 1], &udp_peer.sin_addr, &pt);
417   udp_peer.sin_port = pt;
418
419   if ((fd_udp = socket(PF_INET, SOCK_DGRAM, 0)) < 0 ||
420       bind(fd_udp, (struct sockaddr *)&udp_me, sizeof(udp_me)) ||
421       connect(fd_udp, (struct sockaddr *)&udp_peer, sizeof(udp_peer)) ||
422       setsockopt(fd_udp, SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)) ||
423       setsockopt(fd_udp, SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)) ||
424       nonblockify(fd_udp) || cloexec(fd_udp))
425     die(1, "couldn't set up UDP socket: %s", strerror(errno));
426
427   if (cw.me.sin_port != 0)
428     dolisten();
429   else if (connaddr.sin_addr.s_addr != INADDR_ANY) {
430     int fd;
431     if ((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0 ||
432         bind(fd, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) ||
433         connect(fd, (struct sockaddr *)&connaddr, sizeof(connaddr)) ||
434         nonblockify(fd) || cloexec(fd))
435       die(1, "couldn't connect to TCP server: %s", strerror(errno));
436     dofwd(fd, fd);
437   } else
438     dofwd(STDIN_FILENO, STDOUT_FILENO);
439
440   for (;;)
441     sel_select(&sel);
442   return (0);
443 }
444
445 /*----- That's all, folks -------------------------------------------------*/