chiark / gitweb /
break out SlipProtocol
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 import re as regexp
23
24 import hippotat.slip as slip
25
26 defcfg = '''
27 [DEFAULT]
28 #[<client>] overrides
29 max_batch_down = 65536           # used by server, subject to [limits]
30 max_queue_time = 10              # used by server, subject to [limits]
31 max_request_time = 54            # used by server, subject to [limits]
32 target_requests_outstanding = 3  # must match; subject to [limits] on server
33 max_requests_outstanding = 4     # used by client
34 max_batch_up = 4000              # used by client
35 http_timeout = 30                # used by client
36
37 #[server] or [<client>] overrides
38 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
39 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
40 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
41 #      from   on client  <client>         [virtual]server   [virtual]routes
42
43 [virtual]
44 mtu = 1500
45 routes = ''
46 # network = <prefix>/<len>  # mandatory for server
47 # server  = <ipaddr>   # used by both, default is computed from `network'
48 # relay   = <ipaddr>   # used by server, default from `network' and `server'
49 #  default server is first host in network
50 #  default relay is first host which is not server
51
52 [server]
53 # addrs = 127.0.0.1 ::1    # mandatory for server
54 port = 80                  # used by server
55 # url              # used by client; default from first `addrs' and `port'
56
57 # [<client-ip4-or-ipv6-address>]
58 # password = <password>    # used by both, must match
59
60 [limits]
61 max_batch_down = 262144           # used by server
62 max_queue_time = 121              # used by server
63 max_request_time = 121            # used by server
64 target_requests_outstanding = 10  # used by server
65 '''
66
67 # these need to be defined here so that they can be imported by import *
68 cfg = ConfigParser()
69 optparser = OptionParser()
70
71 _mimetrans = str.maketrans(b'-'+slip.esc, slip.esc+'-')
72 def mime_translate(s):
73   # SLIP-encoded packets cannot contain ESC ESC.
74   # Swap `-' and ESC.  The result cannot contain `--'
75   return s.translate(_mimetrans)
76
77 class ConfigResults:
78   def __init__(self, d = { }):
79     self.__dict__ = d
80   def __repr__(self):
81     return 'ConfigResults('+repr(self.__dict__)+')'
82
83 c = ConfigResults()
84
85 def log_discard(packet, saddr, daddr, why):
86   print('DROP ', saddr, daddr, why)
87 #  syslog.syslog(syslog.LOG_DEBUG,
88 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
89
90 #---------- packet parsing ----------
91
92 def packet_addrs(packet):
93   version = packet[0] >> 4
94   if version == 4:
95     addrlen = 4
96     saddroff = 3*4
97     factory = ipaddress.IPv4Address
98   elif version == 6:
99     addrlen = 16
100     saddroff = 2*4
101     factory = ipaddress.IPv6Address
102   else:
103     raise ValueError('unsupported IP version %d' % version)
104   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
105   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
106   return (saddr, daddr)
107
108 #---------- address handling ----------
109
110 def ipaddr(input):
111   try:
112     r = ipaddress.IPv4Address(input)
113   except AddressValueError:
114     r = ipaddress.IPv6Address(input)
115   return r
116
117 def ipnetwork(input):
118   try:
119     r = ipaddress.IPv4Network(input)
120   except NetworkValueError:
121     r = ipaddress.IPv6Network(input)
122   return r
123
124 #---------- ipif (SLIP) subprocess ----------
125
126 class SlipProtocol(twisted.internet.protocol.ProcessProtocol):
127   # caller must define method receivedPacket(packet)
128   def __init__(self):
129     self._buffer = b''
130   def connectionMade(self): pass
131   def outReceived(self, data):
132     #print('SLIP-GOT ', repr(data))
133     self._buffer += data
134     packets = slip.decode(self._buffer)
135     self._buffer = packets.pop()
136     for packet in packets:
137       if not len(packet): continue
138       self.receivedPacket(packet)
139   def flush(self):
140     if len(self._buffer):
141       self.receivedPacket(self._buffer)
142       self._buffer = ''
143
144 class _IpifProcessProtocol(SlipProtocol):
145   def __init__(self, router):
146     self._router = router
147     super.__init__()
148   def receivedPacket(self, packet):
149     (saddr, daddr) = packet_addrs(packet)
150     if saddr.is_link_local or daddr.is_link_local:
151       log_discard(packet, saddr, daddr, 'link-local')
152       return
153     self._router(packet, saddr, daddr)
154   def processEnded(self, status):
155     status.raiseException()
156
157 def start_ipif(command, router):
158   global ipif
159   ipif = _IpifProcessProtocol(router)
160   reactor.spawnProcess(ipif,
161                        '/bin/sh',['sh','-xc', command],
162                        childFDs={0:'w', 1:'r', 2:2})
163
164 def queue_inbound(packet):
165   ipif.transport.write(slip.delimiter)
166   ipif.transport.write(slip.encode(packet))
167   ipif.transport.write(slip.delimiter)
168
169 #---------- packet queue ----------
170
171 class PacketQueue():
172   def __init__(self, max_queue_time):
173     self._max_queue_time = max_queue_time
174     self._pq = collections.deque() # packets
175
176   def append(self, packet):
177     self._pq.append((time.monotonic(), packet))
178
179   def nonempty(self):
180     while True:
181       try: (queuetime, packet) = self._pq[0]
182       except IndexError: return False
183
184       age = time.monotonic() - queuetime
185       if age > self.max_queue_time:
186         # strip old packets off the front
187         self._pq.popleft()
188         continue
189
190       return True
191
192   def process(self, sizequery, moredata, max_batch):
193     # sizequery() should return size of batch so far
194     # moredata(s) should add s to batch
195     while True:
196       try: (dummy, packet) = self._pq[0]
197       except IndexError: break
198
199       encoded = slip.encode(packet)
200       sofar = sizequery()  
201
202       if sofar > 0:
203         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
204           break
205         moredata(slip.delimiter)
206
207       moredata(encoded)
208       self._pq.popLeft()
209
210 #---------- error handling ----------
211
212 def crash(err):
213   print('CRASH ', err, file=sys.stderr)
214   try: reactor.stop()
215   except twisted.internet.error.ReactorNotRunning: pass
216
217 def crash_on_defer(defer):
218   defer.addErrback(lambda err: crash(err))
219
220 vdef crash_on_critical(event):
221   if event.get('log_level') >= LogLevel.critical:
222     crash(twisted.logger.formatEvent(event))
223
224 #---------- config processing ----------
225
226 def process_cfg_common_always():
227   global mtu
228   c.mtu = cfg.get('virtual','mtu')
229
230 def process_cfg_ipif(section, varmap):
231   for d, s in varmap:
232     try: v = getattr(c, s)
233     except AttributeError: continue
234     setattr(c, d, v)
235
236   print(repr(c))
237
238   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
239
240 def process_cfg_network():
241   c.network = ipnetwork(cfg.get('virtual','network'))
242   if c.network.num_addresses < 3 + 2:
243     raise ValueError('network needs at least 2^3 addresses')
244
245 def process_cfg_server():
246   try:
247     c.server = cfg.get('virtual','server')
248   except NoOptionError:
249     process_cfg_network()
250     c.server = next(c.network.hosts())
251
252 class ServerAddr():
253   def __init__(self, port, addrspec):
254     self.port = port
255     # also self.addr
256     try:
257       self.addr = ipaddress.IPv4Address(addrspec)
258       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
259       self._inurl = '%s'
260     except AddressValueError:
261       self.addr = ipaddress.IPv6Address(addrspec)
262       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
263       self._inurl = '[%s]'
264   def make_endpoint(self):
265     return self._endpointfactory(reactor, self.port, self.addr)
266   def url(self):
267     url = 'http://' + (self._inurl % self.addr)
268     if self.port != 80: url += ':%d' % self.port
269     url += '/'
270     return url
271     
272 def process_cfg_saddrs():
273   try: port = cfg.getint('server','port')
274   except NoOptionError: port = 80
275
276   c.saddrs = [ ]
277   for addrspec in cfg.get('server','addrs').split():
278     sa = ServerAddr(port, addrspec)
279     c.saddrs.append(sa)
280
281 def process_cfg_clients(constructor):
282   c.clients = [ ]
283   for cs in cfg.sections():
284     if not (':' in cs or '.' in cs): continue
285     ci = ipaddr(cs)
286     pw = cfg.get(cs, 'password')
287     constructor(ci,cs,pw)
288
289 #---------- startup ----------
290
291 def common_startup():
292   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
293
294   optparser.add_option('-c', '--config', dest='configfile',
295                        default='/etc/hippotat/config')
296   (opts, args) = optparser.parse_args()
297   if len(args): optparser.error('no non-option arguments please')
298
299   re = regexp.compile('#.*')
300   cfg.read_string(re.sub('', defcfg))
301   cfg.read(opts.configfile)
302
303 def common_run():
304   reactor.run()
305   print('CRASHED (end)', file=sys.stderr)