chiark / gitweb /
94856b8a5210d639d2973bbc8e7420b0a4f1e96e
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.exception import *
25 from stgit.utils import *
26 from stgit.out import *
27 from stgit.run import *
28 from stgit import git, basedir, templates
29 from stgit.config import config
30 from shutil import copyfile
31
32
33 # stack exception class
34 class StackException(StgException):
35     pass
36
37 class FilterUntil:
38     def __init__(self):
39         self.should_print = True
40     def __call__(self, x, until_test, prefix):
41         if until_test(x):
42             self.should_print = False
43         if self.should_print:
44             return x[0:len(prefix)] != prefix
45         return False
46
47 #
48 # Functions
49 #
50 __comment_prefix = 'STG:'
51 __patch_prefix = 'STG_PATCH:'
52
53 def __clean_comments(f):
54     """Removes lines marked for status in a commit file
55     """
56     f.seek(0)
57
58     # remove status-prefixed lines
59     lines = f.readlines()
60
61     patch_filter = FilterUntil()
62     until_test = lambda t: t == (__patch_prefix + '\n')
63     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
64
65     # remove empty lines at the end
66     while len(lines) != 0 and lines[-1] == '\n':
67         del lines[-1]
68
69     f.seek(0); f.truncate()
70     f.writelines(lines)
71
72 # TODO: move this out of the stgit.stack module, it is really for
73 # higher level commands to handle the user interaction
74 def edit_file(series, line, comment, show_patch = True):
75     fname = '.stgitmsg.txt'
76     tmpl = templates.get_template('patchdescr.tmpl')
77
78     f = file(fname, 'w+')
79     if line:
80         print >> f, line
81     elif tmpl:
82         print >> f, tmpl,
83     else:
84         print >> f
85     print >> f, __comment_prefix, comment
86     print >> f, __comment_prefix, \
87           'Lines prefixed with "%s" will be automatically removed.' \
88           % __comment_prefix
89     print >> f, __comment_prefix, \
90           'Trailing empty lines will be automatically removed.'
91
92     if show_patch:
93        print >> f, __patch_prefix
94        # series.get_patch(series.get_current()).get_top()
95        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
96        f.write(diff_str)
97
98     #Vim modeline must be near the end.
99     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
100     f.close()
101
102     call_editor(fname)
103
104     f = file(fname, 'r+')
105
106     __clean_comments(f)
107     f.seek(0)
108     result = f.read()
109
110     f.close()
111     os.remove(fname)
112
113     return result
114
115 #
116 # Classes
117 #
118
119 class StgitObject:
120     """An object with stgit-like properties stored as files in a directory
121     """
122     def _set_dir(self, dir):
123         self.__dir = dir
124     def _dir(self):
125         return self.__dir
126
127     def create_empty_field(self, name):
128         create_empty_file(os.path.join(self.__dir, name))
129
130     def _get_field(self, name, multiline = False):
131         id_file = os.path.join(self.__dir, name)
132         if os.path.isfile(id_file):
133             line = read_string(id_file, multiline)
134             if line == '':
135                 return None
136             else:
137                 return line
138         else:
139             return None
140
141     def _set_field(self, name, value, multiline = False):
142         fname = os.path.join(self.__dir, name)
143         if value and value != '':
144             write_string(fname, value, multiline)
145         elif os.path.isfile(fname):
146             os.remove(fname)
147
148     
149 class Patch(StgitObject):
150     """Basic patch implementation
151     """
152     def __init_refs(self):
153         self.__top_ref = self.__refs_base + '/' + self.__name
154         self.__log_ref = self.__top_ref + '.log'
155
156     def __init__(self, name, series_dir, refs_base):
157         self.__series_dir = series_dir
158         self.__name = name
159         self._set_dir(os.path.join(self.__series_dir, self.__name))
160         self.__refs_base = refs_base
161         self.__init_refs()
162
163     def create(self):
164         os.mkdir(self._dir())
165         self.create_empty_field('bottom')
166         self.create_empty_field('top')
167
168     def delete(self):
169         for f in os.listdir(self._dir()):
170             os.remove(os.path.join(self._dir(), f))
171         os.rmdir(self._dir())
172         git.delete_ref(self.__top_ref)
173         if git.ref_exists(self.__log_ref):
174             git.delete_ref(self.__log_ref)
175
176     def get_name(self):
177         return self.__name
178
179     def rename(self, newname):
180         olddir = self._dir()
181         old_top_ref = self.__top_ref
182         old_log_ref = self.__log_ref
183         self.__name = newname
184         self._set_dir(os.path.join(self.__series_dir, self.__name))
185         self.__init_refs()
186
187         git.rename_ref(old_top_ref, self.__top_ref)
188         if git.ref_exists(old_log_ref):
189             git.rename_ref(old_log_ref, self.__log_ref)
190         os.rename(olddir, self._dir())
191
192     def __update_top_ref(self, ref):
193         git.set_ref(self.__top_ref, ref)
194
195     def __update_log_ref(self, ref):
196         git.set_ref(self.__log_ref, ref)
197
198     def update_top_ref(self):
199         top = self.get_top()
200         if top:
201             self.__update_top_ref(top)
202
203     def get_old_bottom(self):
204         return self._get_field('bottom.old')
205
206     def get_bottom(self):
207         return self._get_field('bottom')
208
209     def set_bottom(self, value, backup = False):
210         if backup:
211             curr = self._get_field('bottom')
212             self._set_field('bottom.old', curr)
213         self._set_field('bottom', value)
214
215     def get_old_top(self):
216         return self._get_field('top.old')
217
218     def get_top(self):
219         return self._get_field('top')
220
221     def set_top(self, value, backup = False):
222         if backup:
223             curr = self._get_field('top')
224             self._set_field('top.old', curr)
225         self._set_field('top', value)
226         self.__update_top_ref(value)
227
228     def restore_old_boundaries(self):
229         bottom = self._get_field('bottom.old')
230         top = self._get_field('top.old')
231
232         if top and bottom:
233             self._set_field('bottom', bottom)
234             self._set_field('top', top)
235             self.__update_top_ref(top)
236             return True
237         else:
238             return False
239
240     def get_description(self):
241         return self._get_field('description', True)
242
243     def set_description(self, line):
244         self._set_field('description', line, True)
245
246     def get_authname(self):
247         return self._get_field('authname')
248
249     def set_authname(self, name):
250         self._set_field('authname', name or git.author().name)
251
252     def get_authemail(self):
253         return self._get_field('authemail')
254
255     def set_authemail(self, email):
256         self._set_field('authemail', email or git.author().email)
257
258     def get_authdate(self):
259         date = self._get_field('authdate')
260         if not date:
261             return date
262
263         if re.match('[0-9]+\s+[+-][0-9]+', date):
264             # Unix time (seconds) + time zone
265             secs_tz = date.split()
266             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
267
268         return date
269
270     def set_authdate(self, date):
271         self._set_field('authdate', date or git.author().date)
272
273     def get_commname(self):
274         return self._get_field('commname')
275
276     def set_commname(self, name):
277         self._set_field('commname', name or git.committer().name)
278
279     def get_commemail(self):
280         return self._get_field('commemail')
281
282     def set_commemail(self, email):
283         self._set_field('commemail', email or git.committer().email)
284
285     def get_log(self):
286         return self._get_field('log')
287
288     def set_log(self, value, backup = False):
289         self._set_field('log', value)
290         self.__update_log_ref(value)
291
292 # The current StGIT metadata format version.
293 FORMAT_VERSION = 2
294
295 class PatchSet(StgitObject):
296     def __init__(self, name = None):
297         try:
298             if name:
299                 self.set_name (name)
300             else:
301                 self.set_name (git.get_head_file())
302             self.__base_dir = basedir.get()
303         except git.GitException, ex:
304             raise StackException, 'GIT tree not initialised: %s' % ex
305
306         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
307
308     def get_name(self):
309         return self.__name
310     def set_name(self, name):
311         self.__name = name
312
313     def _basedir(self):
314         return self.__base_dir
315
316     def get_head(self):
317         """Return the head of the branch
318         """
319         crt = self.get_current_patch()
320         if crt:
321             return crt.get_top()
322         else:
323             return self.get_base()
324
325     def get_protected(self):
326         return os.path.isfile(os.path.join(self._dir(), 'protected'))
327
328     def protect(self):
329         protect_file = os.path.join(self._dir(), 'protected')
330         if not os.path.isfile(protect_file):
331             create_empty_file(protect_file)
332
333     def unprotect(self):
334         protect_file = os.path.join(self._dir(), 'protected')
335         if os.path.isfile(protect_file):
336             os.remove(protect_file)
337
338     def __branch_descr(self):
339         return 'branch.%s.description' % self.get_name()
340
341     def get_description(self):
342         return config.get(self.__branch_descr()) or ''
343
344     def set_description(self, line):
345         if line:
346             config.set(self.__branch_descr(), line)
347         else:
348             config.unset(self.__branch_descr())
349
350     def head_top_equal(self):
351         """Return true if the head and the top are the same
352         """
353         crt = self.get_current_patch()
354         if not crt:
355             # we don't care, no patches applied
356             return True
357         return git.get_head() == crt.get_top()
358
359     def is_initialised(self):
360         """Checks if series is already initialised
361         """
362         return bool(config.get(self.format_version_key()))
363
364
365 def shortlog(patches):
366     log = ''.join(Run('git-log', '--pretty=short',
367                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
368                   for p in patches)
369     return Run('git-shortlog').raw_input(log).raw_output()
370
371 class Series(PatchSet):
372     """Class including the operations on series
373     """
374     def __init__(self, name = None):
375         """Takes a series name as the parameter.
376         """
377         PatchSet.__init__(self, name)
378
379         # Update the branch to the latest format version if it is
380         # initialized, but don't touch it if it isn't.
381         self.update_to_current_format_version()
382
383         self.__refs_base = 'refs/patches/%s' % self.get_name()
384
385         self.__applied_file = os.path.join(self._dir(), 'applied')
386         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
387         self.__hidden_file = os.path.join(self._dir(), 'hidden')
388
389         # where this series keeps its patches
390         self.__patch_dir = os.path.join(self._dir(), 'patches')
391
392         # trash directory
393         self.__trash_dir = os.path.join(self._dir(), 'trash')
394
395     def format_version_key(self):
396         return 'branch.%s.stgit.stackformatversion' % self.get_name()
397
398     def update_to_current_format_version(self):
399         """Update a potentially older StGIT directory structure to the
400         latest version. Note: This function should depend as little as
401         possible on external functions that may change during a format
402         version bump, since it must remain able to process older formats."""
403
404         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
405         def get_format_version():
406             """Return the integer format version number, or None if the
407             branch doesn't have any StGIT metadata at all, of any version."""
408             fv = config.get(self.format_version_key())
409             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
410             if fv:
411                 # Great, there's an explicitly recorded format version
412                 # number, which means that the branch is initialized and
413                 # of that exact version.
414                 return int(fv)
415             elif ofv:
416                 # Old name for the version info, upgrade it
417                 config.set(self.format_version_key(), ofv)
418                 config.unset('branch.%s.stgitformatversion' % self.get_name())
419                 return int(ofv)
420             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
421                 # There's a .git/patches/<branch>/patches dirctory, which
422                 # means this is an initialized version 1 branch.
423                 return 1
424             elif os.path.isdir(branch_dir):
425                 # There's a .git/patches/<branch> directory, which means
426                 # this is an initialized version 0 branch.
427                 return 0
428             else:
429                 # The branch doesn't seem to be initialized at all.
430                 return None
431         def set_format_version(v):
432             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
433             config.set(self.format_version_key(), '%d' % v)
434         def mkdir(d):
435             if not os.path.isdir(d):
436                 os.makedirs(d)
437         def rm(f):
438             if os.path.exists(f):
439                 os.remove(f)
440         def rm_ref(ref):
441             if git.ref_exists(ref):
442                 git.delete_ref(ref)
443
444         # Update 0 -> 1.
445         if get_format_version() == 0:
446             mkdir(os.path.join(branch_dir, 'trash'))
447             patch_dir = os.path.join(branch_dir, 'patches')
448             mkdir(patch_dir)
449             refs_base = 'refs/patches/%s' % self.get_name()
450             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
451                           + file(os.path.join(branch_dir, 'applied')).readlines()):
452                 patch = patch.strip()
453                 os.rename(os.path.join(branch_dir, patch),
454                           os.path.join(patch_dir, patch))
455                 Patch(patch, patch_dir, refs_base).update_top_ref()
456             set_format_version(1)
457
458         # Update 1 -> 2.
459         if get_format_version() == 1:
460             desc_file = os.path.join(branch_dir, 'description')
461             if os.path.isfile(desc_file):
462                 desc = read_string(desc_file)
463                 if desc:
464                     config.set('branch.%s.description' % self.get_name(), desc)
465                 rm(desc_file)
466             rm(os.path.join(branch_dir, 'current'))
467             rm_ref('refs/bases/%s' % self.get_name())
468             set_format_version(2)
469
470         # Make sure we're at the latest version.
471         if not get_format_version() in [None, FORMAT_VERSION]:
472             raise StackException('Branch %s is at format version %d, expected %d'
473                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
474
475     def __patch_name_valid(self, name):
476         """Raise an exception if the patch name is not valid.
477         """
478         if not name or re.search('[^\w.-]', name):
479             raise StackException, 'Invalid patch name: "%s"' % name
480
481     def get_patch(self, name):
482         """Return a Patch object for the given name
483         """
484         return Patch(name, self.__patch_dir, self.__refs_base)
485
486     def get_current_patch(self):
487         """Return a Patch object representing the topmost patch, or
488         None if there is no such patch."""
489         crt = self.get_current()
490         if not crt:
491             return None
492         return self.get_patch(crt)
493
494     def get_current(self):
495         """Return the name of the topmost patch, or None if there is
496         no such patch."""
497         try:
498             applied = self.get_applied()
499         except StackException:
500             # No "applied" file: branch is not initialized.
501             return None
502         try:
503             return applied[-1]
504         except IndexError:
505             # No patches applied.
506             return None
507
508     def get_applied(self):
509         if not os.path.isfile(self.__applied_file):
510             raise StackException, 'Branch "%s" not initialised' % self.get_name()
511         return read_strings(self.__applied_file)
512
513     def set_applied(self, applied):
514         write_strings(self.__applied_file, applied)
515
516     def get_unapplied(self):
517         if not os.path.isfile(self.__unapplied_file):
518             raise StackException, 'Branch "%s" not initialised' % self.get_name()
519         return read_strings(self.__unapplied_file)
520
521     def set_unapplied(self, unapplied):
522         write_strings(self.__unapplied_file, unapplied)
523
524     def get_hidden(self):
525         if not os.path.isfile(self.__hidden_file):
526             return []
527         return read_strings(self.__hidden_file)
528
529     def get_base(self):
530         # Return the parent of the bottommost patch, if there is one.
531         if os.path.isfile(self.__applied_file):
532             bottommost = file(self.__applied_file).readline().strip()
533             if bottommost:
534                 return self.get_patch(bottommost).get_bottom()
535         # No bottommost patch, so just return HEAD
536         return git.get_head()
537
538     def get_parent_remote(self):
539         value = config.get('branch.%s.remote' % self.get_name())
540         if value:
541             return value
542         elif 'origin' in git.remotes_list():
543             out.note(('No parent remote declared for stack "%s",'
544                       ' defaulting to "origin".' % self.get_name()),
545                      ('Consider setting "branch.%s.remote" and'
546                       ' "branch.%s.merge" with "git config".'
547                       % (self.get_name(), self.get_name())))
548             return 'origin'
549         else:
550             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
551
552     def __set_parent_remote(self, remote):
553         value = config.set('branch.%s.remote' % self.get_name(), remote)
554
555     def get_parent_branch(self):
556         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
557         if value:
558             return value
559         elif git.rev_parse('heads/origin'):
560             out.note(('No parent branch declared for stack "%s",'
561                       ' defaulting to "heads/origin".' % self.get_name()),
562                      ('Consider setting "branch.%s.stgit.parentbranch"'
563                       ' with "git config".' % self.get_name()))
564             return 'heads/origin'
565         else:
566             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
567
568     def __set_parent_branch(self, name):
569         if config.get('branch.%s.remote' % self.get_name()):
570             # Never set merge if remote is not set to avoid
571             # possibly-erroneous lookups into 'origin'
572             config.set('branch.%s.merge' % self.get_name(), name)
573         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
574
575     def set_parent(self, remote, localbranch):
576         if localbranch:
577             if remote:
578                 self.__set_parent_remote(remote)
579             self.__set_parent_branch(localbranch)
580         # We'll enforce this later
581 #         else:
582 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
583
584     def __patch_is_current(self, patch):
585         return patch.get_name() == self.get_current()
586
587     def patch_applied(self, name):
588         """Return true if the patch exists in the applied list
589         """
590         return name in self.get_applied()
591
592     def patch_unapplied(self, name):
593         """Return true if the patch exists in the unapplied list
594         """
595         return name in self.get_unapplied()
596
597     def patch_hidden(self, name):
598         """Return true if the patch is hidden.
599         """
600         return name in self.get_hidden()
601
602     def patch_exists(self, name):
603         """Return true if there is a patch with the given name, false
604         otherwise."""
605         return self.patch_applied(name) or self.patch_unapplied(name) \
606                or self.patch_hidden(name)
607
608     def init(self, create_at=False, parent_remote=None, parent_branch=None):
609         """Initialises the stgit series
610         """
611         if self.is_initialised():
612             raise StackException, '%s already initialized' % self.get_name()
613         for d in [self._dir()]:
614             if os.path.exists(d):
615                 raise StackException, '%s already exists' % d
616
617         if (create_at!=False):
618             git.create_branch(self.get_name(), create_at)
619
620         os.makedirs(self.__patch_dir)
621
622         self.set_parent(parent_remote, parent_branch)
623
624         self.create_empty_field('applied')
625         self.create_empty_field('unapplied')
626
627         config.set(self.format_version_key(), str(FORMAT_VERSION))
628
629     def rename(self, to_name):
630         """Renames a series
631         """
632         to_stack = Series(to_name)
633
634         if to_stack.is_initialised():
635             raise StackException, '"%s" already exists' % to_stack.get_name()
636
637         patches = self.get_applied() + self.get_unapplied()
638
639         git.rename_branch(self.get_name(), to_name)
640
641         for patch in patches:
642             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
643                            'refs/patches/%s/%s' % (to_name, patch))
644             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
645                            'refs/patches/%s/%s.log' % (to_name, patch))
646         if os.path.isdir(self._dir()):
647             rename(os.path.join(self._basedir(), 'patches'),
648                    self.get_name(), to_stack.get_name())
649
650         # Rename the config section
651         for k in ['branch.%s', 'branch.%s.stgit']:
652             config.rename_section(k % self.get_name(), k % to_name)
653
654         self.__init__(to_name)
655
656     def clone(self, target_series):
657         """Clones a series
658         """
659         try:
660             # allow cloning of branches not under StGIT control
661             base = self.get_base()
662         except:
663             base = git.get_head()
664         Series(target_series).init(create_at = base)
665         new_series = Series(target_series)
666
667         # generate an artificial description file
668         new_series.set_description('clone of "%s"' % self.get_name())
669
670         # clone self's entire series as unapplied patches
671         try:
672             # allow cloning of branches not under StGIT control
673             applied = self.get_applied()
674             unapplied = self.get_unapplied()
675             patches = applied + unapplied
676             patches.reverse()
677         except:
678             patches = applied = unapplied = []
679         for p in patches:
680             patch = self.get_patch(p)
681             newpatch = new_series.new_patch(p, message = patch.get_description(),
682                                             can_edit = False, unapplied = True,
683                                             bottom = patch.get_bottom(),
684                                             top = patch.get_top(),
685                                             author_name = patch.get_authname(),
686                                             author_email = patch.get_authemail(),
687                                             author_date = patch.get_authdate())
688             if patch.get_log():
689                 out.info('Setting log to %s' %  patch.get_log())
690                 newpatch.set_log(patch.get_log())
691             else:
692                 out.info('No log for %s' % p)
693
694         # fast forward the cloned series to self's top
695         new_series.forward_patches(applied)
696
697         # Clone parent informations
698         value = config.get('branch.%s.remote' % self.get_name())
699         if value:
700             config.set('branch.%s.remote' % target_series, value)
701
702         value = config.get('branch.%s.merge' % self.get_name())
703         if value:
704             config.set('branch.%s.merge' % target_series, value)
705
706         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
707         if value:
708             config.set('branch.%s.stgit.parentbranch' % target_series, value)
709
710     def delete(self, force = False):
711         """Deletes an stgit series
712         """
713         if self.is_initialised():
714             patches = self.get_unapplied() + self.get_applied()
715             if not force and patches:
716                 raise StackException, \
717                       'Cannot delete: the series still contains patches'
718             for p in patches:
719                 self.get_patch(p).delete()
720
721             # remove the trash directory if any
722             if os.path.exists(self.__trash_dir):
723                 for fname in os.listdir(self.__trash_dir):
724                     os.remove(os.path.join(self.__trash_dir, fname))
725                 os.rmdir(self.__trash_dir)
726
727             # FIXME: find a way to get rid of those manual removals
728             # (move functionality to StgitObject ?)
729             if os.path.exists(self.__applied_file):
730                 os.remove(self.__applied_file)
731             if os.path.exists(self.__unapplied_file):
732                 os.remove(self.__unapplied_file)
733             if os.path.exists(self.__hidden_file):
734                 os.remove(self.__hidden_file)
735             if os.path.exists(self._dir()+'/orig-base'):
736                 os.remove(self._dir()+'/orig-base')
737
738             if not os.listdir(self.__patch_dir):
739                 os.rmdir(self.__patch_dir)
740             else:
741                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
742
743             try:
744                 os.removedirs(self._dir())
745             except OSError:
746                 raise StackException('Series directory %s is not empty'
747                                      % self._dir())
748
749             try:
750                 git.delete_branch(self.get_name())
751             except GitException:
752                 out.warn('Could not delete branch "%s"' % self.get_name())
753
754         config.remove_section('branch.%s' % self.get_name())
755         config.remove_section('branch.%s.stgit' % self.get_name())
756
757     def refresh_patch(self, files = None, message = None, edit = False,
758                       show_patch = False,
759                       cache_update = True,
760                       author_name = None, author_email = None,
761                       author_date = None,
762                       committer_name = None, committer_email = None,
763                       backup = False, sign_str = None, log = 'refresh',
764                       notes = None, bottom = None):
765         """Generates a new commit for the topmost patch
766         """
767         patch = self.get_current_patch()
768         if not patch:
769             raise StackException, 'No patches applied'
770
771         descr = patch.get_description()
772         if not (message or descr):
773             edit = True
774             descr = ''
775         elif message:
776             descr = message
777
778         # TODO: move this out of the stgit.stack module, it is really
779         # for higher level commands to handle the user interaction
780         if not message and edit:
781             descr = edit_file(self, descr.rstrip(), \
782                               'Please edit the description for patch "%s" ' \
783                               'above.' % patch.get_name(), show_patch)
784
785         if not author_name:
786             author_name = patch.get_authname()
787         if not author_email:
788             author_email = patch.get_authemail()
789         if not author_date:
790             author_date = patch.get_authdate()
791         if not committer_name:
792             committer_name = patch.get_commname()
793         if not committer_email:
794             committer_email = patch.get_commemail()
795
796         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
797
798         if not bottom:
799             bottom = patch.get_bottom()
800
801         commit_id = git.commit(files = files,
802                                message = descr, parents = [bottom],
803                                cache_update = cache_update,
804                                allowempty = True,
805                                author_name = author_name,
806                                author_email = author_email,
807                                author_date = author_date,
808                                committer_name = committer_name,
809                                committer_email = committer_email)
810
811         patch.set_bottom(bottom, backup = backup)
812         patch.set_top(commit_id, backup = backup)
813         patch.set_description(descr)
814         patch.set_authname(author_name)
815         patch.set_authemail(author_email)
816         patch.set_authdate(author_date)
817         patch.set_commname(committer_name)
818         patch.set_commemail(committer_email)
819
820         if log:
821             self.log_patch(patch, log, notes)
822
823         return commit_id
824
825     def undo_refresh(self):
826         """Undo the patch boundaries changes caused by 'refresh'
827         """
828         name = self.get_current()
829         assert(name)
830
831         patch = self.get_patch(name)
832         old_bottom = patch.get_old_bottom()
833         old_top = patch.get_old_top()
834
835         # the bottom of the patch is not changed by refresh. If the
836         # old_bottom is different, there wasn't any previous 'refresh'
837         # command (probably only a 'push')
838         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
839             raise StackException, 'No undo information available'
840
841         git.reset(tree_id = old_top, check_out = False)
842         if patch.restore_old_boundaries():
843             self.log_patch(patch, 'undo')
844
845     def new_patch(self, name, message = None, can_edit = True,
846                   unapplied = False, show_patch = False,
847                   top = None, bottom = None, commit = True,
848                   author_name = None, author_email = None, author_date = None,
849                   committer_name = None, committer_email = None,
850                   before_existing = False, sign_str = None):
851         """Creates a new patch, either pointing to an existing commit object,
852         or by creating a new commit object.
853         """
854
855         assert commit or (top and bottom)
856         assert not before_existing or (top and bottom)
857         assert not (commit and before_existing)
858         assert (top and bottom) or (not top and not bottom)
859         assert not top or (bottom == git.get_commit(top).get_parent())
860
861         if name != None:
862             self.__patch_name_valid(name)
863             if self.patch_exists(name):
864                 raise StackException, 'Patch "%s" already exists' % name
865
866         # TODO: move this out of the stgit.stack module, it is really
867         # for higher level commands to handle the user interaction
868         def sign(msg):
869             return add_sign_line(msg, sign_str,
870                                  committer_name or git.committer().name,
871                                  committer_email or git.committer().email)
872         if not message and can_edit:
873             descr = edit_file(
874                 self, sign(''),
875                 'Please enter the description for the patch above.',
876                 show_patch)
877         else:
878             descr = sign(message)
879
880         head = git.get_head()
881
882         if name == None:
883             name = make_patch_name(descr, self.patch_exists)
884
885         patch = self.get_patch(name)
886         patch.create()
887
888         patch.set_description(descr)
889         patch.set_authname(author_name)
890         patch.set_authemail(author_email)
891         patch.set_authdate(author_date)
892         patch.set_commname(committer_name)
893         patch.set_commemail(committer_email)
894
895         if before_existing:
896             insert_string(self.__applied_file, patch.get_name())
897         elif unapplied:
898             patches = [patch.get_name()] + self.get_unapplied()
899             write_strings(self.__unapplied_file, patches)
900             set_head = False
901         else:
902             append_string(self.__applied_file, patch.get_name())
903             set_head = True
904
905         if commit:
906             if top:
907                 top_commit = git.get_commit(top)
908             else:
909                 bottom = head
910                 top_commit = git.get_commit(head)
911
912             # create a commit for the patch (may be empty if top == bottom);
913             # only commit on top of the current branch
914             assert(unapplied or bottom == head)
915             commit_id = git.commit(message = descr, parents = [bottom],
916                                    cache_update = False,
917                                    tree_id = top_commit.get_tree(),
918                                    allowempty = True, set_head = set_head,
919                                    author_name = author_name,
920                                    author_email = author_email,
921                                    author_date = author_date,
922                                    committer_name = committer_name,
923                                    committer_email = committer_email)
924             # set the patch top to the new commit
925             patch.set_bottom(bottom)
926             patch.set_top(commit_id)
927         else:
928             assert top != bottom
929             patch.set_bottom(bottom)
930             patch.set_top(top)
931
932         self.log_patch(patch, 'new')
933
934         return patch
935
936     def delete_patch(self, name):
937         """Deletes a patch
938         """
939         self.__patch_name_valid(name)
940         patch = self.get_patch(name)
941
942         if self.__patch_is_current(patch):
943             self.pop_patch(name)
944         elif self.patch_applied(name):
945             raise StackException, 'Cannot remove an applied patch, "%s", ' \
946                   'which is not current' % name
947         elif not name in self.get_unapplied():
948             raise StackException, 'Unknown patch "%s"' % name
949
950         # save the commit id to a trash file
951         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
952
953         patch.delete()
954
955         unapplied = self.get_unapplied()
956         unapplied.remove(name)
957         write_strings(self.__unapplied_file, unapplied)
958
959     def forward_patches(self, names):
960         """Try to fast-forward an array of patches.
961
962         On return, patches in names[0:returned_value] have been pushed on the
963         stack. Apply the rest with push_patch
964         """
965         unapplied = self.get_unapplied()
966
967         forwarded = 0
968         top = git.get_head()
969
970         for name in names:
971             assert(name in unapplied)
972
973             patch = self.get_patch(name)
974
975             head = top
976             bottom = patch.get_bottom()
977             top = patch.get_top()
978
979             # top != bottom always since we have a commit for each patch
980             if head == bottom:
981                 # reset the backup information. No logging since the
982                 # patch hasn't changed
983                 patch.set_bottom(head, backup = True)
984                 patch.set_top(top, backup = True)
985
986             else:
987                 head_tree = git.get_commit(head).get_tree()
988                 bottom_tree = git.get_commit(bottom).get_tree()
989                 if head_tree == bottom_tree:
990                     # We must just reparent this patch and create a new commit
991                     # for it
992                     descr = patch.get_description()
993                     author_name = patch.get_authname()
994                     author_email = patch.get_authemail()
995                     author_date = patch.get_authdate()
996                     committer_name = patch.get_commname()
997                     committer_email = patch.get_commemail()
998
999                     top_tree = git.get_commit(top).get_tree()
1000
1001                     top = git.commit(message = descr, parents = [head],
1002                                      cache_update = False,
1003                                      tree_id = top_tree,
1004                                      allowempty = True,
1005                                      author_name = author_name,
1006                                      author_email = author_email,
1007                                      author_date = author_date,
1008                                      committer_name = committer_name,
1009                                      committer_email = committer_email)
1010
1011                     patch.set_bottom(head, backup = True)
1012                     patch.set_top(top, backup = True)
1013
1014                     self.log_patch(patch, 'push(f)')
1015                 else:
1016                     top = head
1017                     # stop the fast-forwarding, must do a real merge
1018                     break
1019
1020             forwarded+=1
1021             unapplied.remove(name)
1022
1023         if forwarded == 0:
1024             return 0
1025
1026         git.switch(top)
1027
1028         append_strings(self.__applied_file, names[0:forwarded])
1029         write_strings(self.__unapplied_file, unapplied)
1030
1031         return forwarded
1032
1033     def merged_patches(self, names):
1034         """Test which patches were merged upstream by reverse-applying
1035         them in reverse order. The function returns the list of
1036         patches detected to have been applied. The state of the tree
1037         is restored to the original one
1038         """
1039         patches = [self.get_patch(name) for name in names]
1040         patches.reverse()
1041
1042         merged = []
1043         for p in patches:
1044             if git.apply_diff(p.get_top(), p.get_bottom()):
1045                 merged.append(p.get_name())
1046         merged.reverse()
1047
1048         git.reset()
1049
1050         return merged
1051
1052     def push_empty_patch(self, name):
1053         """Pushes an empty patch on the stack
1054         """
1055         unapplied = self.get_unapplied()
1056         assert(name in unapplied)
1057
1058         # patch = self.get_patch(name)
1059         head = git.get_head()
1060
1061         append_string(self.__applied_file, name)
1062
1063         unapplied.remove(name)
1064         write_strings(self.__unapplied_file, unapplied)
1065
1066         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1067
1068     def push_patch(self, name):
1069         """Pushes a patch on the stack
1070         """
1071         unapplied = self.get_unapplied()
1072         assert(name in unapplied)
1073
1074         patch = self.get_patch(name)
1075
1076         head = git.get_head()
1077         bottom = patch.get_bottom()
1078         top = patch.get_top()
1079         # top != bottom always since we have a commit for each patch
1080
1081         if head == bottom:
1082             # A fast-forward push. Just reset the backup
1083             # information. No need for logging
1084             patch.set_bottom(bottom, backup = True)
1085             patch.set_top(top, backup = True)
1086
1087             git.switch(top)
1088             append_string(self.__applied_file, name)
1089
1090             unapplied.remove(name)
1091             write_strings(self.__unapplied_file, unapplied)
1092             return False
1093
1094         # Need to create a new commit an merge in the old patch
1095         ex = None
1096         modified = False
1097
1098         # Try the fast applying first. If this fails, fall back to the
1099         # three-way merge
1100         if not git.apply_diff(bottom, top):
1101             # if git.apply_diff() fails, the patch requires a diff3
1102             # merge and can be reported as modified
1103             modified = True
1104
1105             # merge can fail but the patch needs to be pushed
1106             try:
1107                 git.merge(bottom, head, top, recursive = True)
1108             except git.GitException, ex:
1109                 out.error('The merge failed during "push".',
1110                           'Use "refresh" after fixing the conflicts or'
1111                           ' revert the operation with "push --undo".')
1112
1113         append_string(self.__applied_file, name)
1114
1115         unapplied.remove(name)
1116         write_strings(self.__unapplied_file, unapplied)
1117
1118         if not ex:
1119             # if the merge was OK and no conflicts, just refresh the patch
1120             # The GIT cache was already updated by the merge operation
1121             if modified:
1122                 log = 'push(m)'
1123             else:
1124                 log = 'push'
1125             self.refresh_patch(bottom = head, cache_update = False, log = log)
1126         else:
1127             # we store the correctly merged files only for
1128             # tracking the conflict history. Note that the
1129             # git.merge() operations should always leave the index
1130             # in a valid state (i.e. only stage 0 files)
1131             self.refresh_patch(bottom = head, cache_update = False,
1132                                log = 'push(c)')
1133             raise StackException, str(ex)
1134
1135         return modified
1136
1137     def undo_push(self):
1138         name = self.get_current()
1139         assert(name)
1140
1141         patch = self.get_patch(name)
1142         old_bottom = patch.get_old_bottom()
1143         old_top = patch.get_old_top()
1144
1145         # the top of the patch is changed by a push operation only
1146         # together with the bottom (otherwise the top was probably
1147         # modified by 'refresh'). If they are both unchanged, there
1148         # was a fast forward
1149         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1150             raise StackException, 'No undo information available'
1151
1152         git.reset()
1153         self.pop_patch(name)
1154         ret = patch.restore_old_boundaries()
1155         if ret:
1156             self.log_patch(patch, 'undo')
1157
1158         return ret
1159
1160     def pop_patch(self, name, keep = False):
1161         """Pops the top patch from the stack
1162         """
1163         applied = self.get_applied()
1164         applied.reverse()
1165         assert(name in applied)
1166
1167         patch = self.get_patch(name)
1168
1169         if git.get_head_file() == self.get_name():
1170             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1171                 raise StackException(
1172                     'Failed to pop patches while preserving the local changes')
1173             git.switch(patch.get_bottom(), keep)
1174         else:
1175             git.set_branch(self.get_name(), patch.get_bottom())
1176
1177         # save the new applied list
1178         idx = applied.index(name) + 1
1179
1180         popped = applied[:idx]
1181         popped.reverse()
1182         unapplied = popped + self.get_unapplied()
1183         write_strings(self.__unapplied_file, unapplied)
1184
1185         del applied[:idx]
1186         applied.reverse()
1187         write_strings(self.__applied_file, applied)
1188
1189     def empty_patch(self, name):
1190         """Returns True if the patch is empty
1191         """
1192         self.__patch_name_valid(name)
1193         patch = self.get_patch(name)
1194         bottom = patch.get_bottom()
1195         top = patch.get_top()
1196
1197         if bottom == top:
1198             return True
1199         elif git.get_commit(top).get_tree() \
1200                  == git.get_commit(bottom).get_tree():
1201             return True
1202
1203         return False
1204
1205     def rename_patch(self, oldname, newname):
1206         self.__patch_name_valid(newname)
1207
1208         applied = self.get_applied()
1209         unapplied = self.get_unapplied()
1210
1211         if oldname == newname:
1212             raise StackException, '"To" name and "from" name are the same'
1213
1214         if newname in applied or newname in unapplied:
1215             raise StackException, 'Patch "%s" already exists' % newname
1216
1217         if oldname in unapplied:
1218             self.get_patch(oldname).rename(newname)
1219             unapplied[unapplied.index(oldname)] = newname
1220             write_strings(self.__unapplied_file, unapplied)
1221         elif oldname in applied:
1222             self.get_patch(oldname).rename(newname)
1223
1224             applied[applied.index(oldname)] = newname
1225             write_strings(self.__applied_file, applied)
1226         else:
1227             raise StackException, 'Unknown patch "%s"' % oldname
1228
1229     def log_patch(self, patch, message, notes = None):
1230         """Generate a log commit for a patch
1231         """
1232         top = git.get_commit(patch.get_top())
1233         old_log = patch.get_log()
1234
1235         if message is None:
1236             # replace the current log entry
1237             if not old_log:
1238                 raise StackException, \
1239                       'No log entry to annotate for patch "%s"' \
1240                       % patch.get_name()
1241             replace = True
1242             log_commit = git.get_commit(old_log)
1243             msg = log_commit.get_log().split('\n')[0]
1244             log_parent = log_commit.get_parent()
1245             if log_parent:
1246                 parents = [log_parent]
1247             else:
1248                 parents = []
1249         else:
1250             # generate a new log entry
1251             replace = False
1252             msg = '%s\t%s' % (message, top.get_id_hash())
1253             if old_log:
1254                 parents = [old_log]
1255             else:
1256                 parents = []
1257
1258         if notes:
1259             msg += '\n\n' + notes
1260
1261         log = git.commit(message = msg, parents = parents,
1262                          cache_update = False, tree_id = top.get_tree(),
1263                          allowempty = True)
1264         patch.set_log(log)
1265
1266     def hide_patch(self, name):
1267         """Add the patch to the hidden list.
1268         """
1269         unapplied = self.get_unapplied()
1270         if name not in unapplied:
1271             # keep the checking order for backward compatibility with
1272             # the old hidden patches functionality
1273             if self.patch_applied(name):
1274                 raise StackException, 'Cannot hide applied patch "%s"' % name
1275             elif self.patch_hidden(name):
1276                 raise StackException, 'Patch "%s" already hidden' % name
1277             else:
1278                 raise StackException, 'Unknown patch "%s"' % name
1279
1280         if not self.patch_hidden(name):
1281             # check needed for backward compatibility with the old
1282             # hidden patches functionality
1283             append_string(self.__hidden_file, name)
1284
1285         unapplied.remove(name)
1286         write_strings(self.__unapplied_file, unapplied)
1287
1288     def unhide_patch(self, name):
1289         """Remove the patch from the hidden list.
1290         """
1291         hidden = self.get_hidden()
1292         if not name in hidden:
1293             if self.patch_applied(name) or self.patch_unapplied(name):
1294                 raise StackException, 'Patch "%s" not hidden' % name
1295             else:
1296                 raise StackException, 'Unknown patch "%s"' % name
1297
1298         hidden.remove(name)
1299         write_strings(self.__hidden_file, hidden)
1300
1301         if not self.patch_applied(name) and not self.patch_unapplied(name):
1302             # check needed for backward compatibility with the old
1303             # hidden patches functionality
1304             append_string(self.__unapplied_file, name)