chiark / gitweb /
abc2a9fd687129eecd9732e258ac68168ed24711
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21 import time
22
23 import re as regexp
24
25 from twisted.python.constants import NamedConstant
26
27 import hippotat.slip as slip
28
29 class DBG(twisted.python.constants.Names):
30   ROUTE = NamedConstant()
31   FLOW = NamedConstant()
32   HTTP = NamedConstant()
33   HTTP_CTRL = NamedConstant()
34   INIT = NamedConstant()
35   QUEUE = NamedConstant()
36   QUEUE_CTRL = NamedConstant()
37
38 defcfg = '''
39 [DEFAULT]
40 #[<client>] overrides
41 max_batch_down = 65536           # used by server, subject to [limits]
42 max_queue_time = 10              # used by server, subject to [limits]
43 max_request_time = 54            # used by server, subject to [limits]
44 target_requests_outstanding = 3  # must match; subject to [limits] on server
45 max_requests_outstanding = 4     # used by client
46 max_batch_up = 4000              # used by client
47 http_timeout = 30                # used by client
48 http_retry = 5                   # used by client
49
50 #[server] or [<client>] overrides
51 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
52 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
53 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
54 #      from   on client  <client>         [virtual]server   [virtual]routes
55
56 [virtual]
57 mtu = 1500
58 routes = ''
59 # network = <prefix>/<len>  # mandatory for server
60 # server  = <ipaddr>   # used by both, default is computed from `network'
61 # relay   = <ipaddr>   # used by server, default from `network' and `server'
62 #  default server is first host in network
63 #  default relay is first host which is not server
64
65 [server]
66 # addrs = 127.0.0.1 ::1    # mandatory for server
67 port = 80                  # used by server
68 # url              # used by client; default from first `addrs' and `port'
69
70 # [<client-ip4-or-ipv6-address>]
71 # password = <password>    # used by both, must match
72
73 [limits]
74 max_batch_down = 262144           # used by server
75 max_queue_time = 121              # used by server
76 max_request_time = 121            # used by server
77 target_requests_outstanding = 10  # used by server
78 '''
79
80 # these need to be defined here so that they can be imported by import *
81 cfg = ConfigParser()
82 optparser = OptionParser()
83
84 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
85 def mime_translate(s):
86   # SLIP-encoded packets cannot contain ESC ESC.
87   # Swap `-' and ESC.  The result cannot contain `--'
88   return s.translate(_mimetrans)
89
90 class ConfigResults:
91   def __init__(self, d = { }):
92     self.__dict__ = d
93   def __repr__(self):
94     return 'ConfigResults('+repr(self.__dict__)+')'
95
96 c = ConfigResults()
97
98 def log_discard(packet, saddr, daddr, why):
99   print('DROP ', saddr, daddr, why)
100 #  syslog.syslog(syslog.LOG_DEBUG,
101 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
102
103 #---------- packet parsing ----------
104
105 def packet_addrs(packet):
106   version = packet[0] >> 4
107   if version == 4:
108     addrlen = 4
109     saddroff = 3*4
110     factory = ipaddress.IPv4Address
111   elif version == 6:
112     addrlen = 16
113     saddroff = 2*4
114     factory = ipaddress.IPv6Address
115   else:
116     raise ValueError('unsupported IP version %d' % version)
117   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
118   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
119   return (saddr, daddr)
120
121 #---------- address handling ----------
122
123 def ipaddr(input):
124   try:
125     r = ipaddress.IPv4Address(input)
126   except AddressValueError:
127     r = ipaddress.IPv6Address(input)
128   return r
129
130 def ipnetwork(input):
131   try:
132     r = ipaddress.IPv4Network(input)
133   except NetworkValueError:
134     r = ipaddress.IPv6Network(input)
135   return r
136
137 #---------- ipif (SLIP) subprocess ----------
138
139 class SlipStreamDecoder():
140   def __init__(self, on_packet):
141     # we will call packet(<packet>)
142     self._buffer = b''
143     self._on_packet = on_packet
144
145   def inputdata(self, data):
146     #print('SLIP-GOT ', repr(data))
147     self._buffer += data
148     packets = slip.decode(self._buffer)
149     self._buffer = packets.pop()
150     for packet in packets:
151       self._maybe_packet(packet)
152
153   def _maybe_packet(self, packet):
154       if len(packet):
155         self._on_packet(packet)
156
157   def flush(self):
158     self._maybe_packet(self._buffer)
159     self._buffer = b''
160
161 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
162   def __init__(self, router):
163     self._router = router
164     self._decoder = SlipStreamDecoder(self.slip_on_packet)
165   def connectionMade(self): pass
166   def outReceived(self, data):
167     self._decoder.inputdata(data)
168   def slip_on_packet(self, packet):
169     (saddr, daddr) = packet_addrs(packet)
170     if saddr.is_link_local or daddr.is_link_local:
171       log_discard(packet, saddr, daddr, 'link-local')
172       return
173     self._router(packet, saddr, daddr)
174   def processEnded(self, status):
175     status.raiseException()
176
177 def start_ipif(command, router):
178   global ipif
179   ipif = _IpifProcessProtocol(router)
180   reactor.spawnProcess(ipif,
181                        '/bin/sh',['sh','-xc', command],
182                        childFDs={0:'w', 1:'r', 2:2},
183                        env=None)
184
185 def queue_inbound(packet):
186   ipif.transport.write(slip.delimiter)
187   ipif.transport.write(slip.encode(packet))
188   ipif.transport.write(slip.delimiter)
189
190 #---------- packet queue ----------
191
192 class PacketQueue():
193   def __init__(self, desc, max_queue_time):
194     self._desc = desc
195     self._max_queue_time = max_queue_time
196     self._pq = collections.deque() # packets
197
198   def _log_debug(self, fn, pri, msg)
199     log_debug(pri, 
200
201   def append(self, packet):
202     log_data(DBG.QUEUE, packet, 'pq %s: append' % self._desc)
203     self._pq.append((time.monotonic(), packet))
204
205   def nonempty(self):
206     log_debug(DBG.QUEUE, 'pq %s: nonempty ?' % self._desc)
207     while True:
208       try: (queuetime, packet) = self._pq[0]
209       except IndexError: return False
210
211       age = time.monotonic() - queuetime
212       if age > self._max_queue_time:
213         # strip old packets off the front
214         self._pq.popleft()
215         continue
216
217       return True
218
219   def process(self, sizequery, moredata, max_batch):
220     # sizequery() should return size of batch so far
221     # moredata(s) should add s to batch
222     while True:
223       try: (dummy, packet) = self._pq[0]
224       except IndexError: break
225
226       encoded = slip.encode(packet)
227       sofar = sizequery()  
228
229       if sofar > 0:
230         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
231           break
232         moredata(slip.delimiter)
233
234       moredata(encoded)
235       self._pq.popleft()
236
237 #---------- error handling ----------
238
239 def crash(err):
240   print('CRASH ', err, file=sys.stderr)
241   try: reactor.stop()
242   except twisted.internet.error.ReactorNotRunning: pass
243
244 def crash_on_defer(defer):
245   defer.addErrback(lambda err: crash(err))
246
247 def crash_on_critical(event):
248   if event.get('log_level') >= LogLevel.critical:
249     crash(twisted.logger.formatEvent(event))
250
251 #---------- config processing ----------
252
253 def process_cfg_common_always():
254   global mtu
255   c.mtu = cfg.get('virtual','mtu')
256
257 def process_cfg_ipif(section, varmap):
258   for d, s in varmap:
259     try: v = getattr(c, s)
260     except AttributeError: continue
261     setattr(c, d, v)
262
263   print(repr(c))
264
265   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
266
267 def process_cfg_network():
268   c.network = ipnetwork(cfg.get('virtual','network'))
269   if c.network.num_addresses < 3 + 2:
270     raise ValueError('network needs at least 2^3 addresses')
271
272 def process_cfg_server():
273   try:
274     c.server = cfg.get('virtual','server')
275   except NoOptionError:
276     process_cfg_network()
277     c.server = next(c.network.hosts())
278
279 class ServerAddr():
280   def __init__(self, port, addrspec):
281     self.port = port
282     # also self.addr
283     try:
284       self.addr = ipaddress.IPv4Address(addrspec)
285       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
286       self._inurl = b'%s'
287     except AddressValueError:
288       self.addr = ipaddress.IPv6Address(addrspec)
289       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
290       self._inurl = b'[%s]'
291   def make_endpoint(self):
292     return self._endpointfactory(reactor, self.port, self.addr)
293   def url(self):
294     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
295     if self.port != 80: url += b':%d' % self.port
296     url += b'/'
297     return url
298     
299 def process_cfg_saddrs():
300   try: port = cfg.getint('server','port')
301   except NoOptionError: port = 80
302
303   c.saddrs = [ ]
304   for addrspec in cfg.get('server','addrs').split():
305     sa = ServerAddr(port, addrspec)
306     c.saddrs.append(sa)
307
308 def process_cfg_clients(constructor):
309   c.clients = [ ]
310   for cs in cfg.sections():
311     if not (':' in cs or '.' in cs): continue
312     ci = ipaddr(cs)
313     pw = cfg.get(cs, 'password')
314     pw = pw.encode('utf-8')
315     constructor(ci,cs,pw)
316
317 #---------- startup ----------
318
319 def common_startup():
320   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
321
322   optparser.add_option('-c', '--config', dest='configfile',
323                        default='/etc/hippotat/config')
324   (opts, args) = optparser.parse_args()
325   if len(args): optparser.error('no non-option arguments please')
326
327   re = regexp.compile('#.*')
328   cfg.read_string(re.sub('', defcfg))
329   cfg.read(opts.configfile)
330
331 def common_run():
332   reactor.run()
333   print('CRASHED (end)', file=sys.stderr)