1 from stgit import exception, utils
2 from stgit.utils import any, all
3 from stgit.out import *
4 from stgit.lib import git
6 class TransactionException(exception.StgException):
9 class TransactionHalted(TransactionException):
12 def _print_current_patch(old_applied, new_applied):
14 out.info('Now at patch "%s"' % pn)
15 if not old_applied and not new_applied:
18 now_at(new_applied[-1])
20 out.info('No patch applied')
21 elif old_applied[-1] == new_applied[-1]:
24 now_at(new_applied[-1])
26 class _TransPatchMap(dict):
27 def __init__(self, stack):
30 def __getitem__(self, pn):
32 return dict.__getitem__(self, pn)
34 return self.__stack.patches.get(pn).commit
36 class StackTransaction(object):
37 def __init__(self, stack, msg, allow_conflicts = False):
40 self.__patches = _TransPatchMap(stack)
41 self.__applied = list(self.__stack.patchorder.applied)
42 self.__unapplied = list(self.__stack.patchorder.unapplied)
44 self.__current_tree = self.__stack.head.data.tree
45 self.__base = self.__stack.base
46 if isinstance(allow_conflicts, bool):
47 self.__allow_conflicts = lambda trans: allow_conflicts
49 self.__allow_conflicts = allow_conflicts
50 stack = property(lambda self: self.__stack)
51 patches = property(lambda self: self.__patches)
52 def __set_applied(self, val):
53 self.__applied = list(val)
54 applied = property(lambda self: self.__applied, __set_applied)
55 def __set_unapplied(self, val):
56 self.__unapplied = list(val)
57 unapplied = property(lambda self: self.__unapplied, __set_unapplied)
58 def __set_base(self, val):
59 assert (not self.__applied
60 or self.patches[self.applied[0]].data.parent == val)
62 base = property(lambda self: self.__base, __set_base)
63 def __checkout(self, tree, iw):
64 if not self.__stack.head_top_equal():
66 'HEAD and top are not the same.',
67 'This can happen if you modify a branch with git.',
68 '"stg repair --help" explains more about what to do next.')
70 if self.__current_tree == tree:
71 # No tree change, but we still want to make sure that
72 # there are no unresolved conflicts. Conflicts
73 # conceptually "belong" to the topmost patch, and just
74 # carrying them along to another patch is confusing.
75 if (self.__allow_conflicts(self) or iw == None
76 or not iw.index.conflicts()):
78 out.error('Need to resolve conflicts first')
81 iw.checkout(self.__current_tree, tree)
82 self.__current_tree = tree
85 raise TransactionException(
86 'Command aborted (all changes rolled back)')
87 def __check_consistency(self):
88 remaining = set(self.__applied + self.__unapplied)
89 for pn, commit in self.__patches.iteritems():
91 assert self.__stack.patches.exists(pn)
93 assert pn in remaining
97 return self.__patches[self.__applied[-1]]
100 def abort(self, iw = None):
101 # The only state we need to restore is index+worktree.
103 self.__checkout(self.__stack.head.data.tree, iw)
104 def run(self, iw = None):
105 self.__check_consistency()
106 new_head = self.__head
111 self.__checkout(new_head.data.tree, iw)
112 except git.CheckoutException:
113 # We have to abort the transaction.
116 self.__stack.set_head(new_head, self.__msg)
119 out.error(self.__error)
122 for pn, commit in self.__patches.iteritems():
123 if self.__stack.patches.exists(pn):
124 p = self.__stack.patches.get(pn)
128 p.set_commit(commit, self.__msg)
130 self.__stack.patches.new(pn, commit, self.__msg)
131 _print_current_patch(self.__stack.patchorder.applied, self.__applied)
132 self.__stack.patchorder.applied = self.__applied
133 self.__stack.patchorder.unapplied = self.__unapplied
136 return utils.STGIT_CONFLICT
138 return utils.STGIT_SUCCESS
140 def __halt(self, msg):
142 raise TransactionHalted(msg)
145 def __print_popped(popped):
148 elif len(popped) == 1:
149 out.info('Popped %s' % popped[0])
151 out.info('Popped %s -- %s' % (popped[-1], popped[0]))
153 def pop_patches(self, p):
154 """Pop all patches pn for which p(pn) is true. Return the list of
155 other patches that had to be popped to accomplish this."""
157 for i in xrange(len(self.applied)):
158 if p(self.applied[i]):
159 popped = self.applied[i:]
162 popped1 = [pn for pn in popped if not p(pn)]
163 popped2 = [pn for pn in popped if p(pn)]
164 self.unapplied = popped1 + popped2 + self.unapplied
165 self.__print_popped(popped)
168 def delete_patches(self, p):
169 """Delete all patches pn for which p(pn) is true. Return the list of
170 other patches that had to be popped to accomplish this."""
172 all_patches = self.applied + self.unapplied
173 for i in xrange(len(self.applied)):
174 if p(self.applied[i]):
175 popped = self.applied[i:]
178 popped = [pn for pn in popped if not p(pn)]
179 self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
180 self.__print_popped(popped)
181 for pn in all_patches:
183 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
184 self.patches[pn] = None
185 out.info('Deleted %s%s' % (pn, s))
188 def push_patch(self, pn, iw = None):
189 """Attempt to push the named patch. If this results in conflicts,
190 halts the transaction. If index+worktree are given, spill any
191 conflicts to them."""
192 orig_cd = self.patches[pn].data
193 cd = orig_cd.set_committer(None)
194 s = ['', ' (empty)'][cd.is_nochange()]
195 oldparent = cd.parent
196 cd = cd.set_parent(self.__head)
197 base = oldparent.data.tree
198 ours = cd.parent.data.tree
200 tree = self.__stack.repository.simple_merge(base, ours, theirs)
201 merge_conflict = False
204 self.__halt('%s does not apply cleanly' % pn)
206 self.__checkout(ours, iw)
207 except git.CheckoutException:
208 self.__halt('Index/worktree dirty')
210 iw.merge(base, ours, theirs)
211 tree = iw.index.write_tree()
212 self.__current_tree = tree
214 except git.MergeConflictException:
216 merge_conflict = True
218 except git.MergeException, e:
220 cd = cd.set_tree(tree)
221 if any(getattr(cd, a) != getattr(orig_cd, a) for a in
222 ['parent', 'tree', 'author', 'message']):
223 self.patches[pn] = self.__stack.repository.commit(cd)
226 del self.unapplied[self.unapplied.index(pn)]
227 self.applied.append(pn)
228 out.info('Pushed %s%s' % (pn, s))
230 # We've just caused conflicts, so we must allow them in
231 # the final checkout.
232 self.__allow_conflicts = lambda trans: True
234 self.__halt('Merge conflict')