chiark / gitweb /
1ece01ea3e55996d8ccb82351a85b016d58fa865
[stgit] / stgit / lib / transaction.py
1 from stgit import exception, utils
2 from stgit.utils import any, all
3 from stgit.out import *
4 from stgit.lib import git
5
6 class TransactionException(exception.StgException):
7     pass
8
9 class TransactionHalted(TransactionException):
10     pass
11
12 def _print_current_patch(old_applied, new_applied):
13     def now_at(pn):
14         out.info('Now at patch "%s"' % pn)
15     if not old_applied and not new_applied:
16         pass
17     elif not old_applied:
18         now_at(new_applied[-1])
19     elif not new_applied:
20         out.info('No patch applied')
21     elif old_applied[-1] == new_applied[-1]:
22         pass
23     else:
24         now_at(new_applied[-1])
25
26 class _TransPatchMap(dict):
27     def __init__(self, stack):
28         dict.__init__(self)
29         self.__stack = stack
30     def __getitem__(self, pn):
31         try:
32             return dict.__getitem__(self, pn)
33         except KeyError:
34             return self.__stack.patches.get(pn).commit
35
36 class StackTransaction(object):
37     def __init__(self, stack, msg):
38         self.__stack = stack
39         self.__msg = msg
40         self.__patches = _TransPatchMap(stack)
41         self.__applied = list(self.__stack.patchorder.applied)
42         self.__unapplied = list(self.__stack.patchorder.unapplied)
43         self.__error = None
44         self.__current_tree = self.__stack.head.data.tree
45         self.__base = self.__stack.base
46     stack = property(lambda self: self.__stack)
47     patches = property(lambda self: self.__patches)
48     def __set_applied(self, val):
49         self.__applied = list(val)
50     applied = property(lambda self: self.__applied, __set_applied)
51     def __set_unapplied(self, val):
52         self.__unapplied = list(val)
53     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
54     def __set_base(self, val):
55         assert (not self.__applied
56                 or self.patches[self.applied[0]].data.parent == val)
57         self.__base = val
58     base = property(lambda self: self.__base, __set_base)
59     def __checkout(self, tree, iw):
60         if not self.__stack.head_top_equal():
61             out.error(
62                 'HEAD and top are not the same.',
63                 'This can happen if you modify a branch with git.',
64                 '"stg repair --help" explains more about what to do next.')
65             self.__abort()
66         if self.__current_tree != tree:
67             assert iw != None
68             iw.checkout(self.__current_tree, tree)
69             self.__current_tree = tree
70     @staticmethod
71     def __abort():
72         raise TransactionException(
73             'Command aborted (all changes rolled back)')
74     def __check_consistency(self):
75         remaining = set(self.__applied + self.__unapplied)
76         for pn, commit in self.__patches.iteritems():
77             if commit == None:
78                 assert self.__stack.patches.exists(pn)
79             else:
80                 assert pn in remaining
81     @property
82     def __head(self):
83         if self.__applied:
84             return self.__patches[self.__applied[-1]]
85         else:
86             return self.__base
87     def abort(self, iw = None):
88         # The only state we need to restore is index+worktree.
89         if iw:
90             self.__checkout(self.__stack.head.data.tree, iw)
91     def run(self, iw = None):
92         self.__check_consistency()
93         new_head = self.__head
94
95         # Set branch head.
96         if iw:
97             try:
98                 self.__checkout(new_head.data.tree, iw)
99             except git.CheckoutException:
100                 # We have to abort the transaction.
101                 self.abort(iw)
102                 self.__abort()
103         self.__stack.set_head(new_head, self.__msg)
104
105         if self.__error:
106             out.error(self.__error)
107
108         # Write patches.
109         for pn, commit in self.__patches.iteritems():
110             if self.__stack.patches.exists(pn):
111                 p = self.__stack.patches.get(pn)
112                 if commit == None:
113                     p.delete()
114                 else:
115                     p.set_commit(commit, self.__msg)
116             else:
117                 self.__stack.patches.new(pn, commit, self.__msg)
118         _print_current_patch(self.__stack.patchorder.applied, self.__applied)
119         self.__stack.patchorder.applied = self.__applied
120         self.__stack.patchorder.unapplied = self.__unapplied
121
122         if self.__error:
123             return utils.STGIT_CONFLICT
124         else:
125             return utils.STGIT_SUCCESS
126
127     def __halt(self, msg):
128         self.__error = msg
129         raise TransactionHalted(msg)
130
131     @staticmethod
132     def __print_popped(popped):
133         if len(popped) == 0:
134             pass
135         elif len(popped) == 1:
136             out.info('Popped %s' % popped[0])
137         else:
138             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
139
140     def pop_patches(self, p):
141         """Pop all patches pn for which p(pn) is true. Return the list of
142         other patches that had to be popped to accomplish this."""
143         popped = []
144         for i in xrange(len(self.applied)):
145             if p(self.applied[i]):
146                 popped = self.applied[i:]
147                 del self.applied[i:]
148                 break
149         popped1 = [pn for pn in popped if not p(pn)]
150         popped2 = [pn for pn in popped if p(pn)]
151         self.unapplied = popped1 + popped2 + self.unapplied
152         self.__print_popped(popped)
153         return popped1
154
155     def delete_patches(self, p):
156         """Delete all patches pn for which p(pn) is true. Return the list of
157         other patches that had to be popped to accomplish this."""
158         popped = []
159         all_patches = self.applied + self.unapplied
160         for i in xrange(len(self.applied)):
161             if p(self.applied[i]):
162                 popped = self.applied[i:]
163                 del self.applied[i:]
164                 break
165         popped = [pn for pn in popped if not p(pn)]
166         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
167         self.__print_popped(popped)
168         for pn in all_patches:
169             if p(pn):
170                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
171                 self.patches[pn] = None
172                 out.info('Deleted %s%s' % (pn, s))
173         return popped
174
175     def push_patch(self, pn, iw = None):
176         """Attempt to push the named patch. If this results in conflicts,
177         halts the transaction. If index+worktree are given, spill any
178         conflicts to them."""
179         orig_cd = self.patches[pn].data
180         cd = orig_cd.set_committer(None)
181         s = ['', ' (empty)'][cd.is_nochange()]
182         oldparent = cd.parent
183         cd = cd.set_parent(self.__head)
184         base = oldparent.data.tree
185         ours = cd.parent.data.tree
186         theirs = cd.tree
187         tree = self.__stack.repository.simple_merge(base, ours, theirs)
188         merge_conflict = False
189         if not tree:
190             if iw == None:
191                 self.__halt('%s does not apply cleanly' % pn)
192             try:
193                 self.__checkout(ours, iw)
194             except git.CheckoutException:
195                 self.__halt('Index/worktree dirty')
196             try:
197                 iw.merge(base, ours, theirs)
198                 tree = iw.index.write_tree()
199                 self.__current_tree = tree
200                 s = ' (modified)'
201             except git.MergeConflictException:
202                 tree = ours
203                 merge_conflict = True
204                 s = ' (conflict)'
205             except git.MergeException, e:
206                 self.__halt(str(e))
207         cd = cd.set_tree(tree)
208         if any(getattr(cd, a) != getattr(orig_cd, a) for a in
209                ['parent', 'tree', 'author', 'message']):
210             self.patches[pn] = self.__stack.repository.commit(cd)
211         else:
212             s = ' (unmodified)'
213         del self.unapplied[self.unapplied.index(pn)]
214         self.applied.append(pn)
215         out.info('Pushed %s%s' % (pn, s))
216         if merge_conflict:
217             self.__halt('Merge conflict')