chiark / gitweb /
25de0f177da7339a4c0448cdcd8d79fd27ace026
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.exception import *
25 from stgit.utils import *
26 from stgit.out import *
27 from stgit.run import *
28 from stgit import git, basedir, templates
29 from stgit.config import config
30 from shutil import copyfile
31
32
33 # stack exception class
34 class StackException(StgException):
35     pass
36
37 class FilterUntil:
38     def __init__(self):
39         self.should_print = True
40     def __call__(self, x, until_test, prefix):
41         if until_test(x):
42             self.should_print = False
43         if self.should_print:
44             return x[0:len(prefix)] != prefix
45         return False
46
47 #
48 # Functions
49 #
50 __comment_prefix = 'STG:'
51 __patch_prefix = 'STG_PATCH:'
52
53 def __clean_comments(f):
54     """Removes lines marked for status in a commit file
55     """
56     f.seek(0)
57
58     # remove status-prefixed lines
59     lines = f.readlines()
60
61     patch_filter = FilterUntil()
62     until_test = lambda t: t == (__patch_prefix + '\n')
63     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
64
65     # remove empty lines at the end
66     while len(lines) != 0 and lines[-1] == '\n':
67         del lines[-1]
68
69     f.seek(0); f.truncate()
70     f.writelines(lines)
71
72 # TODO: move this out of the stgit.stack module, it is really for
73 # higher level commands to handle the user interaction
74 def edit_file(series, line, comment, show_patch = True):
75     fname = '.stgitmsg.txt'
76     tmpl = templates.get_template('patchdescr.tmpl')
77
78     f = file(fname, 'w+')
79     if line:
80         print >> f, line
81     elif tmpl:
82         print >> f, tmpl,
83     else:
84         print >> f
85     print >> f, __comment_prefix, comment
86     print >> f, __comment_prefix, \
87           'Lines prefixed with "%s" will be automatically removed.' \
88           % __comment_prefix
89     print >> f, __comment_prefix, \
90           'Trailing empty lines will be automatically removed.'
91
92     if show_patch:
93        print >> f, __patch_prefix
94        # series.get_patch(series.get_current()).get_top()
95        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
96        f.write(diff_str)
97
98     #Vim modeline must be near the end.
99     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
100     f.close()
101
102     call_editor(fname)
103
104     f = file(fname, 'r+')
105
106     __clean_comments(f)
107     f.seek(0)
108     result = f.read()
109
110     f.close()
111     os.remove(fname)
112
113     return result
114
115 #
116 # Classes
117 #
118
119 class StgitObject:
120     """An object with stgit-like properties stored as files in a directory
121     """
122     def _set_dir(self, dir):
123         self.__dir = dir
124     def _dir(self):
125         return self.__dir
126
127     def create_empty_field(self, name):
128         create_empty_file(os.path.join(self.__dir, name))
129
130     def _get_field(self, name, multiline = False):
131         id_file = os.path.join(self.__dir, name)
132         if os.path.isfile(id_file):
133             line = read_string(id_file, multiline)
134             if line == '':
135                 return None
136             else:
137                 return line
138         else:
139             return None
140
141     def _set_field(self, name, value, multiline = False):
142         fname = os.path.join(self.__dir, name)
143         if value and value != '':
144             write_string(fname, value, multiline)
145         elif os.path.isfile(fname):
146             os.remove(fname)
147
148
149 class Patch(StgitObject):
150     """Basic patch implementation
151     """
152     def __init_refs(self):
153         self.__top_ref = self.__refs_base + '/' + self.__name
154         self.__log_ref = self.__top_ref + '.log'
155
156     def __init__(self, name, series_dir, refs_base):
157         self.__series_dir = series_dir
158         self.__name = name
159         self._set_dir(os.path.join(self.__series_dir, self.__name))
160         self.__refs_base = refs_base
161         self.__init_refs()
162
163     def create(self):
164         os.mkdir(self._dir())
165         self.create_empty_field('top')
166
167     def delete(self, keep_log = False):
168         if os.path.isdir(self._dir()):
169             for f in os.listdir(self._dir()):
170                 os.remove(os.path.join(self._dir(), f))
171             os.rmdir(self._dir())
172         else:
173             out.warn('Patch directory "%s" does not exist' % self._dir())
174         try:
175             # the reference might not exist if the repository was corrupted
176             git.delete_ref(self.__top_ref)
177         except git.GitException, e:
178             out.warn(str(e))
179         if not keep_log and git.ref_exists(self.__log_ref):
180             git.delete_ref(self.__log_ref)
181
182     def get_name(self):
183         return self.__name
184
185     def rename(self, newname):
186         olddir = self._dir()
187         old_top_ref = self.__top_ref
188         old_log_ref = self.__log_ref
189         self.__name = newname
190         self._set_dir(os.path.join(self.__series_dir, self.__name))
191         self.__init_refs()
192
193         git.rename_ref(old_top_ref, self.__top_ref)
194         if git.ref_exists(old_log_ref):
195             git.rename_ref(old_log_ref, self.__log_ref)
196         os.rename(olddir, self._dir())
197
198     def __update_top_ref(self, ref):
199         git.set_ref(self.__top_ref, ref)
200
201     def __update_log_ref(self, ref):
202         git.set_ref(self.__log_ref, ref)
203
204     def update_top_ref(self):
205         top = self.get_top()
206         if top:
207             self.__update_top_ref(top)
208
209     def get_old_bottom(self):
210         return git.get_commit(self.get_old_top()).get_parent()
211
212     def get_bottom(self):
213         return git.get_commit(self.get_top()).get_parent()
214
215     def get_old_top(self):
216         return self._get_field('top.old')
217
218     def get_top(self):
219         top = self._get_field('top')
220         try:
221             ref = git.rev_parse(self.__top_ref)
222         except:
223             ref = None
224         assert not ref or top == ref
225         return top
226
227     def set_top(self, value, backup = False):
228         if backup:
229             curr = self._get_field('top')
230             self._set_field('top.old', curr)
231         self._set_field('top', value)
232         self.__update_top_ref(value)
233
234     def restore_old_boundaries(self):
235         top = self._get_field('top.old')
236
237         if top:
238             self._set_field('top', top)
239             self.__update_top_ref(top)
240             return True
241         else:
242             return False
243
244     def get_description(self):
245         return self._get_field('description', True)
246
247     def set_description(self, line):
248         self._set_field('description', line, True)
249
250     def get_authname(self):
251         return self._get_field('authname')
252
253     def set_authname(self, name):
254         self._set_field('authname', name or git.author().name)
255
256     def get_authemail(self):
257         return self._get_field('authemail')
258
259     def set_authemail(self, email):
260         self._set_field('authemail', email or git.author().email)
261
262     def get_authdate(self):
263         date = self._get_field('authdate')
264         if not date:
265             return date
266
267         if re.match('[0-9]+\s+[+-][0-9]+', date):
268             # Unix time (seconds) + time zone
269             secs_tz = date.split()
270             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
271
272         return date
273
274     def set_authdate(self, date):
275         self._set_field('authdate', date or git.author().date)
276
277     def get_commname(self):
278         return self._get_field('commname')
279
280     def set_commname(self, name):
281         self._set_field('commname', name or git.committer().name)
282
283     def get_commemail(self):
284         return self._get_field('commemail')
285
286     def set_commemail(self, email):
287         self._set_field('commemail', email or git.committer().email)
288
289     def get_log(self):
290         return self._get_field('log')
291
292     def set_log(self, value, backup = False):
293         self._set_field('log', value)
294         self.__update_log_ref(value)
295
296 # The current StGIT metadata format version.
297 FORMAT_VERSION = 2
298
299 class PatchSet(StgitObject):
300     def __init__(self, name = None):
301         try:
302             if name:
303                 self.set_name (name)
304             else:
305                 self.set_name (git.get_head_file())
306             self.__base_dir = basedir.get()
307         except git.GitException, ex:
308             raise StackException, 'GIT tree not initialised: %s' % ex
309
310         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
311
312     def get_name(self):
313         return self.__name
314     def set_name(self, name):
315         self.__name = name
316
317     def _basedir(self):
318         return self.__base_dir
319
320     def get_head(self):
321         """Return the head of the branch
322         """
323         crt = self.get_current_patch()
324         if crt:
325             return crt.get_top()
326         else:
327             return self.get_base()
328
329     def get_protected(self):
330         return os.path.isfile(os.path.join(self._dir(), 'protected'))
331
332     def protect(self):
333         protect_file = os.path.join(self._dir(), 'protected')
334         if not os.path.isfile(protect_file):
335             create_empty_file(protect_file)
336
337     def unprotect(self):
338         protect_file = os.path.join(self._dir(), 'protected')
339         if os.path.isfile(protect_file):
340             os.remove(protect_file)
341
342     def __branch_descr(self):
343         return 'branch.%s.description' % self.get_name()
344
345     def get_description(self):
346         return config.get(self.__branch_descr()) or ''
347
348     def set_description(self, line):
349         if line:
350             config.set(self.__branch_descr(), line)
351         else:
352             config.unset(self.__branch_descr())
353
354     def head_top_equal(self):
355         """Return true if the head and the top are the same
356         """
357         crt = self.get_current_patch()
358         if not crt:
359             # we don't care, no patches applied
360             return True
361         return git.get_head() == crt.get_top()
362
363     def is_initialised(self):
364         """Checks if series is already initialised
365         """
366         return bool(config.get(self.format_version_key()))
367
368
369 def shortlog(patches):
370     log = ''.join(Run('git', 'log', '--pretty=short',
371                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
372                   for p in patches)
373     return Run('git', 'shortlog').raw_input(log).raw_output()
374
375 class Series(PatchSet):
376     """Class including the operations on series
377     """
378     def __init__(self, name = None):
379         """Takes a series name as the parameter.
380         """
381         PatchSet.__init__(self, name)
382
383         # Update the branch to the latest format version if it is
384         # initialized, but don't touch it if it isn't.
385         self.update_to_current_format_version()
386
387         self.__refs_base = 'refs/patches/%s' % self.get_name()
388
389         self.__applied_file = os.path.join(self._dir(), 'applied')
390         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
391         self.__hidden_file = os.path.join(self._dir(), 'hidden')
392
393         # where this series keeps its patches
394         self.__patch_dir = os.path.join(self._dir(), 'patches')
395
396         # trash directory
397         self.__trash_dir = os.path.join(self._dir(), 'trash')
398
399     def format_version_key(self):
400         return 'branch.%s.stgit.stackformatversion' % self.get_name()
401
402     def update_to_current_format_version(self):
403         """Update a potentially older StGIT directory structure to the
404         latest version. Note: This function should depend as little as
405         possible on external functions that may change during a format
406         version bump, since it must remain able to process older formats."""
407
408         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
409         def get_format_version():
410             """Return the integer format version number, or None if the
411             branch doesn't have any StGIT metadata at all, of any version."""
412             fv = config.get(self.format_version_key())
413             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
414             if fv:
415                 # Great, there's an explicitly recorded format version
416                 # number, which means that the branch is initialized and
417                 # of that exact version.
418                 return int(fv)
419             elif ofv:
420                 # Old name for the version info, upgrade it
421                 config.set(self.format_version_key(), ofv)
422                 config.unset('branch.%s.stgitformatversion' % self.get_name())
423                 return int(ofv)
424             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
425                 # There's a .git/patches/<branch>/patches dirctory, which
426                 # means this is an initialized version 1 branch.
427                 return 1
428             elif os.path.isdir(branch_dir):
429                 # There's a .git/patches/<branch> directory, which means
430                 # this is an initialized version 0 branch.
431                 return 0
432             else:
433                 # The branch doesn't seem to be initialized at all.
434                 return None
435         def set_format_version(v):
436             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
437             config.set(self.format_version_key(), '%d' % v)
438         def mkdir(d):
439             if not os.path.isdir(d):
440                 os.makedirs(d)
441         def rm(f):
442             if os.path.exists(f):
443                 os.remove(f)
444         def rm_ref(ref):
445             if git.ref_exists(ref):
446                 git.delete_ref(ref)
447
448         # Update 0 -> 1.
449         if get_format_version() == 0:
450             mkdir(os.path.join(branch_dir, 'trash'))
451             patch_dir = os.path.join(branch_dir, 'patches')
452             mkdir(patch_dir)
453             refs_base = 'refs/patches/%s' % self.get_name()
454             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
455                           + file(os.path.join(branch_dir, 'applied')).readlines()):
456                 patch = patch.strip()
457                 os.rename(os.path.join(branch_dir, patch),
458                           os.path.join(patch_dir, patch))
459                 Patch(patch, patch_dir, refs_base).update_top_ref()
460             set_format_version(1)
461
462         # Update 1 -> 2.
463         if get_format_version() == 1:
464             desc_file = os.path.join(branch_dir, 'description')
465             if os.path.isfile(desc_file):
466                 desc = read_string(desc_file)
467                 if desc:
468                     config.set('branch.%s.description' % self.get_name(), desc)
469                 rm(desc_file)
470             rm(os.path.join(branch_dir, 'current'))
471             rm_ref('refs/bases/%s' % self.get_name())
472             set_format_version(2)
473
474         # Make sure we're at the latest version.
475         if not get_format_version() in [None, FORMAT_VERSION]:
476             raise StackException('Branch %s is at format version %d, expected %d'
477                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
478
479     def __patch_name_valid(self, name):
480         """Raise an exception if the patch name is not valid.
481         """
482         if not name or re.search('[^\w.-]', name):
483             raise StackException, 'Invalid patch name: "%s"' % name
484
485     def get_patch(self, name):
486         """Return a Patch object for the given name
487         """
488         return Patch(name, self.__patch_dir, self.__refs_base)
489
490     def get_current_patch(self):
491         """Return a Patch object representing the topmost patch, or
492         None if there is no such patch."""
493         crt = self.get_current()
494         if not crt:
495             return None
496         return self.get_patch(crt)
497
498     def get_current(self):
499         """Return the name of the topmost patch, or None if there is
500         no such patch."""
501         try:
502             applied = self.get_applied()
503         except StackException:
504             # No "applied" file: branch is not initialized.
505             return None
506         try:
507             return applied[-1]
508         except IndexError:
509             # No patches applied.
510             return None
511
512     def get_applied(self):
513         if not os.path.isfile(self.__applied_file):
514             raise StackException, 'Branch "%s" not initialised' % self.get_name()
515         return read_strings(self.__applied_file)
516
517     def set_applied(self, applied):
518         write_strings(self.__applied_file, applied)
519
520     def get_unapplied(self):
521         if not os.path.isfile(self.__unapplied_file):
522             raise StackException, 'Branch "%s" not initialised' % self.get_name()
523         return read_strings(self.__unapplied_file)
524
525     def set_unapplied(self, unapplied):
526         write_strings(self.__unapplied_file, unapplied)
527
528     def get_hidden(self):
529         if not os.path.isfile(self.__hidden_file):
530             return []
531         return read_strings(self.__hidden_file)
532
533     def get_base(self):
534         # Return the parent of the bottommost patch, if there is one.
535         if os.path.isfile(self.__applied_file):
536             bottommost = file(self.__applied_file).readline().strip()
537             if bottommost:
538                 return self.get_patch(bottommost).get_bottom()
539         # No bottommost patch, so just return HEAD
540         return git.get_head()
541
542     def get_parent_remote(self):
543         value = config.get('branch.%s.remote' % self.get_name())
544         if value:
545             return value
546         elif 'origin' in git.remotes_list():
547             out.note(('No parent remote declared for stack "%s",'
548                       ' defaulting to "origin".' % self.get_name()),
549                      ('Consider setting "branch.%s.remote" and'
550                       ' "branch.%s.merge" with "git config".'
551                       % (self.get_name(), self.get_name())))
552             return 'origin'
553         else:
554             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
555
556     def __set_parent_remote(self, remote):
557         value = config.set('branch.%s.remote' % self.get_name(), remote)
558
559     def get_parent_branch(self):
560         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
561         if value:
562             return value
563         elif git.rev_parse('heads/origin'):
564             out.note(('No parent branch declared for stack "%s",'
565                       ' defaulting to "heads/origin".' % self.get_name()),
566                      ('Consider setting "branch.%s.stgit.parentbranch"'
567                       ' with "git config".' % self.get_name()))
568             return 'heads/origin'
569         else:
570             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
571
572     def __set_parent_branch(self, name):
573         if config.get('branch.%s.remote' % self.get_name()):
574             # Never set merge if remote is not set to avoid
575             # possibly-erroneous lookups into 'origin'
576             config.set('branch.%s.merge' % self.get_name(), name)
577         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
578
579     def set_parent(self, remote, localbranch):
580         if localbranch:
581             if remote:
582                 self.__set_parent_remote(remote)
583             self.__set_parent_branch(localbranch)
584         # We'll enforce this later
585 #         else:
586 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
587
588     def __patch_is_current(self, patch):
589         return patch.get_name() == self.get_current()
590
591     def patch_applied(self, name):
592         """Return true if the patch exists in the applied list
593         """
594         return name in self.get_applied()
595
596     def patch_unapplied(self, name):
597         """Return true if the patch exists in the unapplied list
598         """
599         return name in self.get_unapplied()
600
601     def patch_hidden(self, name):
602         """Return true if the patch is hidden.
603         """
604         return name in self.get_hidden()
605
606     def patch_exists(self, name):
607         """Return true if there is a patch with the given name, false
608         otherwise."""
609         return self.patch_applied(name) or self.patch_unapplied(name) \
610                or self.patch_hidden(name)
611
612     def init(self, create_at=False, parent_remote=None, parent_branch=None):
613         """Initialises the stgit series
614         """
615         if self.is_initialised():
616             raise StackException, '%s already initialized' % self.get_name()
617         for d in [self._dir()]:
618             if os.path.exists(d):
619                 raise StackException, '%s already exists' % d
620
621         if (create_at!=False):
622             git.create_branch(self.get_name(), create_at)
623
624         os.makedirs(self.__patch_dir)
625
626         self.set_parent(parent_remote, parent_branch)
627
628         self.create_empty_field('applied')
629         self.create_empty_field('unapplied')
630
631         config.set(self.format_version_key(), str(FORMAT_VERSION))
632
633     def rename(self, to_name):
634         """Renames a series
635         """
636         to_stack = Series(to_name)
637
638         if to_stack.is_initialised():
639             raise StackException, '"%s" already exists' % to_stack.get_name()
640
641         patches = self.get_applied() + self.get_unapplied()
642
643         git.rename_branch(self.get_name(), to_name)
644
645         for patch in patches:
646             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
647                            'refs/patches/%s/%s' % (to_name, patch))
648             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
649                            'refs/patches/%s/%s.log' % (to_name, patch))
650         if os.path.isdir(self._dir()):
651             rename(os.path.join(self._basedir(), 'patches'),
652                    self.get_name(), to_stack.get_name())
653
654         # Rename the config section
655         for k in ['branch.%s', 'branch.%s.stgit']:
656             config.rename_section(k % self.get_name(), k % to_name)
657
658         self.__init__(to_name)
659
660     def clone(self, target_series):
661         """Clones a series
662         """
663         try:
664             # allow cloning of branches not under StGIT control
665             base = self.get_base()
666         except:
667             base = git.get_head()
668         Series(target_series).init(create_at = base)
669         new_series = Series(target_series)
670
671         # generate an artificial description file
672         new_series.set_description('clone of "%s"' % self.get_name())
673
674         # clone self's entire series as unapplied patches
675         try:
676             # allow cloning of branches not under StGIT control
677             applied = self.get_applied()
678             unapplied = self.get_unapplied()
679             patches = applied + unapplied
680             patches.reverse()
681         except:
682             patches = applied = unapplied = []
683         for p in patches:
684             patch = self.get_patch(p)
685             newpatch = new_series.new_patch(p, message = patch.get_description(),
686                                             can_edit = False, unapplied = True,
687                                             bottom = patch.get_bottom(),
688                                             top = patch.get_top(),
689                                             author_name = patch.get_authname(),
690                                             author_email = patch.get_authemail(),
691                                             author_date = patch.get_authdate())
692             if patch.get_log():
693                 out.info('Setting log to %s' %  patch.get_log())
694                 newpatch.set_log(patch.get_log())
695             else:
696                 out.info('No log for %s' % p)
697
698         # fast forward the cloned series to self's top
699         new_series.forward_patches(applied)
700
701         # Clone parent informations
702         value = config.get('branch.%s.remote' % self.get_name())
703         if value:
704             config.set('branch.%s.remote' % target_series, value)
705
706         value = config.get('branch.%s.merge' % self.get_name())
707         if value:
708             config.set('branch.%s.merge' % target_series, value)
709
710         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
711         if value:
712             config.set('branch.%s.stgit.parentbranch' % target_series, value)
713
714     def delete(self, force = False):
715         """Deletes an stgit series
716         """
717         if self.is_initialised():
718             patches = self.get_unapplied() + self.get_applied()
719             if not force and patches:
720                 raise StackException, \
721                       'Cannot delete: the series still contains patches'
722             for p in patches:
723                 self.get_patch(p).delete()
724
725             # remove the trash directory if any
726             if os.path.exists(self.__trash_dir):
727                 for fname in os.listdir(self.__trash_dir):
728                     os.remove(os.path.join(self.__trash_dir, fname))
729                 os.rmdir(self.__trash_dir)
730
731             # FIXME: find a way to get rid of those manual removals
732             # (move functionality to StgitObject ?)
733             if os.path.exists(self.__applied_file):
734                 os.remove(self.__applied_file)
735             if os.path.exists(self.__unapplied_file):
736                 os.remove(self.__unapplied_file)
737             if os.path.exists(self.__hidden_file):
738                 os.remove(self.__hidden_file)
739             if os.path.exists(self._dir()+'/orig-base'):
740                 os.remove(self._dir()+'/orig-base')
741
742             if not os.listdir(self.__patch_dir):
743                 os.rmdir(self.__patch_dir)
744             else:
745                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
746
747             try:
748                 os.removedirs(self._dir())
749             except OSError:
750                 raise StackException('Series directory %s is not empty'
751                                      % self._dir())
752
753         try:
754             git.delete_branch(self.get_name())
755         except GitException:
756             out.warn('Could not delete branch "%s"' % self.get_name())
757
758         config.remove_section('branch.%s' % self.get_name())
759         config.remove_section('branch.%s.stgit' % self.get_name())
760
761     def refresh_patch(self, files = None, message = None, edit = False,
762                       show_patch = False,
763                       cache_update = True,
764                       author_name = None, author_email = None,
765                       author_date = None,
766                       committer_name = None, committer_email = None,
767                       backup = True, sign_str = None, log = 'refresh',
768                       notes = None, bottom = None):
769         """Generates a new commit for the topmost patch
770         """
771         patch = self.get_current_patch()
772         if not patch:
773             raise StackException, 'No patches applied'
774
775         descr = patch.get_description()
776         if not (message or descr):
777             edit = True
778             descr = ''
779         elif message:
780             descr = message
781
782         # TODO: move this out of the stgit.stack module, it is really
783         # for higher level commands to handle the user interaction
784         if not message and edit:
785             descr = edit_file(self, descr.rstrip(), \
786                               'Please edit the description for patch "%s" ' \
787                               'above.' % patch.get_name(), show_patch)
788
789         if not author_name:
790             author_name = patch.get_authname()
791         if not author_email:
792             author_email = patch.get_authemail()
793         if not author_date:
794             author_date = patch.get_authdate()
795         if not committer_name:
796             committer_name = patch.get_commname()
797         if not committer_email:
798             committer_email = patch.get_commemail()
799
800         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
801
802         if not bottom:
803             bottom = patch.get_bottom()
804
805         commit_id = git.commit(files = files,
806                                message = descr, parents = [bottom],
807                                cache_update = cache_update,
808                                allowempty = True,
809                                author_name = author_name,
810                                author_email = author_email,
811                                author_date = author_date,
812                                committer_name = committer_name,
813                                committer_email = committer_email)
814
815         patch.set_top(commit_id, backup = backup)
816         patch.set_description(descr)
817         patch.set_authname(author_name)
818         patch.set_authemail(author_email)
819         patch.set_authdate(author_date)
820         patch.set_commname(committer_name)
821         patch.set_commemail(committer_email)
822
823         if log:
824             self.log_patch(patch, log, notes)
825
826         return commit_id
827
828     def undo_refresh(self):
829         """Undo the patch boundaries changes caused by 'refresh'
830         """
831         name = self.get_current()
832         assert(name)
833
834         patch = self.get_patch(name)
835         old_bottom = patch.get_old_bottom()
836         old_top = patch.get_old_top()
837
838         # the bottom of the patch is not changed by refresh. If the
839         # old_bottom is different, there wasn't any previous 'refresh'
840         # command (probably only a 'push')
841         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
842             raise StackException, 'No undo information available'
843
844         git.reset(tree_id = old_top, check_out = False)
845         if patch.restore_old_boundaries():
846             self.log_patch(patch, 'undo')
847
848     def new_patch(self, name, message = None, can_edit = True,
849                   unapplied = False, show_patch = False,
850                   top = None, bottom = None, commit = True,
851                   author_name = None, author_email = None, author_date = None,
852                   committer_name = None, committer_email = None,
853                   before_existing = False, sign_str = None):
854         """Creates a new patch, either pointing to an existing commit object,
855         or by creating a new commit object.
856         """
857
858         assert commit or (top and bottom)
859         assert not before_existing or (top and bottom)
860         assert not (commit and before_existing)
861         assert (top and bottom) or (not top and not bottom)
862         assert commit or (not top or (bottom == git.get_commit(top).get_parent()))
863
864         if name != None:
865             self.__patch_name_valid(name)
866             if self.patch_exists(name):
867                 raise StackException, 'Patch "%s" already exists' % name
868
869         # TODO: move this out of the stgit.stack module, it is really
870         # for higher level commands to handle the user interaction
871         def sign(msg):
872             return add_sign_line(msg, sign_str,
873                                  committer_name or git.committer().name,
874                                  committer_email or git.committer().email)
875         if not message and can_edit:
876             descr = edit_file(
877                 self, sign(''),
878                 'Please enter the description for the patch above.',
879                 show_patch)
880         else:
881             descr = sign(message)
882
883         head = git.get_head()
884
885         if name == None:
886             name = make_patch_name(descr, self.patch_exists)
887
888         patch = self.get_patch(name)
889         patch.create()
890
891         patch.set_description(descr)
892         patch.set_authname(author_name)
893         patch.set_authemail(author_email)
894         patch.set_authdate(author_date)
895         patch.set_commname(committer_name)
896         patch.set_commemail(committer_email)
897
898         if before_existing:
899             insert_string(self.__applied_file, patch.get_name())
900         elif unapplied:
901             patches = [patch.get_name()] + self.get_unapplied()
902             write_strings(self.__unapplied_file, patches)
903             set_head = False
904         else:
905             append_string(self.__applied_file, patch.get_name())
906             set_head = True
907
908         if commit:
909             if top:
910                 top_commit = git.get_commit(top)
911             else:
912                 bottom = head
913                 top_commit = git.get_commit(head)
914
915             # create a commit for the patch (may be empty if top == bottom);
916             # only commit on top of the current branch
917             assert(unapplied or bottom == head)
918             commit_id = git.commit(message = descr, parents = [bottom],
919                                    cache_update = False,
920                                    tree_id = top_commit.get_tree(),
921                                    allowempty = True, set_head = set_head,
922                                    author_name = author_name,
923                                    author_email = author_email,
924                                    author_date = author_date,
925                                    committer_name = committer_name,
926                                    committer_email = committer_email)
927             # set the patch top to the new commit
928             patch.set_top(commit_id)
929         else:
930             patch.set_top(top)
931
932         self.log_patch(patch, 'new')
933
934         return patch
935
936     def delete_patch(self, name, keep_log = False):
937         """Deletes a patch
938         """
939         self.__patch_name_valid(name)
940         patch = self.get_patch(name)
941
942         if self.__patch_is_current(patch):
943             self.pop_patch(name)
944         elif self.patch_applied(name):
945             raise StackException, 'Cannot remove an applied patch, "%s", ' \
946                   'which is not current' % name
947         elif not name in self.get_unapplied():
948             raise StackException, 'Unknown patch "%s"' % name
949
950         # save the commit id to a trash file
951         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
952
953         patch.delete(keep_log = keep_log)
954
955         unapplied = self.get_unapplied()
956         unapplied.remove(name)
957         write_strings(self.__unapplied_file, unapplied)
958
959     def forward_patches(self, names):
960         """Try to fast-forward an array of patches.
961
962         On return, patches in names[0:returned_value] have been pushed on the
963         stack. Apply the rest with push_patch
964         """
965         unapplied = self.get_unapplied()
966
967         forwarded = 0
968         top = git.get_head()
969
970         for name in names:
971             assert(name in unapplied)
972
973             patch = self.get_patch(name)
974
975             head = top
976             bottom = patch.get_bottom()
977             top = patch.get_top()
978
979             # top != bottom always since we have a commit for each patch
980             if head == bottom:
981                 # reset the backup information. No logging since the
982                 # patch hasn't changed
983                 patch.set_top(top, backup = True)
984
985             else:
986                 head_tree = git.get_commit(head).get_tree()
987                 bottom_tree = git.get_commit(bottom).get_tree()
988                 if head_tree == bottom_tree:
989                     # We must just reparent this patch and create a new commit
990                     # for it
991                     descr = patch.get_description()
992                     author_name = patch.get_authname()
993                     author_email = patch.get_authemail()
994                     author_date = patch.get_authdate()
995                     committer_name = patch.get_commname()
996                     committer_email = patch.get_commemail()
997
998                     top_tree = git.get_commit(top).get_tree()
999
1000                     top = git.commit(message = descr, parents = [head],
1001                                      cache_update = False,
1002                                      tree_id = top_tree,
1003                                      allowempty = True,
1004                                      author_name = author_name,
1005                                      author_email = author_email,
1006                                      author_date = author_date,
1007                                      committer_name = committer_name,
1008                                      committer_email = committer_email)
1009
1010                     patch.set_top(top, backup = True)
1011
1012                     self.log_patch(patch, 'push(f)')
1013                 else:
1014                     top = head
1015                     # stop the fast-forwarding, must do a real merge
1016                     break
1017
1018             forwarded+=1
1019             unapplied.remove(name)
1020
1021         if forwarded == 0:
1022             return 0
1023
1024         git.switch(top)
1025
1026         append_strings(self.__applied_file, names[0:forwarded])
1027         write_strings(self.__unapplied_file, unapplied)
1028
1029         return forwarded
1030
1031     def merged_patches(self, names):
1032         """Test which patches were merged upstream by reverse-applying
1033         them in reverse order. The function returns the list of
1034         patches detected to have been applied. The state of the tree
1035         is restored to the original one
1036         """
1037         patches = [self.get_patch(name) for name in names]
1038         patches.reverse()
1039
1040         merged = []
1041         for p in patches:
1042             if git.apply_diff(p.get_top(), p.get_bottom()):
1043                 merged.append(p.get_name())
1044         merged.reverse()
1045
1046         git.reset()
1047
1048         return merged
1049
1050     def push_empty_patch(self, name):
1051         """Pushes an empty patch on the stack
1052         """
1053         unapplied = self.get_unapplied()
1054         assert(name in unapplied)
1055
1056         # patch = self.get_patch(name)
1057         head = git.get_head()
1058
1059         append_string(self.__applied_file, name)
1060
1061         unapplied.remove(name)
1062         write_strings(self.__unapplied_file, unapplied)
1063
1064         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1065
1066     def push_patch(self, name):
1067         """Pushes a patch on the stack
1068         """
1069         unapplied = self.get_unapplied()
1070         assert(name in unapplied)
1071
1072         patch = self.get_patch(name)
1073
1074         head = git.get_head()
1075         bottom = patch.get_bottom()
1076         top = patch.get_top()
1077         # top != bottom always since we have a commit for each patch
1078
1079         if head == bottom:
1080             # A fast-forward push. Just reset the backup
1081             # information. No need for logging
1082             patch.set_top(top, backup = True)
1083
1084             git.switch(top)
1085             append_string(self.__applied_file, name)
1086
1087             unapplied.remove(name)
1088             write_strings(self.__unapplied_file, unapplied)
1089             return False
1090
1091         # Need to create a new commit an merge in the old patch
1092         ex = None
1093         modified = False
1094
1095         # Try the fast applying first. If this fails, fall back to the
1096         # three-way merge
1097         if not git.apply_diff(bottom, top):
1098             # if git.apply_diff() fails, the patch requires a diff3
1099             # merge and can be reported as modified
1100             modified = True
1101
1102             # merge can fail but the patch needs to be pushed
1103             try:
1104                 git.merge(bottom, head, top, recursive = True)
1105             except git.GitException, ex:
1106                 out.error('The merge failed during "push".',
1107                           'Use "refresh" after fixing the conflicts or'
1108                           ' revert the operation with "push --undo".')
1109
1110         append_string(self.__applied_file, name)
1111
1112         unapplied.remove(name)
1113         write_strings(self.__unapplied_file, unapplied)
1114
1115         if not ex:
1116             # if the merge was OK and no conflicts, just refresh the patch
1117             # The GIT cache was already updated by the merge operation
1118             if modified:
1119                 log = 'push(m)'
1120             else:
1121                 log = 'push'
1122             self.refresh_patch(bottom = head, cache_update = False, log = log)
1123         else:
1124             # we store the correctly merged files only for
1125             # tracking the conflict history. Note that the
1126             # git.merge() operations should always leave the index
1127             # in a valid state (i.e. only stage 0 files)
1128             self.refresh_patch(bottom = head, cache_update = False,
1129                                log = 'push(c)')
1130             raise StackException, str(ex)
1131
1132         return modified
1133
1134     def undo_push(self):
1135         name = self.get_current()
1136         assert(name)
1137
1138         patch = self.get_patch(name)
1139         old_bottom = patch.get_old_bottom()
1140         old_top = patch.get_old_top()
1141
1142         # the top of the patch is changed by a push operation only
1143         # together with the bottom (otherwise the top was probably
1144         # modified by 'refresh'). If they are both unchanged, there
1145         # was a fast forward
1146         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1147             raise StackException, 'No undo information available'
1148
1149         git.reset()
1150         self.pop_patch(name)
1151         ret = patch.restore_old_boundaries()
1152         if ret:
1153             self.log_patch(patch, 'undo')
1154
1155         return ret
1156
1157     def pop_patch(self, name, keep = False):
1158         """Pops the top patch from the stack
1159         """
1160         applied = self.get_applied()
1161         applied.reverse()
1162         assert(name in applied)
1163
1164         patch = self.get_patch(name)
1165
1166         if git.get_head_file() == self.get_name():
1167             if keep and not git.apply_diff(git.get_head(), patch.get_bottom(),
1168                                            check_index = False):
1169                 raise StackException(
1170                     'Failed to pop patches while preserving the local changes')
1171             git.switch(patch.get_bottom(), keep)
1172         else:
1173             git.set_branch(self.get_name(), patch.get_bottom())
1174
1175         # save the new applied list
1176         idx = applied.index(name) + 1
1177
1178         popped = applied[:idx]
1179         popped.reverse()
1180         unapplied = popped + self.get_unapplied()
1181         write_strings(self.__unapplied_file, unapplied)
1182
1183         del applied[:idx]
1184         applied.reverse()
1185         write_strings(self.__applied_file, applied)
1186
1187     def empty_patch(self, name):
1188         """Returns True if the patch is empty
1189         """
1190         self.__patch_name_valid(name)
1191         patch = self.get_patch(name)
1192         bottom = patch.get_bottom()
1193         top = patch.get_top()
1194
1195         if bottom == top:
1196             return True
1197         elif git.get_commit(top).get_tree() \
1198                  == git.get_commit(bottom).get_tree():
1199             return True
1200
1201         return False
1202
1203     def rename_patch(self, oldname, newname):
1204         self.__patch_name_valid(newname)
1205
1206         applied = self.get_applied()
1207         unapplied = self.get_unapplied()
1208
1209         if oldname == newname:
1210             raise StackException, '"To" name and "from" name are the same'
1211
1212         if newname in applied or newname in unapplied:
1213             raise StackException, 'Patch "%s" already exists' % newname
1214
1215         if oldname in unapplied:
1216             self.get_patch(oldname).rename(newname)
1217             unapplied[unapplied.index(oldname)] = newname
1218             write_strings(self.__unapplied_file, unapplied)
1219         elif oldname in applied:
1220             self.get_patch(oldname).rename(newname)
1221
1222             applied[applied.index(oldname)] = newname
1223             write_strings(self.__applied_file, applied)
1224         else:
1225             raise StackException, 'Unknown patch "%s"' % oldname
1226
1227     def log_patch(self, patch, message, notes = None):
1228         """Generate a log commit for a patch
1229         """
1230         top = git.get_commit(patch.get_top())
1231         old_log = patch.get_log()
1232
1233         if message is None:
1234             # replace the current log entry
1235             if not old_log:
1236                 raise StackException, \
1237                       'No log entry to annotate for patch "%s"' \
1238                       % patch.get_name()
1239             replace = True
1240             log_commit = git.get_commit(old_log)
1241             msg = log_commit.get_log().split('\n')[0]
1242             log_parent = log_commit.get_parent()
1243             if log_parent:
1244                 parents = [log_parent]
1245             else:
1246                 parents = []
1247         else:
1248             # generate a new log entry
1249             replace = False
1250             msg = '%s\t%s' % (message, top.get_id_hash())
1251             if old_log:
1252                 parents = [old_log]
1253             else:
1254                 parents = []
1255
1256         if notes:
1257             msg += '\n\n' + notes
1258
1259         log = git.commit(message = msg, parents = parents,
1260                          cache_update = False, tree_id = top.get_tree(),
1261                          allowempty = True)
1262         patch.set_log(log)
1263
1264     def hide_patch(self, name):
1265         """Add the patch to the hidden list.
1266         """
1267         unapplied = self.get_unapplied()
1268         if name not in unapplied:
1269             # keep the checking order for backward compatibility with
1270             # the old hidden patches functionality
1271             if self.patch_applied(name):
1272                 raise StackException, 'Cannot hide applied patch "%s"' % name
1273             elif self.patch_hidden(name):
1274                 raise StackException, 'Patch "%s" already hidden' % name
1275             else:
1276                 raise StackException, 'Unknown patch "%s"' % name
1277
1278         if not self.patch_hidden(name):
1279             # check needed for backward compatibility with the old
1280             # hidden patches functionality
1281             append_string(self.__hidden_file, name)
1282
1283         unapplied.remove(name)
1284         write_strings(self.__unapplied_file, unapplied)
1285
1286     def unhide_patch(self, name):
1287         """Remove the patch from the hidden list.
1288         """
1289         hidden = self.get_hidden()
1290         if not name in hidden:
1291             if self.patch_applied(name) or self.patch_unapplied(name):
1292                 raise StackException, 'Patch "%s" not hidden' % name
1293             else:
1294                 raise StackException, 'Unknown patch "%s"' % name
1295
1296         hidden.remove(name)
1297         write_strings(self.__hidden_file, hidden)
1298
1299         if not self.patch_applied(name) and not self.patch_unapplied(name):
1300             # check needed for backward compatibility with the old
1301             # hidden patches functionality
1302             append_string(self.__unapplied_file, name)