chiark / gitweb /
wip, new log stuff
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32   ROUTE = NamedConstant()
33   FLOW = NamedConstant()
34   HTTP = NamedConstant()
35   HTTP_CTRL = NamedConstant()
36   INIT = NamedConstant()
37   QUEUE = NamedConstant()
38   QUEUE_CTRL = NamedConstant()
39
40 _hexcodec = codecs.getencoder('hex_codec')
41
42 log = twisted.logger.Logger()
43
44 def log_debug(dflag, msg, idof=None, d=None):
45   if idof is not None:
46     msg = '[%d] %s' % (id(idof), msg)
47   if d is not None:
48     d = d[0:64]
49     d = _hex_codec(d)[0]
50     msg += ' ' + d
51   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
52
53 defcfg = '''
54 [DEFAULT]
55 #[<client>] overrides
56 max_batch_down = 65536           # used by server, subject to [limits]
57 max_queue_time = 10              # used by server, subject to [limits]
58 max_request_time = 54            # used by server, subject to [limits]
59 target_requests_outstanding = 3  # must match; subject to [limits] on server
60 max_requests_outstanding = 4     # used by client
61 max_batch_up = 4000              # used by client
62 http_timeout = 30                # used by client
63 http_retry = 5                   # used by client
64
65 #[server] or [<client>] overrides
66 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
67 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
68 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
69 #      from   on client  <client>         [virtual]server   [virtual]routes
70
71 [virtual]
72 mtu = 1500
73 routes = ''
74 # network = <prefix>/<len>  # mandatory for server
75 # server  = <ipaddr>   # used by both, default is computed from `network'
76 # relay   = <ipaddr>   # used by server, default from `network' and `server'
77 #  default server is first host in network
78 #  default relay is first host which is not server
79
80 [server]
81 # addrs = 127.0.0.1 ::1    # mandatory for server
82 port = 80                  # used by server
83 # url              # used by client; default from first `addrs' and `port'
84
85 # [<client-ip4-or-ipv6-address>]
86 # password = <password>    # used by both, must match
87
88 [limits]
89 max_batch_down = 262144           # used by server
90 max_queue_time = 121              # used by server
91 max_request_time = 121            # used by server
92 target_requests_outstanding = 10  # used by server
93 '''
94
95 # these need to be defined here so that they can be imported by import *
96 cfg = ConfigParser()
97 optparser = OptionParser()
98
99 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
100 def mime_translate(s):
101   # SLIP-encoded packets cannot contain ESC ESC.
102   # Swap `-' and ESC.  The result cannot contain `--'
103   return s.translate(_mimetrans)
104
105 class ConfigResults:
106   def __init__(self, d = { }):
107     self.__dict__ = d
108   def __repr__(self):
109     return 'ConfigResults('+repr(self.__dict__)+')'
110
111 c = ConfigResults()
112
113 def log_discard(packet, saddr, daddr, why):
114   
115   Print('Drop ', Saddr, Daddr, why)
116 #  syslog.syslog(syslog.LOG_DEBUG,
117 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
118
119 #---------- packet parsing ----------
120
121 def packet_addrs(packet):
122   version = packet[0] >> 4
123   if version == 4:
124     addrlen = 4
125     saddroff = 3*4
126     factory = ipaddress.IPv4Address
127   elif version == 6:
128     addrlen = 16
129     saddroff = 2*4
130     factory = ipaddress.IPv6Address
131   else:
132     raise ValueError('unsupported IP version %d' % version)
133   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
134   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
135   return (saddr, daddr)
136
137 #---------- address handling ----------
138
139 def ipaddr(input):
140   try:
141     r = ipaddress.IPv4Address(input)
142   except AddressValueError:
143     r = ipaddress.IPv6Address(input)
144   return r
145
146 def ipnetwork(input):
147   try:
148     r = ipaddress.IPv4Network(input)
149   except NetworkValueError:
150     r = ipaddress.IPv6Network(input)
151   return r
152
153 #---------- ipif (SLIP) subprocess ----------
154
155 class SlipStreamDecoder():
156   def __init__(self, on_packet):
157     # we will call packet(<packet>)
158     self._buffer = b''
159     self._on_packet = on_packet
160
161   def inputdata(self, data):
162     #print('SLIP-GOT ', repr(data))
163     self._buffer += data
164     packets = slip.decode(self._buffer)
165     self._buffer = packets.pop()
166     for packet in packets:
167       self._maybe_packet(packet)
168
169   def _maybe_packet(self, packet):
170       if len(packet):
171         self._on_packet(packet)
172
173   def flush(self):
174     self._maybe_packet(self._buffer)
175     self._buffer = b''
176
177 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
178   def __init__(self, router):
179     self._router = router
180     self._decoder = SlipStreamDecoder(self.slip_on_packet)
181   def connectionMade(self): pass
182   def outReceived(self, data):
183     self._decoder.inputdata(data)
184   def slip_on_packet(self, packet):
185     (saddr, daddr) = packet_addrs(packet)
186     if saddr.is_link_local or daddr.is_link_local:
187       log_discard(packet, saddr, daddr, 'link-local')
188       return
189     self._router(packet, saddr, daddr)
190   def processEnded(self, status):
191     status.raiseException()
192
193 def start_ipif(command, router):
194   global ipif
195   ipif = _IpifProcessProtocol(router)
196   reactor.spawnProcess(ipif,
197                        '/bin/sh',['sh','-xc', command],
198                        childFDs={0:'w', 1:'r', 2:2},
199                        env=None)
200
201 def queue_inbound(packet):
202   ipif.transport.write(slip.delimiter)
203   ipif.transport.write(slip.encode(packet))
204   ipif.transport.write(slip.delimiter)
205
206 #---------- packet queue ----------
207
208 class PacketQueue():
209   def __init__(self, desc, max_queue_time):
210     self._desc = desc
211     self._max_queue_time = max_queue_time
212     self._pq = collections.deque() # packets
213
214   def _log(self, dflag, msg, **kwargs)
215     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
216
217   def append(self, packet):
218     self._log(DBG.QUEUE, 'append', d=packet)
219     self._pq.append((time.monotonic(), packet))
220
221   def nonempty(self):
222     self._log(DBG.QUEUE, 'nonempty ?')
223     while True:
224       try: (queuetime, packet) = self._pq[0]
225       except IndexError:
226         self._log(DBG.QUEUE, 'nonempty ? empty.')
227         return False
228
229       age = time.monotonic() - queuetime
230       if age > self._max_queue_time:
231         # strip old packets off the front
232         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
233         self._pq.popleft()
234         continue
235
236       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
237       return True
238
239   def process(self, sizequery, moredata, max_batch):
240     # sizequery() should return size of batch so far
241     # moredata(s) should add s to batch
242     self._log(DBG.QUEUE, 'process...')
243     while True:
244       try: (dummy, packet) = self._pq[0]
245       except IndexError:
246         self._log(DBG.QUEUE, 'process... empty')
247         break
248
249       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
250
251       encoded = slip.encode(packet)
252       sofar = sizequery()  
253
254       self._log(DBG.QUEUE_CTRL,
255                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
256                 d=encoded
257
258       if sofar > 0:
259         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
260           self._log(DBG.QUEUE_CTRL, 'process... overflow')
261           break
262         moredata(slip.delimiter)
263
264       moredata(encoded)
265       self._pq.popleft()
266
267 #---------- error handling ----------
268
269 def crash(err):
270   print('CRASH ', err, file=sys.stderr)
271   try: reactor.stop()
272   except twisted.internet.error.ReactorNotRunning: pass
273
274 def crash_on_defer(defer):
275   defer.addErrback(lambda err: crash(err))
276
277 def crash_on_critical(event):
278   if event.get('log_level') >= LogLevel.critical:
279     crash(twisted.logger.formatEvent(event))
280
281 #---------- config processing ----------
282
283 def process_cfg_common_always():
284   global mtu
285   c.mtu = cfg.get('virtual','mtu')
286
287 def process_cfg_ipif(section, varmap):
288   for d, s in varmap:
289     try: v = getattr(c, s)
290     except AttributeError: continue
291     setattr(c, d, v)
292
293   print(repr(c))
294
295   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
296
297 def process_cfg_network():
298   c.network = ipnetwork(cfg.get('virtual','network'))
299   if c.network.num_addresses < 3 + 2:
300     raise ValueError('network needs at least 2^3 addresses')
301
302 def process_cfg_server():
303   try:
304     c.server = cfg.get('virtual','server')
305   except NoOptionError:
306     process_cfg_network()
307     c.server = next(c.network.hosts())
308
309 class ServerAddr():
310   def __init__(self, port, addrspec):
311     self.port = port
312     # also self.addr
313     try:
314       self.addr = ipaddress.IPv4Address(addrspec)
315       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
316       self._inurl = b'%s'
317     except AddressValueError:
318       self.addr = ipaddress.IPv6Address(addrspec)
319       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
320       self._inurl = b'[%s]'
321   def make_endpoint(self):
322     return self._endpointfactory(reactor, self.port, self.addr)
323   def url(self):
324     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
325     if self.port != 80: url += b':%d' % self.port
326     url += b'/'
327     return url
328     
329 def process_cfg_saddrs():
330   try: port = cfg.getint('server','port')
331   except NoOptionError: port = 80
332
333   c.saddrs = [ ]
334   for addrspec in cfg.get('server','addrs').split():
335     sa = ServerAddr(port, addrspec)
336     c.saddrs.append(sa)
337
338 def process_cfg_clients(constructor):
339   c.clients = [ ]
340   for cs in cfg.sections():
341     if not (':' in cs or '.' in cs): continue
342     ci = ipaddr(cs)
343     pw = cfg.get(cs, 'password')
344     pw = pw.encode('utf-8')
345     constructor(ci,cs,pw)
346
347 #---------- startup ----------
348
349 def common_startup():
350   log_formatter = twisted.logger.formatEventAsClassicLogText
351   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
352   twisted.logger.globalLogBeginner.beginLoggingTo(
353     [ log_observer, crash_on_critical ]
354     )
355
356   optparser.add_option('-c', '--config', dest='configfile',
357                        default='/etc/hippotat/config')
358   (opts, args) = optparser.parse_args()
359   if len(args): optparser.error('no non-option arguments please')
360
361   re = regexp.compile('#.*')
362   cfg.read_string(re.sub('', defcfg))
363   cfg.read(opts.configfile)
364
365 def common_run():
366   log_debug(DBG.INIT, 'ready')
367   reactor.run()
368   print('CRASHED (end)', file=sys.stderr)