chiark / gitweb /
c57767ae1c2c3045def4875cdaa1efd7c7c61191
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 import re as regexp
23
24 import hippotat.slip as slip
25
26 defcfg = '''
27 [DEFAULT]
28 #[<client>] overrides
29 max_batch_down = 65536           # used by server, subject to [limits]
30 max_queue_time = 10              # used by server, subject to [limits]
31 max_request_time = 54            # used by server, subject to [limits]
32 target_requests_outstanding = 3  # must match; subject to [limits] on server
33 max_requests_outstanding = 4     # used by client
34 max_batch_up = 4000              # used by client
35
36 #[server] or [<client>] overrides
37 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
38 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
39 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
40 #      from   on client  <client>         [virtual]server   [virtual]routes
41
42 [virtual]
43 mtu = 1500
44 routes = ''
45 # network = <prefix>/<len>  # mandatory for server
46 # server  = <ipaddr>   # used by both, default is computed from `network'
47 # relay   = <ipaddr>   # used by server, default from `network' and `server'
48 #  default server is first host in network
49 #  default relay is first host which is not server
50
51 [server]
52 # addrs = 127.0.0.1 ::1    # mandatory for server
53 port = 80                  # used by server
54 # url              # used by client; default from first `addrs' and `port'
55
56 # [<client-ip4-or-ipv6-address>]
57 # password = <password>    # used by both, must match
58
59 [limits]
60 max_batch_down = 262144           # used by server
61 max_queue_time = 121              # used by server
62 max_request_time = 121            # used by server
63 target_requests_outstanding = 10  # used by server
64 '''
65
66 # these need to be defined here so that they can be imported by import *
67 cfg = ConfigParser()
68 optparser = OptionParser()
69
70 class ConfigResults:
71   def __init__(self, d = { }):
72     self.__dict__ = d
73   def __repr__(self):
74     return 'ConfigResults('+repr(self.__dict__)+')'
75
76 c = ConfigResults()
77
78 def log_discard(packet, saddr, daddr, why):
79   print('DROP ', saddr, daddr, why)
80 #  syslog.syslog(syslog.LOG_DEBUG,
81 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
82
83 #---------- packet parsing ----------
84
85 def packet_addrs(packet):
86   version = packet[0] >> 4
87   if version == 4:
88     addrlen = 4
89     saddroff = 3*4
90     factory = ipaddress.IPv4Address
91   elif version == 6:
92     addrlen = 16
93     saddroff = 2*4
94     factory = ipaddress.IPv6Address
95   else:
96     raise ValueError('unsupported IP version %d' % version)
97   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
98   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
99   return (saddr, daddr)
100
101 #---------- address handling ----------
102
103 def ipaddr(input):
104   try:
105     r = ipaddress.IPv4Address(input)
106   except AddressValueError:
107     r = ipaddress.IPv6Address(input)
108   return r
109
110 def ipnetwork(input):
111   try:
112     r = ipaddress.IPv4Network(input)
113   except NetworkValueError:
114     r = ipaddress.IPv6Network(input)
115   return r
116
117 #---------- ipif (SLIP) subprocess ----------
118
119 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
120   def __init__(self, router):
121     self._buffer = b''
122     self._router = router
123   def connectionMade(self): pass
124   def outReceived(self, data):
125     #print('IPIF-GOT ', repr(data))
126     self._buffer += data
127     packets = slip.decode(self._buffer)
128     self._buffer = packets.pop()
129     for packet in packets:
130       if not len(packet): continue
131       (saddr, daddr) = packet_addrs(packet)
132       if saddr.is_link_local or daddr.is_link_local:
133         log_discard(packet, saddr, daddr, 'link-local')
134         continue
135       self._router(packet, saddr, daddr)
136   def processEnded(self, status):
137     status.raiseException()
138
139 def start_ipif(command, router):
140   global ipif
141   ipif = _IpifProcessProtocol(router)
142   reactor.spawnProcess(ipif,
143                        '/bin/sh',['sh','-xc', command],
144                        childFDs={0:'w', 1:'r', 2:2})
145
146 def queue_inbound(packet):
147   ipif.transport.write(slip.delimiter)
148   ipif.transport.write(slip.encode(packet))
149   ipif.transport.write(slip.delimiter)
150
151 #---------- packet queue ----------
152
153 class PacketQueue():
154   def __init__(self, max_queue_time):
155     self._max_queue_time = max_queue_time
156     self._pq = collections.deque() # packets
157
158   def append(self, packet):
159     self._pq.append((time.monotonic(), packet))
160
161   def nonempty(self):
162     while True:
163       try: (queuetime, packet) = self._pq[0]
164       except IndexError: return False
165
166       age = time.monotonic() - queuetime
167       if age > self.max_queue_time:
168         # strip old packets off the front
169         self._pq.popleft()
170         continue
171
172       return True
173
174   def popleft(self):
175     # caller must have checked nonempty
176     try: (dummy, packet) = self._pq[0]
177     except IndexError: return None
178     return packet
179
180 #---------- error handling ----------
181
182 def crash(err):
183   print('CRASH ', err, file=sys.stderr)
184   try: reactor.stop()
185   except twisted.internet.error.ReactorNotRunning: pass
186
187 def crash_on_defer(defer):
188   defer.addErrback(lambda err: crash(err))
189
190 def crash_on_critical(event):
191   if event.get('log_level') >= LogLevel.critical:
192     crash(twisted.logger.formatEvent(event))
193
194 #---------- config processing ----------
195
196 def process_cfg_common_always():
197   global mtu
198   c.mtu = cfg.get('virtual','mtu')
199
200 def process_cfg_ipif(section, varmap):
201   for d, s in varmap:
202     try: v = getattr(c, s)
203     except AttributeError: continue
204     setattr(c, d, v)
205
206   print(repr(c))
207
208   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
209
210 def process_cfg_network():
211   c.network = ipnetwork(cfg.get('virtual','network'))
212   if c.network.num_addresses < 3 + 2:
213     raise ValueError('network needs at least 2^3 addresses')
214
215 def process_cfg_server():
216   try:
217     c.server = cfg.get('virtual','server')
218   except NoOptionError:
219     process_cfg_network()
220     c.server = next(c.network.hosts())
221
222 class ServerAddr():
223   def __init__(self, port, addrspec):
224     self.port = port
225     # also self.addr
226     try:
227       self.addr = ipaddress.IPv4Address(addrspec)
228       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
229       self._inurl = '%s'
230     except AddressValueError:
231       self.addr = ipaddress.IPv6Address(addrspec)
232       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
233       self._inurl = '[%s]'
234   def make_endpoint(self):
235     return self._endpointfactory(reactor, self.port, self.addr)
236   def url(self):
237     url = 'http://' + (self._inurl % self.addr)
238     if self.port != 80: url += ':%d' % self.port
239     url += '/'
240     return url
241     
242 def process_cfg_saddrs():
243   try: port = cfg.getint('server','port')
244   except NoOptionError: port = 80
245
246   c.saddrs = [ ]
247   for addrspec in cfg.get('server','addrs').split():
248     sa = ServerAddr(port, addrspec)
249     c.saddrs.append(sa)
250
251 def process_cfg_clients(constructor):
252   c.clients = [ ]
253   for cs in cfg.sections():
254     if not (':' in cs or '.' in cs): continue
255     ci = ipaddr(cs)
256     pw = cfg.get(cs, 'password')
257     constructor(ci,cs,pw)
258
259 #---------- startup ----------
260
261 def common_startup():
262   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
263
264   optparser.add_option('-c', '--config', dest='configfile',
265                        default='/etc/hippotat/config')
266   (opts, args) = optparser.parse_args()
267   if len(args): optparser.error('no non-option arguments please')
268
269   re = regexp.compile('#.*')
270   cfg.read_string(re.sub('', defcfg))
271   cfg.read(opts.configfile)
272
273 def common_run():
274   reactor.run()
275   print('CRASHED (end)', file=sys.stderr)