chiark / gitweb /
d25ff06d9a1a6580e8ce9a1888938bbd2b789f5a
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8
9 import ipaddress
10 from ipaddress import AddressValueError
11
12 import hippotat.slip as slip
13
14 #---------- packet parsing ----------
15
16 def packet_addrs(packet):
17   version = packet[0] >> 4
18   if version == 4:
19     addrlen = 4
20     saddroff = 3*4
21     factory = ipaddress.IPv4Address
22   elif version == 6:
23     addrlen = 16
24     saddroff = 2*4
25     factory = ipaddress.IPv6Address
26   else:
27     raise ValueError('unsupported IP version %d' % version)
28   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
29   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
30   return (saddr, daddr)
31
32 #---------- address handling ----------
33
34 def ipaddr(input):
35   try:
36     r = ipaddress.IPv4Address(input)
37   except AddressValueError:
38     r = ipaddress.IPv6Address(input)
39   return r
40
41 def ipnetwork(input):
42   try:
43     r = ipaddress.IPv4Network(input)
44   except NetworkValueError:
45     r = ipaddress.IPv6Network(input)
46   return r
47
48 #---------- ipif (SLIP) subprocess ----------
49
50 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
51   def __init__(self, router):
52     self._buffer = b''
53     self._router = router
54   def connectionMade(self): pass
55   def outReceived(self, data):
56     #print('RECV ', repr(data))
57     self._buffer += data
58     packets = slip.decode(self._buffer)
59     self._buffer = packets.pop()
60     for packet in packets:
61       if not len(packet): continue
62       (saddr, daddr) = packet_addrs(packet)
63       self._router(packet, saddr, daddr)
64   def processEnded(self, status):
65     status.raiseException()
66
67 def start_ipif(command, router):
68   global ipif
69   ipif = _IpifProcessProtocol(router)
70   reactor.spawnProcess(ipif,
71                        '/bin/sh',['sh','-xc', command],
72                        childFDs={0:'w', 1:'r', 2:2})
73
74 def queue_inbound(packet):
75   ipif.transport.write(slip.delimiter)
76   ipif.transport.write(slip.encode(packet))
77   ipif.transport.write(slip.delimiter)
78
79 #---------- packet queue ----------
80
81 class PacketQueue():
82   def __init__(self, max_queue_time):
83     self._max_queue_time = max_queue_time
84     self._pq = collections.deque() # packets
85
86   def append(self, packet):
87     self._pq.append((time.monotonic(), packet))
88
89   def nonempty(self):
90     while True:
91       try: (queuetime, packet) = self._pq[0]
92       except IndexError: return False
93
94       age = time.monotonic() - queuetime
95       if age > self.max_queue_time:
96         # strip old packets off the front
97         self._pq.popleft()
98         continue
99
100       return True
101
102   def popleft(self):
103     # caller must have checked nonempty
104     try: (dummy, packet) = self._pq[0]
105     except IndexError: return None
106     return packet