chiark / gitweb /
029b22d7ddf245dcfa385459fc02a2762fc91cb2
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21 import time
22
23 import re as regexp
24
25 import hippotat.slip as slip
26
27 defcfg = '''
28 [DEFAULT]
29 #[<client>] overrides
30 max_batch_down = 65536           # used by server, subject to [limits]
31 max_queue_time = 10              # used by server, subject to [limits]
32 max_request_time = 54            # used by server, subject to [limits]
33 target_requests_outstanding = 3  # must match; subject to [limits] on server
34 max_requests_outstanding = 4     # used by client
35 max_batch_up = 4000              # used by client
36 http_timeout = 30                # used by client
37
38 #[server] or [<client>] overrides
39 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
40 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
41 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
42 #      from   on client  <client>         [virtual]server   [virtual]routes
43
44 [virtual]
45 mtu = 1500
46 routes = ''
47 # network = <prefix>/<len>  # mandatory for server
48 # server  = <ipaddr>   # used by both, default is computed from `network'
49 # relay   = <ipaddr>   # used by server, default from `network' and `server'
50 #  default server is first host in network
51 #  default relay is first host which is not server
52
53 [server]
54 # addrs = 127.0.0.1 ::1    # mandatory for server
55 port = 80                  # used by server
56 # url              # used by client; default from first `addrs' and `port'
57
58 # [<client-ip4-or-ipv6-address>]
59 # password = <password>    # used by both, must match
60
61 [limits]
62 max_batch_down = 262144           # used by server
63 max_queue_time = 121              # used by server
64 max_request_time = 121            # used by server
65 target_requests_outstanding = 10  # used by server
66 '''
67
68 # these need to be defined here so that they can be imported by import *
69 cfg = ConfigParser()
70 optparser = OptionParser()
71
72 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
73 def mime_translate(s):
74   # SLIP-encoded packets cannot contain ESC ESC.
75   # Swap `-' and ESC.  The result cannot contain `--'
76   return s.translate(_mimetrans)
77
78 class ConfigResults:
79   def __init__(self, d = { }):
80     self.__dict__ = d
81   def __repr__(self):
82     return 'ConfigResults('+repr(self.__dict__)+')'
83
84 c = ConfigResults()
85
86 def log_discard(packet, saddr, daddr, why):
87   print('DROP ', saddr, daddr, why)
88 #  syslog.syslog(syslog.LOG_DEBUG,
89 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
90
91 #---------- packet parsing ----------
92
93 def packet_addrs(packet):
94   version = packet[0] >> 4
95   if version == 4:
96     addrlen = 4
97     saddroff = 3*4
98     factory = ipaddress.IPv4Address
99   elif version == 6:
100     addrlen = 16
101     saddroff = 2*4
102     factory = ipaddress.IPv6Address
103   else:
104     raise ValueError('unsupported IP version %d' % version)
105   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
106   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
107   return (saddr, daddr)
108
109 #---------- address handling ----------
110
111 def ipaddr(input):
112   try:
113     r = ipaddress.IPv4Address(input)
114   except AddressValueError:
115     r = ipaddress.IPv6Address(input)
116   return r
117
118 def ipnetwork(input):
119   try:
120     r = ipaddress.IPv4Network(input)
121   except NetworkValueError:
122     r = ipaddress.IPv6Network(input)
123   return r
124
125 #---------- ipif (SLIP) subprocess ----------
126
127 class SlipStreamDecoder():
128   def __init__(self, on_packet):
129     # we will call packet(<packet>)
130     self._buffer = b''
131     self._on_packet = on_packet
132
133   def inputdata(self, data):
134     #print('SLIP-GOT ', repr(data))
135     self._buffer += data
136     packets = slip.decode(self._buffer)
137     self._buffer = packets.pop()
138     for packet in packets:
139       self._maybe_packet(packet)
140
141   def _maybe_packet(self, packet):
142       if len(packet):
143         self._on_packet(packet)
144
145   def flush(self):
146     self._maybe_packet(self._buffer)
147     self._buffer = b''
148
149 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
150   def __init__(self, router):
151     self._router = router
152     self._decoder = SlipStreamDecoder(self.slip_on_packet)
153   def connectionMade(self): pass
154   def outReceived(self, data):
155     self._decoder.inputdata(data)
156   def slip_on_packet(self, packet):
157     (saddr, daddr) = packet_addrs(packet)
158     if saddr.is_link_local or daddr.is_link_local:
159       log_discard(packet, saddr, daddr, 'link-local')
160       return
161     self._router(packet, saddr, daddr)
162   def processEnded(self, status):
163     status.raiseException()
164
165 def start_ipif(command, router):
166   global ipif
167   ipif = _IpifProcessProtocol(router)
168   reactor.spawnProcess(ipif,
169                        '/bin/sh',['sh','-xc', command],
170                        childFDs={0:'w', 1:'r', 2:2},
171                        env=None)
172
173 def queue_inbound(packet):
174   ipif.transport.write(slip.delimiter)
175   ipif.transport.write(slip.encode(packet))
176   ipif.transport.write(slip.delimiter)
177
178 #---------- packet queue ----------
179
180 class PacketQueue():
181   def __init__(self, max_queue_time):
182     self._max_queue_time = max_queue_time
183     self._pq = collections.deque() # packets
184
185   def append(self, packet):
186     self._pq.append((time.monotonic(), packet))
187
188   def nonempty(self):
189     while True:
190       try: (queuetime, packet) = self._pq[0]
191       except IndexError: return False
192
193       age = time.monotonic() - queuetime
194       if age > self._max_queue_time:
195         # strip old packets off the front
196         self._pq.popleft()
197         continue
198
199       return True
200
201   def process(self, sizequery, moredata, max_batch):
202     # sizequery() should return size of batch so far
203     # moredata(s) should add s to batch
204     while True:
205       try: (dummy, packet) = self._pq[0]
206       except IndexError: break
207
208       encoded = slip.encode(packet)
209       sofar = sizequery()  
210
211       if sofar > 0:
212         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
213           break
214         moredata(slip.delimiter)
215
216       moredata(encoded)
217       self._pq.popleft()
218
219 #---------- error handling ----------
220
221 def crash(err):
222   print('CRASH ', err, file=sys.stderr)
223   try: reactor.stop()
224   except twisted.internet.error.ReactorNotRunning: pass
225
226 def crash_on_defer(defer):
227   defer.addErrback(lambda err: crash(err))
228
229 def crash_on_critical(event):
230   if event.get('log_level') >= LogLevel.critical:
231     crash(twisted.logger.formatEvent(event))
232
233 #---------- config processing ----------
234
235 def process_cfg_common_always():
236   global mtu
237   c.mtu = cfg.get('virtual','mtu')
238
239 def process_cfg_ipif(section, varmap):
240   for d, s in varmap:
241     try: v = getattr(c, s)
242     except AttributeError: continue
243     setattr(c, d, v)
244
245   print(repr(c))
246
247   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
248
249 def process_cfg_network():
250   c.network = ipnetwork(cfg.get('virtual','network'))
251   if c.network.num_addresses < 3 + 2:
252     raise ValueError('network needs at least 2^3 addresses')
253
254 def process_cfg_server():
255   try:
256     c.server = cfg.get('virtual','server')
257   except NoOptionError:
258     process_cfg_network()
259     c.server = next(c.network.hosts())
260
261 class ServerAddr():
262   def __init__(self, port, addrspec):
263     self.port = port
264     # also self.addr
265     try:
266       self.addr = ipaddress.IPv4Address(addrspec)
267       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
268       self._inurl = b'%s'
269     except AddressValueError:
270       self.addr = ipaddress.IPv6Address(addrspec)
271       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
272       self._inurl = b'[%s]'
273   def make_endpoint(self):
274     return self._endpointfactory(reactor, self.port, self.addr)
275   def url(self):
276     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
277     if self.port != 80: url += b':%d' % self.port
278     url += b'/'
279     return url
280     
281 def process_cfg_saddrs():
282   try: port = cfg.getint('server','port')
283   except NoOptionError: port = 80
284
285   c.saddrs = [ ]
286   for addrspec in cfg.get('server','addrs').split():
287     sa = ServerAddr(port, addrspec)
288     c.saddrs.append(sa)
289
290 def process_cfg_clients(constructor):
291   c.clients = [ ]
292   for cs in cfg.sections():
293     if not (':' in cs or '.' in cs): continue
294     ci = ipaddr(cs)
295     pw = cfg.get(cs, 'password')
296     pw = pw.encode('utf-8')
297     constructor(ci,cs,pw)
298
299 #---------- startup ----------
300
301 def common_startup():
302   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
303
304   optparser.add_option('-c', '--config', dest='configfile',
305                        default='/etc/hippotat/config')
306   (opts, args) = optparser.parse_args()
307   if len(args): optparser.error('no non-option arguments please')
308
309   re = regexp.compile('#.*')
310   cfg.read_string(re.sub('', defcfg))
311   cfg.read(opts.configfile)
312
313 def common_run():
314   reactor.run()
315   print('CRASHED (end)', file=sys.stderr)