chiark / gitweb /
2946a67840be7c63fa90a9d5e233b567f30dbe26
[stgit] / stgit / lib / transaction.py
1 from stgit import exception, utils
2 from stgit.out import *
3 from stgit.lib import git
4
5 class TransactionException(exception.StgException):
6     pass
7
8 class TransactionHalted(TransactionException):
9     pass
10
11 def _print_current_patch(old_applied, new_applied):
12     def now_at(pn):
13         out.info('Now at patch "%s"' % pn)
14     if not old_applied and not new_applied:
15         pass
16     elif not old_applied:
17         now_at(new_applied[-1])
18     elif not new_applied:
19         out.info('No patch applied')
20     elif old_applied[-1] == new_applied[-1]:
21         pass
22     else:
23         now_at(new_applied[-1])
24
25 class _TransPatchMap(dict):
26     def __init__(self, stack):
27         dict.__init__(self)
28         self.__stack = stack
29     def __getitem__(self, pn):
30         try:
31             return dict.__getitem__(self, pn)
32         except KeyError:
33             return self.__stack.patches.get(pn).commit
34
35 class StackTransaction(object):
36     def __init__(self, stack, msg):
37         self.__stack = stack
38         self.__msg = msg
39         self.__patches = _TransPatchMap(stack)
40         self.__applied = list(self.__stack.patchorder.applied)
41         self.__unapplied = list(self.__stack.patchorder.unapplied)
42         self.__error = None
43         self.__current_tree = self.__stack.head.data.tree
44         self.__base = self.__stack.base
45     stack = property(lambda self: self.__stack)
46     patches = property(lambda self: self.__patches)
47     def __set_applied(self, val):
48         self.__applied = list(val)
49     applied = property(lambda self: self.__applied, __set_applied)
50     def __set_unapplied(self, val):
51         self.__unapplied = list(val)
52     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
53     def __set_base(self, val):
54         assert (not self.__applied
55                 or self.patches[self.applied[0]].data.parent == val)
56         self.__base = val
57     base = property(lambda self: self.__base, __set_base)
58     def __checkout(self, tree, iw):
59         if not self.__stack.head_top_equal():
60             out.error(
61                 'HEAD and top are not the same.',
62                 'This can happen if you modify a branch with git.',
63                 '"stg repair --help" explains more about what to do next.')
64             self.__abort()
65         if self.__current_tree != tree:
66             assert iw != None
67             iw.checkout(self.__current_tree, tree)
68             self.__current_tree = tree
69     @staticmethod
70     def __abort():
71         raise TransactionException(
72             'Command aborted (all changes rolled back)')
73     def __check_consistency(self):
74         remaining = set(self.__applied + self.__unapplied)
75         for pn, commit in self.__patches.iteritems():
76             if commit == None:
77                 assert self.__stack.patches.exists(pn)
78             else:
79                 assert pn in remaining
80     @property
81     def __head(self):
82         if self.__applied:
83             return self.__patches[self.__applied[-1]]
84         else:
85             return self.__base
86     def abort(self, iw = None):
87         # The only state we need to restore is index+worktree.
88         if iw:
89             self.__checkout(self.__stack.head.data.tree, iw)
90     def run(self, iw = None):
91         self.__check_consistency()
92         new_head = self.__head
93
94         # Set branch head.
95         if iw:
96             try:
97                 self.__checkout(new_head.data.tree, iw)
98             except git.CheckoutException:
99                 # We have to abort the transaction.
100                 self.abort(iw)
101                 self.__abort()
102         self.__stack.set_head(new_head, self.__msg)
103
104         if self.__error:
105             out.error(self.__error)
106
107         # Write patches.
108         for pn, commit in self.__patches.iteritems():
109             if self.__stack.patches.exists(pn):
110                 p = self.__stack.patches.get(pn)
111                 if commit == None:
112                     p.delete()
113                 else:
114                     p.set_commit(commit, self.__msg)
115             else:
116                 self.__stack.patches.new(pn, commit, self.__msg)
117         _print_current_patch(self.__stack.patchorder.applied, self.__applied)
118         self.__stack.patchorder.applied = self.__applied
119         self.__stack.patchorder.unapplied = self.__unapplied
120
121         if self.__error:
122             return utils.STGIT_CONFLICT
123         else:
124             return utils.STGIT_SUCCESS
125
126     def __halt(self, msg):
127         self.__error = msg
128         raise TransactionHalted(msg)
129
130     @staticmethod
131     def __print_popped(popped):
132         if len(popped) == 0:
133             pass
134         elif len(popped) == 1:
135             out.info('Popped %s' % popped[0])
136         else:
137             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
138
139     def pop_patches(self, p):
140         """Pop all patches pn for which p(pn) is true. Return the list of
141         other patches that had to be popped to accomplish this."""
142         popped = []
143         for i in xrange(len(self.applied)):
144             if p(self.applied[i]):
145                 popped = self.applied[i:]
146                 del self.applied[i:]
147                 break
148         popped1 = [pn for pn in popped if not p(pn)]
149         popped2 = [pn for pn in popped if p(pn)]
150         self.unapplied = popped1 + popped2 + self.unapplied
151         self.__print_popped(popped)
152         return popped1
153
154     def delete_patches(self, p):
155         """Delete all patches pn for which p(pn) is true. Return the list of
156         other patches that had to be popped to accomplish this."""
157         popped = []
158         all_patches = self.applied + self.unapplied
159         for i in xrange(len(self.applied)):
160             if p(self.applied[i]):
161                 popped = self.applied[i:]
162                 del self.applied[i:]
163                 break
164         popped = [pn for pn in popped if not p(pn)]
165         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
166         self.__print_popped(popped)
167         for pn in all_patches:
168             if p(pn):
169                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
170                 self.patches[pn] = None
171                 out.info('Deleted %s%s' % (pn, s))
172         return popped
173
174     def push_patch(self, pn, iw = None):
175         """Attempt to push the named patch. If this results in conflicts,
176         halts the transaction. If index+worktree are given, spill any
177         conflicts to them."""
178         cd = self.patches[pn].data
179         cd = cd.set_committer(None)
180         s = ['', ' (empty)'][cd.is_nochange()]
181         oldparent = cd.parent
182         cd = cd.set_parent(self.__head)
183         base = oldparent.data.tree
184         ours = cd.parent.data.tree
185         theirs = cd.tree
186         tree = self.__stack.repository.simple_merge(base, ours, theirs)
187         merge_conflict = False
188         if not tree:
189             if iw == None:
190                 self.__halt('%s does not apply cleanly' % pn)
191             try:
192                 self.__checkout(ours, iw)
193             except git.CheckoutException:
194                 self.__halt('Index/worktree dirty')
195             try:
196                 iw.merge(base, ours, theirs)
197                 tree = iw.index.write_tree()
198                 self.__current_tree = tree
199                 s = ' (modified)'
200             except git.MergeConflictException:
201                 tree = ours
202                 merge_conflict = True
203                 s = ' (conflict)'
204             except git.MergeException, e:
205                 self.__halt(str(e))
206         cd = cd.set_tree(tree)
207         self.patches[pn] = self.__stack.repository.commit(cd)
208         del self.unapplied[self.unapplied.index(pn)]
209         self.applied.append(pn)
210         out.info('Pushed %s%s' % (pn, s))
211         if merge_conflict:
212             self.__halt('Merge conflict')