chiark / gitweb /
cryptomail: Implement info and revoke commands.
[cryptomail] / bin / cryptomail
1 #! /usr/bin/python
2 ### -*-python-*-
3 ###
4 ### Encrypted email address handling
5 ###
6 ### (c) 2006 Mark Wooding
7 ###
8
9 ###----- Licensing notice ---------------------------------------------------
10 ###
11 ### This program is free software; you can redistribute it and/or modify
12 ### it under the terms of the GNU General Public License as published by
13 ### the Free Software Foundation; either version 2 of the License, or
14 ### (at your option) any later version.
15 ### 
16 ### This program is distributed in the hope that it will be useful,
17 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
18 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 ### GNU General Public License for more details.
20 ### 
21 ### You should have received a copy of the GNU General Public License
22 ### along with this program; if not, write to the Free Software Foundation,
23 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
24
25 ###----- External dependencies ----------------------------------------------
26
27 import catacomb as C
28 import mLib as M
29 from pysqlite2 import dbapi2 as sqlite
30 from UserDict import DictMixin
31 from getopt import getopt, GetoptError
32 from getdate import getdate 
33 from sys import stdin, stdout, stderr, exit, argv, exc_info
34 from email import Parser as EP
35 import os as OS
36 import time as T
37 import sre as RX
38 import traceback as TB
39
40 ###----- Database messing ---------------------------------------------------
41
42 class AttrDB (object):
43   def __init__(me, dbfile):
44     me.db = sqlite.connect(dbfile)
45   def setup(me):
46     cur = me.db.cursor()
47     cur.execute('''CREATE TABLE attr
48                            (id INTEGER PRIMARY KEY,
49                             key VARCHAR(64) NOT NULL,
50                             value VARCHAR(256) NOT NULL)''')
51     cur.execute('''CREATE TABLE attrset
52                            (id INTEGER NOT NULL,
53                            attr INTEGER NOT NULL)''')
54     cur.execute('''CREATE TABLE uniq
55                            (id INTEGER PRIMARY KEY AUTOINCREMENT,
56                             dummy INTEGER NOT NULL)''')
57     cur.execute('CREATE UNIQUE INDEX attr_bykv ON attr (key, value)')
58     cur.execute('CREATE INDEX attrset_byid ON attrset (id)')
59     cur.execute('CREATE INDEX attrset_byattr ON attrset (attr)')
60     cur.execute('CREATE UNIQUE INDEX attrset_all ON attrset (id, attr)')
61   def uniqueid(me):
62     cur = me.db.cursor()
63     cur.execute('INSERT INTO uniq (dummy) VALUES (0)')
64     cur.execute('SELECT MAX(id) FROM uniq')
65     id = cur.fetchone()[0]
66     cur.execute('DELETE FROM uniq')
67     me.commit()
68     return id
69   def select(me, expr, args = [], cur = None):
70     if cur is None: cur = me.db.cursor()
71     cur.execute(expr, args)
72     while True:
73       r = cur.fetchone()
74       if r is None: break
75       yield r
76   def cleanup(me):
77     cur = me.db.cursor()
78     cur.execute('''DELETE FROM attr WHERE id IN
79                            (SELECT attr.id
80                             FROM attr LEFT JOIN attrset
81                             ON attr.id = attrset.attr
82                             WHERE attrset.id ISNULL)''')
83   def check(me, cleanp = False):
84     toclean = {}
85     cur = me.db.cursor()
86     for set, attr in me.select('''SELECT attrset.id, attrset.attr
87                                           FROM attrset LEFT JOIN attr
88                                           ON attrset.attr = attr.id
89                                           WHERE attr.id ISNULL''',
90                                [], cur):
91       print "attrset %d missing attr %d" % (set, attr)
92       toclean[set] = True
93     if cleanp:
94       for set in toclean:
95         cur.execute('DELETE FROM attrset WHERE id = ?', [set])
96       me.cleanup()
97   def commit(me):
98     me.db.commit()
99
100 class AttrSet (object):
101   def __init__(me, db, id = None):
102     if id is None: id = db.uniqueid()
103     me.id = id
104     me.db = db
105   def insert(me, key, value):
106     cur = me.db.db.cursor()
107     try:
108       cur.execute('INSERT INTO attr (key, value) VALUES (?, ?)',
109                   [key, value])
110     except sqlite.OperationalError:
111       pass
112     cur.execute('SELECT id FROM attr WHERE key = ? AND value = ?',
113                 [key, value])
114     r = cur.fetchone()
115     attr = r[0]
116     try:
117       cur.execute('INSERT INTO attrset VALUES (?, ?)',
118                   [me.id, attr])
119     except sqlite.OperationalError:
120       pass
121   def fetch(me):
122     for r in me.db.select('''SELECT attr.key, attr.value
123                                FROM attr, attrset ON attr.id = attrset.attr
124                                WHERE attrset.id = ?''',
125                           [me.id]):
126       yield r
127   def delete(me):
128     cur = me.db.db.cursor()
129     cur.execute('DELETE FROM attrset WHERE id = ?', [me.id])
130     me.db.cleanup()
131
132 class AttrMap (AttrSet, DictMixin):
133   def __getitem__(me, key):
134     it = None
135     for v, in me.db.select('''SELECT attr.value
136                                 FROM attr, attrset ON attr.id = attrset.attr
137                                 WHERE attrset.id = ? AND attr.key = ?''',
138                            [me.id, key]):
139       if it is None:
140         it = v
141       else:
142         raise ValueError, 'multiple values for key %s' % key
143     if it is None:
144       raise KeyError, key
145     return it
146   def __delitem__(me, key):
147     cur = me.db.db.cursor()
148     cur.execute('''DELETE FROM attrset
149                            WHERE id = ? AND
150                                  attr in
151                                    (SELECT id FROM attr WHERE key = ?)''',
152                 [me.id, key])
153     me.db.cleanup()
154   def __setitem__(me, key, value):
155     me.__delitem__(key)
156     me.insert(key, value)
157   def __iter__(me):
158     set = {}
159     for k, v in me.fetch():
160       if k in set:
161         continue
162       set[k] = True
163       yield k
164   def keys(me):
165     return [k for k in me]
166
167 class AttrMultiMap (AttrMap):
168   def __getitem__(me, key):
169     them = []
170     for v, in me.db.select('''SELECT attr.value
171                                 FROM attr, attrset ON attr.id = attrset.attr
172                                 WHERE attrset.id = ? AND attr.key = ?''',
173                            [me.id, key]):
174       them.append(v)
175     if not them:
176       raise KeyError, key
177     return them
178   def __setitem__(me, key, values):
179     me.__delitem__(key)
180     for it in values:
181       me.insert(key, it)
182
183 ###----- Miscellaneous utilities --------------------------------------------
184
185 def time_format(t = None):
186   if t is None:
187     t = T.time()
188   tm = T.gmtime(t)
189   return T.strftime('%Y-%m-%d %H:%M:%S', tm)
190
191 def any(pred, list):
192   for i in list:
193     if pred(i): return True
194   return False
195 def every(pred, list):
196   for i in list:
197     if not pred(i): return False
198   return True
199
200 prog = RX.sub(r'^.*[/\\]', '', argv[0])
201 def moan(msg):
202   print >>stderr, '%s: %s' % (prog, msg)
203 def die(msg):
204   moan(msg)
205   exit(111)
206
207 ###----- My actual database -------------------------------------------------
208
209 class CMDB (AttrDB):
210   def setup(me):
211     AttrDB.setup(me)
212     cur = me.db.cursor()
213     cur.execute('''CREATE TABLE expiry
214                            (attrset INTEGER PRIMARY KEY,
215                             time CHAR(20) NOT NULL)''')
216     cur.execute('CREATE INDEX expiry_bytime ON expiry (time)')
217   def cleanup(me):
218     cur = me.db.cursor()
219     now = time_format()
220     cur.execute('''DELETE FROM attrset WHERE id IN
221                            (SELECT attrset FROM expiry WHERE time < ?)''',
222                 [now])
223     cur.execute('DELETE FROM expiry WHERE time < ?', [now])
224     cur.execute('''DELETE FROM expiry WHERE attrset IN
225                            (SELECT attrset
226                             FROM expiry LEFT JOIN attrset
227                             ON expiry.attrset = attrset.id
228                             WHERE attrset.id ISNULL)''')
229     AttrDB.cleanup(me)
230   def expiry(me, id):
231     for t, in me.select('SELECT time FROM expiry WHERE attrset = ?', [id]):
232       return t
233     return None
234   def expiredp(me, id):
235     t = me.expiry(id)
236     if t is not None and t < time_format():
237       return True
238     else:
239       return False
240   def setexpire(me, id, when):
241     if when != C.KEXP_FOREVER:
242       cur = me.db.cursor()
243       cur.execute('INSERT INTO expiry VALUES (?, ?)',
244                   [id, time_format(when)])
245
246 ###----- Crypto messing about -----------------------------------------------
247
248 ## Very vague security arguments...
249 ##
250 ## If the block size n of the PRP is large enough (128 bits) then we encrypt
251 ## id || 0^{n - 64}.  Decryption checks we have the right thing.  The
252 ## security proofs for secrecy and integrity are trivial.
253 ##
254 ## If the block size is small, then we encrypt two blocks:
255 ##   C_0 = E_K(0^{n - 64} || id)
256 ##   C_1 = E_K(C_0)
257 ## The proofs are a little more complicated, but essentially work like this.
258 ## If no 0^{n - 64} || id is ever seen as a C_0 then an adversary can't tell
259 ## the difference between this and a similar construction using independent
260 ## keys.  This other construction must provide secrecy (pushing a
261 ## nonrepeating thing through a PRF) and integrity (PRF on noncolliding
262 ## inputs).  So we win, give or take a birthday term.
263 class Crypto (object):
264   def __init__(me, key):
265     me.prp = C.gcprps[key.attr.get('prp', 'blowfish')](key.data.bin)
266   def encrypt(me, id):
267     blksz = type(me.prp).blksz
268     p = C.MP(id).storeb(blksz)
269     c = me.prp.encrypt(p)
270     if blksz < 16:
271       c += me.prp.encrypt(c)
272     return c
273   def decrypt(me, c):
274     bad = False
275     blksz = type(me.prp).blksz
276     if blksz < 16:
277       if len(c) != blksz * 2:
278         return None
279       c, c1 = c[:blksz], c[blksz:]
280       if c1 != me.prp.encrypt(c):
281         bad = True
282     else:
283       if len(c) != blksz:
284         return None
285     p = me.prp.decrypt(c)
286     id = C.MP.loadb(p)
287     if id >> 64:
288       bad = True
289     if bad:
290       return None
291     return long(id)
292
293 ###----- Canonification -----------------------------------------------------
294
295 rx_prefix = RX.compile(r'''(?x) ^ (
296   \[ \S+ \] \s* |
297   \S{,4} : \s* |
298   \s+
299 )   
300 ''')
301 rx_suffix = RX.compile(r'''(?ix) (
302   \( \s* was \s* : .* \) \s* |
303   \s+
304 ) $''')
305 rx_punct = RX.compile(r'(?x) [^\w]+ ')
306
307 def canon_sender(addr):
308   return addr.lower()
309
310 def canon_subject(subject):
311   subject = subject.lower()
312   while True:
313     m = rx_prefix.match(subject)
314     if not m: break
315     subject = subject[m.end():]
316   while True:
317     m = rx_suffix.search(subject)
318     if not m: break
319     subject = subject[:m.start()]
320   subject = rx_punct.sub('', subject)
321   return subject
322
323 ###----- Checking a message for validity ------------------------------------
324
325 class Reject (Exception): pass
326
327 class MessageInfo (object):
328   __slots__ = '''
329     sender msg
330   '''.split()
331
332 constraints = {}
333
334 def check_sender(mi, vv):
335   if mi.sender is None:
336     raise Reject, 'no sender'
337   sender = canon_sender(mi.sender)
338   if not any(lambda pat: M.match(pat.lower(), sender), vv):
339     raise Reject, 'unmatched sender'
340 constraints['sender'] = check_sender
341
342 def check_subject(mi, vv):
343   if mi.msg is None:
344     return
345   subj = mi.msg['subject']
346   if subj is None:
347     raise Reject, 'no subject'
348   subj = canon_subject(subj)
349   if not any(lambda pat: M.match(pat.lower(), subj), vv):
350     raise Reject, 'unmatched subject'
351 constraints['subject'] = check_subject
352
353 def check_nothing(me, vv):
354   pass
355   
356 def check(db, id, sender = None, msgfile = None):
357   mi = MessageInfo()
358   a = AttrMultiMap(db, id)
359   try:
360     addr = a['addr'][0]
361   except KeyError:
362     raise Reject, 'unknown id'
363   if db.expiredp(id):
364     raise Reject, 'expired'
365   if msgfile is None:
366     mi.msg = None
367   else:
368     try:
369       mi.msg = EP.HeaderParser().parse(msgfile)
370     except EP.Errors.HeaderParseError:
371       raise Reject, 'unparseable header'
372   mi.sender = sender
373   for k, vv in a.iteritems():
374     constraints.get(k, check_nothing)(mi, vv)
375   return a['addr'][0]
376
377 ###----- Commands -----------------------------------------------------------
378
379 keyfile = 'db/keyring'
380 tag = 'cryptomail'
381 dbfile = 'db/cryptomail.db'
382 user = None
383 commands = {}
384
385 def timecmp(x, y):
386   if x == y:
387     return 0
388   elif x == C.KEXP_FOREVER or y == C.KEXP_EXPIRE:
389     return +1
390   elif y == C.KEXP_FOREVER or x == C.KEXP_EXPIRE:
391     return +1
392   else:
393     return cmp(x, y)
394
395 def cmd_generate(argv):
396   try:
397     opts, argv = getopt(argv, 't:c:f:i:',
398                         ['expire=', 'timeout=', 'constraint=',
399                          'info=', 'format='])
400   except GetoptError:
401     return 1
402   kr = C.KeyFile(keyfile, C.KOPEN_WRITE)
403   k = kr[tag]
404   db = CMDB(dbfile)
405   map = {}
406   expwhen = C.KEXP_FOREVER
407   format = '%'
408   for o, a in opts:
409     if o in ('-t', '--expire', '--timeout'):
410       if a == 'forever':
411         expwhen = C.KEXP_FOREVER
412       else:
413         expwhen = getdate(a)
414     elif o in ('-c', '--constraint'):
415       c, v = a.split('=', 1)
416       if c not in constraints:
417         die("unknown constraint `%s'", c)
418       map.setdefault(c, []).append(v)
419     elif o in ('-f', '--format'):
420       format = a
421     elif o in ('-i', '--info'):
422       map['info'] = [a]
423     else:
424       raise 'Barf!'
425   if timecmp(expwhen, k.deltime) > 0:
426     k.deltime = expwhen
427   if len(argv) != 1:
428     return 1
429   addr = argv[0]
430   a = AttrMultiMap(db)
431   a.update(map)
432   a['addr'] = [addr]
433   if user is not None:
434     a['user'] = [user]
435   c = Crypto(k).encrypt(a.id)
436   db.setexpire(a.id, expwhen)
437   print format.replace('%', M.base32_encode(Crypto(k).encrypt(a.id)).
438                        strip('=').lower())
439   db.commit()
440   kr.save()
441 commands['generate'] = \
442   (cmd_generate, '[-t TIME] [-c TYPE=VALUE] ADDR', """
443 Generate a new encrypted email address token forwarding to ADDR.
444
445 Subcommand options:
446   -t, --timeout=TIME            Address should expire at TIME.
447   -c, --constraint=TYPE=VALUE   Apply constraint on the use of the address.
448   -f, --format=STRING           Substitute token for `%' in STRING.
449
450 Constraint types:
451   sender                        Envelope sender must match glob pattern.
452   subject                       Message subject must match glob pattern.""")
453
454 def cmd_initdb(argv):
455   try:
456     opts, argv = getopt(argv, '', [])
457   except GetoptError:
458     return 1
459   try:
460     OS.unlink(dbfile)
461   except OSError:
462     pass
463   CMDB(dbfile).setup()
464 commands['initdb'] = \
465   (cmd_initdb, '', """
466 Initialize an attribute database.""")
467
468 def getid(local):
469   k = C.KeyFile(keyfile, C.KOPEN_READ)[tag]
470   id = Crypto(k).decrypt(M.base32_decode(local))
471   if id is None:
472     raise Reject, 'decrypt failed'
473   return id
474
475 def cmd_addrcheck(argv):
476   try:
477     opts, argv = getopt(argv, '', [])
478   except GetoptError:
479     return 1
480   local, sender = (lambda addr, sender = None, *hunoz: (addr, sender))(*argv)
481   db = CMDB(dbfile)
482   try:
483     id = getid(local)
484     addr = check(db, id, sender)
485   except Reject, msg:
486     print '-%s' % msg
487     return
488   print '+%s' % addr
489 commands['addrcheck'] = \
490   (cmd_addrcheck, 'LOCAL [SENDER [IGNORED ...]]', """
491 Check address token LOCAL, and report `-REASON' for failure or `+ADDR' for
492 success.""")
493
494 def cmd_fwaddr(argv):
495   try:
496     opts, argv = getopt(argv, '', [])
497   except GetoptError:
498     return 1
499   if len(argv) not in (1, 2):
500     return 1
501   local, sender = (lambda addr, sender = None: (addr, sender))(*argv)
502   db = CMDB(dbfile)
503   try:
504     id = getid(local)
505     if id is None:
506       raise Reject, 'decrypt failed'
507     addr = check(db, id, sender, stdin)
508   except Reject, msg:
509     print >>stderr, '%s rejected message: %s' % (prog, msg)
510     exit(100)
511   stdin.seek(0)
512   print addr
513 commands['fwaddr'] = \
514   (cmd_fwaddr, 'LOCAL [SENDER]', """
515 Check address token LOCAL.  On failure, report reason to stderr and exit
516 111.  On success, write forwarding address to stdout and exit 0.  Expects
517 the message on standard input, as a seekable file.""")
518
519 def cmd_info(argv):
520   try:
521     opts, argv = getopt(argv, '', [])
522   except GetoptError:
523     return 1
524   if len(argv) != 1:
525     return 1
526   local = argv[0]
527   db = CMDB(dbfile)
528   try:
529     id = getid(local)
530     a = AttrMultiMap(db, id)
531     if user is not None and user != a.get('user', [None])[0]:
532       raise Reject, 'not your token'
533     if 'addr' not in a:
534       die('unknown token (expired?)')
535     keys = a.keys()
536     keys.sort()
537     for k in keys:
538       for v in a[k]:
539         print '%s: %s' % (k, v)
540     expwhen = db.expiry(id)
541     if expwhen:
542       print 'expires: %s'
543     else:
544       print 'no-expiry'
545   except Reject, msg:
546     die('invalid token')
547 commands['info'] = \
548   (cmd_info, 'LOCAL', """
549 Exaimne the address token LOCAL, and print information about it to standard
550 output.""")
551
552 def cmd_revoke(argv):
553   try:
554     opts, argv = getopt(argv, '', [])
555   except GetoptError:
556     return 1
557   if len(argv) != 1:
558     return 1
559   local = argv[0]
560   db = CMDB(dbfile)
561   try:
562     id = getid(local)
563     a = AttrMultiMap(db, id)
564     if user is not None and user != a.get('user', [None])[0]:
565       raise Reject, 'not your token'
566     if 'addr' not in a:
567       die('unknown token (expired?)')
568     a.clear()
569     db.cleanup()
570     db.commit()
571   except Reject, msg:
572     die('invalid token')
573 commands['revoke'] = \
574   (cmd_revoke, 'LOCAL', """
575 Revoke the token LOCAL.""")
576
577 def cmd_cleanup(argv):
578   try:
579     opts, argv = getopt(argv, '', [])
580   except GetoptError:
581     return 1
582   db = CMDB(dbfile)
583   db.cleanup()
584   cur = db.db.cursor()
585   cur.execute('VACUUM')
586   db.commit()
587 commands['cleanup'] = \
588   (cmd_cleanup, '', """
589 Cleans up the attribute database, disposing of old records and compatifying
590 the file.""")
591
592 def cmd_help(argv):
593   try:
594     opts, argv = getopt(argv, '', [])
595   except GetoptError:
596     return 1
597   if len(argv) == 0:
598     cmd = None
599   elif len(argv) == 1:
600     try:
601       cmd = argv[0]
602       ci = commands[cmd]
603     except KeyError:
604       die("unknown command `%s'" % cmd)
605   else:
606     return 1    
607   version()
608   print
609   if cmd:
610     print 'Usage: %s [-OPTIONS] %s %s' % (prog, cmd, ci[1])
611     print ci[2]
612   else:
613     usage(stdout)
614     print """
615 Handle encrypted email addresses.
616
617 Help options:
618   -h, --help                    Show this help text.
619   -v, --version                 Show version number.
620   -u, --usage                   Show a usage message.
621
622 Global options:
623   -d, --database=FILE           Use FILE as the attribute database.
624   -k, --keyring=KEYRING         Use KEYRING as the keyring.
625   -t, --tag=TAG                 Use TAG as the key tag.
626   -U, --user=USER               Claim to be USER.
627 """
628     cmds = commands.keys()
629     cmds.sort()
630     print 'Subcommands:'
631     for c in cmds:
632       print '  %s %s' % (c, commands[c][1])
633 commands['help'] = \
634   (cmd_help, '[COMMAND]', """
635 Show help for subcommand COMMAND.
636 """)
637
638 ###----- Main program -------------------------------------------------------
639
640 def usage(file):
641   print >>file, \
642     'Usage: %s [-d FILE] [-k KEYRING] [-t TAG] COMMAND [ARGS...]' % prog
643 def version():
644   print '%s version 1.0.0' % prog
645 def help():
646   cmd_help()  
647
648 def main():
649   global argv, user, keyfile, dbfile, tag
650   try:
651     opts, argv = getopt(argv[1:],
652                         'hvud:k:t:U:',
653                         ['help', 'version', 'usage',
654                          'database=', 'keyring=', 'tag=', 'user='])
655   except GetoptError:
656     usage(stderr)
657     exit(111)
658   for o, a in opts:
659     if o in ('-h', '--help'):
660       help()
661       exit(0)
662     elif o in ('-v', '--version'):
663       version()
664       exit(0)
665     elif o in ('-u', '--usage'):
666       usage(stdout)
667       exit(0)
668     elif o in ('-d', '--database'):
669       dbfile = a
670     elif o in ('-k', '--keyring'):
671       keyfile = a
672     elif o in ('-t', '--tag'):
673       tag = a
674     elif o in ('-U', '--user'):
675       user = a
676     else:
677       raise 'Barf!'
678   if len(argv) < 1:
679     usage(stderr)
680     exit(111)
681
682   if argv[0] in commands:
683     c = argv[0]
684     argv = argv[1:]
685   else:
686     usage(stderr)
687     exit(111)
688   cmd = commands[c]
689   if cmd[0](argv):
690     print >>stderr, 'Usage: %s %s %s' % (prog, c, cmd[1])
691     exit(111)
692
693 try:
694   main()
695 except SystemExit:
696   raise
697 except:
698   ty, exc, tb = exc_info()
699   moan('unhandled %s exception' % ty.__name__)
700   for file, line, func, text in TB.extract_tb(tb):
701     print >>stderr, \
702           '  %-35s -- %.38s' % ('%s:%d (%s)' % (file, line, func), text)
703   die('%s: %s' % (ty.__name__, exc[0]))