chiark / gitweb /
Refactor --author/--committer options
[stgit] / stgit / commands / common.py
index 9e1f7cbc74d05ee23dd3c658d4559635a361b2bf..d6df813542b4486f06653cbdf3ca0906218420fe 100644 (file)
@@ -21,20 +21,20 @@ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 import sys, os, os.path, re
 from optparse import OptionParser, make_option
 
 import sys, os, os.path, re
 from optparse import OptionParser, make_option
 
+from stgit.exception import *
 from stgit.utils import *
 from stgit.out import *
 from stgit.utils import *
 from stgit.out import *
+from stgit.run import *
 from stgit import stack, git, basedir
 from stgit.config import config, file_extensions
 from stgit import stack, git, basedir
 from stgit.config import config, file_extensions
-
-crt_series = None
-
+from stgit.lib import stack as libstack
 
 # Command exception class
 
 # Command exception class
-class CmdException(Exception):
+class CmdException(StgException):
     pass
 
 # Utility functions
     pass
 
 # Utility functions
-class RevParseException(Exception):
+class RevParseException(StgException):
     """Revision spec parse error."""
     pass
 
     """Revision spec parse error."""
     pass
 
@@ -72,11 +72,19 @@ def parse_rev(rev):
     # No, we can't parse that.
     raise RevParseException
 
     # No, we can't parse that.
     raise RevParseException
 
-def git_id(rev):
+def git_id(crt_series, rev):
     """Return the GIT id
     """
     if not rev:
         return None
     """Return the GIT id
     """
     if not rev:
         return None
+
+    # try a GIT revision first
+    try:
+        return git.rev_parse(rev + '^{commit}')
+    except git.GitException:
+        pass
+
+    # try an StGIT patch name
     try:
         patch, branch, patch_id = parse_rev(rev)
         if branch == None:
     try:
         patch, branch, patch_id = parse_rev(rev)
         if branch == None:
@@ -103,27 +111,30 @@ def git_id(rev):
             return series.get_base()
     except RevParseException:
         pass
             return series.get_base()
     except RevParseException:
         pass
-    return git.rev_parse(rev + '^{commit}')
+    except stack.StackException:
+        pass
+
+    raise CmdException, 'Unknown patch or revision: %s' % rev
 
 def check_local_changes():
     if git.local_changes():
         raise CmdException, \
               'local changes in the tree. Use "refresh" or "status --reset"'
 
 
 def check_local_changes():
     if git.local_changes():
         raise CmdException, \
               'local changes in the tree. Use "refresh" or "status --reset"'
 
-def check_head_top_equal():
+def check_head_top_equal(crt_series):
     if not crt_series.head_top_equal():
         raise CmdException(
     if not crt_series.head_top_equal():
         raise CmdException(
-            'HEAD and top are not the same. You probably committed\n'
-            '  changes to the tree outside of StGIT. To bring them\n'
-            '  into StGIT, use the "assimilate" command')
+"""HEAD and top are not the same. This can happen if you
+   modify a branch with git. "stg repair --help" explains
+   more about what to do next.""")
 
 def check_conflicts():
 
 def check_conflicts():
-    if os.path.exists(os.path.join(basedir.get(), 'conflicts')):
+    if git.get_conflicts():
         raise CmdException, \
               'Unsolved conflicts. Please resolve them first or\n' \
               '  revert the changes with "status --reset"'
 
         raise CmdException, \
               'Unsolved conflicts. Please resolve them first or\n' \
               '  revert the changes with "status --reset"'
 
-def print_crt_patch(branch = None):
+def print_crt_patch(crt_series, branch = None):
     if not branch:
         patch = crt_series.get_current()
     else:
     if not branch:
         patch = crt_series.get_current()
     else:
@@ -134,29 +145,11 @@ def print_crt_patch(branch = None):
     else:
         out.info('No patches applied')
 
     else:
         out.info('No patches applied')
 
-def resolved(filename, reset = None):
-    if reset:
-        reset_file = filename + file_extensions()[reset]
-        if os.path.isfile(reset_file):
-            if os.path.isfile(filename):
-                os.remove(filename)
-            os.rename(reset_file, filename)
-
-    git.update_cache([filename], force = True)
-
-    for ext in file_extensions().values():
-        fn = filename + ext
-        if os.path.isfile(fn):
-            os.remove(fn)
-
 def resolved_all(reset = None):
     conflicts = git.get_conflicts()
 def resolved_all(reset = None):
     conflicts = git.get_conflicts()
-    if conflicts:
-        for filename in conflicts:
-            resolved(filename, reset)
-        os.remove(os.path.join(basedir.get(), 'conflicts'))
+    git.resolved(conflicts, reset)
 
 
-def push_patches(patches, check_merged = False):
+def push_patches(crt_series, patches, check_merged = False):
     """Push multiple patches onto the stack. This function is shared
     between the push and pull commands
     """
     """Push multiple patches onto the stack. This function is shared
     between the push and pull commands
     """
@@ -195,7 +188,7 @@ def push_patches(patches, check_merged = False):
             else:
                 out.done()
 
             else:
                 out.done()
 
-def pop_patches(patches, keep = False):
+def pop_patches(crt_series, patches, keep = False):
     """Pop the patches in the list from the stack. It is assumed that
     the patches are listed in the stack reverse order.
     """
     """Pop the patches in the list from the stack. It is assumed that
     the patches are listed in the stack reverse order.
     """
@@ -273,29 +266,19 @@ def parse_patches(patch_args, patch_list, boundary = 0, ordered = False):
     return patches
 
 def name_email(address):
     return patches
 
 def name_email(address):
-    """Return a tuple consisting of the name and email parsed from a
-    standard 'name <email>' or 'email (name)' string
-    """
-    address = re.sub('[\\\\"]', '\\\\\g<0>', address)
-    str_list = re.findall('^(.*)\s*<(.*)>\s*$', address)
-    if not str_list:
-        str_list = re.findall('^(.*)\s*\((.*)\)\s*$', address)
-        if not str_list:
-            raise CmdException, 'Incorrect "name <email>"/"email (name)" string: %s' % address
-        return ( str_list[0][1], str_list[0][0] )
-
-    return str_list[0]
+    p = parse_name_email(address)
+    if p:
+        return p
+    else:
+        raise CmdException('Incorrect "name <email>"/"email (name)" string: %s'
+                           % address)
 
 def name_email_date(address):
 
 def name_email_date(address):
-    """Return a tuple consisting of the name, email and date parsed
-    from a 'name <email> date' string
-    """
-    address = re.sub('[\\\\"]', '\\\\\g<0>', address)
-    str_list = re.findall('^(.*)\s*<(.*)>\s*(.*)\s*$', address)
-    if not str_list:
-        raise CmdException, 'Incorrect "name <email> date" string: %s' % address
-
-    return str_list[0]
+    p = parse_name_email_date(address)
+    if p:
+        return p
+    else:
+        raise CmdException('Incorrect "name <email> date" string: %s' % address)
 
 def address_or_alias(addr_str):
     """Return the address if it contains an e-mail address or look up
 
 def address_or_alias(addr_str):
     """Return the address if it contains an e-mail address or look up
@@ -317,14 +300,7 @@ def address_or_alias(addr_str):
                  for addr in addr_str.split(',')]
     return ', '.join([addr for addr in addr_list if addr])
 
                  for addr in addr_str.split(',')]
     return ', '.join([addr for addr in addr_list if addr])
 
-def prepare_rebase(force=None):
-    if not force:
-        # Be sure we won't loose results of stg-(un)commit by error.
-        # Do not require an existing orig-base for compatibility with 0.12 and earlier.
-        origbase = crt_series._get_field('orig-base')
-        if origbase and crt_series.get_base() != origbase:
-            raise CmdException, 'Rebasing would possibly lose data'
-
+def prepare_rebase(crt_series):
     # pop all patches
     applied = crt_series.get_applied()
     if len(applied) > 0:
     # pop all patches
     applied = crt_series.get_applied()
     if len(applied) > 0:
@@ -333,9 +309,9 @@ def prepare_rebase(force=None):
         out.done()
     return applied
 
         out.done()
     return applied
 
-def rebase(target):
+def rebase(crt_series, target):
     try:
     try:
-        tree_id = git_id(target)
+        tree_id = git_id(crt_series, target)
     except:
         # it might be that we use a custom rebase command with its own
         # target type
     except:
         # it might be that we use a custom rebase command with its own
         # target type
@@ -350,12 +326,12 @@ def rebase(target):
     git.rebase(tree_id = tree_id)
     out.done()
 
     git.rebase(tree_id = tree_id)
     out.done()
 
-def post_rebase(applied, nopush, merged):
+def post_rebase(crt_series, applied, nopush, merged):
     # memorize that we rebased to here
     crt_series._set_field('orig-base', git.get_head())
     # push the patches back
     if not nopush:
     # memorize that we rebased to here
     crt_series._set_field('orig-base', git.get_head())
     # push the patches back
     if not nopush:
-        push_patches(applied, merged)
+        push_patches(crt_series, applied, merged)
 
 #
 # Patch description/e-mail/diff parsing
 
 #
 # Patch description/e-mail/diff parsing
@@ -476,13 +452,90 @@ def parse_mail(msg):
 
     return (descr, authname, authemail, authdate, diff)
 
 
     return (descr, authname, authemail, authdate, diff)
 
-def parse_patch(fobj):
-    """Parse the input file and return (description, authname,
+def parse_patch(text):
+    """Parse the input text and return (description, authname,
     authemail, authdate, diff)
     """
     authemail, authdate, diff)
     """
-    descr, diff = __split_descr_diff(fobj.read())
+    descr, diff = __split_descr_diff(text)
     descr, authname, authemail, authdate = __parse_description(descr)
 
     # we don't yet have an agreed place for the creation date.
     # Just return None
     return (descr, authname, authemail, authdate, diff)
     descr, authname, authemail, authdate = __parse_description(descr)
 
     # we don't yet have an agreed place for the creation date.
     # Just return None
     return (descr, authname, authemail, authdate, diff)
+
+def readonly_constant_property(f):
+    """Decorator that converts a function that computes a value to an
+    attribute that returns the value. The value is computed only once,
+    the first time it is accessed."""
+    def new_f(self):
+        n = '__' + f.__name__
+        if not hasattr(self, n):
+            setattr(self, n, f(self))
+        return getattr(self, n)
+    return property(new_f)
+
+class DirectoryException(StgException):
+    pass
+
+class _Directory(object):
+    def __init__(self, needs_current_series = True):
+        self.needs_current_series =  needs_current_series
+    @readonly_constant_property
+    def git_dir(self):
+        try:
+            return Run('git', 'rev-parse', '--git-dir'
+                       ).discard_stderr().output_one_line()
+        except RunException:
+            raise DirectoryException('No git repository found')
+    @readonly_constant_property
+    def __topdir_path(self):
+        try:
+            lines = Run('git', 'rev-parse', '--show-cdup'
+                        ).discard_stderr().output_lines()
+            if len(lines) == 0:
+                return '.'
+            elif len(lines) == 1:
+                return lines[0]
+            else:
+                raise RunException('Too much output')
+        except RunException:
+            raise DirectoryException('No git repository found')
+    @readonly_constant_property
+    def is_inside_git_dir(self):
+        return { 'true': True, 'false': False
+                 }[Run('git', 'rev-parse', '--is-inside-git-dir'
+                       ).output_one_line()]
+    @readonly_constant_property
+    def is_inside_worktree(self):
+        return { 'true': True, 'false': False
+                 }[Run('git', 'rev-parse', '--is-inside-work-tree'
+                       ).output_one_line()]
+    def cd_to_topdir(self):
+        os.chdir(self.__topdir_path)
+
+class DirectoryAnywhere(_Directory):
+    def setup(self):
+        pass
+
+class DirectoryHasRepository(_Directory):
+    def setup(self):
+        self.git_dir # might throw an exception
+
+class DirectoryInWorktree(DirectoryHasRepository):
+    def setup(self):
+        DirectoryHasRepository.setup(self)
+        if not self.is_inside_worktree:
+            raise DirectoryException('Not inside a git worktree')
+
+class DirectoryGotoToplevel(DirectoryInWorktree):
+    def setup(self):
+        DirectoryInWorktree.setup(self)
+        self.cd_to_topdir()
+
+class DirectoryHasRepositoryLib(_Directory):
+    """For commands that use the new infrastructure in stgit.lib.*."""
+    def __init__(self):
+        self.needs_current_series = False
+    def setup(self):
+        # This will throw an exception if we don't have a repository.
+        self.repository = libstack.Repository.default()