chiark / gitweb /
make-secnet-sites: Crash if complain() is called too late
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64
65 import ipaddress
66
67 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
68 # which we don't want to preempt
69 sys.path.insert(1,"/usr/local/share/secnet")
70 sys.path.insert(1,"/usr/share/secnet")
71 import ipaddrset
72
73 VERSION="0.1.18"
74
75 from sys import version_info
76 if version_info.major == 2:  # for python2
77     import codecs
78     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
79     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
80     import io
81     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
82
83 def parse_args():
84         global service
85         global inputfile
86         global header
87         global groupfiledir
88         global sitesfile
89         global group
90         global user
91         global of
92
93         ap = argparse.ArgumentParser(description='process secnet sites files')
94         ap.add_argument('--userv', '-u', action='store_true',
95                         help='userv service fragment update mode')
96         ap.add_argument('--prefix', '-P', nargs=1,
97                         help='set prefix')
98         ap.add_argument('arg',nargs=argparse.REMAINDER)
99         av = ap.parse_args()
100         #print(repr(av), file=sys.stderr)
101         service = 1 if av.userv else 0
102         if service:
103                 if len(av.arg)!=4:
104                         print("Wrong number of arguments")
105                         sys.exit(1)
106                 (header, groupfiledir, sitesfile, group) = av.arg
107                 if "USERV_USER" not in os.environ:
108                         print("Environment variable USERV_USER not found")
109                         sys.exit(1)
110                 user=os.environ["USERV_USER"]
111                 # Check that group is in USERV_GROUP
112                 if "USERV_GROUP" not in os.environ:
113                         print("Environment variable USERV_GROUP not found")
114                         sys.exit(1)
115                 ugs=os.environ["USERV_GROUP"]
116                 ok=0
117                 for i in ugs.split():
118                         if group==i: ok=1
119                 if not ok:
120                         print("caller not in group %s"%group)
121                         sys.exit(1)
122         else:
123                 if len(av.arg)>3:
124                         print("Too many arguments")
125                         sys.exit(1)
126                 (inputfile, outputfile) = (av.arg + [None]*2)[0:2]
127                 if outputfile is None: of=sys.stdout
128                 else: of=open(sys.argv[2],'w')
129
130 parse_args()
131
132 # Classes describing possible datatypes in the configuration file
133
134 class basetype:
135         "Common protocol for configuration types."
136         def add(self,obj,w):
137                 complain("%s %s already has property %s defined"%
138                         (obj.type,obj.name,w[0]))
139
140 class conflist:
141         "A list of some kind of configuration type."
142         def __init__(self,subtype,w):
143                 self.subtype=subtype
144                 self.list=[subtype(w)]
145         def add(self,obj,w):
146                 self.list.append(self.subtype(w))
147         def __str__(self):
148                 return ', '.join(map(str, self.list))
149 def listof(subtype):
150         return lambda w: conflist(subtype, w)
151
152 class single_ipaddr (basetype):
153         "An IP address"
154         def __init__(self,w):
155                 self.addr=ipaddress.ip_address(w[1])
156         def __str__(self):
157                 return '"%s"'%self.addr
158
159 class networks (basetype):
160         "A set of IP addresses specified as a list of networks"
161         def __init__(self,w):
162                 self.set=ipaddrset.IPAddressSet()
163                 for i in w[1:]:
164                         x=ipaddress.ip_network(i,strict=True)
165                         self.set.append([x])
166         def __str__(self):
167                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
168
169 class dhgroup (basetype):
170         "A Diffie-Hellman group"
171         def __init__(self,w):
172                 self.mod=w[1]
173                 self.gen=w[2]
174         def __str__(self):
175                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
176
177 class hash (basetype):
178         "A choice of hash function"
179         def __init__(self,w):
180                 hname=w[1]
181                 self.ht=hname
182                 if (self.ht!='md5' and self.ht!='sha1'):
183                         complain("unknown hash type %s"%(self.ht))
184         def __str__(self):
185                 return '%s'%(self.ht)
186
187 class email (basetype):
188         "An email address"
189         def __init__(self,w):
190                 self.addr=w[1]
191         def __str__(self):
192                 return '<%s>'%(self.addr)
193
194 class boolean (basetype):
195         "A boolean"
196         def __init__(self,w):
197                 if re.match('[TtYy1]',w[1]):
198                         self.b=True
199                 elif re.match('[FfNn0]',w[1]):
200                         self.b=False
201                 else:
202                         complain("invalid boolean value");
203         def __str__(self):
204                 return ['False','True'][self.b]
205
206 class num (basetype):
207         "A decimal number"
208         def __init__(self,w):
209                 self.n=int(w[1])
210         def __str__(self):
211                 return '%d'%(self.n)
212
213 class address (basetype):
214         "A DNS name and UDP port number"
215         def __init__(self,w):
216                 self.adr=w[1]
217                 self.port=int(w[2])
218                 if (self.port<1 or self.port>65535):
219                         complain("invalid port number")
220         def __str__(self):
221                 return '"%s"; port %d'%(self.adr,self.port)
222
223 class rsakey (basetype):
224         "An RSA public key"
225         def __init__(self,w):
226                 self.l=int(w[1])
227                 self.e=w[2]
228                 self.n=w[3]
229         def __str__(self):
230                 return 'rsa-public("%s","%s")'%(self.e,self.n)
231
232 # Possible properties of configuration nodes
233 keywords={
234  'contact':(email,"Contact address"),
235  'dh':(dhgroup,"Diffie-Hellman group"),
236  'hash':(hash,"Hash function"),
237  'key-lifetime':(num,"Maximum key lifetime (ms)"),
238  'setup-timeout':(num,"Key setup timeout (ms)"),
239  'setup-retries':(num,"Maximum key setup packet retries"),
240  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
241  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
242  'restrict-nets':(networks,"Allowable networks"),
243  'networks':(networks,"Claimed networks"),
244  'pubkey':(rsakey,"RSA public site key"),
245  'peer':(single_ipaddr,"Tunnel peer IP address"),
246  'address':(address,"External contact address and port"),
247  'mobile':(boolean,"Site is mobile"),
248 }
249
250 def sp(name,value):
251         "Simply output a property - the default case"
252         return "%s %s;\n"%(name,value)
253
254 # All levels support these properties
255 global_properties={
256         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
257         'dh':sp,
258         'hash':sp,
259         'key-lifetime':sp,
260         'setup-timeout':sp,
261         'setup-retries':sp,
262         'wait-time':sp,
263         'renegotiate-time':sp,
264         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
265 }
266
267 class level:
268         "A level in the configuration hierarchy"
269         depth=0
270         leaf=0
271         allow_properties={}
272         require_properties={}
273         def __init__(self,w):
274                 self.type=w[0]
275                 self.name=w[1]
276                 self.properties={}
277                 self.children={}
278         def indent(self,w,t):
279                 w.write("                 "[:t])
280         def prop_out(self,n):
281                 return self.allow_properties[n](n,str(self.properties[n]))
282         def output_props(self,w,ind):
283                 for i in self.properties.keys():
284                         if self.allow_properties[i]:
285                                 self.indent(w,ind)
286                                 w.write("%s"%self.prop_out(i))
287         def output_data(self,w,ind,np):
288                 self.indent(w,ind)
289                 w.write("%s {\n"%(self.name))
290                 self.output_props(w,ind+2)
291                 if self.depth==1: w.write("\n");
292                 for c in self.children.values():
293                         c.output_data(w,ind+2,np+self.name+"/")
294                 self.indent(w,ind)
295                 w.write("};\n")
296
297 class vpnlevel(level):
298         "VPN level in the configuration hierarchy"
299         depth=1
300         leaf=0
301         type="vpn"
302         allow_properties=global_properties.copy()
303         require_properties={
304          'contact':"VPN admin contact address"
305         }
306         def __init__(self,w):
307                 level.__init__(self,w)
308         def output_vpnflat(self,w,ind,h):
309                 "Output flattened list of site names for this VPN"
310                 self.indent(w,ind)
311                 w.write("%s {\n"%(self.name))
312                 for i in self.children.keys():
313                         self.children[i].output_vpnflat(w,ind+2,
314                                 h+"/"+self.name+"/"+i)
315                 w.write("\n")
316                 self.indent(w,ind+2)
317                 w.write("all-sites %s;\n"%
318                         ','.join(self.children.keys()))
319                 self.indent(w,ind)
320                 w.write("};\n")
321
322 class locationlevel(level):
323         "Location level in the configuration hierarchy"
324         depth=2
325         leaf=0
326         type="location"
327         allow_properties=global_properties.copy()
328         require_properties={
329          'contact':"Location admin contact address",
330         }
331         def __init__(self,w):
332                 level.__init__(self,w)
333                 self.group=w[2]
334         def output_vpnflat(self,w,ind,h):
335                 self.indent(w,ind)
336                 # The "h=h,self=self" abomination below exists because
337                 # Python didn't support nested_scopes until version 2.1
338                 w.write("%s %s;\n"%(self.name,','.join(
339                         map(lambda x,h=h,self=self:
340                                 h+"/"+x,self.children.keys()))))
341
342 class sitelevel(level):
343         "Site level (i.e. a leafnode) in the configuration hierarchy"
344         depth=3
345         leaf=1
346         type="site"
347         allow_properties=global_properties.copy()
348         allow_properties.update({
349          'address':sp,
350          'networks':None,
351          'peer':None,
352          'pubkey':(lambda n,v:"key %s;\n"%v),
353          'mobile':sp,
354         })
355         require_properties={
356          'dh':"Diffie-Hellman group",
357          'contact':"Site admin contact address",
358          'networks':"Networks claimed by the site",
359          'hash':"hash function",
360          'peer':"Gateway address of the site",
361          'pubkey':"RSA public key of the site",
362         }
363         def __init__(self,w):
364                 level.__init__(self,w)
365         def output_data(self,w,ind,np):
366                 self.indent(w,ind)
367                 w.write("%s {\n"%(self.name))
368                 self.indent(w,ind+2)
369                 w.write("name \"%s\";\n"%(np+self.name))
370                 self.output_props(w,ind+2)
371                 self.indent(w,ind+2)
372                 w.write("link netlink {\n");
373                 self.indent(w,ind+4)
374                 w.write("routes %s;\n"%str(self.properties["networks"]))
375                 self.indent(w,ind+4)
376                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
377                 self.indent(w,ind+2)
378                 w.write("};\n")
379                 self.indent(w,ind)
380                 w.write("};\n")
381
382 # Levels in the configuration file
383 # (depth,properties)
384 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
385
386 # Reserved vpn/location/site names
387 reserved={'all-sites':None}
388 reserved.update(keywords)
389 reserved.update(levels)
390
391 def complain(msg):
392         "Complain about a particular input line"
393         global complaints
394         print(("%s line %d: "%(file,line))+msg)
395         complaints=complaints+1
396 def moan(msg):
397         "Complain about something in general"
398         global complaints
399         print(msg);
400         complaints=complaints+1
401
402 root=level(['root','root'])   # All vpns are children of this node
403 obstack=[root]
404 allow_defs=0   # Level above which new definitions are permitted
405 prefix=''
406
407 def set_property(obj,w):
408         "Set a property on a configuration node"
409         prop=w[0]
410         if prop in obj.properties:
411                 obj.properties[prop].add(obj,w)
412         else:
413                 obj.properties[prop]=keywords[prop][0](w)
414
415 def pline(i,allow_include=False):
416         "Process a configuration file line"
417         global allow_defs, obstack, root
418         w=i.rstrip('\n').split()
419         if len(w)==0: return [i]
420         keyword=w[0]
421         current=obstack[len(obstack)-1]
422         if keyword=='end-definitions':
423                 allow_defs=sitelevel.depth
424                 obstack=[root]
425                 return [i]
426         if keyword=='include':
427                 if not allow_include:
428                         complain("include not permitted here")
429                         return []
430                 if len(w) != 2:
431                         complain("include requires one argument")
432                         return []
433                 newfile=os.path.join(os.path.dirname(file),w[1])
434                 return pfilepath(newfile,allow_include=allow_include)
435         if keyword in levels:
436                 # We may go up any number of levels, but only down by one
437                 newdepth=levels[keyword].depth
438                 currentdepth=len(obstack) # actually +1...
439                 if newdepth<=currentdepth:
440                         obstack=obstack[:newdepth]
441                 if newdepth>currentdepth:
442                         complain("May not go from level %d to level %d"%
443                                 (currentdepth-1,newdepth))
444                 # See if it's a new one (and whether that's permitted)
445                 # or an existing one
446                 current=obstack[len(obstack)-1]
447                 tname=w[1]
448                 if tname in current.children:
449                         # Not new
450                         current=current.children[tname]
451                         if service and group and current.depth==2:
452                                 if group!=current.group:
453                                         complain("Incorrect group!")
454                 else:
455                         # New
456                         # Ignore depth check for now
457                         nl=levels[keyword](w)
458                         if nl.depth<allow_defs:
459                                 complain("New definitions not allowed at "
460                                         "level %d"%nl.depth)
461                                 # we risk crashing if we continue
462                                 sys.exit(1)
463                         current.children[tname]=nl
464                         current=nl
465                 obstack.append(current)
466                 return [i]
467         if keyword not in current.allow_properties:
468                 complain("Property %s not allowed at %s level"%
469                         (keyword,current.type))
470                 return []
471         elif current.depth == vpnlevel.depth < allow_defs:
472                 complain("Not allowed to set VPN properties here")
473                 return []
474         else:
475                 set_property(current,w)
476                 return [i]
477
478         complain("unknown keyword '%s'"%(keyword))
479
480 def pfilepath(pathname,allow_include=False):
481         f=open(pathname)
482         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
483         f.close()
484         return outlines
485
486 def pfile(name,lines,allow_include=False):
487         "Process a file"
488         global file,line
489         file=name
490         line=0
491         outlines=[]
492         for i in lines:
493                 line=line+1
494                 if (i[0]=='#'): continue
495                 outlines += pline(i,allow_include=allow_include)
496         return outlines
497
498 def outputsites(w):
499         "Output include file for secnet configuration"
500         w.write("# secnet sites file autogenerated by make-secnet-sites "
501                 +"version %s\n"%VERSION)
502         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
503         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
504
505         # Raw VPN data section of file
506         w.write(prefix+"vpn-data {\n")
507         for i in root.children.values():
508                 i.output_data(w,2,"")
509         w.write("};\n")
510
511         # Per-VPN flattened lists
512         w.write(prefix+"vpn {\n")
513         for i in root.children.values():
514                 i.output_vpnflat(w,2,prefix+"vpn-data")
515         w.write("};\n")
516
517         # Flattened list of sites
518         w.write(prefix+"all-sites %s;\n"%",".join(
519                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x),
520                         root.children.keys())))
521
522 line=0
523 file=None
524 complaints=0
525
526 # Sanity check section
527 # Delete nodes where leaf=0 that have no children
528
529 def live(n):
530         "Number of leafnodes below node n"
531         if n.leaf: return 1
532         for i in n.children.keys():
533                 if live(n.children[i]): return 1
534         return 0
535 def delempty(n):
536         "Delete nodes that have no leafnode children"
537         for i in list(n.children.keys()):
538                 delempty(n.children[i])
539                 if not live(n.children[i]):
540                         del n.children[i]
541
542 # Check that all constraints are met (as far as I can tell
543 # restrict-nets/networks/peer are the only special cases)
544
545 def checkconstraints(n,p,ra):
546         new_p=p.copy()
547         new_p.update(n.properties)
548         for i in n.require_properties.keys():
549                 if i not in new_p:
550                         moan("%s %s is missing property %s"%
551                                 (n.type,n.name,i))
552         for i in new_p.keys():
553                 if i not in n.allow_properties:
554                         moan("%s %s has forbidden property %s"%
555                                 (n.type,n.name,i))
556         # Check address range restrictions
557         if "restrict-nets" in n.properties:
558                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
559         else:
560                 new_ra=ra
561         if "networks" in n.properties:
562                 if not n.properties["networks"].set <= new_ra:
563                         moan("%s %s networks out of bounds"%(n.type,n.name))
564                 if "peer" in n.properties:
565                         if not n.properties["networks"].set.contains(
566                                 n.properties["peer"].addr):
567                                 moan("%s %s peer not in networks"%(n.type,n.name))
568         for i in n.children.keys():
569                 checkconstraints(n.children[i],new_p,new_ra)
570
571 if service:
572         headerinput=pfilepath(header,allow_include=True)
573         userinput=sys.stdin.readlines()
574         pfile("user input",userinput)
575 else:
576         if inputfile is None:
577                 pfile("stdin",sys.stdin.readlines())
578         else:
579                 pfilepath(inputfile)
580
581 delempty(root)
582 checkconstraints(root,{},ipaddrset.complete_set())
583
584 if complaints>0:
585         if complaints==1: print("There was 1 problem.")
586         else: print("There were %d problems."%(complaints))
587         sys.exit(1)
588 complaints=None # arranges to crash if we complain later
589
590 if service:
591         # Put the user's input into their group file, and rebuild the main
592         # sites file
593         f=open(groupfiledir+"/T"+group,'w')
594         f.write("# Section submitted by user %s, %s\n"%
595                 (user,time.asctime(time.localtime(time.time()))))
596         f.write("# Checked by make-secnet-sites version %s\n\n"%VERSION)
597         for i in userinput: f.write(i)
598         f.write("\n")
599         f.close()
600         os.rename(groupfiledir+"/T"+group,groupfiledir+"/R"+group)
601         f=open(sitesfile+"-tmp",'w')
602         f.write("# sites file autogenerated by make-secnet-sites\n")
603         f.write("# generated %s, invoked by %s\n"%
604                 (time.asctime(time.localtime(time.time())),user))
605         f.write("# use make-secnet-sites to turn this file into a\n")
606         f.write("# valid /etc/secnet/sites.conf file\n\n")
607         for i in headerinput: f.write(i)
608         files=os.listdir(groupfiledir)
609         for i in files:
610                 if i[0]=='R':
611                         j=open(groupfiledir+"/"+i)
612                         f.write(j.read())
613                         j.close()
614         f.write("# end of sites file\n")
615         f.close()
616         os.rename(sitesfile+"-tmp",sitesfile)
617 else:
618         outputsites(of)