chiark / gitweb /
Add support for branch description files
[stgit] / stgit / stack.py
index 55c49a85c3312ab842faf89de4531735ba22740f..33c6d83d52f511e3fdbfecd7e81507ef21643b8e 100644 (file)
@@ -249,8 +249,7 @@ class Series:
     """Class including the operations on series
     """
     def __init__(self, name = None):
-        """Takes a series name as the parameter. A valid .git/patches/name
-        directory should exist
+        """Takes a series name as the parameter.
         """
         if name:
             self.__name = name
@@ -265,6 +264,7 @@ class Series:
             self.__applied_file = os.path.join(self.__patch_dir, 'applied')
             self.__unapplied_file = os.path.join(self.__patch_dir, 'unapplied')
             self.__current_file = os.path.join(self.__patch_dir, 'current')
+            self.__descr_file = os.path.join(self.__patch_dir, 'description')
 
     def get_branch(self):
         """Return the branch name for the Series object
@@ -297,12 +297,16 @@ class Series:
             return name
 
     def get_applied(self):
+        if not os.path.isfile(self.__applied_file):
+            raise StackException, 'Branch "%s" not initialised' % self.__name
         f = file(self.__applied_file)
         names = [line.strip() for line in f.readlines()]
         f.close()
         return names
 
     def get_unapplied(self):
+        if not os.path.isfile(self.__unapplied_file):
+            raise StackException, 'Branch "%s" not initialised' % self.__name
         f = file(self.__unapplied_file)
         names = [line.strip() for line in f.readlines()]
         f.close()
@@ -311,6 +315,25 @@ class Series:
     def get_base_file(self):
         return self.__base_file
 
+    def get_protected(self):
+        return os.path.isfile(os.path.join(self.__patch_dir, 'protected'))
+
+    def protect(self):
+        protect_file = os.path.join(self.__patch_dir, 'protected')
+        if not os.path.isfile(protect_file):
+            create_empty_file(protect_file)
+
+    def unprotect(self):
+        protect_file = os.path.join(self.__patch_dir, 'protected')
+        if os.path.isfile(protect_file):
+            os.remove(protect_file)
+
+    def get_description(self):
+        if os.path.isfile(self.__descr_file):
+            return read_string(self.__descr_file)
+        else:
+            return ''
+
     def __patch_is_current(self, patch):
         return patch.get_name() == read_string(self.__current_file)
 
@@ -363,14 +386,42 @@ class Series:
 
         create_empty_file(self.__applied_file)
         create_empty_file(self.__unapplied_file)
+        create_empty_file(self.__descr_file)
         self.__begin_stack_check()
 
+    def delete(self, force = False):
+        """Deletes an stgit series
+        """
+        if os.path.isdir(self.__patch_dir):
+            patches = self.get_unapplied() + self.get_applied()
+            if not force and patches:
+                raise StackException, \
+                      'Cannot delete: the series still contains patches'
+            patches.reverse()
+            for p in patches:
+                self.delete_patch(p)
+
+            if os.path.isfile(self.__applied_file):
+                os.remove(self.__applied_file)
+            if os.path.isfile(self.__unapplied_file):
+                os.remove(self.__unapplied_file)
+            if os.path.isfile(self.__current_file):
+                os.remove(self.__current_file)
+            if os.path.isfile(self.__descr_file):
+                os.remove(self.__descr_file)
+            if not os.listdir(self.__patch_dir):
+                os.rmdir(self.__patch_dir)
+            else:
+                print 'Series directory %s is not empty.' % self.__name
+
+        if os.path.isfile(self.__base_file):
+            os.remove(self.__base_file)
+
     def refresh_patch(self, message = None, edit = False, show_patch = False,
                       cache_update = True,
                       author_name = None, author_email = None,
                       author_date = None,
-                      committer_name = None, committer_email = None,
-                      commit_only = False):
+                      committer_name = None, committer_email = None):
         """Generates a new commit for the given patch
         """
         name = self.get_current()
@@ -411,14 +462,13 @@ class Series:
                                committer_name = committer_name,
                                committer_email = committer_email)
 
-        if not commit_only:
-            patch.set_top(commit_id)
-            patch.set_description(descr)
-            patch.set_authname(author_name)
-            patch.set_authemail(author_email)
-            patch.set_authdate(author_date)
-            patch.set_commname(committer_name)
-            patch.set_commemail(committer_email)
+        patch.set_top(commit_id)
+        patch.set_description(descr)
+        patch.set_authname(author_name)
+        patch.set_authemail(author_email)
+        patch.set_authdate(author_date)
+        patch.set_commname(committer_name)
+        patch.set_commemail(committer_email)
 
         return commit_id
 
@@ -517,13 +567,40 @@ class Series:
             # top != bottom always since we have a commit for each patch
             if head == bottom:
                 # reset the backup information
-                patch.set_bottom(bottom, backup = True)
+                patch.set_bottom(head, backup = True)
                 patch.set_top(top, backup = True)
 
             else:
-                top = head
-                # stop the fast-forwarding, must do a real merge
-                break
+                head_tree = git.get_commit(head).get_tree()
+                bottom_tree = git.get_commit(bottom).get_tree()
+                if head_tree == bottom_tree:
+                    # We must just reparent this patch and create a new commit
+                    # for it
+                    descr = patch.get_description()
+                    author_name = patch.get_authname()
+                    author_email = patch.get_authemail()
+                    author_date = patch.get_authdate()
+                    committer_name = patch.get_commname()
+                    committer_email = patch.get_commemail()
+
+                    top_tree = git.get_commit(top).get_tree()
+
+                    top = git.commit(message = descr, parents = [head],
+                                     cache_update = False,
+                                     tree_id = top_tree,
+                                     allowempty = True,
+                                     author_name = author_name,
+                                     author_email = author_email,
+                                     author_date = author_date,
+                                     committer_name = committer_name,
+                                     committer_email = committer_email)
+
+                    patch.set_bottom(head, backup = True)
+                    patch.set_top(top, backup = True)
+                else:
+                    top = head
+                    # stop the fast-forwarding, must do a real merge
+                    break
 
             forwarded+=1
             unapplied.remove(name)
@@ -568,14 +645,17 @@ class Series:
             # The current patch is empty after merge.
             patch.set_bottom(head, backup = True)
             patch.set_top(head, backup = True)
-            # merge/refresh can fail but the patch needs to be pushed
-            try:
-                git.merge(bottom, head, top)
-            except git.GitException, ex:
-                print >> sys.stderr, \
-                      'The merge failed during "push". ' \
-                      'Use "refresh" after fixing the conflicts'
-                pass
+
+            # Try the fast applying first. If this fails, fall back to the
+            # three-way merge
+            if not git.apply_diff(bottom, top):
+                # merge can fail but the patch needs to be pushed
+                try:
+                    git.merge(bottom, head, top)
+                except git.GitException, ex:
+                    print >> sys.stderr, \
+                          'The merge failed during "push". ' \
+                          'Use "refresh" after fixing the conflicts'
 
         append_string(self.__applied_file, name)