chiark / gitweb /
comment re comments re
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag                         in debug_set: return yes
90       if dflag is None and DBG.TWISTED in debug_set: return yes
91       return no
92     except Exception:
93       print(traceback.format_exc(), file=org_stderr)
94       return yes
95
96 #---------- default config ----------
97
98 defcfg = '''
99 [DEFAULT]
100 #[<client>] overrides
101 max_batch_down = 65536           # used by server, subject to [limits]
102 max_queue_time = 10              # used by server, subject to [limits]
103 target_requests_outstanding = 3  # must match; subject to [limits] on server
104 http_timeout = 30                # used by both } must be
105 http_timeout_grace = 5           # used by both }  compatible
106 max_requests_outstanding = 4     # used by client
107 max_batch_up = 4000              # used by client
108 http_retry = 5                   # used by client
109
110 #[server] or [<client>] overrides
111 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
112 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
113 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
114 #      from   on client  <client>         [virtual]server   [virtual]routes
115
116 [virtual]
117 mtu = 1500
118 routes = ''
119 # network = <prefix>/<len>  # mandatory for server
120 # server  = <ipaddr>   # used by both, default is computed from `network'
121 # relay   = <ipaddr>   # used by server, default from `network' and `server'
122 #  default server is first host in network
123 #  default relay is first host which is not server
124
125 [server]
126 # addrs = 127.0.0.1 ::1    # mandatory for server
127 port = 80                  # used by server
128 # url              # used by client; default from first `addrs' and `port'
129
130 # [<client-ip4-or-ipv6-address>]
131 # password = <password>    # used by both, must match
132
133 [limits]
134 max_batch_down = 262144           # used by server
135 max_queue_time = 121              # used by server
136 http_timeout = 121                # used by server
137 target_requests_outstanding = 10  # used by server
138 '''
139
140 # these need to be defined here so that they can be imported by import *
141 cfg = ConfigParser(strict=False)
142 optparser = OptionParser()
143
144 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
145 def mime_translate(s):
146   # SLIP-encoded packets cannot contain ESC ESC.
147   # Swap `-' and ESC.  The result cannot contain `--'
148   return s.translate(_mimetrans)
149
150 class ConfigResults:
151   def __init__(self, d = { }):
152     self.__dict__ = d
153   def __repr__(self):
154     return 'ConfigResults('+repr(self.__dict__)+')'
155
156 c = ConfigResults()
157
158 def log_discard(packet, iface, saddr, daddr, why):
159   log_debug(DBG.DROP,
160             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
161             d=packet)
162
163 #---------- packet parsing ----------
164
165 def packet_addrs(packet):
166   version = packet[0] >> 4
167   if version == 4:
168     addrlen = 4
169     saddroff = 3*4
170     factory = ipaddress.IPv4Address
171   elif version == 6:
172     addrlen = 16
173     saddroff = 2*4
174     factory = ipaddress.IPv6Address
175   else:
176     raise ValueError('unsupported IP version %d' % version)
177   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
178   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
179   return (saddr, daddr)
180
181 #---------- address handling ----------
182
183 def ipaddr(input):
184   try:
185     r = ipaddress.IPv4Address(input)
186   except AddressValueError:
187     r = ipaddress.IPv6Address(input)
188   return r
189
190 def ipnetwork(input):
191   try:
192     r = ipaddress.IPv4Network(input)
193   except NetworkValueError:
194     r = ipaddress.IPv6Network(input)
195   return r
196
197 #---------- ipif (SLIP) subprocess ----------
198
199 class SlipStreamDecoder():
200   def __init__(self, desc, on_packet):
201     self._buffer = b''
202     self._on_packet = on_packet
203     self._desc = desc
204     self._log('__init__')
205
206   def _log(self, msg, **kwargs):
207     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
208
209   def inputdata(self, data):
210     self._log('inputdata', d=data)
211     packets = slip.decode(data)
212     packets[0] = self._buffer + packets[0]
213     self._buffer = packets.pop()
214     for packet in packets:
215       self._maybe_packet(packet)
216     self._log('bufremain', d=self._buffer)
217
218   def _maybe_packet(self, packet):
219     self._log('maybepacket', d=packet)
220     if len(packet):
221       self._on_packet(packet)
222
223   def flush(self):
224     self._log('flush')
225     self._maybe_packet(self._buffer)
226     self._buffer = b''
227
228 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
229   def __init__(self, router):
230     self._router = router
231     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
232   def connectionMade(self): pass
233   def outReceived(self, data):
234     self._decoder.inputdata(data)
235   def slip_on_packet(self, packet):
236     (saddr, daddr) = packet_addrs(packet)
237     if saddr.is_link_local or daddr.is_link_local:
238       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
239       return
240     self._router(packet, saddr, daddr)
241   def processEnded(self, status):
242     status.raiseException()
243
244 def start_ipif(command, router):
245   global ipif
246   ipif = _IpifProcessProtocol(router)
247   reactor.spawnProcess(ipif,
248                        '/bin/sh',['sh','-xc', command],
249                        childFDs={0:'w', 1:'r', 2:2},
250                        env=None)
251
252 def queue_inbound(packet):
253   log_debug(DBG.FLOW, "queue_inbound", d=packet)
254   ipif.transport.write(slip.delimiter)
255   ipif.transport.write(slip.encode(packet))
256   ipif.transport.write(slip.delimiter)
257
258 #---------- packet queue ----------
259
260 class PacketQueue():
261   def __init__(self, desc, max_queue_time):
262     self._desc = desc
263     assert(desc + '')
264     self._max_queue_time = max_queue_time
265     self._pq = collections.deque() # packets
266
267   def _log(self, dflag, msg, **kwargs):
268     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
269
270   def append(self, packet):
271     self._log(DBG.QUEUE, 'append', d=packet)
272     self._pq.append((time.monotonic(), packet))
273
274   def nonempty(self):
275     self._log(DBG.QUEUE, 'nonempty ?')
276     while True:
277       try: (queuetime, packet) = self._pq[0]
278       except IndexError:
279         self._log(DBG.QUEUE, 'nonempty ? empty.')
280         return False
281
282       age = time.monotonic() - queuetime
283       if age > self._max_queue_time:
284         # strip old packets off the front
285         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
286         self._pq.popleft()
287         continue
288
289       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
290       return True
291
292   def process(self, sizequery, moredata, max_batch):
293     # sizequery() should return size of batch so far
294     # moredata(s) should add s to batch
295     self._log(DBG.QUEUE, 'process...')
296     while True:
297       try: (dummy, packet) = self._pq[0]
298       except IndexError:
299         self._log(DBG.QUEUE, 'process... empty')
300         break
301
302       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
303
304       encoded = slip.encode(packet)
305       sofar = sizequery()  
306
307       self._log(DBG.QUEUE_CTRL,
308                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
309                 d=encoded)
310
311       if sofar > 0:
312         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
313           self._log(DBG.QUEUE_CTRL, 'process... overflow')
314           break
315         moredata(slip.delimiter)
316
317       moredata(encoded)
318       self._pq.popleft()
319
320 #---------- error handling ----------
321
322 _crashing = False
323
324 def crash(err):
325   global _crashing
326   _crashing = True
327   print('========== CRASH ==========', err,
328         '===========================', file=sys.stderr)
329   try: reactor.stop()
330   except twisted.internet.error.ReactorNotRunning: pass
331
332 def crash_on_defer(defer):
333   defer.addErrback(lambda err: crash(err))
334
335 def crash_on_critical(event):
336   if event.get('log_level') >= LogLevel.critical:
337     crash(twisted.logger.formatEvent(event))
338
339 #---------- config processing ----------
340
341 def process_cfg_common_always():
342   global mtu
343   c.mtu = cfg.get('virtual','mtu')
344
345 def process_cfg_ipif(section, varmap):
346   for d, s in varmap:
347     try: v = getattr(c, s)
348     except AttributeError: continue
349     setattr(c, d, v)
350
351   #print(repr(c))
352
353   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
354
355 def process_cfg_network():
356   c.network = ipnetwork(cfg.get('virtual','network'))
357   if c.network.num_addresses < 3 + 2:
358     raise ValueError('network needs at least 2^3 addresses')
359
360 def process_cfg_server():
361   try:
362     c.server = cfg.get('virtual','server')
363   except NoOptionError:
364     process_cfg_network()
365     c.server = next(c.network.hosts())
366
367 class ServerAddr():
368   def __init__(self, port, addrspec):
369     self.port = port
370     # also self.addr
371     try:
372       self.addr = ipaddress.IPv4Address(addrspec)
373       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
374       self._inurl = b'%s'
375     except AddressValueError:
376       self.addr = ipaddress.IPv6Address(addrspec)
377       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
378       self._inurl = b'[%s]'
379   def make_endpoint(self):
380     return self._endpointfactory(reactor, self.port, self.addr)
381   def url(self):
382     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
383     if self.port != 80: url += b':%d' % self.port
384     url += b'/'
385     return url
386     
387 def process_cfg_saddrs():
388   try: port = cfg.getint('server','port')
389   except NoOptionError: port = 80
390
391   c.saddrs = [ ]
392   for addrspec in cfg.get('server','addrs').split():
393     sa = ServerAddr(port, addrspec)
394     c.saddrs.append(sa)
395
396 def process_cfg_clients(constructor):
397   c.clients = [ ]
398   for cs in cfg.sections():
399     if not (':' in cs or '.' in cs): continue
400     ci = ipaddr(cs)
401     pw = cfg.get(cs, 'password')
402     pw = pw.encode('utf-8')
403     constructor(ci,cs,pw)
404
405 #---------- startup ----------
406
407 def common_startup(process_cfg):
408   # ConfigParser hates #-comments after values
409   trailingcomments_re = regexp.compile('#.*')
410   cfg.read_string(trailingcomments_re.sub('', defcfg))
411   need_defcfg = True
412
413   def readconfig(pathname, mandatory=True):
414     def log(m, p=pathname):
415       if not DBG.CONFIG in debug_set: return
416       print('DBG.CONFIG: %s: %s' % (m, pathname))
417
418     try:
419       files = os.listdir(pathname)
420
421     except FileNotFoundError:
422       if mandatory: raise
423       log('skipped')
424       return
425
426     except NotADirectoryError:
427       cfg.read(pathname)
428       log('read file')
429       return
430
431     # is a directory
432     log('directory')
433     re = regexp.compile('[^-A-Za-z0-9_]')
434     for f in os.listdir(cdir):
435       if re.search(f): continue
436       subpath = pathname + '/' + f
437       try:
438         os.stat(subpath)
439       except FileNotFoundError:
440         log('entry skipped', subpath)
441         continue
442       cfg.read(subpath)
443       log('entry read', subpath)
444       
445   def oc_config(od,os, value, op):
446     nonlocal need_defcfg
447     need_defcfg = False
448     readconfig(value)
449
450   def dfs_less_detailed(dl):
451     return [df for df in DBG.iterconstants() if df <= dl]
452
453   def ds_default(od,os,dl,op):
454     global debug_set
455     debug_set = set(dfs_less_detailed(debug_def_detail))
456
457   def ds_select(od,os, spec, op):
458     for it in spec.split(','):
459
460       if it.startswith('-'):
461         mutator = debug_set.discard
462         it = it[1:]
463       else:
464         mutator = debug_set.add
465
466       if it == '+':
467         dfs = DBG.iterconstants()
468
469       else:
470         if it.endswith('+'):
471           mapper = dfs_less_detailed
472           it = it[0:len(it)-1]
473         else:
474           mapper = lambda x: [x]
475
476           try:
477             dfspec = DBG.lookupByName(it)
478           except ValueError:
479             optparser.error('unknown debug flag %s in --debug-select' % it)
480
481         dfs = mapper(dfspec)
482
483       for df in dfs:
484         mutator(df)
485
486   optparser.add_option('-D', '--debug',
487                        nargs=0,
488                        action='callback',
489                        help='enable default debug (to stdout)',
490                        callback= ds_default)
491
492   optparser.add_option('--debug-select',
493                        nargs=1,
494                        type='string',
495                        metavar='[-]DFLAG[+]|[-]+,...',
496                        help=
497 '''enable (`-': disable) each specified DFLAG;
498 `+': do same for all "more interesting" DFLAGSs;
499 just `+': all DFLAGs.
500   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
501                        action='callback',
502                        callback= ds_select)
503
504   optparser.add_option('-c', '--config',
505                        nargs=1,
506                        type='string',
507                        metavar='CONFIGFILE',
508                        dest='configfile',
509                        action='callback',
510                        callback= oc_config)
511
512   (opts, args) = optparser.parse_args()
513   if len(args): optparser.error('no non-option arguments please')
514
515   if need_defcfg:
516     readconfig('/etc/hippotat/config',   False)
517     readconfig('/etc/hippotat/config.d', False)
518
519   try: process_cfg()
520   except (configparser.Error, ValueError):
521     traceback.print_exc(file=sys.stderr)
522     print('\nInvalid configuration, giving up.', file=sys.stderr)
523     sys.exit(12)
524
525   #print(repr(debug_set), file=sys.stderr)
526
527   log_formatter = twisted.logger.formatEventAsClassicLogText
528   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
529   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
530   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
531   stdsomething_obs = twisted.logger.FilteringLogObserver(
532     stderr_obs, [pred], stdout_obs
533   )
534   log_observer = twisted.logger.FilteringLogObserver(
535     stdsomething_obs, [LogNotBoringTwisted()]
536   )
537   #log_observer = stdsomething_obs
538   twisted.logger.globalLogBeginner.beginLoggingTo(
539     [ log_observer, crash_on_critical ]
540     )
541
542 def common_run():
543   log_debug(DBG.INIT, 'entering reactor')
544   if not _crashing: reactor.run()
545   print('CRASHED (end)', file=sys.stderr)