chiark / gitweb /
filter out twisted stuff
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 from zope.interface import implementer
9
10 import twisted
11 from twisted.internet import reactor
12 import twisted.internet.endpoints
13 import twisted.logger
14 from twisted.logger import LogLevel
15 import twisted.python.constants
16 from twisted.python.constants import NamedConstant
17
18 import ipaddress
19 from ipaddress import AddressValueError
20
21 from optparse import OptionParser
22 from configparser import ConfigParser
23 from configparser import NoOptionError
24
25 from functools import partial
26
27 import collections
28 import time
29 import codecs
30 import traceback
31
32 import re as regexp
33
34 import hippotat.slip as slip
35
36 class DBG(twisted.python.constants.Names):
37   INIT = NamedConstant()
38   ROUTE = NamedConstant()
39   DROP = NamedConstant()
40   FLOW = NamedConstant()
41   HTTP = NamedConstant()
42   TWISTED = NamedConstant()
43   QUEUE = NamedConstant()
44   HTTP_CTRL = NamedConstant()
45   QUEUE_CTRL = NamedConstant()
46   HTTP_FULL = NamedConstant()
47   CTRL_DUMP = NamedConstant()
48   SLIP_FULL = NamedConstant()
49
50 _hex_codec = codecs.getencoder('hex_codec')
51
52 #---------- logging ----------
53
54 org_stderr = sys.stderr
55
56 log = twisted.logger.Logger()
57
58 debug_set = set([x for x in DBG.iterconstants() if x <= DBG.HTTP])
59
60 def log_debug(dflag, msg, idof=None, d=None):
61   if dflag not in debug_set: return
62   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
63   if idof is not None:
64     msg = '[%#x] %s' % (id(idof), msg)
65   if d is not None:
66     #d = d[0:64]
67     d = _hex_codec(d)[0].decode('ascii')
68     msg += ' ' + d
69   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
70
71 @implementer(twisted.logger.ILogFilterPredicate)
72 class LogNotBoringTwisted:
73   def __call__(self, event):
74     yes = twisted.logger.PredicateResult.yes
75     no  = twisted.logger.PredicateResult.no
76     try:
77       if event.get('log_level') != LogLevel.info:
78         return yes
79       try:
80         dflag = event.get('dflag')
81       except KeyError:
82         dflag = DBG.TWISTED
83       return yes if (dflag in debug_set) else no
84     except Exception:
85       print(traceback.format_exc(), file=org_stderr)
86       return yes
87
88 #---------- default config ----------
89
90 defcfg = '''
91 [DEFAULT]
92 #[<client>] overrides
93 max_batch_down = 65536           # used by server, subject to [limits]
94 max_queue_time = 10              # used by server, subject to [limits]
95 target_requests_outstanding = 3  # must match; subject to [limits] on server
96 http_timeout = 30                # used by both } must be
97 http_timeout_grace = 5           # used by both }  compatible
98 max_requests_outstanding = 4     # used by client
99 max_batch_up = 4000              # used by client
100 http_retry = 5                   # used by client
101
102 #[server] or [<client>] overrides
103 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
104 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
105 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
106 #      from   on client  <client>         [virtual]server   [virtual]routes
107
108 [virtual]
109 mtu = 1500
110 routes = ''
111 # network = <prefix>/<len>  # mandatory for server
112 # server  = <ipaddr>   # used by both, default is computed from `network'
113 # relay   = <ipaddr>   # used by server, default from `network' and `server'
114 #  default server is first host in network
115 #  default relay is first host which is not server
116
117 [server]
118 # addrs = 127.0.0.1 ::1    # mandatory for server
119 port = 80                  # used by server
120 # url              # used by client; default from first `addrs' and `port'
121
122 # [<client-ip4-or-ipv6-address>]
123 # password = <password>    # used by both, must match
124
125 [limits]
126 max_batch_down = 262144           # used by server
127 max_queue_time = 121              # used by server
128 http_timeout = 121                # used by server
129 target_requests_outstanding = 10  # used by server
130 '''
131
132 # these need to be defined here so that they can be imported by import *
133 cfg = ConfigParser()
134 optparser = OptionParser()
135
136 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
137 def mime_translate(s):
138   # SLIP-encoded packets cannot contain ESC ESC.
139   # Swap `-' and ESC.  The result cannot contain `--'
140   return s.translate(_mimetrans)
141
142 class ConfigResults:
143   def __init__(self, d = { }):
144     self.__dict__ = d
145   def __repr__(self):
146     return 'ConfigResults('+repr(self.__dict__)+')'
147
148 c = ConfigResults()
149
150 def log_discard(packet, iface, saddr, daddr, why):
151   log_debug(DBG.DROP,
152             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
153             d=packet)
154
155 #---------- packet parsing ----------
156
157 def packet_addrs(packet):
158   version = packet[0] >> 4
159   if version == 4:
160     addrlen = 4
161     saddroff = 3*4
162     factory = ipaddress.IPv4Address
163   elif version == 6:
164     addrlen = 16
165     saddroff = 2*4
166     factory = ipaddress.IPv6Address
167   else:
168     raise ValueError('unsupported IP version %d' % version)
169   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
170   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
171   return (saddr, daddr)
172
173 #---------- address handling ----------
174
175 def ipaddr(input):
176   try:
177     r = ipaddress.IPv4Address(input)
178   except AddressValueError:
179     r = ipaddress.IPv6Address(input)
180   return r
181
182 def ipnetwork(input):
183   try:
184     r = ipaddress.IPv4Network(input)
185   except NetworkValueError:
186     r = ipaddress.IPv6Network(input)
187   return r
188
189 #---------- ipif (SLIP) subprocess ----------
190
191 class SlipStreamDecoder():
192   def __init__(self, desc, on_packet):
193     self._buffer = b''
194     self._on_packet = on_packet
195     self._desc = desc
196     self._log('__init__')
197
198   def _log(self, msg, **kwargs):
199     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
200
201   def inputdata(self, data):
202     self._log('inputdata', d=data)
203     packets = slip.decode(data)
204     packets[0] = self._buffer + packets[0]
205     self._buffer = packets.pop()
206     for packet in packets:
207       self._maybe_packet(packet)
208     self._log('bufremain', d=self._buffer)
209
210   def _maybe_packet(self, packet):
211     self._log('maybepacket', d=packet)
212     if len(packet):
213       self._on_packet(packet)
214
215   def flush(self):
216     self._log('flush')
217     self._maybe_packet(self._buffer)
218     self._buffer = b''
219
220 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
221   def __init__(self, router):
222     self._router = router
223     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
224   def connectionMade(self): pass
225   def outReceived(self, data):
226     self._decoder.inputdata(data)
227   def slip_on_packet(self, packet):
228     (saddr, daddr) = packet_addrs(packet)
229     if saddr.is_link_local or daddr.is_link_local:
230       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
231       return
232     self._router(packet, saddr, daddr)
233   def processEnded(self, status):
234     status.raiseException()
235
236 def start_ipif(command, router):
237   global ipif
238   ipif = _IpifProcessProtocol(router)
239   reactor.spawnProcess(ipif,
240                        '/bin/sh',['sh','-xc', command],
241                        childFDs={0:'w', 1:'r', 2:2},
242                        env=None)
243
244 def queue_inbound(packet):
245   log_debug(DBG.FLOW, "queue_inbound", d=packet)
246   ipif.transport.write(slip.delimiter)
247   ipif.transport.write(slip.encode(packet))
248   ipif.transport.write(slip.delimiter)
249
250 #---------- packet queue ----------
251
252 class PacketQueue():
253   def __init__(self, desc, max_queue_time):
254     self._desc = desc
255     assert(desc + '')
256     self._max_queue_time = max_queue_time
257     self._pq = collections.deque() # packets
258
259   def _log(self, dflag, msg, **kwargs):
260     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
261
262   def append(self, packet):
263     self._log(DBG.QUEUE, 'append', d=packet)
264     self._pq.append((time.monotonic(), packet))
265
266   def nonempty(self):
267     self._log(DBG.QUEUE, 'nonempty ?')
268     while True:
269       try: (queuetime, packet) = self._pq[0]
270       except IndexError:
271         self._log(DBG.QUEUE, 'nonempty ? empty.')
272         return False
273
274       age = time.monotonic() - queuetime
275       if age > self._max_queue_time:
276         # strip old packets off the front
277         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
278         self._pq.popleft()
279         continue
280
281       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
282       return True
283
284   def process(self, sizequery, moredata, max_batch):
285     # sizequery() should return size of batch so far
286     # moredata(s) should add s to batch
287     self._log(DBG.QUEUE, 'process...')
288     while True:
289       try: (dummy, packet) = self._pq[0]
290       except IndexError:
291         self._log(DBG.QUEUE, 'process... empty')
292         break
293
294       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
295
296       encoded = slip.encode(packet)
297       sofar = sizequery()  
298
299       self._log(DBG.QUEUE_CTRL,
300                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
301                 d=encoded)
302
303       if sofar > 0:
304         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
305           self._log(DBG.QUEUE_CTRL, 'process... overflow')
306           break
307         moredata(slip.delimiter)
308
309       moredata(encoded)
310       self._pq.popleft()
311
312 #---------- error handling ----------
313
314 _crashing = False
315
316 def crash(err):
317   global _crashing
318   _crashing = True
319   print('========== CRASH ==========', err,
320         '===========================', file=sys.stderr)
321   try: reactor.stop()
322   except twisted.internet.error.ReactorNotRunning: pass
323
324 def crash_on_defer(defer):
325   defer.addErrback(lambda err: crash(err))
326
327 def crash_on_critical(event):
328   if event.get('log_level') >= LogLevel.critical:
329     crash(twisted.logger.formatEvent(event))
330
331 #---------- config processing ----------
332
333 def process_cfg_common_always():
334   global mtu
335   c.mtu = cfg.get('virtual','mtu')
336
337 def process_cfg_ipif(section, varmap):
338   for d, s in varmap:
339     try: v = getattr(c, s)
340     except AttributeError: continue
341     setattr(c, d, v)
342
343   #print(repr(c))
344
345   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
346
347 def process_cfg_network():
348   c.network = ipnetwork(cfg.get('virtual','network'))
349   if c.network.num_addresses < 3 + 2:
350     raise ValueError('network needs at least 2^3 addresses')
351
352 def process_cfg_server():
353   try:
354     c.server = cfg.get('virtual','server')
355   except NoOptionError:
356     process_cfg_network()
357     c.server = next(c.network.hosts())
358
359 class ServerAddr():
360   def __init__(self, port, addrspec):
361     self.port = port
362     # also self.addr
363     try:
364       self.addr = ipaddress.IPv4Address(addrspec)
365       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
366       self._inurl = b'%s'
367     except AddressValueError:
368       self.addr = ipaddress.IPv6Address(addrspec)
369       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
370       self._inurl = b'[%s]'
371   def make_endpoint(self):
372     return self._endpointfactory(reactor, self.port, self.addr)
373   def url(self):
374     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
375     if self.port != 80: url += b':%d' % self.port
376     url += b'/'
377     return url
378     
379 def process_cfg_saddrs():
380   try: port = cfg.getint('server','port')
381   except NoOptionError: port = 80
382
383   c.saddrs = [ ]
384   for addrspec in cfg.get('server','addrs').split():
385     sa = ServerAddr(port, addrspec)
386     c.saddrs.append(sa)
387
388 def process_cfg_clients(constructor):
389   c.clients = [ ]
390   for cs in cfg.sections():
391     if not (':' in cs or '.' in cs): continue
392     ci = ipaddr(cs)
393     pw = cfg.get(cs, 'password')
394     pw = pw.encode('utf-8')
395     constructor(ci,cs,pw)
396
397 #---------- startup ----------
398
399 def common_startup():
400   log_formatter = twisted.logger.formatEventAsClassicLogText
401   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
402   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
403   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
404   stdsomething_obs = twisted.logger.FilteringLogObserver(
405     stderr_obs, [pred], stdout_obs
406   )
407   log_observer = twisted.logger.FilteringLogObserver(
408     stdsomething_obs, [LogNotBoringTwisted()]
409   )
410   #log_observer = stdsomething_obs
411   twisted.logger.globalLogBeginner.beginLoggingTo(
412     [ log_observer, crash_on_critical ]
413     )
414
415   optparser.add_option('-c', '--config', dest='configfile',
416                        default='/etc/hippotat/config')
417   (opts, args) = optparser.parse_args()
418   if len(args): optparser.error('no non-option arguments please')
419
420   re = regexp.compile('#.*')
421   cfg.read_string(re.sub('', defcfg))
422   cfg.read(opts.configfile)
423
424 def common_run():
425   log_debug(DBG.INIT, 'entering reactor')
426   if not _crashing: reactor.run()
427   print('CRASHED (end)', file=sys.stderr)