chiark / gitweb /
6af49b6da754e46f88a3e65b663aec32df46c5bd
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9
10 import ipaddress
11 from ipaddress import AddressValueError
12
13 import hippotat.slip as slip
14
15 from optparse import OptionParser
16 from configparser import ConfigParser
17 from configparser import NoOptionError
18
19 import collections
20
21 cfg = ConfigParser()
22 optparser = OptionParser()
23
24 #---------- packet parsing ----------
25
26 def packet_addrs(packet):
27   version = packet[0] >> 4
28   if version == 4:
29     addrlen = 4
30     saddroff = 3*4
31     factory = ipaddress.IPv4Address
32   elif version == 6:
33     addrlen = 16
34     saddroff = 2*4
35     factory = ipaddress.IPv6Address
36   else:
37     raise ValueError('unsupported IP version %d' % version)
38   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
39   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
40   return (saddr, daddr)
41
42 #---------- address handling ----------
43
44 def ipaddr(input):
45   try:
46     r = ipaddress.IPv4Address(input)
47   except AddressValueError:
48     r = ipaddress.IPv6Address(input)
49   return r
50
51 def ipnetwork(input):
52   try:
53     r = ipaddress.IPv4Network(input)
54   except NetworkValueError:
55     r = ipaddress.IPv6Network(input)
56   return r
57
58 #---------- ipif (SLIP) subprocess ----------
59
60 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
61   def __init__(self, router):
62     self._buffer = b''
63     self._router = router
64   def connectionMade(self): pass
65   def outReceived(self, data):
66     #print('RECV ', repr(data))
67     self._buffer += data
68     packets = slip.decode(self._buffer)
69     self._buffer = packets.pop()
70     for packet in packets:
71       if not len(packet): continue
72       (saddr, daddr) = packet_addrs(packet)
73       self._router(packet, saddr, daddr)
74   def processEnded(self, status):
75     status.raiseException()
76
77 def start_ipif(command, router):
78   global ipif
79   ipif = _IpifProcessProtocol(router)
80   reactor.spawnProcess(ipif,
81                        '/bin/sh',['sh','-xc', command],
82                        childFDs={0:'w', 1:'r', 2:2})
83
84 def queue_inbound(packet):
85   ipif.transport.write(slip.delimiter)
86   ipif.transport.write(slip.encode(packet))
87   ipif.transport.write(slip.delimiter)
88
89 #---------- packet queue ----------
90
91 class PacketQueue():
92   def __init__(self, max_queue_time):
93     self._max_queue_time = max_queue_time
94     self._pq = collections.deque() # packets
95
96   def append(self, packet):
97     self._pq.append((time.monotonic(), packet))
98
99   def nonempty(self):
100     while True:
101       try: (queuetime, packet) = self._pq[0]
102       except IndexError: return False
103
104       age = time.monotonic() - queuetime
105       if age > self.max_queue_time:
106         # strip old packets off the front
107         self._pq.popleft()
108         continue
109
110       return True
111
112   def popleft(self):
113     # caller must have checked nonempty
114     try: (dummy, packet) = self._pq[0]
115     except IndexError: return None
116     return packet
117
118 #---------- error handling ----------
119
120 def crash(err):
121   print('CRASH ', err, file=sys.stderr)
122   try: reactor.stop()
123   except twisted.internet.error.ReactorNotRunning: pass
124
125 def crash_on_defer(defer):
126   defer.addErrback(lambda err: crash(err))
127
128 def crash_on_critical(event):
129   if event.get('log_level') >= LogLevel.critical:
130     crash(twisted.logger.formatEvent(event))
131
132 #---------- startup ----------
133
134 def common_startup(defcfg):
135   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
136
137   optparser.add_option('-c', '--config', dest='configfile',
138                        default='/etc/hippotat/config')
139   (opts, args) = optparser.parse_args()
140   if len(args): optparser.error('no non-option arguments please')
141
142   cfg.read_string(defcfg)
143   cfg.read(opts.configfile)
144
145 def common_run():
146   reactor.run()
147   print('CRASHED (end)', file=sys.stderr)