chiark / gitweb /
663d393f1860debbe38b64ae7cc2106fddc87a18
[stgit] / stgit / lib / transaction.py
1 from stgit import exception
2 from stgit.out import *
3 from stgit.lib import git
4
5 class TransactionException(exception.StgException):
6     pass
7
8 class TransactionHalted(TransactionException):
9     pass
10
11 def _print_current_patch(old_applied, new_applied):
12     def now_at(pn):
13         out.info('Now at patch "%s"' % pn)
14     if not old_applied and not new_applied:
15         pass
16     elif not old_applied:
17         now_at(new_applied[-1])
18     elif not new_applied:
19         out.info('No patch applied')
20     elif old_applied[-1] == new_applied[-1]:
21         pass
22     else:
23         now_at(new_applied[-1])
24
25 class _TransPatchMap(dict):
26     def __init__(self, stack):
27         dict.__init__(self)
28         self.__stack = stack
29     def __getitem__(self, pn):
30         try:
31             return dict.__getitem__(self, pn)
32         except KeyError:
33             return self.__stack.patches.get(pn).commit
34
35 class StackTransaction(object):
36     def __init__(self, stack, msg):
37         self.__stack = stack
38         self.__msg = msg
39         self.__patches = _TransPatchMap(stack)
40         self.__applied = list(self.__stack.patchorder.applied)
41         self.__unapplied = list(self.__stack.patchorder.unapplied)
42         self.__error = None
43         self.__current_tree = self.__stack.head.data.tree
44     stack = property(lambda self: self.__stack)
45     patches = property(lambda self: self.__patches)
46     def __set_applied(self, val):
47         self.__applied = list(val)
48     applied = property(lambda self: self.__applied, __set_applied)
49     def __set_unapplied(self, val):
50         self.__unapplied = list(val)
51     unapplied = property(lambda self: self.__unapplied, __set_unapplied)
52     def __checkout(self, tree, iw):
53         if not self.__stack.head_top_equal():
54             out.error(
55                 'HEAD and top are not the same.',
56                 'This can happen if you modify a branch with git.',
57                 '"stg repair --help" explains more about what to do next.')
58             self.__abort()
59         if self.__current_tree != tree:
60             assert iw != None
61             iw.checkout(self.__current_tree, tree)
62             self.__current_tree = tree
63     @staticmethod
64     def __abort():
65         raise TransactionException(
66             'Command aborted (all changes rolled back)')
67     def __check_consistency(self):
68         remaining = set(self.__applied + self.__unapplied)
69         for pn, commit in self.__patches.iteritems():
70             if commit == None:
71                 assert self.__stack.patches.exists(pn)
72             else:
73                 assert pn in remaining
74     @property
75     def __head(self):
76         if self.__applied:
77             return self.__patches[self.__applied[-1]]
78         else:
79             return self.__stack.base
80     def abort(self, iw = None):
81         # The only state we need to restore is index+worktree.
82         if iw:
83             self.__checkout(self.__stack.head.data.tree, iw)
84     def run(self, iw = None):
85         self.__check_consistency()
86         new_head = self.__head
87
88         # Set branch head.
89         try:
90             self.__checkout(new_head.data.tree, iw)
91         except git.CheckoutException:
92             # We have to abort the transaction.
93             self.abort(iw)
94             self.__abort()
95         self.__stack.set_head(new_head, self.__msg)
96
97         if self.__error:
98             out.error(self.__error)
99
100         # Write patches.
101         for pn, commit in self.__patches.iteritems():
102             if self.__stack.patches.exists(pn):
103                 p = self.__stack.patches.get(pn)
104                 if commit == None:
105                     p.delete()
106                 else:
107                     p.set_commit(commit, self.__msg)
108             else:
109                 self.__stack.patches.new(pn, commit, self.__msg)
110         _print_current_patch(self.__stack.patchorder.applied, self.__applied)
111         self.__stack.patchorder.applied = self.__applied
112         self.__stack.patchorder.unapplied = self.__unapplied
113
114     def __halt(self, msg):
115         self.__error = msg
116         raise TransactionHalted(msg)
117
118     @staticmethod
119     def __print_popped(popped):
120         if len(popped) == 0:
121             pass
122         elif len(popped) == 1:
123             out.info('Popped %s' % popped[0])
124         else:
125             out.info('Popped %s -- %s' % (popped[-1], popped[0]))
126
127     def pop_patches(self, p):
128         """Pop all patches pn for which p(pn) is true. Return the list of
129         other patches that had to be popped to accomplish this."""
130         popped = []
131         for i in xrange(len(self.applied)):
132             if p(self.applied[i]):
133                 popped = self.applied[i:]
134                 del self.applied[i:]
135                 break
136         popped1 = [pn for pn in popped if not p(pn)]
137         popped2 = [pn for pn in popped if p(pn)]
138         self.unapplied = popped1 + popped2 + self.unapplied
139         self.__print_popped(popped)
140         return popped1
141
142     def delete_patches(self, p):
143         """Delete all patches pn for which p(pn) is true. Return the list of
144         other patches that had to be popped to accomplish this."""
145         popped = []
146         all_patches = self.applied + self.unapplied
147         for i in xrange(len(self.applied)):
148             if p(self.applied[i]):
149                 popped = self.applied[i:]
150                 del self.applied[i:]
151                 break
152         popped = [pn for pn in popped if not p(pn)]
153         self.unapplied = popped + [pn for pn in self.unapplied if not p(pn)]
154         self.__print_popped(popped)
155         for pn in all_patches:
156             if p(pn):
157                 s = ['', ' (empty)'][self.patches[pn].data.is_nochange()]
158                 self.patches[pn] = None
159                 out.info('Deleted %s%s' % (pn, s))
160         return popped
161
162     def push_patch(self, pn, iw = None):
163         """Attempt to push the named patch. If this results in conflicts,
164         halts the transaction. If index+worktree are given, spill any
165         conflicts to them."""
166         i = self.unapplied.index(pn)
167         cd = self.patches[pn].data
168         s = ['', ' (empty)'][cd.is_nochange()]
169         oldparent = cd.parent
170         cd = cd.set_parent(self.__head)
171         base = oldparent.data.tree
172         ours = cd.parent.data.tree
173         theirs = cd.tree
174         tree = self.__stack.repository.simple_merge(base, ours, theirs)
175         merge_conflict = False
176         if not tree:
177             if iw == None:
178                 self.__halt('%s does not apply cleanly' % pn)
179             try:
180                 self.__checkout(ours, iw)
181             except git.CheckoutException:
182                 self.__halt('Index/worktree dirty')
183             try:
184                 iw.merge(base, ours, theirs)
185                 tree = iw.index.write_tree()
186                 self.__current_tree = tree
187                 s = ' (modified)'
188             except git.MergeException:
189                 tree = ours
190                 merge_conflict = True
191                 s = ' (conflict)'
192         cd = cd.set_tree(tree)
193         self.patches[pn] = self.__stack.repository.commit(cd)
194         del self.unapplied[i]
195         self.applied.append(pn)
196         out.info('Pushed %s%s' % (pn, s))
197         if merge_conflict:
198             self.__halt('Merge conflict')