chiark / gitweb /
config: fix client config (no SERVER)
[hippotat.git] / hippotatlib / __init__.py
1 # -*- python -*-
2 #
3 # Hippotat - Asinine IP Over HTTP program
4 # hippotatlib/__init__.py - common library code
5 #
6 # Copyright 2017 Ian Jackson
7 #
8 # GPLv3+
9 #
10 #    This program is free software: you can redistribute it and/or modify
11 #    it under the terms of the GNU General Public License as published by
12 #    the Free Software Foundation, either version 3 of the License, or
13 #    (at your option) any later version.
14 #
15 #    This program is distributed in the hope that it will be useful,
16 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
17 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 #    GNU General Public License for more details.
19 #
20 #    You should have received a copy of the GNU General Public License
21 #    along with this program, in the file GPLv3.  If not,
22 #    see <http://www.gnu.org/licenses/>.
23
24
25 import signal
26 signal.signal(signal.SIGINT, signal.SIG_DFL)
27
28 import sys
29 import os
30
31 from zope.interface import implementer
32
33 import twisted
34 from twisted.internet import reactor
35 import twisted.internet.endpoints
36 import twisted.logger
37 from twisted.logger import LogLevel
38 import twisted.python.constants
39 from twisted.python.constants import NamedConstant
40
41 import ipaddress
42 from ipaddress import AddressValueError
43
44 from optparse import OptionParser
45 import configparser
46 from configparser import ConfigParser
47 from configparser import NoOptionError
48
49 from functools import partial
50
51 import collections
52 import time
53 import codecs
54 import traceback
55
56 import re as regexp
57
58 import hippotatlib.slip as slip
59
60 class DBG(twisted.python.constants.Names):
61   INIT = NamedConstant()
62   CONFIG = NamedConstant()
63   ROUTE = NamedConstant()
64   DROP = NamedConstant()
65   OWNSOURCE = NamedConstant()
66   FLOW = NamedConstant()
67   HTTP = NamedConstant()
68   TWISTED = NamedConstant()
69   QUEUE = NamedConstant()
70   HTTP_CTRL = NamedConstant()
71   QUEUE_CTRL = NamedConstant()
72   HTTP_FULL = NamedConstant()
73   CTRL_DUMP = NamedConstant()
74   SLIP_FULL = NamedConstant()
75   DATA_COMPLETE = NamedConstant()
76
77 _hex_codec = codecs.getencoder('hex_codec')
78
79 #---------- logging ----------
80
81 org_stderr = sys.stderr
82
83 log = twisted.logger.Logger()
84
85 debug_set = set()
86 debug_def_detail = DBG.HTTP
87
88 def log_debug(dflag, msg, idof=None, d=None):
89   if dflag not in debug_set: return
90   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
91   if idof is not None:
92     msg = '[%#x] %s' % (id(idof), msg)
93   if d is not None:
94     trunc = ''
95     if not DBG.DATA_COMPLETE in debug_set:
96       if len(d) > 64:
97         d = d[0:64]
98         trunc = '...'
99     d = _hex_codec(d)[0].decode('ascii')
100     msg += ' ' + d + trunc
101   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
102
103 def logevent_is_boringtwisted(event):
104   try:
105     if event.get('log_level') != LogLevel.info:
106       return False
107     dflag = event.get('dflag')
108     if dflag is False                            : return False
109     if dflag                         in debug_set: return False
110     if dflag is None and DBG.TWISTED in debug_set: return False
111     return True
112   except Exception:
113     print(traceback.format_exc(), file=org_stderr)
114     return False
115
116 @implementer(twisted.logger.ILogFilterPredicate)
117 class LogNotBoringTwisted:
118   def __call__(self, event):
119     return (
120       twisted.logger.PredicateResult.no
121       if logevent_is_boringtwisted(event) else
122       twisted.logger.PredicateResult.yes
123     )
124
125 #---------- default config ----------
126
127 defcfg = '''
128 [DEFAULT]
129 max_batch_down = 65536
130 max_queue_time = 10
131 target_requests_outstanding = 3
132 http_timeout = 30
133 http_timeout_grace = 5
134 max_requests_outstanding = 6
135 max_batch_up = 4000
136 http_retry = 5
137 port = 80
138 vroutes = ''
139
140 #[server] or [<client>] overrides
141 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
142
143 # relating to virtual network
144 mtu = 1500
145
146 # addrs = 127.0.0.1 ::1
147 # url
148
149 # relating to virtual network
150 vvnetwork = 172.24.230.192
151 # vnetwork = <prefix>/<len>
152 # vadd  r  = <ipaddr>
153 # vrelay   = <ipaddr>
154
155
156 # [<client-ip4-or-ipv6-address>]
157 # password = <password>    # used by both, must match
158
159 [LIMIT]
160 max_batch_down = 262144
161 max_queue_time = 121
162 http_timeout = 121
163 target_requests_outstanding = 10
164 '''
165
166 # these need to be defined here so that they can be imported by import *
167 cfg = ConfigParser(strict=False)
168 optparser = OptionParser()
169
170 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
171 def mime_translate(s):
172   # SLIP-encoded packets cannot contain ESC ESC.
173   # Swap `-' and ESC.  The result cannot contain `--'
174   return s.translate(_mimetrans)
175
176 class ConfigResults:
177   def __init__(self):
178     pass
179   def __repr__(self):
180     return 'ConfigResults('+repr(self.__dict__)+')'
181
182 def log_discard(packet, iface, saddr, daddr, why):
183   log_debug(DBG.DROP,
184             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
185             d=packet)
186
187 #---------- packet parsing ----------
188
189 def packet_addrs(packet):
190   version = packet[0] >> 4
191   if version == 4:
192     addrlen = 4
193     saddroff = 3*4
194     factory = ipaddress.IPv4Address
195   elif version == 6:
196     addrlen = 16
197     saddroff = 2*4
198     factory = ipaddress.IPv6Address
199   else:
200     raise ValueError('unsupported IP version %d' % version)
201   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
202   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
203   return (saddr, daddr)
204
205 #---------- address handling ----------
206
207 def ipaddr(input):
208   try:
209     r = ipaddress.IPv4Address(input)
210   except AddressValueError:
211     r = ipaddress.IPv6Address(input)
212   return r
213
214 def ipnetwork(input):
215   try:
216     r = ipaddress.IPv4Network(input)
217   except NetworkValueError:
218     r = ipaddress.IPv6Network(input)
219   return r
220
221 #---------- ipif (SLIP) subprocess ----------
222
223 class SlipStreamDecoder():
224   def __init__(self, desc, on_packet):
225     self._buffer = b''
226     self._on_packet = on_packet
227     self._desc = desc
228     self._log('__init__')
229
230   def _log(self, msg, **kwargs):
231     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
232
233   def inputdata(self, data):
234     self._log('inputdata', d=data)
235     data = self._buffer + data
236     self._buffer = b''
237     packets = slip.decode(data, True)
238     self._buffer = packets.pop()
239     for packet in packets:
240       self._maybe_packet(packet)
241     self._log('bufremain', d=self._buffer)
242
243   def _maybe_packet(self, packet):
244     self._log('maybepacket', d=packet)
245     if len(packet):
246       self._on_packet(packet)
247
248   def flush(self):
249     self._log('flush')
250     data = self._buffer
251     self._buffer = b''
252     packets = slip.decode(data)
253     assert(len(packets) == 1)
254     self._maybe_packet(packets[0])
255
256 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
257   def __init__(self, router):
258     self._router = router
259     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
260   def connectionMade(self): pass
261   def outReceived(self, data):
262     self._decoder.inputdata(data)
263   def slip_on_packet(self, packet):
264     (saddr, daddr) = packet_addrs(packet)
265     if saddr.is_link_local or daddr.is_link_local:
266       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
267       return
268     self._router(packet, saddr, daddr)
269   def processEnded(self, status):
270     status.raiseException()
271
272 def start_ipif(command, router):
273   ipif = _IpifProcessProtocol(router)
274   reactor.spawnProcess(ipif,
275                        '/bin/sh',['sh','-xc', command],
276                        childFDs={0:'w', 1:'r', 2:2},
277                        env=None)
278   return ipif
279
280 def queue_inbound(ipif, packet):
281   log_debug(DBG.FLOW, "queue_inbound", d=packet)
282   ipif.transport.write(slip.delimiter)
283   ipif.transport.write(slip.encode(packet))
284   ipif.transport.write(slip.delimiter)
285
286 #---------- packet queue ----------
287
288 class PacketQueue():
289   def __init__(self, desc, max_queue_time):
290     self._desc = desc
291     assert(desc + '')
292     self._max_queue_time = max_queue_time
293     self._pq = collections.deque() # packets
294
295   def _log(self, dflag, msg, **kwargs):
296     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
297
298   def append(self, packet):
299     self._log(DBG.QUEUE, 'append', d=packet)
300     self._pq.append((time.monotonic(), packet))
301
302   def nonempty(self):
303     self._log(DBG.QUEUE, 'nonempty ?')
304     while True:
305       try: (queuetime, packet) = self._pq[0]
306       except IndexError:
307         self._log(DBG.QUEUE, 'nonempty ? empty.')
308         return False
309
310       age = time.monotonic() - queuetime
311       if age > self._max_queue_time:
312         # strip old packets off the front
313         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
314         self._pq.popleft()
315         continue
316
317       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
318       return True
319
320   def process(self, sizequery, moredata, max_batch):
321     # sizequery() should return size of batch so far
322     # moredata(s) should add s to batch
323     self._log(DBG.QUEUE, 'process...')
324     while True:
325       try: (dummy, packet) = self._pq[0]
326       except IndexError:
327         self._log(DBG.QUEUE, 'process... empty')
328         break
329
330       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
331
332       encoded = slip.encode(packet)
333       sofar = sizequery()  
334
335       self._log(DBG.QUEUE_CTRL,
336                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
337                 d=encoded)
338
339       if sofar > 0:
340         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
341           self._log(DBG.QUEUE_CTRL, 'process... overflow')
342           break
343         moredata(slip.delimiter)
344
345       moredata(encoded)
346       self._pq.popleft()
347
348 #---------- error handling ----------
349
350 _crashing = False
351
352 def crash(err):
353   global _crashing
354   _crashing = True
355   print('========== CRASH ==========', err,
356         '===========================', file=sys.stderr)
357   try: reactor.stop()
358   except twisted.internet.error.ReactorNotRunning: pass
359
360 def crash_on_defer(defer):
361   defer.addErrback(lambda err: crash(err))
362
363 def crash_on_critical(event):
364   if event.get('log_level') >= LogLevel.critical:
365     crash(twisted.logger.formatEvent(event))
366
367 #---------- config processing ----------
368
369 def _cfg_process_putatives():
370   servers = { }
371   clients = { }
372   # maps from abstract object to canonical name for cs's
373
374   def putative(cmap, abstract, canoncs):
375     try:
376       current_canoncs = cmap[abstract]
377     except KeyError:
378       pass
379     else:
380       assert(current_canoncs == canoncs)
381     cmap[abstract] = canoncs
382
383   server_pat = r'[-.0-9A-Za-z]+'
384   client_pat = r'[.:0-9a-f]+'
385   server_re = regexp.compile(server_pat)
386   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
387
388   for cs in cfg.sections():
389     if cs == 'LIMIT':
390       # plan A "[LIMIT]"
391       continue
392
393     try:
394       # plan B "[<client>]" part 1
395       ci = ipaddr(cs)
396     except AddressValueError:
397
398       if server_re.fullmatch(cs):
399         # plan C "[<servername>]"
400         putative(servers, cs, cs)
401         continue
402
403       if serverclient_re.fullmatch(cs):
404         # plan D "[<servername> <client>]" part 1
405         (pss,pcs) = cs.split(' ')
406
407         if pcs == 'LIMIT':
408           # plan E "[<servername> LIMIT]"
409           continue
410
411         try:
412           # plan D "[<servername> <client>]" part 2
413           ci = ipaddr(pc)
414         except AddressValueError:
415           # plan F "[<some thing we do not understand>]"
416           # well, we ignore this
417           print('warning: ignoring config section %s' % cs, file=sys.stderr)
418           continue
419
420         else: # no AddressValueError
421           # plan D "[<servername> <client]" part 3
422           putative(clients, ci, pcs)
423           putative(servers, pss, pss)
424           continue
425
426     else: # no AddressValueError
427       # plan B "[<client>" part 2
428       putative(clients, ci, cs)
429       continue
430
431   return (servers, clients)
432
433 def cfg_process_common(c, ss):
434   c.mtu = cfg.getint(ss, 'mtu')
435
436 def cfg_process_saddrs(c, ss):
437   class ServerAddr():
438     def __init__(self, port, addrspec):
439       self.port = port
440       # also self.addr
441       try:
442         self.addr = ipaddress.IPv4Address(addrspec)
443         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
444         self._inurl = b'%s'
445       except AddressValueError:
446         self.addr = ipaddress.IPv6Address(addrspec)
447         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
448         self._inurl = b'[%s]'
449     def make_endpoint(self):
450       return self._endpointfactory(reactor, self.port,
451                                    interface= '%s' % self.addr)
452     def url(self):
453       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
454       if self.port != 80: url += b':%d' % self.port
455       url += b'/'
456       return url
457     def __repr__(self):
458       return 'ServerAddr'+repr((self.port,self.addr))
459
460   c.port = cfg.getint(ss,'port')
461   c.saddrs = [ ]
462   for addrspec in cfg.get(ss, 'addrs').split():
463     sa = ServerAddr(c.port, addrspec)
464     c.saddrs.append(sa)
465
466 def cfg_process_vnetwork(c, ss):
467   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
468   if c.vnetwork.num_addresses < 3 + 2:
469     raise ValueError('vnetwork needs at least 2^3 addresses')
470
471 def cfg_process_vaddr(c, ss):
472   try:
473     c.vaddr = cfg.get(ss,'vaddr')
474   except NoOptionError:
475     cfg_process_vnetwork(c, ss)
476     c.vaddr = next(c.vnetwork.hosts())
477
478 def cfg_search_section(key,sections):
479   for section in sections:
480     if cfg.has_option(section, key):
481       return section
482   raise NoOptionError(key, repr(sections))
483
484 def cfg_search(getter,key,sections):
485   section = cfg_search_section(key,sections)
486   return getter(section, key)
487
488 def cfg_process_client_limited(cc,ss,sections,key):
489   val = cfg_search(cfg.getint, key, sections)
490   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
491   cc.__dict__[key] = min(val,lim)
492
493 def cfg_process_client_common(cc,ss,cs,ci):
494   # returns sections to search in, iff password is defined, otherwise None
495   cc.ci = ci
496
497   sections = ['%s %s' % (ss,cs),
498               cs,
499               ss,
500               'DEFAULT']
501
502   try: pwsection = cfg_search_section('password', sections)
503   except NoOptionError: return None
504     
505   pw = cfg.get(pwsection, 'password')
506   cc.password = pw.encode('utf-8')
507
508   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
509   cfg_process_client_limited(cc,ss,sections,'http_timeout')
510
511   return sections
512
513 def cfg_process_ipif(c, sections, varmap):
514   for d, s in varmap:
515     try: v = getattr(c, s)
516     except AttributeError: continue
517     setattr(c, d, v)
518
519   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
520
521   section = cfg_search_section('ipif', sections)
522   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
523
524 #---------- startup ----------
525
526 def log_debug_config(m):
527   if not DBG.CONFIG in debug_set: return
528   print('DBG.CONFIG:', m)
529
530 def common_startup(process_cfg):
531   # calls process_cfg(putative_clients, putative_servers)
532
533   # ConfigParser hates #-comments after values
534   trailingcomments_re = regexp.compile(r'#.*')
535   cfg.read_string(trailingcomments_re.sub('', defcfg))
536   need_defcfg = True
537
538   def readconfig(pathname, mandatory=True):
539     def log(m, p=pathname):
540       if not DBG.CONFIG in debug_set: return
541       log_debug_config('%s: %s' % (m, pathname))
542
543     try:
544       files = os.listdir(pathname)
545
546     except FileNotFoundError:
547       if mandatory: raise
548       log('skipped')
549       return
550
551     except NotADirectoryError:
552       cfg.read(pathname)
553       log('read file')
554       return
555
556     # is a directory
557     log('directory')
558     re = regexp.compile('[^-A-Za-z0-9_]')
559     for f in os.listdir(pathname):
560       if re.search(f): continue
561       subpath = pathname + '/' + f
562       try:
563         os.stat(subpath)
564       except FileNotFoundError:
565         log('entry skipped', subpath)
566         continue
567       cfg.read(subpath)
568       log('entry read', subpath)
569       
570   def oc_config(od,os, value, op):
571     nonlocal need_defcfg
572     need_defcfg = False
573     readconfig(value)
574
575   def oc_extra_config(od,os, value, op):
576     readconfig(value)
577
578   def read_defconfig():
579     readconfig('/etc/hippotat/config.d', False)
580     readconfig('/etc/hippotat/passwords.d', False)
581     readconfig('/etc/hippotat/master.cfg',   False)
582
583   def oc_defconfig(od,os, value, op):
584     nonlocal need_defcfg
585     need_defcfg = False
586     read_defconfig(value)
587
588   def dfs_less_detailed(dl):
589     return [df for df in DBG.iterconstants() if df <= dl]
590
591   def ds_default(od,os,dl,op):
592     global debug_set
593     debug_set.clear
594     debug_set |= set(dfs_less_detailed(debug_def_detail))
595
596   def ds_select(od,os, spec, op):
597     for it in spec.split(','):
598
599       if it.startswith('-'):
600         mutator = debug_set.discard
601         it = it[1:]
602       else:
603         mutator = debug_set.add
604
605       if it == '+':
606         dfs = DBG.iterconstants()
607
608       else:
609         if it.endswith('+'):
610           mapper = dfs_less_detailed
611           it = it[0:len(it)-1]
612         else:
613           mapper = lambda x: [x]
614
615           try:
616             dfspec = DBG.lookupByName(it)
617           except ValueError:
618             optparser.error('unknown debug flag %s in --debug-select' % it)
619
620         dfs = mapper(dfspec)
621
622       for df in dfs:
623         mutator(df)
624
625   optparser.add_option('-D', '--debug',
626                        nargs=0,
627                        action='callback',
628                        help='enable default debug (to stdout)',
629                        callback= ds_default)
630
631   optparser.add_option('--debug-select',
632                        nargs=1,
633                        type='string',
634                        metavar='[-]DFLAG[+]|[-]+,...',
635                        help=
636 '''enable (`-': disable) each specified DFLAG;
637 `+': do same for all "more interesting" DFLAGSs;
638 just `+': all DFLAGs.
639   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
640                        action='callback',
641                        callback= ds_select)
642
643   optparser.add_option('-c', '--config',
644                        nargs=1,
645                        type='string',
646                        metavar='CONFIGFILE',
647                        dest='configfile',
648                        action='callback',
649                        callback= oc_config)
650
651   optparser.add_option('--extra-config',
652                        nargs=1,
653                        type='string',
654                        metavar='CONFIGFILE',
655                        dest='configfile',
656                        action='callback',
657                        callback= oc_extra_config)
658
659   optparser.add_option('--default-config',
660                        action='callback',
661                        callback= oc_defconfig)
662
663   (opts, args) = optparser.parse_args()
664   if len(args): optparser.error('no non-option arguments please')
665
666   if need_defcfg:
667     read_defconfig()
668
669   try:
670     (pss, pcs) = _cfg_process_putatives()
671     process_cfg(opts, pss, pcs)
672   except (configparser.Error, ValueError):
673     traceback.print_exc(file=sys.stderr)
674     print('\nInvalid configuration, giving up.', file=sys.stderr)
675     sys.exit(12)
676
677
678   #print('X', debug_set, file=sys.stderr)
679
680   log_formatter = twisted.logger.formatEventAsClassicLogText
681   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
682   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
683   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
684   stdsomething_obs = twisted.logger.FilteringLogObserver(
685     stderr_obs, [pred], stdout_obs
686   )
687   global file_log_observer
688   file_log_observer = twisted.logger.FilteringLogObserver(
689     stdsomething_obs, [LogNotBoringTwisted()]
690   )
691   #log_observer = stdsomething_obs
692   twisted.logger.globalLogBeginner.beginLoggingTo(
693     [ file_log_observer, crash_on_critical ]
694     )
695
696 def common_run():
697   log_debug(DBG.INIT, 'entering reactor')
698   if not _crashing: reactor.run()
699   print('ENDED', file=sys.stderr)
700   sys.exit(16)