chiark / gitweb /
wip
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32   ROUTE = NamedConstant()
33   DROP = NamedConstant()
34   FLOW = NamedConstant()
35   HTTP = NamedConstant()
36   HTTP_CTRL = NamedConstant()
37   INIT = NamedConstant()
38   QUEUE = NamedConstant()
39   QUEUE_CTRL = NamedConstant()
40   HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
48   if idof is not None:
49     msg = '[%d] %s' % (id(idof), msg)
50   if d is not None:
51     #d = d[0:64]
52     d = _hex_codec(d)[0].decode('ascii')
53     msg += ' ' + d
54   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
55
56 defcfg = '''
57 [DEFAULT]
58 #[<client>] overrides
59 max_batch_down = 65536           # used by server, subject to [limits]
60 max_queue_time = 10              # used by server, subject to [limits]
61 max_request_time = 54            # used by server, subject to [limits]
62 target_requests_outstanding = 3  # must match; subject to [limits] on server
63 max_requests_outstanding = 4     # used by client
64 max_batch_up = 4000              # used by client
65 http_timeout = 30                # used by client
66 http_retry = 5                   # used by client
67
68 #[server] or [<client>] overrides
69 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
70 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
71 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
72 #      from   on client  <client>         [virtual]server   [virtual]routes
73
74 [virtual]
75 mtu = 1500
76 routes = ''
77 # network = <prefix>/<len>  # mandatory for server
78 # server  = <ipaddr>   # used by both, default is computed from `network'
79 # relay   = <ipaddr>   # used by server, default from `network' and `server'
80 #  default server is first host in network
81 #  default relay is first host which is not server
82
83 [server]
84 # addrs = 127.0.0.1 ::1    # mandatory for server
85 port = 80                  # used by server
86 # url              # used by client; default from first `addrs' and `port'
87
88 # [<client-ip4-or-ipv6-address>]
89 # password = <password>    # used by both, must match
90
91 [limits]
92 max_batch_down = 262144           # used by server
93 max_queue_time = 121              # used by server
94 max_request_time = 121            # used by server
95 target_requests_outstanding = 10  # used by server
96 '''
97
98 # these need to be defined here so that they can be imported by import *
99 cfg = ConfigParser()
100 optparser = OptionParser()
101
102 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
103 def mime_translate(s):
104   # SLIP-encoded packets cannot contain ESC ESC.
105   # Swap `-' and ESC.  The result cannot contain `--'
106   return s.translate(_mimetrans)
107
108 class ConfigResults:
109   def __init__(self, d = { }):
110     self.__dict__ = d
111   def __repr__(self):
112     return 'ConfigResults('+repr(self.__dict__)+')'
113
114 c = ConfigResults()
115
116 def log_discard(packet, iface, saddr, daddr, why):
117   log_debug(DBG.DROP,
118             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
119             d=packet)
120
121 #---------- packet parsing ----------
122
123 def packet_addrs(packet):
124   version = packet[0] >> 4
125   if version == 4:
126     addrlen = 4
127     saddroff = 3*4
128     factory = ipaddress.IPv4Address
129   elif version == 6:
130     addrlen = 16
131     saddroff = 2*4
132     factory = ipaddress.IPv6Address
133   else:
134     raise ValueError('unsupported IP version %d' % version)
135   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
136   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
137   return (saddr, daddr)
138
139 #---------- address handling ----------
140
141 def ipaddr(input):
142   try:
143     r = ipaddress.IPv4Address(input)
144   except AddressValueError:
145     r = ipaddress.IPv6Address(input)
146   return r
147
148 def ipnetwork(input):
149   try:
150     r = ipaddress.IPv4Network(input)
151   except NetworkValueError:
152     r = ipaddress.IPv6Network(input)
153   return r
154
155 #---------- ipif (SLIP) subprocess ----------
156
157 class SlipStreamDecoder():
158   def __init__(self, on_packet):
159     # we will call packet(<packet>)
160     self._buffer = b''
161     self._on_packet = on_packet
162
163   def inputdata(self, data):
164     #print('SLIP-GOT ', repr(data))
165     self._buffer += data
166     packets = slip.decode(self._buffer)
167     self._buffer = packets.pop()
168     for packet in packets:
169       self._maybe_packet(packet)
170
171   def _maybe_packet(self, packet):
172       if len(packet):
173         self._on_packet(packet)
174
175   def flush(self):
176     self._maybe_packet(self._buffer)
177     self._buffer = b''
178
179 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
180   def __init__(self, router):
181     self._router = router
182     self._decoder = SlipStreamDecoder(self.slip_on_packet)
183   def connectionMade(self): pass
184   def outReceived(self, data):
185     self._decoder.inputdata(data)
186   def slip_on_packet(self, packet):
187     (saddr, daddr) = packet_addrs(packet)
188     if saddr.is_link_local or daddr.is_link_local:
189       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
190       return
191     self._router(packet, saddr, daddr)
192   def processEnded(self, status):
193     status.raiseException()
194
195 def start_ipif(command, router):
196   global ipif
197   ipif = _IpifProcessProtocol(router)
198   reactor.spawnProcess(ipif,
199                        '/bin/sh',['sh','-xc', command],
200                        childFDs={0:'w', 1:'r', 2:2},
201                        env=None)
202
203 def queue_inbound(packet):
204   ipif.transport.write(slip.delimiter)
205   ipif.transport.write(slip.encode(packet))
206   ipif.transport.write(slip.delimiter)
207
208 #---------- packet queue ----------
209
210 class PacketQueue():
211   def __init__(self, desc, max_queue_time):
212     self._desc = desc
213     assert(desc + '')
214     self._max_queue_time = max_queue_time
215     self._pq = collections.deque() # packets
216
217   def _log(self, dflag, msg, **kwargs):
218     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
219
220   def append(self, packet):
221     self._log(DBG.QUEUE, 'append', d=packet)
222     self._pq.append((time.monotonic(), packet))
223
224   def nonempty(self):
225     self._log(DBG.QUEUE, 'nonempty ?')
226     while True:
227       try: (queuetime, packet) = self._pq[0]
228       except IndexError:
229         self._log(DBG.QUEUE, 'nonempty ? empty.')
230         return False
231
232       age = time.monotonic() - queuetime
233       if age > self._max_queue_time:
234         # strip old packets off the front
235         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
236         self._pq.popleft()
237         continue
238
239       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
240       return True
241
242   def process(self, sizequery, moredata, max_batch):
243     # sizequery() should return size of batch so far
244     # moredata(s) should add s to batch
245     self._log(DBG.QUEUE, 'process...')
246     while True:
247       try: (dummy, packet) = self._pq[0]
248       except IndexError:
249         self._log(DBG.QUEUE, 'process... empty')
250         break
251
252       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
253
254       encoded = slip.encode(packet)
255       sofar = sizequery()  
256
257       self._log(DBG.QUEUE_CTRL,
258                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
259                 d=encoded)
260
261       if sofar > 0:
262         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
263           self._log(DBG.QUEUE_CTRL, 'process... overflow')
264           break
265         moredata(slip.delimiter)
266
267       moredata(encoded)
268       self._pq.popleft()
269
270 #---------- error handling ----------
271
272 _crashing = False
273
274 def crash(err):
275   global _crashing
276   _crashing = True
277   print('CRASH ', err, file=sys.stderr)
278   try: reactor.stop()
279   except twisted.internet.error.ReactorNotRunning: pass
280
281 def crash_on_defer(defer):
282   defer.addErrback(lambda err: crash(err))
283
284 def crash_on_critical(event):
285   if event.get('log_level') >= LogLevel.critical:
286     crash(twisted.logger.formatEvent(event))
287
288 #---------- config processing ----------
289
290 def process_cfg_common_always():
291   global mtu
292   c.mtu = cfg.get('virtual','mtu')
293
294 def process_cfg_ipif(section, varmap):
295   for d, s in varmap:
296     try: v = getattr(c, s)
297     except AttributeError: continue
298     setattr(c, d, v)
299
300   #print(repr(c))
301
302   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
303
304 def process_cfg_network():
305   c.network = ipnetwork(cfg.get('virtual','network'))
306   if c.network.num_addresses < 3 + 2:
307     raise ValueError('network needs at least 2^3 addresses')
308
309 def process_cfg_server():
310   try:
311     c.server = cfg.get('virtual','server')
312   except NoOptionError:
313     process_cfg_network()
314     c.server = next(c.network.hosts())
315
316 class ServerAddr():
317   def __init__(self, port, addrspec):
318     self.port = port
319     # also self.addr
320     try:
321       self.addr = ipaddress.IPv4Address(addrspec)
322       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
323       self._inurl = b'%s'
324     except AddressValueError:
325       self.addr = ipaddress.IPv6Address(addrspec)
326       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
327       self._inurl = b'[%s]'
328   def make_endpoint(self):
329     return self._endpointfactory(reactor, self.port, self.addr)
330   def url(self):
331     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
332     if self.port != 80: url += b':%d' % self.port
333     url += b'/'
334     return url
335     
336 def process_cfg_saddrs():
337   try: port = cfg.getint('server','port')
338   except NoOptionError: port = 80
339
340   c.saddrs = [ ]
341   for addrspec in cfg.get('server','addrs').split():
342     sa = ServerAddr(port, addrspec)
343     c.saddrs.append(sa)
344
345 def process_cfg_clients(constructor):
346   c.clients = [ ]
347   for cs in cfg.sections():
348     if not (':' in cs or '.' in cs): continue
349     ci = ipaddr(cs)
350     pw = cfg.get(cs, 'password')
351     pw = pw.encode('utf-8')
352     constructor(ci,cs,pw)
353
354 #---------- startup ----------
355
356 def common_startup():
357   log_formatter = twisted.logger.formatEventAsClassicLogText
358   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
359   twisted.logger.globalLogBeginner.beginLoggingTo(
360     [ log_observer, crash_on_critical ]
361     )
362
363   optparser.add_option('-c', '--config', dest='configfile',
364                        default='/etc/hippotat/config')
365   (opts, args) = optparser.parse_args()
366   if len(args): optparser.error('no non-option arguments please')
367
368   re = regexp.compile('#.*')
369   cfg.read_string(re.sub('', defcfg))
370   cfg.read(opts.configfile)
371
372 def common_run():
373   log_debug(DBG.INIT, 'entering reactor')
374   if not _crashing: reactor.run()
375   print('CRASHED (end)', file=sys.stderr)