chiark / gitweb /
798f5d2bbce4f3f31a8eef9bbec3a83dcc7e250e
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32   ROUTE = NamedConstant()
33   DROP = NamedConstant()
34   FLOW = NamedConstant()
35   HTTP = NamedConstant()
36   HTTP_CTRL = NamedConstant()
37   INIT = NamedConstant()
38   QUEUE = NamedConstant()
39   QUEUE_CTRL = NamedConstant()
40   HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
48   if idof is not None:
49     msg = '[%d] %s' % (id(idof), msg)
50   if d is not None:
51     #d = d[0:64]
52     d = _hex_codec(d)[0].decode('ascii')
53     msg += ' ' + d
54   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
55
56 defcfg = '''
57 [DEFAULT]
58 #[<client>] overrides
59 max_batch_down = 65536           # used by server, subject to [limits]
60 max_queue_time = 10              # used by server, subject to [limits]
61 max_request_time = 54            # used by server, subject to [limits]
62 target_requests_outstanding = 3  # must match; subject to [limits] on server
63 max_requests_outstanding = 4     # used by client
64 max_batch_up = 4000              # used by client
65 http_timeout = 30                # used by client
66 http_retry = 5                   # used by client
67
68 #[server] or [<client>] overrides
69 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
70 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
71 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
72 #      from   on client  <client>         [virtual]server   [virtual]routes
73
74 [virtual]
75 mtu = 1500
76 routes = ''
77 # network = <prefix>/<len>  # mandatory for server
78 # server  = <ipaddr>   # used by both, default is computed from `network'
79 # relay   = <ipaddr>   # used by server, default from `network' and `server'
80 #  default server is first host in network
81 #  default relay is first host which is not server
82
83 [server]
84 # addrs = 127.0.0.1 ::1    # mandatory for server
85 port = 80                  # used by server
86 # url              # used by client; default from first `addrs' and `port'
87
88 # [<client-ip4-or-ipv6-address>]
89 # password = <password>    # used by both, must match
90
91 [limits]
92 max_batch_down = 262144           # used by server
93 max_queue_time = 121              # used by server
94 max_request_time = 121            # used by server
95 target_requests_outstanding = 10  # used by server
96 '''
97
98 # these need to be defined here so that they can be imported by import *
99 cfg = ConfigParser()
100 optparser = OptionParser()
101
102 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
103 def mime_translate(s):
104   # SLIP-encoded packets cannot contain ESC ESC.
105   # Swap `-' and ESC.  The result cannot contain `--'
106   return s.translate(_mimetrans)
107
108 class ConfigResults:
109   def __init__(self, d = { }):
110     self.__dict__ = d
111   def __repr__(self):
112     return 'ConfigResults('+repr(self.__dict__)+')'
113
114 c = ConfigResults()
115
116 def log_discard(packet, iface, saddr, daddr, why):
117   log_debug(DBG.DROP,
118             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
119             d=packet)
120
121 #---------- packet parsing ----------
122
123 def packet_addrs(packet):
124   version = packet[0] >> 4
125   if version == 4:
126     addrlen = 4
127     saddroff = 3*4
128     factory = ipaddress.IPv4Address
129   elif version == 6:
130     addrlen = 16
131     saddroff = 2*4
132     factory = ipaddress.IPv6Address
133   else:
134     raise ValueError('unsupported IP version %d' % version)
135   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
136   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
137   return (saddr, daddr)
138
139 #---------- address handling ----------
140
141 def ipaddr(input):
142   try:
143     r = ipaddress.IPv4Address(input)
144   except AddressValueError:
145     r = ipaddress.IPv6Address(input)
146   return r
147
148 def ipnetwork(input):
149   try:
150     r = ipaddress.IPv4Network(input)
151   except NetworkValueError:
152     r = ipaddress.IPv6Network(input)
153   return r
154
155 #---------- ipif (SLIP) subprocess ----------
156
157 class SlipStreamDecoder():
158   def __init__(self, on_packet):
159     # we will call packet(<packet>)
160     self._buffer = b''
161     self._on_packet = on_packet
162
163   def inputdata(self, data):
164     #print('SLIP-GOT ', repr(data))
165     self._buffer += data
166     packets = slip.decode(self._buffer)
167     self._buffer = packets.pop()
168     for packet in packets:
169       self._maybe_packet(packet)
170
171   def _maybe_packet(self, packet):
172       if len(packet):
173         self._on_packet(packet)
174
175   def flush(self):
176     self._maybe_packet(self._buffer)
177     self._buffer = b''
178
179 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
180   def __init__(self, router):
181     self._router = router
182     self._decoder = SlipStreamDecoder(self.slip_on_packet)
183   def connectionMade(self): pass
184   def outReceived(self, data):
185     self._decoder.inputdata(data)
186   def slip_on_packet(self, packet):
187     (saddr, daddr) = packet_addrs(packet)
188     if saddr.is_link_local or daddr.is_link_local:
189       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
190       return
191     self._router(packet, saddr, daddr)
192   def processEnded(self, status):
193     status.raiseException()
194
195 def start_ipif(command, router):
196   global ipif
197   ipif = _IpifProcessProtocol(router)
198   reactor.spawnProcess(ipif,
199                        '/bin/sh',['sh','-xc', command],
200                        childFDs={0:'w', 1:'r', 2:2},
201                        env=None)
202
203 def queue_inbound(packet):
204   log_debug(DBG.FLOW, "queue_inbound", d=packet)
205   ipif.transport.write(slip.delimiter)
206   ipif.transport.write(slip.encode(packet))
207   ipif.transport.write(slip.delimiter)
208
209 #---------- packet queue ----------
210
211 class PacketQueue():
212   def __init__(self, desc, max_queue_time):
213     self._desc = desc
214     assert(desc + '')
215     self._max_queue_time = max_queue_time
216     self._pq = collections.deque() # packets
217
218   def _log(self, dflag, msg, **kwargs):
219     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
220
221   def append(self, packet):
222     self._log(DBG.QUEUE, 'append', d=packet)
223     self._pq.append((time.monotonic(), packet))
224
225   def nonempty(self):
226     self._log(DBG.QUEUE, 'nonempty ?')
227     while True:
228       try: (queuetime, packet) = self._pq[0]
229       except IndexError:
230         self._log(DBG.QUEUE, 'nonempty ? empty.')
231         return False
232
233       age = time.monotonic() - queuetime
234       if age > self._max_queue_time:
235         # strip old packets off the front
236         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
237         self._pq.popleft()
238         continue
239
240       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
241       return True
242
243   def process(self, sizequery, moredata, max_batch):
244     # sizequery() should return size of batch so far
245     # moredata(s) should add s to batch
246     self._log(DBG.QUEUE, 'process...')
247     while True:
248       try: (dummy, packet) = self._pq[0]
249       except IndexError:
250         self._log(DBG.QUEUE, 'process... empty')
251         break
252
253       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
254
255       encoded = slip.encode(packet)
256       sofar = sizequery()  
257
258       self._log(DBG.QUEUE_CTRL,
259                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
260                 d=encoded)
261
262       if sofar > 0:
263         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
264           self._log(DBG.QUEUE_CTRL, 'process... overflow')
265           break
266         moredata(slip.delimiter)
267
268       moredata(encoded)
269       self._pq.popleft()
270
271 #---------- error handling ----------
272
273 _crashing = False
274
275 def crash(err):
276   global _crashing
277   _crashing = True
278   print('CRASH ', err, file=sys.stderr)
279   try: reactor.stop()
280   except twisted.internet.error.ReactorNotRunning: pass
281
282 def crash_on_defer(defer):
283   defer.addErrback(lambda err: crash(err))
284
285 def crash_on_critical(event):
286   if event.get('log_level') >= LogLevel.critical:
287     crash(twisted.logger.formatEvent(event))
288
289 #---------- config processing ----------
290
291 def process_cfg_common_always():
292   global mtu
293   c.mtu = cfg.get('virtual','mtu')
294
295 def process_cfg_ipif(section, varmap):
296   for d, s in varmap:
297     try: v = getattr(c, s)
298     except AttributeError: continue
299     setattr(c, d, v)
300
301   #print(repr(c))
302
303   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
304
305 def process_cfg_network():
306   c.network = ipnetwork(cfg.get('virtual','network'))
307   if c.network.num_addresses < 3 + 2:
308     raise ValueError('network needs at least 2^3 addresses')
309
310 def process_cfg_server():
311   try:
312     c.server = cfg.get('virtual','server')
313   except NoOptionError:
314     process_cfg_network()
315     c.server = next(c.network.hosts())
316
317 class ServerAddr():
318   def __init__(self, port, addrspec):
319     self.port = port
320     # also self.addr
321     try:
322       self.addr = ipaddress.IPv4Address(addrspec)
323       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
324       self._inurl = b'%s'
325     except AddressValueError:
326       self.addr = ipaddress.IPv6Address(addrspec)
327       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
328       self._inurl = b'[%s]'
329   def make_endpoint(self):
330     return self._endpointfactory(reactor, self.port, self.addr)
331   def url(self):
332     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
333     if self.port != 80: url += b':%d' % self.port
334     url += b'/'
335     return url
336     
337 def process_cfg_saddrs():
338   try: port = cfg.getint('server','port')
339   except NoOptionError: port = 80
340
341   c.saddrs = [ ]
342   for addrspec in cfg.get('server','addrs').split():
343     sa = ServerAddr(port, addrspec)
344     c.saddrs.append(sa)
345
346 def process_cfg_clients(constructor):
347   c.clients = [ ]
348   for cs in cfg.sections():
349     if not (':' in cs or '.' in cs): continue
350     ci = ipaddr(cs)
351     pw = cfg.get(cs, 'password')
352     pw = pw.encode('utf-8')
353     constructor(ci,cs,pw)
354
355 #---------- startup ----------
356
357 def common_startup():
358   log_formatter = twisted.logger.formatEventAsClassicLogText
359   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
360   twisted.logger.globalLogBeginner.beginLoggingTo(
361     [ log_observer, crash_on_critical ]
362     )
363
364   optparser.add_option('-c', '--config', dest='configfile',
365                        default='/etc/hippotat/config')
366   (opts, args) = optparser.parse_args()
367   if len(args): optparser.error('no non-option arguments please')
368
369   re = regexp.compile('#.*')
370   cfg.read_string(re.sub('', defcfg))
371   cfg.read(opts.configfile)
372
373 def common_run():
374   log_debug(DBG.INIT, 'entering reactor')
375   if not _crashing: reactor.run()
376   print('CRASHED (end)', file=sys.stderr)