chiark / gitweb /
format.py: Document `#' as a format parameter.
[chopwood] / backend.py
index 54c5374d2b99ef1f84925b8b6cb04b379950389d..fda32e0b8e37d259090e71f4d6ac864bd407e069 100644 (file)
@@ -25,6 +25,9 @@
 
 from __future__ import with_statement
 
 
 from __future__ import with_statement
 
+from auto import HOME
+import errno as E
+import itertools as I
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
@@ -36,7 +39,63 @@ import util as U
 CONF.DEFAULTS.update(
 
   ## A directory in which we can create lockfiles.
 CONF.DEFAULTS.update(
 
   ## A directory in which we can create lockfiles.
-  LOCKDIR = OS.path.join(ENV['HOME'], 'var', 'lock', 'chpwd'))
+  LOCKDIR = OS.path.join(HOME, 'lock'))
+
+###--------------------------------------------------------------------------
+### Utilities.
+
+def fill_in_fields(fno_user, fno_passwd, fno_map, user, passwd, args):
+  """
+  Return a vector of filled-in fields.
+
+  The FNO_... arguments give field numbers: FNO_USER and FNO_PASSWD give the
+  positions for the username and password fields, respectively; and FNO_MAP
+  is a sequence of (NAME, POS) pairs.  The USER and PASSWD arguments give the
+  actual user name and password values; ARGS are the remaining arguments,
+  maybe in the form `NAME=VALUE'.
+  """
+
+  ## Prepare the result vector, and set up some data structures.
+  n = 2 + len(fno_map)
+  fmap = {}
+  rmap = map(int, xrange(n))
+  ok = True
+  if fno_user >= n or fno_passwd >= n: ok = False
+  for k, i in fno_map:
+    fmap[k] = i
+    rmap[i] = "`%s'" % k
+    if i >= n: ok = False
+  if not ok:
+    raise U.ExpectedError, \
+        (500, "Fields specified aren't contiguous")
+
+  ## Prepare the new record's fields.
+  f = [None]*n
+  f[fno_user] = user
+  f[fno_passwd] = passwd
+
+  for a in args:
+    if '=' in a:
+      k, v = a.split('=', 1)
+      try: i = fmap[k]
+      except KeyError: raise U.ExpectedError, (400, "Unknown field `%s'" % k)
+    else:
+      for i in xrange(n):
+        if f[i] is None: break
+      else:
+        raise U.ExpectedError, (500, "All fields already populated")
+      v = a
+    if f[i] is not None:
+      raise U.ExpectedError, (400, "Field %s is already set" % rmap[i])
+    f[i] = v
+
+  ## Check that the vector of fields is properly set up.
+  for i in xrange(n):
+    if f[i] is None:
+      raise U.ExpectedError, (500, "Field %s is unset" % rmap[i])
+
+  ## Done.
+  return f
 
 ###--------------------------------------------------------------------------
 ### Protocol.
 
 ###--------------------------------------------------------------------------
 ### Protocol.
@@ -76,6 +135,8 @@ class BasicRecord (object):
     me._be = backend
   def write(me):
     me._be._update(me)
     me._be = backend
   def write(me):
     me._be._update(me)
+  def remove(me):
+    me._be._remove(me)
 
 class TrivialRecord (BasicRecord):
   """
 
 class TrivialRecord (BasicRecord):
   """
@@ -147,10 +208,13 @@ class FlatFileBackend (object):
   specified by the DELIM constructor argument.
 
   The file is updated by writing a new version alongside, as `FILE.new', and
   specified by the DELIM constructor argument.
 
   The file is updated by writing a new version alongside, as `FILE.new', and
-  renaming it over the old version.  If a LOCK file is named then an
-  exclusive fcntl(2)-style lock is taken out on `LOCKDIR/LOCK' (creating the
-  file if necessary) during the update operation.  Use of a lockfile is
-  strongly recommended.
+  renaming it over the old version.  If a LOCK is provided then this is done
+  while holding a lock.  By default, an exclusive fcntl(2)-style lock is
+  taken out on `LOCKDIR/LOCK' (creating the file if necessary) during the
+  update operation, but subclasses can override the `dolocked' method to
+  provide alternative locking behaviour; the LOCK parameter is not
+  interpreted by any other methods.  Use of a lockfile is strongly
+  recommended.
 
   The DELIM constructor argument specifies the delimiter character used when
   splitting lines into fields.  The USER and PASSWD arguments give the field
 
   The DELIM constructor argument specifies the delimiter character used when
   splitting lines into fields.  The USER and PASSWD arguments give the field
@@ -181,6 +245,24 @@ class FlatFileBackend (object):
           return rec
     raise UnknownUser, user
 
           return rec
     raise UnknownUser, user
 
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right.
+    """
+
+    f = fill_in_fields(me._fmap['user'], me._fmap['passwd'],
+                       [(k[2:], i)
+                        for k, i in me._fmap.iteritems()
+                        if k.startswith('f_')],
+                       user, passwd, args)
+    r = FlatFileRecord(me._delim.join(f), me._delim, me._fmap, backend = me)
+    me._rewrite('create', r)
+
   def _rewrite(me, op, rec):
     """
     Rewrite the file, according to OP.
   def _rewrite(me, op, rec):
     """
     Rewrite the file, according to OP.
@@ -252,11 +334,21 @@ class FlatFileBackend (object):
 
     ## If there's a lockfile, then acquire it around the meat of this
     ## function; otherwise just do the job.
 
     ## If there's a lockfile, then acquire it around the meat of this
     ## function; otherwise just do the job.
-    if me._lock is None:
-      doit()
-    else:
-      with U.lockfile(OS.path.join(CFG.LOCKDIR, me._lock), 5):
-        doit()
+    if me._lock is None: doit()
+    else: me.dolocked(me._lock, doit)
+
+  def dolocked(me, lock, func):
+    """
+    Call FUNC with the LOCK held.
+
+    Subclasses can override this method in order to provide alternative
+    locking functionality.
+    """
+    try: OS.mkdir(CFG.LOCKDIR)
+    except OSError, e:
+      if e.errno != E.EEXIST: raise
+    with U.lockfile(OS.path.join(CFG.LOCKDIR, lock), 5):
+      func()
 
   def _parse(me, line):
     """Convenience function for constructing a record."""
 
   def _parse(me, line):
     """Convenience function for constructing a record."""
@@ -266,6 +358,10 @@ class FlatFileBackend (object):
     """Update the record REC in the file."""
     me._rewrite('update', rec)
 
     """Update the record REC in the file."""
     me._rewrite('update', rec)
 
+  def _remove(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('remove', rec)
+
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
@@ -322,6 +418,36 @@ class DatabaseBackend (object):
       setattr(rec, 'f_' + f, v)
     return rec
 
       setattr(rec, 'f_' + f, v)
     return rec
 
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the named USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right, in the order given to the constructor.
+    """
+
+    tags = ['user', 'passwd'] + \
+        ['t_%d' % 0 for i in xrange(len(me._fields))]
+    f = fill_in_fields(0, 1, list(I.izip(me._fields, I.count(2))),
+                       user, passwd, args)
+    me._connect()
+    with me._db:
+      me._db.execute("INSERT INTO %s (%s) VALUES (%s)" %
+                     (me._table,
+                      ', '.join([me._user, me._passwd] + me._fields),
+                      ', '.join(['$%s' % t for t in tags])),
+                     **dict(I.izip(tags, f)))
+
+  def _remove(me, rec):
+    """Remove the record REC from the database."""
+    me._connect()
+    with me._db:
+      me._db.execute("DELETE FROM %s WHERE %s = $user" %
+                     (me._table, me._user),
+                     user = rec.user)
+
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()