chiark / gitweb /
Split git.merge into two functions
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.exception import *
25 from stgit.utils import *
26 from stgit.out import *
27 from stgit.run import *
28 from stgit import git, basedir, templates
29 from stgit.config import config
30 from shutil import copyfile
31
32
33 # stack exception class
34 class StackException(StgException):
35     pass
36
37 class FilterUntil:
38     def __init__(self):
39         self.should_print = True
40     def __call__(self, x, until_test, prefix):
41         if until_test(x):
42             self.should_print = False
43         if self.should_print:
44             return x[0:len(prefix)] != prefix
45         return False
46
47 #
48 # Functions
49 #
50 __comment_prefix = 'STG:'
51 __patch_prefix = 'STG_PATCH:'
52
53 def __clean_comments(f):
54     """Removes lines marked for status in a commit file
55     """
56     f.seek(0)
57
58     # remove status-prefixed lines
59     lines = f.readlines()
60
61     patch_filter = FilterUntil()
62     until_test = lambda t: t == (__patch_prefix + '\n')
63     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
64
65     # remove empty lines at the end
66     while len(lines) != 0 and lines[-1] == '\n':
67         del lines[-1]
68
69     f.seek(0); f.truncate()
70     f.writelines(lines)
71
72 # TODO: move this out of the stgit.stack module, it is really for
73 # higher level commands to handle the user interaction
74 def edit_file(series, line, comment, show_patch = True):
75     fname = '.stgitmsg.txt'
76     tmpl = templates.get_template('patchdescr.tmpl')
77
78     f = file(fname, 'w+')
79     if line:
80         print >> f, line
81     elif tmpl:
82         print >> f, tmpl,
83     else:
84         print >> f
85     print >> f, __comment_prefix, comment
86     print >> f, __comment_prefix, \
87           'Lines prefixed with "%s" will be automatically removed.' \
88           % __comment_prefix
89     print >> f, __comment_prefix, \
90           'Trailing empty lines will be automatically removed.'
91
92     if show_patch:
93        print >> f, __patch_prefix
94        # series.get_patch(series.get_current()).get_top()
95        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
96        f.write(diff_str)
97
98     #Vim modeline must be near the end.
99     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
100     f.close()
101
102     call_editor(fname)
103
104     f = file(fname, 'r+')
105
106     __clean_comments(f)
107     f.seek(0)
108     result = f.read()
109
110     f.close()
111     os.remove(fname)
112
113     return result
114
115 #
116 # Classes
117 #
118
119 class StgitObject:
120     """An object with stgit-like properties stored as files in a directory
121     """
122     def _set_dir(self, dir):
123         self.__dir = dir
124     def _dir(self):
125         return self.__dir
126
127     def create_empty_field(self, name):
128         create_empty_file(os.path.join(self.__dir, name))
129
130     def _get_field(self, name, multiline = False):
131         id_file = os.path.join(self.__dir, name)
132         if os.path.isfile(id_file):
133             line = read_string(id_file, multiline)
134             if line == '':
135                 return None
136             else:
137                 return line
138         else:
139             return None
140
141     def _set_field(self, name, value, multiline = False):
142         fname = os.path.join(self.__dir, name)
143         if value and value != '':
144             write_string(fname, value, multiline)
145         elif os.path.isfile(fname):
146             os.remove(fname)
147
148
149 class Patch(StgitObject):
150     """Basic patch implementation
151     """
152     def __init_refs(self):
153         self.__top_ref = self.__refs_base + '/' + self.__name
154         self.__log_ref = self.__top_ref + '.log'
155
156     def __init__(self, name, series_dir, refs_base):
157         self.__series_dir = series_dir
158         self.__name = name
159         self._set_dir(os.path.join(self.__series_dir, self.__name))
160         self.__refs_base = refs_base
161         self.__init_refs()
162
163     def create(self):
164         os.mkdir(self._dir())
165
166     def delete(self, keep_log = False):
167         if os.path.isdir(self._dir()):
168             for f in os.listdir(self._dir()):
169                 os.remove(os.path.join(self._dir(), f))
170             os.rmdir(self._dir())
171         else:
172             out.warn('Patch directory "%s" does not exist' % self._dir())
173         try:
174             # the reference might not exist if the repository was corrupted
175             git.delete_ref(self.__top_ref)
176         except git.GitException, e:
177             out.warn(str(e))
178         if not keep_log and git.ref_exists(self.__log_ref):
179             git.delete_ref(self.__log_ref)
180
181     def get_name(self):
182         return self.__name
183
184     def rename(self, newname):
185         olddir = self._dir()
186         old_top_ref = self.__top_ref
187         old_log_ref = self.__log_ref
188         self.__name = newname
189         self._set_dir(os.path.join(self.__series_dir, self.__name))
190         self.__init_refs()
191
192         git.rename_ref(old_top_ref, self.__top_ref)
193         if git.ref_exists(old_log_ref):
194             git.rename_ref(old_log_ref, self.__log_ref)
195         os.rename(olddir, self._dir())
196
197     def __update_top_ref(self, ref):
198         git.set_ref(self.__top_ref, ref)
199         self._set_field('top', ref)
200         self._set_field('bottom', git.get_commit(ref).get_parent())
201
202     def __update_log_ref(self, ref):
203         git.set_ref(self.__log_ref, ref)
204
205     def get_old_bottom(self):
206         return git.get_commit(self.get_old_top()).get_parent()
207
208     def get_bottom(self):
209         return git.get_commit(self.get_top()).get_parent()
210
211     def get_old_top(self):
212         return self._get_field('top.old')
213
214     def get_top(self):
215         return git.rev_parse(self.__top_ref)
216
217     def set_top(self, value, backup = False):
218         if backup:
219             curr_top = self.get_top()
220             self._set_field('top.old', curr_top)
221             self._set_field('bottom.old', git.get_commit(curr_top).get_parent())
222         self.__update_top_ref(value)
223
224     def restore_old_boundaries(self):
225         top = self._get_field('top.old')
226
227         if top:
228             self.__update_top_ref(top)
229             return True
230         else:
231             return False
232
233     def get_description(self):
234         return self._get_field('description', True)
235
236     def set_description(self, line):
237         self._set_field('description', line, True)
238
239     def get_authname(self):
240         return self._get_field('authname')
241
242     def set_authname(self, name):
243         self._set_field('authname', name or git.author().name)
244
245     def get_authemail(self):
246         return self._get_field('authemail')
247
248     def set_authemail(self, email):
249         self._set_field('authemail', email or git.author().email)
250
251     def get_authdate(self):
252         date = self._get_field('authdate')
253         if not date:
254             return date
255
256         if re.match('[0-9]+\s+[+-][0-9]+', date):
257             # Unix time (seconds) + time zone
258             secs_tz = date.split()
259             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
260
261         return date
262
263     def set_authdate(self, date):
264         self._set_field('authdate', date or git.author().date)
265
266     def get_commname(self):
267         return self._get_field('commname')
268
269     def set_commname(self, name):
270         self._set_field('commname', name or git.committer().name)
271
272     def get_commemail(self):
273         return self._get_field('commemail')
274
275     def set_commemail(self, email):
276         self._set_field('commemail', email or git.committer().email)
277
278     def get_log(self):
279         return self._get_field('log')
280
281     def set_log(self, value, backup = False):
282         self._set_field('log', value)
283         self.__update_log_ref(value)
284
285 # The current StGIT metadata format version.
286 FORMAT_VERSION = 2
287
288 class PatchSet(StgitObject):
289     def __init__(self, name = None):
290         try:
291             if name:
292                 self.set_name (name)
293             else:
294                 self.set_name (git.get_head_file())
295             self.__base_dir = basedir.get()
296         except git.GitException, ex:
297             raise StackException, 'GIT tree not initialised: %s' % ex
298
299         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
300
301     def get_name(self):
302         return self.__name
303     def set_name(self, name):
304         self.__name = name
305
306     def _basedir(self):
307         return self.__base_dir
308
309     def get_head(self):
310         """Return the head of the branch
311         """
312         crt = self.get_current_patch()
313         if crt:
314             return crt.get_top()
315         else:
316             return self.get_base()
317
318     def get_protected(self):
319         return os.path.isfile(os.path.join(self._dir(), 'protected'))
320
321     def protect(self):
322         protect_file = os.path.join(self._dir(), 'protected')
323         if not os.path.isfile(protect_file):
324             create_empty_file(protect_file)
325
326     def unprotect(self):
327         protect_file = os.path.join(self._dir(), 'protected')
328         if os.path.isfile(protect_file):
329             os.remove(protect_file)
330
331     def __branch_descr(self):
332         return 'branch.%s.description' % self.get_name()
333
334     def get_description(self):
335         return config.get(self.__branch_descr()) or ''
336
337     def set_description(self, line):
338         if line:
339             config.set(self.__branch_descr(), line)
340         else:
341             config.unset(self.__branch_descr())
342
343     def head_top_equal(self):
344         """Return true if the head and the top are the same
345         """
346         crt = self.get_current_patch()
347         if not crt:
348             # we don't care, no patches applied
349             return True
350         return git.get_head() == crt.get_top()
351
352     def is_initialised(self):
353         """Checks if series is already initialised
354         """
355         return bool(config.get(self.format_version_key()))
356
357
358 def shortlog(patches):
359     log = ''.join(Run('git', 'log', '--pretty=short',
360                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
361                   for p in patches)
362     return Run('git', 'shortlog').raw_input(log).raw_output()
363
364 class Series(PatchSet):
365     """Class including the operations on series
366     """
367     def __init__(self, name = None):
368         """Takes a series name as the parameter.
369         """
370         PatchSet.__init__(self, name)
371
372         # Update the branch to the latest format version if it is
373         # initialized, but don't touch it if it isn't.
374         self.update_to_current_format_version()
375
376         self.__refs_base = 'refs/patches/%s' % self.get_name()
377
378         self.__applied_file = os.path.join(self._dir(), 'applied')
379         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
380         self.__hidden_file = os.path.join(self._dir(), 'hidden')
381
382         # where this series keeps its patches
383         self.__patch_dir = os.path.join(self._dir(), 'patches')
384
385         # trash directory
386         self.__trash_dir = os.path.join(self._dir(), 'trash')
387
388     def format_version_key(self):
389         return 'branch.%s.stgit.stackformatversion' % self.get_name()
390
391     def update_to_current_format_version(self):
392         """Update a potentially older StGIT directory structure to the
393         latest version. Note: This function should depend as little as
394         possible on external functions that may change during a format
395         version bump, since it must remain able to process older formats."""
396
397         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
398         def get_format_version():
399             """Return the integer format version number, or None if the
400             branch doesn't have any StGIT metadata at all, of any version."""
401             fv = config.get(self.format_version_key())
402             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
403             if fv:
404                 # Great, there's an explicitly recorded format version
405                 # number, which means that the branch is initialized and
406                 # of that exact version.
407                 return int(fv)
408             elif ofv:
409                 # Old name for the version info, upgrade it
410                 config.set(self.format_version_key(), ofv)
411                 config.unset('branch.%s.stgitformatversion' % self.get_name())
412                 return int(ofv)
413             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
414                 # There's a .git/patches/<branch>/patches dirctory, which
415                 # means this is an initialized version 1 branch.
416                 return 1
417             elif os.path.isdir(branch_dir):
418                 # There's a .git/patches/<branch> directory, which means
419                 # this is an initialized version 0 branch.
420                 return 0
421             else:
422                 # The branch doesn't seem to be initialized at all.
423                 return None
424         def set_format_version(v):
425             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
426             config.set(self.format_version_key(), '%d' % v)
427         def mkdir(d):
428             if not os.path.isdir(d):
429                 os.makedirs(d)
430         def rm(f):
431             if os.path.exists(f):
432                 os.remove(f)
433         def rm_ref(ref):
434             if git.ref_exists(ref):
435                 git.delete_ref(ref)
436
437         # Update 0 -> 1.
438         if get_format_version() == 0:
439             mkdir(os.path.join(branch_dir, 'trash'))
440             patch_dir = os.path.join(branch_dir, 'patches')
441             mkdir(patch_dir)
442             refs_base = 'refs/patches/%s' % self.get_name()
443             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
444                           + file(os.path.join(branch_dir, 'applied')).readlines()):
445                 patch = patch.strip()
446                 os.rename(os.path.join(branch_dir, patch),
447                           os.path.join(patch_dir, patch))
448                 topfield = os.path.join(patch_dir, patch, 'top')
449                 if os.path.isfile(topfield):
450                     top = read_string(topfield, False)
451                 else:
452                     top = None
453                 if top:
454                     git.set_ref(refs_base + '/' + patch, top)
455             set_format_version(1)
456
457         # Update 1 -> 2.
458         if get_format_version() == 1:
459             desc_file = os.path.join(branch_dir, 'description')
460             if os.path.isfile(desc_file):
461                 desc = read_string(desc_file)
462                 if desc:
463                     config.set('branch.%s.description' % self.get_name(), desc)
464                 rm(desc_file)
465             rm(os.path.join(branch_dir, 'current'))
466             rm_ref('refs/bases/%s' % self.get_name())
467             set_format_version(2)
468
469         # Make sure we're at the latest version.
470         if not get_format_version() in [None, FORMAT_VERSION]:
471             raise StackException('Branch %s is at format version %d, expected %d'
472                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
473
474     def __patch_name_valid(self, name):
475         """Raise an exception if the patch name is not valid.
476         """
477         if not name or re.search('[^\w.-]', name):
478             raise StackException, 'Invalid patch name: "%s"' % name
479
480     def get_patch(self, name):
481         """Return a Patch object for the given name
482         """
483         return Patch(name, self.__patch_dir, self.__refs_base)
484
485     def get_current_patch(self):
486         """Return a Patch object representing the topmost patch, or
487         None if there is no such patch."""
488         crt = self.get_current()
489         if not crt:
490             return None
491         return self.get_patch(crt)
492
493     def get_current(self):
494         """Return the name of the topmost patch, or None if there is
495         no such patch."""
496         try:
497             applied = self.get_applied()
498         except StackException:
499             # No "applied" file: branch is not initialized.
500             return None
501         try:
502             return applied[-1]
503         except IndexError:
504             # No patches applied.
505             return None
506
507     def get_applied(self):
508         if not os.path.isfile(self.__applied_file):
509             raise StackException, 'Branch "%s" not initialised' % self.get_name()
510         return read_strings(self.__applied_file)
511
512     def set_applied(self, applied):
513         write_strings(self.__applied_file, applied)
514
515     def get_unapplied(self):
516         if not os.path.isfile(self.__unapplied_file):
517             raise StackException, 'Branch "%s" not initialised' % self.get_name()
518         return read_strings(self.__unapplied_file)
519
520     def set_unapplied(self, unapplied):
521         write_strings(self.__unapplied_file, unapplied)
522
523     def get_hidden(self):
524         if not os.path.isfile(self.__hidden_file):
525             return []
526         return read_strings(self.__hidden_file)
527
528     def get_base(self):
529         # Return the parent of the bottommost patch, if there is one.
530         if os.path.isfile(self.__applied_file):
531             bottommost = file(self.__applied_file).readline().strip()
532             if bottommost:
533                 return self.get_patch(bottommost).get_bottom()
534         # No bottommost patch, so just return HEAD
535         return git.get_head()
536
537     def get_parent_remote(self):
538         value = config.get('branch.%s.remote' % self.get_name())
539         if value:
540             return value
541         elif 'origin' in git.remotes_list():
542             out.note(('No parent remote declared for stack "%s",'
543                       ' defaulting to "origin".' % self.get_name()),
544                      ('Consider setting "branch.%s.remote" and'
545                       ' "branch.%s.merge" with "git config".'
546                       % (self.get_name(), self.get_name())))
547             return 'origin'
548         else:
549             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
550
551     def __set_parent_remote(self, remote):
552         value = config.set('branch.%s.remote' % self.get_name(), remote)
553
554     def get_parent_branch(self):
555         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
556         if value:
557             return value
558         elif git.rev_parse('heads/origin'):
559             out.note(('No parent branch declared for stack "%s",'
560                       ' defaulting to "heads/origin".' % self.get_name()),
561                      ('Consider setting "branch.%s.stgit.parentbranch"'
562                       ' with "git config".' % self.get_name()))
563             return 'heads/origin'
564         else:
565             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
566
567     def __set_parent_branch(self, name):
568         if config.get('branch.%s.remote' % self.get_name()):
569             # Never set merge if remote is not set to avoid
570             # possibly-erroneous lookups into 'origin'
571             config.set('branch.%s.merge' % self.get_name(), name)
572         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
573
574     def set_parent(self, remote, localbranch):
575         if localbranch:
576             if remote:
577                 self.__set_parent_remote(remote)
578             self.__set_parent_branch(localbranch)
579         # We'll enforce this later
580 #         else:
581 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
582
583     def __patch_is_current(self, patch):
584         return patch.get_name() == self.get_current()
585
586     def patch_applied(self, name):
587         """Return true if the patch exists in the applied list
588         """
589         return name in self.get_applied()
590
591     def patch_unapplied(self, name):
592         """Return true if the patch exists in the unapplied list
593         """
594         return name in self.get_unapplied()
595
596     def patch_hidden(self, name):
597         """Return true if the patch is hidden.
598         """
599         return name in self.get_hidden()
600
601     def patch_exists(self, name):
602         """Return true if there is a patch with the given name, false
603         otherwise."""
604         return self.patch_applied(name) or self.patch_unapplied(name) \
605                or self.patch_hidden(name)
606
607     def init(self, create_at=False, parent_remote=None, parent_branch=None):
608         """Initialises the stgit series
609         """
610         if self.is_initialised():
611             raise StackException, '%s already initialized' % self.get_name()
612         for d in [self._dir()]:
613             if os.path.exists(d):
614                 raise StackException, '%s already exists' % d
615
616         if (create_at!=False):
617             git.create_branch(self.get_name(), create_at)
618
619         os.makedirs(self.__patch_dir)
620
621         self.set_parent(parent_remote, parent_branch)
622
623         self.create_empty_field('applied')
624         self.create_empty_field('unapplied')
625
626         config.set(self.format_version_key(), str(FORMAT_VERSION))
627
628     def rename(self, to_name):
629         """Renames a series
630         """
631         to_stack = Series(to_name)
632
633         if to_stack.is_initialised():
634             raise StackException, '"%s" already exists' % to_stack.get_name()
635
636         patches = self.get_applied() + self.get_unapplied()
637
638         git.rename_branch(self.get_name(), to_name)
639
640         for patch in patches:
641             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
642                            'refs/patches/%s/%s' % (to_name, patch))
643             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
644                            'refs/patches/%s/%s.log' % (to_name, patch))
645         if os.path.isdir(self._dir()):
646             rename(os.path.join(self._basedir(), 'patches'),
647                    self.get_name(), to_stack.get_name())
648
649         # Rename the config section
650         for k in ['branch.%s', 'branch.%s.stgit']:
651             config.rename_section(k % self.get_name(), k % to_name)
652
653         self.__init__(to_name)
654
655     def clone(self, target_series):
656         """Clones a series
657         """
658         try:
659             # allow cloning of branches not under StGIT control
660             base = self.get_base()
661         except:
662             base = git.get_head()
663         Series(target_series).init(create_at = base)
664         new_series = Series(target_series)
665
666         # generate an artificial description file
667         new_series.set_description('clone of "%s"' % self.get_name())
668
669         # clone self's entire series as unapplied patches
670         try:
671             # allow cloning of branches not under StGIT control
672             applied = self.get_applied()
673             unapplied = self.get_unapplied()
674             patches = applied + unapplied
675             patches.reverse()
676         except:
677             patches = applied = unapplied = []
678         for p in patches:
679             patch = self.get_patch(p)
680             newpatch = new_series.new_patch(p, message = patch.get_description(),
681                                             can_edit = False, unapplied = True,
682                                             bottom = patch.get_bottom(),
683                                             top = patch.get_top(),
684                                             author_name = patch.get_authname(),
685                                             author_email = patch.get_authemail(),
686                                             author_date = patch.get_authdate())
687             if patch.get_log():
688                 out.info('Setting log to %s' %  patch.get_log())
689                 newpatch.set_log(patch.get_log())
690             else:
691                 out.info('No log for %s' % p)
692
693         # fast forward the cloned series to self's top
694         new_series.forward_patches(applied)
695
696         # Clone parent informations
697         value = config.get('branch.%s.remote' % self.get_name())
698         if value:
699             config.set('branch.%s.remote' % target_series, value)
700
701         value = config.get('branch.%s.merge' % self.get_name())
702         if value:
703             config.set('branch.%s.merge' % target_series, value)
704
705         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
706         if value:
707             config.set('branch.%s.stgit.parentbranch' % target_series, value)
708
709     def delete(self, force = False):
710         """Deletes an stgit series
711         """
712         if self.is_initialised():
713             patches = self.get_unapplied() + self.get_applied()
714             if not force and patches:
715                 raise StackException, \
716                       'Cannot delete: the series still contains patches'
717             for p in patches:
718                 self.get_patch(p).delete()
719
720             # remove the trash directory if any
721             if os.path.exists(self.__trash_dir):
722                 for fname in os.listdir(self.__trash_dir):
723                     os.remove(os.path.join(self.__trash_dir, fname))
724                 os.rmdir(self.__trash_dir)
725
726             # FIXME: find a way to get rid of those manual removals
727             # (move functionality to StgitObject ?)
728             if os.path.exists(self.__applied_file):
729                 os.remove(self.__applied_file)
730             if os.path.exists(self.__unapplied_file):
731                 os.remove(self.__unapplied_file)
732             if os.path.exists(self.__hidden_file):
733                 os.remove(self.__hidden_file)
734             if os.path.exists(self._dir()+'/orig-base'):
735                 os.remove(self._dir()+'/orig-base')
736
737             if not os.listdir(self.__patch_dir):
738                 os.rmdir(self.__patch_dir)
739             else:
740                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
741
742             try:
743                 os.removedirs(self._dir())
744             except OSError:
745                 raise StackException('Series directory %s is not empty'
746                                      % self._dir())
747
748         try:
749             git.delete_branch(self.get_name())
750         except GitException:
751             out.warn('Could not delete branch "%s"' % self.get_name())
752
753         config.remove_section('branch.%s' % self.get_name())
754         config.remove_section('branch.%s.stgit' % self.get_name())
755
756     def refresh_patch(self, files = None, message = None, edit = False,
757                       show_patch = False,
758                       cache_update = True,
759                       author_name = None, author_email = None,
760                       author_date = None,
761                       committer_name = None, committer_email = None,
762                       backup = True, sign_str = None, log = 'refresh',
763                       notes = None, bottom = None):
764         """Generates a new commit for the topmost patch
765         """
766         patch = self.get_current_patch()
767         if not patch:
768             raise StackException, 'No patches applied'
769
770         descr = patch.get_description()
771         if not (message or descr):
772             edit = True
773             descr = ''
774         elif message:
775             descr = message
776
777         # TODO: move this out of the stgit.stack module, it is really
778         # for higher level commands to handle the user interaction
779         if not message and edit:
780             descr = edit_file(self, descr.rstrip(), \
781                               'Please edit the description for patch "%s" ' \
782                               'above.' % patch.get_name(), show_patch)
783
784         if not author_name:
785             author_name = patch.get_authname()
786         if not author_email:
787             author_email = patch.get_authemail()
788         if not author_date:
789             author_date = patch.get_authdate()
790         if not committer_name:
791             committer_name = patch.get_commname()
792         if not committer_email:
793             committer_email = patch.get_commemail()
794
795         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
796
797         if not bottom:
798             bottom = patch.get_bottom()
799
800         commit_id = git.commit(files = files,
801                                message = descr, parents = [bottom],
802                                cache_update = cache_update,
803                                allowempty = True,
804                                author_name = author_name,
805                                author_email = author_email,
806                                author_date = author_date,
807                                committer_name = committer_name,
808                                committer_email = committer_email)
809
810         patch.set_top(commit_id, backup = backup)
811         patch.set_description(descr)
812         patch.set_authname(author_name)
813         patch.set_authemail(author_email)
814         patch.set_authdate(author_date)
815         patch.set_commname(committer_name)
816         patch.set_commemail(committer_email)
817
818         if log:
819             self.log_patch(patch, log, notes)
820
821         return commit_id
822
823     def undo_refresh(self):
824         """Undo the patch boundaries changes caused by 'refresh'
825         """
826         name = self.get_current()
827         assert(name)
828
829         patch = self.get_patch(name)
830         old_bottom = patch.get_old_bottom()
831         old_top = patch.get_old_top()
832
833         # the bottom of the patch is not changed by refresh. If the
834         # old_bottom is different, there wasn't any previous 'refresh'
835         # command (probably only a 'push')
836         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
837             raise StackException, 'No undo information available'
838
839         git.reset(tree_id = old_top, check_out = False)
840         if patch.restore_old_boundaries():
841             self.log_patch(patch, 'undo')
842
843     def new_patch(self, name, message = None, can_edit = True,
844                   unapplied = False, show_patch = False,
845                   top = None, bottom = None, commit = True,
846                   author_name = None, author_email = None, author_date = None,
847                   committer_name = None, committer_email = None,
848                   before_existing = False, sign_str = None):
849         """Creates a new patch, either pointing to an existing commit object,
850         or by creating a new commit object.
851         """
852
853         assert commit or (top and bottom)
854         assert not before_existing or (top and bottom)
855         assert not (commit and before_existing)
856         assert (top and bottom) or (not top and not bottom)
857         assert commit or (not top or (bottom == git.get_commit(top).get_parent()))
858
859         if name != None:
860             self.__patch_name_valid(name)
861             if self.patch_exists(name):
862                 raise StackException, 'Patch "%s" already exists' % name
863
864         # TODO: move this out of the stgit.stack module, it is really
865         # for higher level commands to handle the user interaction
866         def sign(msg):
867             return add_sign_line(msg, sign_str,
868                                  committer_name or git.committer().name,
869                                  committer_email or git.committer().email)
870         if not message and can_edit:
871             descr = edit_file(
872                 self, sign(''),
873                 'Please enter the description for the patch above.',
874                 show_patch)
875         else:
876             descr = sign(message)
877
878         head = git.get_head()
879
880         if name == None:
881             name = make_patch_name(descr, self.patch_exists)
882
883         patch = self.get_patch(name)
884         patch.create()
885
886         patch.set_description(descr)
887         patch.set_authname(author_name)
888         patch.set_authemail(author_email)
889         patch.set_authdate(author_date)
890         patch.set_commname(committer_name)
891         patch.set_commemail(committer_email)
892
893         if before_existing:
894             insert_string(self.__applied_file, patch.get_name())
895         elif unapplied:
896             patches = [patch.get_name()] + self.get_unapplied()
897             write_strings(self.__unapplied_file, patches)
898             set_head = False
899         else:
900             append_string(self.__applied_file, patch.get_name())
901             set_head = True
902
903         if commit:
904             if top:
905                 top_commit = git.get_commit(top)
906             else:
907                 bottom = head
908                 top_commit = git.get_commit(head)
909
910             # create a commit for the patch (may be empty if top == bottom);
911             # only commit on top of the current branch
912             assert(unapplied or bottom == head)
913             commit_id = git.commit(message = descr, parents = [bottom],
914                                    cache_update = False,
915                                    tree_id = top_commit.get_tree(),
916                                    allowempty = True, set_head = set_head,
917                                    author_name = author_name,
918                                    author_email = author_email,
919                                    author_date = author_date,
920                                    committer_name = committer_name,
921                                    committer_email = committer_email)
922             # set the patch top to the new commit
923             patch.set_top(commit_id)
924         else:
925             patch.set_top(top)
926
927         self.log_patch(patch, 'new')
928
929         return patch
930
931     def delete_patch(self, name, keep_log = False):
932         """Deletes a patch
933         """
934         self.__patch_name_valid(name)
935         patch = self.get_patch(name)
936
937         if self.__patch_is_current(patch):
938             self.pop_patch(name)
939         elif self.patch_applied(name):
940             raise StackException, 'Cannot remove an applied patch, "%s", ' \
941                   'which is not current' % name
942         elif not name in self.get_unapplied():
943             raise StackException, 'Unknown patch "%s"' % name
944
945         # save the commit id to a trash file
946         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
947
948         patch.delete(keep_log = keep_log)
949
950         unapplied = self.get_unapplied()
951         unapplied.remove(name)
952         write_strings(self.__unapplied_file, unapplied)
953
954     def forward_patches(self, names):
955         """Try to fast-forward an array of patches.
956
957         On return, patches in names[0:returned_value] have been pushed on the
958         stack. Apply the rest with push_patch
959         """
960         unapplied = self.get_unapplied()
961
962         forwarded = 0
963         top = git.get_head()
964
965         for name in names:
966             assert(name in unapplied)
967
968             patch = self.get_patch(name)
969
970             head = top
971             bottom = patch.get_bottom()
972             top = patch.get_top()
973
974             # top != bottom always since we have a commit for each patch
975             if head == bottom:
976                 # reset the backup information. No logging since the
977                 # patch hasn't changed
978                 patch.set_top(top, backup = True)
979
980             else:
981                 head_tree = git.get_commit(head).get_tree()
982                 bottom_tree = git.get_commit(bottom).get_tree()
983                 if head_tree == bottom_tree:
984                     # We must just reparent this patch and create a new commit
985                     # for it
986                     descr = patch.get_description()
987                     author_name = patch.get_authname()
988                     author_email = patch.get_authemail()
989                     author_date = patch.get_authdate()
990                     committer_name = patch.get_commname()
991                     committer_email = patch.get_commemail()
992
993                     top_tree = git.get_commit(top).get_tree()
994
995                     top = git.commit(message = descr, parents = [head],
996                                      cache_update = False,
997                                      tree_id = top_tree,
998                                      allowempty = True,
999                                      author_name = author_name,
1000                                      author_email = author_email,
1001                                      author_date = author_date,
1002                                      committer_name = committer_name,
1003                                      committer_email = committer_email)
1004
1005                     patch.set_top(top, backup = True)
1006
1007                     self.log_patch(patch, 'push(f)')
1008                 else:
1009                     top = head
1010                     # stop the fast-forwarding, must do a real merge
1011                     break
1012
1013             forwarded+=1
1014             unapplied.remove(name)
1015
1016         if forwarded == 0:
1017             return 0
1018
1019         git.switch(top)
1020
1021         append_strings(self.__applied_file, names[0:forwarded])
1022         write_strings(self.__unapplied_file, unapplied)
1023
1024         return forwarded
1025
1026     def merged_patches(self, names):
1027         """Test which patches were merged upstream by reverse-applying
1028         them in reverse order. The function returns the list of
1029         patches detected to have been applied. The state of the tree
1030         is restored to the original one
1031         """
1032         patches = [self.get_patch(name) for name in names]
1033         patches.reverse()
1034
1035         merged = []
1036         for p in patches:
1037             if git.apply_diff(p.get_top(), p.get_bottom()):
1038                 merged.append(p.get_name())
1039         merged.reverse()
1040
1041         git.reset()
1042
1043         return merged
1044
1045     def push_empty_patch(self, name):
1046         """Pushes an empty patch on the stack
1047         """
1048         unapplied = self.get_unapplied()
1049         assert(name in unapplied)
1050
1051         # patch = self.get_patch(name)
1052         head = git.get_head()
1053
1054         append_string(self.__applied_file, name)
1055
1056         unapplied.remove(name)
1057         write_strings(self.__unapplied_file, unapplied)
1058
1059         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1060
1061     def push_patch(self, name):
1062         """Pushes a patch on the stack
1063         """
1064         unapplied = self.get_unapplied()
1065         assert(name in unapplied)
1066
1067         patch = self.get_patch(name)
1068
1069         head = git.get_head()
1070         bottom = patch.get_bottom()
1071         top = patch.get_top()
1072         # top != bottom always since we have a commit for each patch
1073
1074         if head == bottom:
1075             # A fast-forward push. Just reset the backup
1076             # information. No need for logging
1077             patch.set_top(top, backup = True)
1078
1079             git.switch(top)
1080             append_string(self.__applied_file, name)
1081
1082             unapplied.remove(name)
1083             write_strings(self.__unapplied_file, unapplied)
1084             return False
1085
1086         # Need to create a new commit an merge in the old patch
1087         ex = None
1088         modified = False
1089
1090         # Try the fast applying first. If this fails, fall back to the
1091         # three-way merge
1092         if not git.apply_diff(bottom, top):
1093             # if git.apply_diff() fails, the patch requires a diff3
1094             # merge and can be reported as modified
1095             modified = True
1096
1097             # merge can fail but the patch needs to be pushed
1098             try:
1099                 git.merge_recursive(bottom, head, top)
1100             except git.GitException, ex:
1101                 out.error('The merge failed during "push".',
1102                           'Use "refresh" after fixing the conflicts or'
1103                           ' revert the operation with "push --undo".')
1104
1105         append_string(self.__applied_file, name)
1106
1107         unapplied.remove(name)
1108         write_strings(self.__unapplied_file, unapplied)
1109
1110         if not ex:
1111             # if the merge was OK and no conflicts, just refresh the patch
1112             # The GIT cache was already updated by the merge operation
1113             if modified:
1114                 log = 'push(m)'
1115             else:
1116                 log = 'push'
1117             self.refresh_patch(bottom = head, cache_update = False, log = log)
1118         else:
1119             # we store the correctly merged files only for
1120             # tracking the conflict history. Note that the
1121             # git.merge() operations should always leave the index
1122             # in a valid state (i.e. only stage 0 files)
1123             self.refresh_patch(bottom = head, cache_update = False,
1124                                log = 'push(c)')
1125             raise StackException, str(ex)
1126
1127         return modified
1128
1129     def undo_push(self):
1130         name = self.get_current()
1131         assert(name)
1132
1133         patch = self.get_patch(name)
1134         old_bottom = patch.get_old_bottom()
1135         old_top = patch.get_old_top()
1136
1137         # the top of the patch is changed by a push operation only
1138         # together with the bottom (otherwise the top was probably
1139         # modified by 'refresh'). If they are both unchanged, there
1140         # was a fast forward
1141         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1142             raise StackException, 'No undo information available'
1143
1144         git.reset()
1145         self.pop_patch(name)
1146         ret = patch.restore_old_boundaries()
1147         if ret:
1148             self.log_patch(patch, 'undo')
1149
1150         return ret
1151
1152     def pop_patch(self, name, keep = False):
1153         """Pops the top patch from the stack
1154         """
1155         applied = self.get_applied()
1156         applied.reverse()
1157         assert(name in applied)
1158
1159         patch = self.get_patch(name)
1160
1161         if git.get_head_file() == self.get_name():
1162             if keep and not git.apply_diff(git.get_head(), patch.get_bottom(),
1163                                            check_index = False):
1164                 raise StackException(
1165                     'Failed to pop patches while preserving the local changes')
1166             git.switch(patch.get_bottom(), keep)
1167         else:
1168             git.set_branch(self.get_name(), patch.get_bottom())
1169
1170         # save the new applied list
1171         idx = applied.index(name) + 1
1172
1173         popped = applied[:idx]
1174         popped.reverse()
1175         unapplied = popped + self.get_unapplied()
1176         write_strings(self.__unapplied_file, unapplied)
1177
1178         del applied[:idx]
1179         applied.reverse()
1180         write_strings(self.__applied_file, applied)
1181
1182     def empty_patch(self, name):
1183         """Returns True if the patch is empty
1184         """
1185         self.__patch_name_valid(name)
1186         patch = self.get_patch(name)
1187         bottom = patch.get_bottom()
1188         top = patch.get_top()
1189
1190         if bottom == top:
1191             return True
1192         elif git.get_commit(top).get_tree() \
1193                  == git.get_commit(bottom).get_tree():
1194             return True
1195
1196         return False
1197
1198     def rename_patch(self, oldname, newname):
1199         self.__patch_name_valid(newname)
1200
1201         applied = self.get_applied()
1202         unapplied = self.get_unapplied()
1203
1204         if oldname == newname:
1205             raise StackException, '"To" name and "from" name are the same'
1206
1207         if newname in applied or newname in unapplied:
1208             raise StackException, 'Patch "%s" already exists' % newname
1209
1210         if oldname in unapplied:
1211             self.get_patch(oldname).rename(newname)
1212             unapplied[unapplied.index(oldname)] = newname
1213             write_strings(self.__unapplied_file, unapplied)
1214         elif oldname in applied:
1215             self.get_patch(oldname).rename(newname)
1216
1217             applied[applied.index(oldname)] = newname
1218             write_strings(self.__applied_file, applied)
1219         else:
1220             raise StackException, 'Unknown patch "%s"' % oldname
1221
1222     def log_patch(self, patch, message, notes = None):
1223         """Generate a log commit for a patch
1224         """
1225         top = git.get_commit(patch.get_top())
1226         old_log = patch.get_log()
1227
1228         if message is None:
1229             # replace the current log entry
1230             if not old_log:
1231                 raise StackException, \
1232                       'No log entry to annotate for patch "%s"' \
1233                       % patch.get_name()
1234             replace = True
1235             log_commit = git.get_commit(old_log)
1236             msg = log_commit.get_log().split('\n')[0]
1237             log_parent = log_commit.get_parent()
1238             if log_parent:
1239                 parents = [log_parent]
1240             else:
1241                 parents = []
1242         else:
1243             # generate a new log entry
1244             replace = False
1245             msg = '%s\t%s' % (message, top.get_id_hash())
1246             if old_log:
1247                 parents = [old_log]
1248             else:
1249                 parents = []
1250
1251         if notes:
1252             msg += '\n\n' + notes
1253
1254         log = git.commit(message = msg, parents = parents,
1255                          cache_update = False, tree_id = top.get_tree(),
1256                          allowempty = True)
1257         patch.set_log(log)
1258
1259     def hide_patch(self, name):
1260         """Add the patch to the hidden list.
1261         """
1262         unapplied = self.get_unapplied()
1263         if name not in unapplied:
1264             # keep the checking order for backward compatibility with
1265             # the old hidden patches functionality
1266             if self.patch_applied(name):
1267                 raise StackException, 'Cannot hide applied patch "%s"' % name
1268             elif self.patch_hidden(name):
1269                 raise StackException, 'Patch "%s" already hidden' % name
1270             else:
1271                 raise StackException, 'Unknown patch "%s"' % name
1272
1273         if not self.patch_hidden(name):
1274             # check needed for backward compatibility with the old
1275             # hidden patches functionality
1276             append_string(self.__hidden_file, name)
1277
1278         unapplied.remove(name)
1279         write_strings(self.__unapplied_file, unapplied)
1280
1281     def unhide_patch(self, name):
1282         """Remove the patch from the hidden list.
1283         """
1284         hidden = self.get_hidden()
1285         if not name in hidden:
1286             if self.patch_applied(name) or self.patch_unapplied(name):
1287                 raise StackException, 'Patch "%s" not hidden' % name
1288             else:
1289                 raise StackException, 'Unknown patch "%s"' % name
1290
1291         hidden.remove(name)
1292         write_strings(self.__hidden_file, hidden)
1293
1294         if not self.patch_applied(name) and not self.patch_unapplied(name):
1295             # check needed for backward compatibility with the old
1296             # hidden patches functionality
1297             append_string(self.__unapplied_file, name)