chiark / gitweb /
064909100fbe409fe30ee67b772c902140060443
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32   ROUTE = NamedConstant()
33   DROP = NamedConstant()
34   FLOW = NamedConstant()
35   HTTP = NamedConstant()
36   HTTP_CTRL = NamedConstant()
37   INIT = NamedConstant()
38   QUEUE = NamedConstant()
39   QUEUE_CTRL = NamedConstant()
40
41 _hex_codec = codecs.getencoder('hex_codec')
42
43 log = twisted.logger.Logger()
44
45 def log_debug(dflag, msg, idof=None, d=None):
46   if idof is not None:
47     msg = '[%d] %s' % (id(idof), msg)
48   if d is not None:
49     d = d[0:64]
50     d = _hex_codec(d)[0].decode('ascii')
51     msg += ' ' + d
52   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
53
54 defcfg = '''
55 [DEFAULT]
56 #[<client>] overrides
57 max_batch_down = 65536           # used by server, subject to [limits]
58 max_queue_time = 10              # used by server, subject to [limits]
59 max_request_time = 54            # used by server, subject to [limits]
60 target_requests_outstanding = 3  # must match; subject to [limits] on server
61 max_requests_outstanding = 4     # used by client
62 max_batch_up = 4000              # used by client
63 http_timeout = 30                # used by client
64 http_retry = 5                   # used by client
65
66 #[server] or [<client>] overrides
67 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
68 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
69 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
70 #      from   on client  <client>         [virtual]server   [virtual]routes
71
72 [virtual]
73 mtu = 1500
74 routes = ''
75 # network = <prefix>/<len>  # mandatory for server
76 # server  = <ipaddr>   # used by both, default is computed from `network'
77 # relay   = <ipaddr>   # used by server, default from `network' and `server'
78 #  default server is first host in network
79 #  default relay is first host which is not server
80
81 [server]
82 # addrs = 127.0.0.1 ::1    # mandatory for server
83 port = 80                  # used by server
84 # url              # used by client; default from first `addrs' and `port'
85
86 # [<client-ip4-or-ipv6-address>]
87 # password = <password>    # used by both, must match
88
89 [limits]
90 max_batch_down = 262144           # used by server
91 max_queue_time = 121              # used by server
92 max_request_time = 121            # used by server
93 target_requests_outstanding = 10  # used by server
94 '''
95
96 # these need to be defined here so that they can be imported by import *
97 cfg = ConfigParser()
98 optparser = OptionParser()
99
100 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
101 def mime_translate(s):
102   # SLIP-encoded packets cannot contain ESC ESC.
103   # Swap `-' and ESC.  The result cannot contain `--'
104   return s.translate(_mimetrans)
105
106 class ConfigResults:
107   def __init__(self, d = { }):
108     self.__dict__ = d
109   def __repr__(self):
110     return 'ConfigResults('+repr(self.__dict__)+')'
111
112 c = ConfigResults()
113
114 def log_discard(packet, iface, saddr, daddr, why):
115   log_debug(DBG.DROP,
116             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
117             d=packet)
118
119 #---------- packet parsing ----------
120
121 def packet_addrs(packet):
122   version = packet[0] >> 4
123   if version == 4:
124     addrlen = 4
125     saddroff = 3*4
126     factory = ipaddress.IPv4Address
127   elif version == 6:
128     addrlen = 16
129     saddroff = 2*4
130     factory = ipaddress.IPv6Address
131   else:
132     raise ValueError('unsupported IP version %d' % version)
133   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
134   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
135   return (saddr, daddr)
136
137 #---------- address handling ----------
138
139 def ipaddr(input):
140   try:
141     r = ipaddress.IPv4Address(input)
142   except AddressValueError:
143     r = ipaddress.IPv6Address(input)
144   return r
145
146 def ipnetwork(input):
147   try:
148     r = ipaddress.IPv4Network(input)
149   except NetworkValueError:
150     r = ipaddress.IPv6Network(input)
151   return r
152
153 #---------- ipif (SLIP) subprocess ----------
154
155 class SlipStreamDecoder():
156   def __init__(self, on_packet):
157     # we will call packet(<packet>)
158     self._buffer = b''
159     self._on_packet = on_packet
160
161   def inputdata(self, data):
162     #print('SLIP-GOT ', repr(data))
163     self._buffer += data
164     packets = slip.decode(self._buffer)
165     self._buffer = packets.pop()
166     for packet in packets:
167       self._maybe_packet(packet)
168
169   def _maybe_packet(self, packet):
170       if len(packet):
171         self._on_packet(packet)
172
173   def flush(self):
174     self._maybe_packet(self._buffer)
175     self._buffer = b''
176
177 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
178   def __init__(self, router):
179     self._router = router
180     self._decoder = SlipStreamDecoder(self.slip_on_packet)
181   def connectionMade(self): pass
182   def outReceived(self, data):
183     self._decoder.inputdata(data)
184   def slip_on_packet(self, packet):
185     (saddr, daddr) = packet_addrs(packet)
186     if saddr.is_link_local or daddr.is_link_local:
187       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
188       return
189     self._router(packet, saddr, daddr)
190   def processEnded(self, status):
191     status.raiseException()
192
193 def start_ipif(command, router):
194   global ipif
195   ipif = _IpifProcessProtocol(router)
196   reactor.spawnProcess(ipif,
197                        '/bin/sh',['sh','-xc', command],
198                        childFDs={0:'w', 1:'r', 2:2},
199                        env=None)
200
201 def queue_inbound(packet):
202   ipif.transport.write(slip.delimiter)
203   ipif.transport.write(slip.encode(packet))
204   ipif.transport.write(slip.delimiter)
205
206 #---------- packet queue ----------
207
208 class PacketQueue():
209   def __init__(self, desc, max_queue_time):
210     self._desc = desc
211     self._max_queue_time = max_queue_time
212     self._pq = collections.deque() # packets
213
214   def _log(self, dflag, msg, **kwargs):
215     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
216
217   def append(self, packet):
218     self._log(DBG.QUEUE, 'append', d=packet)
219     self._pq.append((time.monotonic(), packet))
220
221   def nonempty(self):
222     self._log(DBG.QUEUE, 'nonempty ?')
223     while True:
224       try: (queuetime, packet) = self._pq[0]
225       except IndexError:
226         self._log(DBG.QUEUE, 'nonempty ? empty.')
227         return False
228
229       age = time.monotonic() - queuetime
230       if age > self._max_queue_time:
231         # strip old packets off the front
232         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
233         self._pq.popleft()
234         continue
235
236       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
237       return True
238
239   def process(self, sizequery, moredata, max_batch):
240     # sizequery() should return size of batch so far
241     # moredata(s) should add s to batch
242     self._log(DBG.QUEUE, 'process...')
243     while True:
244       try: (dummy, packet) = self._pq[0]
245       except IndexError:
246         self._log(DBG.QUEUE, 'process... empty')
247         break
248
249       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
250
251       encoded = slip.encode(packet)
252       sofar = sizequery()  
253
254       self._log(DBG.QUEUE_CTRL,
255                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
256                 d=encoded)
257
258       if sofar > 0:
259         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
260           self._log(DBG.QUEUE_CTRL, 'process... overflow')
261           break
262         moredata(slip.delimiter)
263
264       moredata(encoded)
265       self._pq.popleft()
266
267 #---------- error handling ----------
268
269 _crashing = False
270
271 def crash(err):
272   global _crashing
273   _crashing = True
274   print('CRASH ', err, file=sys.stderr)
275   try: reactor.stop()
276   except twisted.internet.error.ReactorNotRunning: pass
277
278 def crash_on_defer(defer):
279   defer.addErrback(lambda err: crash(err))
280
281 def crash_on_critical(event):
282   if event.get('log_level') >= LogLevel.critical:
283     crash(twisted.logger.formatEvent(event))
284
285 #---------- config processing ----------
286
287 def process_cfg_common_always():
288   global mtu
289   c.mtu = cfg.get('virtual','mtu')
290
291 def process_cfg_ipif(section, varmap):
292   for d, s in varmap:
293     try: v = getattr(c, s)
294     except AttributeError: continue
295     setattr(c, d, v)
296
297   #print(repr(c))
298
299   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
300
301 def process_cfg_network():
302   c.network = ipnetwork(cfg.get('virtual','network'))
303   if c.network.num_addresses < 3 + 2:
304     raise ValueError('network needs at least 2^3 addresses')
305
306 def process_cfg_server():
307   try:
308     c.server = cfg.get('virtual','server')
309   except NoOptionError:
310     process_cfg_network()
311     c.server = next(c.network.hosts())
312
313 class ServerAddr():
314   def __init__(self, port, addrspec):
315     self.port = port
316     # also self.addr
317     try:
318       self.addr = ipaddress.IPv4Address(addrspec)
319       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
320       self._inurl = b'%s'
321     except AddressValueError:
322       self.addr = ipaddress.IPv6Address(addrspec)
323       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
324       self._inurl = b'[%s]'
325   def make_endpoint(self):
326     return self._endpointfactory(reactor, self.port, self.addr)
327   def url(self):
328     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
329     if self.port != 80: url += b':%d' % self.port
330     url += b'/'
331     return url
332     
333 def process_cfg_saddrs():
334   try: port = cfg.getint('server','port')
335   except NoOptionError: port = 80
336
337   c.saddrs = [ ]
338   for addrspec in cfg.get('server','addrs').split():
339     sa = ServerAddr(port, addrspec)
340     c.saddrs.append(sa)
341
342 def process_cfg_clients(constructor):
343   c.clients = [ ]
344   for cs in cfg.sections():
345     if not (':' in cs or '.' in cs): continue
346     ci = ipaddr(cs)
347     pw = cfg.get(cs, 'password')
348     pw = pw.encode('utf-8')
349     constructor(ci,cs,pw)
350
351 #---------- startup ----------
352
353 def common_startup():
354   log_formatter = twisted.logger.formatEventAsClassicLogText
355   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
356   twisted.logger.globalLogBeginner.beginLoggingTo(
357     [ log_observer, crash_on_critical ]
358     )
359
360   optparser.add_option('-c', '--config', dest='configfile',
361                        default='/etc/hippotat/config')
362   (opts, args) = optparser.parse_args()
363   if len(args): optparser.error('no non-option arguments please')
364
365   re = regexp.compile('#.*')
366   cfg.read_string(re.sub('', defcfg))
367   cfg.read(opts.configfile)
368
369 def common_run():
370   log_debug(DBG.INIT, 'entering reactor')
371   if not _crashing: reactor.run()
372   print('CRASHED (end)', file=sys.stderr)