chiark / gitweb /
f701434540d9d56b88237228e1fb195e88c4d54a
[chiark-utils.git] / fishdescriptor / py / fishdescriptor / indonor.py
1
2 # class for use inside gdb which is debugging the donor process
3
4 import gdb
5 import copy
6 import os
7 import sys
8 import socket
9 import re
10
11 def _string_bytearray(s):
12     # gets us bytes in py2 and py3
13     if not isinstance(s, bytes):
14         s = s.encode('utf-8') # sigh, python 2/3 compat
15     return bytearray(s)
16
17 def _string_escape_for_c(s):
18     out = ''
19     for c in _string_bytearray(s):
20         if c == ord('\\') or c == ord('"') or c < 32 or c > 126:
21             out += '\\x%02x' % c
22         else:
23             out += chr(c)
24     return out
25
26 # constructing values
27
28 def _lit_integer(v):
29     return '%d' % v
30
31 def _lit_aggregate_uncasted(val_lit_strs):
32     return '{' + ', '.join(['(%s)' % v for v in val_lit_strs]) + ' }'
33
34 def _lit_string_uncasted(s):
35     b = _string_bytearray(s)
36     return _lit_aggregate_uncasted([_lit_integer(x) for x in b] + [ '0' ])
37
38 def _lit_array(elemtype, val_lit_strs):
39     return (
40         '((%s[%d])%s)' %
41         (elemtype, len(val_lit_strs), _lit_aggregate_uncasted(val_lit_strs))
42     )
43
44 def _lit_addressof(v):
45     return '&(char[])(%s)' % v
46
47 def _make_lit(v):
48     if isinstance(v, int):
49         return _lit_integer(v)
50     else:
51         return v # should already be an integer
52
53 def parse_eval(expr):
54     sys.stderr.write("##  EVAL %s\n" % repr(expr))
55     x = gdb.parse_and_eval(expr)
56     sys.stderr.write('##  => %s\n' % x)
57     sys.stderr.flush()
58     return x
59
60 def parse_eval_via_print(expr):
61     # works only with things whose value is an int and where expr is simple
62     sys.stderr.write("##  EVAL-VIA-PRINT %s\n" % repr(expr))
63     x = gdb.execute('print %s' % expr, to_string=True)
64     m = re.match('\$\d+ = (\d+)\n$', x) # seriously !
65     r = int(m.group(1))
66     sys.stderr.write('##  => %s\n' % r)
67     return 4
68
69 class DonorStructLayout():
70     def __init__(l, typename):
71         x = gdb.lookup_type(typename)
72         l._typename = typename
73         l._template = [ ]
74         l._posns = { }
75         for f in x.fields():
76             l._posns[f.name] = len(l._template)
77             try: f.type.fields();  blank = '{ }'
78             except TypeError:      blank = '0'
79             except AttributeError: blank = '0'
80             l._template.append(blank)
81         sys.stderr.write('##  STRUCT %s template %s fields %s\n'
82                          % (typename, l._template, l._posns))
83
84     def substitute(l, values):
85         build = copy.deepcopy(l._template)
86         for (k,v) in values.items():
87             build[ l._posns[k] ] = _make_lit(v)
88         return '((%s)%s)' % (l._typename, _lit_aggregate_uncasted(build))
89
90 class DonorImplementation():
91     def __init__(di):
92         di._structs = { }
93         di._saved_errno = None
94         di._result_stream = os.fdopen(3, 'w')
95
96     # assembling structs
97     # sigh, we have to record the order of the arguments!
98     def _find_fields(di, typename):
99         try:
100             fields = di._structs[typename]
101         except KeyError:
102             fields = DonorStructLayout(typename)
103             di._structs[typename] = fields
104         return fields
105
106     def _make(di, typename, values):
107         fields = di._find_fields(typename)
108         return fields.substitute(values)
109
110     # calling functions (need to cast the function name to the right
111     # type in case maybe gdb doesn't know the type)
112
113     def _func(di, functype, funcname, realargs):
114         expr = '((%s) %s) %s' % (functype, funcname, realargs)
115         return parse_eval(expr)
116
117     def _must_func(di, functype, funcname, realargs):
118         retval = di._func(functype, funcname, realargs)
119         if retval < 0:
120             errnoval = parse_eval('errno')
121             raise RuntimeError("%s gave errno=%d `%s'" %
122                                (funcname, errnoval, os.strerror(errnoval)))
123         return retval
124
125     # wrappers for the syscalls that do what we want
126
127     def _sendmsg(di, carrier, control_msg):
128         iov_base = _lit_array('char', [1])
129         iov = di._make('struct iovec', {
130             'iov_base': iov_base,
131             'iov_len' : 1,
132         })
133
134         msg = di._make('struct msghdr', {
135             'msg_iov'       : _lit_addressof(iov),
136             'msg_iovlen'    : 1,
137             'msg_control'   : _lit_array('char', control_msg),
138             'msg_controllen': len(control_msg),
139         })
140
141         di._must_func(
142             'ssize_t (*)(int, const struct msghdr*, int)',
143             'sendmsg',
144             '(%s, %s, 0)' % (carrier, _lit_addressof(msg))
145         )
146
147     def _socket(di):
148         return di._must_func(
149             'int (*)(int, int, int)',
150             'socket',
151             '(%d, %d, 0)' % (socket.AF_UNIX, socket.SOCK_STREAM)
152         )
153
154     def _connect(di, fd, path):
155         addr = di._make('struct sockaddr_un', {
156             'sun_family' : _lit_integer(socket.AF_UNIX),
157             'sun_path'   : _lit_string_uncasted(path),
158         })
159
160         di._must_func(
161             'int (*)(int, const struct sockaddr*, socklen_t)',
162             'connect',
163             '(%d, (const struct sockaddr*)%s, sizeof(struct sockaddr_un))'
164             % (fd, _lit_addressof(addr))
165         )
166
167     def _close(di, fd):
168         di._must_func('int (*)(int)', 'close', '(%d)' % fd)
169
170     def _mkdir(di, path, mode):
171         r = di._func(
172             'int (*)(const char*, mode_t)',
173             'mkdir',
174             '("%s", %d)' % (_string_escape_for_c(path), mode)
175         )
176         if r < 0:
177             errnoval = parse_eval('errno')
178             if errnoval != os.errno.EEXIST:
179                 raise RuntimeError("mkdir %s failed: `%s'" %
180                                    (repr(path), os.strerror(errnoval)))
181             return 0
182         return 1
183
184     def _errno_save(di):
185         # incomprehensibly, gdb.parse_and_eval('errno') can sometimes
186         # fail with
187         #   gdb.error: Cannot find thread-local variables on this target
188         # even though plain gdb `print errno' works.
189         # OMG.  This may be related to:
190         #  https://github.com/cloudburst/libheap/issues/24
191         # although I can't find it in the gdb bug db (which is half-broken
192         # in my browser)
193         # Anyway:
194         di._saved_errno = parse_eval_via_print('errno')
195
196     def _errno_restore(di):
197         to_restore = di._saved_errno
198         di._saved_errno = None
199         if to_restore is not None:
200             parse_eval_via_print('errno = %d' % to_restore)
201
202     def _result(di, output):
203         sys.stderr.write("#> %s" % output)
204         di._result_stream.write(output)
205         di._result_stream.flush()
206
207     # main entrypoints
208
209     def donate(di, path, control_msg):
210         # control_msg is an array of integers being the ancillary data
211         # array ("control") for sendmsg, and hence specifies which fds
212         # to pass
213
214         carrier = None
215         try:
216             di._errno_save()
217             carrier = di._socket()
218             di._connect(carrier, path)
219             di._sendmsg(carrier, control_msg)
220             di._close(carrier)
221             carrier = None
222         finally:
223             if carrier is not None:
224                 try: di._close(carrier)
225                 except Exception: pass
226             di._errno_restore()
227
228         di._result('1\n')
229
230     def geteuid(di):
231         try:
232             di._errno_save()
233             val = di._must_func('uid_t (*)(void)', 'geteuid', '()')
234         finally:
235             di._errno_restore()
236         
237         di._result('%d\n' % val)
238
239     def mkdir(di, path):
240         try:
241             di._errno_save()
242             val = di._mkdir(path, int('0700', 8))
243         finally:
244             di._errno_restore()
245
246         di._result('%d\n' % val)
247
248     def _protocol_read(di):
249         input = sys.stdin.readline()
250         if input == '': return None
251         input = input.rstrip('\n')
252         sys.stderr.write("#< %s\n" % input)
253         return input
254
255     def eval_loop(di):
256         while True:
257             di._result('!\n')
258             cmd = di._protocol_read()
259             if cmd is None: break
260             eval(cmd)