chiark / gitweb /
86cd3170388c2d9a66d32b7925616411a8a38874
[chiark-utils.git] / cprogs / rcopy-repeatedly.c
1 /*
2  * rcopy-repeatedly
3  */     
4   
5 /*
6  * protocol is:
7  *   server sends banner
8  *    - "#rcopy-repeatedly#\n"
9  *    - length of declaration, as 4 hex digits, zero prefixed,
10  *      and a space [5 bytes].  In this protocol version this
11  *      will be "0002" but client _must_ parse it.
12  *   server sends declaration
13  *    - one of "u " or "d" [1 byte]
14  *    - optionally, some more ascii text, reserved for future use
15  *      must be ignored by client (but not sent by server)
16  *    - a newline [1 byte]
17  *   client sends
18  *    - 0x01   go
19  * then for each update
20  *   sender sends one of
21  *    - 0x03   destination file should be deleted
22  *             but note that contents must be retained by receiver
23  *             as it may be used for rle updates
24  *    - 0x04   complete new destination file follows, 64-bit length
25  *        l    8 bytes big endian length
26  *        ...  l bytes data
27  *             receiver must then reply with 0x01 GO
28  */
29
30 #define _GNU_SOURCE
31
32 #include <stdio.h>
33 #include <time.h>
34 #include <stdarg.h>
35 #include <stdlib.h>
36 #include <stdint.h>
37 #include <string.h>
38 #include <errno.h>
39 #include <limits.h>
40 #include <assert.h>
41 #include <math.h>
42
43 #include <sys/types.h>
44 #include <sys/stat.h>
45 #include <unistd.h>
46
47 #include "myopt.h"
48
49 #define REPLMSG_GO     0x01
50 #define REPLMSG_RM     0x03
51 #define REPLMSG_FILE64 0x04
52
53 static const char banner[]= "#rcopy-repeatedly#\n";
54
55 static FILE *commsi, *commso;
56
57 static double max_bw_prop_mean= 0.2;
58 static double max_bw_prop_burst= 0.8;
59 static int txblocksz= INT_MAX, verbose=1;
60 static int min_interval_usec= 100000; /* 100ms */
61
62 static int nsargs;
63 static const char **sargs;
64
65 static const char *rsh_program= 0;
66 static const char *rcopy_repeatedly_program= "rcopy-repeatedly";
67 static int server_upcopy=-1; /* -1 means not yet known; 0 means download */
68   /* `up' means towards the client,
69    * since we regard the subprocess as `down' */
70
71 static double stream_allow_secsperbyte= 1/1e6; /* for initial transfer */
72
73 static char mainbuf[65536]; /* must be at least 2^16 */
74
75 #define NORETURN __attribute__((noreturn))
76
77 static void vdie(int ec, const char *pfx, int eno,
78                  const char *fmt, va_list al) NORETURN;
79 static void vdie(int ec, const char *pfx, int eno,
80                  const char *fmt, va_list al) {
81   fputs("rcopy-repeatedly: ",stderr);
82   if (server_upcopy>=0) fputs("server: ",stderr);
83   if (pfx) fprintf(stderr,"%s: ",pfx);
84   vfprintf(stderr,fmt,al);
85   if (eno!=-1) fprintf(stderr,": %s",strerror(eno));
86   fputc('\n',stderr);
87   exit(ec);
88 }
89 static void die(int ec, const char *pfx, int eno,
90                 const char *fmt, ...) NORETURN;
91 static void die(int ec, const char *pfx, int eno,
92                 const char *fmt, ...) {
93   va_list al;
94   va_start(al,fmt);
95   vdie(ec,pfx,eno,fmt,al);
96 }
97
98 static void diem(void) NORETURN;
99 static void diem(void) { die(16,0,errno,"malloc failed"); }
100 static void *xmalloc(size_t sz) {
101   assert(sz);
102   void *p= malloc(sz);
103   if (!p) diem();
104   return p;
105 }
106 static void *xrealloc(void *p, size_t sz) {
107   assert(sz);
108   p= realloc(p,sz);
109   if (!p) diem();
110   return p;
111 }
112
113 static void diee(const char *fmt, ...) NORETURN;
114 static void diee(const char *fmt, ...) {
115   va_list al;
116   va_start(al,fmt);
117   vdie(12,0,errno,fmt,al);
118 }
119 static void die_protocol(const char *fmt, ...) NORETURN;
120 static void die_protocol(const char *fmt, ...) {
121   va_list al;
122   va_start(al,fmt);
123   vdie(10,"protocol error",-1,fmt,al);
124 }
125
126 static void die_badrecv(const char *what) NORETURN;
127 static void die_badrecv(const char *what) {
128   if (ferror(commsi)) diee("communication failed while receiving %s", what);
129   if (feof(commsi)) die_protocol("receiver got unexpected EOF in %s", what);
130   abort();
131 }
132 static void die_badsend(void) NORETURN;
133 static void die_badsend(void) {
134   diee("transmission failed");
135 }
136
137 static void send_flush(void) {
138   if (ferror(commso) || fflush(commso))
139     die_badsend();
140 }
141 static void sendbyte(int c) {
142   if (putc(c,commso)==EOF)
143     die_badsend();
144 }
145
146 static void mpipe(int p[2]) { if (pipe(p)) diee("could not create pipe"); }
147 static void mdup2(int fd, int fd2) {
148   if (dup2(fd,fd2)!=fd2) diee("could not dup2(%d,%d)",fd,fd2);
149 }
150
151 typedef void copyfile_die_fn(FILE *f, const char *xi);
152
153 struct timespec ts_sendstart;
154
155 static void mgettime(struct timespec *ts) {
156   int r= clock_gettime(CLOCK_MONOTONIC, ts);
157   if (r) diee("clock_gettime failed");
158 }
159
160 static void bandlimit_sendstart(void) {
161   mgettime(&ts_sendstart);
162 }
163
164 static double mgettime_elapsed(struct timespec ts_base,
165                                struct timespec *ts_ret) {
166   mgettime(ts_ret);
167   return (ts_ret->tv_sec - ts_base.tv_sec) +
168          (ts_ret->tv_nsec - ts_base.tv_nsec)*1e-9;
169 }
170
171 static void flushstderr(void) {
172   if (ferror(stderr) || fflush(stderr))
173     diee("could not write progress to stderr");
174 }
175
176 static void verbosespinprintf(const char *fmt, ...) {
177   static const char spinnerchars[]= "/-\\";
178   static int spinnerchar;
179
180   if (!verbose)
181     return;
182
183   va_list al;
184   va_start(al,fmt);
185   fprintf(stderr,"      %c ",spinnerchars[spinnerchar]);
186   spinnerchar++; spinnerchar %= sizeof(spinnerchars)-1;
187   vfprintf(stderr,fmt,al);
188   putc('\r',stderr);
189   flushstderr();
190 }
191
192 static void bandlimit_sendend(uint64_t bytes, int *interval_usec_update) {
193   struct timespec ts_buf;
194   double elapsed= mgettime_elapsed(ts_sendstart, &ts_buf);
195   double secsperbyte_observed= elapsed / bytes;
196
197   stream_allow_secsperbyte=
198     secsperbyte_observed * max_bw_prop_mean / max_bw_prop_burst;
199
200   double min_update= elapsed / max_bw_prop_mean;
201   if (min_update > 1e3) min_update= 1e3;
202   int min_update_usec= min_update * 1e6;
203
204   if (*interval_usec_update > min_update_usec)
205     *interval_usec_update= min_update_usec;
206
207   verbosespinprintf("%12lluby %10.3fs %13.2fkby/s",
208                     (unsigned long long)bytes, elapsed,
209                     1e-3/secsperbyte_observed);
210 }
211  
212 static void copyfile(FILE *sf, copyfile_die_fn *sdie, const char *sxi,
213                      FILE *df, copyfile_die_fn *ddie, const char *dxi,
214                      uint64_t lstart, int amsender) {
215   struct timespec ts_last;
216   int now, r;
217   uint64_t l=lstart, done=0;
218
219   ts_last= ts_sendstart;
220
221   while (l>0) {
222     now= l < sizeof(mainbuf) ? l : sizeof(mainbuf);
223     if (now > txblocksz) now= txblocksz;
224
225     if (verbose) {
226       fprintf(stderr," %3d%% \r",
227               (int)(done*100.0/lstart));
228       flushstderr();
229     }
230
231     if (amsender) {
232       double elapsed_want= now * stream_allow_secsperbyte;
233       double elapsed= mgettime_elapsed(ts_last, &ts_last);
234       double needwait= elapsed_want - elapsed;
235       if (needwait > 1) needwait= 1;
236       if (needwait > 0) usleep(ceil(needwait * 1e6));
237     }
238
239     r= fread(mainbuf,1,now,sf);  if (r!=now) sdie(sf,sxi);
240     r= fwrite(mainbuf,1,now,df);  if (r!=now) ddie(df,dxi);
241     l -= now;
242     done += now;
243   }
244 }
245
246 static void copydie_inputfile(FILE *f, const char *filename) {
247   diee("read failed on source file `%s'", filename);
248 }
249 static void copydie_tmpwrite(FILE *f, const char *tmpfilename) {
250   diee("write failed to temporary receiving file `%s'", tmpfilename);
251 }
252 static void copydie_commsi(FILE *f, const char *what) {
253   die_badrecv(what);
254 }
255 static void copydie_commso(FILE *f, const char *what) {
256   die_badsend();
257 }
258   
259 static void receiver(const char *filename) {
260   FILE *newfile;
261   char *tmpfilename;
262   int r, c;
263
264   char *lastslash= strrchr(filename,'/');
265   if (!lastslash)
266     r= asprintf(&tmpfilename, ".rcopy-repeatedly.#%s#", filename);
267   else
268     r= asprintf(&tmpfilename, "%.*s/.rcopy-repeatedly.#%s#",
269                 (int)(lastslash-filename), filename, lastslash+1);
270   if (r==-1) diem();
271   
272   r= unlink(tmpfilename);
273   if (r && errno!=ENOENT)
274     diee("could not remove temporary receiving file `%s'", tmpfilename);
275   
276   for (;;) {
277     send_flush();
278     c= fgetc(commsi);
279
280     switch (c) {
281
282     case EOF:
283       if (ferror(commsi)) die_badrecv("transfer message code");
284       assert(feof(commsi));
285       return;
286
287     case REPLMSG_RM:
288       r= unlink(filename);
289       if (r && errno!=ENOENT)
290         diee("source file removed but could not remove destination file `%s'",
291              filename);
292       break;
293       
294     case REPLMSG_FILE64:
295       newfile= fopen(tmpfilename, "wb");
296       if (!newfile) diee("could not create temporary receiving file `%s'",
297                          tmpfilename);
298       uint8_t lbuf[8];
299       r= fread(lbuf,1,8,commsi);  if (r!=8) die_badrecv("FILE64 l");
300
301       uint64_t l=
302         (lbuf[0] << 28 << 28) |
303         (lbuf[1] << 24 << 24) |
304         (lbuf[2] << 16 << 24) |
305         (lbuf[3] <<  8 << 24) |
306         (lbuf[4]       << 24) |
307         (lbuf[5]       << 16) |
308         (lbuf[6]       <<  8) |
309         (lbuf[7]            ) ;
310
311       copyfile(commsi, copydie_commsi,"FILE64 file data",
312                newfile, copydie_tmpwrite,tmpfilename,
313                l, 0);
314
315       if (fclose(newfile)) diee("could not flush and close temporary"
316                                 " receiving file `%s'", tmpfilename);
317       if (rename(tmpfilename, filename))
318         diee("could not install new version of destination file `%s'",
319              filename);
320
321       sendbyte(REPLMSG_GO);
322       break;
323
324     default:
325       die_protocol("unknown transfer message code 0x%02x",c);
326
327     }
328   }
329 }
330
331 static void sender(const char *filename) {
332   FILE *f, *fold;
333   int interval_usec, r, c;
334   struct stat stabtest, stab;
335   enum { told_nothing, told_file, told_remove } told;
336
337   interval_usec= 0;
338   fold= 0;
339   told= told_nothing;
340   
341   for (;;) {
342     if (interval_usec) {
343       send_flush();
344       usleep(interval_usec);
345     }
346     interval_usec= min_interval_usec;
347
348     r= stat(filename, &stabtest);
349     if (r) {
350       f= 0;
351     } else {
352       if (told == told_file &&
353           stabtest.st_mode  == stab.st_mode  &&
354           stabtest.st_dev   == stab.st_dev   &&
355           stabtest.st_ino   == stab.st_ino   &&
356           stabtest.st_mtime == stab.st_mtime &&
357           stabtest.st_size  == stab.st_size)
358         continue;
359       f= fopen(filename, "rb");
360     }
361     
362     if (!f) {
363       if (errno!=ENOENT) diee("could not access source file `%s'",filename);
364       if (told != told_remove) {
365         verbosespinprintf(" ENOENT                                       ");
366         sendbyte(REPLMSG_RM);
367         told= told_remove;
368       }
369       continue;
370     }
371
372     if (fold) fclose(fold);
373     fold= 0;
374
375     r= fstat(fileno(f),&stab);
376     if (r) diee("could not fstat source file `%s'",filename);
377
378     if (!S_ISREG(stab.st_mode))
379       die(8,0,-1,"source file `%s' is not a plain file",filename);
380
381     uint8_t hbuf[9]= {
382       REPLMSG_FILE64,
383       stab.st_size >> 28 >> 28,
384       stab.st_size >> 24 >> 24,
385       stab.st_size >> 16 >> 24,
386       stab.st_size >>  8 >> 24,
387       stab.st_size       >> 24,
388       stab.st_size       >> 16,
389       stab.st_size       >>  8,
390       stab.st_size
391     };
392
393     bandlimit_sendstart();
394     
395     r= fwrite(hbuf,1,9,commso);  if (r!=9) die_badsend();
396
397     copyfile(f, copydie_inputfile,filename,
398              commso, copydie_commso,0,
399              stab.st_size, 1);
400
401     send_flush();
402
403     c= fgetc(commsi);  if (c==EOF) die_badrecv("ack");
404     if (c!=REPLMSG_GO) die_protocol("got %#02x instead of GO",c);
405
406     bandlimit_sendend(stab.st_size, &interval_usec);
407
408     fold= f;
409     told= told_file;
410   }
411 }
412
413 void usagemessage(void) {
414   puts("usage: rcopy-repeatedly [<options>] <file> <file>\n"
415        " <file> may be <local-file> or [<user>@]<host>:<file>\n"
416        " exactly one of each of the two forms must be provided\n"
417        " a file is taken as remote if it has a : before the first /\n"
418        "options\n"
419        " --help\n");
420 }
421
422 typedef struct {
423   const char *userhost, *path;
424 } FileSpecification;
425
426 static FileSpecification srcspec, dstspec;
427
428 static void of__server(const struct cmdinfo *ci, const char *val) {
429   int ncount= nsargs + 1 + !!val;
430   sargs= xrealloc(sargs, sizeof(*sargs) * ncount);
431   sargs[nsargs++]= ci->olong;
432   if (val)
433     sargs[nsargs++]= val;
434 }
435
436 static int of__server_int(const struct cmdinfo *ci, const char *val) {
437   of__server(ci,val);
438   long v;
439   char *ep;
440   errno= 0; v= strtol(val,&ep,10);
441   if (!*val || *ep || errno || v<INT_MIN || v>INT_MAX)
442     badusage("bad integer argument `%s' for --%s",val,ci->olong);
443   return v;
444 }
445
446 static void of_help(const struct cmdinfo *ci, const char *val) {
447   usagemessage();
448   if (ferror(stdout)) diee("could not write usage message to stdout");
449   exit(0);
450 }
451
452 static void of_bw(const struct cmdinfo *ci, const char *val) {
453   int pct= of__server_int(ci,val);
454   if (pct<1 || pct>100)
455     badusage("bandwidth percentage must be between 1 and 100 inclusive");
456   *(double*)ci->parg= pct * 0.01;
457 }
458
459 static void of_server_int(const struct cmdinfo *ci, const char *val) {
460   *(int*)ci->parg= of__server_int(ci,val);
461 }
462
463 static const struct cmdinfo cmdinfos[]= {
464   { "help",     .call= of_help },
465   { "max-bandwidth-percent-mean", 0,1,.call=of_bw,.parg=&max_bw_prop_mean  },
466   { "max-bandwidth-percent-burst",0,1,.call=of_bw,.parg=&max_bw_prop_burst },
467   { "tx-block-size",     0,1,.call=of_server_int, .parg=&txblocksz         },
468   { "min-interval-usec", 0,1,.call=of_server_int, .parg=&min_interval_usec },
469   { "rcopy-repeatedly",  0,1, .sassignto=&rcopy_repeatedly_program         },
470   { "ssh-program",       0,1, .sassignto=&rsh_program                      },
471   { "receiver", .iassignto=&server_upcopy, .arg=0                          },
472   { "sender",   .iassignto=&server_upcopy, .arg=1                          },
473   { 0 }
474 };
475
476 static void server(const char *filename) {
477   int c;
478   commsi= stdin;
479   commso= stdout;
480   fprintf(commso, "%s0002 %c\n", banner, server_upcopy?'u':'d');
481   send_flush();
482   c= fgetc(commsi);  if (c==EOF) die_badrecv("initial go");
483   if (c!=REPLMSG_GO) die_protocol("initial go message was %#02x instead",c);
484
485   if (server_upcopy)
486     sender(filename);
487   else
488     receiver(filename);
489 }
490
491 static void client(void) {
492   int uppipe[2], downpipe[2], r;
493   pid_t child;
494
495   mpipe(uppipe);
496   mpipe(downpipe);
497
498   FileSpecification *remotespec= srcspec.userhost ? &srcspec : &dstspec;
499   const char *remotemode= srcspec.userhost ? "--sender" : "--receiver";
500
501   sargs= xrealloc(sargs, sizeof(*sargs) * (7 + nsargs));
502   memmove(sargs+5, sargs, sizeof(*sargs) * nsargs);
503   sargs[0]= rsh_program;
504   sargs[1]= remotespec->userhost;
505   sargs[2]= rcopy_repeatedly_program;
506   sargs[3]= remotemode;
507   sargs[4]= "--";
508   sargs[5+nsargs]= remotespec->path;
509   sargs[6+nsargs]= 0;
510     
511   child= fork();
512   if (child==-1) diee("fork failed");
513   if (!child) {
514     mdup2(downpipe[0],0);
515     mdup2(uppipe[1],1);
516     close(uppipe[0]); close(downpipe[0]);
517     close(uppipe[1]); close(downpipe[1]);
518
519     execvp(rsh_program, (char**)sargs);
520     diee("failed to execute rsh program `%s'",rsh_program);
521   }
522
523   commso= fdopen(downpipe[1],"wb");
524   commsi= fdopen(uppipe[0],"rb");
525   if (!commso || !commsi) diee("fdopen failed");
526   close(downpipe[0]);
527   close(uppipe[1]);
528   
529   char banbuf[sizeof(banner)-1 + 5 + 1];
530   r= fread(banbuf,1,sizeof(banbuf)-1,commsi);
531   if (ferror(commsi)) die_badrecv("read banner");
532
533   if (r!=sizeof(banbuf)-1 ||
534       memcmp(banbuf,banner,sizeof(banner)-1) ||
535       banbuf[sizeof(banner)-1 + 4] != ' ') {
536     const char **sap;
537     int count=0;
538     for (count=0, sap=sargs; *sap; sap++) count+= strlen(*sap)+1;
539     char *cmdline= xmalloc(count+1);
540     cmdline[0]=' ';
541     for (sap=sargs; *sap; sap++) {
542       strcat(cmdline," ");
543       strcat(cmdline,*sap);
544     }
545     
546     die(8,0,-1,"did not receive banner as expected -"
547         " shell dirty? ssh broken?\n"
548         " try running\n"
549         "  %s\n"
550         " and expect the first line to be\n"
551         "  %s",
552         cmdline, banner);
553   }
554   
555   banbuf[sizeof(banbuf)-1]= 0;
556   char *ep;
557   long decllen= strtoul(banbuf + sizeof(banner)-1, &ep, 16);
558   if (ep!=banbuf + sizeof(banner)-1 + 4 || *ep!=' ')
559     die_protocol("declaration length syntax error (`%s')",ep);
560   assert(decllen <= sizeof(mainbuf));
561   if (decllen<2) die_protocol("declaration too short");
562
563   r= fread(mainbuf,1,decllen,commsi);
564   if (r!=decllen) die_badrecv("declaration");
565   if (mainbuf[decllen-1] != '\n')
566     die_protocol("declaration missing final newline");
567   if (mainbuf[0] != (remotespec==&srcspec ? 'u' : 'd'))
568     die_protocol("declaration incorrect direction indicator");
569
570   sendbyte(REPLMSG_GO);
571
572   if (remotespec==&srcspec)
573     receiver(dstspec.path);
574   else
575     sender(srcspec.path);
576 }
577
578 static void parse_file_specification(FileSpecification *fs, const char *arg,
579                                      const char *what) {
580   const char *colon;
581   
582   if (!arg) badusage("too few arguments - missing %s\n",what);
583
584   for (colon=arg; ; colon++) {
585     if (!*colon || *colon=='/') {
586       fs->userhost=0;
587       fs->path= arg;
588       return;
589     }
590     if (*colon==':') {
591       char *uh= xmalloc(colon-arg + 1);
592       memcpy(uh,arg, colon-arg);  uh[colon-arg]= 0;
593       fs->userhost= uh;
594       fs->path= colon+1;
595       return;
596     }
597   }
598 }
599
600 int main(int argc, const char *const *argv) {
601   setvbuf(stderr,0,_IOLBF,BUFSIZ);
602
603   myopt(&argv, cmdinfos);
604
605   if (!rsh_program) rsh_program= getenv("RCOPY_REPEATEDLY_RSH");
606   if (!rsh_program) rsh_program= getenv("RSYNC_RSH");
607   if (!rsh_program) rsh_program= "ssh";
608
609   if (max_bw_prop_burst / max_bw_prop_mean < 1.1)
610     badusage("max bandwidth prop burst must be at least 1.1x"
611              " max bandwidth prop mean");
612
613   if (txblocksz<1) badusage("transmit block size must be at least 1");
614   if (min_interval_usec<0) badusage("minimum update interval may not be -ve");
615
616   if (server_upcopy>=0) {
617     if (!argv[0] || argv[1])
618       badusage("server mode must have just the filename as non-option arg");
619     server(argv[0]);
620   } else {
621     parse_file_specification(&srcspec, argv[0], "source");
622     parse_file_specification(&dstspec, argv[1], "destination");
623     if (argv[2]) badusage("too many non-option arguments");
624     if (!!srcspec.userhost == !!dstspec.userhost)
625       badusage("need exactly one remote file argument");
626     client();
627   }
628   return 0;
629 }