chiark / gitweb /
dd7aad24302d6cad5147aa20e916aedfb2c8be26
[stgit] / stgit / lib / transaction.py
1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
3
4 import atexit
5
6 from stgit import exception, utils
7 from stgit.utils import any, all
8 from stgit.out import *
9 from stgit.lib import git
10
11 class TransactionException(exception.StgException):
12     """Exception raised when something goes wrong with a
13     L{StackTransaction}."""
14
15 class TransactionHalted(TransactionException):
16     """Exception raised when a L{StackTransaction} stops part-way through.
17     Used to make a non-local jump from the transaction setup to the
18     part of the transaction code where the transaction is run."""
19
20 def _print_current_patch(old_applied, new_applied):
21     def now_at(pn):
22         out.info('Now at patch "%s"' % pn)
23     if not old_applied and not new_applied:
24         pass
25     elif not old_applied:
26         now_at(new_applied[-1])
27     elif not new_applied:
28         out.info('No patch applied')
29     elif old_applied[-1] == new_applied[-1]:
30         pass
31     else:
32         now_at(new_applied[-1])
33
34 class _TransPatchMap(dict):
35     """Maps patch names to sha1 strings."""
36     def __init__(self, stack):
37         dict.__init__(self)
38         self.__stack = stack
39     def __getitem__(self, pn):
40         try:
41             return dict.__getitem__(self, pn)
42         except KeyError:
43             return self.__stack.patches.get(pn).commit
44
45 class StackTransaction(object):
46     """A stack transaction, used for making complex updates to an StGit
47     stack in one single operation that will either succeed or fail
48     cleanly.
49
50     The basic theory of operation is the following:
51
52       1. Create a transaction object.
53
54       2. Inside a::
55
56          try
57            ...
58          except TransactionHalted:
59            pass
60
61       block, update the transaction with e.g. methods like
62       L{pop_patches} and L{push_patch}. This may create new git
63       objects such as commits, but will not write any refs; this means
64       that in case of a fatal error we can just walk away, no clean-up
65       required.
66
67       (Some operations may need to touch your index and working tree,
68       though. But they are cleaned up when needed.)
69
70       3. After the C{try} block -- wheher or not the setup ran to
71       completion or halted part-way through by raising a
72       L{TransactionHalted} exception -- call the transaction's L{run}
73       method. This will either succeed in writing the updated state to
74       your refs and index+worktree, or fail without having done
75       anything."""
76     def __init__(self, stack, msg, allow_conflicts = False):
77         self.__stack = stack
78         self.__msg = msg
79         self.__patches = _TransPatchMap(stack)
80         self.__applied = list(self.__stack.patchorder.applied)
81         self.__unapplied = list(self.__stack.patchorder.unapplied)
82         self.__error = None
83         self.__current_tree = self.__stack.head.data.tree
84         self.__base = self.__stack.base
85         if isinstance(allow_conflicts, bool):
86             self.__allow_conflicts = lambda trans: allow_conflicts
87         else:
88             self.__allow_conflicts = allow_conflicts
89         self.__temp_index = self.temp_index_tree = None
90     stack = property(lambda self: self.__stack)
91     patches = property(lambda self: self.__patches)
92     def __set_applied(self, val):
93         self.__applied = list(val)
94     applied = property(lambda self: self.__applied, __set_applied)
95     def __set_unapplied(self, val):
96         self.__unapplied = list(val)
97     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
98     def __set_base(self, val):
99         assert (not self.__applied
100                 or self.patches[self.applied[0]].data.parent == val)
101         self.__base = val
102     base = property(lambda self: self.__base, __set_base)
103     @property
104     def temp_index(self):
105         if not self.__temp_index:
106             self.__temp_index = self.__stack.repository.temp_index()
107             atexit.register(self.__temp_index.delete)
108         return self.__temp_index
109     def __checkout(self, tree, iw):
110         if not self.__stack.head_top_equal():
111             out.error(
112                 'HEAD and top are not the same.',
113                 'This can happen if you modify a branch with git.',
114                 '"stg repair --help" explains more about what to do next.')
115             self.__abort()
116         if self.__current_tree == tree:
117             # No tree change, but we still want to make sure that
118             # there are no unresolved conflicts. Conflicts
119             # conceptually "belong" to the topmost patch, and just
120             # carrying them along to another patch is confusing.
121             if (self.__allow_conflicts(self) or iw == None
122                 or not iw.index.conflicts()):
123                 return
124             out.error('Need to resolve conflicts first')
125             self.__abort()
126         assert iw != None
127         iw.checkout(self.__current_tree, tree)
128         self.__current_tree = tree
129     @staticmethod
130     def __abort():
131         raise TransactionException(
132             'Command aborted (all changes rolled back)')
133     def __check_consistency(self):
134         remaining = set(self.__applied + self.__unapplied)
135         for pn, commit in self.__patches.iteritems():
136             if commit == None:
137                 assert self.__stack.patches.exists(pn)
138             else:
139                 assert pn in remaining
140     @property
141     def __head(self):
142         if self.__applied:
143             return self.__patches[self.__applied[-1]]
144         else:
145             return self.__base
146     def abort(self, iw = None):
147         # The only state we need to restore is index+worktree.
148         if iw:
149             self.__checkout(self.__stack.head.data.tree, iw)
150     def run(self, iw = None, set_head = True):
151         """Execute the transaction. Will either succeed, or fail (with an
152         exception) and do nothing."""
153         self.__check_consistency()
154         new_head = self.__head
155
156         # Set branch head.
157         if set_head:
158             if iw:
159                 try:
160                     self.__checkout(new_head.data.tree, iw)
161                 except git.CheckoutException:
162                     # We have to abort the transaction.
163                     self.abort(iw)
164                     self.__abort()
165             self.__stack.set_head(new_head, self.__msg)
166
167         if self.__error:
168             out.error(self.__error)
169
170         # Write patches.
171         for pn, commit in self.__patches.iteritems():
172             if self.__stack.patches.exists(pn):
173                 p = self.__stack.patches.get(pn)
174                 if commit == None:
175                     p.delete()
176                 else:
177                     p.set_commit(commit, self.__msg)
178             else:
179                 self.__stack.patches.new(pn, commit, self.__msg)
180         _print_current_patch(self.__stack.patchorder.applied, self.__applied)
181         self.__stack.patchorder.applied = self.__applied
182         self.__stack.patchorder.unapplied = self.__unapplied
183
184         if self.__error:
185             return utils.STGIT_CONFLICT
186         else:
187             return utils.STGIT_SUCCESS
188
189     def __halt(self, msg):
190         self.__error = msg
191         raise TransactionHalted(msg)
192
193     @staticmethod
194     def __print_popped(popped):
195         if len(popped) == 0:
196             pass
197         elif len(popped) == 1:
198             out.info('Popped %s' % popped[0])
199         else:
200             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
201
202     def pop_patches(self, p):
203         """Pop all patches pn for which p(pn) is true. Return the list of
204         other patches that had to be popped to accomplish this. Always
205         succeeds."""
206         popped = []
207         for i in xrange(len(self.applied)):
208             if p(self.applied[i]):
209                 popped = self.applied[i:]
210                 del self.applied[i:]
211                 break
212         popped1 = [pn for pn in popped if not p(pn)]
213         popped2 = [pn for pn in popped if p(pn)]
214         self.unapplied = popped1 + popped2 + self.unapplied
215         self.__print_popped(popped)
216         return popped1
217
218     def delete_patches(self, p):
219         """Delete all patches pn for which p(pn) is true. Return the list of
220         other patches that had to be popped to accomplish this. Always
221         succeeds."""
222         popped = []
223         all_patches = self.applied + self.unapplied
224         for i in xrange(len(self.applied)):
225             if p(self.applied[i]):
226                 popped = self.applied[i:]
227                 del self.applied[i:]
228                 break
229         popped = [pn for pn in popped if not p(pn)]
230         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
231         self.__print_popped(popped)
232         for pn in all_patches:
233             if p(pn):
234                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
235                 self.patches[pn] = None
236                 out.info('Deleted %s%s' % (pn, s))
237         return popped
238
239     def push_patch(self, pn, iw = None):
240         """Attempt to push the named patch. If this results in conflicts,
241         halts the transaction. If index+worktree are given, spill any
242         conflicts to them."""
243         orig_cd = self.patches[pn].data
244         cd = orig_cd.set_committer(None)
245         s = ['', ' (empty)'][cd.is_nochange()]
246         oldparent = cd.parent
247         cd = cd.set_parent(self.__head)
248         base = oldparent.data.tree
249         ours = cd.parent.data.tree
250         theirs = cd.tree
251         tree, self.temp_index_tree = self.temp_index.merge(
252             base, ours, theirs, self.temp_index_tree)
253         merge_conflict = False
254         if not tree:
255             if iw == None:
256                 self.__halt('%s does not apply cleanly' % pn)
257             try:
258                 self.__checkout(ours, iw)
259             except git.CheckoutException:
260                 self.__halt('Index/worktree dirty')
261             try:
262                 iw.merge(base, ours, theirs)
263                 tree = iw.index.write_tree()
264                 self.__current_tree = tree
265                 s = ' (modified)'
266             except git.MergeConflictException:
267                 tree = ours
268                 merge_conflict = True
269                 s = ' (conflict)'
270             except git.MergeException, e:
271                 self.__halt(str(e))
272         cd = cd.set_tree(tree)
273         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
274                ['parent', 'tree', 'author', 'message']):
275             self.patches[pn] = self.__stack.repository.commit(cd)
276         else:
277             s = ' (unmodified)'
278         del self.unapplied[self.unapplied.index(pn)]
279         self.applied.append(pn)
280         out.info('Pushed %s%s' % (pn, s))
281         if merge_conflict:
282             # We've just caused conflicts, so we must allow them in
283             # the final checkout.
284             self.__allow_conflicts = lambda trans: True
285
286             self.__halt('Merge conflict')