chiark / gitweb /
Implement "stg refresh --edit" again
[stgit] / stgit / commands / refresh.py
index 384cfb958e84d9f8c24de9c6257448db130001c7..3c82906ada48be49b9776c72b603dd3e7ea3ca02 100644 (file)
@@ -20,9 +20,9 @@ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 
 from stgit.argparse import opt
 from stgit.commands import common
-from stgit.lib import git, transaction
+from stgit.lib import git, transaction, edit
 from stgit.out import out
-from stgit import utils
+from stgit import argparse, utils
 
 help = 'Generate a new commit for the current patch'
 kind = 'patch'
@@ -56,7 +56,11 @@ options = [
         Instead of setting the patch top to the current contents of
         the worktree, set it to the current contents of the index."""),
     opt('-p', '--patch',
-        short = 'Refresh (applied) PATCH instead of the top patch')]
+        short = 'Refresh (applied) PATCH instead of the top patch'),
+    opt('-e', '--edit', action = 'store_true',
+        short = 'Invoke an editor for the patch description'),
+    ] + (argparse.message_options(save_template = False) +
+         argparse.sign_options() + argparse.author_options())
 
 directory = common.DirectoryHasRepositoryLib()
 
@@ -121,9 +125,12 @@ def make_temp_patch(stack, patch_name, paths, temp_index):
     return trans.run(stack.repository.default_iw,
                      print_current_patch = False), temp_name
 
-def absorb_applied(trans, iw, patch_name, temp_name):
+def absorb_applied(trans, iw, patch_name, temp_name, edit_fun):
     """Absorb the temp patch (C{temp_name}) into the given patch
-    (C{patch_name}), which must be applied.
+    (C{patch_name}), which must be applied. If the absorption
+    succeeds, call C{edit_fun} on the resulting
+    L{CommitData<stgit.lib.git.CommitData>} before committing it and
+    commit the return value.
 
     @return: C{True} if we managed to absorb the temp patch, C{False}
              if we had to leave it for the user to deal with."""
@@ -141,7 +148,7 @@ def absorb_applied(trans, iw, patch_name, temp_name):
         temp_cd = trans.patches[temp_name].data
         assert trans.patches[patch_name] == temp_cd.parent
         trans.patches[patch_name] = trans.stack.repository.commit(
-            trans.patches[patch_name].data.set_tree(temp_cd.tree))
+            edit_fun(trans.patches[patch_name].data.set_tree(temp_cd.tree)))
         popped = trans.delete_patches(lambda pn: pn == temp_name, quiet = True)
         assert not popped # the temp patch was topmost
         temp_absorbed = True
@@ -153,9 +160,12 @@ def absorb_applied(trans, iw, patch_name, temp_name):
         pass
     return temp_absorbed
 
-def absorb_unapplied(trans, iw, patch_name, temp_name):
+def absorb_unapplied(trans, iw, patch_name, temp_name, edit_fun):
     """Absorb the temp patch (C{temp_name}) into the given patch
-    (C{patch_name}), which must be unapplied.
+    (C{patch_name}), which must be unapplied. If the absorption
+    succeeds, call C{edit_fun} on the resulting
+    L{CommitData<stgit.lib.git.CommitData>} before committing it and
+    commit the return value.
 
     @param iw: Not used.
     @return: C{True} if we managed to absorb the temp patch, C{False}
@@ -179,7 +189,7 @@ def absorb_unapplied(trans, iw, patch_name, temp_name):
         # It worked. Refresh the patch with the new tree, and delete
         # the temp patch.
         trans.patches[patch_name] = trans.stack.repository.commit(
-            patch_cd.set_tree(new_tree))
+            edit_fun(patch_cd.set_tree(new_tree)))
         popped = trans.delete_patches(lambda pn: pn == temp_name, quiet = True)
         assert not popped # the temp patch was not applied
         return True
@@ -188,13 +198,13 @@ def absorb_unapplied(trans, iw, patch_name, temp_name):
         # leave the temp patch for the user.
         return False
 
-def absorb(stack, patch_name, temp_name):
+def absorb(stack, patch_name, temp_name, edit_fun):
     """Absorb the temp patch into the target patch."""
     trans = transaction.StackTransaction(stack, 'refresh')
     iw = stack.repository.default_iw
     f = { True: absorb_applied, False: absorb_unapplied
           }[patch_name in trans.applied]
-    if f(trans, iw, patch_name, temp_name):
+    if f(trans, iw, patch_name, temp_name, edit_fun):
         def info_msg(): pass
     else:
         def info_msg():
@@ -228,4 +238,16 @@ def func(parser, options, args):
         stack, patch_name, paths, temp_index = path_limiting)
     if retval != utils.STGIT_SUCCESS:
         return retval
-    return absorb(stack, patch_name, temp_name)
+    def edit_fun(cd):
+        cd, failed_diff = edit.auto_edit_patch(
+            stack.repository, cd, msg = options.message, contains_diff = False,
+            author = options.author, committer = lambda p: p,
+            sign_str = options.sign_str)
+        assert not failed_diff
+        if options.edit:
+            cd, failed_diff = edit.interactive_edit_patch(
+                stack.repository, cd, edit_diff = False,
+                diff_flags = [], replacement_diff = None)
+            assert not failed_diff
+        return cd
+    return absorb(stack, patch_name, temp_name, edit_fun)