chiark / gitweb /
make-secnet-sites: Use argparse rather than ad-hoc parser
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64
65 import ipaddress
66
67 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
68 # which we don't want to preempt
69 sys.path.insert(1,"/usr/local/share/secnet")
70 sys.path.insert(1,"/usr/share/secnet")
71 import ipaddrset
72
73 VERSION="0.1.18"
74
75 from sys import version_info
76 if version_info.major == 2:  # for python2
77     import codecs
78     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
79     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
80     import io
81     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
82
83 def parse_args():
84         global service
85         global inputfile
86         global header
87         global groupfiledir
88         global sitesfile
89         global group
90         global user
91         global of
92
93         ap = argparse.ArgumentParser(description='process secnet sites files')
94         ap.add_argument('--userv', '-u', action='store_true',
95                         help='userv service fragment update mode')
96         ap.add_argument('--prefix', '-P', nargs=1,
97                         help='set prefix')
98         ap.add_argument('arg',nargs=argparse.REMAINDER)
99         av = ap.parse_args()
100         #print(repr(av), file=sys.stderr)
101         service = 1 if av.userv else 0
102         if service:
103                 if len(av.arg)!=4:
104                         print("Wrong number of arguments")
105                         sys.exit(1)
106                 (header, groupfiledir, sitesfile, group) = av.arg
107                 if "USERV_USER" not in os.environ:
108                         print("Environment variable USERV_USER not found")
109                         sys.exit(1)
110                 user=os.environ["USERV_USER"]
111                 # Check that group is in USERV_GROUP
112                 if "USERV_GROUP" not in os.environ:
113                         print("Environment variable USERV_GROUP not found")
114                         sys.exit(1)
115                 ugs=os.environ["USERV_GROUP"]
116                 ok=0
117                 for i in ugs.split():
118                         if group==i: ok=1
119                 if not ok:
120                         print("caller not in group %s"%group)
121                         sys.exit(1)
122         else:
123                 if len(av.arg)>3:
124                         print("Too many arguments")
125                         sys.exit(1)
126                 (inputfile, outputfile) = (av.arg + [None]*2)[0:2]
127                 if outputfile is None: of=sys.stdout
128                 else: of=open(sys.argv[2],'w')
129
130 parse_args()
131
132 # Classes describing possible datatypes in the configuration file
133
134 class basetype:
135         "Common protocol for configuration types."
136         def add(self,obj,w):
137                 complain("%s %s already has property %s defined"%
138                         (obj.type,obj.name,w[0]))
139
140 class conflist:
141         "A list of some kind of configuration type."
142         def __init__(self,subtype,w):
143                 self.subtype=subtype
144                 self.list=[subtype(w)]
145         def add(self,obj,w):
146                 self.list.append(self.subtype(w))
147         def __str__(self):
148                 return ', '.join(map(str, self.list))
149 def listof(subtype):
150         return lambda w: conflist(subtype, w)
151
152 class single_ipaddr (basetype):
153         "An IP address"
154         def __init__(self,w):
155                 self.addr=ipaddress.ip_address(w[1])
156         def __str__(self):
157                 return '"%s"'%self.addr
158
159 class networks (basetype):
160         "A set of IP addresses specified as a list of networks"
161         def __init__(self,w):
162                 self.set=ipaddrset.IPAddressSet()
163                 for i in w[1:]:
164                         x=ipaddress.ip_network(i,strict=True)
165                         self.set.append([x])
166         def __str__(self):
167                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
168
169 class dhgroup (basetype):
170         "A Diffie-Hellman group"
171         def __init__(self,w):
172                 self.mod=w[1]
173                 self.gen=w[2]
174         def __str__(self):
175                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
176
177 class hash (basetype):
178         "A choice of hash function"
179         def __init__(self,w):
180                 self.ht=w[1]
181                 if (self.ht!='md5' and self.ht!='sha1'):
182                         complain("unknown hash type %s"%(self.ht))
183         def __str__(self):
184                 return '%s'%(self.ht)
185
186 class email (basetype):
187         "An email address"
188         def __init__(self,w):
189                 self.addr=w[1]
190         def __str__(self):
191                 return '<%s>'%(self.addr)
192
193 class boolean (basetype):
194         "A boolean"
195         def __init__(self,w):
196                 if re.match('[TtYy1]',w[1]):
197                         self.b=True
198                 elif re.match('[FfNn0]',w[1]):
199                         self.b=False
200                 else:
201                         complain("invalid boolean value");
202         def __str__(self):
203                 return ['False','True'][self.b]
204
205 class num (basetype):
206         "A decimal number"
207         def __init__(self,w):
208                 self.n=int(w[1])
209         def __str__(self):
210                 return '%d'%(self.n)
211
212 class address (basetype):
213         "A DNS name and UDP port number"
214         def __init__(self,w):
215                 self.adr=w[1]
216                 self.port=int(w[2])
217                 if (self.port<1 or self.port>65535):
218                         complain("invalid port number")
219         def __str__(self):
220                 return '"%s"; port %d'%(self.adr,self.port)
221
222 class rsakey (basetype):
223         "An RSA public key"
224         def __init__(self,w):
225                 self.l=int(w[1])
226                 self.e=w[2]
227                 self.n=w[3]
228         def __str__(self):
229                 return 'rsa-public("%s","%s")'%(self.e,self.n)
230
231 # Possible properties of configuration nodes
232 keywords={
233  'contact':(email,"Contact address"),
234  'dh':(dhgroup,"Diffie-Hellman group"),
235  'hash':(hash,"Hash function"),
236  'key-lifetime':(num,"Maximum key lifetime (ms)"),
237  'setup-timeout':(num,"Key setup timeout (ms)"),
238  'setup-retries':(num,"Maximum key setup packet retries"),
239  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
240  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
241  'restrict-nets':(networks,"Allowable networks"),
242  'networks':(networks,"Claimed networks"),
243  'pubkey':(rsakey,"RSA public site key"),
244  'peer':(single_ipaddr,"Tunnel peer IP address"),
245  'address':(address,"External contact address and port"),
246  'mobile':(boolean,"Site is mobile"),
247 }
248
249 def sp(name,value):
250         "Simply output a property - the default case"
251         return "%s %s;\n"%(name,value)
252
253 # All levels support these properties
254 global_properties={
255         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
256         'dh':sp,
257         'hash':sp,
258         'key-lifetime':sp,
259         'setup-timeout':sp,
260         'setup-retries':sp,
261         'wait-time':sp,
262         'renegotiate-time':sp,
263         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
264 }
265
266 class level:
267         "A level in the configuration hierarchy"
268         depth=0
269         leaf=0
270         allow_properties={}
271         require_properties={}
272         def __init__(self,w):
273                 self.type=w[0]
274                 self.name=w[1]
275                 self.properties={}
276                 self.children={}
277         def indent(self,w,t):
278                 w.write("                 "[:t])
279         def prop_out(self,n):
280                 return self.allow_properties[n](n,str(self.properties[n]))
281         def output_props(self,w,ind):
282                 for i in self.properties.keys():
283                         if self.allow_properties[i]:
284                                 self.indent(w,ind)
285                                 w.write("%s"%self.prop_out(i))
286         def output_data(self,w,ind,np):
287                 self.indent(w,ind)
288                 w.write("%s {\n"%(self.name))
289                 self.output_props(w,ind+2)
290                 if self.depth==1: w.write("\n");
291                 for c in self.children.values():
292                         c.output_data(w,ind+2,np+self.name+"/")
293                 self.indent(w,ind)
294                 w.write("};\n")
295
296 class vpnlevel(level):
297         "VPN level in the configuration hierarchy"
298         depth=1
299         leaf=0
300         type="vpn"
301         allow_properties=global_properties.copy()
302         require_properties={
303          'contact':"VPN admin contact address"
304         }
305         def __init__(self,w):
306                 level.__init__(self,w)
307         def output_vpnflat(self,w,ind,h):
308                 "Output flattened list of site names for this VPN"
309                 self.indent(w,ind)
310                 w.write("%s {\n"%(self.name))
311                 for i in self.children.keys():
312                         self.children[i].output_vpnflat(w,ind+2,
313                                 h+"/"+self.name+"/"+i)
314                 w.write("\n")
315                 self.indent(w,ind+2)
316                 w.write("all-sites %s;\n"%
317                         ','.join(self.children.keys()))
318                 self.indent(w,ind)
319                 w.write("};\n")
320
321 class locationlevel(level):
322         "Location level in the configuration hierarchy"
323         depth=2
324         leaf=0
325         type="location"
326         allow_properties=global_properties.copy()
327         require_properties={
328          'contact':"Location admin contact address",
329         }
330         def __init__(self,w):
331                 level.__init__(self,w)
332                 self.group=w[2]
333         def output_vpnflat(self,w,ind,h):
334                 self.indent(w,ind)
335                 # The "h=h,self=self" abomination below exists because
336                 # Python didn't support nested_scopes until version 2.1
337                 w.write("%s %s;\n"%(self.name,','.join(
338                         map(lambda x,h=h,self=self:
339                                 h+"/"+x,self.children.keys()))))
340
341 class sitelevel(level):
342         "Site level (i.e. a leafnode) in the configuration hierarchy"
343         depth=3
344         leaf=1
345         type="site"
346         allow_properties=global_properties.copy()
347         allow_properties.update({
348          'address':sp,
349          'networks':None,
350          'peer':None,
351          'pubkey':(lambda n,v:"key %s;\n"%v),
352          'mobile':sp,
353         })
354         require_properties={
355          'dh':"Diffie-Hellman group",
356          'contact':"Site admin contact address",
357          'networks':"Networks claimed by the site",
358          'hash':"hash function",
359          'peer':"Gateway address of the site",
360          'pubkey':"RSA public key of the site",
361         }
362         def __init__(self,w):
363                 level.__init__(self,w)
364         def output_data(self,w,ind,np):
365                 self.indent(w,ind)
366                 w.write("%s {\n"%(self.name))
367                 self.indent(w,ind+2)
368                 w.write("name \"%s\";\n"%(np+self.name))
369                 self.output_props(w,ind+2)
370                 self.indent(w,ind+2)
371                 w.write("link netlink {\n");
372                 self.indent(w,ind+4)
373                 w.write("routes %s;\n"%str(self.properties["networks"]))
374                 self.indent(w,ind+4)
375                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
376                 self.indent(w,ind+2)
377                 w.write("};\n")
378                 self.indent(w,ind)
379                 w.write("};\n")
380
381 # Levels in the configuration file
382 # (depth,properties)
383 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
384
385 # Reserved vpn/location/site names
386 reserved={'all-sites':None}
387 reserved.update(keywords)
388 reserved.update(levels)
389
390 def complain(msg):
391         "Complain about a particular input line"
392         global complaints
393         print(("%s line %d: "%(file,line))+msg)
394         complaints=complaints+1
395 def moan(msg):
396         "Complain about something in general"
397         global complaints
398         print(msg);
399         complaints=complaints+1
400
401 root=level(['root','root'])   # All vpns are children of this node
402 obstack=[root]
403 allow_defs=0   # Level above which new definitions are permitted
404 prefix=''
405
406 def set_property(obj,w):
407         "Set a property on a configuration node"
408         if w[0] in obj.properties:
409                 obj.properties[w[0]].add(obj,w)
410         else:
411                 obj.properties[w[0]]=keywords[w[0]][0](w)
412
413 def pline(i,allow_include=False):
414         "Process a configuration file line"
415         global allow_defs, obstack, root
416         w=i.rstrip('\n').split()
417         if len(w)==0: return [i]
418         keyword=w[0]
419         current=obstack[len(obstack)-1]
420         if keyword=='end-definitions':
421                 allow_defs=sitelevel.depth
422                 obstack=[root]
423                 return [i]
424         if keyword=='include':
425                 if not allow_include:
426                         complain("include not permitted here")
427                         return []
428                 if len(w) != 2:
429                         complain("include requires one argument")
430                         return []
431                 newfile=os.path.join(os.path.dirname(file),w[1])
432                 return pfilepath(newfile,allow_include=allow_include)
433         if keyword in levels:
434                 # We may go up any number of levels, but only down by one
435                 newdepth=levels[keyword].depth
436                 currentdepth=len(obstack) # actually +1...
437                 if newdepth<=currentdepth:
438                         obstack=obstack[:newdepth]
439                 if newdepth>currentdepth:
440                         complain("May not go from level %d to level %d"%
441                                 (currentdepth-1,newdepth))
442                 # See if it's a new one (and whether that's permitted)
443                 # or an existing one
444                 current=obstack[len(obstack)-1]
445                 if w[1] in current.children:
446                         # Not new
447                         current=current.children[w[1]]
448                         if service and group and current.depth==2:
449                                 if group!=current.group:
450                                         complain("Incorrect group!")
451                 else:
452                         # New
453                         # Ignore depth check for now
454                         nl=levels[keyword](w)
455                         if nl.depth<allow_defs:
456                                 complain("New definitions not allowed at "
457                                         "level %d"%nl.depth)
458                                 # we risk crashing if we continue
459                                 sys.exit(1)
460                         current.children[w[1]]=nl
461                         current=nl
462                 obstack.append(current)
463                 return [i]
464         if keyword not in current.allow_properties:
465                 complain("Property %s not allowed at %s level"%
466                         (keyword,current.type))
467                 return []
468         elif current.depth == vpnlevel.depth < allow_defs:
469                 complain("Not allowed to set VPN properties here")
470                 return []
471         else:
472                 set_property(current,w)
473                 return [i]
474
475         complain("unknown keyword '%s'"%(keyword))
476
477 def pfilepath(pathname,allow_include=False):
478         f=open(pathname)
479         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
480         f.close()
481         return outlines
482
483 def pfile(name,lines,allow_include=False):
484         "Process a file"
485         global file,line
486         file=name
487         line=0
488         outlines=[]
489         for i in lines:
490                 line=line+1
491                 if (i[0]=='#'): continue
492                 outlines += pline(i,allow_include=allow_include)
493         return outlines
494
495 def outputsites(w):
496         "Output include file for secnet configuration"
497         w.write("# secnet sites file autogenerated by make-secnet-sites "
498                 +"version %s\n"%VERSION)
499         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
500         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
501
502         # Raw VPN data section of file
503         w.write(prefix+"vpn-data {\n")
504         for i in root.children.values():
505                 i.output_data(w,2,"")
506         w.write("};\n")
507
508         # Per-VPN flattened lists
509         w.write(prefix+"vpn {\n")
510         for i in root.children.values():
511                 i.output_vpnflat(w,2,prefix+"vpn-data")
512         w.write("};\n")
513
514         # Flattened list of sites
515         w.write(prefix+"all-sites %s;\n"%",".join(
516                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x),
517                         root.children.keys())))
518
519 line=0
520 file=None
521 complaints=0
522
523 # Sanity check section
524 # Delete nodes where leaf=0 that have no children
525
526 def live(n):
527         "Number of leafnodes below node n"
528         if n.leaf: return 1
529         for i in n.children.keys():
530                 if live(n.children[i]): return 1
531         return 0
532 def delempty(n):
533         "Delete nodes that have no leafnode children"
534         for i in list(n.children.keys()):
535                 delempty(n.children[i])
536                 if not live(n.children[i]):
537                         del n.children[i]
538
539 # Check that all constraints are met (as far as I can tell
540 # restrict-nets/networks/peer are the only special cases)
541
542 def checkconstraints(n,p,ra):
543         new_p=p.copy()
544         new_p.update(n.properties)
545         for i in n.require_properties.keys():
546                 if i not in new_p:
547                         moan("%s %s is missing property %s"%
548                                 (n.type,n.name,i))
549         for i in new_p.keys():
550                 if i not in n.allow_properties:
551                         moan("%s %s has forbidden property %s"%
552                                 (n.type,n.name,i))
553         # Check address range restrictions
554         if "restrict-nets" in n.properties:
555                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
556         else:
557                 new_ra=ra
558         if "networks" in n.properties:
559                 if not n.properties["networks"].set <= new_ra:
560                         moan("%s %s networks out of bounds"%(n.type,n.name))
561                 if "peer" in n.properties:
562                         if not n.properties["networks"].set.contains(
563                                 n.properties["peer"].addr):
564                                 moan("%s %s peer not in networks"%(n.type,n.name))
565         for i in n.children.keys():
566                 checkconstraints(n.children[i],new_p,new_ra)
567
568 if service:
569         headerinput=pfilepath(header,allow_include=True)
570         userinput=sys.stdin.readlines()
571         pfile("user input",userinput)
572 else:
573         if inputfile is None:
574                 pfile("stdin",sys.stdin.readlines())
575         else:
576                 pfilepath(inputfile)
577
578 delempty(root)
579 checkconstraints(root,{},ipaddrset.complete_set())
580
581 if complaints>0:
582         if complaints==1: print("There was 1 problem.")
583         else: print("There were %d problems."%(complaints))
584         sys.exit(1)
585
586 if service:
587         # Put the user's input into their group file, and rebuild the main
588         # sites file
589         f=open(groupfiledir+"/T"+group,'w')
590         f.write("# Section submitted by user %s, %s\n"%
591                 (user,time.asctime(time.localtime(time.time()))))
592         f.write("# Checked by make-secnet-sites version %s\n\n"%VERSION)
593         for i in userinput: f.write(i)
594         f.write("\n")
595         f.close()
596         os.rename(groupfiledir+"/T"+group,groupfiledir+"/R"+group)
597         f=open(sitesfile+"-tmp",'w')
598         f.write("# sites file autogenerated by make-secnet-sites\n")
599         f.write("# generated %s, invoked by %s\n"%
600                 (time.asctime(time.localtime(time.time())),user))
601         f.write("# use make-secnet-sites to turn this file into a\n")
602         f.write("# valid /etc/secnet/sites.conf file\n\n")
603         for i in headerinput: f.write(i)
604         files=os.listdir(groupfiledir)
605         for i in files:
606                 if i[0]=='R':
607                         j=open(groupfiledir+"/"+i)
608                         f.write(j.read())
609                         j.close()
610         f.write("# end of sites file\n")
611         f.close()
612         os.rename(sitesfile+"-tmp",sitesfile)
613 else:
614         outputsites(of)