chiark / gitweb /
c31bdb786d37aeda5a65d6d57a810074dcf7e008
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 from zope.interface import implementer
9
10 import twisted
11 from twisted.internet import reactor
12 import twisted.internet.endpoints
13 import twisted.logger
14 from twisted.logger import LogLevel
15 import twisted.python.constants
16 from twisted.python.constants import NamedConstant
17
18 import ipaddress
19 from ipaddress import AddressValueError
20
21 from optparse import OptionParser
22 from configparser import ConfigParser
23 from configparser import NoOptionError
24
25 from functools import partial
26
27 import collections
28 import time
29 import codecs
30 import traceback
31
32 import re as regexp
33
34 import hippotat.slip as slip
35
36 class DBG(twisted.python.constants.Names):
37   INIT = NamedConstant()
38   ROUTE = NamedConstant()
39   DROP = NamedConstant()
40   FLOW = NamedConstant()
41   HTTP = NamedConstant()
42   TWISTED = NamedConstant()
43   QUEUE = NamedConstant()
44   HTTP_CTRL = NamedConstant()
45   QUEUE_CTRL = NamedConstant()
46   HTTP_FULL = NamedConstant()
47   CTRL_DUMP = NamedConstant()
48   SLIP_FULL = NamedConstant()
49
50 _hex_codec = codecs.getencoder('hex_codec')
51
52 #---------- logging ----------
53
54 org_stderr = sys.stderr
55
56 log = twisted.logger.Logger()
57
58 debug_set = set()
59 debug_def_detail = DBG.HTTP
60
61 def log_debug(dflag, msg, idof=None, d=None):
62   if dflag not in debug_set: return
63   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
64   if idof is not None:
65     msg = '[%#x] %s' % (id(idof), msg)
66   if d is not None:
67     #d = d[0:64]
68     d = _hex_codec(d)[0].decode('ascii')
69     msg += ' ' + d
70   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
71
72 @implementer(twisted.logger.ILogFilterPredicate)
73 class LogNotBoringTwisted:
74   def __call__(self, event):
75     yes = twisted.logger.PredicateResult.yes
76     no  = twisted.logger.PredicateResult.no
77     return yes
78     try:
79       if event.get('log_level') != LogLevel.info:
80         return yes
81       try:
82         dflag = event.get('dflag')
83       except KeyError:
84         dflag = DBG.TWISTED
85       return yes if (dflag in debug_set) else no
86     except Exception:
87       print(traceback.format_exc(), file=org_stderr)
88       return yes
89
90 #---------- default config ----------
91
92 defcfg = '''
93 [DEFAULT]
94 #[<client>] overrides
95 max_batch_down = 65536           # used by server, subject to [limits]
96 max_queue_time = 10              # used by server, subject to [limits]
97 target_requests_outstanding = 3  # must match; subject to [limits] on server
98 http_timeout = 30                # used by both } must be
99 http_timeout_grace = 5           # used by both }  compatible
100 max_requests_outstanding = 4     # used by client
101 max_batch_up = 4000              # used by client
102 http_retry = 5                   # used by client
103
104 #[server] or [<client>] overrides
105 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
106 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
107 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
108 #      from   on client  <client>         [virtual]server   [virtual]routes
109
110 [virtual]
111 mtu = 1500
112 routes = ''
113 # network = <prefix>/<len>  # mandatory for server
114 # server  = <ipaddr>   # used by both, default is computed from `network'
115 # relay   = <ipaddr>   # used by server, default from `network' and `server'
116 #  default server is first host in network
117 #  default relay is first host which is not server
118
119 [server]
120 # addrs = 127.0.0.1 ::1    # mandatory for server
121 port = 80                  # used by server
122 # url              # used by client; default from first `addrs' and `port'
123
124 # [<client-ip4-or-ipv6-address>]
125 # password = <password>    # used by both, must match
126
127 [limits]
128 max_batch_down = 262144           # used by server
129 max_queue_time = 121              # used by server
130 http_timeout = 121                # used by server
131 target_requests_outstanding = 10  # used by server
132 '''
133
134 # these need to be defined here so that they can be imported by import *
135 cfg = ConfigParser()
136 optparser = OptionParser()
137
138 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
139 def mime_translate(s):
140   # SLIP-encoded packets cannot contain ESC ESC.
141   # Swap `-' and ESC.  The result cannot contain `--'
142   return s.translate(_mimetrans)
143
144 class ConfigResults:
145   def __init__(self, d = { }):
146     self.__dict__ = d
147   def __repr__(self):
148     return 'ConfigResults('+repr(self.__dict__)+')'
149
150 c = ConfigResults()
151
152 def log_discard(packet, iface, saddr, daddr, why):
153   log_debug(DBG.DROP,
154             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
155             d=packet)
156
157 #---------- packet parsing ----------
158
159 def packet_addrs(packet):
160   version = packet[0] >> 4
161   if version == 4:
162     addrlen = 4
163     saddroff = 3*4
164     factory = ipaddress.IPv4Address
165   elif version == 6:
166     addrlen = 16
167     saddroff = 2*4
168     factory = ipaddress.IPv6Address
169   else:
170     raise ValueError('unsupported IP version %d' % version)
171   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
172   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
173   return (saddr, daddr)
174
175 #---------- address handling ----------
176
177 def ipaddr(input):
178   try:
179     r = ipaddress.IPv4Address(input)
180   except AddressValueError:
181     r = ipaddress.IPv6Address(input)
182   return r
183
184 def ipnetwork(input):
185   try:
186     r = ipaddress.IPv4Network(input)
187   except NetworkValueError:
188     r = ipaddress.IPv6Network(input)
189   return r
190
191 #---------- ipif (SLIP) subprocess ----------
192
193 class SlipStreamDecoder():
194   def __init__(self, desc, on_packet):
195     self._buffer = b''
196     self._on_packet = on_packet
197     self._desc = desc
198     self._log('__init__')
199
200   def _log(self, msg, **kwargs):
201     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
202
203   def inputdata(self, data):
204     self._log('inputdata', d=data)
205     packets = slip.decode(data)
206     packets[0] = self._buffer + packets[0]
207     self._buffer = packets.pop()
208     for packet in packets:
209       self._maybe_packet(packet)
210     self._log('bufremain', d=self._buffer)
211
212   def _maybe_packet(self, packet):
213     self._log('maybepacket', d=packet)
214     if len(packet):
215       self._on_packet(packet)
216
217   def flush(self):
218     self._log('flush')
219     self._maybe_packet(self._buffer)
220     self._buffer = b''
221
222 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
223   def __init__(self, router):
224     self._router = router
225     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
226   def connectionMade(self): pass
227   def outReceived(self, data):
228     self._decoder.inputdata(data)
229   def slip_on_packet(self, packet):
230     (saddr, daddr) = packet_addrs(packet)
231     if saddr.is_link_local or daddr.is_link_local:
232       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
233       return
234     self._router(packet, saddr, daddr)
235   def processEnded(self, status):
236     status.raiseException()
237
238 def start_ipif(command, router):
239   global ipif
240   ipif = _IpifProcessProtocol(router)
241   reactor.spawnProcess(ipif,
242                        '/bin/sh',['sh','-xc', command],
243                        childFDs={0:'w', 1:'r', 2:2},
244                        env=None)
245
246 def queue_inbound(packet):
247   log_debug(DBG.FLOW, "queue_inbound", d=packet)
248   ipif.transport.write(slip.delimiter)
249   ipif.transport.write(slip.encode(packet))
250   ipif.transport.write(slip.delimiter)
251
252 #---------- packet queue ----------
253
254 class PacketQueue():
255   def __init__(self, desc, max_queue_time):
256     self._desc = desc
257     assert(desc + '')
258     self._max_queue_time = max_queue_time
259     self._pq = collections.deque() # packets
260
261   def _log(self, dflag, msg, **kwargs):
262     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
263
264   def append(self, packet):
265     self._log(DBG.QUEUE, 'append', d=packet)
266     self._pq.append((time.monotonic(), packet))
267
268   def nonempty(self):
269     self._log(DBG.QUEUE, 'nonempty ?')
270     while True:
271       try: (queuetime, packet) = self._pq[0]
272       except IndexError:
273         self._log(DBG.QUEUE, 'nonempty ? empty.')
274         return False
275
276       age = time.monotonic() - queuetime
277       if age > self._max_queue_time:
278         # strip old packets off the front
279         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
280         self._pq.popleft()
281         continue
282
283       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
284       return True
285
286   def process(self, sizequery, moredata, max_batch):
287     # sizequery() should return size of batch so far
288     # moredata(s) should add s to batch
289     self._log(DBG.QUEUE, 'process...')
290     while True:
291       try: (dummy, packet) = self._pq[0]
292       except IndexError:
293         self._log(DBG.QUEUE, 'process... empty')
294         break
295
296       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
297
298       encoded = slip.encode(packet)
299       sofar = sizequery()  
300
301       self._log(DBG.QUEUE_CTRL,
302                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
303                 d=encoded)
304
305       if sofar > 0:
306         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
307           self._log(DBG.QUEUE_CTRL, 'process... overflow')
308           break
309         moredata(slip.delimiter)
310
311       moredata(encoded)
312       self._pq.popleft()
313
314 #---------- error handling ----------
315
316 _crashing = False
317
318 def crash(err):
319   global _crashing
320   _crashing = True
321   print('========== CRASH ==========', err,
322         '===========================', file=sys.stderr)
323   try: reactor.stop()
324   except twisted.internet.error.ReactorNotRunning: pass
325
326 def crash_on_defer(defer):
327   defer.addErrback(lambda err: crash(err))
328
329 def crash_on_critical(event):
330   if event.get('log_level') >= LogLevel.critical:
331     crash(twisted.logger.formatEvent(event))
332
333 #---------- config processing ----------
334
335 def process_cfg_common_always():
336   global mtu
337   c.mtu = cfg.get('virtual','mtu')
338
339 def process_cfg_ipif(section, varmap):
340   for d, s in varmap:
341     try: v = getattr(c, s)
342     except AttributeError: continue
343     setattr(c, d, v)
344
345   #print(repr(c))
346
347   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
348
349 def process_cfg_network():
350   c.network = ipnetwork(cfg.get('virtual','network'))
351   if c.network.num_addresses < 3 + 2:
352     raise ValueError('network needs at least 2^3 addresses')
353
354 def process_cfg_server():
355   try:
356     c.server = cfg.get('virtual','server')
357   except NoOptionError:
358     process_cfg_network()
359     c.server = next(c.network.hosts())
360
361 class ServerAddr():
362   def __init__(self, port, addrspec):
363     self.port = port
364     # also self.addr
365     try:
366       self.addr = ipaddress.IPv4Address(addrspec)
367       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
368       self._inurl = b'%s'
369     except AddressValueError:
370       self.addr = ipaddress.IPv6Address(addrspec)
371       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
372       self._inurl = b'[%s]'
373   def make_endpoint(self):
374     return self._endpointfactory(reactor, self.port, self.addr)
375   def url(self):
376     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
377     if self.port != 80: url += b':%d' % self.port
378     url += b'/'
379     return url
380     
381 def process_cfg_saddrs():
382   try: port = cfg.getint('server','port')
383   except NoOptionError: port = 80
384
385   c.saddrs = [ ]
386   for addrspec in cfg.get('server','addrs').split():
387     sa = ServerAddr(port, addrspec)
388     c.saddrs.append(sa)
389
390 def process_cfg_clients(constructor):
391   c.clients = [ ]
392   for cs in cfg.sections():
393     if not (':' in cs or '.' in cs): continue
394     ci = ipaddr(cs)
395     pw = cfg.get(cs, 'password')
396     pw = pw.encode('utf-8')
397     constructor(ci,cs,pw)
398
399 #---------- startup ----------
400
401 def common_startup():
402   optparser.add_option('-c', '--config', dest='configfile',
403                        default='/etc/hippotat/config')
404
405   def ds_by_detail(od,os,detail_level,op):
406     global debug_set
407     debug_set = set([df for df in DBG.iterconstants() if df <= detail_level])
408
409   def ds_one(mutator,df, od,os,value,op):
410     mutator(df)
411
412   optparser.add_option('-D', '--debug',
413                        default=debug_def_detail.name,
414                        type='choice',
415                        choices=[dl.name for dl in DBG.iterconstants()],
416                        action='callback',
417                        callback= ds_by_detail)
418
419   optparser.add_option('--no-debug',
420                        nargs=0,
421                        action='callback',
422                        callback= partial(ds_by_detail,DBG.INIT))
423
424   for df in DBG.iterconstants():
425     optparser.add_option('--debug-'+df.name,
426                          action='callback',
427                          callback= partial(ds_one, debug_set.add, df))
428     optparser.add_option('--no-debug-'+df.name,
429                          action='callback',
430                          callback= partial(ds_one, debug_set.discard, df))
431
432   (opts, args) = optparser.parse_args()
433   if len(args): optparser.error('no non-option arguments please')
434
435   re = regexp.compile('#.*')
436   cfg.read_string(re.sub('', defcfg))
437   cfg.read(opts.configfile)
438
439   log_formatter = twisted.logger.formatEventAsClassicLogText
440   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
441   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
442   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
443   stdsomething_obs = twisted.logger.FilteringLogObserver(
444     stderr_obs, [pred], stdout_obs
445   )
446   log_observer = twisted.logger.FilteringLogObserver(
447     stdsomething_obs, [LogNotBoringTwisted()]
448   )
449   #log_observer = stdsomething_obs
450   twisted.logger.globalLogBeginner.beginLoggingTo(
451     [ log_observer, crash_on_critical ]
452     )
453
454 def common_run():
455   log_debug(DBG.INIT, 'entering reactor')
456   if not _crashing: reactor.run()
457   print('CRASHED (end)', file=sys.stderr)