chiark / gitweb /
0f414d893812eeddf037757c7784eeb09c3cf236
[stgit] / stgit / lib / transaction.py
1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
3
4 import atexit
5 import itertools as it
6
7 from stgit import exception, utils
8 from stgit.utils import any, all
9 from stgit.out import *
10 from stgit.lib import git, log
11
12 class TransactionException(exception.StgException):
13     """Exception raised when something goes wrong with a
14     L{StackTransaction}."""
15
16 class TransactionHalted(TransactionException):
17     """Exception raised when a L{StackTransaction} stops part-way through.
18     Used to make a non-local jump from the transaction setup to the
19     part of the transaction code where the transaction is run."""
20
21 def _print_current_patch(old_applied, new_applied):
22     def now_at(pn):
23         out.info('Now at patch "%s"' % pn)
24     if not old_applied and not new_applied:
25         pass
26     elif not old_applied:
27         now_at(new_applied[-1])
28     elif not new_applied:
29         out.info('No patch applied')
30     elif old_applied[-1] == new_applied[-1]:
31         pass
32     else:
33         now_at(new_applied[-1])
34
35 class _TransPatchMap(dict):
36     """Maps patch names to sha1 strings."""
37     def __init__(self, stack):
38         dict.__init__(self)
39         self.__stack = stack
40     def __getitem__(self, pn):
41         try:
42             return dict.__getitem__(self, pn)
43         except KeyError:
44             return self.__stack.patches.get(pn).commit
45
46 class StackTransaction(object):
47     """A stack transaction, used for making complex updates to an StGit
48     stack in one single operation that will either succeed or fail
49     cleanly.
50
51     The basic theory of operation is the following:
52
53       1. Create a transaction object.
54
55       2. Inside a::
56
57          try
58            ...
59          except TransactionHalted:
60            pass
61
62       block, update the transaction with e.g. methods like
63       L{pop_patches} and L{push_patch}. This may create new git
64       objects such as commits, but will not write any refs; this means
65       that in case of a fatal error we can just walk away, no clean-up
66       required.
67
68       (Some operations may need to touch your index and working tree,
69       though. But they are cleaned up when needed.)
70
71       3. After the C{try} block -- wheher or not the setup ran to
72       completion or halted part-way through by raising a
73       L{TransactionHalted} exception -- call the transaction's L{run}
74       method. This will either succeed in writing the updated state to
75       your refs and index+worktree, or fail without having done
76       anything."""
77     def __init__(self, stack, msg, discard_changes = False,
78                  allow_conflicts = False, allow_bad_head = False):
79         """Create a new L{StackTransaction}.
80
81         @param discard_changes: Discard any changes in index+worktree
82         @type discard_changes: bool
83         @param allow_conflicts: Whether to allow pre-existing conflicts
84         @type allow_conflicts: bool or function of L{StackTransaction}"""
85         self.__stack = stack
86         self.__msg = msg
87         self.__patches = _TransPatchMap(stack)
88         self.__applied = list(self.__stack.patchorder.applied)
89         self.__unapplied = list(self.__stack.patchorder.unapplied)
90         self.__hidden = list(self.__stack.patchorder.hidden)
91         self.__conflicting_push = None
92         self.__error = None
93         self.__current_tree = self.__stack.head.data.tree
94         self.__base = self.__stack.base
95         self.__discard_changes = discard_changes
96         self.__bad_head = None
97         if isinstance(allow_conflicts, bool):
98             self.__allow_conflicts = lambda trans: allow_conflicts
99         else:
100             self.__allow_conflicts = allow_conflicts
101         self.__temp_index = self.temp_index_tree = None
102         if not allow_bad_head:
103             self.__assert_head_top_equal()
104     stack = property(lambda self: self.__stack)
105     patches = property(lambda self: self.__patches)
106     def __set_applied(self, val):
107         self.__applied = list(val)
108     applied = property(lambda self: self.__applied, __set_applied)
109     def __set_unapplied(self, val):
110         self.__unapplied = list(val)
111     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
112     def __set_hidden(self, val):
113         self.__hidden = list(val)
114     hidden = property(lambda self: self.__hidden, __set_hidden)
115     all_patches = property(lambda self: (self.__applied + self.__unapplied
116                                          + self.__hidden))
117     def __set_base(self, val):
118         assert (not self.__applied
119                 or self.patches[self.applied[0]].data.parent == val)
120         self.__base = val
121     base = property(lambda self: self.__base, __set_base)
122     @property
123     def temp_index(self):
124         if not self.__temp_index:
125             self.__temp_index = self.__stack.repository.temp_index()
126             atexit.register(self.__temp_index.delete)
127         return self.__temp_index
128     @property
129     def top(self):
130         if self.__applied:
131             return self.__patches[self.__applied[-1]]
132         else:
133             return self.__base
134     def __get_head(self):
135         if self.__bad_head:
136             return self.__bad_head
137         else:
138             return self.top
139     def __set_head(self, val):
140         self.__bad_head = val
141     head = property(__get_head, __set_head)
142     def __assert_head_top_equal(self):
143         if not self.__stack.head_top_equal():
144             out.error(
145                 'HEAD and top are not the same.',
146                 'This can happen if you modify a branch with git.',
147                 '"stg repair --help" explains more about what to do next.')
148             self.__abort()
149     def __checkout(self, tree, iw, allow_bad_head):
150         if not allow_bad_head:
151             self.__assert_head_top_equal()
152         if self.__current_tree == tree and not self.__discard_changes:
153             # No tree change, but we still want to make sure that
154             # there are no unresolved conflicts. Conflicts
155             # conceptually "belong" to the topmost patch, and just
156             # carrying them along to another patch is confusing.
157             if (self.__allow_conflicts(self) or iw == None
158                 or not iw.index.conflicts()):
159                 return
160             out.error('Need to resolve conflicts first')
161             self.__abort()
162         assert iw != None
163         if self.__discard_changes:
164             iw.checkout_hard(tree)
165         else:
166             iw.checkout(self.__current_tree, tree)
167         self.__current_tree = tree
168     @staticmethod
169     def __abort():
170         raise TransactionException(
171             'Command aborted (all changes rolled back)')
172     def __check_consistency(self):
173         remaining = set(self.all_patches)
174         for pn, commit in self.__patches.iteritems():
175             if commit == None:
176                 assert self.__stack.patches.exists(pn)
177             else:
178                 assert pn in remaining
179     def abort(self, iw = None):
180         # The only state we need to restore is index+worktree.
181         if iw:
182             self.__checkout(self.__stack.head.data.tree, iw,
183                             allow_bad_head = True)
184     def run(self, iw = None, set_head = True, allow_bad_head = False,
185             print_current_patch = True):
186         """Execute the transaction. Will either succeed, or fail (with an
187         exception) and do nothing."""
188         self.__check_consistency()
189         log.log_external_mods(self.__stack)
190         new_head = self.head
191
192         # Set branch head.
193         if set_head:
194             if iw:
195                 try:
196                     self.__checkout(new_head.data.tree, iw, allow_bad_head)
197                 except git.CheckoutException:
198                     # We have to abort the transaction.
199                     self.abort(iw)
200                     self.__abort()
201             self.__stack.set_head(new_head, self.__msg)
202
203         if self.__error:
204             out.error(self.__error)
205
206         # Write patches.
207         def write(msg):
208             for pn, commit in self.__patches.iteritems():
209                 if self.__stack.patches.exists(pn):
210                     p = self.__stack.patches.get(pn)
211                     if commit == None:
212                         p.delete()
213                     else:
214                         p.set_commit(commit, msg)
215                 else:
216                     self.__stack.patches.new(pn, commit, msg)
217             self.__stack.patchorder.applied = self.__applied
218             self.__stack.patchorder.unapplied = self.__unapplied
219             self.__stack.patchorder.hidden = self.__hidden
220             log.log_entry(self.__stack, msg)
221         old_applied = self.__stack.patchorder.applied
222         write(self.__msg)
223         if self.__conflicting_push != None:
224             self.__patches = _TransPatchMap(self.__stack)
225             self.__conflicting_push()
226             write(self.__msg + ' (CONFLICT)')
227         if print_current_patch:
228             _print_current_patch(old_applied, self.__applied)
229
230         if self.__error:
231             return utils.STGIT_CONFLICT
232         else:
233             return utils.STGIT_SUCCESS
234
235     def __halt(self, msg):
236         self.__error = msg
237         raise TransactionHalted(msg)
238
239     @staticmethod
240     def __print_popped(popped):
241         if len(popped) == 0:
242             pass
243         elif len(popped) == 1:
244             out.info('Popped %s' % popped[0])
245         else:
246             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
247
248     def pop_patches(self, p):
249         """Pop all patches pn for which p(pn) is true. Return the list of
250         other patches that had to be popped to accomplish this. Always
251         succeeds."""
252         popped = []
253         for i in xrange(len(self.applied)):
254             if p(self.applied[i]):
255                 popped = self.applied[i:]
256                 del self.applied[i:]
257                 break
258         popped1 = [pn for pn in popped if not p(pn)]
259         popped2 = [pn for pn in popped if p(pn)]
260         self.unapplied = popped1 + popped2 + self.unapplied
261         self.__print_popped(popped)
262         return popped1
263
264     def delete_patches(self, p, quiet = False):
265         """Delete all patches pn for which p(pn) is true. Return the list of
266         other patches that had to be popped to accomplish this. Always
267         succeeds."""
268         popped = []
269         all_patches = self.applied + self.unapplied + self.hidden
270         for i in xrange(len(self.applied)):
271             if p(self.applied[i]):
272                 popped = self.applied[i:]
273                 del self.applied[i:]
274                 break
275         popped = [pn for pn in popped if not p(pn)]
276         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
277         self.hidden = [pn for pn in self.hidden if not p(pn)]
278         self.__print_popped(popped)
279         for pn in all_patches:
280             if p(pn):
281                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
282                 self.patches[pn] = None
283                 if not quiet:
284                     out.info('Deleted %s%s' % (pn, s))
285         return popped
286
287     def push_patch(self, pn, iw = None):
288         """Attempt to push the named patch. If this results in conflicts,
289         halts the transaction. If index+worktree are given, spill any
290         conflicts to them."""
291         orig_cd = self.patches[pn].data
292         cd = orig_cd.set_committer(None)
293         oldparent = cd.parent
294         cd = cd.set_parent(self.top)
295         base = oldparent.data.tree
296         ours = cd.parent.data.tree
297         theirs = cd.tree
298         tree, self.temp_index_tree = self.temp_index.merge(
299             base, ours, theirs, self.temp_index_tree)
300         s = ''
301         merge_conflict = False
302         if not tree:
303             if iw == None:
304                 self.__halt('%s does not apply cleanly' % pn)
305             try:
306                 self.__checkout(ours, iw, allow_bad_head = False)
307             except git.CheckoutException:
308                 self.__halt('Index/worktree dirty')
309             try:
310                 iw.merge(base, ours, theirs)
311                 tree = iw.index.write_tree()
312                 self.__current_tree = tree
313                 s = ' (modified)'
314             except git.MergeConflictException:
315                 tree = ours
316                 merge_conflict = True
317                 s = ' (conflict)'
318             except git.MergeException, e:
319                 self.__halt(str(e))
320         cd = cd.set_tree(tree)
321         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
322                ['parent', 'tree', 'author', 'message']):
323             comm = self.__stack.repository.commit(cd)
324             self.head = comm
325         else:
326             comm = None
327             s = ' (unmodified)'
328         if not merge_conflict and cd.is_nochange():
329             s = ' (empty)'
330         out.info('Pushed %s%s' % (pn, s))
331         def update():
332             if comm:
333                 self.patches[pn] = comm
334             if pn in self.hidden:
335                 x = self.hidden
336             else:
337                 x = self.unapplied
338             del x[x.index(pn)]
339             self.applied.append(pn)
340         if merge_conflict:
341             # We've just caused conflicts, so we must allow them in
342             # the final checkout.
343             self.__allow_conflicts = lambda trans: True
344
345             # Save this update so that we can run it a little later.
346             self.__conflicting_push = update
347             self.__halt('Merge conflict')
348         else:
349             # Update immediately.
350             update()
351
352     def reorder_patches(self, applied, unapplied, hidden, iw = None):
353         """Push and pop patches to attain the given ordering."""
354         common = len(list(it.takewhile(lambda (a, b): a == b,
355                                        zip(self.applied, applied))))
356         to_pop = set(self.applied[common:])
357         self.pop_patches(lambda pn: pn in to_pop)
358         for pn in applied[common:]:
359             self.push_patch(pn, iw)
360         assert self.applied == applied
361         assert set(self.unapplied + self.hidden) == set(unapplied + hidden)
362         self.unapplied = unapplied
363         self.hidden = hidden