chiark / gitweb /
20d87e156752b3872f7c3d528e8b361800b9bc40
[stgit] / stgit / lib / transaction.py
1 """The L{StackTransaction} class makes it possible to make complex
2 updates to an StGit stack in a safe and convenient way."""
3
4 import atexit
5 import itertools as it
6
7 from stgit import exception, utils
8 from stgit.utils import any, all
9 from stgit.out import *
10 from stgit.lib import git, log
11
12 class TransactionException(exception.StgException):
13     """Exception raised when something goes wrong with a
14     L{StackTransaction}."""
15
16 class TransactionHalted(TransactionException):
17     """Exception raised when a L{StackTransaction} stops part-way through.
18     Used to make a non-local jump from the transaction setup to the
19     part of the transaction code where the transaction is run."""
20
21 def _print_current_patch(old_applied, new_applied):
22     def now_at(pn):
23         out.info('Now at patch "%s"' % pn)
24     if not old_applied and not new_applied:
25         pass
26     elif not old_applied:
27         now_at(new_applied[-1])
28     elif not new_applied:
29         out.info('No patch applied')
30     elif old_applied[-1] == new_applied[-1]:
31         pass
32     else:
33         now_at(new_applied[-1])
34
35 class _TransPatchMap(dict):
36     """Maps patch names to sha1 strings."""
37     def __init__(self, stack):
38         dict.__init__(self)
39         self.__stack = stack
40     def __getitem__(self, pn):
41         try:
42             return dict.__getitem__(self, pn)
43         except KeyError:
44             return self.__stack.patches.get(pn).commit
45
46 class StackTransaction(object):
47     """A stack transaction, used for making complex updates to an StGit
48     stack in one single operation that will either succeed or fail
49     cleanly.
50
51     The basic theory of operation is the following:
52
53       1. Create a transaction object.
54
55       2. Inside a::
56
57          try
58            ...
59          except TransactionHalted:
60            pass
61
62       block, update the transaction with e.g. methods like
63       L{pop_patches} and L{push_patch}. This may create new git
64       objects such as commits, but will not write any refs; this means
65       that in case of a fatal error we can just walk away, no clean-up
66       required.
67
68       (Some operations may need to touch your index and working tree,
69       though. But they are cleaned up when needed.)
70
71       3. After the C{try} block -- wheher or not the setup ran to
72       completion or halted part-way through by raising a
73       L{TransactionHalted} exception -- call the transaction's L{run}
74       method. This will either succeed in writing the updated state to
75       your refs and index+worktree, or fail without having done
76       anything."""
77     def __init__(self, stack, msg, allow_conflicts = False):
78         self.__stack = stack
79         self.__msg = msg
80         self.__patches = _TransPatchMap(stack)
81         self.__applied = list(self.__stack.patchorder.applied)
82         self.__unapplied = list(self.__stack.patchorder.unapplied)
83         self.__hidden = list(self.__stack.patchorder.hidden)
84         self.__conflicting_push = None
85         self.__error = None
86         self.__current_tree = self.__stack.head.data.tree
87         self.__base = self.__stack.base
88         if isinstance(allow_conflicts, bool):
89             self.__allow_conflicts = lambda trans: allow_conflicts
90         else:
91             self.__allow_conflicts = allow_conflicts
92         self.__temp_index = self.temp_index_tree = None
93     stack = property(lambda self: self.__stack)
94     patches = property(lambda self: self.__patches)
95     def __set_applied(self, val):
96         self.__applied = list(val)
97     applied = property(lambda self: self.__applied, __set_applied)
98     def __set_unapplied(self, val):
99         self.__unapplied = list(val)
100     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
101     def __set_hidden(self, val):
102         self.__hidden = list(val)
103     hidden = property(lambda self: self.__hidden, __set_hidden)
104     def __set_base(self, val):
105         assert (not self.__applied
106                 or self.patches[self.applied[0]].data.parent == val)
107         self.__base = val
108     base = property(lambda self: self.__base, __set_base)
109     @property
110     def temp_index(self):
111         if not self.__temp_index:
112             self.__temp_index = self.__stack.repository.temp_index()
113             atexit.register(self.__temp_index.delete)
114         return self.__temp_index
115     def __checkout(self, tree, iw):
116         if not self.__stack.head_top_equal():
117             out.error(
118                 'HEAD and top are not the same.',
119                 'This can happen if you modify a branch with git.',
120                 '"stg repair --help" explains more about what to do next.')
121             self.__abort()
122         if self.__current_tree == tree:
123             # No tree change, but we still want to make sure that
124             # there are no unresolved conflicts. Conflicts
125             # conceptually "belong" to the topmost patch, and just
126             # carrying them along to another patch is confusing.
127             if (self.__allow_conflicts(self) or iw == None
128                 or not iw.index.conflicts()):
129                 return
130             out.error('Need to resolve conflicts first')
131             self.__abort()
132         assert iw != None
133         iw.checkout(self.__current_tree, tree)
134         self.__current_tree = tree
135     @staticmethod
136     def __abort():
137         raise TransactionException(
138             'Command aborted (all changes rolled back)')
139     def __check_consistency(self):
140         remaining = set(self.__applied + self.__unapplied + self.__hidden)
141         for pn, commit in self.__patches.iteritems():
142             if commit == None:
143                 assert self.__stack.patches.exists(pn)
144             else:
145                 assert pn in remaining
146     @property
147     def __head(self):
148         if self.__applied:
149             return self.__patches[self.__applied[-1]]
150         else:
151             return self.__base
152     def abort(self, iw = None):
153         # The only state we need to restore is index+worktree.
154         if iw:
155             self.__checkout(self.__stack.head.data.tree, iw)
156     def run(self, iw = None, set_head = True):
157         """Execute the transaction. Will either succeed, or fail (with an
158         exception) and do nothing."""
159         self.__check_consistency()
160         new_head = self.__head
161
162         # Set branch head.
163         if set_head:
164             if iw:
165                 try:
166                     self.__checkout(new_head.data.tree, iw)
167                 except git.CheckoutException:
168                     # We have to abort the transaction.
169                     self.abort(iw)
170                     self.__abort()
171             self.__stack.set_head(new_head, self.__msg)
172
173         if self.__error:
174             out.error(self.__error)
175
176         # Write patches.
177         def write(msg):
178             for pn, commit in self.__patches.iteritems():
179                 if self.__stack.patches.exists(pn):
180                     p = self.__stack.patches.get(pn)
181                     if commit == None:
182                         p.delete()
183                     else:
184                         p.set_commit(commit, msg)
185                 else:
186                     self.__stack.patches.new(pn, commit, msg)
187             self.__stack.patchorder.applied = self.__applied
188             self.__stack.patchorder.unapplied = self.__unapplied
189             self.__stack.patchorder.hidden = self.__hidden
190             log.log_entry(self.__stack, msg)
191         old_applied = self.__stack.patchorder.applied
192         write(self.__msg)
193         if self.__conflicting_push != None:
194             self.__patches = _TransPatchMap(self.__stack)
195             self.__conflicting_push()
196             write(self.__msg + ' (CONFLICT)')
197         _print_current_patch(old_applied, self.__applied)
198
199         if self.__error:
200             return utils.STGIT_CONFLICT
201         else:
202             return utils.STGIT_SUCCESS
203
204     def __halt(self, msg):
205         self.__error = msg
206         raise TransactionHalted(msg)
207
208     @staticmethod
209     def __print_popped(popped):
210         if len(popped) == 0:
211             pass
212         elif len(popped) == 1:
213             out.info('Popped %s' % popped[0])
214         else:
215             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
216
217     def pop_patches(self, p):
218         """Pop all patches pn for which p(pn) is true. Return the list of
219         other patches that had to be popped to accomplish this. Always
220         succeeds."""
221         popped = []
222         for i in xrange(len(self.applied)):
223             if p(self.applied[i]):
224                 popped = self.applied[i:]
225                 del self.applied[i:]
226                 break
227         popped1 = [pn for pn in popped if not p(pn)]
228         popped2 = [pn for pn in popped if p(pn)]
229         self.unapplied = popped1 + popped2 + self.unapplied
230         self.__print_popped(popped)
231         return popped1
232
233     def delete_patches(self, p):
234         """Delete all patches pn for which p(pn) is true. Return the list of
235         other patches that had to be popped to accomplish this. Always
236         succeeds."""
237         popped = []
238         all_patches = self.applied + self.unapplied + self.hidden
239         for i in xrange(len(self.applied)):
240             if p(self.applied[i]):
241                 popped = self.applied[i:]
242                 del self.applied[i:]
243                 break
244         popped = [pn for pn in popped if not p(pn)]
245         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
246         self.hidden = [pn for pn in self.hidden if not p(pn)]
247         self.__print_popped(popped)
248         for pn in all_patches:
249             if p(pn):
250                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
251                 self.patches[pn] = None
252                 out.info('Deleted %s%s' % (pn, s))
253         return popped
254
255     def push_patch(self, pn, iw = None):
256         """Attempt to push the named patch. If this results in conflicts,
257         halts the transaction. If index+worktree are given, spill any
258         conflicts to them."""
259         orig_cd = self.patches[pn].data
260         cd = orig_cd.set_committer(None)
261         s = ['', ' (empty)'][cd.is_nochange()]
262         oldparent = cd.parent
263         cd = cd.set_parent(self.__head)
264         base = oldparent.data.tree
265         ours = cd.parent.data.tree
266         theirs = cd.tree
267         tree, self.temp_index_tree = self.temp_index.merge(
268             base, ours, theirs, self.temp_index_tree)
269         merge_conflict = False
270         if not tree:
271             if iw == None:
272                 self.__halt('%s does not apply cleanly' % pn)
273             try:
274                 self.__checkout(ours, iw)
275             except git.CheckoutException:
276                 self.__halt('Index/worktree dirty')
277             try:
278                 iw.merge(base, ours, theirs)
279                 tree = iw.index.write_tree()
280                 self.__current_tree = tree
281                 s = ' (modified)'
282             except git.MergeConflictException:
283                 tree = ours
284                 merge_conflict = True
285                 s = ' (conflict)'
286             except git.MergeException, e:
287                 self.__halt(str(e))
288         cd = cd.set_tree(tree)
289         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
290                ['parent', 'tree', 'author', 'message']):
291             comm = self.__stack.repository.commit(cd)
292         else:
293             comm = None
294             s = ' (unmodified)'
295         out.info('Pushed %s%s' % (pn, s))
296         def update():
297             if comm:
298                 self.patches[pn] = comm
299             if pn in self.hidden:
300                 x = self.hidden
301             else:
302                 x = self.unapplied
303             del x[x.index(pn)]
304             self.applied.append(pn)
305         if merge_conflict:
306             # We've just caused conflicts, so we must allow them in
307             # the final checkout.
308             self.__allow_conflicts = lambda trans: True
309
310             # Save this update so that we can run it a little later.
311             self.__conflicting_push = update
312             self.__halt('Merge conflict')
313         else:
314             # Update immediately.
315             update()
316
317     def reorder_patches(self, applied, unapplied, hidden, iw = None):
318         """Push and pop patches to attain the given ordering."""
319         common = len(list(it.takewhile(lambda (a, b): a == b,
320                                        zip(self.applied, applied))))
321         to_pop = set(self.applied[common:])
322         self.pop_patches(lambda pn: pn in to_pop)
323         for pn in applied[common:]:
324             self.push_patch(pn, iw)
325         assert self.applied == applied
326         assert set(self.unapplied + self.hidden) == set(unapplied + hidden)
327         self.unapplied = unapplied
328         self.hidden = hidden