chiark / gitweb /
e813ca0b650e24f9dfe6a263f6b0b12c01d5517f
[hippotat.git] / server
1 #!/usr/bin/python3
2
3 import sys
4 import os
5
6 import twisted
7 import twisted.internet
8 import twisted.internet.endpoints
9 from twisted.internet import reactor
10 from twisted.web.server import NOT_DONE_YET
11 from twisted.logger import LogLevel
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 clients = { }
28
29 def ipaddr(input):
30   try:
31     r = ipaddress.IPv4Address(input)
32   except AddressValueError:
33     r = ipaddress.IPv6Address(input)
34   return r
35
36 def ipnetwork(input):
37   try:
38     r = ipaddress.IPv4Network(input)
39   except NetworkValueError:
40     r = ipaddress.IPv6Network(input)
41   return r
42
43 defcfg = '''
44 [DEFAULT]
45 max_batch_down = 65536
46 max_queue_time = 10
47 max_request_time = 54
48
49 [virtual]
50 mtu = 1500
51 # network
52 # [host]
53 # [relay]
54
55 [server]
56 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
57 addrs = 127.0.0.1 ::1
58 port = 80
59
60 [limits]
61 max_batch_down = 262144
62 max_queue_time = 121
63 max_request_time = 121
64 '''
65
66 #---------- "router" ----------
67
68 def route(packet, saddr, daddr):
69   print('TRACE ', saddr, daddr, packet)
70   try: client = clients[daddr]
71   except KeyError: dclient = None
72   if dclient is not None:
73     dclient.queue_outbound(packet)
74   elif saddr.is_link_local or daddr.is_link_local:
75     log_discard(packet, saddr, daddr, 'link-local')
76   elif daddr == host or daddr not in network:
77     print('TRACE INBOUND ', saddr, daddr, packet)
78     queue_inbound(packet)
79   elif daddr == relay:
80     log_discard(packet, saddr, daddr, 'relay')
81   else:
82     log_discard(packet, saddr, daddr, 'no client')
83
84 def log_discard(packet, saddr, daddr, why):
85   print('DROP ', saddr, daddr, why)
86 #  syslog.syslog(syslog.LOG_DEBUG,
87 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
88
89 #---------- ipif (slip subprocess) ----------
90
91 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
92   def __init__(self):
93     self._buffer = b''
94   def connectionMade(self): pass
95   def outReceived(self, data):
96     #print('RECV ', repr(data))
97     self._buffer += data
98     packets = slip_decode(self._buffer)
99     self._buffer = packets.pop()
100     for packet in packets:
101       if not len(packet): continue
102       (saddr, daddr) = packet_addrs(packet)
103       route(packet, saddr, daddr)
104   def processEnded(self, status):
105     status.raiseException()
106
107 def start_ipif():
108   global ipif
109   ipif = IpifProcessProtocol()
110   reactor.spawnProcess(ipif,
111                        '/bin/sh',['sh','-xc', ipif_command],
112                        childFDs={0:'w', 1:'r', 2:2})
113
114 def queue_inbound(packet):
115   ipif.transport.write(slip_delimiter)
116   ipif.transport.write(slip_encode(packet))
117   ipif.transport.write(slip_delimiter)
118
119 #---------- SLIP handling ----------
120
121 slip_end = b'\300'
122 slip_esc = b'\333'
123 slip_esc_end = b'\334'
124 slip_esc_esc = b'\335'
125 slip_delimiter = slip_end
126
127 def slip_encode(packet):
128   return (packet
129           .replace(slip_esc, slip_esc + slip_esc_esc)
130           .replace(slip_end, slip_esc + slip_esc_end))
131
132 def slip_decode(data):
133   print('DECODE ', repr(data))
134   out = []
135   for packet in data.split(slip_end):
136     pdata = b''
137     while True:
138       eix = packet.find(slip_esc)
139       if eix == -1:
140         pdata += packet
141         break
142       #print('ESC ', repr((pdata, packet, eix)))
143       pdata += packet[0 : eix]
144       ck = packet[eix+1]
145       if   ck == slip_esc_esc: pdata += slip_esc
146       elif ck == slip_esc_end: pdata += slip_end
147       else: raise ValueError('invalid SLIP escape')
148       packet = packet[eix+2 : ]
149     out.append(pdata)
150   print('DECODED ', repr(out))
151   return out
152
153 #---------- packet parsing ----------
154
155 def packet_addrs(packet):
156   version = packet[0] >> 4
157   if version == 4:
158     addrlen = 4
159     saddroff = 3*4
160     factory = ipaddress.IPv4Address
161   elif version == 6:
162     addrlen = 16
163     saddroff = 2*4
164     factory = ipaddress.IPv6Address
165   else:
166     raise ValueError('unsupported IP version %d' % version)
167   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
168   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
169   return (saddr, daddr)
170
171 #---------- client ----------
172
173 class Client():
174   def __init__(self, ip, cs):
175     # instance data members
176     self._ip = ip
177     self._cs = cs
178     self.pw = cfg.get(cs, 'password')
179     self._rq = collections.deque() # requests
180     self._pq = collections.deque() # packets
181     # plus from config:
182     #  .max_batch_down
183     #  .max_queue_time
184     #  .max_request_time
185     for k in ('max_batch_down','max_queue_time','max_request_time'):
186       req = cfg.getint(cs, k)
187       limit = cfg.getint('limits',k)
188       self.__dict__[k] = min(req, limit)
189
190     def process_arriving_data(self, d):
191       for packet in slip_decode(d):
192         (saddr, daddr) = packet_addrs(packet)
193         if saddr != self._ip:
194           raise ValueError('wrong source address %s' % saddr)
195         route(packet, saddr, daddr)
196
197     def _req_cancel(self, request):
198       request.finish()
199
200     def _req_error(self, err, request):
201       self._req_cancel(request)
202
203     def queue_outbound(self, packet):
204       self._pq.append((time.monotonic(), packet))
205
206     def http_request(self, request):
207       request.setHeader('Content-Type','application/octet-stream')
208       reactor.callLater(self.max_request_time, self._req_cancel, request)
209       request.notifyFinish().addErrback(self._req_error, request)
210       self._rq.append(request)
211       self._check_outbound()
212
213     def _check_outbound(self):
214       while True:
215         try: request = self._rq[0]
216         except IndexError: request = None
217         if request and request.finished:
218           self._rq.popleft()
219           continue
220
221         # now request is an unfinished request, or None
222         try: (queuetime, packet) = self._pq[0]
223         except IndexError:
224           # no packets, oh well
225           break
226
227         age = time.monotonic() - queuetime
228         if age > self.max_queue_time:
229           self._pq.popleft()
230           continue
231
232         if request is None:
233           # no request
234           break
235
236         # request, and also some non-expired packets
237         while True:
238           try: (dummy, packet) = self._pq[0]
239           except IndexError: break
240
241           encoded = slip_encode(packet)
242           
243           if request.sentLength > 0:
244             if (request.sentLength + len(slip_delimiter)
245                 + len(encoded) > self.max_batch_down):
246               break
247             request.write(slip_delimiter)
248
249           request.write(encoded)
250           self._pq.popLeft()
251
252         assert(request.sentLength)
253         self._rq.popLeft()
254         request.finish()
255         # round again, looking for more to do
256
257 class IphttpResource(twisted.web.resource.Resource):
258   def render_POST(self, request):
259     # find client, update config, etc.
260     ci = ipaddr(request.args['i'])
261     c = clients[ci]
262     pw = request.args['pw']
263     if pw != c.pw: raise ValueError('bad password')
264
265     # update config
266     for r, w in (('mbd', 'max_batch_down'),
267                  ('mqt', 'max_queue_time'),
268                  ('mrt', 'max_request_time')):
269       try: v = request.args[r]
270       except KeyError: continue
271       v = int(v)
272       c.__dict__[w] = v
273
274     try: d = request.args['d']
275     except KeyError: d = ''
276
277     c.process_arriving_data(d)
278     c.new_request(request)
279
280 def start_http():
281   resource = IphttpResource()
282   sitefactory = twisted.web.server.Site(resource)
283   for addrspec in cfg.get('server','addrs').split():
284     try:
285       addr = ipaddress.IPv4Address(addrspec)
286       endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
287     except AddressValueError:
288       addr = ipaddress.IPv6Address(addrspec)
289       endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
290     ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
291     defer = ep.listen(sitefactory)
292     defer.addErrback(lambda err: err.raiseException())
293
294 #---------- config and setup ----------
295         
296 def process_cfg():
297   global network
298   global host
299   global relay
300   global ipif_command
301
302   network = ipnetwork(cfg.get('virtual','network'))
303   if network.num_addresses < 3 + 2:
304     raise ValueError('network needs at least 2^3 addresses')
305
306   try:
307     host = cfg.get('virtual','host')
308   except NoOptionError:
309     host = next(network.hosts())
310
311   try:
312     relay = cfg.get('virtual','relay')
313   except NoOptionError:
314     for search in network.hosts():
315       if search == host: continue
316       relay = search
317       break
318
319   for cs in cfg.sections():
320     if not (':' in cs or '.' in cs): continue
321     ci = ipaddr(cs)
322     if ci not in network:
323       raise ValueError('client %s not in network' % ci)
324     if ci in clients:
325       raise ValueError('multiple client cfg sections for %s' % ci)
326     clients[ci] = Client(ci, cs)
327
328   global mtu
329   mtu = cfg.get('virtual','mtu')
330
331   iic_vars = { }
332   for k in ('host','relay','mtu','network'):
333     iic_vars[k] = globals()[k]
334
335   ipif_command = cfg.get('server','ipif', vars=iic_vars)
336
337 def crash_on_critical(event):
338   if event.get('log_level') >= LogLevel.critical:
339     print('crashing: ', twisted.logger.formatEvent(event), file=sys.stderr)
340     #print('crashing!', file=sys.stderr)
341     #os._exit(1)
342     try: reactor.stop()
343     except twisted.internet.error.ReactorNotRunning: pass
344
345 def startup():
346   global cfg
347
348   op = OptionParser()
349   op.add_option('-c', '--config', dest='configfile',
350                 default='/etc/hippottd/server.conf')
351   global opts
352   (opts, args) = op.parse_args()
353   if len(args): op.error('no non-option arguments please')
354
355   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
356
357   cfg = ConfigParser()
358   cfg.read_string(defcfg)
359   cfg.read(opts.configfile)
360   process_cfg()
361
362   start_ipif()
363   start_http()
364
365 startup()
366 reactor.run()