chiark / gitweb /
b6a84d9546903868edf9d6830f4b8df2de92b08c
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35   INIT = NamedConstant()
36   ROUTE = NamedConstant()
37   DROP = NamedConstant()
38   FLOW = NamedConstant()
39   HTTP = NamedConstant()
40   TWISTED = NamedConstant()
41   QUEUE = NamedConstant()
42   HTTP_CTRL = NamedConstant()
43   QUEUE_CTRL = NamedConstant()
44   HTTP_FULL = NamedConstant()
45   CTRL_DUMP = NamedConstant()
46   SLIP_FULL = NamedConstant()
47
48 _hex_codec = codecs.getencoder('hex_codec')
49
50 log = twisted.logger.Logger()
51
52 debug_set = set([x for x in DBG.iterconstants() if x <= DBG.HTTP])
53
54 def log_debug(dflag, msg, idof=None, d=None):
55   if dflag not in debug_set: return
56   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
57   if idof is not None:
58     msg = '[%#x] %s' % (id(idof), msg)
59   if d is not None:
60     #d = d[0:64]
61     d = _hex_codec(d)[0].decode('ascii')
62     msg += ' ' + d
63   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
64
65 defcfg = '''
66 [DEFAULT]
67 #[<client>] overrides
68 max_batch_down = 65536           # used by server, subject to [limits]
69 max_queue_time = 10              # used by server, subject to [limits]
70 target_requests_outstanding = 3  # must match; subject to [limits] on server
71 http_timeout = 30                # used by both } must be
72 http_timeout_grace = 5           # used by both }  compatible
73 max_requests_outstanding = 4     # used by client
74 max_batch_up = 4000              # used by client
75 http_retry = 5                   # used by client
76
77 #[server] or [<client>] overrides
78 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
79 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
80 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
81 #      from   on client  <client>         [virtual]server   [virtual]routes
82
83 [virtual]
84 mtu = 1500
85 routes = ''
86 # network = <prefix>/<len>  # mandatory for server
87 # server  = <ipaddr>   # used by both, default is computed from `network'
88 # relay   = <ipaddr>   # used by server, default from `network' and `server'
89 #  default server is first host in network
90 #  default relay is first host which is not server
91
92 [server]
93 # addrs = 127.0.0.1 ::1    # mandatory for server
94 port = 80                  # used by server
95 # url              # used by client; default from first `addrs' and `port'
96
97 # [<client-ip4-or-ipv6-address>]
98 # password = <password>    # used by both, must match
99
100 [limits]
101 max_batch_down = 262144           # used by server
102 max_queue_time = 121              # used by server
103 http_timeout = 121                # used by server
104 target_requests_outstanding = 10  # used by server
105 '''
106
107 # these need to be defined here so that they can be imported by import *
108 cfg = ConfigParser()
109 optparser = OptionParser()
110
111 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
112 def mime_translate(s):
113   # SLIP-encoded packets cannot contain ESC ESC.
114   # Swap `-' and ESC.  The result cannot contain `--'
115   return s.translate(_mimetrans)
116
117 class ConfigResults:
118   def __init__(self, d = { }):
119     self.__dict__ = d
120   def __repr__(self):
121     return 'ConfigResults('+repr(self.__dict__)+')'
122
123 c = ConfigResults()
124
125 def log_discard(packet, iface, saddr, daddr, why):
126   log_debug(DBG.DROP,
127             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
128             d=packet)
129
130 #---------- packet parsing ----------
131
132 def packet_addrs(packet):
133   version = packet[0] >> 4
134   if version == 4:
135     addrlen = 4
136     saddroff = 3*4
137     factory = ipaddress.IPv4Address
138   elif version == 6:
139     addrlen = 16
140     saddroff = 2*4
141     factory = ipaddress.IPv6Address
142   else:
143     raise ValueError('unsupported IP version %d' % version)
144   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
145   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
146   return (saddr, daddr)
147
148 #---------- address handling ----------
149
150 def ipaddr(input):
151   try:
152     r = ipaddress.IPv4Address(input)
153   except AddressValueError:
154     r = ipaddress.IPv6Address(input)
155   return r
156
157 def ipnetwork(input):
158   try:
159     r = ipaddress.IPv4Network(input)
160   except NetworkValueError:
161     r = ipaddress.IPv6Network(input)
162   return r
163
164 #---------- ipif (SLIP) subprocess ----------
165
166 class SlipStreamDecoder():
167   def __init__(self, desc, on_packet):
168     self._buffer = b''
169     self._on_packet = on_packet
170     self._desc = desc
171     self._log('__init__')
172
173   def _log(self, msg, **kwargs):
174     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
175
176   def inputdata(self, data):
177     self._log('inputdata', d=data)
178     packets = slip.decode(data)
179     packets[0] = self._buffer + packets[0]
180     self._buffer = packets.pop()
181     for packet in packets:
182       self._maybe_packet(packet)
183     self._log('bufremain', d=self._buffer)
184
185   def _maybe_packet(self, packet):
186     self._log('maybepacket', d=packet)
187     if len(packet):
188       self._on_packet(packet)
189
190   def flush(self):
191     self._log('flush')
192     self._maybe_packet(self._buffer)
193     self._buffer = b''
194
195 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
196   def __init__(self, router):
197     self._router = router
198     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
199   def connectionMade(self): pass
200   def outReceived(self, data):
201     self._decoder.inputdata(data)
202   def slip_on_packet(self, packet):
203     (saddr, daddr) = packet_addrs(packet)
204     if saddr.is_link_local or daddr.is_link_local:
205       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
206       return
207     self._router(packet, saddr, daddr)
208   def processEnded(self, status):
209     status.raiseException()
210
211 def start_ipif(command, router):
212   global ipif
213   ipif = _IpifProcessProtocol(router)
214   reactor.spawnProcess(ipif,
215                        '/bin/sh',['sh','-xc', command],
216                        childFDs={0:'w', 1:'r', 2:2},
217                        env=None)
218
219 def queue_inbound(packet):
220   log_debug(DBG.FLOW, "queue_inbound", d=packet)
221   ipif.transport.write(slip.delimiter)
222   ipif.transport.write(slip.encode(packet))
223   ipif.transport.write(slip.delimiter)
224
225 #---------- packet queue ----------
226
227 class PacketQueue():
228   def __init__(self, desc, max_queue_time):
229     self._desc = desc
230     assert(desc + '')
231     self._max_queue_time = max_queue_time
232     self._pq = collections.deque() # packets
233
234   def _log(self, dflag, msg, **kwargs):
235     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
236
237   def append(self, packet):
238     self._log(DBG.QUEUE, 'append', d=packet)
239     self._pq.append((time.monotonic(), packet))
240
241   def nonempty(self):
242     self._log(DBG.QUEUE, 'nonempty ?')
243     while True:
244       try: (queuetime, packet) = self._pq[0]
245       except IndexError:
246         self._log(DBG.QUEUE, 'nonempty ? empty.')
247         return False
248
249       age = time.monotonic() - queuetime
250       if age > self._max_queue_time:
251         # strip old packets off the front
252         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
253         self._pq.popleft()
254         continue
255
256       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
257       return True
258
259   def process(self, sizequery, moredata, max_batch):
260     # sizequery() should return size of batch so far
261     # moredata(s) should add s to batch
262     self._log(DBG.QUEUE, 'process...')
263     while True:
264       try: (dummy, packet) = self._pq[0]
265       except IndexError:
266         self._log(DBG.QUEUE, 'process... empty')
267         break
268
269       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
270
271       encoded = slip.encode(packet)
272       sofar = sizequery()  
273
274       self._log(DBG.QUEUE_CTRL,
275                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
276                 d=encoded)
277
278       if sofar > 0:
279         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
280           self._log(DBG.QUEUE_CTRL, 'process... overflow')
281           break
282         moredata(slip.delimiter)
283
284       moredata(encoded)
285       self._pq.popleft()
286
287 #---------- error handling ----------
288
289 _crashing = False
290
291 def crash(err):
292   global _crashing
293   _crashing = True
294   print('========== CRASH ==========', err,
295         '===========================', file=sys.stderr)
296   try: reactor.stop()
297   except twisted.internet.error.ReactorNotRunning: pass
298
299 def crash_on_defer(defer):
300   defer.addErrback(lambda err: crash(err))
301
302 def crash_on_critical(event):
303   if event.get('log_level') >= LogLevel.critical:
304     crash(twisted.logger.formatEvent(event))
305
306 #---------- config processing ----------
307
308 def process_cfg_common_always():
309   global mtu
310   c.mtu = cfg.get('virtual','mtu')
311
312 def process_cfg_ipif(section, varmap):
313   for d, s in varmap:
314     try: v = getattr(c, s)
315     except AttributeError: continue
316     setattr(c, d, v)
317
318   #print(repr(c))
319
320   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
321
322 def process_cfg_network():
323   c.network = ipnetwork(cfg.get('virtual','network'))
324   if c.network.num_addresses < 3 + 2:
325     raise ValueError('network needs at least 2^3 addresses')
326
327 def process_cfg_server():
328   try:
329     c.server = cfg.get('virtual','server')
330   except NoOptionError:
331     process_cfg_network()
332     c.server = next(c.network.hosts())
333
334 class ServerAddr():
335   def __init__(self, port, addrspec):
336     self.port = port
337     # also self.addr
338     try:
339       self.addr = ipaddress.IPv4Address(addrspec)
340       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
341       self._inurl = b'%s'
342     except AddressValueError:
343       self.addr = ipaddress.IPv6Address(addrspec)
344       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
345       self._inurl = b'[%s]'
346   def make_endpoint(self):
347     return self._endpointfactory(reactor, self.port, self.addr)
348   def url(self):
349     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
350     if self.port != 80: url += b':%d' % self.port
351     url += b'/'
352     return url
353     
354 def process_cfg_saddrs():
355   try: port = cfg.getint('server','port')
356   except NoOptionError: port = 80
357
358   c.saddrs = [ ]
359   for addrspec in cfg.get('server','addrs').split():
360     sa = ServerAddr(port, addrspec)
361     c.saddrs.append(sa)
362
363 def process_cfg_clients(constructor):
364   c.clients = [ ]
365   for cs in cfg.sections():
366     if not (':' in cs or '.' in cs): continue
367     ci = ipaddr(cs)
368     pw = cfg.get(cs, 'password')
369     pw = pw.encode('utf-8')
370     constructor(ci,cs,pw)
371
372 #---------- startup ----------
373
374 def common_startup():
375   log_formatter = twisted.logger.formatEventAsClassicLogText
376   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
377   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
378   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
379   log_observer = twisted.logger.FilteringLogObserver(
380     stderr_obs, [pred], stdout_obs
381   )
382   twisted.logger.globalLogBeginner.beginLoggingTo(
383     [ log_observer, crash_on_critical ]
384     )
385
386   optparser.add_option('-c', '--config', dest='configfile',
387                        default='/etc/hippotat/config')
388   (opts, args) = optparser.parse_args()
389   if len(args): optparser.error('no non-option arguments please')
390
391   re = regexp.compile('#.*')
392   cfg.read_string(re.sub('', defcfg))
393   cfg.read(opts.configfile)
394
395 def common_run():
396   log_debug(DBG.INIT, 'entering reactor')
397   if not _crashing: reactor.run()
398   print('CRASHED (end)', file=sys.stderr)