chiark / gitweb /
reorg SlipStreamDecoder again
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 import re as regexp
23
24 import hippotat.slip as slip
25
26 defcfg = '''
27 [DEFAULT]
28 #[<client>] overrides
29 max_batch_down = 65536           # used by server, subject to [limits]
30 max_queue_time = 10              # used by server, subject to [limits]
31 max_request_time = 54            # used by server, subject to [limits]
32 target_requests_outstanding = 3  # must match; subject to [limits] on server
33 max_requests_outstanding = 4     # used by client
34 max_batch_up = 4000              # used by client
35 http_timeout = 30                # used by client
36
37 #[server] or [<client>] overrides
38 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
39 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
40 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
41 #      from   on client  <client>         [virtual]server   [virtual]routes
42
43 [virtual]
44 mtu = 1500
45 routes = ''
46 # network = <prefix>/<len>  # mandatory for server
47 # server  = <ipaddr>   # used by both, default is computed from `network'
48 # relay   = <ipaddr>   # used by server, default from `network' and `server'
49 #  default server is first host in network
50 #  default relay is first host which is not server
51
52 [server]
53 # addrs = 127.0.0.1 ::1    # mandatory for server
54 port = 80                  # used by server
55 # url              # used by client; default from first `addrs' and `port'
56
57 # [<client-ip4-or-ipv6-address>]
58 # password = <password>    # used by both, must match
59
60 [limits]
61 max_batch_down = 262144           # used by server
62 max_queue_time = 121              # used by server
63 max_request_time = 121            # used by server
64 target_requests_outstanding = 10  # used by server
65 '''
66
67 # these need to be defined here so that they can be imported by import *
68 cfg = ConfigParser()
69 optparser = OptionParser()
70
71 _mimetrans = str.maketrans(b'-'+slip.esc, slip.esc+'-')
72 def mime_translate(s):
73   # SLIP-encoded packets cannot contain ESC ESC.
74   # Swap `-' and ESC.  The result cannot contain `--'
75   return s.translate(_mimetrans)
76
77 class ConfigResults:
78   def __init__(self, d = { }):
79     self.__dict__ = d
80   def __repr__(self):
81     return 'ConfigResults('+repr(self.__dict__)+')'
82
83 c = ConfigResults()
84
85 def log_discard(packet, saddr, daddr, why):
86   print('DROP ', saddr, daddr, why)
87 #  syslog.syslog(syslog.LOG_DEBUG,
88 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
89
90 #---------- packet parsing ----------
91
92 def packet_addrs(packet):
93   version = packet[0] >> 4
94   if version == 4:
95     addrlen = 4
96     saddroff = 3*4
97     factory = ipaddress.IPv4Address
98   elif version == 6:
99     addrlen = 16
100     saddroff = 2*4
101     factory = ipaddress.IPv6Address
102   else:
103     raise ValueError('unsupported IP version %d' % version)
104   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
105   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
106   return (saddr, daddr)
107
108 #---------- address handling ----------
109
110 def ipaddr(input):
111   try:
112     r = ipaddress.IPv4Address(input)
113   except AddressValueError:
114     r = ipaddress.IPv6Address(input)
115   return r
116
117 def ipnetwork(input):
118   try:
119     r = ipaddress.IPv4Network(input)
120   except NetworkValueError:
121     r = ipaddress.IPv6Network(input)
122   return r
123
124 #---------- ipif (SLIP) subprocess ----------
125
126 class SlipStreamDecoder():
127   def __init__(self, on_packet):
128     # we will call packet(<packet>)
129     self._buffer = b''
130     self._on_packet = on_packet
131
132   def inputdata(self, data):
133     #print('SLIP-GOT ', repr(data))
134     self._buffer += data
135     packets = slip.decode(self._buffer)
136     self._buffer = packets.pop()
137     for packet in packets:
138       self._maybe_packet(packet)
139
140   def _maybe_packet(self, packet):
141       if len(packet):
142         self._on_packet(packet)
143
144   def flush(self):
145     self._maybe_packet(self._buffer)
146     self._buffer = b''
147
148 class _IpifProcessProtocol(SlipProtocol):
149   def __init__(self, router):
150     self._router = router
151     self._decoder = SlipStreamDecoder(self.slip_on_packet)
152   def connectionMade(self): pass
153   def outReceived(self, data):
154     self._decoder.inputdata(data)
155   def slip_on_packet(self, packet):
156     (saddr, daddr) = packet_addrs(packet)
157     if saddr.is_link_local or daddr.is_link_local:
158       log_discard(packet, saddr, daddr, 'link-local')
159       return
160     self._router(packet, saddr, daddr)
161   def processEnded(self, status):
162     status.raiseException()
163
164 def start_ipif(command, router):
165   global ipif
166   ipif = _IpifProcessProtocol(router)
167   reactor.spawnProcess(ipif,
168                        '/bin/sh',['sh','-xc', command],
169                        childFDs={0:'w', 1:'r', 2:2})
170
171 def queue_inbound(packet):
172   ipif.transport.write(slip.delimiter)
173   ipif.transport.write(slip.encode(packet))
174   ipif.transport.write(slip.delimiter)
175
176 #---------- packet queue ----------
177
178 class PacketQueue():
179   def __init__(self, max_queue_time):
180     self._max_queue_time = max_queue_time
181     self._pq = collections.deque() # packets
182
183   def append(self, packet):
184     self._pq.append((time.monotonic(), packet))
185
186   def nonempty(self):
187     while True:
188       try: (queuetime, packet) = self._pq[0]
189       except IndexError: return False
190
191       age = time.monotonic() - queuetime
192       if age > self.max_queue_time:
193         # strip old packets off the front
194         self._pq.popleft()
195         continue
196
197       return True
198
199   def process(self, sizequery, moredata, max_batch):
200     # sizequery() should return size of batch so far
201     # moredata(s) should add s to batch
202     while True:
203       try: (dummy, packet) = self._pq[0]
204       except IndexError: break
205
206       encoded = slip.encode(packet)
207       sofar = sizequery()  
208
209       if sofar > 0:
210         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
211           break
212         moredata(slip.delimiter)
213
214       moredata(encoded)
215       self._pq.popLeft()
216
217 #---------- error handling ----------
218
219 def crash(err):
220   print('CRASH ', err, file=sys.stderr)
221   try: reactor.stop()
222   except twisted.internet.error.ReactorNotRunning: pass
223
224 def crash_on_defer(defer):
225   defer.addErrback(lambda err: crash(err))
226
227 vdef crash_on_critical(event):
228   if event.get('log_level') >= LogLevel.critical:
229     crash(twisted.logger.formatEvent(event))
230
231 #---------- config processing ----------
232
233 def process_cfg_common_always():
234   global mtu
235   c.mtu = cfg.get('virtual','mtu')
236
237 def process_cfg_ipif(section, varmap):
238   for d, s in varmap:
239     try: v = getattr(c, s)
240     except AttributeError: continue
241     setattr(c, d, v)
242
243   print(repr(c))
244
245   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
246
247 def process_cfg_network():
248   c.network = ipnetwork(cfg.get('virtual','network'))
249   if c.network.num_addresses < 3 + 2:
250     raise ValueError('network needs at least 2^3 addresses')
251
252 def process_cfg_server():
253   try:
254     c.server = cfg.get('virtual','server')
255   except NoOptionError:
256     process_cfg_network()
257     c.server = next(c.network.hosts())
258
259 class ServerAddr():
260   def __init__(self, port, addrspec):
261     self.port = port
262     # also self.addr
263     try:
264       self.addr = ipaddress.IPv4Address(addrspec)
265       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
266       self._inurl = '%s'
267     except AddressValueError:
268       self.addr = ipaddress.IPv6Address(addrspec)
269       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
270       self._inurl = '[%s]'
271   def make_endpoint(self):
272     return self._endpointfactory(reactor, self.port, self.addr)
273   def url(self):
274     url = 'http://' + (self._inurl % self.addr)
275     if self.port != 80: url += ':%d' % self.port
276     url += '/'
277     return url
278     
279 def process_cfg_saddrs():
280   try: port = cfg.getint('server','port')
281   except NoOptionError: port = 80
282
283   c.saddrs = [ ]
284   for addrspec in cfg.get('server','addrs').split():
285     sa = ServerAddr(port, addrspec)
286     c.saddrs.append(sa)
287
288 def process_cfg_clients(constructor):
289   c.clients = [ ]
290   for cs in cfg.sections():
291     if not (':' in cs or '.' in cs): continue
292     ci = ipaddr(cs)
293     pw = cfg.get(cs, 'password')
294     constructor(ci,cs,pw)
295
296 #---------- startup ----------
297
298 def common_startup():
299   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
300
301   optparser.add_option('-c', '--config', dest='configfile',
302                        default='/etc/hippotat/config')
303   (opts, args) = optparser.parse_args()
304   if len(args): optparser.error('no non-option arguments please')
305
306   re = regexp.compile('#.*')
307   cfg.read_string(re.sub('', defcfg))
308   cfg.read(opts.configfile)
309
310 def common_run():
311   reactor.run()
312   print('CRASHED (end)', file=sys.stderr)