chiark / gitweb /
874f81bd582d7cca1eb1dbe429848618fcd14d81
[stgit] / stgit / lib / transaction.py
1 from stgit import exception, utils
2 from stgit.utils import any, all
3 from stgit.out import *
4 from stgit.lib import git
5
6 class TransactionException(exception.StgException):
7     pass
8
9 class TransactionHalted(TransactionException):
10     pass
11
12 def _print_current_patch(old_applied, new_applied):
13     def now_at(pn):
14         out.info('Now at patch "%s"' % pn)
15     if not old_applied and not new_applied:
16         pass
17     elif not old_applied:
18         now_at(new_applied[-1])
19     elif not new_applied:
20         out.info('No patch applied')
21     elif old_applied[-1] == new_applied[-1]:
22         pass
23     else:
24         now_at(new_applied[-1])
25
26 class _TransPatchMap(dict):
27     def __init__(self, stack):
28         dict.__init__(self)
29         self.__stack = stack
30     def __getitem__(self, pn):
31         try:
32             return dict.__getitem__(self, pn)
33         except KeyError:
34             return self.__stack.patches.get(pn).commit
35
36 class StackTransaction(object):
37     def __init__(self, stack, msg, allow_conflicts = False):
38         self.__stack = stack
39         self.__msg = msg
40         self.__patches = _TransPatchMap(stack)
41         self.__applied = list(self.__stack.patchorder.applied)
42         self.__unapplied = list(self.__stack.patchorder.unapplied)
43         self.__error = None
44         self.__current_tree = self.__stack.head.data.tree
45         self.__base = self.__stack.base
46         if isinstance(allow_conflicts, bool):
47             self.__allow_conflicts = lambda trans: allow_conflicts
48         else:
49             self.__allow_conflicts = allow_conflicts
50     stack = property(lambda self: self.__stack)
51     patches = property(lambda self: self.__patches)
52     def __set_applied(self, val):
53         self.__applied = list(val)
54     applied = property(lambda self: self.__applied, __set_applied)
55     def __set_unapplied(self, val):
56         self.__unapplied = list(val)
57     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
58     def __set_base(self, val):
59         assert (not self.__applied
60                 or self.patches[self.applied[0]].data.parent == val)
61         self.__base = val
62     base = property(lambda self: self.__base, __set_base)
63     def __checkout(self, tree, iw):
64         if not self.__stack.head_top_equal():
65             out.error(
66                 'HEAD and top are not the same.',
67                 'This can happen if you modify a branch with git.',
68                 '"stg repair --help" explains more about what to do next.')
69             self.__abort()
70         if self.__current_tree == tree:
71             # No tree change, but we still want to make sure that
72             # there are no unresolved conflicts. Conflicts
73             # conceptually "belong" to the topmost patch, and just
74             # carrying them along to another patch is confusing.
75             if (self.__allow_conflicts(self) or iw == None
76                 or not iw.index.conflicts()):
77                 return
78             out.error('Need to resolve conflicts first')
79             self.__abort()
80         assert iw != None
81         iw.checkout(self.__current_tree, tree)
82         self.__current_tree = tree
83     @staticmethod
84     def __abort():
85         raise TransactionException(
86             'Command aborted (all changes rolled back)')
87     def __check_consistency(self):
88         remaining = set(self.__applied + self.__unapplied)
89         for pn, commit in self.__patches.iteritems():
90             if commit == None:
91                 assert self.__stack.patches.exists(pn)
92             else:
93                 assert pn in remaining
94     @property
95     def __head(self):
96         if self.__applied:
97             return self.__patches[self.__applied[-1]]
98         else:
99             return self.__base
100     def abort(self, iw = None):
101         # The only state we need to restore is index+worktree.
102         if iw:
103             self.__checkout(self.__stack.head.data.tree, iw)
104     def run(self, iw = None):
105         self.__check_consistency()
106         new_head = self.__head
107
108         # Set branch head.
109         if iw:
110             try:
111                 self.__checkout(new_head.data.tree, iw)
112             except git.CheckoutException:
113                 # We have to abort the transaction.
114                 self.abort(iw)
115                 self.__abort()
116         self.__stack.set_head(new_head, self.__msg)
117
118         if self.__error:
119             out.error(self.__error)
120
121         # Write patches.
122         for pn, commit in self.__patches.iteritems():
123             if self.__stack.patches.exists(pn):
124                 p = self.__stack.patches.get(pn)
125                 if commit == None:
126                     p.delete()
127                 else:
128                     p.set_commit(commit, self.__msg)
129             else:
130                 self.__stack.patches.new(pn, commit, self.__msg)
131         _print_current_patch(self.__stack.patchorder.applied, self.__applied)
132         self.__stack.patchorder.applied = self.__applied
133         self.__stack.patchorder.unapplied = self.__unapplied
134
135         if self.__error:
136             return utils.STGIT_CONFLICT
137         else:
138             return utils.STGIT_SUCCESS
139
140     def __halt(self, msg):
141         self.__error = msg
142         raise TransactionHalted(msg)
143
144     @staticmethod
145     def __print_popped(popped):
146         if len(popped) == 0:
147             pass
148         elif len(popped) == 1:
149             out.info('Popped %s' % popped[0])
150         else:
151             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
152
153     def pop_patches(self, p):
154         """Pop all patches pn for which p(pn) is true. Return the list of
155         other patches that had to be popped to accomplish this."""
156         popped = []
157         for i in xrange(len(self.applied)):
158             if p(self.applied[i]):
159                 popped = self.applied[i:]
160                 del self.applied[i:]
161                 break
162         popped1 = [pn for pn in popped if not p(pn)]
163         popped2 = [pn for pn in popped if p(pn)]
164         self.unapplied = popped1 + popped2 + self.unapplied
165         self.__print_popped(popped)
166         return popped1
167
168     def delete_patches(self, p):
169         """Delete all patches pn for which p(pn) is true. Return the list of
170         other patches that had to be popped to accomplish this."""
171         popped = []
172         all_patches = self.applied + self.unapplied
173         for i in xrange(len(self.applied)):
174             if p(self.applied[i]):
175                 popped = self.applied[i:]
176                 del self.applied[i:]
177                 break
178         popped = [pn for pn in popped if not p(pn)]
179         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
180         self.__print_popped(popped)
181         for pn in all_patches:
182             if p(pn):
183                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
184                 self.patches[pn] = None
185                 out.info('Deleted %s%s' % (pn, s))
186         return popped
187
188     def push_patch(self, pn, iw = None):
189         """Attempt to push the named patch. If this results in conflicts,
190         halts the transaction. If index+worktree are given, spill any
191         conflicts to them."""
192         orig_cd = self.patches[pn].data
193         cd = orig_cd.set_committer(None)
194         s = ['', ' (empty)'][cd.is_nochange()]
195         oldparent = cd.parent
196         cd = cd.set_parent(self.__head)
197         base = oldparent.data.tree
198         ours = cd.parent.data.tree
199         theirs = cd.tree
200         tree = self.__stack.repository.simple_merge(base, ours, theirs)
201         merge_conflict = False
202         if not tree:
203             if iw == None:
204                 self.__halt('%s does not apply cleanly' % pn)
205             try:
206                 self.__checkout(ours, iw)
207             except git.CheckoutException:
208                 self.__halt('Index/worktree dirty')
209             try:
210                 iw.merge(base, ours, theirs)
211                 tree = iw.index.write_tree()
212                 self.__current_tree = tree
213                 s = ' (modified)'
214             except git.MergeConflictException:
215                 tree = ours
216                 merge_conflict = True
217                 s = ' (conflict)'
218             except git.MergeException, e:
219                 self.__halt(str(e))
220         cd = cd.set_tree(tree)
221         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
222                ['parent', 'tree', 'author', 'message']):
223             self.patches[pn] = self.__stack.repository.commit(cd)
224         else:
225             s = ' (unmodified)'
226         del self.unapplied[self.unapplied.index(pn)]
227         self.applied.append(pn)
228         out.info('Pushed %s%s' % (pn, s))
229         if merge_conflict:
230             # We've just caused conflicts, so we must allow them in
231             # the final checkout.
232             self.__allow_conflicts = lambda trans: True
233
234             self.__halt('Merge conflict')