chiark / gitweb /
make-secnet-sites: OpConf: Move positional_args to OpBase
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64 import math
65
66 import ipaddress
67
68 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
69 # which we don't want to preempt
70 sys.path.insert(1,"/usr/local/share/secnet")
71 sys.path.insert(1,"/usr/share/secnet")
72 import ipaddrset
73 import base91
74
75 from argparseactionnoyes import ActionNoYes
76
77 VERSION="0.1.18"
78
79 max_version = 2
80
81 from sys import version_info
82 if version_info.major == 2:  # for python2
83     import codecs
84     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
85     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
86     import io
87     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
88
89 max={'rsa_bits':8200,'name':33,'dh_bits':8200,'algname':127}
90
91 def debugrepr(*args):
92         if debug_level > 0:
93                 print(repr(args), file=sys.stderr)
94
95 def base91s_encode(bindata):
96         return base91.encode(bindata).replace('"',"-")
97
98 def base91s_decode(string):
99         return base91.decode(string.replace("-",'"'))
100
101 class Tainted:
102         def __init__(self,s,tline=None,tfile=None):
103                 self._s=s
104                 self._ok=None
105                 self._line=line if tline is None else tline
106                 self._file=file if tfile is None else tfile
107         def __eq__(self,e):
108                 return self._s==e
109         def __ne__(self,e):
110                 # for Python2
111                 return not self.__eq__(e)
112         def __str__(self):
113                 raise RuntimeError('direct use of Tainted value')
114         def __repr__(self):
115                 return 'Tainted(%s)' % repr(self._s)
116
117         def _bad(self,what,why):
118                 assert(self._ok is not True)
119                 self._ok=False
120                 complain('bad parameter: %s: %s' % (what, why))
121                 return False
122
123         def _max_ok(self,what,maxlen):
124                 if len(self._s) > maxlen:
125                         return self._bad(what,'too long (max %d)' % maxlen)
126                 return True
127
128         def _re_ok(self,bad,what,maxlen=None):
129                 if maxlen is None: maxlen=max[what]
130                 self._max_ok(what,maxlen)
131                 if self._ok is False: return False
132                 if bad.search(self._s):
133                         #print(repr(self), file=sys.stderr)
134                         return self._bad(what,'bad syntax')
135                 return True
136
137         def _rtnval(self, is_ok, ifgood, ifbad=''):
138                 if is_ok:
139                         assert(self._ok is not False)
140                         self._ok=True
141                         return ifgood
142                 else:
143                         assert(self._ok is not True)
144                         self._ok=False
145                         return ifbad
146
147         def _rtn(self, is_ok, ifbad=''):
148                 return self._rtnval(is_ok, self._s, ifbad)
149
150         def raw(self):
151                 return self._s
152         def raw_mark_ok(self):
153                 # caller promises to throw if syntax was dangeorus
154                 return self._rtn(True)
155
156         def output(self):
157                 if self._ok is False: return ''
158                 if self._ok is True: return self._s
159                 print('%s:%d: unchecked/unknown additional data "%s"' %
160                       (self._file,self._line,self._s),
161                       file=sys.stderr)
162                 sys.exit(1)
163
164         bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
165         # secnet accepts _ at start of names, but we reserve that
166         bad_name_counter=0
167         def name(self,what='name'):
168                 ok=self._re_ok(Tainted.bad_name,what)
169                 return self._rtn(ok,
170                                  '_line%d_%s' % (self._line, id(self)))
171
172         def keyword(self):
173                 ok=self._s in keywords or self._s in levels
174                 if not ok:
175                         complain('unknown keyword %s' % self._s)
176                 return self._rtn(ok)
177
178         bad_hex=re.compile(r'[^0-9a-fA-F]')
179         def bignum_16(self,kind,what):
180                 maxlen=(max[kind+'_bits']+3)/4
181                 ok=self._re_ok(Tainted.bad_hex,what,maxlen)
182                 return self._rtn(ok)
183
184         bad_num=re.compile(r'[^0-9]')
185         def bignum_10(self,kind,what):
186                 maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
187                 ok=self._re_ok(Tainted.bad_num,what,maxlen)
188                 return self._rtn(ok)
189
190         def number(self,minn,maxx,what='number'):
191                 # not for bignums
192                 ok=self._re_ok(Tainted.bad_num,what,10)
193                 if ok:
194                         v=int(self._s)
195                         if v<minn or v>maxx:
196                                 ok=self._bad(what,'out of range %d..%d'
197                                              % (minn,maxx))
198                 return self._rtnval(ok,v,minn)
199
200         def hexid(self,byteslen,what):
201                 ok=self._re_ok(Tainted.bad_hex,what,byteslen*2)
202                 if ok:
203                         if len(self._s) < byteslen*2:
204                                 ok=self._bad(what,'too short')
205                 return self._rtn(ok,ifbad='00'*byteslen)
206
207         bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
208         # We permit _ so we can refer to special non-host domains
209         # which have A and AAAA RRs.  This is a crude check and we may
210         # still produce config files with syntactically invalid
211         # domains or addresses, but that is OK.
212         def host(self):
213                 ok=self._re_ok(Tainted.bad_host,'host/address',255)
214                 return self._rtn(ok)
215
216         bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
217         # ^ This does not accept all valid email addresses.  That's
218         # not really possible with this input syntax.  It accepts
219         # all ones that don't require quoting anywhere in email
220         # protocols (and also accepts some invalid ones).
221         def email(self):
222                 ok=self._re_ok(Tainted.bad_email,'email address',1023)
223                 return self._rtn(ok)
224
225         bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
226         def groupname(self):
227                 ok=self._re_ok(Tainted.bad_groupname,'group name',64)
228                 return self._rtn(ok)
229
230         bad_base91=re.compile(r'[^!-~]|[\'\"\\]')
231         def base91(self,what='base91'):
232                 ok=self._re_ok(Tainted.bad_base91,what,4096)
233                 return self._rtn(ok)
234
235 class ArgActionLambda(argparse.Action):
236         def __init__(self, fn, **kwargs):
237                 self.fn=fn
238                 argparse.Action.__init__(self,**kwargs)
239         def __call__(self,ap,ns,values,option_string):
240                 self.fn(values,ns,ap,option_string)
241
242 class PkmBase():
243         def site_start(self,pubkeys_path):
244                 self._pa=pubkeys_path
245                 self._fs = FilterState()
246         def site_serial(self,serial): pass
247         def write_key(self,k): pass
248         def site_finish(self,confw): pass
249
250 class PkmSingle(PkmBase):
251         opt = 'single'
252         help = 'write one public key per site to sites.conf'
253         def site_start(self,pubkeys_path):
254                 PkmBase.site_start(self,pubkeys_path)
255                 self._outk = []
256         def write_key(self,k):
257                 if k.okforonlykey(output_version,self._fs):
258                         self._outk.append(k)
259         def site_finish(self,confw):
260                 if len(self._outk) == 0:
261                         complain("site with no public key");
262                 elif len(self._outk) != 1:
263                         debugrepr('outk ', self._outk)
264                         complain(
265  "site with multiple public keys, without --pubkeys-install (maybe --output-version=1 would help"
266                         )
267                 else:
268                         confw.write("key %s;\n"%str(self._outk[0]))
269
270 class PkmInstall(PkmBase):
271         opt = 'install'
272         help = 'install public keys in public key directory'
273         def site_start(self,pubkeys_path):
274                 PkmBase.site_start(self,pubkeys_path)
275                 self._pw=open(self._pa+'~tmp','w')
276         def site_serial(self,serial):
277                 self._pw.write('serial %s\n' % serial)
278         def write_key(self,k):
279                 wout=k.forpub(output_version,self._fs)
280                 self._pw.write(' '.join(wout))
281                 self._pw.write('\n')
282         def site_finish(self,confw):
283                 self._pw.close()
284                 os.rename(self._pa+'~tmp',self._pa+'~update')
285                 PkmElide.site_finish(self,confw)
286
287 class PkmElide(PkmBase):
288         opt = 'elide'
289         help = 'no public keys in sites.conf output nor in directory'
290         def site_finish(self,confw):
291                 confw.write("peer-keys \"%s\";\n"%self._pa);
292
293 class OpBase():
294         # Base case is reading a sites file from self.inputfilee.
295         # And writing a sites file to self.sitesfile.
296         def positional_args(self, av):
297                 if len(av.arg)>3:
298                         print("Too many arguments")
299                         sys.exit(1)
300                 (self.inputfile, self.outputfile) = (av.arg + [None]*2)[0:2]
301         def read_in(self):
302                 if self.inputfile is None:
303                         self.inputlines = pfile("stdin",sys.stdin.readlines())
304                 else:
305                         self.inputlines = pfilepath(self.inputfile)
306         def write_out(self):
307                 if self.outputfile is None:
308                         f=sys.stdout
309                 else:
310                         f=open(self.outputfile+"-tmp",'w')
311                 f.write("# sites file autogenerated by make-secnet-sites\n")
312                 self.write_out_heading(f)
313                 f.write("# use make-secnet-sites to turn this file into a\n")
314                 f.write("# valid /etc/secnet/sites.conf file\n\n")
315                 self.write_out_contents(f)
316                 f.write("# end of sites file\n")
317                 if self.outputfile is not None:
318                         f.close()
319                         os.rename(self.outputfile+"-tmp",self.outputfile)
320
321 class OpConf(OpBase):
322         opts = ['--conf']
323         help = 'sites.conf generation mode (default)'
324         def check_group(self,group,w): pass
325         def write_out(self):
326                 if self.outputfile is None:
327                         of=sys.stdout
328                 else:
329                         tmp_outputfile=self.outputfile+'~tmp~'
330                         of=open(tmp_outputfile,'w')
331                 outputsites(of)
332                 if self.outputfile is not None:
333                         os.rename(tmp_outputfile,self.outputfile)
334
335 class OpUserv(OpBase):
336         opts = ['--userv','-u']
337         help = 'userv service fragment update mode'
338         def positional_args(self, av):
339                 if len(av.arg)!=4:
340                         print("Wrong number of arguments")
341                         sys.exit(1)
342                 (self.header, self.groupfiledir,
343                  self.outputfile, self.group) = av.arg
344                 self.group = Tainted(self.group,0,'command line')
345                 # untrusted argument from caller
346                 if "USERV_USER" not in os.environ:
347                         print("Environment variable USERV_USER not found")
348                         sys.exit(1)
349                 self.user=os.environ["USERV_USER"]
350                 # Check that group is in USERV_GROUP
351                 if "USERV_GROUP" not in os.environ:
352                         print("Environment variable USERV_GROUP not found")
353                         sys.exit(1)
354                 ugs=os.environ["USERV_GROUP"]
355                 ok=0
356                 for i in ugs.split():
357                         if self.group==i: ok=1
358                 if not ok:
359                         print("caller not in group %s"%group)
360                         sys.exit(1)
361         def check_group(self,group,w):
362                 if group!=self.group: complain("Incorrect group!")
363                 w[2].groupname()
364         def read_in(self):
365                 self.headerinput=pfilepath(self.header,allow_include=True)
366                 self.userinput=sys.stdin.readlines()
367                 pfile("user input",self.userinput)
368         def write_out(self):
369                 # Put the user's input into their group file, and
370                 # rebuild the main sites file
371                 f=open(self.groupfiledir+"/T"+self.group.groupname(),'w')
372                 f.write("# Section submitted by user %s, %s\n"%
373                         (self.user,time.asctime(time.localtime(time.time()))))
374                 f.write("# Checked by make-secnet-sites version %s\n\n"
375                         %VERSION)
376                 for i in self.userinput: f.write(i)
377                 f.write("\n")
378                 f.close()
379                 os.rename(self.groupfiledir+"/T"+self.group.groupname(),
380                           self.groupfiledir+"/R"+self.group.groupname())
381                 OpBase.write_out(self)
382         def write_out_heading(self,f):
383                 f.write("# generated %s, invoked by %s\n"%
384                         (time.asctime(time.localtime(time.time())),
385                          self.user))
386         def write_out_contents(self,f):
387                 for i in self.headerinput: f.write(i)
388                 files=os.listdir(self.groupfiledir)
389                 for i in files:
390                         if i[0]=='R':
391                                 j=open(self.groupfiledir+"/"+i)
392                                 f.write(j.read())
393                                 j.close()
394
395 def parse_args():
396         global opmode
397         global prefix
398         global key_prefix
399         global debug_level
400         global output_version
401         global pubkeys_dir
402         global pubkeys_mode
403
404         ap = argparse.ArgumentParser(description='process secnet sites files')
405         def add_opmode(how):
406                 ap.add_argument(*how().opts, action=ArgActionLambda,
407                         nargs=0,
408                         fn=(lambda v,ns,*x: setattr(ns,'opmode',how)),
409                         help=how().help)
410         add_opmode(OpConf)
411         add_opmode(OpUserv)
412         ap.add_argument('--conf-key-prefix', action=ActionNoYes,
413                         default=True,
414                  help='prefix conf file key names derived from sites data')
415         def add_pkm(how):
416                 ap.add_argument('--pubkeys-'+how().opt, action=ArgActionLambda,
417                         nargs=0,
418                         fn=(lambda v,ns,*x: setattr(ns,'pkm',how)),
419                         help=how().help)
420         add_pkm(PkmInstall)
421         add_pkm(PkmSingle)
422         add_pkm(PkmElide)
423         ap.add_argument('--pubkeys-dir',  nargs=1,
424                         help='public key directory',
425                         default=['/var/lib/secnet/pubkeys'])
426         ap.add_argument('--output-version', nargs=1, type=int,
427                         help='sites file output version',
428                         default=[max_version])
429         ap.add_argument('--prefix', '-P', nargs=1,
430                         help='set prefix')
431         ap.add_argument('--debug', '-D', action='count', default=0)
432         ap.add_argument('arg',nargs=argparse.REMAINDER)
433         av = ap.parse_args()
434         debug_level = av.debug
435         debugrepr('av',av)
436         opmode = getattr(av,'opmode',OpConf)()
437         prefix = '' if av.prefix is None else av.prefix[0]
438         key_prefix = av.conf_key_prefix
439         output_version = av.output_version[0]
440         pubkeys_dir = av.pubkeys_dir[0]
441         pubkeys_mode = getattr(av,'pkm',PkmSingle)
442         opmode.positional_args(av)
443
444 parse_args()
445
446 # Classes describing possible datatypes in the configuration file
447
448 class basetype:
449         "Common protocol for configuration types."
450         def add(self,obj,w):
451                 complain("%s %s already has property %s defined"%
452                         (obj.type,obj.name,w[0].raw()))
453         def forsites(self,version,copy,fs):
454                 return copy
455
456 class conflist:
457         "A list of some kind of configuration type."
458         def __init__(self,subtype,w):
459                 self.subtype=subtype
460                 self.list=[subtype(w)]
461         def add(self,obj,w):
462                 self.list.append(self.subtype(w))
463         def __str__(self):
464                 return ', '.join(map(str, self.list))
465         def forsites(self,version,copy,fs):
466                 most_recent=self.list[len(self.list)-1]
467                 return most_recent.forsites(version,copy,fs)
468 def listof(subtype):
469         return lambda w: conflist(subtype, w)
470
471 class single_ipaddr (basetype):
472         "An IP address"
473         def __init__(self,w):
474                 self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
475         def __str__(self):
476                 return '"%s"'%self.addr
477
478 class networks (basetype):
479         "A set of IP addresses specified as a list of networks"
480         def __init__(self,w):
481                 self.set=ipaddrset.IPAddressSet()
482                 for i in w[1:]:
483                         x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
484                         self.set.append([x])
485         def __str__(self):
486                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
487
488 class dhgroup (basetype):
489         "A Diffie-Hellman group"
490         def __init__(self,w):
491                 self.mod=w[1].bignum_16('dh','dh mod')
492                 self.gen=w[2].bignum_16('dh','dh gen')
493         def __str__(self):
494                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
495
496 class hash (basetype):
497         "A choice of hash function"
498         def __init__(self,w):
499                 hname=w[1]
500                 self.ht=hname.raw()
501                 if (self.ht!='md5' and self.ht!='sha1'):
502                         complain("unknown hash type %s"%(self.ht))
503                         self.ht=None
504                 else:
505                         hname.raw_mark_ok()
506         def __str__(self):
507                 return '%s'%(self.ht)
508
509 class email (basetype):
510         "An email address"
511         def __init__(self,w):
512                 self.addr=w[1].email()
513         def __str__(self):
514                 return '<%s>'%(self.addr)
515
516 class boolean (basetype):
517         "A boolean"
518         def __init__(self,w):
519                 v=w[1]
520                 if re.match('[TtYy1]',v.raw()):
521                         self.b=True
522                         v.raw_mark_ok()
523                 elif re.match('[FfNn0]',v.raw()):
524                         self.b=False
525                         v.raw_mark_ok()
526                 else:
527                         complain("invalid boolean value");
528         def __str__(self):
529                 return ['False','True'][self.b]
530
531 class num (basetype):
532         "A decimal number"
533         def __init__(self,w):
534                 self.n=w[1].number(0,0x7fffffff)
535         def __str__(self):
536                 return '%d'%(self.n)
537
538 class serial (basetype):
539         def __init__(self,w):
540                 self.i=w[1].hexid(4,'serial')
541         def __str__(self):
542                 return self.i
543         def forsites(self,version,copy,fs):
544                 if version < 2: return []
545                 return copy
546
547 class address (basetype):
548         "A DNS name and UDP port number"
549         def __init__(self,w):
550                 self.adr=w[1].host()
551                 self.port=w[2].number(1,65536,'port')
552         def __str__(self):
553                 return '"%s"; port %d'%(self.adr,self.port)
554
555 class inpub (basetype):
556         def forsites(self,version,xcopy,fs):
557                 return self.forpub(version,fs)
558
559 class pubkey (inpub):
560         "Some kind of publie key"
561         def __init__(self,w):
562                 self.a=w[1].name('algname')
563                 self.d=w[2].base91();
564         def __str__(self):
565                 return 'make-public("%s","%s")'%(self.a,self.d)
566         def forpub(self,version,fs):
567                 if version < 2: return []
568                 return ['pub', self.a, self.d]
569         def okforonlykey(self,version,fs):
570                 return len(self.forpub(version,fs)) != 0
571
572 class rsakey (pubkey):
573         "An old-style RSA public key"
574         def __init__(self,w):
575                 self.l=w[1].number(0,max['rsa_bits'],'rsa len')
576                 self.e=w[2].bignum_10('rsa','rsa e')
577                 self.n=w[3].bignum_10('rsa','rsa n')
578                 if len(w) >= 5: w[4].email()
579                 self.a='rsa1'
580                 self.d=base91s_encode(b'%d %s %s' %
581                                       (self.l,
582                                        self.e.encode('ascii'),
583                                        self.n.encode('ascii')))
584                 # ^ this allows us to use the pubkey.forsites()
585                 # method for output in versions>=2
586         def __str__(self):
587                 return 'rsa-public("%s","%s")'%(self.e,self.n)
588                 # this specialisation means we can generate files
589                 # compatible with old secnet executables
590         def forpub(self,version,fs):
591                 if version < 2:
592                         if fs.pkg != '00000000': return []
593                         return ['pubkey', str(self.l), self.e, self.n]
594                 return pubkey.forpub(self,version,fs)
595
596 class rsakey_newfmt(rsakey):
597         "An old-style RSA public key in new-style sites format"
598         # This is its own class simply to have its own constructor.
599         def __init__(self,w):
600                 self.a=w[1].name()
601                 assert(self.a == 'rsa1')
602                 self.d=w[2].base91()
603                 try:
604                         w_inner=list(map(Tainted,
605                                         ['X-PUB-RSA1'] +
606                                         base91s_decode(self.d)
607                                         .decode('ascii')
608                                         .split(' ')))
609                 except UnicodeDecodeError:
610                         complain('rsa1 key in new format has bad base91')
611                 #print(repr(w_inner), file=sys.stderr)
612                 rsakey.__init__(self,w_inner)
613
614 class pubkey_group(inpub):
615         "Public key group introducer"
616         # appears in the site's list of keys mixed in with the keys
617         def __init__(self,w,fallback):
618                 self.i=w[1].hexid(4,'pkg-id')
619                 self.fallback=fallback
620         def forpub(self,version,fs):
621                 fs.pkg=self.i
622                 if version < 2: return []
623                 return ['pkgf' if self.fallback else 'pkg', self.i]
624         def okforonlykey(self,version,fs):
625                 self.forpub(version,fs)
626                 return False
627         
628 def somepubkey(w):
629         #print(repr(w), file=sys.stderr)
630         if w[0]=='pubkey':
631                 return rsakey(w)
632         elif w[0]=='pub' and w[1]=='rsa1':
633                 return rsakey_newfmt(w)
634         elif w[0]=='pub':
635                 return pubkey(w)
636         elif w[0]=='pkg':
637                 return pubkey_group(w,False)
638         elif w[0]=='pkgf':
639                 return pubkey_group(w,True)
640         else:
641                 assert(False)
642
643 # Possible properties of configuration nodes
644 keywords={
645  'contact':(email,"Contact address"),
646  'dh':(dhgroup,"Diffie-Hellman group"),
647  'hash':(hash,"Hash function"),
648  'key-lifetime':(num,"Maximum key lifetime (ms)"),
649  'setup-timeout':(num,"Key setup timeout (ms)"),
650  'setup-retries':(num,"Maximum key setup packet retries"),
651  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
652  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
653  'restrict-nets':(networks,"Allowable networks"),
654  'networks':(networks,"Claimed networks"),
655  'serial':(serial,"public key set serial"),
656  'pkg':(listof(somepubkey),"start of public key group",'pub'),
657  'pkgf':(listof(somepubkey),"start of fallback public key group",'pub'),
658  'pub':(listof(somepubkey),"new style public site key"),
659  'pubkey':(listof(somepubkey),"Old-style RSA public site key",'pub'),
660  'peer':(single_ipaddr,"Tunnel peer IP address"),
661  'address':(address,"External contact address and port"),
662  'mobile':(boolean,"Site is mobile"),
663 }
664
665 def sp(name,value):
666         "Simply output a property - the default case"
667         return "%s %s;\n"%(name,value)
668
669 # All levels support these properties
670 global_properties={
671         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
672         'dh':sp,
673         'hash':sp,
674         'key-lifetime':sp,
675         'setup-timeout':sp,
676         'setup-retries':sp,
677         'wait-time':sp,
678         'renegotiate-time':sp,
679         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
680 }
681
682 class level:
683         "A level in the configuration hierarchy"
684         depth=0
685         leaf=0
686         allow_properties={}
687         require_properties={}
688         def __init__(self,w):
689                 self.type=w[0].keyword()
690                 self.name=w[1].name()
691                 self.properties={}
692                 self.children={}
693         def indent(self,w,t):
694                 w.write("                 "[:t])
695         def prop_out(self,n):
696                 return self.allow_properties[n](n,str(self.properties[n]))
697         def output_props(self,w,ind):
698                 for i in sorted(self.properties.keys()):
699                         if self.allow_properties[i]:
700                                 self.indent(w,ind)
701                                 w.write("%s"%self.prop_out(i))
702         def kname(self):
703                 return ((self.type[0].upper() if key_prefix else '')
704                         + self.name)
705         def output_data(self,w,path):
706                 ind = 2*len(path)
707                 self.indent(w,ind)
708                 w.write("%s {\n"%(self.kname()))
709                 self.output_props(w,ind+2)
710                 if self.depth==1: w.write("\n");
711                 for k in sorted(self.children.keys()):
712                         c=self.children[k]
713                         c.output_data(w,path+(c,))
714                 self.indent(w,ind)
715                 w.write("};\n")
716
717 class vpnlevel(level):
718         "VPN level in the configuration hierarchy"
719         depth=1
720         leaf=0
721         type="vpn"
722         allow_properties=global_properties.copy()
723         require_properties={
724          'contact':"VPN admin contact address"
725         }
726         def __init__(self,w):
727                 level.__init__(self,w)
728         def output_vpnflat(self,w,path):
729                 "Output flattened list of site names for this VPN"
730                 ind=2*(len(path)+1)
731                 self.indent(w,ind)
732                 w.write("%s {\n"%(self.kname()))
733                 for i in self.children.keys():
734                         self.children[i].output_vpnflat(w,path+(self,))
735                 w.write("\n")
736                 self.indent(w,ind+2)
737                 w.write("all-sites %s;\n"%
738                         ','.join(map(lambda i: i.kname(),
739                                      self.children.values())))
740                 self.indent(w,ind)
741                 w.write("};\n")
742
743 class locationlevel(level):
744         "Location level in the configuration hierarchy"
745         depth=2
746         leaf=0
747         type="location"
748         allow_properties=global_properties.copy()
749         require_properties={
750          'contact':"Location admin contact address",
751         }
752         def __init__(self,w):
753                 level.__init__(self,w)
754                 self.group=w[2].groupname()
755         def output_vpnflat(self,w,path):
756                 ind=2*(len(path)+1)
757                 self.indent(w,ind)
758                 # The "path=path,self=self" abomination below exists because
759                 # Python didn't support nested_scopes until version 2.1
760                 #
761                 #"/"+self.name+"/"+i
762                 w.write("%s %s;\n"%(self.kname(),','.join(
763                         map(lambda x,path=path,self=self:
764                             '/'.join([prefix+"vpn-data"] + list(map(
765                                     lambda i: i.kname(),
766                                     path+(self,x)))),
767                             self.children.values()))))
768
769 class sitelevel(level):
770         "Site level (i.e. a leafnode) in the configuration hierarchy"
771         depth=3
772         leaf=1
773         type="site"
774         allow_properties=global_properties.copy()
775         allow_properties.update({
776          'address':sp,
777          'networks':None,
778          'peer':None,
779          'serial':None,
780          'pkg':None,
781          'pkgf':None,
782          'pub':None,
783          'pubkey':None,
784          'mobile':sp,
785         })
786         require_properties={
787          'dh':"Diffie-Hellman group",
788          'contact':"Site admin contact address",
789          'networks':"Networks claimed by the site",
790          'hash':"hash function",
791          'peer':"Gateway address of the site",
792         }
793         def mangle_name(self):
794                 return self.name.replace('/',',')
795         def pubkeys_path(self):
796                 return pubkeys_dir + '/peer.' + self.mangle_name()
797         def __init__(self,w):
798                 level.__init__(self,w)
799         def output_data(self,w,path):
800                 ind=2*len(path)
801                 np='/'.join(map(lambda i: i.name, path))
802                 self.indent(w,ind)
803                 w.write("%s {\n"%(self.kname()))
804                 self.indent(w,ind+2)
805                 w.write("name \"%s\";\n"%(np,))
806                 self.indent(w,ind+2)
807
808                 pkm = pubkeys_mode()
809                 debugrepr('pkm ',pkm)
810                 pkm.site_start(self.pubkeys_path())
811                 if 'serial' in self.properties:
812                         pkm.site_serial(self.properties['serial'])
813
814                 for k in self.properties["pub"].list:
815                         debugrepr('pubkeys ', k)
816                         pkm.write_key(k)
817
818                 pkm.site_finish(w)
819
820                 self.output_props(w,ind+2)
821                 self.indent(w,ind+2)
822                 w.write("link netlink {\n");
823                 self.indent(w,ind+4)
824                 w.write("routes %s;\n"%str(self.properties["networks"]))
825                 self.indent(w,ind+4)
826                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
827                 self.indent(w,ind+2)
828                 w.write("};\n")
829                 self.indent(w,ind)
830                 w.write("};\n")
831
832 # Levels in the configuration file
833 # (depth,properties)
834 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
835
836 def complain(msg):
837         "Complain about a particular input line"
838         moan(("%s line %d: "%(file,line))+msg)
839 def moan(msg):
840         "Complain about something in general"
841         global complaints
842         print(msg);
843         if complaints is None: sys.exit(1)
844         complaints=complaints+1
845
846 class UntaintedRoot():
847         def __init__(self,s): self._s=s
848         def name(self): return self._s
849         def keyword(self): return self._s
850
851 root=level([UntaintedRoot(x) for x in ['root','root']])
852 # All vpns are children of this node
853 obstack=[root]
854 allow_defs=0   # Level above which new definitions are permitted
855
856 def set_property(obj,w):
857         "Set a property on a configuration node"
858         prop=w[0]
859         propname=prop.raw_mark_ok()
860         kw=keywords[propname]
861         if len(kw) >= 3: propname=kw[2] # for aliases
862         if propname in obj.properties:
863                 obj.properties[propname].add(obj,w)
864         else:
865                 obj.properties[propname]=kw[0](w)
866         return obj.properties[propname]
867
868 class FilterState:
869         def __init__(self):
870                 self.reset()
871         def reset(self):
872                 # called when we enter a new node,
873                 # in particular, at the start of each site
874                 self.pkg = '00000000'
875
876 def pline(il,filterstate,allow_include=False):
877         "Process a configuration file line"
878         global allow_defs, obstack, root
879         w=il.rstrip('\n').split()
880         if len(w)==0: return ['']
881         w=list([Tainted(x) for x in w])
882         keyword=w[0]
883         current=obstack[len(obstack)-1]
884         copyout_core=lambda: ' '.join([ww.output() for ww in w])
885         indent='    '*len(obstack)
886         copyout=lambda: [indent + copyout_core() + '\n']
887         if keyword=='end-definitions':
888                 keyword.raw_mark_ok()
889                 allow_defs=sitelevel.depth
890                 obstack=[root]
891                 return copyout()
892         if keyword=='include':
893                 if not allow_include:
894                         complain("include not permitted here")
895                         return []
896                 if len(w) != 2:
897                         complain("include requires one argument")
898                         return []
899                 newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
900                 # ^ user of "include" is trusted so raw_mark_ok is good
901                 return pfilepath(newfile,allow_include=allow_include)
902         if keyword.raw() in levels:
903                 # We may go up any number of levels, but only down by one
904                 newdepth=levels[keyword.raw_mark_ok()].depth
905                 currentdepth=len(obstack) # actually +1...
906                 if newdepth<=currentdepth:
907                         obstack=obstack[:newdepth]
908                 if newdepth>currentdepth:
909                         complain("May not go from level %d to level %d"%
910                                 (currentdepth-1,newdepth))
911                 # See if it's a new one (and whether that's permitted)
912                 # or an existing one
913                 current=obstack[len(obstack)-1]
914                 tname=w[1].name()
915                 if tname in current.children:
916                         # Not new
917                         current=current.children[tname]
918                         if current.depth==2:
919                                 opmode.check_group(current.group, w)
920                 else:
921                         # New
922                         # Ignore depth check for now
923                         nl=levels[keyword.raw()](w)
924                         if nl.depth<allow_defs:
925                                 complain("New definitions not allowed at "
926                                         "level %d"%nl.depth)
927                                 # we risk crashing if we continue
928                                 sys.exit(1)
929                         current.children[tname]=nl
930                         current=nl
931                 filterstate.reset()
932                 obstack.append(current)
933                 return copyout()
934         if keyword.raw() not in current.allow_properties:
935                 complain("Property %s not allowed at %s level"%
936                         (keyword.raw(),current.type))
937                 return []
938         elif current.depth == vpnlevel.depth < allow_defs:
939                 complain("Not allowed to set VPN properties here")
940                 return []
941         else:
942                 prop=set_property(current,w)
943                 out=[copyout_core()]
944                 out=prop.forsites(output_version,out,filterstate)
945                 if len(out)==0: return [indent + '#', copyout_core(), '\n']
946                 return [indent + ' '.join(out) + '\n']
947
948         complain("unknown keyword '%s'"%(keyword.raw()))
949
950 def pfilepath(pathname,allow_include=False):
951         f=open(pathname)
952         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
953         f.close()
954         return outlines
955
956 def pfile(name,lines,allow_include=False):
957         "Process a file"
958         global file,line
959         file=name
960         line=0
961         outlines=[]
962         filterstate = FilterState()
963         for i in lines:
964                 line=line+1
965                 if (i[0]=='#'): continue
966                 outlines += pline(i,filterstate,allow_include=allow_include)
967         return outlines
968
969 def outputsites(w):
970         "Output include file for secnet configuration"
971         w.write("# secnet sites file autogenerated by make-secnet-sites "
972                 +"version %s\n"%VERSION)
973         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
974         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
975
976         # Raw VPN data section of file
977         w.write(prefix+"vpn-data {\n")
978         for i in root.children.values():
979                 i.output_data(w,(i,))
980         w.write("};\n")
981
982         # Per-VPN flattened lists
983         w.write(prefix+"vpn {\n")
984         for i in root.children.values():
985                 i.output_vpnflat(w,())
986         w.write("};\n")
987
988         # Flattened list of sites
989         w.write(prefix+"all-sites %s;\n"%",".join(
990                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x.kname()),
991                         root.children.values())))
992
993 line=0
994 file=None
995 complaints=0
996
997 # Sanity check section
998 # Delete nodes where leaf=0 that have no children
999
1000 def live(n):
1001         "Number of leafnodes below node n"
1002         if n.leaf: return 1
1003         for i in n.children.keys():
1004                 if live(n.children[i]): return 1
1005         return 0
1006 def delempty(n):
1007         "Delete nodes that have no leafnode children"
1008         for i in list(n.children.keys()):
1009                 delempty(n.children[i])
1010                 if not live(n.children[i]):
1011                         del n.children[i]
1012
1013 # Check that all constraints are met (as far as I can tell
1014 # restrict-nets/networks/peer are the only special cases)
1015
1016 def checkconstraints(n,p,ra):
1017         new_p=p.copy()
1018         new_p.update(n.properties)
1019         for i in n.require_properties.keys():
1020                 if i not in new_p:
1021                         moan("%s %s is missing property %s"%
1022                                 (n.type,n.name,i))
1023         for i in new_p.keys():
1024                 if i not in n.allow_properties:
1025                         moan("%s %s has forbidden property %s"%
1026                                 (n.type,n.name,i))
1027         # Check address range restrictions
1028         if "restrict-nets" in n.properties:
1029                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
1030         else:
1031                 new_ra=ra
1032         if "networks" in n.properties:
1033                 if not n.properties["networks"].set <= new_ra:
1034                         moan("%s %s networks out of bounds"%(n.type,n.name))
1035                 if "peer" in n.properties:
1036                         if not n.properties["networks"].set.contains(
1037                                 n.properties["peer"].addr):
1038                                 moan("%s %s peer not in networks"%(n.type,n.name))
1039         for i in n.children.keys():
1040                 checkconstraints(n.children[i],new_p,new_ra)
1041
1042 opmode.read_in()
1043
1044 delempty(root)
1045 checkconstraints(root,{},ipaddrset.complete_set())
1046
1047 if complaints>0:
1048         if complaints==1: print("There was 1 problem.")
1049         else: print("There were %d problems."%(complaints))
1050         sys.exit(1)
1051 complaints=None # arranges to crash if we complain later
1052
1053 opmode.write_out()