chiark / gitweb /
new config definition
[hippotat] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag                         in debug_set: return yes
90       if dflag is None and DBG.TWISTED in debug_set: return yes
91       return no
92     except Exception:
93       print(traceback.format_exc(), file=org_stderr)
94       return yes
95
96 #---------- default config ----------
97
98 defcfg = '''
99 [DEFAULT]
100 max_batch_down = 65536
101 max_queue_time = 10
102 target_requests_outstanding = 3
103 http_timeout = 30
104 http_timeout_grace = 5
105 max_requests_outstanding = 6
106 max_batch_up = 4000
107 http_retry = 5
108
109 #[server] or [<client>] overrides
110 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
111
112 # relating to virtual network
113 mtu = 1500
114
115 [server]
116 # addrs = 127.0.0.1 ::1
117 port = 80
118 # url
119
120 # relating to virtual network
121 routes = ''
122 vnetwork = 172.24.230.192
123 # network = <prefix>/<len>
124 # server  = <ipaddr>
125 # relay   = <ipaddr>
126
127
128 # [<client-ip4-or-ipv6-address>]
129 # password = <password>    # used by both, must match
130
131 [limits]
132 max_batch_down = 262144
133 max_queue_time = 121
134 http_timeout = 121
135 target_requests_outstanding = 10
136 '''
137
138 # these need to be defined here so that they can be imported by import *
139 cfg = ConfigParser(strict=False)
140 optparser = OptionParser()
141
142 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
143 def mime_translate(s):
144   # SLIP-encoded packets cannot contain ESC ESC.
145   # Swap `-' and ESC.  The result cannot contain `--'
146   return s.translate(_mimetrans)
147
148 class ConfigResults:
149   def __init__(self, d = { }):
150     self.__dict__ = d
151   def __repr__(self):
152     return 'ConfigResults('+repr(self.__dict__)+')'
153
154 c = ConfigResults()
155
156 def log_discard(packet, iface, saddr, daddr, why):
157   log_debug(DBG.DROP,
158             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
159             d=packet)
160
161 #---------- packet parsing ----------
162
163 def packet_addrs(packet):
164   version = packet[0] >> 4
165   if version == 4:
166     addrlen = 4
167     saddroff = 3*4
168     factory = ipaddress.IPv4Address
169   elif version == 6:
170     addrlen = 16
171     saddroff = 2*4
172     factory = ipaddress.IPv6Address
173   else:
174     raise ValueError('unsupported IP version %d' % version)
175   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
176   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
177   return (saddr, daddr)
178
179 #---------- address handling ----------
180
181 def ipaddr(input):
182   try:
183     r = ipaddress.IPv4Address(input)
184   except AddressValueError:
185     r = ipaddress.IPv6Address(input)
186   return r
187
188 def ipnetwork(input):
189   try:
190     r = ipaddress.IPv4Network(input)
191   except NetworkValueError:
192     r = ipaddress.IPv6Network(input)
193   return r
194
195 #---------- ipif (SLIP) subprocess ----------
196
197 class SlipStreamDecoder():
198   def __init__(self, desc, on_packet):
199     self._buffer = b''
200     self._on_packet = on_packet
201     self._desc = desc
202     self._log('__init__')
203
204   def _log(self, msg, **kwargs):
205     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
206
207   def inputdata(self, data):
208     self._log('inputdata', d=data)
209     packets = slip.decode(data)
210     packets[0] = self._buffer + packets[0]
211     self._buffer = packets.pop()
212     for packet in packets:
213       self._maybe_packet(packet)
214     self._log('bufremain', d=self._buffer)
215
216   def _maybe_packet(self, packet):
217     self._log('maybepacket', d=packet)
218     if len(packet):
219       self._on_packet(packet)
220
221   def flush(self):
222     self._log('flush')
223     self._maybe_packet(self._buffer)
224     self._buffer = b''
225
226 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
227   def __init__(self, router):
228     self._router = router
229     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
230   def connectionMade(self): pass
231   def outReceived(self, data):
232     self._decoder.inputdata(data)
233   def slip_on_packet(self, packet):
234     (saddr, daddr) = packet_addrs(packet)
235     if saddr.is_link_local or daddr.is_link_local:
236       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
237       return
238     self._router(packet, saddr, daddr)
239   def processEnded(self, status):
240     status.raiseException()
241
242 def start_ipif(command, router):
243   global ipif
244   ipif = _IpifProcessProtocol(router)
245   reactor.spawnProcess(ipif,
246                        '/bin/sh',['sh','-xc', command],
247                        childFDs={0:'w', 1:'r', 2:2},
248                        env=None)
249
250 def queue_inbound(packet):
251   log_debug(DBG.FLOW, "queue_inbound", d=packet)
252   ipif.transport.write(slip.delimiter)
253   ipif.transport.write(slip.encode(packet))
254   ipif.transport.write(slip.delimiter)
255
256 #---------- packet queue ----------
257
258 class PacketQueue():
259   def __init__(self, desc, max_queue_time):
260     self._desc = desc
261     assert(desc + '')
262     self._max_queue_time = max_queue_time
263     self._pq = collections.deque() # packets
264
265   def _log(self, dflag, msg, **kwargs):
266     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
267
268   def append(self, packet):
269     self._log(DBG.QUEUE, 'append', d=packet)
270     self._pq.append((time.monotonic(), packet))
271
272   def nonempty(self):
273     self._log(DBG.QUEUE, 'nonempty ?')
274     while True:
275       try: (queuetime, packet) = self._pq[0]
276       except IndexError:
277         self._log(DBG.QUEUE, 'nonempty ? empty.')
278         return False
279
280       age = time.monotonic() - queuetime
281       if age > self._max_queue_time:
282         # strip old packets off the front
283         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
284         self._pq.popleft()
285         continue
286
287       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
288       return True
289
290   def process(self, sizequery, moredata, max_batch):
291     # sizequery() should return size of batch so far
292     # moredata(s) should add s to batch
293     self._log(DBG.QUEUE, 'process...')
294     while True:
295       try: (dummy, packet) = self._pq[0]
296       except IndexError:
297         self._log(DBG.QUEUE, 'process... empty')
298         break
299
300       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
301
302       encoded = slip.encode(packet)
303       sofar = sizequery()  
304
305       self._log(DBG.QUEUE_CTRL,
306                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
307                 d=encoded)
308
309       if sofar > 0:
310         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
311           self._log(DBG.QUEUE_CTRL, 'process... overflow')
312           break
313         moredata(slip.delimiter)
314
315       moredata(encoded)
316       self._pq.popleft()
317
318 #---------- error handling ----------
319
320 _crashing = False
321
322 def crash(err):
323   global _crashing
324   _crashing = True
325   print('========== CRASH ==========', err,
326         '===========================', file=sys.stderr)
327   try: reactor.stop()
328   except twisted.internet.error.ReactorNotRunning: pass
329
330 def crash_on_defer(defer):
331   defer.addErrback(lambda err: crash(err))
332
333 def crash_on_critical(event):
334   if event.get('log_level') >= LogLevel.critical:
335     crash(twisted.logger.formatEvent(event))
336
337 #---------- config processing ----------
338
339 def process_cfg_common_always():
340   global mtu
341   c.mtu = cfg.get('virtual','mtu')
342
343 def process_cfg_ipif(section, varmap):
344   for d, s in varmap:
345     try: v = getattr(c, s)
346     except AttributeError: continue
347     setattr(c, d, v)
348
349   #print(repr(c))
350
351   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
352
353 def process_cfg_network():
354   c.network = ipnetwork(cfg.get('virtual','network'))
355   if c.network.num_addresses < 3 + 2:
356     raise ValueError('network needs at least 2^3 addresses')
357
358 def process_cfg_server():
359   try:
360     c.server = cfg.get('virtual','server')
361   except NoOptionError:
362     process_cfg_network()
363     c.server = next(c.network.hosts())
364
365 class ServerAddr():
366   def __init__(self, port, addrspec):
367     self.port = port
368     # also self.addr
369     try:
370       self.addr = ipaddress.IPv4Address(addrspec)
371       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
372       self._inurl = b'%s'
373     except AddressValueError:
374       self.addr = ipaddress.IPv6Address(addrspec)
375       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
376       self._inurl = b'[%s]'
377   def make_endpoint(self):
378     return self._endpointfactory(reactor, self.port, self.addr)
379   def url(self):
380     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
381     if self.port != 80: url += b':%d' % self.port
382     url += b'/'
383     return url
384     
385 def process_cfg_saddrs():
386   try: port = cfg.getint('server','port')
387   except NoOptionError: port = 80
388
389   c.saddrs = [ ]
390   for addrspec in cfg.get('server','addrs').split():
391     sa = ServerAddr(port, addrspec)
392     c.saddrs.append(sa)
393
394 def process_cfg_clients(constructor):
395   c.clients = [ ]
396   for cs in cfg.sections():
397     if not (':' in cs or '.' in cs): continue
398     ci = ipaddr(cs)
399     pw = cfg.get(cs, 'password')
400     pw = pw.encode('utf-8')
401     constructor(ci,cs,pw)
402
403 #---------- startup ----------
404
405 def common_startup(process_cfg):
406   # ConfigParser hates #-comments after values
407   trailingcomments_re = regexp.compile('#.*')
408   cfg.read_string(trailingcomments_re.sub('', defcfg))
409   need_defcfg = True
410
411   def readconfig(pathname, mandatory=True):
412     def log(m, p=pathname):
413       if not DBG.CONFIG in debug_set: return
414       print('DBG.CONFIG: %s: %s' % (m, pathname))
415
416     try:
417       files = os.listdir(pathname)
418
419     except FileNotFoundError:
420       if mandatory: raise
421       log('skipped')
422       return
423
424     except NotADirectoryError:
425       cfg.read(pathname)
426       log('read file')
427       return
428
429     # is a directory
430     log('directory')
431     re = regexp.compile('[^-A-Za-z0-9_]')
432     for f in os.listdir(cdir):
433       if re.search(f): continue
434       subpath = pathname + '/' + f
435       try:
436         os.stat(subpath)
437       except FileNotFoundError:
438         log('entry skipped', subpath)
439         continue
440       cfg.read(subpath)
441       log('entry read', subpath)
442       
443   def oc_config(od,os, value, op):
444     nonlocal need_defcfg
445     need_defcfg = False
446     readconfig(value)
447
448   def dfs_less_detailed(dl):
449     return [df for df in DBG.iterconstants() if df <= dl]
450
451   def ds_default(od,os,dl,op):
452     global debug_set
453     debug_set = set(dfs_less_detailed(debug_def_detail))
454
455   def ds_select(od,os, spec, op):
456     for it in spec.split(','):
457
458       if it.startswith('-'):
459         mutator = debug_set.discard
460         it = it[1:]
461       else:
462         mutator = debug_set.add
463
464       if it == '+':
465         dfs = DBG.iterconstants()
466
467       else:
468         if it.endswith('+'):
469           mapper = dfs_less_detailed
470           it = it[0:len(it)-1]
471         else:
472           mapper = lambda x: [x]
473
474           try:
475             dfspec = DBG.lookupByName(it)
476           except ValueError:
477             optparser.error('unknown debug flag %s in --debug-select' % it)
478
479         dfs = mapper(dfspec)
480
481       for df in dfs:
482         mutator(df)
483
484   optparser.add_option('-D', '--debug',
485                        nargs=0,
486                        action='callback',
487                        help='enable default debug (to stdout)',
488                        callback= ds_default)
489
490   optparser.add_option('--debug-select',
491                        nargs=1,
492                        type='string',
493                        metavar='[-]DFLAG[+]|[-]+,...',
494                        help=
495 '''enable (`-': disable) each specified DFLAG;
496 `+': do same for all "more interesting" DFLAGSs;
497 just `+': all DFLAGs.
498   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
499                        action='callback',
500                        callback= ds_select)
501
502   optparser.add_option('-c', '--config',
503                        nargs=1,
504                        type='string',
505                        metavar='CONFIGFILE',
506                        dest='configfile',
507                        action='callback',
508                        callback= oc_config)
509
510   (opts, args) = optparser.parse_args()
511   if len(args): optparser.error('no non-option arguments please')
512
513   if need_defcfg:
514     readconfig('/etc/hippotat/config',   False)
515     readconfig('/etc/hippotat/config.d', False)
516
517   try: process_cfg()
518   except (configparser.Error, ValueError):
519     traceback.print_exc(file=sys.stderr)
520     print('\nInvalid configuration, giving up.', file=sys.stderr)
521     sys.exit(12)
522
523   #print(repr(debug_set), file=sys.stderr)
524
525   log_formatter = twisted.logger.formatEventAsClassicLogText
526   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
527   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
528   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
529   stdsomething_obs = twisted.logger.FilteringLogObserver(
530     stderr_obs, [pred], stdout_obs
531   )
532   log_observer = twisted.logger.FilteringLogObserver(
533     stdsomething_obs, [LogNotBoringTwisted()]
534   )
535   #log_observer = stdsomething_obs
536   twisted.logger.globalLogBeginner.beginLoggingTo(
537     [ log_observer, crash_on_critical ]
538     )
539
540 def common_run():
541   log_debug(DBG.INIT, 'entering reactor')
542   if not _crashing: reactor.run()
543   print('CRASHED (end)', file=sys.stderr)