chiark / gitweb /
Use standard GNU uppercase for metavariables in usage strings. Some manpage
[tripe] / pkstream.c
1 /* -*-c-*-
2  *
3  * $Id: pkstream.c,v 1.3 2004/04/08 01:36:17 mdw Exp $
4  *
5  * Forwarding UDP packets over a stream
6  *
7  * (c) 2003 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Trivial IP Encryption (TrIPE).
13  *
14  * TrIPE is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU General Public License as published by
16  * the Free Software Foundation; either version 2 of the License, or
17  * (at your option) any later version.
18  * 
19  * TrIPE is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU General Public License for more details.
23  * 
24  * You should have received a copy of the GNU General Public License
25  * along with TrIPE; if not, write to the Free Software Foundation,
26  * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
27  */
28
29 /*----- Header files ------------------------------------------------------*/
30
31 #include "config.h"
32
33 #include <ctype.h>
34 #include <errno.h>
35 #include <stdio.h>
36 #include <stdlib.h>
37 #include <string.h>
38
39 #include <sys/time.h>
40 #include <sys/types.h>
41 #include <unistd.h>
42 #include <fcntl.h>
43 #include <sys/uio.h>
44 #include <sys/socket.h>
45 #include <netinet/in.h>
46 #include <arpa/inet.h>
47 #include <netdb.h>
48
49 #include <mLib/alloc.h>
50 #include <mLib/bits.h>
51 #include <mLib/dstr.h>
52 #include <mLib/fdflags.h>
53 #include <mLib/mdwopt.h>
54 #include <mLib/quis.h>
55 #include <mLib/report.h>
56 #include <mLib/sel.h>
57 #include <mLib/selpk.h>
58
59 /*----- Data structures ---------------------------------------------------*/
60
61 typedef struct pk {
62   struct pk *next;                      /* Next packet in the chain */
63   octet *p, *o;                         /* Buffer start and current posn */
64   size_t n;                             /* Size of packet remaining */
65 } pk;
66
67 typedef struct pkstream {
68   unsigned f;                           /* Flags... */
69 #define PKF_FULL 1u                     /*   Buffer is full: stop reading */
70   sel_file r, w;                        /* Read and write selectors */
71   pk *pks, **pk_tail;                   /* Packet queue */
72   size_t npk, szpk;                     /* Number and size of data */
73   selpk p;                              /* Packet parser */
74 } pkstream;
75
76 typedef struct connwait {
77   sel_file a;                           /* Selector */
78   struct sockaddr_in me;                /* Who I'm meant to be */
79   struct in_addr peer;                  /* Who my peer is */
80 } connwait;
81
82 /*----- Static variables --------------------------------------------------*/
83
84 static sel_state sel;
85 static connwait cw;
86 static int fd_udp;
87 static size_t pk_nmax = 128, pk_szmax = 1024 * 1024;
88
89 /*----- Main code ---------------------------------------------------------*/
90
91 static int nonblockify(int fd)
92 {
93   return (fdflags(fd, O_NONBLOCK, O_NONBLOCK, 0, 0));
94 }
95
96 static int cloexec(int fd)
97 {
98   return (fdflags(fd, 0, 0, FD_CLOEXEC, FD_CLOEXEC));
99 }
100
101 static void dolisten(void);
102
103 static void doclose(pkstream *p)
104 {
105   pk *pk, *ppk;
106   close(p->w.fd);
107   close(p->p.reader.fd);
108   selpk_destroy(&p->p);
109   if (!(p->f & PKF_FULL))
110     sel_rmfile(&p->r);
111   if (p->npk)
112     sel_rmfile(&p->w);
113   for (pk = p->pks; pk; pk = ppk) {
114     ppk = pk->next;
115     xfree(pk->p);
116     xfree(pk);
117   }
118   xfree(p);
119   if (cw.me.sin_port != 0)
120     dolisten();
121   else
122     exit(0);
123 }
124
125 static void rdtcp(octet *b, size_t sz, pkbuf *pk, size_t *k, void *vp)
126 {
127   pkstream *p = vp;
128   size_t pksz;
129
130   if (!sz) {
131     doclose(p);
132     return;
133   }
134   pksz = LOAD16(b);
135   if (pksz + 2 == sz) {
136     write(fd_udp, b + 2, pksz);
137     selpk_want(&p->p, 2);
138   } else {
139     selpk_want(&p->p, pksz + 2);
140     *k = sz;
141   }
142 }
143
144 static void wrtcp(int fd, unsigned mode, void *vp)
145 {
146 #define NPK 16
147   struct iovec iov[NPK];
148   pkstream *p = vp;
149   size_t i;
150   ssize_t n;
151   pk *pk, *ppk;
152
153   for (i = 0, pk = p->pks; i < NPK && pk; i++, pk = pk->next) {
154     iov[i].iov_base = pk->o;
155     iov[i].iov_len = pk->n;
156   }
157
158   if ((n = writev(fd, iov, i)) < 0) {
159     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
160       return;
161     moan("couldn't write to TCP socket: %s", strerror(errno));
162     doclose(p);
163     return;
164   }
165
166   p->szpk -= n;
167   for (pk = p->pks; n && pk; pk = ppk) {
168     ppk = pk->next;
169     if (pk->n <= n) {
170       p->npk--;
171       n -= pk->n;
172       xfree(pk->p);
173       xfree(pk);
174     } else {
175       pk->n -= n;
176       pk->o += n;
177       break;
178     }
179   }
180   p->pks = pk;
181   if (!pk) {
182     p->pk_tail = &p->pks;
183     sel_rmfile(&p->w);
184   }
185   if ((p->f & PKF_FULL) && p->npk < pk_nmax && p->szpk < pk_szmax) {
186     p->f &= ~PKF_FULL;
187     sel_addfile(&p->r);
188   }
189 }
190
191 static void rdudp(int fd, unsigned mode, void *vp)
192 {
193   octet buf[65536];
194   ssize_t n;
195   pkstream *p = vp;
196   pk *pk;
197
198   if ((n = read(fd, buf, sizeof(buf))) < 0) {
199     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
200       return;
201     moan("couldn't read from UDP socket: %s", strerror(errno));
202     return;
203   }
204   pk = xmalloc(sizeof(*pk));
205   pk->next = 0;
206   pk->p = xmalloc(n + 2);
207   STORE16(pk->p, n);
208   memcpy(pk->p + 2, buf, n);
209   pk->o = pk->p;
210   pk->n = n + 2;
211   *p->pk_tail = pk;
212   p->pk_tail = &pk->next;
213   if (!p->npk)
214     sel_addfile(&p->w);
215   sel_force(&p->w);
216   p->npk++;
217   p->szpk += n + 2;
218   if (p->npk >= pk_nmax || p->szpk >= pk_szmax) {
219     sel_rmfile(&p->r);
220     p->f |= PKF_FULL;
221   }
222 }
223
224 static void dofwd(int fd_in, int fd_out)
225 {
226   pkstream *p = xmalloc(sizeof(*p));
227   sel_initfile(&sel, &p->r, fd_udp, SEL_READ, rdudp, p);
228   sel_initfile(&sel, &p->w, fd_out, SEL_WRITE, wrtcp, p);
229   selpk_init(&p->p, &sel, fd_in, rdtcp, p);
230   selpk_want(&p->p, 2);
231   p->pks = 0;
232   p->pk_tail = &p->pks;
233   p->npk = p->szpk = 0;
234   p->f = 0;
235   sel_addfile(&p->r);
236 }
237
238 static void doaccept(int fd_s, unsigned mode, void *p)
239 {
240   int fd;
241   struct sockaddr_in sin;
242   socklen_t sz = sizeof(sin);
243
244   if ((fd = accept(fd_s, (struct sockaddr *)&sin, &sz)) < 0) {
245     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
246       return;
247     moan("couldn't accept incoming connection: %s", strerror(errno));
248     return;
249   }
250   if (cw.peer.s_addr != INADDR_ANY &&
251       cw.peer.s_addr != sin.sin_addr.s_addr) {
252     close(fd);
253     moan("rejecting connection from %s", inet_ntoa(sin.sin_addr));
254     return;
255   }
256   if (nonblockify(fd) || cloexec(fd)) {
257     close(fd);
258     moan("couldn't accept incoming connection: %s", strerror(errno));
259     return;
260   }
261   dofwd(fd, fd);
262   close(fd_s);
263   sel_rmfile(&cw.a);
264 }
265
266 static void dolisten(void)
267 {
268   int fd;
269   int opt = 1;
270
271   if ((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0 ||
272       setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) ||
273       bind(fd, (struct sockaddr *)&cw.me, sizeof(cw.me)) ||
274       listen(fd, 1) || nonblockify(fd) || cloexec(fd))
275     die(1, "couldn't set up listening socket: %s", strerror(errno));
276   sel_initfile(&sel, &cw.a, fd, SEL_READ, doaccept, 0);
277   sel_addfile(&cw.a);
278 }
279
280 static void parseaddr(const char *pp, struct in_addr *a, unsigned short *pt)
281 {
282   char *p = xstrdup(pp);
283   char *q = 0;
284   if (a && pt) {
285     strtok(p, ":");
286     q = strtok(0, "");
287     if (!q)
288       die(1, "missing port number in address `%s'", p);
289   } else if (pt) {
290     q = p;
291   }
292
293   if (a) {
294     struct hostent *h;
295     if ((h = gethostbyname(p)) == 0)
296       die(1, "unknown host `%s'", p);
297     memcpy(a, h->h_addr, sizeof(*a));
298   }
299
300   if (pt) {
301     struct servent *s;
302     char *qq;
303     unsigned long n;
304     if ((s = getservbyname(q, "tcp")) != 0)
305       *pt = s->s_port;
306     else if ((n = strtoul(q, &qq, 0)) == 0 || *qq || n > 0xffff)
307       die(1, "bad port number `%s'", q);
308     else
309       *pt = htons(n);
310   }
311 }
312
313 static void usage(FILE *fp)
314 {
315   pquis(fp,
316         "Usage: $ [-l PORT] [-b ADDR] [-p ADDR] [-c ADDR:PORT]\n\
317         ADDR:PORT ADDR:PORT\n");
318 }
319
320 static void version(FILE *fp)
321 {
322   pquis(fp, "$, tripe version " VERSION "\n");
323 }
324
325 static void help(FILE *fp)
326 {
327   version(fp);
328   fputc('\n', fp);
329   usage(fp);
330   fputs("\n\
331 Options:\n\
332 \n\
333 -h, --help              Display this help text.\n\
334 -v, --version           Display version number.\n\
335 -u, --usage             Display pointless usage message.\n\
336 \n\
337 -l, --listen=PORT       Listen for connections to TCP PORT.\n\
338 -p, --peer=ADDR         Only accept connections from IP ADDR.\n\
339 -b, --bind=ADDR         Bind to ADDR before connecting.\n\
340 -c, --connect=ADDR:PORT Connect to IP ADDR, TCP PORT.\n\
341 \n\
342 Forwards UDP packets over a reliable stream.  By default, uses stdin and\n\
343 stdout; though it can use TCP sockets instead.\n\
344 ", fp);
345 }
346
347 int main(int argc, char *argv[])
348 {
349   unsigned f = 0;
350   unsigned short pt;
351   struct sockaddr_in connaddr, bindaddr;
352   struct sockaddr_in udp_me, udp_peer;
353   int len = 65536;
354
355 #define f_bogus 1u
356
357   ego(argv[0]);
358   bindaddr.sin_family = AF_INET;
359   bindaddr.sin_addr.s_addr = INADDR_ANY;
360   bindaddr.sin_port = 0;
361   connaddr.sin_family = AF_INET;
362   connaddr.sin_addr.s_addr = INADDR_ANY;
363   cw.me.sin_family = AF_INET;
364   cw.me.sin_addr.s_addr = INADDR_ANY;
365   cw.me.sin_port = 0;
366   cw.peer.s_addr = INADDR_ANY;
367   sel_init(&sel);
368   for (;;) {
369     static struct option opt[] = {
370       { "help",                 0,              0,      'h' },
371       { "version",              0,              0,      'v' },
372       { "usage",                0,              0,      'u' },
373       { "listen",               OPTF_ARGREQ,    0,      'l' },
374       { "peer",                 OPTF_ARGREQ,    0,      'p' },
375       { "bind",                 OPTF_ARGREQ,    0,      'b' },
376       { "connect",              OPTF_ARGREQ,    0,      'c' },
377       { 0,                      0,              0,      0 }
378     };
379     int i;
380
381     i = mdwopt(argc, argv, "hvul:p:b:c:", opt, 0, 0, 0);
382     if (i < 0)
383       break;
384     switch (i) {
385       case 'h':
386         help(stdout);
387         exit(0);
388       case 'v':
389         version(stdout);
390         exit(0);
391       case 'u':
392         usage(stdout);
393         exit(0);
394       case 'l':
395         parseaddr(optarg, 0, &pt);
396         cw.me.sin_port = pt;
397         break;
398       case 'p':
399         parseaddr(optarg, &cw.peer, 0);
400         break;
401       case 'b':
402         parseaddr(optarg, &bindaddr.sin_addr, 0);
403         break;
404       case 'c':
405         parseaddr(optarg, &connaddr.sin_addr, &pt);
406         connaddr.sin_port = pt;
407         break;
408       default:
409         f |= f_bogus;
410         break;
411     }
412   }
413   if (optind + 2 != argc || (f & f_bogus)) {
414     usage(stderr);
415     exit(1);
416   }
417
418   udp_me.sin_family = udp_peer.sin_family = AF_INET;
419   parseaddr(argv[optind], &udp_me.sin_addr, &pt);
420   udp_me.sin_port = pt; 
421   parseaddr(argv[optind + 1], &udp_peer.sin_addr, &pt);
422   udp_peer.sin_port = pt;
423
424   if ((fd_udp = socket(PF_INET, SOCK_DGRAM, 0)) < 0 ||
425       bind(fd_udp, (struct sockaddr *)&udp_me, sizeof(udp_me)) ||
426       connect(fd_udp, (struct sockaddr *)&udp_peer, sizeof(udp_peer)) ||
427       setsockopt(fd_udp, SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)) ||
428       setsockopt(fd_udp, SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)) ||
429       nonblockify(fd_udp) || cloexec(fd_udp))
430     die(1, "couldn't set up UDP socket: %s", strerror(errno));
431
432   if (cw.me.sin_port != 0)
433     dolisten();
434   else if (connaddr.sin_addr.s_addr != INADDR_ANY) {
435     int fd;
436     if ((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0 ||
437         bind(fd, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) ||
438         connect(fd, (struct sockaddr *)&connaddr, sizeof(connaddr)) ||
439         nonblockify(fd) || cloexec(fd))
440       die(1, "couldn't connect to TCP server: %s", strerror(errno));
441     dofwd(fd, fd);
442   } else
443     dofwd(STDIN_FILENO, STDOUT_FILENO);
444
445   for (;;)
446     sel_select(&sel);
447   return (0);
448 }
449
450 /*----- That's all, folks -------------------------------------------------*/