chiark / gitweb /
c1a6996bb862cb2a50cf08980ca3b35a5f679142
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33   ROUTE = NamedConstant()
34   DROP = NamedConstant()
35   FLOW = NamedConstant()
36   HTTP = NamedConstant()
37   HTTP_CTRL = NamedConstant()
38   INIT = NamedConstant()
39   QUEUE = NamedConstant()
40   QUEUE_CTRL = NamedConstant()
41   HTTP_FULL = NamedConstant()
42   SLIP_FULL = NamedConstant()
43
44 _hex_codec = codecs.getencoder('hex_codec')
45
46 log = twisted.logger.Logger()
47
48 def log_debug(dflag, msg, idof=None, d=None):
49   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
50   if idof is not None:
51     msg = '[%d] %s' % (id(idof), msg)
52   if d is not None:
53     d = d[0:64]
54     d = _hex_codec(d)[0].decode('ascii')
55     msg += ' ' + d
56   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
57
58 defcfg = '''
59 [DEFAULT]
60 #[<client>] overrides
61 max_batch_down = 65536           # used by server, subject to [limits]
62 max_queue_time = 10              # used by server, subject to [limits]
63 max_request_time = 54            # used by server, subject to [limits]
64 target_requests_outstanding = 3  # must match; subject to [limits] on server
65 max_requests_outstanding = 4     # used by client
66 max_batch_up = 4000              # used by client
67 http_timeout = 30                # used by client
68 http_retry = 5                   # used by client
69
70 #[server] or [<client>] overrides
71 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
72 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
73 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
74 #      from   on client  <client>         [virtual]server   [virtual]routes
75
76 [virtual]
77 mtu = 1500
78 routes = ''
79 # network = <prefix>/<len>  # mandatory for server
80 # server  = <ipaddr>   # used by both, default is computed from `network'
81 # relay   = <ipaddr>   # used by server, default from `network' and `server'
82 #  default server is first host in network
83 #  default relay is first host which is not server
84
85 [server]
86 # addrs = 127.0.0.1 ::1    # mandatory for server
87 port = 80                  # used by server
88 # url              # used by client; default from first `addrs' and `port'
89
90 # [<client-ip4-or-ipv6-address>]
91 # password = <password>    # used by both, must match
92
93 [limits]
94 max_batch_down = 262144           # used by server
95 max_queue_time = 121              # used by server
96 max_request_time = 121            # used by server
97 target_requests_outstanding = 10  # used by server
98 '''
99
100 # these need to be defined here so that they can be imported by import *
101 cfg = ConfigParser()
102 optparser = OptionParser()
103
104 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
105 def mime_translate(s):
106   # SLIP-encoded packets cannot contain ESC ESC.
107   # Swap `-' and ESC.  The result cannot contain `--'
108   return s.translate(_mimetrans)
109
110 class ConfigResults:
111   def __init__(self, d = { }):
112     self.__dict__ = d
113   def __repr__(self):
114     return 'ConfigResults('+repr(self.__dict__)+')'
115
116 c = ConfigResults()
117
118 def log_discard(packet, iface, saddr, daddr, why):
119   log_debug(DBG.DROP,
120             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
121             d=packet)
122
123 #---------- packet parsing ----------
124
125 def packet_addrs(packet):
126   version = packet[0] >> 4
127   if version == 4:
128     addrlen = 4
129     saddroff = 3*4
130     factory = ipaddress.IPv4Address
131   elif version == 6:
132     addrlen = 16
133     saddroff = 2*4
134     factory = ipaddress.IPv6Address
135   else:
136     raise ValueError('unsupported IP version %d' % version)
137   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
138   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
139   return (saddr, daddr)
140
141 #---------- address handling ----------
142
143 def ipaddr(input):
144   try:
145     r = ipaddress.IPv4Address(input)
146   except AddressValueError:
147     r = ipaddress.IPv6Address(input)
148   return r
149
150 def ipnetwork(input):
151   try:
152     r = ipaddress.IPv4Network(input)
153   except NetworkValueError:
154     r = ipaddress.IPv6Network(input)
155   return r
156
157 #---------- ipif (SLIP) subprocess ----------
158
159 class SlipStreamDecoder():
160   def __init__(self, desc, on_packet):
161     self._buffer = b''
162     self._on_packet = on_packet
163     self._desc = desc
164     self._log('__init__')
165
166   def _log(self, msg, **kwargs):
167     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
168
169   def inputdata(self, data):
170     self._log('inputdata', d=data)
171     data = self._buffer + data
172     self._buffer = b''
173     packets = slip.decode(data)
174     self._buffer = packets.pop()
175     for packet in packets:
176       self._maybe_packet(packet)
177     self._log('bufremain', d=self._buffer)
178
179   def _maybe_packet(self, packet):
180     self._log('maybepacket', d=packet)
181     if len(packet):
182       self._on_packet(packet)
183
184   def flush(self):
185     self._log('flush')
186     self._maybe_packet(self._buffer)
187     self._buffer = b''
188
189 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
190   def __init__(self, router):
191     self._router = router
192     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
193   def connectionMade(self): pass
194   def outReceived(self, data):
195     self._decoder.inputdata(data)
196   def slip_on_packet(self, packet):
197     (saddr, daddr) = packet_addrs(packet)
198     if saddr.is_link_local or daddr.is_link_local:
199       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
200       return
201     self._router(packet, saddr, daddr)
202   def processEnded(self, status):
203     status.raiseException()
204
205 def start_ipif(command, router):
206   global ipif
207   ipif = _IpifProcessProtocol(router)
208   reactor.spawnProcess(ipif,
209                        '/bin/sh',['sh','-xc', command],
210                        childFDs={0:'w', 1:'r', 2:2},
211                        env=None)
212
213 def queue_inbound(packet):
214   log_debug(DBG.FLOW, "queue_inbound", d=packet)
215   ipif.transport.write(slip.delimiter)
216   ipif.transport.write(slip.encode(packet))
217   ipif.transport.write(slip.delimiter)
218
219 #---------- packet queue ----------
220
221 class PacketQueue():
222   def __init__(self, desc, max_queue_time):
223     self._desc = desc
224     assert(desc + '')
225     self._max_queue_time = max_queue_time
226     self._pq = collections.deque() # packets
227
228   def _log(self, dflag, msg, **kwargs):
229     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
230
231   def append(self, packet):
232     self._log(DBG.QUEUE, 'append', d=packet)
233     self._pq.append((time.monotonic(), packet))
234
235   def nonempty(self):
236     self._log(DBG.QUEUE, 'nonempty ?')
237     while True:
238       try: (queuetime, packet) = self._pq[0]
239       except IndexError:
240         self._log(DBG.QUEUE, 'nonempty ? empty.')
241         return False
242
243       age = time.monotonic() - queuetime
244       if age > self._max_queue_time:
245         # strip old packets off the front
246         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
247         self._pq.popleft()
248         continue
249
250       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
251       return True
252
253   def process(self, sizequery, moredata, max_batch):
254     # sizequery() should return size of batch so far
255     # moredata(s) should add s to batch
256     self._log(DBG.QUEUE, 'process...')
257     while True:
258       try: (dummy, packet) = self._pq[0]
259       except IndexError:
260         self._log(DBG.QUEUE, 'process... empty')
261         break
262
263       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
264
265       encoded = slip.encode(packet)
266       sofar = sizequery()  
267
268       self._log(DBG.QUEUE_CTRL,
269                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
270                 d=encoded)
271
272       if sofar > 0:
273         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
274           self._log(DBG.QUEUE_CTRL, 'process... overflow')
275           break
276         moredata(slip.delimiter)
277
278       moredata(encoded)
279       self._pq.popleft()
280
281 #---------- error handling ----------
282
283 _crashing = False
284
285 def crash(err):
286   global _crashing
287   _crashing = True
288   print('CRASH ', err, file=sys.stderr)
289   try: reactor.stop()
290   except twisted.internet.error.ReactorNotRunning: pass
291
292 def crash_on_defer(defer):
293   defer.addErrback(lambda err: crash(err))
294
295 def crash_on_critical(event):
296   if event.get('log_level') >= LogLevel.critical:
297     crash(twisted.logger.formatEvent(event))
298
299 #---------- config processing ----------
300
301 def process_cfg_common_always():
302   global mtu
303   c.mtu = cfg.get('virtual','mtu')
304
305 def process_cfg_ipif(section, varmap):
306   for d, s in varmap:
307     try: v = getattr(c, s)
308     except AttributeError: continue
309     setattr(c, d, v)
310
311   #print(repr(c))
312
313   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
314
315 def process_cfg_network():
316   c.network = ipnetwork(cfg.get('virtual','network'))
317   if c.network.num_addresses < 3 + 2:
318     raise ValueError('network needs at least 2^3 addresses')
319
320 def process_cfg_server():
321   try:
322     c.server = cfg.get('virtual','server')
323   except NoOptionError:
324     process_cfg_network()
325     c.server = next(c.network.hosts())
326
327 class ServerAddr():
328   def __init__(self, port, addrspec):
329     self.port = port
330     # also self.addr
331     try:
332       self.addr = ipaddress.IPv4Address(addrspec)
333       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
334       self._inurl = b'%s'
335     except AddressValueError:
336       self.addr = ipaddress.IPv6Address(addrspec)
337       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
338       self._inurl = b'[%s]'
339   def make_endpoint(self):
340     return self._endpointfactory(reactor, self.port, self.addr)
341   def url(self):
342     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
343     if self.port != 80: url += b':%d' % self.port
344     url += b'/'
345     return url
346     
347 def process_cfg_saddrs():
348   try: port = cfg.getint('server','port')
349   except NoOptionError: port = 80
350
351   c.saddrs = [ ]
352   for addrspec in cfg.get('server','addrs').split():
353     sa = ServerAddr(port, addrspec)
354     c.saddrs.append(sa)
355
356 def process_cfg_clients(constructor):
357   c.clients = [ ]
358   for cs in cfg.sections():
359     if not (':' in cs or '.' in cs): continue
360     ci = ipaddr(cs)
361     pw = cfg.get(cs, 'password')
362     pw = pw.encode('utf-8')
363     constructor(ci,cs,pw)
364
365 #---------- startup ----------
366
367 def common_startup():
368   log_formatter = twisted.logger.formatEventAsClassicLogText
369   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
370   twisted.logger.globalLogBeginner.beginLoggingTo(
371     [ log_observer, crash_on_critical ]
372     )
373
374   optparser.add_option('-c', '--config', dest='configfile',
375                        default='/etc/hippotat/config')
376   (opts, args) = optparser.parse_args()
377   if len(args): optparser.error('no non-option arguments please')
378
379   re = regexp.compile('#.*')
380   cfg.read_string(re.sub('', defcfg))
381   cfg.read(opts.configfile)
382
383 def common_run():
384   log_debug(DBG.INIT, 'entering reactor')
385   if not _crashing: reactor.run()
386   print('CRASHED (end)', file=sys.stderr)