chiark / gitweb /
bb1f606f137327b016e94f72fb2ed86b186c98f3
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 from zope.interface import implementer
9
10 import twisted
11 from twisted.internet import reactor
12 import twisted.internet.endpoints
13 import twisted.logger
14 from twisted.logger import LogLevel
15 import twisted.python.constants
16 from twisted.python.constants import NamedConstant
17
18 import ipaddress
19 from ipaddress import AddressValueError
20
21 from optparse import OptionParser
22 from configparser import ConfigParser
23 from configparser import NoOptionError
24
25 from functools import partial
26
27 import collections
28 import time
29 import codecs
30 import traceback
31
32 import re as regexp
33
34 import hippotat.slip as slip
35
36 class DBG(twisted.python.constants.Names):
37   INIT = NamedConstant()
38   ROUTE = NamedConstant()
39   DROP = NamedConstant()
40   FLOW = NamedConstant()
41   HTTP = NamedConstant()
42   TWISTED = NamedConstant()
43   QUEUE = NamedConstant()
44   HTTP_CTRL = NamedConstant()
45   QUEUE_CTRL = NamedConstant()
46   HTTP_FULL = NamedConstant()
47   CTRL_DUMP = NamedConstant()
48   SLIP_FULL = NamedConstant()
49   DATA_COMPLETE = NamedConstant()
50
51 _hex_codec = codecs.getencoder('hex_codec')
52
53 #---------- logging ----------
54
55 org_stderr = sys.stderr
56
57 log = twisted.logger.Logger()
58
59 debug_set = set()
60 debug_def_detail = DBG.HTTP
61
62 def log_debug(dflag, msg, idof=None, d=None):
63   if dflag not in debug_set: return
64   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
65   if idof is not None:
66     msg = '[%#x] %s' % (id(idof), msg)
67   if d is not None:
68     trunc = ''
69     if not DBG.DATA_COMPLETE in debug_set:
70       if len(d) > 64:
71         d = d[0:64]
72         trunc = '...'
73     d = _hex_codec(d)[0].decode('ascii')
74     msg += ' ' + d + trunc
75   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
76
77 @implementer(twisted.logger.ILogFilterPredicate)
78 class LogNotBoringTwisted:
79   def __call__(self, event):
80     yes = twisted.logger.PredicateResult.yes
81     no  = twisted.logger.PredicateResult.no
82     try:
83       if event.get('log_level') != LogLevel.info:
84         return yes
85       dflag = event.get('dflag')
86       if dflag                         in debug_set: return yes
87       if dflag is None and DBG.TWISTED in debug_set: return yes
88       return no
89     except Exception:
90       print(traceback.format_exc(), file=org_stderr)
91       return yes
92
93 #---------- default config ----------
94
95 defcfg = '''
96 [DEFAULT]
97 #[<client>] overrides
98 max_batch_down = 65536           # used by server, subject to [limits]
99 max_queue_time = 10              # used by server, subject to [limits]
100 target_requests_outstanding = 3  # must match; subject to [limits] on server
101 http_timeout = 30                # used by both } must be
102 http_timeout_grace = 5           # used by both }  compatible
103 max_requests_outstanding = 4     # used by client
104 max_batch_up = 4000              # used by client
105 http_retry = 5                   # used by client
106
107 #[server] or [<client>] overrides
108 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
109 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
110 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
111 #      from   on client  <client>         [virtual]server   [virtual]routes
112
113 [virtual]
114 mtu = 1500
115 routes = ''
116 # network = <prefix>/<len>  # mandatory for server
117 # server  = <ipaddr>   # used by both, default is computed from `network'
118 # relay   = <ipaddr>   # used by server, default from `network' and `server'
119 #  default server is first host in network
120 #  default relay is first host which is not server
121
122 [server]
123 # addrs = 127.0.0.1 ::1    # mandatory for server
124 port = 80                  # used by server
125 # url              # used by client; default from first `addrs' and `port'
126
127 # [<client-ip4-or-ipv6-address>]
128 # password = <password>    # used by both, must match
129
130 [limits]
131 max_batch_down = 262144           # used by server
132 max_queue_time = 121              # used by server
133 http_timeout = 121                # used by server
134 target_requests_outstanding = 10  # used by server
135 '''
136
137 # these need to be defined here so that they can be imported by import *
138 cfg = ConfigParser()
139 optparser = OptionParser()
140
141 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
142 def mime_translate(s):
143   # SLIP-encoded packets cannot contain ESC ESC.
144   # Swap `-' and ESC.  The result cannot contain `--'
145   return s.translate(_mimetrans)
146
147 class ConfigResults:
148   def __init__(self, d = { }):
149     self.__dict__ = d
150   def __repr__(self):
151     return 'ConfigResults('+repr(self.__dict__)+')'
152
153 c = ConfigResults()
154
155 def log_discard(packet, iface, saddr, daddr, why):
156   log_debug(DBG.DROP,
157             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
158             d=packet)
159
160 #---------- packet parsing ----------
161
162 def packet_addrs(packet):
163   version = packet[0] >> 4
164   if version == 4:
165     addrlen = 4
166     saddroff = 3*4
167     factory = ipaddress.IPv4Address
168   elif version == 6:
169     addrlen = 16
170     saddroff = 2*4
171     factory = ipaddress.IPv6Address
172   else:
173     raise ValueError('unsupported IP version %d' % version)
174   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
175   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
176   return (saddr, daddr)
177
178 #---------- address handling ----------
179
180 def ipaddr(input):
181   try:
182     r = ipaddress.IPv4Address(input)
183   except AddressValueError:
184     r = ipaddress.IPv6Address(input)
185   return r
186
187 def ipnetwork(input):
188   try:
189     r = ipaddress.IPv4Network(input)
190   except NetworkValueError:
191     r = ipaddress.IPv6Network(input)
192   return r
193
194 #---------- ipif (SLIP) subprocess ----------
195
196 class SlipStreamDecoder():
197   def __init__(self, desc, on_packet):
198     self._buffer = b''
199     self._on_packet = on_packet
200     self._desc = desc
201     self._log('__init__')
202
203   def _log(self, msg, **kwargs):
204     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
205
206   def inputdata(self, data):
207     self._log('inputdata', d=data)
208     packets = slip.decode(data)
209     packets[0] = self._buffer + packets[0]
210     self._buffer = packets.pop()
211     for packet in packets:
212       self._maybe_packet(packet)
213     self._log('bufremain', d=self._buffer)
214
215   def _maybe_packet(self, packet):
216     self._log('maybepacket', d=packet)
217     if len(packet):
218       self._on_packet(packet)
219
220   def flush(self):
221     self._log('flush')
222     self._maybe_packet(self._buffer)
223     self._buffer = b''
224
225 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
226   def __init__(self, router):
227     self._router = router
228     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
229   def connectionMade(self): pass
230   def outReceived(self, data):
231     self._decoder.inputdata(data)
232   def slip_on_packet(self, packet):
233     (saddr, daddr) = packet_addrs(packet)
234     if saddr.is_link_local or daddr.is_link_local:
235       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
236       return
237     self._router(packet, saddr, daddr)
238   def processEnded(self, status):
239     status.raiseException()
240
241 def start_ipif(command, router):
242   global ipif
243   ipif = _IpifProcessProtocol(router)
244   reactor.spawnProcess(ipif,
245                        '/bin/sh',['sh','-xc', command],
246                        childFDs={0:'w', 1:'r', 2:2},
247                        env=None)
248
249 def queue_inbound(packet):
250   log_debug(DBG.FLOW, "queue_inbound", d=packet)
251   ipif.transport.write(slip.delimiter)
252   ipif.transport.write(slip.encode(packet))
253   ipif.transport.write(slip.delimiter)
254
255 #---------- packet queue ----------
256
257 class PacketQueue():
258   def __init__(self, desc, max_queue_time):
259     self._desc = desc
260     assert(desc + '')
261     self._max_queue_time = max_queue_time
262     self._pq = collections.deque() # packets
263
264   def _log(self, dflag, msg, **kwargs):
265     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
266
267   def append(self, packet):
268     self._log(DBG.QUEUE, 'append', d=packet)
269     self._pq.append((time.monotonic(), packet))
270
271   def nonempty(self):
272     self._log(DBG.QUEUE, 'nonempty ?')
273     while True:
274       try: (queuetime, packet) = self._pq[0]
275       except IndexError:
276         self._log(DBG.QUEUE, 'nonempty ? empty.')
277         return False
278
279       age = time.monotonic() - queuetime
280       if age > self._max_queue_time:
281         # strip old packets off the front
282         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
283         self._pq.popleft()
284         continue
285
286       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
287       return True
288
289   def process(self, sizequery, moredata, max_batch):
290     # sizequery() should return size of batch so far
291     # moredata(s) should add s to batch
292     self._log(DBG.QUEUE, 'process...')
293     while True:
294       try: (dummy, packet) = self._pq[0]
295       except IndexError:
296         self._log(DBG.QUEUE, 'process... empty')
297         break
298
299       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
300
301       encoded = slip.encode(packet)
302       sofar = sizequery()  
303
304       self._log(DBG.QUEUE_CTRL,
305                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
306                 d=encoded)
307
308       if sofar > 0:
309         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
310           self._log(DBG.QUEUE_CTRL, 'process... overflow')
311           break
312         moredata(slip.delimiter)
313
314       moredata(encoded)
315       self._pq.popleft()
316
317 #---------- error handling ----------
318
319 _crashing = False
320
321 def crash(err):
322   global _crashing
323   _crashing = True
324   print('========== CRASH ==========', err,
325         '===========================', file=sys.stderr)
326   try: reactor.stop()
327   except twisted.internet.error.ReactorNotRunning: pass
328
329 def crash_on_defer(defer):
330   defer.addErrback(lambda err: crash(err))
331
332 def crash_on_critical(event):
333   if event.get('log_level') >= LogLevel.critical:
334     crash(twisted.logger.formatEvent(event))
335
336 #---------- config processing ----------
337
338 def process_cfg_common_always():
339   global mtu
340   c.mtu = cfg.get('virtual','mtu')
341
342 def process_cfg_ipif(section, varmap):
343   for d, s in varmap:
344     try: v = getattr(c, s)
345     except AttributeError: continue
346     setattr(c, d, v)
347
348   #print(repr(c))
349
350   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
351
352 def process_cfg_network():
353   c.network = ipnetwork(cfg.get('virtual','network'))
354   if c.network.num_addresses < 3 + 2:
355     raise ValueError('network needs at least 2^3 addresses')
356
357 def process_cfg_server():
358   try:
359     c.server = cfg.get('virtual','server')
360   except NoOptionError:
361     process_cfg_network()
362     c.server = next(c.network.hosts())
363
364 class ServerAddr():
365   def __init__(self, port, addrspec):
366     self.port = port
367     # also self.addr
368     try:
369       self.addr = ipaddress.IPv4Address(addrspec)
370       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
371       self._inurl = b'%s'
372     except AddressValueError:
373       self.addr = ipaddress.IPv6Address(addrspec)
374       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
375       self._inurl = b'[%s]'
376   def make_endpoint(self):
377     return self._endpointfactory(reactor, self.port, self.addr)
378   def url(self):
379     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
380     if self.port != 80: url += b':%d' % self.port
381     url += b'/'
382     return url
383     
384 def process_cfg_saddrs():
385   try: port = cfg.getint('server','port')
386   except NoOptionError: port = 80
387
388   c.saddrs = [ ]
389   for addrspec in cfg.get('server','addrs').split():
390     sa = ServerAddr(port, addrspec)
391     c.saddrs.append(sa)
392
393 def process_cfg_clients(constructor):
394   c.clients = [ ]
395   for cs in cfg.sections():
396     if not (':' in cs or '.' in cs): continue
397     ci = ipaddr(cs)
398     pw = cfg.get(cs, 'password')
399     pw = pw.encode('utf-8')
400     constructor(ci,cs,pw)
401
402 #---------- startup ----------
403
404 def common_startup():
405   optparser.add_option('-c', '--config', dest='configfile',
406                        default='/etc/hippotat/config')
407
408   def dfs_less_detailed(dl):
409     return [df for df in DBG.iterconstants() if df <= dl]
410
411   def ds_default(od,os,dl,op):
412     global debug_set
413     debug_set = set(dfs_less_detailed(debug_def_detail))
414
415   def ds_select(od,os, spec, op):
416     for it in spec.split(','):
417
418       if it.startswith('-'):
419         mutator = debug_set.discard
420         it = it[1:]
421       else:
422         mutator = debug_set.add
423
424       if it == '+':
425         dfs = DBG.iterconstants()
426
427       else:
428         if it.endswith('+'):
429           mapper = dfs_less_detailed
430           it = it[0:len(it)-1]
431         else:
432           mapper = lambda x: [x]
433
434           try:
435             dfspec = DBG.lookupByName(it)
436           except ValueError:
437             optparser.error('unknown debug flag %s in --debug-select' % it)
438
439         dfs = mapper(dfspec)
440
441       for df in dfs:
442         mutator(df)
443
444   optparser.add_option('-D', '--debug',
445                        nargs=0,
446                        action='callback',
447                        help='enable default debug (to stdout)',
448                        callback= ds_default)
449
450   optparser.add_option('--debug-select',
451                        nargs=1,
452                        type='string',
453                        metavar='[-]DFLAG[+]|[-]+,...',
454                        help=
455 '''enable (`-': disable) each specified DFLAG;
456 `+': do same for all "more interesting" DFLAGSs;
457 just `+': all DFLAGs.
458   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
459                        action='callback',
460                        callback= ds_select)
461
462   (opts, args) = optparser.parse_args()
463   if len(args): optparser.error('no non-option arguments please')
464
465   print(repr(debug_set), file=sys.stderr)
466
467   re = regexp.compile('#.*')
468   cfg.read_string(re.sub('', defcfg))
469   cfg.read(opts.configfile)
470
471   log_formatter = twisted.logger.formatEventAsClassicLogText
472   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
473   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
474   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
475   stdsomething_obs = twisted.logger.FilteringLogObserver(
476     stderr_obs, [pred], stdout_obs
477   )
478   log_observer = twisted.logger.FilteringLogObserver(
479     stdsomething_obs, [LogNotBoringTwisted()]
480   )
481   #log_observer = stdsomething_obs
482   twisted.logger.globalLogBeginner.beginLoggingTo(
483     [ log_observer, crash_on_critical ]
484     )
485
486 def common_run():
487   log_debug(DBG.INIT, 'entering reactor')
488   if not _crashing: reactor.run()
489   print('CRASHED (end)', file=sys.stderr)