chiark / gitweb /
make-secnet-sites: Refactor operational code into OpModes
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64 import math
65
66 import ipaddress
67
68 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
69 # which we don't want to preempt
70 sys.path.insert(1,"/usr/local/share/secnet")
71 sys.path.insert(1,"/usr/share/secnet")
72 import ipaddrset
73 import base91
74
75 from argparseactionnoyes import ActionNoYes
76
77 VERSION="0.1.18"
78
79 max_version = 2
80
81 from sys import version_info
82 if version_info.major == 2:  # for python2
83     import codecs
84     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
85     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
86     import io
87     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
88
89 max={'rsa_bits':8200,'name':33,'dh_bits':8200,'algname':127}
90
91 def debugrepr(*args):
92         if debug_level > 0:
93                 print(repr(args), file=sys.stderr)
94
95 def base91s_encode(bindata):
96         return base91.encode(bindata).replace('"',"-")
97
98 def base91s_decode(string):
99         return base91.decode(string.replace("-",'"'))
100
101 class Tainted:
102         def __init__(self,s,tline=None,tfile=None):
103                 self._s=s
104                 self._ok=None
105                 self._line=line if tline is None else tline
106                 self._file=file if tfile is None else tfile
107         def __eq__(self,e):
108                 return self._s==e
109         def __ne__(self,e):
110                 # for Python2
111                 return not self.__eq__(e)
112         def __str__(self):
113                 raise RuntimeError('direct use of Tainted value')
114         def __repr__(self):
115                 return 'Tainted(%s)' % repr(self._s)
116
117         def _bad(self,what,why):
118                 assert(self._ok is not True)
119                 self._ok=False
120                 complain('bad parameter: %s: %s' % (what, why))
121                 return False
122
123         def _max_ok(self,what,maxlen):
124                 if len(self._s) > maxlen:
125                         return self._bad(what,'too long (max %d)' % maxlen)
126                 return True
127
128         def _re_ok(self,bad,what,maxlen=None):
129                 if maxlen is None: maxlen=max[what]
130                 self._max_ok(what,maxlen)
131                 if self._ok is False: return False
132                 if bad.search(self._s):
133                         #print(repr(self), file=sys.stderr)
134                         return self._bad(what,'bad syntax')
135                 return True
136
137         def _rtnval(self, is_ok, ifgood, ifbad=''):
138                 if is_ok:
139                         assert(self._ok is not False)
140                         self._ok=True
141                         return ifgood
142                 else:
143                         assert(self._ok is not True)
144                         self._ok=False
145                         return ifbad
146
147         def _rtn(self, is_ok, ifbad=''):
148                 return self._rtnval(is_ok, self._s, ifbad)
149
150         def raw(self):
151                 return self._s
152         def raw_mark_ok(self):
153                 # caller promises to throw if syntax was dangeorus
154                 return self._rtn(True)
155
156         def output(self):
157                 if self._ok is False: return ''
158                 if self._ok is True: return self._s
159                 print('%s:%d: unchecked/unknown additional data "%s"' %
160                       (self._file,self._line,self._s),
161                       file=sys.stderr)
162                 sys.exit(1)
163
164         bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
165         # secnet accepts _ at start of names, but we reserve that
166         bad_name_counter=0
167         def name(self,what='name'):
168                 ok=self._re_ok(Tainted.bad_name,what)
169                 return self._rtn(ok,
170                                  '_line%d_%s' % (self._line, id(self)))
171
172         def keyword(self):
173                 ok=self._s in keywords or self._s in levels
174                 if not ok:
175                         complain('unknown keyword %s' % self._s)
176                 return self._rtn(ok)
177
178         bad_hex=re.compile(r'[^0-9a-fA-F]')
179         def bignum_16(self,kind,what):
180                 maxlen=(max[kind+'_bits']+3)/4
181                 ok=self._re_ok(Tainted.bad_hex,what,maxlen)
182                 return self._rtn(ok)
183
184         bad_num=re.compile(r'[^0-9]')
185         def bignum_10(self,kind,what):
186                 maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
187                 ok=self._re_ok(Tainted.bad_num,what,maxlen)
188                 return self._rtn(ok)
189
190         def number(self,minn,maxx,what='number'):
191                 # not for bignums
192                 ok=self._re_ok(Tainted.bad_num,what,10)
193                 if ok:
194                         v=int(self._s)
195                         if v<minn or v>maxx:
196                                 ok=self._bad(what,'out of range %d..%d'
197                                              % (minn,maxx))
198                 return self._rtnval(ok,v,minn)
199
200         def hexid(self,byteslen,what):
201                 ok=self._re_ok(Tainted.bad_hex,what,byteslen*2)
202                 if ok:
203                         if len(self._s) < byteslen*2:
204                                 ok=self._bad(what,'too short')
205                 return self._rtn(ok,ifbad='00'*byteslen)
206
207         bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
208         # We permit _ so we can refer to special non-host domains
209         # which have A and AAAA RRs.  This is a crude check and we may
210         # still produce config files with syntactically invalid
211         # domains or addresses, but that is OK.
212         def host(self):
213                 ok=self._re_ok(Tainted.bad_host,'host/address',255)
214                 return self._rtn(ok)
215
216         bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
217         # ^ This does not accept all valid email addresses.  That's
218         # not really possible with this input syntax.  It accepts
219         # all ones that don't require quoting anywhere in email
220         # protocols (and also accepts some invalid ones).
221         def email(self):
222                 ok=self._re_ok(Tainted.bad_email,'email address',1023)
223                 return self._rtn(ok)
224
225         bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
226         def groupname(self):
227                 ok=self._re_ok(Tainted.bad_groupname,'group name',64)
228                 return self._rtn(ok)
229
230         bad_base91=re.compile(r'[^!-~]|[\'\"\\]')
231         def base91(self,what='base91'):
232                 ok=self._re_ok(Tainted.bad_base91,what,4096)
233                 return self._rtn(ok)
234
235 class ArgActionLambda(argparse.Action):
236         def __init__(self, fn, **kwargs):
237                 self.fn=fn
238                 argparse.Action.__init__(self,**kwargs)
239         def __call__(self,ap,ns,values,option_string):
240                 self.fn(values,ns,ap,option_string)
241
242 class PkmBase():
243         def site_start(self,pubkeys_path):
244                 self._pa=pubkeys_path
245                 self._fs = FilterState()
246         def site_serial(self,serial): pass
247         def write_key(self,k): pass
248         def site_finish(self,confw): pass
249
250 class PkmSingle(PkmBase):
251         opt = 'single'
252         help = 'write one public key per site to sites.conf'
253         def site_start(self,pubkeys_path):
254                 PkmBase.site_start(self,pubkeys_path)
255                 self._outk = []
256         def write_key(self,k):
257                 if k.okforonlykey(output_version,self._fs):
258                         self._outk.append(k)
259         def site_finish(self,confw):
260                 if len(self._outk) == 0:
261                         complain("site with no public key");
262                 elif len(self._outk) != 1:
263                         debugrepr('outk ', self._outk)
264                         complain(
265  "site with multiple public keys, without --pubkeys-install (maybe --output-version=1 would help"
266                         )
267                 else:
268                         confw.write("key %s;\n"%str(self._outk[0]))
269
270 class PkmInstall(PkmBase):
271         opt = 'install'
272         help = 'install public keys in public key directory'
273         def site_start(self,pubkeys_path):
274                 PkmBase.site_start(self,pubkeys_path)
275                 self._pw=open(self._pa+'~tmp','w')
276         def site_serial(self,serial):
277                 self._pw.write('serial %s\n' % serial)
278         def write_key(self,k):
279                 wout=k.forpub(output_version,self._fs)
280                 self._pw.write(' '.join(wout))
281                 self._pw.write('\n')
282         def site_finish(self,confw):
283                 self._pw.close()
284                 os.rename(self._pa+'~tmp',self._pa+'~update')
285                 PkmElide.site_finish(self,confw)
286
287 class PkmElide(PkmBase):
288         opt = 'elide'
289         help = 'no public keys in sites.conf output nor in directory'
290         def site_finish(self,confw):
291                 confw.write("peer-keys \"%s\";\n"%self._pa);
292
293 class OpBase():
294         # Base case is reading a sites file from self.inputfilee.
295         def read_in(self):
296                 if self.inputfile is None:
297                         pfile("stdin",sys.stdin.readlines())
298                 else:
299                         pfilepath(self.inputfile)
300
301 class OpConf(OpBase):
302         def is_service(self): return 0
303         def positional_args(self, av):
304                 if len(av.arg)>3:
305                         print("Too many arguments")
306                         sys.exit(1)
307                 (self.inputfile, self.outputfile) = (av.arg + [None]*2)[0:2]
308         def check_group(self,group,w): pass
309         def write_out(self):
310                 if self.outputfile is None:
311                         of=sys.stdout
312                 else:
313                         tmp_outputfile=self.outputfile+'~tmp~'
314                         of=open(tmp_outputfile,'w')
315                 outputsites(of)
316                 if self.outputfile is not None:
317                         os.rename(tmp_outputfile,self.outputfile)
318
319 class OpUserv(OpBase):
320         opts = ['--userv','-u']
321         help = 'userv service fragment update mode'
322         def is_service(self): return 1
323         def positional_args(self, av):
324                 if len(av.arg)!=4:
325                         print("Wrong number of arguments")
326                         sys.exit(1)
327                 (self.header, self.groupfiledir,
328                  self.sitesfile, self.group) = av.arg
329                 self.group = Tainted(self.group,0,'command line')
330                 # untrusted argument from caller
331                 if "USERV_USER" not in os.environ:
332                         print("Environment variable USERV_USER not found")
333                         sys.exit(1)
334                 self.user=os.environ["USERV_USER"]
335                 # Check that group is in USERV_GROUP
336                 if "USERV_GROUP" not in os.environ:
337                         print("Environment variable USERV_GROUP not found")
338                         sys.exit(1)
339                 ugs=os.environ["USERV_GROUP"]
340                 ok=0
341                 for i in ugs.split():
342                         if self.group==i: ok=1
343                 if not ok:
344                         print("caller not in group %s"%group)
345                         sys.exit(1)
346         def check_group(self,group,w):
347                 if group!=self.group: complain("Incorrect group!")
348                 w[2].groupname()
349         def read_in(self):
350                 self.headerinput=pfilepath(self.header,allow_include=True)
351                 self.userinput=sys.stdin.readlines()
352                 pfile("user input",self.userinput)
353         def write_out(self):
354                 # Put the user's input into their group file, and
355                 # rebuild the main sites file
356                 f=open(self.groupfiledir+"/T"+self.group.groupname(),'w')
357                 f.write("# Section submitted by user %s, %s\n"%
358                         (self.user,time.asctime(time.localtime(time.time()))))
359                 f.write("# Checked by make-secnet-sites version %s\n\n"
360                         %VERSION)
361                 for i in self.userinput: f.write(i)
362                 f.write("\n")
363                 f.close()
364                 os.rename(self.groupfiledir+"/T"+self.group.groupname(),
365                           self.groupfiledir+"/R"+self.group.groupname())
366                 f=open(self.sitesfile+"-tmp",'w')
367                 f.write("# sites file autogenerated by make-secnet-sites\n")
368                 f.write("# generated %s, invoked by %s\n"%
369                         (time.asctime(time.localtime(time.time())),
370                          self.user))
371                 f.write("# use make-secnet-sites to turn this file into a\n")
372                 f.write("# valid /etc/secnet/sites.conf file\n\n")
373                 for i in self.headerinput: f.write(i)
374                 files=os.listdir(self.groupfiledir)
375                 for i in files:
376                         if i[0]=='R':
377                                 j=open(self.groupfiledir+"/"+i)
378                                 f.write(j.read())
379                                 j.close()
380                 f.write("# end of sites file\n")
381                 f.close()
382                 os.rename(self.sitesfile+"-tmp",self.sitesfile)
383                 
384
385 def parse_args():
386         global opmode
387         global service
388         global prefix
389         global key_prefix
390         global debug_level
391         global output_version
392         global pubkeys_dir
393         global pubkeys_mode
394
395         ap = argparse.ArgumentParser(description='process secnet sites files')
396         def add_opmode(how):
397                 ap.add_argument(*how().opts, action=ArgActionLambda,
398                         nargs=0,
399                         fn=(lambda v,ns,*x: setattr(ns,'opmode',how)),
400                         help=how().help)
401         add_opmode(OpUserv)
402         ap.add_argument('--conf-key-prefix', action=ActionNoYes,
403                         default=True,
404                  help='prefix conf file key names derived from sites data')
405         def add_pkm(how):
406                 ap.add_argument('--pubkeys-'+how().opt, action=ArgActionLambda,
407                         nargs=0,
408                         fn=(lambda v,ns,*x: setattr(ns,'pkm',how)),
409                         help=how().help)
410         add_pkm(PkmInstall)
411         add_pkm(PkmSingle)
412         add_pkm(PkmElide)
413         ap.add_argument('--pubkeys-dir',  nargs=1,
414                         help='public key directory',
415                         default=['/var/lib/secnet/pubkeys'])
416         ap.add_argument('--output-version', nargs=1, type=int,
417                         help='sites file output version',
418                         default=[max_version])
419         ap.add_argument('--prefix', '-P', nargs=1,
420                         help='set prefix')
421         ap.add_argument('--debug', '-D', action='count', default=0)
422         ap.add_argument('arg',nargs=argparse.REMAINDER)
423         av = ap.parse_args()
424         debug_level = av.debug
425         debugrepr('av',av)
426         opmode = getattr(av,'opmode',OpConf)()
427         service = opmode.is_service()
428         prefix = '' if av.prefix is None else av.prefix[0]
429         key_prefix = av.conf_key_prefix
430         output_version = av.output_version[0]
431         pubkeys_dir = av.pubkeys_dir[0]
432         pubkeys_mode = getattr(av,'pkm',PkmSingle)
433         opmode.positional_args(av)
434
435 parse_args()
436
437 # Classes describing possible datatypes in the configuration file
438
439 class basetype:
440         "Common protocol for configuration types."
441         def add(self,obj,w):
442                 complain("%s %s already has property %s defined"%
443                         (obj.type,obj.name,w[0].raw()))
444         def forsites(self,version,copy,fs):
445                 return copy
446
447 class conflist:
448         "A list of some kind of configuration type."
449         def __init__(self,subtype,w):
450                 self.subtype=subtype
451                 self.list=[subtype(w)]
452         def add(self,obj,w):
453                 self.list.append(self.subtype(w))
454         def __str__(self):
455                 return ', '.join(map(str, self.list))
456         def forsites(self,version,copy,fs):
457                 most_recent=self.list[len(self.list)-1]
458                 return most_recent.forsites(version,copy,fs)
459 def listof(subtype):
460         return lambda w: conflist(subtype, w)
461
462 class single_ipaddr (basetype):
463         "An IP address"
464         def __init__(self,w):
465                 self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
466         def __str__(self):
467                 return '"%s"'%self.addr
468
469 class networks (basetype):
470         "A set of IP addresses specified as a list of networks"
471         def __init__(self,w):
472                 self.set=ipaddrset.IPAddressSet()
473                 for i in w[1:]:
474                         x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
475                         self.set.append([x])
476         def __str__(self):
477                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
478
479 class dhgroup (basetype):
480         "A Diffie-Hellman group"
481         def __init__(self,w):
482                 self.mod=w[1].bignum_16('dh','dh mod')
483                 self.gen=w[2].bignum_16('dh','dh gen')
484         def __str__(self):
485                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
486
487 class hash (basetype):
488         "A choice of hash function"
489         def __init__(self,w):
490                 hname=w[1]
491                 self.ht=hname.raw()
492                 if (self.ht!='md5' and self.ht!='sha1'):
493                         complain("unknown hash type %s"%(self.ht))
494                         self.ht=None
495                 else:
496                         hname.raw_mark_ok()
497         def __str__(self):
498                 return '%s'%(self.ht)
499
500 class email (basetype):
501         "An email address"
502         def __init__(self,w):
503                 self.addr=w[1].email()
504         def __str__(self):
505                 return '<%s>'%(self.addr)
506
507 class boolean (basetype):
508         "A boolean"
509         def __init__(self,w):
510                 v=w[1]
511                 if re.match('[TtYy1]',v.raw()):
512                         self.b=True
513                         v.raw_mark_ok()
514                 elif re.match('[FfNn0]',v.raw()):
515                         self.b=False
516                         v.raw_mark_ok()
517                 else:
518                         complain("invalid boolean value");
519         def __str__(self):
520                 return ['False','True'][self.b]
521
522 class num (basetype):
523         "A decimal number"
524         def __init__(self,w):
525                 self.n=w[1].number(0,0x7fffffff)
526         def __str__(self):
527                 return '%d'%(self.n)
528
529 class serial (basetype):
530         def __init__(self,w):
531                 self.i=w[1].hexid(4,'serial')
532         def __str__(self):
533                 return self.i
534         def forsites(self,version,copy,fs):
535                 if version < 2: return []
536                 return copy
537
538 class address (basetype):
539         "A DNS name and UDP port number"
540         def __init__(self,w):
541                 self.adr=w[1].host()
542                 self.port=w[2].number(1,65536,'port')
543         def __str__(self):
544                 return '"%s"; port %d'%(self.adr,self.port)
545
546 class inpub (basetype):
547         def forsites(self,version,xcopy,fs):
548                 return self.forpub(version,fs)
549
550 class pubkey (inpub):
551         "Some kind of publie key"
552         def __init__(self,w):
553                 self.a=w[1].name('algname')
554                 self.d=w[2].base91();
555         def __str__(self):
556                 return 'make-public("%s","%s")'%(self.a,self.d)
557         def forpub(self,version,fs):
558                 if version < 2: return []
559                 return ['pub', self.a, self.d]
560         def okforonlykey(self,version,fs):
561                 return len(self.forpub(version,fs)) != 0
562
563 class rsakey (pubkey):
564         "An old-style RSA public key"
565         def __init__(self,w):
566                 self.l=w[1].number(0,max['rsa_bits'],'rsa len')
567                 self.e=w[2].bignum_10('rsa','rsa e')
568                 self.n=w[3].bignum_10('rsa','rsa n')
569                 if len(w) >= 5: w[4].email()
570                 self.a='rsa1'
571                 self.d=base91s_encode(b'%d %s %s' %
572                                       (self.l,
573                                        self.e.encode('ascii'),
574                                        self.n.encode('ascii')))
575                 # ^ this allows us to use the pubkey.forsites()
576                 # method for output in versions>=2
577         def __str__(self):
578                 return 'rsa-public("%s","%s")'%(self.e,self.n)
579                 # this specialisation means we can generate files
580                 # compatible with old secnet executables
581         def forpub(self,version,fs):
582                 if version < 2:
583                         if fs.pkg != '00000000': return []
584                         return ['pubkey', str(self.l), self.e, self.n]
585                 return pubkey.forpub(self,version,fs)
586
587 class rsakey_newfmt(rsakey):
588         "An old-style RSA public key in new-style sites format"
589         # This is its own class simply to have its own constructor.
590         def __init__(self,w):
591                 self.a=w[1].name()
592                 assert(self.a == 'rsa1')
593                 self.d=w[2].base91()
594                 try:
595                         w_inner=list(map(Tainted,
596                                         ['X-PUB-RSA1'] +
597                                         base91s_decode(self.d)
598                                         .decode('ascii')
599                                         .split(' ')))
600                 except UnicodeDecodeError:
601                         complain('rsa1 key in new format has bad base91')
602                 #print(repr(w_inner), file=sys.stderr)
603                 rsakey.__init__(self,w_inner)
604
605 class pubkey_group(inpub):
606         "Public key group introducer"
607         # appears in the site's list of keys mixed in with the keys
608         def __init__(self,w,fallback):
609                 self.i=w[1].hexid(4,'pkg-id')
610                 self.fallback=fallback
611         def forpub(self,version,fs):
612                 fs.pkg=self.i
613                 if version < 2: return []
614                 return ['pkgf' if self.fallback else 'pkg', self.i]
615         def okforonlykey(self,version,fs):
616                 self.forpub(version,fs)
617                 return False
618         
619 def somepubkey(w):
620         #print(repr(w), file=sys.stderr)
621         if w[0]=='pubkey':
622                 return rsakey(w)
623         elif w[0]=='pub' and w[1]=='rsa1':
624                 return rsakey_newfmt(w)
625         elif w[0]=='pub':
626                 return pubkey(w)
627         elif w[0]=='pkg':
628                 return pubkey_group(w,False)
629         elif w[0]=='pkgf':
630                 return pubkey_group(w,True)
631         else:
632                 assert(False)
633
634 # Possible properties of configuration nodes
635 keywords={
636  'contact':(email,"Contact address"),
637  'dh':(dhgroup,"Diffie-Hellman group"),
638  'hash':(hash,"Hash function"),
639  'key-lifetime':(num,"Maximum key lifetime (ms)"),
640  'setup-timeout':(num,"Key setup timeout (ms)"),
641  'setup-retries':(num,"Maximum key setup packet retries"),
642  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
643  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
644  'restrict-nets':(networks,"Allowable networks"),
645  'networks':(networks,"Claimed networks"),
646  'serial':(serial,"public key set serial"),
647  'pkg':(listof(somepubkey),"start of public key group",'pub'),
648  'pkgf':(listof(somepubkey),"start of fallback public key group",'pub'),
649  'pub':(listof(somepubkey),"new style public site key"),
650  'pubkey':(listof(somepubkey),"Old-style RSA public site key",'pub'),
651  'peer':(single_ipaddr,"Tunnel peer IP address"),
652  'address':(address,"External contact address and port"),
653  'mobile':(boolean,"Site is mobile"),
654 }
655
656 def sp(name,value):
657         "Simply output a property - the default case"
658         return "%s %s;\n"%(name,value)
659
660 # All levels support these properties
661 global_properties={
662         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
663         'dh':sp,
664         'hash':sp,
665         'key-lifetime':sp,
666         'setup-timeout':sp,
667         'setup-retries':sp,
668         'wait-time':sp,
669         'renegotiate-time':sp,
670         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
671 }
672
673 class level:
674         "A level in the configuration hierarchy"
675         depth=0
676         leaf=0
677         allow_properties={}
678         require_properties={}
679         def __init__(self,w):
680                 self.type=w[0].keyword()
681                 self.name=w[1].name()
682                 self.properties={}
683                 self.children={}
684         def indent(self,w,t):
685                 w.write("                 "[:t])
686         def prop_out(self,n):
687                 return self.allow_properties[n](n,str(self.properties[n]))
688         def output_props(self,w,ind):
689                 for i in sorted(self.properties.keys()):
690                         if self.allow_properties[i]:
691                                 self.indent(w,ind)
692                                 w.write("%s"%self.prop_out(i))
693         def kname(self):
694                 return ((self.type[0].upper() if key_prefix else '')
695                         + self.name)
696         def output_data(self,w,path):
697                 ind = 2*len(path)
698                 self.indent(w,ind)
699                 w.write("%s {\n"%(self.kname()))
700                 self.output_props(w,ind+2)
701                 if self.depth==1: w.write("\n");
702                 for k in sorted(self.children.keys()):
703                         c=self.children[k]
704                         c.output_data(w,path+(c,))
705                 self.indent(w,ind)
706                 w.write("};\n")
707
708 class vpnlevel(level):
709         "VPN level in the configuration hierarchy"
710         depth=1
711         leaf=0
712         type="vpn"
713         allow_properties=global_properties.copy()
714         require_properties={
715          'contact':"VPN admin contact address"
716         }
717         def __init__(self,w):
718                 level.__init__(self,w)
719         def output_vpnflat(self,w,path):
720                 "Output flattened list of site names for this VPN"
721                 ind=2*(len(path)+1)
722                 self.indent(w,ind)
723                 w.write("%s {\n"%(self.kname()))
724                 for i in self.children.keys():
725                         self.children[i].output_vpnflat(w,path+(self,))
726                 w.write("\n")
727                 self.indent(w,ind+2)
728                 w.write("all-sites %s;\n"%
729                         ','.join(map(lambda i: i.kname(),
730                                      self.children.values())))
731                 self.indent(w,ind)
732                 w.write("};\n")
733
734 class locationlevel(level):
735         "Location level in the configuration hierarchy"
736         depth=2
737         leaf=0
738         type="location"
739         allow_properties=global_properties.copy()
740         require_properties={
741          'contact':"Location admin contact address",
742         }
743         def __init__(self,w):
744                 level.__init__(self,w)
745                 self.group=w[2].groupname()
746         def output_vpnflat(self,w,path):
747                 ind=2*(len(path)+1)
748                 self.indent(w,ind)
749                 # The "path=path,self=self" abomination below exists because
750                 # Python didn't support nested_scopes until version 2.1
751                 #
752                 #"/"+self.name+"/"+i
753                 w.write("%s %s;\n"%(self.kname(),','.join(
754                         map(lambda x,path=path,self=self:
755                             '/'.join([prefix+"vpn-data"] + list(map(
756                                     lambda i: i.kname(),
757                                     path+(self,x)))),
758                             self.children.values()))))
759
760 class sitelevel(level):
761         "Site level (i.e. a leafnode) in the configuration hierarchy"
762         depth=3
763         leaf=1
764         type="site"
765         allow_properties=global_properties.copy()
766         allow_properties.update({
767          'address':sp,
768          'networks':None,
769          'peer':None,
770          'serial':None,
771          'pkg':None,
772          'pkgf':None,
773          'pub':None,
774          'pubkey':None,
775          'mobile':sp,
776         })
777         require_properties={
778          'dh':"Diffie-Hellman group",
779          'contact':"Site admin contact address",
780          'networks':"Networks claimed by the site",
781          'hash':"hash function",
782          'peer':"Gateway address of the site",
783         }
784         def mangle_name(self):
785                 return self.name.replace('/',',')
786         def pubkeys_path(self):
787                 return pubkeys_dir + '/peer.' + self.mangle_name()
788         def __init__(self,w):
789                 level.__init__(self,w)
790         def output_data(self,w,path):
791                 ind=2*len(path)
792                 np='/'.join(map(lambda i: i.name, path))
793                 self.indent(w,ind)
794                 w.write("%s {\n"%(self.kname()))
795                 self.indent(w,ind+2)
796                 w.write("name \"%s\";\n"%(np,))
797                 self.indent(w,ind+2)
798
799                 pkm = pubkeys_mode()
800                 debugrepr('pkm ',pkm)
801                 pkm.site_start(self.pubkeys_path())
802                 if 'serial' in self.properties:
803                         pkm.site_serial(self.properties['serial'])
804
805                 for k in self.properties["pub"].list:
806                         debugrepr('pubkeys ', k)
807                         pkm.write_key(k)
808
809                 pkm.site_finish(w)
810
811                 self.output_props(w,ind+2)
812                 self.indent(w,ind+2)
813                 w.write("link netlink {\n");
814                 self.indent(w,ind+4)
815                 w.write("routes %s;\n"%str(self.properties["networks"]))
816                 self.indent(w,ind+4)
817                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
818                 self.indent(w,ind+2)
819                 w.write("};\n")
820                 self.indent(w,ind)
821                 w.write("};\n")
822
823 # Levels in the configuration file
824 # (depth,properties)
825 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
826
827 def complain(msg):
828         "Complain about a particular input line"
829         moan(("%s line %d: "%(file,line))+msg)
830 def moan(msg):
831         "Complain about something in general"
832         global complaints
833         print(msg);
834         if complaints is None: sys.exit(1)
835         complaints=complaints+1
836
837 class UntaintedRoot():
838         def __init__(self,s): self._s=s
839         def name(self): return self._s
840         def keyword(self): return self._s
841
842 root=level([UntaintedRoot(x) for x in ['root','root']])
843 # All vpns are children of this node
844 obstack=[root]
845 allow_defs=0   # Level above which new definitions are permitted
846
847 def set_property(obj,w):
848         "Set a property on a configuration node"
849         prop=w[0]
850         propname=prop.raw_mark_ok()
851         kw=keywords[propname]
852         if len(kw) >= 3: propname=kw[2] # for aliases
853         if propname in obj.properties:
854                 obj.properties[propname].add(obj,w)
855         else:
856                 obj.properties[propname]=kw[0](w)
857         return obj.properties[propname]
858
859 class FilterState:
860         def __init__(self):
861                 self.reset()
862         def reset(self):
863                 # called when we enter a new node,
864                 # in particular, at the start of each site
865                 self.pkg = '00000000'
866
867 def pline(il,filterstate,allow_include=False):
868         "Process a configuration file line"
869         global allow_defs, obstack, root
870         w=il.rstrip('\n').split()
871         if len(w)==0: return ['']
872         w=list([Tainted(x) for x in w])
873         keyword=w[0]
874         current=obstack[len(obstack)-1]
875         copyout_core=lambda: ' '.join([ww.output() for ww in w])
876         indent='    '*len(obstack)
877         copyout=lambda: [indent + copyout_core() + '\n']
878         if keyword=='end-definitions':
879                 keyword.raw_mark_ok()
880                 allow_defs=sitelevel.depth
881                 obstack=[root]
882                 return copyout()
883         if keyword=='include':
884                 if not allow_include:
885                         complain("include not permitted here")
886                         return []
887                 if len(w) != 2:
888                         complain("include requires one argument")
889                         return []
890                 newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
891                 # ^ user of "include" is trusted so raw_mark_ok is good
892                 return pfilepath(newfile,allow_include=allow_include)
893         if keyword.raw() in levels:
894                 # We may go up any number of levels, but only down by one
895                 newdepth=levels[keyword.raw_mark_ok()].depth
896                 currentdepth=len(obstack) # actually +1...
897                 if newdepth<=currentdepth:
898                         obstack=obstack[:newdepth]
899                 if newdepth>currentdepth:
900                         complain("May not go from level %d to level %d"%
901                                 (currentdepth-1,newdepth))
902                 # See if it's a new one (and whether that's permitted)
903                 # or an existing one
904                 current=obstack[len(obstack)-1]
905                 tname=w[1].name()
906                 if tname in current.children:
907                         # Not new
908                         current=current.children[tname]
909                         if current.depth==2:
910                                 opmode.check_group(current.group, w)
911                 else:
912                         # New
913                         # Ignore depth check for now
914                         nl=levels[keyword.raw()](w)
915                         if nl.depth<allow_defs:
916                                 complain("New definitions not allowed at "
917                                         "level %d"%nl.depth)
918                                 # we risk crashing if we continue
919                                 sys.exit(1)
920                         current.children[tname]=nl
921                         current=nl
922                 filterstate.reset()
923                 obstack.append(current)
924                 return copyout()
925         if keyword.raw() not in current.allow_properties:
926                 complain("Property %s not allowed at %s level"%
927                         (keyword.raw(),current.type))
928                 return []
929         elif current.depth == vpnlevel.depth < allow_defs:
930                 complain("Not allowed to set VPN properties here")
931                 return []
932         else:
933                 prop=set_property(current,w)
934                 out=[copyout_core()]
935                 out=prop.forsites(output_version,out,filterstate)
936                 if len(out)==0: return [indent + '#', copyout_core(), '\n']
937                 return [indent + ' '.join(out) + '\n']
938
939         complain("unknown keyword '%s'"%(keyword.raw()))
940
941 def pfilepath(pathname,allow_include=False):
942         f=open(pathname)
943         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
944         f.close()
945         return outlines
946
947 def pfile(name,lines,allow_include=False):
948         "Process a file"
949         global file,line
950         file=name
951         line=0
952         outlines=[]
953         filterstate = FilterState()
954         for i in lines:
955                 line=line+1
956                 if (i[0]=='#'): continue
957                 outlines += pline(i,filterstate,allow_include=allow_include)
958         return outlines
959
960 def outputsites(w):
961         "Output include file for secnet configuration"
962         w.write("# secnet sites file autogenerated by make-secnet-sites "
963                 +"version %s\n"%VERSION)
964         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
965         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
966
967         # Raw VPN data section of file
968         w.write(prefix+"vpn-data {\n")
969         for i in root.children.values():
970                 i.output_data(w,(i,))
971         w.write("};\n")
972
973         # Per-VPN flattened lists
974         w.write(prefix+"vpn {\n")
975         for i in root.children.values():
976                 i.output_vpnflat(w,())
977         w.write("};\n")
978
979         # Flattened list of sites
980         w.write(prefix+"all-sites %s;\n"%",".join(
981                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x.kname()),
982                         root.children.values())))
983
984 line=0
985 file=None
986 complaints=0
987
988 # Sanity check section
989 # Delete nodes where leaf=0 that have no children
990
991 def live(n):
992         "Number of leafnodes below node n"
993         if n.leaf: return 1
994         for i in n.children.keys():
995                 if live(n.children[i]): return 1
996         return 0
997 def delempty(n):
998         "Delete nodes that have no leafnode children"
999         for i in list(n.children.keys()):
1000                 delempty(n.children[i])
1001                 if not live(n.children[i]):
1002                         del n.children[i]
1003
1004 # Check that all constraints are met (as far as I can tell
1005 # restrict-nets/networks/peer are the only special cases)
1006
1007 def checkconstraints(n,p,ra):
1008         new_p=p.copy()
1009         new_p.update(n.properties)
1010         for i in n.require_properties.keys():
1011                 if i not in new_p:
1012                         moan("%s %s is missing property %s"%
1013                                 (n.type,n.name,i))
1014         for i in new_p.keys():
1015                 if i not in n.allow_properties:
1016                         moan("%s %s has forbidden property %s"%
1017                                 (n.type,n.name,i))
1018         # Check address range restrictions
1019         if "restrict-nets" in n.properties:
1020                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
1021         else:
1022                 new_ra=ra
1023         if "networks" in n.properties:
1024                 if not n.properties["networks"].set <= new_ra:
1025                         moan("%s %s networks out of bounds"%(n.type,n.name))
1026                 if "peer" in n.properties:
1027                         if not n.properties["networks"].set.contains(
1028                                 n.properties["peer"].addr):
1029                                 moan("%s %s peer not in networks"%(n.type,n.name))
1030         for i in n.children.keys():
1031                 checkconstraints(n.children[i],new_p,new_ra)
1032
1033 opmode.read_in()
1034
1035 delempty(root)
1036 checkconstraints(root,{},ipaddrset.complete_set())
1037
1038 if complaints>0:
1039         if complaints==1: print("There was 1 problem.")
1040         else: print("There were %d problems."%(complaints))
1041         sys.exit(1)
1042 complaints=None # arranges to crash if we complain later
1043
1044 opmode.write_out()