chiark / gitweb /
wip,
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9 import twisted.internet.endpoints
10
11 import ipaddress
12 from ipaddress import AddressValueError
13
14 import hippotat.slip as slip
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 # these need to be defined here so that they can be imported by import *
23 cfg = ConfigParser()
24 optparser = OptionParser()
25
26 class ConfigResults:
27   def __init__(self, d = { }):
28     self.__dict__ = d
29   def __repr__(self):
30     return 'ConfigResults('+repr(self.__dict__)+')'
31
32 c = ConfigResults()
33
34 #---------- packet parsing ----------
35
36 def packet_addrs(packet):
37   version = packet[0] >> 4
38   if version == 4:
39     addrlen = 4
40     saddroff = 3*4
41     factory = ipaddress.IPv4Address
42   elif version == 6:
43     addrlen = 16
44     saddroff = 2*4
45     factory = ipaddress.IPv6Address
46   else:
47     raise ValueError('unsupported IP version %d' % version)
48   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
49   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
50   return (saddr, daddr)
51
52 #---------- address handling ----------
53
54 def ipaddr(input):
55   try:
56     r = ipaddress.IPv4Address(input)
57   except AddressValueError:
58     r = ipaddress.IPv6Address(input)
59   return r
60
61 def ipnetwork(input):
62   try:
63     r = ipaddress.IPv4Network(input)
64   except NetworkValueError:
65     r = ipaddress.IPv6Network(input)
66   return r
67
68 #---------- ipif (SLIP) subprocess ----------
69
70 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
71   def __init__(self, router):
72     self._buffer = b''
73     self._router = router
74   def connectionMade(self): pass
75   def outReceived(self, data):
76     #print('RECV ', repr(data))
77     self._buffer += data
78     packets = slip.decode(self._buffer)
79     self._buffer = packets.pop()
80     for packet in packets:
81       if not len(packet): continue
82       (saddr, daddr) = packet_addrs(packet)
83       self._router(packet, saddr, daddr)
84   def processEnded(self, status):
85     status.raiseException()
86
87 def start_ipif(command, router):
88   global ipif
89   ipif = _IpifProcessProtocol(router)
90   reactor.spawnProcess(ipif,
91                        '/bin/sh',['sh','-xc', command],
92                        childFDs={0:'w', 1:'r', 2:2})
93
94 def queue_inbound(packet):
95   ipif.transport.write(slip.delimiter)
96   ipif.transport.write(slip.encode(packet))
97   ipif.transport.write(slip.delimiter)
98
99 #---------- packet queue ----------
100
101 class PacketQueue():
102   def __init__(self, max_queue_time):
103     self._max_queue_time = max_queue_time
104     self._pq = collections.deque() # packets
105
106   def append(self, packet):
107     self._pq.append((time.monotonic(), packet))
108
109   def nonempty(self):
110     while True:
111       try: (queuetime, packet) = self._pq[0]
112       except IndexError: return False
113
114       age = time.monotonic() - queuetime
115       if age > self.max_queue_time:
116         # strip old packets off the front
117         self._pq.popleft()
118         continue
119
120       return True
121
122   def popleft(self):
123     # caller must have checked nonempty
124     try: (dummy, packet) = self._pq[0]
125     except IndexError: return None
126     return packet
127
128 #---------- error handling ----------
129
130 def crash(err):
131   print('CRASH ', err, file=sys.stderr)
132   try: reactor.stop()
133   except twisted.internet.error.ReactorNotRunning: pass
134
135 def crash_on_defer(defer):
136   defer.addErrback(lambda err: crash(err))
137
138 def crash_on_critical(event):
139   if event.get('log_level') >= LogLevel.critical:
140     crash(twisted.logger.formatEvent(event))
141
142 #---------- config processing ----------
143
144 def process_cfg_common_always():
145   global mtu
146   c.mtu = cfg.get('virtual','mtu')
147
148 def process_cfg_ipif(section, varmap):
149   for d, s in varmap:
150     try: v = getattr(c, s)
151     except KeyError: pass
152     setattr(c, d, v)
153
154   print(repr(c))
155
156   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
157
158 def process_cfg_network():
159   c.network = ipnetwork(cfg.get('virtual','network'))
160   if c.network.num_addresses < 3 + 2:
161     raise ValueError('network needs at least 2^3 addresses')
162
163 def process_cfg_server():
164   try:
165     c.server = cfg.get('virtual','server')
166   except NoOptionError:
167     process_cfg_network()
168     c.server = next(c.network.hosts())
169
170 class ServerAddr():
171   def __init__(self, port, addrspec):
172     self.port = port
173     # also self.addr
174     try:
175       self.addr = ipaddress.IPv4Address(addrspec)
176       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
177       self._inurl = '%s'
178     except AddressValueError:
179       self.addr = ipaddress.IPv6Address(addrspec)
180       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
181       self._inurl = '[%s]'
182   def make_endpoint(self):
183     return self._endpointfactory(reactor, self.port, self.addr)
184   def url(self):
185     url = 'http://' + (self._inurl % self.addr)
186     if self.port != 80: url += ':%d' % self.port
187     url += '/'
188     return url
189     
190 def process_cfg_saddrs():
191   try: port = cfg.getint('server','port')
192   except NoOptionError: port = 80
193
194   c.saddrs = [ ]
195   for addrspec in cfg.get('server','addrs').split():
196     sa = ServerAddr(port, addrspec)
197     c.saddrs.append(sa)
198
199 def process_cfg_clients(constructor):
200   c.clients = [ ]
201   for cs in cfg.sections():
202     if not (':' in cs or '.' in cs): continue
203     ci = ipaddr(cs)
204     pw = cfg.get(cs, 'password')
205     constructor(ci,cs,pw)
206
207 #---------- startup ----------
208
209 def common_startup(defcfg):
210   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
211
212   optparser.add_option('-c', '--config', dest='configfile',
213                        default='/etc/hippotat/config')
214   (opts, args) = optparser.parse_args()
215   if len(args): optparser.error('no non-option arguments please')
216
217   cfg.read_string(defcfg)
218   cfg.read(opts.configfile)
219
220 def common_run():
221   reactor.run()
222   print('CRASHED (end)', file=sys.stderr)