chiark / gitweb /
wip debug
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35   INIT = NamedConstant()
36   ROUTE = NamedConstant()
37   DROP = NamedConstant()
38   FLOW = NamedConstant()
39   HTTP = NamedConstant()
40   TWISTED = NamedConstant()
41   QUEUE = NamedConstant()
42   HTTP_CTRL = NamedConstant()
43   QUEUE_CTRL = NamedConstant()
44   HTTP_FULL = NamedConstant()
45   CTRL_DUMP = NamedConstant()
46   SLIP_FULL = NamedConstant()
47
48 _hex_codec = codecs.getencoder('hex_codec')
49
50 log = twisted.logger.Logger()
51
52 def log_debug(dflag, msg, idof=None, d=None):
53   if dflag > DBG.HTTP: return
54   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
55   if idof is not None:
56     msg = '[%#x] %s' % (id(idof), msg)
57   if d is not None:
58     #d = d[0:64]
59     d = _hex_codec(d)[0].decode('ascii')
60     msg += ' ' + d
61   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
62
63 defcfg = '''
64 [DEFAULT]
65 #[<client>] overrides
66 max_batch_down = 65536           # used by server, subject to [limits]
67 max_queue_time = 10              # used by server, subject to [limits]
68 target_requests_outstanding = 3  # must match; subject to [limits] on server
69 http_timeout = 30                # used by both } must be
70 http_timeout_grace = 5           # used by both }  compatible
71 max_requests_outstanding = 4     # used by client
72 max_batch_up = 4000              # used by client
73 http_retry = 5                   # used by client
74
75 #[server] or [<client>] overrides
76 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
77 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
78 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
79 #      from   on client  <client>         [virtual]server   [virtual]routes
80
81 [virtual]
82 mtu = 1500
83 routes = ''
84 # network = <prefix>/<len>  # mandatory for server
85 # server  = <ipaddr>   # used by both, default is computed from `network'
86 # relay   = <ipaddr>   # used by server, default from `network' and `server'
87 #  default server is first host in network
88 #  default relay is first host which is not server
89
90 [server]
91 # addrs = 127.0.0.1 ::1    # mandatory for server
92 port = 80                  # used by server
93 # url              # used by client; default from first `addrs' and `port'
94
95 # [<client-ip4-or-ipv6-address>]
96 # password = <password>    # used by both, must match
97
98 [limits]
99 max_batch_down = 262144           # used by server
100 max_queue_time = 121              # used by server
101 http_timeout = 121                # used by server
102 target_requests_outstanding = 10  # used by server
103 '''
104
105 # these need to be defined here so that they can be imported by import *
106 cfg = ConfigParser()
107 optparser = OptionParser()
108
109 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
110 def mime_translate(s):
111   # SLIP-encoded packets cannot contain ESC ESC.
112   # Swap `-' and ESC.  The result cannot contain `--'
113   return s.translate(_mimetrans)
114
115 class ConfigResults:
116   def __init__(self, d = { }):
117     self.__dict__ = d
118   def __repr__(self):
119     return 'ConfigResults('+repr(self.__dict__)+')'
120
121 c = ConfigResults()
122
123 def log_discard(packet, iface, saddr, daddr, why):
124   log_debug(DBG.DROP,
125             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
126             d=packet)
127
128 #---------- packet parsing ----------
129
130 def packet_addrs(packet):
131   version = packet[0] >> 4
132   if version == 4:
133     addrlen = 4
134     saddroff = 3*4
135     factory = ipaddress.IPv4Address
136   elif version == 6:
137     addrlen = 16
138     saddroff = 2*4
139     factory = ipaddress.IPv6Address
140   else:
141     raise ValueError('unsupported IP version %d' % version)
142   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
143   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
144   return (saddr, daddr)
145
146 #---------- address handling ----------
147
148 def ipaddr(input):
149   try:
150     r = ipaddress.IPv4Address(input)
151   except AddressValueError:
152     r = ipaddress.IPv6Address(input)
153   return r
154
155 def ipnetwork(input):
156   try:
157     r = ipaddress.IPv4Network(input)
158   except NetworkValueError:
159     r = ipaddress.IPv6Network(input)
160   return r
161
162 #---------- ipif (SLIP) subprocess ----------
163
164 class SlipStreamDecoder():
165   def __init__(self, desc, on_packet):
166     self._buffer = b''
167     self._on_packet = on_packet
168     self._desc = desc
169     self._log('__init__')
170
171   def _log(self, msg, **kwargs):
172     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
173
174   def inputdata(self, data):
175     self._log('inputdata', d=data)
176     packets = slip.decode(data)
177     packets[0] = self._buffer + packets[0]
178     self._buffer = packets.pop()
179     for packet in packets:
180       self._maybe_packet(packet)
181     self._log('bufremain', d=self._buffer)
182
183   def _maybe_packet(self, packet):
184     self._log('maybepacket', d=packet)
185     if len(packet):
186       self._on_packet(packet)
187
188   def flush(self):
189     self._log('flush')
190     self._maybe_packet(self._buffer)
191     self._buffer = b''
192
193 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
194   def __init__(self, router):
195     self._router = router
196     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
197   def connectionMade(self): pass
198   def outReceived(self, data):
199     self._decoder.inputdata(data)
200   def slip_on_packet(self, packet):
201     (saddr, daddr) = packet_addrs(packet)
202     if saddr.is_link_local or daddr.is_link_local:
203       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
204       return
205     self._router(packet, saddr, daddr)
206   def processEnded(self, status):
207     status.raiseException()
208
209 def start_ipif(command, router):
210   global ipif
211   ipif = _IpifProcessProtocol(router)
212   reactor.spawnProcess(ipif,
213                        '/bin/sh',['sh','-xc', command],
214                        childFDs={0:'w', 1:'r', 2:2},
215                        env=None)
216
217 def queue_inbound(packet):
218   log_debug(DBG.FLOW, "queue_inbound", d=packet)
219   ipif.transport.write(slip.delimiter)
220   ipif.transport.write(slip.encode(packet))
221   ipif.transport.write(slip.delimiter)
222
223 #---------- packet queue ----------
224
225 class PacketQueue():
226   def __init__(self, desc, max_queue_time):
227     self._desc = desc
228     assert(desc + '')
229     self._max_queue_time = max_queue_time
230     self._pq = collections.deque() # packets
231
232   def _log(self, dflag, msg, **kwargs):
233     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
234
235   def append(self, packet):
236     self._log(DBG.QUEUE, 'append', d=packet)
237     self._pq.append((time.monotonic(), packet))
238
239   def nonempty(self):
240     self._log(DBG.QUEUE, 'nonempty ?')
241     while True:
242       try: (queuetime, packet) = self._pq[0]
243       except IndexError:
244         self._log(DBG.QUEUE, 'nonempty ? empty.')
245         return False
246
247       age = time.monotonic() - queuetime
248       if age > self._max_queue_time:
249         # strip old packets off the front
250         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
251         self._pq.popleft()
252         continue
253
254       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
255       return True
256
257   def process(self, sizequery, moredata, max_batch):
258     # sizequery() should return size of batch so far
259     # moredata(s) should add s to batch
260     self._log(DBG.QUEUE, 'process...')
261     while True:
262       try: (dummy, packet) = self._pq[0]
263       except IndexError:
264         self._log(DBG.QUEUE, 'process... empty')
265         break
266
267       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
268
269       encoded = slip.encode(packet)
270       sofar = sizequery()  
271
272       self._log(DBG.QUEUE_CTRL,
273                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
274                 d=encoded)
275
276       if sofar > 0:
277         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
278           self._log(DBG.QUEUE_CTRL, 'process... overflow')
279           break
280         moredata(slip.delimiter)
281
282       moredata(encoded)
283       self._pq.popleft()
284
285 #---------- error handling ----------
286
287 _crashing = False
288
289 def crash(err):
290   global _crashing
291   _crashing = True
292   print('========== CRASH ==========', err,
293         '===========================', file=sys.stderr)
294   try: reactor.stop()
295   except twisted.internet.error.ReactorNotRunning: pass
296
297 def crash_on_defer(defer):
298   defer.addErrback(lambda err: crash(err))
299
300 def crash_on_critical(event):
301   if event.get('log_level') >= LogLevel.critical:
302     crash(twisted.logger.formatEvent(event))
303
304 #---------- config processing ----------
305
306 def process_cfg_common_always():
307   global mtu
308   c.mtu = cfg.get('virtual','mtu')
309
310 def process_cfg_ipif(section, varmap):
311   for d, s in varmap:
312     try: v = getattr(c, s)
313     except AttributeError: continue
314     setattr(c, d, v)
315
316   #print(repr(c))
317
318   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
319
320 def process_cfg_network():
321   c.network = ipnetwork(cfg.get('virtual','network'))
322   if c.network.num_addresses < 3 + 2:
323     raise ValueError('network needs at least 2^3 addresses')
324
325 def process_cfg_server():
326   try:
327     c.server = cfg.get('virtual','server')
328   except NoOptionError:
329     process_cfg_network()
330     c.server = next(c.network.hosts())
331
332 class ServerAddr():
333   def __init__(self, port, addrspec):
334     self.port = port
335     # also self.addr
336     try:
337       self.addr = ipaddress.IPv4Address(addrspec)
338       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
339       self._inurl = b'%s'
340     except AddressValueError:
341       self.addr = ipaddress.IPv6Address(addrspec)
342       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
343       self._inurl = b'[%s]'
344   def make_endpoint(self):
345     return self._endpointfactory(reactor, self.port, self.addr)
346   def url(self):
347     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
348     if self.port != 80: url += b':%d' % self.port
349     url += b'/'
350     return url
351     
352 def process_cfg_saddrs():
353   try: port = cfg.getint('server','port')
354   except NoOptionError: port = 80
355
356   c.saddrs = [ ]
357   for addrspec in cfg.get('server','addrs').split():
358     sa = ServerAddr(port, addrspec)
359     c.saddrs.append(sa)
360
361 def process_cfg_clients(constructor):
362   c.clients = [ ]
363   for cs in cfg.sections():
364     if not (':' in cs or '.' in cs): continue
365     ci = ipaddr(cs)
366     pw = cfg.get(cs, 'password')
367     pw = pw.encode('utf-8')
368     constructor(ci,cs,pw)
369
370 #---------- startup ----------
371
372 def common_startup():
373   log_formatter = twisted.logger.formatEventAsClassicLogText
374   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
375   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
376   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
377   log_observer = twisted.logger.FilteringLogObserver(
378     stderr_obs, [pred], stdout_obs
379   )
380   twisted.logger.globalLogBeginner.beginLoggingTo(
381     [ log_observer, crash_on_critical ]
382     )
383
384   optparser.add_option('-c', '--config', dest='configfile',
385                        default='/etc/hippotat/config')
386   (opts, args) = optparser.parse_args()
387   if len(args): optparser.error('no non-option arguments please')
388
389   re = regexp.compile('#.*')
390   cfg.read_string(re.sub('', defcfg))
391   cfg.read(opts.configfile)
392
393 def common_run():
394   log_debug(DBG.INIT, 'entering reactor')
395   if not _crashing: reactor.run()
396   print('CRASHED (end)', file=sys.stderr)