chiark / gitweb /
1ab5f8b0fa47e421e7a5593612c6f10d06541562
[stgit] / stgit / lib / transaction.py
1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
3
4 import atexit
5 import itertools as it
6
7 from stgit import exception, utils
8 from stgit.utils import any, all
9 from stgit.out import *
10 from stgit.lib import git, log
11
12 class TransactionException(exception.StgException):
13     """Exception raised when something goes wrong with a
14     L{StackTransaction}."""
15
16 class TransactionHalted(TransactionException):
17     """Exception raised when a L{StackTransaction} stops part-way through.
18     Used to make a non-local jump from the transaction setup to the
19     part of the transaction code where the transaction is run."""
20
21 def _print_current_patch(old_applied, new_applied):
22     def now_at(pn):
23         out.info('Now at patch "%s"' % pn)
24     if not old_applied and not new_applied:
25         pass
26     elif not old_applied:
27         now_at(new_applied[-1])
28     elif not new_applied:
29         out.info('No patch applied')
30     elif old_applied[-1] == new_applied[-1]:
31         pass
32     else:
33         now_at(new_applied[-1])
34
35 class _TransPatchMap(dict):
36     """Maps patch names to sha1 strings."""
37     def __init__(self, stack):
38         dict.__init__(self)
39         self.__stack = stack
40     def __getitem__(self, pn):
41         try:
42             return dict.__getitem__(self, pn)
43         except KeyError:
44             return self.__stack.patches.get(pn).commit
45
46 class StackTransaction(object):
47     """A stack transaction, used for making complex updates to an StGit
48     stack in one single operation that will either succeed or fail
49     cleanly.
50
51     The basic theory of operation is the following:
52
53       1. Create a transaction object.
54
55       2. Inside a::
56
57          try
58            ...
59          except TransactionHalted:
60            pass
61
62       block, update the transaction with e.g. methods like
63       L{pop_patches} and L{push_patch}. This may create new git
64       objects such as commits, but will not write any refs; this means
65       that in case of a fatal error we can just walk away, no clean-up
66       required.
67
68       (Some operations may need to touch your index and working tree,
69       though. But they are cleaned up when needed.)
70
71       3. After the C{try} block -- wheher or not the setup ran to
72       completion or halted part-way through by raising a
73       L{TransactionHalted} exception -- call the transaction's L{run}
74       method. This will either succeed in writing the updated state to
75       your refs and index+worktree, or fail without having done
76       anything."""
77     def __init__(self, stack, msg, discard_changes = False,
78                  allow_conflicts = False):
79         """Create a new L{StackTransaction}.
80
81         @param discard_changes: Discard any changes in index+worktree
82         @type discard_changes: bool
83         @param allow_conflicts: Whether to allow pre-existing conflicts
84         @type allow_conflicts: bool or function of L{StackTransaction}"""
85         self.__stack = stack
86         self.__msg = msg
87         self.__patches = _TransPatchMap(stack)
88         self.__applied = list(self.__stack.patchorder.applied)
89         self.__unapplied = list(self.__stack.patchorder.unapplied)
90         self.__hidden = list(self.__stack.patchorder.hidden)
91         self.__conflicting_push = None
92         self.__error = None
93         self.__current_tree = self.__stack.head.data.tree
94         self.__base = self.__stack.base
95         self.__discard_changes = discard_changes
96         if isinstance(allow_conflicts, bool):
97             self.__allow_conflicts = lambda trans: allow_conflicts
98         else:
99             self.__allow_conflicts = allow_conflicts
100         self.__temp_index = self.temp_index_tree = None
101     stack = property(lambda self: self.__stack)
102     patches = property(lambda self: self.__patches)
103     def __set_applied(self, val):
104         self.__applied = list(val)
105     applied = property(lambda self: self.__applied, __set_applied)
106     def __set_unapplied(self, val):
107         self.__unapplied = list(val)
108     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
109     def __set_hidden(self, val):
110         self.__hidden = list(val)
111     hidden = property(lambda self: self.__hidden, __set_hidden)
112     all_patches = property(lambda self: (self.__applied + self.__unapplied
113                                          + self.__hidden))
114     def __set_base(self, val):
115         assert (not self.__applied
116                 or self.patches[self.applied[0]].data.parent == val)
117         self.__base = val
118     base = property(lambda self: self.__base, __set_base)
119     @property
120     def temp_index(self):
121         if not self.__temp_index:
122             self.__temp_index = self.__stack.repository.temp_index()
123             atexit.register(self.__temp_index.delete)
124         return self.__temp_index
125     def __checkout(self, tree, iw):
126         if not self.__stack.head_top_equal():
127             out.error(
128                 'HEAD and top are not the same.',
129                 'This can happen if you modify a branch with git.',
130                 '"stg repair --help" explains more about what to do next.')
131             self.__abort()
132         if self.__current_tree == tree and not self.__discard_changes:
133             # No tree change, but we still want to make sure that
134             # there are no unresolved conflicts. Conflicts
135             # conceptually "belong" to the topmost patch, and just
136             # carrying them along to another patch is confusing.
137             if (self.__allow_conflicts(self) or iw == None
138                 or not iw.index.conflicts()):
139                 return
140             out.error('Need to resolve conflicts first')
141             self.__abort()
142         assert iw != None
143         if self.__discard_changes:
144             iw.checkout_hard(tree)
145         else:
146             iw.checkout(self.__current_tree, tree)
147         self.__current_tree = tree
148     @staticmethod
149     def __abort():
150         raise TransactionException(
151             'Command aborted (all changes rolled back)')
152     def __check_consistency(self):
153         remaining = set(self.all_patches)
154         for pn, commit in self.__patches.iteritems():
155             if commit == None:
156                 assert self.__stack.patches.exists(pn)
157             else:
158                 assert pn in remaining
159     @property
160     def __head(self):
161         if self.__applied:
162             return self.__patches[self.__applied[-1]]
163         else:
164             return self.__base
165     def abort(self, iw = None):
166         # The only state we need to restore is index+worktree.
167         if iw:
168             self.__checkout(self.__stack.head.data.tree, iw)
169     def run(self, iw = None, set_head = True):
170         """Execute the transaction. Will either succeed, or fail (with an
171         exception) and do nothing."""
172         self.__check_consistency()
173         new_head = self.__head
174
175         # Set branch head.
176         if set_head:
177             if iw:
178                 try:
179                     self.__checkout(new_head.data.tree, iw)
180                 except git.CheckoutException:
181                     # We have to abort the transaction.
182                     self.abort(iw)
183                     self.__abort()
184             self.__stack.set_head(new_head, self.__msg)
185
186         if self.__error:
187             out.error(self.__error)
188
189         # Write patches.
190         def write(msg):
191             for pn, commit in self.__patches.iteritems():
192                 if self.__stack.patches.exists(pn):
193                     p = self.__stack.patches.get(pn)
194                     if commit == None:
195                         p.delete()
196                     else:
197                         p.set_commit(commit, msg)
198                 else:
199                     self.__stack.patches.new(pn, commit, msg)
200             self.__stack.patchorder.applied = self.__applied
201             self.__stack.patchorder.unapplied = self.__unapplied
202             self.__stack.patchorder.hidden = self.__hidden
203             log.log_entry(self.__stack, msg)
204         old_applied = self.__stack.patchorder.applied
205         write(self.__msg)
206         if self.__conflicting_push != None:
207             self.__patches = _TransPatchMap(self.__stack)
208             self.__conflicting_push()
209             write(self.__msg + ' (CONFLICT)')
210         _print_current_patch(old_applied, self.__applied)
211
212         if self.__error:
213             return utils.STGIT_CONFLICT
214         else:
215             return utils.STGIT_SUCCESS
216
217     def __halt(self, msg):
218         self.__error = msg
219         raise TransactionHalted(msg)
220
221     @staticmethod
222     def __print_popped(popped):
223         if len(popped) == 0:
224             pass
225         elif len(popped) == 1:
226             out.info('Popped %s' % popped[0])
227         else:
228             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
229
230     def pop_patches(self, p):
231         """Pop all patches pn for which p(pn) is true. Return the list of
232         other patches that had to be popped to accomplish this. Always
233         succeeds."""
234         popped = []
235         for i in xrange(len(self.applied)):
236             if p(self.applied[i]):
237                 popped = self.applied[i:]
238                 del self.applied[i:]
239                 break
240         popped1 = [pn for pn in popped if not p(pn)]
241         popped2 = [pn for pn in popped if p(pn)]
242         self.unapplied = popped1 + popped2 + self.unapplied
243         self.__print_popped(popped)
244         return popped1
245
246     def delete_patches(self, p):
247         """Delete all patches pn for which p(pn) is true. Return the list of
248         other patches that had to be popped to accomplish this. Always
249         succeeds."""
250         popped = []
251         all_patches = self.applied + self.unapplied + self.hidden
252         for i in xrange(len(self.applied)):
253             if p(self.applied[i]):
254                 popped = self.applied[i:]
255                 del self.applied[i:]
256                 break
257         popped = [pn for pn in popped if not p(pn)]
258         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
259         self.hidden = [pn for pn in self.hidden if not p(pn)]
260         self.__print_popped(popped)
261         for pn in all_patches:
262             if p(pn):
263                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
264                 self.patches[pn] = None
265                 out.info('Deleted %s%s' % (pn, s))
266         return popped
267
268     def push_patch(self, pn, iw = None):
269         """Attempt to push the named patch. If this results in conflicts,
270         halts the transaction. If index+worktree are given, spill any
271         conflicts to them."""
272         orig_cd = self.patches[pn].data
273         cd = orig_cd.set_committer(None)
274         s = ['', ' (empty)'][cd.is_nochange()]
275         oldparent = cd.parent
276         cd = cd.set_parent(self.__head)
277         base = oldparent.data.tree
278         ours = cd.parent.data.tree
279         theirs = cd.tree
280         tree, self.temp_index_tree = self.temp_index.merge(
281             base, ours, theirs, self.temp_index_tree)
282         merge_conflict = False
283         if not tree:
284             if iw == None:
285                 self.__halt('%s does not apply cleanly' % pn)
286             try:
287                 self.__checkout(ours, iw)
288             except git.CheckoutException:
289                 self.__halt('Index/worktree dirty')
290             try:
291                 iw.merge(base, ours, theirs)
292                 tree = iw.index.write_tree()
293                 self.__current_tree = tree
294                 s = ' (modified)'
295             except git.MergeConflictException:
296                 tree = ours
297                 merge_conflict = True
298                 s = ' (conflict)'
299             except git.MergeException, e:
300                 self.__halt(str(e))
301         cd = cd.set_tree(tree)
302         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
303                ['parent', 'tree', 'author', 'message']):
304             comm = self.__stack.repository.commit(cd)
305         else:
306             comm = None
307             s = ' (unmodified)'
308         out.info('Pushed %s%s' % (pn, s))
309         def update():
310             if comm:
311                 self.patches[pn] = comm
312             if pn in self.hidden:
313                 x = self.hidden
314             else:
315                 x = self.unapplied
316             del x[x.index(pn)]
317             self.applied.append(pn)
318         if merge_conflict:
319             # We've just caused conflicts, so we must allow them in
320             # the final checkout.
321             self.__allow_conflicts = lambda trans: True
322
323             # Save this update so that we can run it a little later.
324             self.__conflicting_push = update
325             self.__halt('Merge conflict')
326         else:
327             # Update immediately.
328             update()
329
330     def reorder_patches(self, applied, unapplied, hidden, iw = None):
331         """Push and pop patches to attain the given ordering."""
332         common = len(list(it.takewhile(lambda (a, b): a == b,
333                                        zip(self.applied, applied))))
334         to_pop = set(self.applied[common:])
335         self.pop_patches(lambda pn: pn in to_pop)
336         for pn in applied[common:]:
337             self.push_patch(pn, iw)
338         assert self.applied == applied
339         assert set(self.unapplied + self.hidden) == set(unapplied + hidden)
340         self.unapplied = unapplied
341         self.hidden = hidden