chiark / gitweb /
wip
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 import re as regexp
23
24 import hippotat.slip as slip
25
26 defcfg = '''
27 [DEFAULT]
28 #[<client>] overrides
29 max_batch_down = 65536           # used by server, subject to [limits]
30 max_queue_time = 10              # used by server, subject to [limits]
31 max_request_time = 54            # used by server, subject to [limits]
32 target_requests_outstanding = 3  # must match; subject to [limits] on server
33 max_requests_outstanding = 4     # used by client
34 max_batch_up = 4000              # used by client
35 http_timeout = 30                # used by client
36
37 #[server] or [<client>] overrides
38 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
39 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
40 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
41 #      from   on client  <client>         [virtual]server   [virtual]routes
42
43 [virtual]
44 mtu = 1500
45 routes = ''
46 # network = <prefix>/<len>  # mandatory for server
47 # server  = <ipaddr>   # used by both, default is computed from `network'
48 # relay   = <ipaddr>   # used by server, default from `network' and `server'
49 #  default server is first host in network
50 #  default relay is first host which is not server
51
52 [server]
53 # addrs = 127.0.0.1 ::1    # mandatory for server
54 port = 80                  # used by server
55 # url              # used by client; default from first `addrs' and `port'
56
57 # [<client-ip4-or-ipv6-address>]
58 # password = <password>    # used by both, must match
59
60 [limits]
61 max_batch_down = 262144           # used by server
62 max_queue_time = 121              # used by server
63 max_request_time = 121            # used by server
64 target_requests_outstanding = 10  # used by server
65 '''
66
67 # these need to be defined here so that they can be imported by import *
68 cfg = ConfigParser()
69 optparser = OptionParser()
70
71 _mimetrans = str.maketrans(b'-'+slip.esc, slip.esc+'-')
72 def mime_translate(s):
73   # SLIP-encoded packets cannot contain ESC ESC.
74   # Swap `-' and ESC.  The result cannot contain `--'
75   return s.translate(_mimetrans)
76
77 class ConfigResults:
78   def __init__(self, d = { }):
79     self.__dict__ = d
80   def __repr__(self):
81     return 'ConfigResults('+repr(self.__dict__)+')'
82
83 c = ConfigResults()
84
85 def log_discard(packet, saddr, daddr, why):
86   print('DROP ', saddr, daddr, why)
87 #  syslog.syslog(syslog.LOG_DEBUG,
88 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
89
90 #---------- packet parsing ----------
91
92 def packet_addrs(packet):
93   version = packet[0] >> 4
94   if version == 4:
95     addrlen = 4
96     saddroff = 3*4
97     factory = ipaddress.IPv4Address
98   elif version == 6:
99     addrlen = 16
100     saddroff = 2*4
101     factory = ipaddress.IPv6Address
102   else:
103     raise ValueError('unsupported IP version %d' % version)
104   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
105   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
106   return (saddr, daddr)
107
108 #---------- address handling ----------
109
110 def ipaddr(input):
111   try:
112     r = ipaddress.IPv4Address(input)
113   except AddressValueError:
114     r = ipaddress.IPv6Address(input)
115   return r
116
117 def ipnetwork(input):
118   try:
119     r = ipaddress.IPv4Network(input)
120   except NetworkValueError:
121     r = ipaddress.IPv6Network(input)
122   return r
123
124 #---------- ipif (SLIP) subprocess ----------
125
126 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
127   def __init__(self, router):
128     self._buffer = b''
129     self._router = router
130   def connectionMade(self): pass
131   def outReceived(self, data):
132     #print('IPIF-GOT ', repr(data))
133     self._buffer += data
134     packets = slip.decode(self._buffer)
135     self._buffer = packets.pop()
136     for packet in packets:
137       if not len(packet): continue
138       (saddr, daddr) = packet_addrs(packet)
139       if saddr.is_link_local or daddr.is_link_local:
140         log_discard(packet, saddr, daddr, 'link-local')
141         continue
142       self._router(packet, saddr, daddr)
143   def processEnded(self, status):
144     status.raiseException()
145
146 def start_ipif(command, router):
147   global ipif
148   ipif = _IpifProcessProtocol(router)
149   reactor.spawnProcess(ipif,
150                        '/bin/sh',['sh','-xc', command],
151                        childFDs={0:'w', 1:'r', 2:2})
152
153 def queue_inbound(packet):
154   ipif.transport.write(slip.delimiter)
155   ipif.transport.write(slip.encode(packet))
156   ipif.transport.write(slip.delimiter)
157
158 #---------- packet queue ----------
159
160 class PacketQueue():
161   def __init__(self, max_queue_time):
162     self._max_queue_time = max_queue_time
163     self._pq = collections.deque() # packets
164
165   def append(self, packet):
166     self._pq.append((time.monotonic(), packet))
167
168   def nonempty(self):
169     while True:
170       try: (queuetime, packet) = self._pq[0]
171       except IndexError: return False
172
173       age = time.monotonic() - queuetime
174       if age > self.max_queue_time:
175         # strip old packets off the front
176         self._pq.popleft()
177         continue
178
179       return True
180
181   def process(self, sizequery, moredata, max_batch):
182     # sizequery() should return size of batch so far
183     # moredata(s) should add s to batch
184     while True:
185       try: (dummy, packet) = self._pq[0]
186       except IndexError: break
187
188       encoded = slip.encode(packet)
189       sofar = sizequery()  
190
191       if sofar > 0:
192         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
193           break
194         moredata(slip.delimiter)
195
196       moredata(encoded)
197       self._pq.popLeft()
198
199 #---------- error handling ----------
200
201 def crash(err):
202   print('CRASH ', err, file=sys.stderr)
203   try: reactor.stop()
204   except twisted.internet.error.ReactorNotRunning: pass
205
206 def crash_on_defer(defer):
207   defer.addErrback(lambda err: crash(err))
208
209 vdef crash_on_critical(event):
210   if event.get('log_level') >= LogLevel.critical:
211     crash(twisted.logger.formatEvent(event))
212
213 #---------- config processing ----------
214
215 def process_cfg_common_always():
216   global mtu
217   c.mtu = cfg.get('virtual','mtu')
218
219 def process_cfg_ipif(section, varmap):
220   for d, s in varmap:
221     try: v = getattr(c, s)
222     except AttributeError: continue
223     setattr(c, d, v)
224
225   print(repr(c))
226
227   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
228
229 def process_cfg_network():
230   c.network = ipnetwork(cfg.get('virtual','network'))
231   if c.network.num_addresses < 3 + 2:
232     raise ValueError('network needs at least 2^3 addresses')
233
234 def process_cfg_server():
235   try:
236     c.server = cfg.get('virtual','server')
237   except NoOptionError:
238     process_cfg_network()
239     c.server = next(c.network.hosts())
240
241 class ServerAddr():
242   def __init__(self, port, addrspec):
243     self.port = port
244     # also self.addr
245     try:
246       self.addr = ipaddress.IPv4Address(addrspec)
247       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
248       self._inurl = '%s'
249     except AddressValueError:
250       self.addr = ipaddress.IPv6Address(addrspec)
251       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
252       self._inurl = '[%s]'
253   def make_endpoint(self):
254     return self._endpointfactory(reactor, self.port, self.addr)
255   def url(self):
256     url = 'http://' + (self._inurl % self.addr)
257     if self.port != 80: url += ':%d' % self.port
258     url += '/'
259     return url
260     
261 def process_cfg_saddrs():
262   try: port = cfg.getint('server','port')
263   except NoOptionError: port = 80
264
265   c.saddrs = [ ]
266   for addrspec in cfg.get('server','addrs').split():
267     sa = ServerAddr(port, addrspec)
268     c.saddrs.append(sa)
269
270 def process_cfg_clients(constructor):
271   c.clients = [ ]
272   for cs in cfg.sections():
273     if not (':' in cs or '.' in cs): continue
274     ci = ipaddr(cs)
275     pw = cfg.get(cs, 'password')
276     constructor(ci,cs,pw)
277
278 #---------- startup ----------
279
280 def common_startup():
281   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
282
283   optparser.add_option('-c', '--config', dest='configfile',
284                        default='/etc/hippotat/config')
285   (opts, args) = optparser.parse_args()
286   if len(args): optparser.error('no non-option arguments please')
287
288   re = regexp.compile('#.*')
289   cfg.read_string(re.sub('', defcfg))
290   cfg.read(opts.configfile)
291
292 def common_run():
293   reactor.run()
294   print('CRASHED (end)', file=sys.stderr)