chiark / gitweb /
start splitting up
[hippotat.git] / server
1 #!/usr/bin/python3
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 import twisted
10 import twisted.internet
11 import twisted.internet.endpoints
12 from twisted.internet import reactor
13 from twisted.web.server import NOT_DONE_YET
14 from twisted.logger import LogLevel
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 from hippotit import *
28
29 clients = { }
30
31 defcfg = '''
32 [DEFAULT]
33 max_batch_down = 65536
34 max_queue_time = 10
35 max_request_time = 54
36
37 [virtual]
38 mtu = 1500
39 # network
40 # [host]
41 # [relay]
42
43 [server]
44 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
45 addrs = 127.0.0.1 ::1
46 port = 8099
47
48 [limits]
49 max_batch_down = 262144
50 max_queue_time = 121
51 max_request_time = 121
52 '''
53
54 #---------- error handling ----------
55
56 def crash(err):
57   print('CRASH ', err, file=sys.stderr)
58   try: reactor.stop()
59   except twisted.internet.error.ReactorNotRunning: pass
60
61 def crash_on_defer(defer):
62   defer.addErrback(lambda err: crash(err))
63
64 def crash_on_critical(event):
65   if event.get('log_level') >= LogLevel.critical:
66     crash(twisted.logger.formatEvent(event))
67
68 #---------- "router" ----------
69
70 def route(packet, saddr, daddr):
71   print('TRACE ', saddr, daddr, packet)
72   try: client = clients[daddr]
73   except KeyError: dclient = None
74   if dclient is not None:
75     dclient.queue_outbound(packet)
76   elif saddr.is_link_local or daddr.is_link_local:
77     log_discard(packet, saddr, daddr, 'link-local')
78   elif daddr == host or daddr not in network:
79     print('TRACE INBOUND ', saddr, daddr, packet)
80     queue_inbound(packet)
81   elif daddr == relay:
82     log_discard(packet, saddr, daddr, 'relay')
83   else:
84     log_discard(packet, saddr, daddr, 'no client')
85
86 def log_discard(packet, saddr, daddr, why):
87   print('DROP ', saddr, daddr, why)
88 #  syslog.syslog(syslog.LOG_DEBUG,
89 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
90
91 #---------- ipif (slip subprocess) ----------
92
93 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
94   def __init__(self):
95     self._buffer = b''
96   def connectionMade(self): pass
97   def outReceived(self, data):
98     #print('RECV ', repr(data))
99     self._buffer += data
100     packets = slip.decode(self._buffer)
101     self._buffer = packets.pop()
102     for packet in packets:
103       if not len(packet): continue
104       (saddr, daddr) = packet_addrs(packet)
105       route(packet, saddr, daddr)
106   def processEnded(self, status):
107     status.raiseException()
108
109 def start_ipif():
110   global ipif
111   ipif = IpifProcessProtocol()
112   reactor.spawnProcess(ipif,
113                        '/bin/sh',['sh','-xc', ipif_command],
114                        childFDs={0:'w', 1:'r', 2:2})
115
116 def queue_inbound(packet):
117   ipif.transport.write(slip.delimiter)
118   ipif.transport.write(slip.encode(packet))
119   ipif.transport.write(slip.delimiter)
120
121 #---------- client ----------
122
123 class Client():
124   def __init__(self, ip, cs):
125     # instance data members
126     self._ip = ip
127     self._cs = cs
128     self.pw = cfg.get(cs, 'password')
129     self._rq = collections.deque() # requests
130     self._pq = collections.deque() # packets
131     # plus from config:
132     #  .max_batch_down
133     #  .max_queue_time
134     #  .max_request_time
135     for k in ('max_batch_down','max_queue_time','max_request_time'):
136       req = cfg.getint(cs, k)
137       limit = cfg.getint('limits',k)
138       self.__dict__[k] = min(req, limit)
139
140     def process_arriving_data(self, d):
141       for packet in slip.decode(d):
142         (saddr, daddr) = packet_addrs(packet)
143         if saddr != self._ip:
144           raise ValueError('wrong source address %s' % saddr)
145         route(packet, saddr, daddr)
146
147     def _req_cancel(self, request):
148       request.finish()
149
150     def _req_error(self, err, request):
151       self._req_cancel(request)
152
153     def queue_outbound(self, packet):
154       self._pq.append((time.monotonic(), packet))
155
156     def http_request(self, request):
157       request.setHeader('Content-Type','application/octet-stream')
158       reactor.callLater(self.max_request_time, self._req_cancel, request)
159       request.notifyFinish().addErrback(self._req_error, request)
160       self._rq.append(request)
161       self._check_outbound()
162
163     def _check_outbound(self):
164       while True:
165         try: request = self._rq[0]
166         except IndexError: request = None
167         if request and request.finished:
168           self._rq.popleft()
169           continue
170
171         # now request is an unfinished request, or None
172         try: (queuetime, packet) = self._pq[0]
173         except IndexError:
174           # no packets, oh well
175           break
176
177         age = time.monotonic() - queuetime
178         if age > self.max_queue_time:
179           self._pq.popleft()
180           continue
181
182         if request is None:
183           # no request
184           break
185
186         # request, and also some non-expired packets
187         while True:
188           try: (dummy, packet) = self._pq[0]
189           except IndexError: break
190
191           encoded = slip.encode(packet)
192           
193           if request.sentLength > 0:
194             if (request.sentLength + len(slip.delimiter)
195                 + len(encoded) > self.max_batch_down):
196               break
197             request.write(slip.delimiter)
198
199           request.write(encoded)
200           self._pq.popLeft()
201
202         assert(request.sentLength)
203         self._rq.popLeft()
204         request.finish()
205         # round again, looking for more to do
206
207 class IphttpResource(twisted.web.resource.Resource):
208   isLeaf = True
209   def render_POST(self, request):
210     # find client, update config, etc.
211     ci = ipaddr(request.args['i'])
212     c = clients[ci]
213     pw = request.args['pw']
214     if pw != c.pw: raise ValueError('bad password')
215
216     # update config
217     for r, w in (('mbd', 'max_batch_down'),
218                  ('mqt', 'max_queue_time'),
219                  ('mrt', 'max_request_time')):
220       try: v = request.args[r]
221       except KeyError: continue
222       v = int(v)
223       c.__dict__[w] = v
224
225     try: d = request.args['d']
226     except KeyError: d = ''
227
228     c.process_arriving_data(d)
229     c.new_request(request)
230
231   def render_GET(self, request):
232     return b'<html><body>hippotit</body></html>'
233
234 def start_http():
235   resource = IphttpResource()
236   site = twisted.web.server.Site(resource)
237   for addrspec in cfg.get('server','addrs').split():
238     try:
239       addr = ipaddress.IPv4Address(addrspec)
240       endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
241     except AddressValueError:
242       addr = ipaddress.IPv6Address(addrspec)
243       endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
244     ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
245     crash_on_defer(ep.listen(site))
246
247 #---------- config and setup ----------
248         
249 def process_cfg():
250   global network
251   global host
252   global relay
253   global ipif_command
254
255   network = ipnetwork(cfg.get('virtual','network'))
256   if network.num_addresses < 3 + 2:
257     raise ValueError('network needs at least 2^3 addresses')
258
259   try:
260     host = cfg.get('virtual','host')
261   except NoOptionError:
262     host = next(network.hosts())
263
264   try:
265     relay = cfg.get('virtual','relay')
266   except NoOptionError:
267     for search in network.hosts():
268       if search == host: continue
269       relay = search
270       break
271
272   for cs in cfg.sections():
273     if not (':' in cs or '.' in cs): continue
274     ci = ipaddr(cs)
275     if ci not in network:
276       raise ValueError('client %s not in network' % ci)
277     if ci in clients:
278       raise ValueError('multiple client cfg sections for %s' % ci)
279     clients[ci] = Client(ci, cs)
280
281   global mtu
282   mtu = cfg.get('virtual','mtu')
283
284   iic_vars = { }
285   for k in ('host','relay','mtu','network'):
286     iic_vars[k] = globals()[k]
287
288   ipif_command = cfg.get('server','ipif', vars=iic_vars)
289
290 def startup():
291   global cfg
292
293   op = OptionParser()
294   op.add_option('-c', '--config', dest='configfile',
295                 default='/etc/hippottd/server.conf')
296   global opts
297   (opts, args) = op.parse_args()
298   if len(args): op.error('no non-option arguments please')
299
300   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
301
302   cfg = ConfigParser()
303   cfg.read_string(defcfg)
304   cfg.read(opts.configfile)
305   process_cfg()
306
307   start_ipif()
308   start_http()
309
310 startup()
311 reactor.run()
312 print('CRASHED (end)', file=sys.stderr)