chiark / gitweb /
better config errors
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag                         in debug_set: return yes
90       if dflag is None and DBG.TWISTED in debug_set: return yes
91       return no
92     except Exception:
93       print(traceback.format_exc(), file=org_stderr)
94       return yes
95
96 #---------- default config ----------
97
98 defcfg = '''
99 [DEFAULT]
100 #[<client>] overrides
101 max_batch_down = 65536           # used by server, subject to [limits]
102 max_queue_time = 10              # used by server, subject to [limits]
103 target_requests_outstanding = 3  # must match; subject to [limits] on server
104 http_timeout = 30                # used by both } must be
105 http_timeout_grace = 5           # used by both }  compatible
106 max_requests_outstanding = 4     # used by client
107 max_batch_up = 4000              # used by client
108 http_retry = 5                   # used by client
109
110 #[server] or [<client>] overrides
111 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
112 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
113 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
114 #      from   on client  <client>         [virtual]server   [virtual]routes
115
116 [virtual]
117 mtu = 1500
118 routes = ''
119 # network = <prefix>/<len>  # mandatory for server
120 # server  = <ipaddr>   # used by both, default is computed from `network'
121 # relay   = <ipaddr>   # used by server, default from `network' and `server'
122 #  default server is first host in network
123 #  default relay is first host which is not server
124
125 [server]
126 # addrs = 127.0.0.1 ::1    # mandatory for server
127 port = 80                  # used by server
128 # url              # used by client; default from first `addrs' and `port'
129
130 # [<client-ip4-or-ipv6-address>]
131 # password = <password>    # used by both, must match
132
133 [limits]
134 max_batch_down = 262144           # used by server
135 max_queue_time = 121              # used by server
136 http_timeout = 121                # used by server
137 target_requests_outstanding = 10  # used by server
138 '''
139
140 # these need to be defined here so that they can be imported by import *
141 cfg = ConfigParser(strict=False)
142 optparser = OptionParser()
143
144 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
145 def mime_translate(s):
146   # SLIP-encoded packets cannot contain ESC ESC.
147   # Swap `-' and ESC.  The result cannot contain `--'
148   return s.translate(_mimetrans)
149
150 class ConfigResults:
151   def __init__(self, d = { }):
152     self.__dict__ = d
153   def __repr__(self):
154     return 'ConfigResults('+repr(self.__dict__)+')'
155
156 c = ConfigResults()
157
158 def log_discard(packet, iface, saddr, daddr, why):
159   log_debug(DBG.DROP,
160             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
161             d=packet)
162
163 #---------- packet parsing ----------
164
165 def packet_addrs(packet):
166   version = packet[0] >> 4
167   if version == 4:
168     addrlen = 4
169     saddroff = 3*4
170     factory = ipaddress.IPv4Address
171   elif version == 6:
172     addrlen = 16
173     saddroff = 2*4
174     factory = ipaddress.IPv6Address
175   else:
176     raise ValueError('unsupported IP version %d' % version)
177   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
178   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
179   return (saddr, daddr)
180
181 #---------- address handling ----------
182
183 def ipaddr(input):
184   try:
185     r = ipaddress.IPv4Address(input)
186   except AddressValueError:
187     r = ipaddress.IPv6Address(input)
188   return r
189
190 def ipnetwork(input):
191   try:
192     r = ipaddress.IPv4Network(input)
193   except NetworkValueError:
194     r = ipaddress.IPv6Network(input)
195   return r
196
197 #---------- ipif (SLIP) subprocess ----------
198
199 class SlipStreamDecoder():
200   def __init__(self, desc, on_packet):
201     self._buffer = b''
202     self._on_packet = on_packet
203     self._desc = desc
204     self._log('__init__')
205
206   def _log(self, msg, **kwargs):
207     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
208
209   def inputdata(self, data):
210     self._log('inputdata', d=data)
211     packets = slip.decode(data)
212     packets[0] = self._buffer + packets[0]
213     self._buffer = packets.pop()
214     for packet in packets:
215       self._maybe_packet(packet)
216     self._log('bufremain', d=self._buffer)
217
218   def _maybe_packet(self, packet):
219     self._log('maybepacket', d=packet)
220     if len(packet):
221       self._on_packet(packet)
222
223   def flush(self):
224     self._log('flush')
225     self._maybe_packet(self._buffer)
226     self._buffer = b''
227
228 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
229   def __init__(self, router):
230     self._router = router
231     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
232   def connectionMade(self): pass
233   def outReceived(self, data):
234     self._decoder.inputdata(data)
235   def slip_on_packet(self, packet):
236     (saddr, daddr) = packet_addrs(packet)
237     if saddr.is_link_local or daddr.is_link_local:
238       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
239       return
240     self._router(packet, saddr, daddr)
241   def processEnded(self, status):
242     status.raiseException()
243
244 def start_ipif(command, router):
245   global ipif
246   ipif = _IpifProcessProtocol(router)
247   reactor.spawnProcess(ipif,
248                        '/bin/sh',['sh','-xc', command],
249                        childFDs={0:'w', 1:'r', 2:2},
250                        env=None)
251
252 def queue_inbound(packet):
253   log_debug(DBG.FLOW, "queue_inbound", d=packet)
254   ipif.transport.write(slip.delimiter)
255   ipif.transport.write(slip.encode(packet))
256   ipif.transport.write(slip.delimiter)
257
258 #---------- packet queue ----------
259
260 class PacketQueue():
261   def __init__(self, desc, max_queue_time):
262     self._desc = desc
263     assert(desc + '')
264     self._max_queue_time = max_queue_time
265     self._pq = collections.deque() # packets
266
267   def _log(self, dflag, msg, **kwargs):
268     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
269
270   def append(self, packet):
271     self._log(DBG.QUEUE, 'append', d=packet)
272     self._pq.append((time.monotonic(), packet))
273
274   def nonempty(self):
275     self._log(DBG.QUEUE, 'nonempty ?')
276     while True:
277       try: (queuetime, packet) = self._pq[0]
278       except IndexError:
279         self._log(DBG.QUEUE, 'nonempty ? empty.')
280         return False
281
282       age = time.monotonic() - queuetime
283       if age > self._max_queue_time:
284         # strip old packets off the front
285         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
286         self._pq.popleft()
287         continue
288
289       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
290       return True
291
292   def process(self, sizequery, moredata, max_batch):
293     # sizequery() should return size of batch so far
294     # moredata(s) should add s to batch
295     self._log(DBG.QUEUE, 'process...')
296     while True:
297       try: (dummy, packet) = self._pq[0]
298       except IndexError:
299         self._log(DBG.QUEUE, 'process... empty')
300         break
301
302       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
303
304       encoded = slip.encode(packet)
305       sofar = sizequery()  
306
307       self._log(DBG.QUEUE_CTRL,
308                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
309                 d=encoded)
310
311       if sofar > 0:
312         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
313           self._log(DBG.QUEUE_CTRL, 'process... overflow')
314           break
315         moredata(slip.delimiter)
316
317       moredata(encoded)
318       self._pq.popleft()
319
320 #---------- error handling ----------
321
322 _crashing = False
323
324 def crash(err):
325   global _crashing
326   _crashing = True
327   print('========== CRASH ==========', err,
328         '===========================', file=sys.stderr)
329   try: reactor.stop()
330   except twisted.internet.error.ReactorNotRunning: pass
331
332 def crash_on_defer(defer):
333   defer.addErrback(lambda err: crash(err))
334
335 def crash_on_critical(event):
336   if event.get('log_level') >= LogLevel.critical:
337     crash(twisted.logger.formatEvent(event))
338
339 #---------- config processing ----------
340
341 def process_cfg_common_always():
342   global mtu
343   c.mtu = cfg.get('virtual','mtu')
344
345 def process_cfg_ipif(section, varmap):
346   for d, s in varmap:
347     try: v = getattr(c, s)
348     except AttributeError: continue
349     setattr(c, d, v)
350
351   #print(repr(c))
352
353   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
354
355 def process_cfg_network():
356   c.network = ipnetwork(cfg.get('virtual','network'))
357   if c.network.num_addresses < 3 + 2:
358     raise ValueError('network needs at least 2^3 addresses')
359
360 def process_cfg_server():
361   try:
362     c.server = cfg.get('virtual','server')
363   except NoOptionError:
364     process_cfg_network()
365     c.server = next(c.network.hosts())
366
367 class ServerAddr():
368   def __init__(self, port, addrspec):
369     self.port = port
370     # also self.addr
371     try:
372       self.addr = ipaddress.IPv4Address(addrspec)
373       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
374       self._inurl = b'%s'
375     except AddressValueError:
376       self.addr = ipaddress.IPv6Address(addrspec)
377       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
378       self._inurl = b'[%s]'
379   def make_endpoint(self):
380     return self._endpointfactory(reactor, self.port, self.addr)
381   def url(self):
382     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
383     if self.port != 80: url += b':%d' % self.port
384     url += b'/'
385     return url
386     
387 def process_cfg_saddrs():
388   try: port = cfg.getint('server','port')
389   except NoOptionError: port = 80
390
391   c.saddrs = [ ]
392   for addrspec in cfg.get('server','addrs').split():
393     sa = ServerAddr(port, addrspec)
394     c.saddrs.append(sa)
395
396 def process_cfg_clients(constructor):
397   c.clients = [ ]
398   for cs in cfg.sections():
399     if not (':' in cs or '.' in cs): continue
400     ci = ipaddr(cs)
401     pw = cfg.get(cs, 'password')
402     pw = pw.encode('utf-8')
403     constructor(ci,cs,pw)
404
405 #---------- startup ----------
406
407 def common_startup(process_cfg):
408   re = regexp.compile('#.*')
409   cfg.read_string(re.sub('', defcfg))
410   need_defcfg = True
411
412   def readconfig(pathname, mandatory=True):
413     def log(m, p=pathname):
414       if not DBG.CONFIG in debug_set: return
415       print('DBG.CONFIG: %s: %s' % (m, pathname))
416
417     try:
418       files = os.listdir(pathname)
419
420     except FileNotFoundError:
421       if mandatory: raise
422       log('skipped')
423       return
424
425     except NotADirectoryError:
426       cfg.read(pathname)
427       log('read file')
428       return
429
430     # is a directory
431     log('directory')
432     re = regexp.compile('[^-A-Za-z0-9_]')
433     for f in os.listdir(cdir):
434       if re.search(f): continue
435       subpath = pathname + '/' + f
436       try:
437         os.stat(subpath)
438       except FileNotFoundError:
439         log('entry skipped', subpath)
440         continue
441       cfg.read(subpath)
442       log('entry read', subpath)
443       
444   def oc_config(od,os, value, op):
445     nonlocal need_defcfg
446     need_defcfg = False
447     readconfig(value)
448
449   def dfs_less_detailed(dl):
450     return [df for df in DBG.iterconstants() if df <= dl]
451
452   def ds_default(od,os,dl,op):
453     global debug_set
454     debug_set = set(dfs_less_detailed(debug_def_detail))
455
456   def ds_select(od,os, spec, op):
457     for it in spec.split(','):
458
459       if it.startswith('-'):
460         mutator = debug_set.discard
461         it = it[1:]
462       else:
463         mutator = debug_set.add
464
465       if it == '+':
466         dfs = DBG.iterconstants()
467
468       else:
469         if it.endswith('+'):
470           mapper = dfs_less_detailed
471           it = it[0:len(it)-1]
472         else:
473           mapper = lambda x: [x]
474
475           try:
476             dfspec = DBG.lookupByName(it)
477           except ValueError:
478             optparser.error('unknown debug flag %s in --debug-select' % it)
479
480         dfs = mapper(dfspec)
481
482       for df in dfs:
483         mutator(df)
484
485   optparser.add_option('-D', '--debug',
486                        nargs=0,
487                        action='callback',
488                        help='enable default debug (to stdout)',
489                        callback= ds_default)
490
491   optparser.add_option('--debug-select',
492                        nargs=1,
493                        type='string',
494                        metavar='[-]DFLAG[+]|[-]+,...',
495                        help=
496 '''enable (`-': disable) each specified DFLAG;
497 `+': do same for all "more interesting" DFLAGSs;
498 just `+': all DFLAGs.
499   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
500                        action='callback',
501                        callback= ds_select)
502
503   optparser.add_option('-c', '--config',
504                        nargs=1,
505                        type='string',
506                        metavar='CONFIGFILE',
507                        dest='configfile',
508                        action='callback',
509                        callback= oc_config)
510
511   (opts, args) = optparser.parse_args()
512   if len(args): optparser.error('no non-option arguments please')
513
514   if need_defcfg:
515     readconfig('/etc/hippotat/config',   False)
516     readconfig('/etc/hippotat/config.d', False)
517
518   try: process_cfg()
519   except (configparser.Error, ValueError):
520     traceback.print_exc(file=sys.stderr)
521     print('\nInvalid configuration, giving up.', file=sys.stderr)
522     sys.exit(12)
523
524   #print(repr(debug_set), file=sys.stderr)
525
526   log_formatter = twisted.logger.formatEventAsClassicLogText
527   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
528   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
529   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
530   stdsomething_obs = twisted.logger.FilteringLogObserver(
531     stderr_obs, [pred], stdout_obs
532   )
533   log_observer = twisted.logger.FilteringLogObserver(
534     stdsomething_obs, [LogNotBoringTwisted()]
535   )
536   #log_observer = stdsomething_obs
537   twisted.logger.globalLogBeginner.beginLoggingTo(
538     [ log_observer, crash_on_critical ]
539     )
540
541 def common_run():
542   log_debug(DBG.INIT, 'entering reactor')
543   if not _crashing: reactor.run()
544   print('CRASHED (end)', file=sys.stderr)