chiark / gitweb /
make-secnet-sites: Refactor operational code into OpModes
authorIan Jackson <ijackson@chiark.greenend.org.uk>
Wed, 4 Dec 2019 16:19:23 +0000 (16:19 +0000)
committerIan Jackson <ijackson@chiark.greenend.org.uk>
Sat, 15 Feb 2020 21:56:53 +0000 (21:56 +0000)
Everywhere we had conditionals on `service', move the two arms of the
if into methods on OpConf and OpUserv (of which we make one
singleton).

Many global variables become instance variables on the OpMode object.

The read_in method of OpConf is in OpBase because we are going to want
to reuse it.

Signed-off-by: Ian Jackson <ijackson@chiark.greenend.org.uk>
make-secnet-sites

index 46a887c8b73c049da7548f43e8d24f2a12ab21d9..c484e4c69f158244cb70138bab5fffa362ed77b2 100755 (executable)
@@ -291,27 +291,100 @@ class PkmElide(PkmBase):
                confw.write("peer-keys \"%s\";\n"%self._pa);
 
 class OpBase():
-       pass
+       # Base case is reading a sites file from self.inputfilee.
+       def read_in(self):
+               if self.inputfile is None:
+                       pfile("stdin",sys.stdin.readlines())
+               else:
+                       pfilepath(self.inputfile)
 
 class OpConf(OpBase):
        def is_service(self): return 0
+       def positional_args(self, av):
+               if len(av.arg)>3:
+                       print("Too many arguments")
+                       sys.exit(1)
+               (self.inputfile, self.outputfile) = (av.arg + [None]*2)[0:2]
+       def check_group(self,group,w): pass
+       def write_out(self):
+               if self.outputfile is None:
+                       of=sys.stdout
+               else:
+                       tmp_outputfile=self.outputfile+'~tmp~'
+                       of=open(tmp_outputfile,'w')
+               outputsites(of)
+               if self.outputfile is not None:
+                       os.rename(tmp_outputfile,self.outputfile)
 
 class OpUserv(OpBase):
        opts = ['--userv','-u']
        help = 'userv service fragment update mode'
        def is_service(self): return 1
+       def positional_args(self, av):
+               if len(av.arg)!=4:
+                       print("Wrong number of arguments")
+                       sys.exit(1)
+               (self.header, self.groupfiledir,
+                self.sitesfile, self.group) = av.arg
+               self.group = Tainted(self.group,0,'command line')
+               # untrusted argument from caller
+               if "USERV_USER" not in os.environ:
+                       print("Environment variable USERV_USER not found")
+                       sys.exit(1)
+               self.user=os.environ["USERV_USER"]
+               # Check that group is in USERV_GROUP
+               if "USERV_GROUP" not in os.environ:
+                       print("Environment variable USERV_GROUP not found")
+                       sys.exit(1)
+               ugs=os.environ["USERV_GROUP"]
+               ok=0
+               for i in ugs.split():
+                       if self.group==i: ok=1
+               if not ok:
+                       print("caller not in group %s"%group)
+                       sys.exit(1)
+       def check_group(self,group,w):
+               if group!=self.group: complain("Incorrect group!")
+               w[2].groupname()
+       def read_in(self):
+               self.headerinput=pfilepath(self.header,allow_include=True)
+               self.userinput=sys.stdin.readlines()
+               pfile("user input",self.userinput)
+       def write_out(self):
+               # Put the user's input into their group file, and
+               # rebuild the main sites file
+               f=open(self.groupfiledir+"/T"+self.group.groupname(),'w')
+               f.write("# Section submitted by user %s, %s\n"%
+                       (self.user,time.asctime(time.localtime(time.time()))))
+               f.write("# Checked by make-secnet-sites version %s\n\n"
+                       %VERSION)
+               for i in self.userinput: f.write(i)
+               f.write("\n")
+               f.close()
+               os.rename(self.groupfiledir+"/T"+self.group.groupname(),
+                         self.groupfiledir+"/R"+self.group.groupname())
+               f=open(self.sitesfile+"-tmp",'w')
+               f.write("# sites file autogenerated by make-secnet-sites\n")
+               f.write("# generated %s, invoked by %s\n"%
+                       (time.asctime(time.localtime(time.time())),
+                        self.user))
+               f.write("# use make-secnet-sites to turn this file into a\n")
+               f.write("# valid /etc/secnet/sites.conf file\n\n")
+               for i in self.headerinput: f.write(i)
+               files=os.listdir(self.groupfiledir)
+               for i in files:
+                       if i[0]=='R':
+                               j=open(self.groupfiledir+"/"+i)
+                               f.write(j.read())
+                               j.close()
+               f.write("# end of sites file\n")
+               f.close()
+               os.rename(self.sitesfile+"-tmp",self.sitesfile)
+               
 
 def parse_args():
        global opmode
        global service
-       global inputfile
-       global header
-       global groupfiledir
-       global sitesfile
-       global outputfile
-       global group
-       global user
-       global of
        global prefix
        global key_prefix
        global debug_level
@@ -357,33 +430,7 @@ def parse_args():
        output_version = av.output_version[0]
        pubkeys_dir = av.pubkeys_dir[0]
        pubkeys_mode = getattr(av,'pkm',PkmSingle)
-       if service:
-               if len(av.arg)!=4:
-                       print("Wrong number of arguments")
-                       sys.exit(1)
-               (header, groupfiledir, sitesfile, group) = av.arg
-               group = Tainted(group,0,'command line')
-               # untrusted argument from caller
-               if "USERV_USER" not in os.environ:
-                       print("Environment variable USERV_USER not found")
-                       sys.exit(1)
-               user=os.environ["USERV_USER"]
-               # Check that group is in USERV_GROUP
-               if "USERV_GROUP" not in os.environ:
-                       print("Environment variable USERV_GROUP not found")
-                       sys.exit(1)
-               ugs=os.environ["USERV_GROUP"]
-               ok=0
-               for i in ugs.split():
-                       if group==i: ok=1
-               if not ok:
-                       print("caller not in group %s"%group)
-                       sys.exit(1)
-       else:
-               if len(av.arg)>3:
-                       print("Too many arguments")
-                       sys.exit(1)
-               (inputfile, outputfile) = (av.arg + [None]*2)[0:2]
+       opmode.positional_args(av)
 
 parse_args()
 
@@ -859,10 +906,8 @@ def pline(il,filterstate,allow_include=False):
                if tname in current.children:
                        # Not new
                        current=current.children[tname]
-                       if service and group and current.depth==2:
-                               if group!=current.group:
-                                       complain("Incorrect group!")
-                               w[2].groupname()
+                       if current.depth==2:
+                               opmode.check_group(current.group, w)
                else:
                        # New
                        # Ignore depth check for now
@@ -985,15 +1030,7 @@ def checkconstraints(n,p,ra):
        for i in n.children.keys():
                checkconstraints(n.children[i],new_p,new_ra)
 
-if service:
-       headerinput=pfilepath(header,allow_include=True)
-       userinput=sys.stdin.readlines()
-       pfile("user input",userinput)
-else:
-       if inputfile is None:
-               pfile("stdin",sys.stdin.readlines())
-       else:
-               pfilepath(inputfile)
+opmode.read_in()
 
 delempty(root)
 checkconstraints(root,{},ipaddrset.complete_set())
@@ -1004,40 +1041,4 @@ if complaints>0:
        sys.exit(1)
 complaints=None # arranges to crash if we complain later
 
-if service:
-       # Put the user's input into their group file, and rebuild the main
-       # sites file
-       f=open(groupfiledir+"/T"+group.groupname(),'w')
-       f.write("# Section submitted by user %s, %s\n"%
-               (user,time.asctime(time.localtime(time.time()))))
-       f.write("# Checked by make-secnet-sites version %s\n\n"%VERSION)
-       for i in userinput: f.write(i)
-       f.write("\n")
-       f.close()
-       os.rename(groupfiledir+"/T"+group.groupname(),
-                 groupfiledir+"/R"+group.groupname())
-       f=open(sitesfile+"-tmp",'w')
-       f.write("# sites file autogenerated by make-secnet-sites\n")
-       f.write("# generated %s, invoked by %s\n"%
-               (time.asctime(time.localtime(time.time())),user))
-       f.write("# use make-secnet-sites to turn this file into a\n")
-       f.write("# valid /etc/secnet/sites.conf file\n\n")
-       for i in headerinput: f.write(i)
-       files=os.listdir(groupfiledir)
-       for i in files:
-               if i[0]=='R':
-                       j=open(groupfiledir+"/"+i)
-                       f.write(j.read())
-                       j.close()
-       f.write("# end of sites file\n")
-       f.close()
-       os.rename(sitesfile+"-tmp",sitesfile)
-else:
-       if outputfile is None:
-               of=sys.stdout
-       else:
-               tmp_outputfile=outputfile+'~tmp~'
-               of=open(tmp_outputfile,'w')
-       outputsites(of)
-       if outputfile is not None:
-               os.rename(tmp_outputfile,outputfile)
+opmode.write_out()